|
|
"""
|
|
|
ChatHub - 实时聊天室
|
|
|
FastAPI + WebSocket
|
|
|
"""
|
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File, Form, Request
|
|
|
from fastapi.responses import HTMLResponse, JSONResponse
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
from fastapi.templating import Jinja2Templates
|
|
|
from typing import List, Dict, Tuple
|
|
|
import os
|
|
|
import uuid
|
|
|
import json
|
|
|
import asyncio
|
|
|
import shutil
|
|
|
import time
|
|
|
import collections
|
|
|
from datetime import datetime
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
|
TEMPLATE_DIR = os.path.join(BASE_DIR, "templates")
|
|
|
UPLOAD_DIR = os.path.join(BASE_DIR, "uploads")
|
|
|
BACKUP_DIR = os.path.join(BASE_DIR, "backups")
|
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
|
|
os.makedirs(BACKUP_DIR, exist_ok=True)
|
|
|
|
|
|
CHAT_LOG_FILE = os.path.join(BASE_DIR, "chat_history.txt")
|
|
|
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB per file
|
|
|
|
|
|
ALLOWED_IMAGE_EXTENSIONS = {"png", "jpg", "jpeg", "gif", "bmp", "tiff", "webp", "heic"}
|
|
|
ALLOWED_VIDEO_EXTENSIONS = {"mp4", "mov", "wmv", "avi", "m4v", "mpg", "mpeg", "flv", "mkv", "3gp", "webm"}
|
|
|
|
|
|
templates = Jinja2Templates(directory=TEMPLATE_DIR)
|
|
|
app.mount("/uploads", StaticFiles(directory=UPLOAD_DIR), name="uploads")
|
|
|
|
|
|
# ── Global State ────────────────────────────────────────────
|
|
|
connections: Dict[WebSocket, str] = {} # ws -> username
|
|
|
glock = asyncio.Lock()
|
|
|
|
|
|
# ── Rate Limiting ─────────────────────────────────────────
|
|
|
# Upload rate limit: {ip: [timestamp, ...]}
|
|
|
upload_history: Dict[str, List[float]] = {}
|
|
|
UPLOAD_COOLDOWN = 3.0 # seconds between uploads
|
|
|
|
|
|
# Message rate limit: per-connection message timestamps
|
|
|
msg_history: Dict[WebSocket, List[float]] = {}
|
|
|
MSG_BURST = 8 # max messages
|
|
|
MSG_WINDOW = 3.0 # per this many seconds
|
|
|
|
|
|
|
|
|
def get_ext(fn: str) -> str:
|
|
|
return fn.rsplit(".", 1)[-1].lower()
|
|
|
|
|
|
|
|
|
def is_image(ext: str) -> bool:
|
|
|
return ext in ALLOWED_IMAGE_EXTENSIONS
|
|
|
|
|
|
|
|
|
def is_video(ext: str) -> bool:
|
|
|
return ext in ALLOWED_VIDEO_EXTENSIONS
|
|
|
|
|
|
|
|
|
def log_to_file(msg: str):
|
|
|
ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
|
try:
|
|
|
with open(CHAT_LOG_FILE, "a", encoding="utf-8") as f:
|
|
|
f.write(f"[{ts}] {msg}\n")
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
|
|
|
def load_recent_history(n: int = 200) -> List[str]:
|
|
|
if not os.path.exists(CHAT_LOG_FILE):
|
|
|
return []
|
|
|
try:
|
|
|
with open(CHAT_LOG_FILE, "r", encoding="utf-8") as f:
|
|
|
lines = f.readlines()
|
|
|
return [l.strip() for l in lines[-n:]]
|
|
|
except Exception:
|
|
|
return []
|
|
|
|
|
|
|
|
|
# ── HTTP Endpoints ──────────────────────────────────────────
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
|
async def index(request: Request):
|
|
|
history = load_recent_history(200)
|
|
|
return templates.TemplateResponse("chat.html", {"request": request, "history": history})
|
|
|
|
|
|
|
|
|
@app.get("/members")
|
|
|
async def get_members():
|
|
|
"""Return current online member list."""
|
|
|
async with glock:
|
|
|
names = list(connections.values())
|
|
|
return JSONResponse({"members": names})
|
|
|
|
|
|
|
|
|
# ── Rate Limiter Helpers ────────────────────────────────
|
|
|
|
|
|
async def _check_upload_rate(request: Request) -> Tuple[bool, str]:
|
|
|
"""Check upload rate limit per client IP.
|
|
|
Returns (allowed, reason).
|
|
|
"""
|
|
|
ip = request.client.host if request.client else "unknown"
|
|
|
now = time.time()
|
|
|
|
|
|
async with glock:
|
|
|
times = upload_history.get(ip, [])
|
|
|
# Clean old entries
|
|
|
times = [t for t in times if now - t < UPLOAD_COOLDOWN]
|
|
|
if times:
|
|
|
remaining = int(UPLOAD_COOLDOWN - (now - times[-1]))
|
|
|
return False, f"上传太频繁,请 {remaining} 秒后再试"
|
|
|
times.append(now)
|
|
|
upload_history[ip] = times
|
|
|
return True, ""
|
|
|
|
|
|
|
|
|
async def _check_msg_rate(ws: WebSocket) -> bool:
|
|
|
"""Check message send rate per WebSocket connection."""
|
|
|
now = time.time()
|
|
|
async with glock:
|
|
|
times = msg_history.get(ws, [])
|
|
|
times = [t for t in times if now - t < MSG_WINDOW]
|
|
|
if len(times) >= MSG_BURST:
|
|
|
return False
|
|
|
times.append(now)
|
|
|
msg_history[ws] = times
|
|
|
return True
|
|
|
|
|
|
|
|
|
# ── HTTP Endpoints ──────────────────────────────────────────
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
|
async def index(request: Request):
|
|
|
history = load_recent_history(200)
|
|
|
return templates.TemplateResponse("chat.html", {"request": request, "history": history})
|
|
|
|
|
|
|
|
|
@app.get("/members")
|
|
|
async def get_members():
|
|
|
"""Return current online member list."""
|
|
|
async with glock:
|
|
|
names = list(connections.values())
|
|
|
return JSONResponse({"members": names})
|
|
|
|
|
|
|
|
|
@app.post("/upload")
|
|
|
async def upload_file(file: UploadFile = File(...), request: Request = None, sender: str = Form("匿名")):
|
|
|
# Rate limit check
|
|
|
allowed, reason = await _check_upload_rate(request)
|
|
|
if not allowed:
|
|
|
return JSONResponse({"error": reason}, status_code=429)
|
|
|
|
|
|
# File type check
|
|
|
ext = get_ext(file.filename)
|
|
|
if not (is_image(ext) or is_video(ext)):
|
|
|
return JSONResponse({"error": "不支持的文件类型,仅接受图片和视频"}, status_code=400)
|
|
|
|
|
|
# Read and check size
|
|
|
data = await file.read()
|
|
|
if len(data) > MAX_FILE_SIZE:
|
|
|
size_mb = len(data) / (1024 * 1024)
|
|
|
return JSONResponse({
|
|
|
"error": f"文件过大({size_mb:.1f}MB),最大允许 10MB"
|
|
|
}, status_code=413)
|
|
|
|
|
|
# Save and broadcast
|
|
|
fn = f"{uuid.uuid4().hex}.{ext}"
|
|
|
path = os.path.join(UPLOAD_DIR, fn)
|
|
|
with open(path, "wb") as f:
|
|
|
f.write(data)
|
|
|
url = f"/uploads/{fn}"
|
|
|
# Format: TYPE::sender::url — frontend parses two colons
|
|
|
msg = f"IMG::{sender}::{url}" if is_image(ext) else f"VIDEO::{sender}::{url}"
|
|
|
asyncio.create_task(_broadcast(msg))
|
|
|
return {"url": url}
|
|
|
|
|
|
|
|
|
# ── Broadcast ───────────────────────────────────────────────
|
|
|
|
|
|
async def _broadcast(message: str):
|
|
|
"""Send message to all connected clients.
|
|
|
Uses snapshot-then-send pattern: grab targets under lock, then send concurrently without lock.
|
|
|
"""
|
|
|
async with glock:
|
|
|
targets = list(connections.items()) # [(ws, name), ...]
|
|
|
|
|
|
tasks = []
|
|
|
for ws, _ in targets:
|
|
|
tasks.append(_safe_send(ws, message))
|
|
|
if tasks:
|
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
|
|
|
|
|
async def _safe_send(ws: WebSocket, message: str):
|
|
|
"""Send text to one WebSocket, remove on failure."""
|
|
|
try:
|
|
|
await ws.send_text(message)
|
|
|
except Exception:
|
|
|
async with glock:
|
|
|
connections.pop(ws, None)
|
|
|
|
|
|
|
|
|
async def _broadcast_members():
|
|
|
"""Broadcast JSON member list and count to all clients."""
|
|
|
async with glock:
|
|
|
names = list(connections.values())
|
|
|
count = len(names)
|
|
|
payload = json.dumps({"type": "members", "data": names})
|
|
|
await _broadcast(f"MEMBERS::{payload}")
|
|
|
await _broadcast(f"COUNT::{count}")
|
|
|
|
|
|
|
|
|
# ── WebSocket Endpoint ──────────────────────────────────────
|
|
|
|
|
|
@app.websocket("/wschat")
|
|
|
async def ws_endpoint(ws: WebSocket):
|
|
|
await ws.accept()
|
|
|
|
|
|
# ── 1) Receive username as first text frame ──
|
|
|
try:
|
|
|
raw = await asyncio.wait_for(ws.receive_text(), timeout=30)
|
|
|
except Exception:
|
|
|
await ws.close(1008)
|
|
|
return
|
|
|
|
|
|
username = (raw.strip() or f"匿名{id(ws) % 10000}")[:20]
|
|
|
|
|
|
async with glock:
|
|
|
connections[ws] = username
|
|
|
|
|
|
join_msg = f"SYSTEM::{username} 加入聊天室"
|
|
|
log_to_file(join_msg)
|
|
|
asyncio.create_task(_broadcast(join_msg))
|
|
|
asyncio.create_task(_broadcast_members())
|
|
|
|
|
|
# ── 2) Message loop ──
|
|
|
try:
|
|
|
while True:
|
|
|
text = await ws.receive_text()
|
|
|
if not text or not text.strip():
|
|
|
continue
|
|
|
# Rate limit check
|
|
|
if not await _check_msg_rate(ws):
|
|
|
asyncio.create_task(_safe_send(ws, "SYSTEM::⚠️ 发送太频繁,请稍后再试"))
|
|
|
continue
|
|
|
msg = f"TEXT::{username}::{text.strip()}"
|
|
|
log_to_file(msg)
|
|
|
asyncio.create_task(_broadcast(msg))
|
|
|
except (WebSocketDisconnect, Exception):
|
|
|
pass
|
|
|
finally:
|
|
|
async with glock:
|
|
|
name = connections.pop(ws, None) or "匿名用户"
|
|
|
leave_msg = f"SYSTEM::{name} 离开聊天室"
|
|
|
log_to_file(leave_msg)
|
|
|
asyncio.create_task(_broadcast(leave_msg))
|
|
|
asyncio.create_task(_broadcast_members())
|
|
|
|
|
|
|
|
|
# ── Nightly Cleanup ─────────────────────────────────────────
|
|
|
|
|
|
@app.post("/api/nightly-cleanup")
|
|
|
async def nightly_cleanup():
|
|
|
"""Backup chat history to backups/ dir, then clear the log.
|
|
|
Also purge uploaded files older than 24h.
|
|
|
Called by cron / systemd timer.
|
|
|
"""
|
|
|
now = datetime.now()
|
|
|
backup_name = f"chat_history_{now.strftime('%Y%m%d_%H%M%S')}.txt"
|
|
|
backup_path = os.path.join(BACKUP_DIR, backup_name)
|
|
|
|
|
|
if os.path.exists(CHAT_LOG_FILE) and os.path.getsize(CHAT_LOG_FILE) > 0:
|
|
|
shutil.copy2(CHAT_LOG_FILE, backup_path)
|
|
|
# Clear log
|
|
|
open(CHAT_LOG_FILE, "w", encoding="utf-8").close()
|
|
|
|
|
|
# Clean old uploads (> 24h)
|
|
|
now_ts = time.time()
|
|
|
cleaned = 0
|
|
|
for fname in os.listdir(UPLOAD_DIR):
|
|
|
fpath = os.path.join(UPLOAD_DIR, fname)
|
|
|
if os.path.isfile(fpath) and now_ts - os.path.getmtime(fpath) > 86400:
|
|
|
try:
|
|
|
os.remove(fpath)
|
|
|
cleaned += 1
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
return {"status": "ok", "backup": backup_name, "cleaned_uploads": cleaned}
|
|
|
|
|
|
|
|
|
# ── Main ────────────────────────────────────────────────────
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
import uvicorn
|
|
|
uvicorn.run("chat:app", host="0.0.0.0", port=8202, reload=False, log_level="info")
|