|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File, Request
|
|
|
from fastapi.responses import HTMLResponse
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
from fastapi.templating import Jinja2Templates
|
|
|
import os
|
|
|
import shutil
|
|
|
import uuid
|
|
|
import uvicorn
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
# =========================
|
|
|
# 基础配置
|
|
|
# =========================
|
|
|
|
|
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
|
|
TEMPLATE_DIR = os.path.join(BASE_DIR, "templates")
|
|
|
UPLOAD_DIR = os.path.join(BASE_DIR, "uploads")
|
|
|
|
|
|
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
|
|
|
|
|
|
ALLOWED_IMAGE_EXTENSIONS = {"png", "jpg", "jpeg", "gif"}
|
|
|
ALLOWED_VIDEO_EXTENSIONS = {"mp4", "webm"}
|
|
|
|
|
|
# =========================
|
|
|
# 目录初始化
|
|
|
# =========================
|
|
|
|
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
|
|
|
|
|
templates = Jinja2Templates(directory=TEMPLATE_DIR)
|
|
|
app.mount("/uploads", StaticFiles(directory=UPLOAD_DIR), name="uploads")
|
|
|
|
|
|
# =========================
|
|
|
# WebSocket 状态
|
|
|
# =========================
|
|
|
|
|
|
active_connections: list[WebSocket] = []
|
|
|
usernames: dict[WebSocket, str] = {}
|
|
|
|
|
|
# =========================
|
|
|
# 工具函数
|
|
|
# =========================
|
|
|
|
|
|
def get_ext(filename: str) -> str:
|
|
|
return filename.rsplit(".", 1)[-1].lower()
|
|
|
|
|
|
def is_image(ext: str) -> bool:
|
|
|
return ext in ALLOWED_IMAGE_EXTENSIONS
|
|
|
|
|
|
def is_video(ext: str) -> bool:
|
|
|
return ext in ALLOWED_VIDEO_EXTENSIONS
|
|
|
|
|
|
# =========================
|
|
|
# 页面路由
|
|
|
# =========================
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
|
async def index(request: Request):
|
|
|
return templates.TemplateResponse("chat.html", {"request": request})
|
|
|
|
|
|
# =========================
|
|
|
# 文件上传接口
|
|
|
# =========================
|
|
|
|
|
|
@app.post("/upload")
|
|
|
async def upload_file(file: UploadFile = File(...)):
|
|
|
ext = get_ext(file.filename)
|
|
|
|
|
|
if not (is_image(ext) or is_video(ext)):
|
|
|
return {"error": "不支持的文件类型"}
|
|
|
|
|
|
contents = await file.read()
|
|
|
|
|
|
if len(contents) > MAX_FILE_SIZE:
|
|
|
return {"error": "文件过大(最大 10MB)"}
|
|
|
|
|
|
filename = f"{uuid.uuid4().hex}.{ext}"
|
|
|
save_path = os.path.join(UPLOAD_DIR, filename)
|
|
|
|
|
|
with open(save_path, "wb") as f:
|
|
|
f.write(contents)
|
|
|
|
|
|
file_url = f"/uploads/{filename}"
|
|
|
|
|
|
if is_image(ext):
|
|
|
msg = f"IMG::{file_url}"
|
|
|
elif is_video(ext):
|
|
|
msg = f"VIDEO::{file_url}"
|
|
|
else:
|
|
|
msg = f"FILE::{file_url}"
|
|
|
|
|
|
await broadcast(msg)
|
|
|
|
|
|
return {"url": file_url}
|
|
|
|
|
|
# =========================
|
|
|
# WebSocket 聊天
|
|
|
# =========================
|
|
|
|
|
|
@app.websocket("/ws")
|
|
|
async def websocket_endpoint(ws: WebSocket):
|
|
|
await ws.accept()
|
|
|
active_connections.append(ws)
|
|
|
await broadcast_online_count()
|
|
|
|
|
|
try:
|
|
|
# 首条消息作为用户名
|
|
|
username = await ws.receive_text()
|
|
|
usernames[ws] = username
|
|
|
|
|
|
await broadcast(f"SYSTEM::{username} 加入聊天室")
|
|
|
await broadcast_online_count()
|
|
|
|
|
|
while True:
|
|
|
text = await ws.receive_text()
|
|
|
await broadcast(f"TEXT::{username}::{text}")
|
|
|
|
|
|
except WebSocketDisconnect:
|
|
|
pass
|
|
|
|
|
|
finally:
|
|
|
if ws in active_connections:
|
|
|
active_connections.remove(ws)
|
|
|
name = usernames.pop(ws, "匿名用户")
|
|
|
await broadcast(f"SYSTEM::{name} 离开聊天室")
|
|
|
await broadcast_online_count()
|
|
|
|
|
|
# =========================
|
|
|
# 广播工具
|
|
|
# =========================
|
|
|
|
|
|
async def broadcast(message: str):
|
|
|
dead = []
|
|
|
for ws in active_connections:
|
|
|
try:
|
|
|
await ws.send_text(message)
|
|
|
except Exception:
|
|
|
dead.append(ws)
|
|
|
|
|
|
for ws in dead:
|
|
|
active_connections.remove(ws)
|
|
|
usernames.pop(ws, None)
|
|
|
|
|
|
async def broadcast_online_count():
|
|
|
msg = f"COUNT::{len(active_connections)}"
|
|
|
await broadcast(msg)
|
|
|
|
|
|
# =========================
|
|
|
# 启动
|
|
|
# =========================
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
uvicorn.run(
|
|
|
"chat:app",
|
|
|
host="0.0.0.0",
|
|
|
port=8202,
|
|
|
reload=True
|
|
|
) |