""" ChatHub - 实时聊天室 FastAPI + WebSocket """ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File, Form, Request from fastapi.responses import HTMLResponse, JSONResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from typing import List, Dict, Tuple import os import uuid import json import asyncio import shutil import time import collections from datetime import datetime app = FastAPI() BASE_DIR = os.path.dirname(os.path.abspath(__file__)) TEMPLATE_DIR = os.path.join(BASE_DIR, "templates") UPLOAD_DIR = os.path.join(BASE_DIR, "uploads") BACKUP_DIR = os.path.join(BASE_DIR, "backups") os.makedirs(UPLOAD_DIR, exist_ok=True) os.makedirs(BACKUP_DIR, exist_ok=True) CHAT_LOG_FILE = os.path.join(BASE_DIR, "chat_history.txt") MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB per file ALLOWED_IMAGE_EXTENSIONS = {"png", "jpg", "jpeg", "gif", "bmp", "tiff", "webp", "heic"} ALLOWED_VIDEO_EXTENSIONS = {"mp4", "mov", "wmv", "avi", "m4v", "mpg", "mpeg", "flv", "mkv", "3gp", "webm"} templates = Jinja2Templates(directory=TEMPLATE_DIR) app.mount("/uploads", StaticFiles(directory=UPLOAD_DIR), name="uploads") # ── Global State ──────────────────────────────────────────── connections: Dict[WebSocket, str] = {} # ws -> username glock = asyncio.Lock() # ── Rate Limiting ───────────────────────────────────────── # Upload rate limit: {ip: [timestamp, ...]} upload_history: Dict[str, List[float]] = {} UPLOAD_COOLDOWN = 3.0 # seconds between uploads # Message rate limit: per-connection message timestamps msg_history: Dict[WebSocket, List[float]] = {} MSG_BURST = 8 # max messages MSG_WINDOW = 3.0 # per this many seconds def get_ext(fn: str) -> str: return fn.rsplit(".", 1)[-1].lower() def is_image(ext: str) -> bool: return ext in ALLOWED_IMAGE_EXTENSIONS def is_video(ext: str) -> bool: return ext in ALLOWED_VIDEO_EXTENSIONS def log_to_file(msg: str): ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S") try: with open(CHAT_LOG_FILE, "a", encoding="utf-8") as f: f.write(f"[{ts}] {msg}\n") except Exception: pass def load_recent_history(n: int = 200) -> List[str]: if not os.path.exists(CHAT_LOG_FILE): return [] try: with open(CHAT_LOG_FILE, "r", encoding="utf-8") as f: lines = f.readlines() return [l.strip() for l in lines[-n:]] except Exception: return [] # ── HTTP Endpoints ────────────────────────────────────────── @app.get("/", response_class=HTMLResponse) async def index(request: Request): history = load_recent_history(200) return templates.TemplateResponse("chat.html", {"request": request, "history": history}) @app.get("/members") async def get_members(): """Return current online member list.""" async with glock: names = list(connections.values()) return JSONResponse({"members": names}) # ── Rate Limiter Helpers ──────────────────────────────── async def _check_upload_rate(request: Request) -> Tuple[bool, str]: """Check upload rate limit per client IP. Returns (allowed, reason). """ ip = request.client.host if request.client else "unknown" now = time.time() async with glock: times = upload_history.get(ip, []) # Clean old entries times = [t for t in times if now - t < UPLOAD_COOLDOWN] if times: remaining = int(UPLOAD_COOLDOWN - (now - times[-1])) return False, f"上传太频繁,请 {remaining} 秒后再试" times.append(now) upload_history[ip] = times return True, "" async def _check_msg_rate(ws: WebSocket) -> bool: """Check message send rate per WebSocket connection.""" now = time.time() async with glock: times = msg_history.get(ws, []) times = [t for t in times if now - t < MSG_WINDOW] if len(times) >= MSG_BURST: return False times.append(now) msg_history[ws] = times return True # ── HTTP Endpoints ────────────────────────────────────────── @app.get("/", response_class=HTMLResponse) async def index(request: Request): history = load_recent_history(200) return templates.TemplateResponse("chat.html", {"request": request, "history": history}) @app.get("/members") async def get_members(): """Return current online member list.""" async with glock: names = list(connections.values()) return JSONResponse({"members": names}) @app.post("/upload") async def upload_file(file: UploadFile = File(...), request: Request = None, sender: str = Form("匿名")): # Rate limit check allowed, reason = await _check_upload_rate(request) if not allowed: return JSONResponse({"error": reason}, status_code=429) # File type check ext = get_ext(file.filename) if not (is_image(ext) or is_video(ext)): return JSONResponse({"error": "不支持的文件类型,仅接受图片和视频"}, status_code=400) # Read and check size data = await file.read() if len(data) > MAX_FILE_SIZE: size_mb = len(data) / (1024 * 1024) return JSONResponse({ "error": f"文件过大({size_mb:.1f}MB),最大允许 10MB" }, status_code=413) # Save and broadcast fn = f"{uuid.uuid4().hex}.{ext}" path = os.path.join(UPLOAD_DIR, fn) with open(path, "wb") as f: f.write(data) url = f"/uploads/{fn}" # Format: TYPE::sender::url — frontend parses two colons msg = f"IMG::{sender}::{url}" if is_image(ext) else f"VIDEO::{sender}::{url}" asyncio.create_task(_broadcast(msg)) return {"url": url} # ── Broadcast ─────────────────────────────────────────────── async def _broadcast(message: str): """Send message to all connected clients. Uses snapshot-then-send pattern: grab targets under lock, then send concurrently without lock. """ async with glock: targets = list(connections.items()) # [(ws, name), ...] tasks = [] for ws, _ in targets: tasks.append(_safe_send(ws, message)) if tasks: await asyncio.gather(*tasks, return_exceptions=True) async def _safe_send(ws: WebSocket, message: str): """Send text to one WebSocket, remove on failure.""" try: await ws.send_text(message) except Exception: async with glock: connections.pop(ws, None) async def _broadcast_members(): """Broadcast JSON member list and count to all clients.""" async with glock: names = list(connections.values()) count = len(names) payload = json.dumps({"type": "members", "data": names}) await _broadcast(f"MEMBERS::{payload}") await _broadcast(f"COUNT::{count}") # ── WebSocket Endpoint ────────────────────────────────────── @app.websocket("/wschat") async def ws_endpoint(ws: WebSocket): await ws.accept() # ── 1) Receive username as first text frame ── try: raw = await asyncio.wait_for(ws.receive_text(), timeout=30) except Exception: await ws.close(1008) return username = (raw.strip() or f"匿名{id(ws) % 10000}")[:20] async with glock: connections[ws] = username join_msg = f"SYSTEM::{username} 加入聊天室" log_to_file(join_msg) asyncio.create_task(_broadcast(join_msg)) asyncio.create_task(_broadcast_members()) # ── 2) Message loop ── try: while True: text = await ws.receive_text() if not text or not text.strip(): continue # Rate limit check if not await _check_msg_rate(ws): asyncio.create_task(_safe_send(ws, "SYSTEM::⚠️ 发送太频繁,请稍后再试")) continue msg = f"TEXT::{username}::{text.strip()}" log_to_file(msg) asyncio.create_task(_broadcast(msg)) except (WebSocketDisconnect, Exception): pass finally: async with glock: name = connections.pop(ws, None) or "匿名用户" leave_msg = f"SYSTEM::{name} 离开聊天室" log_to_file(leave_msg) asyncio.create_task(_broadcast(leave_msg)) asyncio.create_task(_broadcast_members()) # ── Nightly Cleanup ───────────────────────────────────────── @app.post("/api/nightly-cleanup") async def nightly_cleanup(): """Backup chat history to backups/ dir, then clear the log. Also purge uploaded files older than 24h. Called by cron / systemd timer. """ now = datetime.now() backup_name = f"chat_history_{now.strftime('%Y%m%d_%H%M%S')}.txt" backup_path = os.path.join(BACKUP_DIR, backup_name) if os.path.exists(CHAT_LOG_FILE) and os.path.getsize(CHAT_LOG_FILE) > 0: shutil.copy2(CHAT_LOG_FILE, backup_path) # Clear log open(CHAT_LOG_FILE, "w", encoding="utf-8").close() # Clean old uploads (> 24h) now_ts = time.time() cleaned = 0 for fname in os.listdir(UPLOAD_DIR): fpath = os.path.join(UPLOAD_DIR, fname) if os.path.isfile(fpath) and now_ts - os.path.getmtime(fpath) > 86400: try: os.remove(fpath) cleaned += 1 except Exception: pass return {"status": "ok", "backup": backup_name, "cleaned_uploads": cleaned} # ── Main ──────────────────────────────────────────────────── if __name__ == "__main__": import uvicorn uvicorn.run("chat:app", host="0.0.0.0", port=8202, reload=False, log_level="info")