You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

301 lines
10 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
ChatHub - 实时聊天室
FastAPI + WebSocket
"""
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File, Form, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from typing import List, Dict, Tuple
import os
import uuid
import json
import asyncio
import shutil
import time
import collections
from datetime import datetime
app = FastAPI()
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
TEMPLATE_DIR = os.path.join(BASE_DIR, "templates")
UPLOAD_DIR = os.path.join(BASE_DIR, "uploads")
BACKUP_DIR = os.path.join(BASE_DIR, "backups")
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(BACKUP_DIR, exist_ok=True)
CHAT_LOG_FILE = os.path.join(BASE_DIR, "chat_history.txt")
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB per file
ALLOWED_IMAGE_EXTENSIONS = {"png", "jpg", "jpeg", "gif", "bmp", "tiff", "webp", "heic"}
ALLOWED_VIDEO_EXTENSIONS = {"mp4", "mov", "wmv", "avi", "m4v", "mpg", "mpeg", "flv", "mkv", "3gp", "webm"}
templates = Jinja2Templates(directory=TEMPLATE_DIR)
app.mount("/uploads", StaticFiles(directory=UPLOAD_DIR), name="uploads")
# ── Global State ────────────────────────────────────────────
connections: Dict[WebSocket, str] = {} # ws -> username
glock = asyncio.Lock()
# ── Rate Limiting ─────────────────────────────────────────
# Upload rate limit: {ip: [timestamp, ...]}
upload_history: Dict[str, List[float]] = {}
UPLOAD_COOLDOWN = 3.0 # seconds between uploads
# Message rate limit: per-connection message timestamps
msg_history: Dict[WebSocket, List[float]] = {}
MSG_BURST = 8 # max messages
MSG_WINDOW = 3.0 # per this many seconds
def get_ext(fn: str) -> str:
return fn.rsplit(".", 1)[-1].lower()
def is_image(ext: str) -> bool:
return ext in ALLOWED_IMAGE_EXTENSIONS
def is_video(ext: str) -> bool:
return ext in ALLOWED_VIDEO_EXTENSIONS
def log_to_file(msg: str):
ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
try:
with open(CHAT_LOG_FILE, "a", encoding="utf-8") as f:
f.write(f"[{ts}] {msg}\n")
except Exception:
pass
def load_recent_history(n: int = 200) -> List[str]:
if not os.path.exists(CHAT_LOG_FILE):
return []
try:
with open(CHAT_LOG_FILE, "r", encoding="utf-8") as f:
lines = f.readlines()
return [l.strip() for l in lines[-n:]]
except Exception:
return []
# ── HTTP Endpoints ──────────────────────────────────────────
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
history = load_recent_history(200)
return templates.TemplateResponse("chat.html", {"request": request, "history": history})
@app.get("/members")
async def get_members():
"""Return current online member list."""
async with glock:
names = list(connections.values())
return JSONResponse({"members": names})
# ── Rate Limiter Helpers ────────────────────────────────
async def _check_upload_rate(request: Request) -> Tuple[bool, str]:
"""Check upload rate limit per client IP.
Returns (allowed, reason).
"""
ip = request.client.host if request.client else "unknown"
now = time.time()
async with glock:
times = upload_history.get(ip, [])
# Clean old entries
times = [t for t in times if now - t < UPLOAD_COOLDOWN]
if times:
remaining = int(UPLOAD_COOLDOWN - (now - times[-1]))
return False, f"上传太频繁,请 {remaining} 秒后再试"
times.append(now)
upload_history[ip] = times
return True, ""
async def _check_msg_rate(ws: WebSocket) -> bool:
"""Check message send rate per WebSocket connection."""
now = time.time()
async with glock:
times = msg_history.get(ws, [])
times = [t for t in times if now - t < MSG_WINDOW]
if len(times) >= MSG_BURST:
return False
times.append(now)
msg_history[ws] = times
return True
# ── HTTP Endpoints ──────────────────────────────────────────
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
history = load_recent_history(200)
return templates.TemplateResponse("chat.html", {"request": request, "history": history})
@app.get("/members")
async def get_members():
"""Return current online member list."""
async with glock:
names = list(connections.values())
return JSONResponse({"members": names})
@app.post("/upload")
async def upload_file(file: UploadFile = File(...), request: Request = None, sender: str = Form("匿名")):
# Rate limit check
allowed, reason = await _check_upload_rate(request)
if not allowed:
return JSONResponse({"error": reason}, status_code=429)
# File type check
ext = get_ext(file.filename)
if not (is_image(ext) or is_video(ext)):
return JSONResponse({"error": "不支持的文件类型,仅接受图片和视频"}, status_code=400)
# Read and check size
data = await file.read()
if len(data) > MAX_FILE_SIZE:
size_mb = len(data) / (1024 * 1024)
return JSONResponse({
"error": f"文件过大({size_mb:.1f}MB最大允许 10MB"
}, status_code=413)
# Save and broadcast
fn = f"{uuid.uuid4().hex}.{ext}"
path = os.path.join(UPLOAD_DIR, fn)
with open(path, "wb") as f:
f.write(data)
url = f"/uploads/{fn}"
# Format: TYPE::sender::url — frontend parses two colons
msg = f"IMG::{sender}::{url}" if is_image(ext) else f"VIDEO::{sender}::{url}"
asyncio.create_task(_broadcast(msg))
return {"url": url}
# ── Broadcast ───────────────────────────────────────────────
async def _broadcast(message: str):
"""Send message to all connected clients.
Uses snapshot-then-send pattern: grab targets under lock, then send concurrently without lock.
"""
async with glock:
targets = list(connections.items()) # [(ws, name), ...]
tasks = []
for ws, _ in targets:
tasks.append(_safe_send(ws, message))
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
async def _safe_send(ws: WebSocket, message: str):
"""Send text to one WebSocket, remove on failure."""
try:
await ws.send_text(message)
except Exception:
async with glock:
connections.pop(ws, None)
async def _broadcast_members():
"""Broadcast JSON member list and count to all clients."""
async with glock:
names = list(connections.values())
count = len(names)
payload = json.dumps({"type": "members", "data": names})
await _broadcast(f"MEMBERS::{payload}")
await _broadcast(f"COUNT::{count}")
# ── WebSocket Endpoint ──────────────────────────────────────
@app.websocket("/wschat")
async def ws_endpoint(ws: WebSocket):
await ws.accept()
# ── 1) Receive username as first text frame ──
try:
raw = await asyncio.wait_for(ws.receive_text(), timeout=30)
except Exception:
await ws.close(1008)
return
username = (raw.strip() or f"匿名{id(ws) % 10000}")[:20]
async with glock:
connections[ws] = username
join_msg = f"SYSTEM::{username} 加入聊天室"
log_to_file(join_msg)
asyncio.create_task(_broadcast(join_msg))
asyncio.create_task(_broadcast_members())
# ── 2) Message loop ──
try:
while True:
text = await ws.receive_text()
if not text or not text.strip():
continue
# Rate limit check
if not await _check_msg_rate(ws):
asyncio.create_task(_safe_send(ws, "SYSTEM::⚠️ 发送太频繁,请稍后再试"))
continue
msg = f"TEXT::{username}::{text.strip()}"
log_to_file(msg)
asyncio.create_task(_broadcast(msg))
except (WebSocketDisconnect, Exception):
pass
finally:
async with glock:
name = connections.pop(ws, None) or "匿名用户"
leave_msg = f"SYSTEM::{name} 离开聊天室"
log_to_file(leave_msg)
asyncio.create_task(_broadcast(leave_msg))
asyncio.create_task(_broadcast_members())
# ── Nightly Cleanup ─────────────────────────────────────────
@app.post("/api/nightly-cleanup")
async def nightly_cleanup():
"""Backup chat history to backups/ dir, then clear the log.
Also purge uploaded files older than 24h.
Called by cron / systemd timer.
"""
now = datetime.now()
backup_name = f"chat_history_{now.strftime('%Y%m%d_%H%M%S')}.txt"
backup_path = os.path.join(BACKUP_DIR, backup_name)
if os.path.exists(CHAT_LOG_FILE) and os.path.getsize(CHAT_LOG_FILE) > 0:
shutil.copy2(CHAT_LOG_FILE, backup_path)
# Clear log
open(CHAT_LOG_FILE, "w", encoding="utf-8").close()
# Clean old uploads (> 24h)
now_ts = time.time()
cleaned = 0
for fname in os.listdir(UPLOAD_DIR):
fpath = os.path.join(UPLOAD_DIR, fname)
if os.path.isfile(fpath) and now_ts - os.path.getmtime(fpath) > 86400:
try:
os.remove(fpath)
cleaned += 1
except Exception:
pass
return {"status": "ok", "backup": backup_name, "cleaned_uploads": cleaned}
# ── Main ────────────────────────────────────────────────────
if __name__ == "__main__":
import uvicorn
uvicorn.run("chat:app", host="0.0.0.0", port=8202, reload=False, log_level="info")