from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File, Request from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates import os import shutil import uuid import uvicorn app = FastAPI() # ========================= # 基础配置 # ========================= BASE_DIR = os.path.dirname(os.path.abspath(__file__)) TEMPLATE_DIR = os.path.join(BASE_DIR, "templates") UPLOAD_DIR = os.path.join(BASE_DIR, "uploads") MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB ALLOWED_IMAGE_EXTENSIONS = {"png", "jpg", "jpeg", "gif"} ALLOWED_VIDEO_EXTENSIONS = {"mp4", "webm"} # ========================= # 目录初始化 # ========================= os.makedirs(UPLOAD_DIR, exist_ok=True) templates = Jinja2Templates(directory=TEMPLATE_DIR) app.mount("/uploads", StaticFiles(directory=UPLOAD_DIR), name="uploads") # ========================= # WebSocket 状态 # ========================= active_connections: list[WebSocket] = [] usernames: dict[WebSocket, str] = {} # ========================= # 工具函数 # ========================= def get_ext(filename: str) -> str: return filename.rsplit(".", 1)[-1].lower() def is_image(ext: str) -> bool: return ext in ALLOWED_IMAGE_EXTENSIONS def is_video(ext: str) -> bool: return ext in ALLOWED_VIDEO_EXTENSIONS # ========================= # 页面路由 # ========================= @app.get("/", response_class=HTMLResponse) async def index(request: Request): return templates.TemplateResponse("chat.html", {"request": request}) # ========================= # 文件上传接口 # ========================= @app.post("/upload") async def upload_file(file: UploadFile = File(...)): ext = get_ext(file.filename) if not (is_image(ext) or is_video(ext)): return {"error": "不支持的文件类型"} contents = await file.read() if len(contents) > MAX_FILE_SIZE: return {"error": "文件过大(最大 10MB)"} filename = f"{uuid.uuid4().hex}.{ext}" save_path = os.path.join(UPLOAD_DIR, filename) with open(save_path, "wb") as f: f.write(contents) file_url = f"/uploads/{filename}" if is_image(ext): msg = f"IMG::{file_url}" elif is_video(ext): msg = f"VIDEO::{file_url}" else: msg = f"FILE::{file_url}" await broadcast(msg) return {"url": file_url} # ========================= # WebSocket 聊天 # ========================= @app.websocket("/ws") async def websocket_endpoint(ws: WebSocket): await ws.accept() active_connections.append(ws) await broadcast_online_count() try: # 首条消息作为用户名 username = await ws.receive_text() usernames[ws] = username await broadcast(f"SYSTEM::{username} 加入聊天室") await broadcast_online_count() while True: text = await ws.receive_text() await broadcast(f"TEXT::{username}::{text}") except WebSocketDisconnect: pass finally: if ws in active_connections: active_connections.remove(ws) name = usernames.pop(ws, "匿名用户") await broadcast(f"SYSTEM::{name} 离开聊天室") await broadcast_online_count() # ========================= # 广播工具 # ========================= async def broadcast(message: str): dead = [] for ws in active_connections: try: await ws.send_text(message) except Exception: dead.append(ws) for ws in dead: active_connections.remove(ws) usernames.pop(ws, None) async def broadcast_online_count(): msg = f"COUNT::{len(active_connections)}" await broadcast(msg) # ========================= # 启动 # ========================= if __name__ == "__main__": uvicorn.run( "chat:app", host="0.0.0.0", port=8202, reload=True )