You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

160 lines
3.9 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File, Request
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import os
import shutil
import uuid
import uvicorn
app = FastAPI()
# =========================
# 基础配置
# =========================
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
TEMPLATE_DIR = os.path.join(BASE_DIR, "templates")
UPLOAD_DIR = os.path.join(BASE_DIR, "uploads")
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
ALLOWED_IMAGE_EXTENSIONS = {"png", "jpg", "jpeg", "gif"}
ALLOWED_VIDEO_EXTENSIONS = {"mp4", "webm"}
# =========================
# 目录初始化
# =========================
os.makedirs(UPLOAD_DIR, exist_ok=True)
templates = Jinja2Templates(directory=TEMPLATE_DIR)
app.mount("/uploads", StaticFiles(directory=UPLOAD_DIR), name="uploads")
# =========================
# WebSocket 状态
# =========================
active_connections: list[WebSocket] = []
usernames: dict[WebSocket, str] = {}
# =========================
# 工具函数
# =========================
def get_ext(filename: str) -> str:
return filename.rsplit(".", 1)[-1].lower()
def is_image(ext: str) -> bool:
return ext in ALLOWED_IMAGE_EXTENSIONS
def is_video(ext: str) -> bool:
return ext in ALLOWED_VIDEO_EXTENSIONS
# =========================
# 页面路由
# =========================
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
return templates.TemplateResponse("chat.html", {"request": request})
# =========================
# 文件上传接口
# =========================
@app.post("/upload")
async def upload_file(file: UploadFile = File(...)):
ext = get_ext(file.filename)
if not (is_image(ext) or is_video(ext)):
return {"error": "不支持的文件类型"}
contents = await file.read()
if len(contents) > MAX_FILE_SIZE:
return {"error": "文件过大(最大 10MB"}
filename = f"{uuid.uuid4().hex}.{ext}"
save_path = os.path.join(UPLOAD_DIR, filename)
with open(save_path, "wb") as f:
f.write(contents)
file_url = f"/uploads/{filename}"
if is_image(ext):
msg = f"IMG::{file_url}"
elif is_video(ext):
msg = f"VIDEO::{file_url}"
else:
msg = f"FILE::{file_url}"
await broadcast(msg)
return {"url": file_url}
# =========================
# WebSocket 聊天
# =========================
@app.websocket("/ws")
async def websocket_endpoint(ws: WebSocket):
await ws.accept()
active_connections.append(ws)
await broadcast_online_count()
try:
# 首条消息作为用户名
username = await ws.receive_text()
usernames[ws] = username
await broadcast(f"SYSTEM::{username} 加入聊天室")
await broadcast_online_count()
while True:
text = await ws.receive_text()
await broadcast(f"TEXT::{username}::{text}")
except WebSocketDisconnect:
pass
finally:
if ws in active_connections:
active_connections.remove(ws)
name = usernames.pop(ws, "匿名用户")
await broadcast(f"SYSTEM::{name} 离开聊天室")
await broadcast_online_count()
# =========================
# 广播工具
# =========================
async def broadcast(message: str):
dead = []
for ws in active_connections:
try:
await ws.send_text(message)
except Exception:
dead.append(ws)
for ws in dead:
active_connections.remove(ws)
usernames.pop(ws, None)
async def broadcast_online_count():
msg = f"COUNT::{len(active_connections)}"
await broadcast(msg)
# =========================
# 启动
# =========================
if __name__ == "__main__":
uvicorn.run(
"chat:app",
host="0.0.0.0",
port=8202,
reload=True
)