初始化
This commit is contained in:
8
.idea/.gitignore
generated
vendored
Normal file
8
.idea/.gitignore
generated
vendored
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
# Default ignored files
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# Editor-based HTTP Client requests
|
||||||
|
/httpRequests/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
8
.idea/Video.iml
generated
Normal file
8
.idea/Video.iml
generated
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="PYTHON_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$" />
|
||||||
|
<orderEntry type="jdk" jdkName="video" jdkType="Python SDK" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
</module>
|
98
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
98
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<profile version="1.0">
|
||||||
|
<option name="myName" value="Project Default" />
|
||||||
|
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
|
||||||
|
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||||
|
<option name="ignoredPackages">
|
||||||
|
<value>
|
||||||
|
<list size="76">
|
||||||
|
<item index="0" class="java.lang.String" itemvalue="scipy" />
|
||||||
|
<item index="1" class="java.lang.String" itemvalue="protobuf" />
|
||||||
|
<item index="2" class="java.lang.String" itemvalue="thop" />
|
||||||
|
<item index="3" class="java.lang.String" itemvalue="opencv-python" />
|
||||||
|
<item index="4" class="java.lang.String" itemvalue="PyYAML" />
|
||||||
|
<item index="5" class="java.lang.String" itemvalue="ipython" />
|
||||||
|
<item index="6" class="java.lang.String" itemvalue="torch" />
|
||||||
|
<item index="7" class="java.lang.String" itemvalue="numpy" />
|
||||||
|
<item index="8" class="java.lang.String" itemvalue="requests" />
|
||||||
|
<item index="9" class="java.lang.String" itemvalue="torchvision" />
|
||||||
|
<item index="10" class="java.lang.String" itemvalue="psutil" />
|
||||||
|
<item index="11" class="java.lang.String" itemvalue="tqdm" />
|
||||||
|
<item index="12" class="java.lang.String" itemvalue="pandas" />
|
||||||
|
<item index="13" class="java.lang.String" itemvalue="tensorboard" />
|
||||||
|
<item index="14" class="java.lang.String" itemvalue="seaborn" />
|
||||||
|
<item index="15" class="java.lang.String" itemvalue="matplotlib" />
|
||||||
|
<item index="16" class="java.lang.String" itemvalue="Pillow" />
|
||||||
|
<item index="17" class="java.lang.String" itemvalue="fastapi" />
|
||||||
|
<item index="18" class="java.lang.String" itemvalue="uvicorn" />
|
||||||
|
<item index="19" class="java.lang.String" itemvalue="python-jose" />
|
||||||
|
<item index="20" class="java.lang.String" itemvalue="passlib" />
|
||||||
|
<item index="21" class="java.lang.String" itemvalue="pydantic" />
|
||||||
|
<item index="22" class="java.lang.String" itemvalue="sqlalchemy" />
|
||||||
|
<item index="23" class="java.lang.String" itemvalue="imageio_ffmpeg" />
|
||||||
|
<item index="24" class="java.lang.String" itemvalue="ultralytics" />
|
||||||
|
<item index="25" class="java.lang.String" itemvalue="future" />
|
||||||
|
<item index="26" class="java.lang.String" itemvalue="jose" />
|
||||||
|
<item index="27" class="java.lang.String" itemvalue="ffmpeg-python" />
|
||||||
|
<item index="28" class="java.lang.String" itemvalue="setuptools" />
|
||||||
|
<item index="29" class="java.lang.String" itemvalue="opencv_python" />
|
||||||
|
<item index="30" class="java.lang.String" itemvalue="rsa" />
|
||||||
|
<item index="31" class="java.lang.String" itemvalue="greenlet" />
|
||||||
|
<item index="32" class="java.lang.String" itemvalue="networkx" />
|
||||||
|
<item index="33" class="java.lang.String" itemvalue="python-dateutil" />
|
||||||
|
<item index="34" class="java.lang.String" itemvalue="SQLAlchemy" />
|
||||||
|
<item index="35" class="java.lang.String" itemvalue="cffi" />
|
||||||
|
<item index="36" class="java.lang.String" itemvalue="python-dotenv" />
|
||||||
|
<item index="37" class="java.lang.String" itemvalue="h11" />
|
||||||
|
<item index="38" class="java.lang.String" itemvalue="py-cpuinfo" />
|
||||||
|
<item index="39" class="java.lang.String" itemvalue="cycler" />
|
||||||
|
<item index="40" class="java.lang.String" itemvalue="MarkupSafe" />
|
||||||
|
<item index="41" class="java.lang.String" itemvalue="pyasn1" />
|
||||||
|
<item index="42" class="java.lang.String" itemvalue="pycparser" />
|
||||||
|
<item index="43" class="java.lang.String" itemvalue="Jinja2" />
|
||||||
|
<item index="44" class="java.lang.String" itemvalue="sniffio" />
|
||||||
|
<item index="45" class="java.lang.String" itemvalue="ultralytics-thop" />
|
||||||
|
<item index="46" class="java.lang.String" itemvalue="fsspec" />
|
||||||
|
<item index="47" class="java.lang.String" itemvalue="filelock" />
|
||||||
|
<item index="48" class="java.lang.String" itemvalue="starlette" />
|
||||||
|
<item index="49" class="java.lang.String" itemvalue="certifi" />
|
||||||
|
<item index="50" class="java.lang.String" itemvalue="anyio" />
|
||||||
|
<item index="51" class="java.lang.String" itemvalue="urllib3" />
|
||||||
|
<item index="52" class="java.lang.String" itemvalue="pyparsing" />
|
||||||
|
<item index="53" class="java.lang.String" itemvalue="sympy" />
|
||||||
|
<item index="54" class="java.lang.String" itemvalue="annotated-types" />
|
||||||
|
<item index="55" class="java.lang.String" itemvalue="pydantic-settings" />
|
||||||
|
<item index="56" class="java.lang.String" itemvalue="six" />
|
||||||
|
<item index="57" class="java.lang.String" itemvalue="tzdata" />
|
||||||
|
<item index="58" class="java.lang.String" itemvalue="ecdsa" />
|
||||||
|
<item index="59" class="java.lang.String" itemvalue="kiwisolver" />
|
||||||
|
<item index="60" class="java.lang.String" itemvalue="packaging" />
|
||||||
|
<item index="61" class="java.lang.String" itemvalue="python-multipart" />
|
||||||
|
<item index="62" class="java.lang.String" itemvalue="click" />
|
||||||
|
<item index="63" class="java.lang.String" itemvalue="contourpy" />
|
||||||
|
<item index="64" class="java.lang.String" itemvalue="fonttools" />
|
||||||
|
<item index="65" class="java.lang.String" itemvalue="pydantic_core" />
|
||||||
|
<item index="66" class="java.lang.String" itemvalue="av" />
|
||||||
|
<item index="67" class="java.lang.String" itemvalue="colorama" />
|
||||||
|
<item index="68" class="java.lang.String" itemvalue="mpmath" />
|
||||||
|
<item index="69" class="java.lang.String" itemvalue="argon2-cffi-bindings" />
|
||||||
|
<item index="70" class="java.lang.String" itemvalue="typing_extensions" />
|
||||||
|
<item index="71" class="java.lang.String" itemvalue="charset-normalizer" />
|
||||||
|
<item index="72" class="java.lang.String" itemvalue="pillow" />
|
||||||
|
<item index="73" class="java.lang.String" itemvalue="argon2-cffi" />
|
||||||
|
<item index="74" class="java.lang.String" itemvalue="pytz" />
|
||||||
|
<item index="75" class="java.lang.String" itemvalue="idna" />
|
||||||
|
</list>
|
||||||
|
</value>
|
||||||
|
</option>
|
||||||
|
</inspection_tool>
|
||||||
|
<inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
||||||
|
<option name="ignoredErrors">
|
||||||
|
<list>
|
||||||
|
<option value="N802" />
|
||||||
|
</list>
|
||||||
|
</option>
|
||||||
|
</inspection_tool>
|
||||||
|
<inspection_tool class="Stylelint" enabled="true" level="ERROR" enabled_by_default="true" />
|
||||||
|
</profile>
|
||||||
|
</component>
|
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<settings>
|
||||||
|
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||||
|
<version value="1.0" />
|
||||||
|
</settings>
|
||||||
|
</component>
|
7
.idea/misc.xml
generated
Normal file
7
.idea/misc.xml
generated
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="Black">
|
||||||
|
<option name="sdkName" value="video" />
|
||||||
|
</component>
|
||||||
|
<component name="ProjectRootManager" version="2" project-jdk-name="video" project-jdk-type="Python SDK" />
|
||||||
|
</project>
|
8
.idea/modules.xml
generated
Normal file
8
.idea/modules.xml
generated
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/Video.iml" filepath="$PROJECT_DIR$/.idea/Video.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
BIN
__pycache__/main.cpython-312.pyc
Normal file
BIN
__pycache__/main.cpython-312.pyc
Normal file
Binary file not shown.
19
config.ini
Normal file
19
config.ini
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
[server]
|
||||||
|
port = 8000
|
||||||
|
|
||||||
|
[mysql]
|
||||||
|
host = 192.168.110.65
|
||||||
|
port = 6975
|
||||||
|
user = video_check
|
||||||
|
password = fsjPfhxCs8NrFGmL
|
||||||
|
database = video_check
|
||||||
|
charset = utf8mb4
|
||||||
|
|
||||||
|
[jwt]
|
||||||
|
secret_key = 6tsieyd87wefdw2wgeduwte23rfcsd
|
||||||
|
algorithm = HS256
|
||||||
|
access_token_expire_minutes = 30
|
||||||
|
|
||||||
|
[live]
|
||||||
|
rtmp_url = rtmp://192.168.110.65:1935/live/
|
||||||
|
webrtc_url = http://192.168.110.65:1985/rtc/v1/whep/?app=live&stream=
|
BIN
ds/__pycache__/config.cpython-312.pyc
Normal file
BIN
ds/__pycache__/config.cpython-312.pyc
Normal file
Binary file not shown.
BIN
ds/__pycache__/db.cpython-312.pyc
Normal file
BIN
ds/__pycache__/db.cpython-312.pyc
Normal file
Binary file not shown.
17
ds/config.py
Normal file
17
ds/config.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
import configparser
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 读取配置文件路径
|
||||||
|
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../config.ini")
|
||||||
|
|
||||||
|
# 初始化配置解析器
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
|
||||||
|
# 读取配置文件
|
||||||
|
config.read(config_path, encoding="utf-8")
|
||||||
|
|
||||||
|
# 暴露配置项(方便其他文件调用)
|
||||||
|
SERVER_CONFIG = config["server"]
|
||||||
|
MYSQL_CONFIG = config["mysql"]
|
||||||
|
JWT_CONFIG = config["jwt"]
|
||||||
|
LIVE_CONFIG = config["live"]
|
46
ds/db.py
Normal file
46
ds/db.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
import mysql.connector
|
||||||
|
from mysql.connector import Error
|
||||||
|
|
||||||
|
from .config import MYSQL_CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
class Database:
|
||||||
|
"""MySQL 连接池管理类"""
|
||||||
|
pool_config = {
|
||||||
|
"host": MYSQL_CONFIG.get("host", "localhost"),
|
||||||
|
"port": int(MYSQL_CONFIG.get("port", 3306)),
|
||||||
|
"user": MYSQL_CONFIG.get("user", "root"),
|
||||||
|
"password": MYSQL_CONFIG.get("password", ""),
|
||||||
|
"database": MYSQL_CONFIG.get("database", ""),
|
||||||
|
"charset": MYSQL_CONFIG.get("charset", "utf8mb4"),
|
||||||
|
"pool_name": "fastapi_pool",
|
||||||
|
"pool_size": 5,
|
||||||
|
"pool_reset_session": True
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_connection(cls):
|
||||||
|
"""获取数据库连接"""
|
||||||
|
try:
|
||||||
|
# 从连接池获取连接
|
||||||
|
conn = mysql.connector.connect(**cls.pool_config)
|
||||||
|
if conn.is_connected():
|
||||||
|
return conn
|
||||||
|
except Error as e:
|
||||||
|
# 抛出数据库连接错误(会被全局异常处理器捕获)
|
||||||
|
raise Exception(f"MySQL 连接失败: {str(e)}") from e
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def close_connection(cls, conn, cursor=None):
|
||||||
|
"""关闭连接和游标"""
|
||||||
|
try:
|
||||||
|
if cursor:
|
||||||
|
cursor.close()
|
||||||
|
if conn and conn.is_connected():
|
||||||
|
conn.close()
|
||||||
|
except Error as e:
|
||||||
|
raise Exception(f"MySQL 连接关闭失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
|
# 暴露数据库操作工具
|
||||||
|
db = Database()
|
43
main.py
Normal file
43
main.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from ds.config import SERVER_CONFIG
|
||||||
|
from middle.error_handler import global_exception_handler
|
||||||
|
from service.user_service import router as user_router
|
||||||
|
from service.device_service import router as device_router
|
||||||
|
from ws.ws import ws_router, lifespan
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 初始化 FastAPI 应用、指定生命周期管理
|
||||||
|
# ------------------------------
|
||||||
|
app = FastAPI(
|
||||||
|
title="内容安全审核后台",
|
||||||
|
description="内容安全审核后台",
|
||||||
|
version="1.0.0",
|
||||||
|
lifespan=lifespan
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 注册路由
|
||||||
|
# ------------------------------
|
||||||
|
app.include_router(user_router)
|
||||||
|
app.include_router(device_router)
|
||||||
|
app.include_router(ws_router)
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 注册全局异常处理器
|
||||||
|
# ------------------------------
|
||||||
|
app.add_exception_handler(Exception, global_exception_handler)
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 启动服务
|
||||||
|
# ------------------------------
|
||||||
|
if __name__ == "__main__":
|
||||||
|
port = int(SERVER_CONFIG.get("port", 8000))
|
||||||
|
uvicorn.run(
|
||||||
|
app="main:app",
|
||||||
|
host="0.0.0.0",
|
||||||
|
port=port,
|
||||||
|
reload=True,
|
||||||
|
ws="websockets"
|
||||||
|
)
|
BIN
middle/__pycache__/auth_middleware.cpython-312.pyc
Normal file
BIN
middle/__pycache__/auth_middleware.cpython-312.pyc
Normal file
Binary file not shown.
BIN
middle/__pycache__/error_handler.cpython-312.pyc
Normal file
BIN
middle/__pycache__/error_handler.cpython-312.pyc
Normal file
Binary file not shown.
96
middle/auth_middleware.py
Normal file
96
middle/auth_middleware.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
from passlib.context import CryptContext
|
||||||
|
|
||||||
|
from ds.config import JWT_CONFIG
|
||||||
|
from ds.db import db
|
||||||
|
from service.user_service import UserResponse
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 密码加密配置
|
||||||
|
# ------------------------------
|
||||||
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# JWT 配置
|
||||||
|
# ------------------------------
|
||||||
|
SECRET_KEY = JWT_CONFIG["secret_key"]
|
||||||
|
ALGORITHM = JWT_CONFIG["algorithm"]
|
||||||
|
ACCESS_TOKEN_EXPIRE_MINUTES = int(JWT_CONFIG["access_token_expire_minutes"])
|
||||||
|
|
||||||
|
# OAuth2 依赖(从请求头获取 Token、格式:Bearer <token>)
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login")
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 密码工具函数
|
||||||
|
# ------------------------------
|
||||||
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||||
|
"""验证明文密码与加密密码是否匹配"""
|
||||||
|
return pwd_context.verify(plain_password, hashed_password)
|
||||||
|
|
||||||
|
def get_password_hash(password: str) -> str:
|
||||||
|
"""对明文密码进行 bcrypt 加密"""
|
||||||
|
return pwd_context.hash(password)
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# JWT 工具函数
|
||||||
|
# ------------------------------
|
||||||
|
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||||
|
"""生成 JWT Token"""
|
||||||
|
to_encode = data.copy()
|
||||||
|
# 设置过期时间
|
||||||
|
if expires_delta:
|
||||||
|
expire = datetime.now(timezone.utc) + expires_delta
|
||||||
|
else:
|
||||||
|
expire = datetime.now(timezone.utc) + timedelta(minutes=15)
|
||||||
|
# 添加过期时间到 Token 数据
|
||||||
|
to_encode.update({"exp": expire})
|
||||||
|
# 生成 Token
|
||||||
|
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||||
|
return encoded_jwt
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 认证依赖(获取当前登录用户)
|
||||||
|
# ------------------------------
|
||||||
|
def get_current_user(token: str = Depends(oauth2_scheme)) -> UserResponse:
|
||||||
|
"""从 Token 中解析用户信息、验证通过后返回当前用户"""
|
||||||
|
# 认证失败异常
|
||||||
|
credentials_exception = HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Token 无效或已过期",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 解码 Token
|
||||||
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
|
# 获取 Token 中的用户名
|
||||||
|
username: str = payload.get("sub")
|
||||||
|
if username is None:
|
||||||
|
raise credentials_exception
|
||||||
|
except JWTError:
|
||||||
|
raise credentials_exception
|
||||||
|
|
||||||
|
# 从数据库查询用户(验证用户是否存在)
|
||||||
|
conn = None
|
||||||
|
cursor = None
|
||||||
|
try:
|
||||||
|
conn = db.get_connection()
|
||||||
|
cursor = conn.cursor(dictionary=True) # 返回字典格式结果
|
||||||
|
query = "SELECT id, username, created_at, updated_at FROM users WHERE username = %s"
|
||||||
|
cursor.execute(query, (username,))
|
||||||
|
user = cursor.fetchone()
|
||||||
|
|
||||||
|
if user is None:
|
||||||
|
raise credentials_exception # 用户不存在
|
||||||
|
|
||||||
|
# 转换为 UserResponse 模型(自动校验字段)
|
||||||
|
return UserResponse(** user)
|
||||||
|
except Exception as e:
|
||||||
|
raise credentials_exception from e
|
||||||
|
finally:
|
||||||
|
db.close_connection(conn, cursor)
|
68
middle/error_handler.py
Normal file
68
middle/error_handler.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
from fastapi import Request, status
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from fastapi.exceptions import HTTPException, RequestValidationError
|
||||||
|
from mysql.connector import Error as MySQLError
|
||||||
|
from jose import JWTError
|
||||||
|
|
||||||
|
from schema.response_schema import APIResponse
|
||||||
|
|
||||||
|
|
||||||
|
async def global_exception_handler(request: Request, exc: Exception):
|
||||||
|
"""全局异常处理器:所有未捕获的异常都会在这里统一处理"""
|
||||||
|
# 1. 请求参数验证错误(Pydantic 校验失败)
|
||||||
|
if isinstance(exc, RequestValidationError):
|
||||||
|
error_details = []
|
||||||
|
for err in exc.errors():
|
||||||
|
error_details.append(f"{err['loc'][1]}: {err['msg']}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
content=APIResponse(
|
||||||
|
code=400,
|
||||||
|
message=f"请求参数错误:{'; '.join(error_details)}",
|
||||||
|
data=None
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. HTTP 异常(主动抛出的业务错误、如 401/404)
|
||||||
|
if isinstance(exc, HTTPException):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content=APIResponse(
|
||||||
|
code=exc.status_code,
|
||||||
|
message=exc.detail,
|
||||||
|
data=None
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. JWT 相关错误(Token 无效/过期)
|
||||||
|
if isinstance(exc, JWTError):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
content=APIResponse(
|
||||||
|
code=401,
|
||||||
|
message="Token 无效或已过期",
|
||||||
|
data=None
|
||||||
|
).model_dump(),
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. MySQL 数据库错误
|
||||||
|
if isinstance(exc, MySQLError):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
content=APIResponse(
|
||||||
|
code=500,
|
||||||
|
message=f"数据库错误:{str(exc)}",
|
||||||
|
data=None
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. 其他未知错误(兜底处理)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
content=APIResponse(
|
||||||
|
code=500,
|
||||||
|
message=f"服务器内部错误:{str(exc)}",
|
||||||
|
data=None
|
||||||
|
).model_dump()
|
||||||
|
)
|
BIN
schema/__pycache__/device_schema.cpython-312.pyc
Normal file
BIN
schema/__pycache__/device_schema.cpython-312.pyc
Normal file
Binary file not shown.
BIN
schema/__pycache__/response_schema.cpython-312.pyc
Normal file
BIN
schema/__pycache__/response_schema.cpython-312.pyc
Normal file
Binary file not shown.
BIN
schema/__pycache__/user_schema.cpython-312.pyc
Normal file
BIN
schema/__pycache__/user_schema.cpython-312.pyc
Normal file
Binary file not shown.
51
schema/device_schema.py
Normal file
51
schema/device_schema.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import hashlib
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, List, Dict
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 请求模型(前端传参校验)
|
||||||
|
# ------------------------------
|
||||||
|
class DeviceCreateRequest(BaseModel):
|
||||||
|
"""设备流信息创建请求模型"""
|
||||||
|
ip: Optional[str] = Field(..., max_length=100, description="设备IP地址")
|
||||||
|
hostname: Optional[str] = Field(None, max_length=100, description="设备别名")
|
||||||
|
params: Optional[Dict] = Field(None, description="设备详细信息")
|
||||||
|
|
||||||
|
|
||||||
|
def md5_encrypt(text: str) -> str:
|
||||||
|
"""对字符串进行MD5加密"""
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
md5_hash = hashlib.md5()
|
||||||
|
md5_hash.update(text.encode('utf-8'))
|
||||||
|
return md5_hash.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 响应模型(后端返回设备数据)
|
||||||
|
# ------------------------------
|
||||||
|
class DeviceResponse(BaseModel):
|
||||||
|
"""设备流信息响应模型(字段与表结构完全对齐)"""
|
||||||
|
id: int = Field(..., description="设备ID")
|
||||||
|
hostname: Optional[str] = Field(None, max_length=100, description="设备别名")
|
||||||
|
rtmp_push_url: Optional[str] = Field(None, description="需要推送的RTMP地址")
|
||||||
|
live_webrtc_url: Optional[str] = Field(None, description="直播的Webrtc地址")
|
||||||
|
detection_webrtc_url: Optional[str] = Field(None, description="检测的Webrtc地址")
|
||||||
|
device_online_status: int = Field(..., description="设备在线状态(1-在线、0-离线)")
|
||||||
|
device_type: Optional[str] = Field(None, description="设备类型")
|
||||||
|
alarm_count: int = Field(..., description="报警次数")
|
||||||
|
params: Optional[str] = Field(None, description="设备详细信息")
|
||||||
|
created_at: datetime = Field(..., description="记录创建时间")
|
||||||
|
updated_at: datetime = Field(..., description="记录更新时间")
|
||||||
|
|
||||||
|
# 支持从数据库查询结果转换
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceListResponse(BaseModel):
|
||||||
|
"""设备流信息列表响应模型"""
|
||||||
|
total: int = Field(..., description="设备总数")
|
||||||
|
devices: List[DeviceResponse] = Field(..., description="设备列表")
|
13
schema/response_schema.py
Normal file
13
schema/response_schema.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from typing import Optional, Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class APIResponse(BaseModel):
|
||||||
|
"""统一 API 响应模型(所有接口必返此格式)"""
|
||||||
|
code: int = Field(..., description="状态码:200=成功、4xx=客户端错误、5xx=服务端错误")
|
||||||
|
message: str = Field(..., description="响应信息:成功/错误描述")
|
||||||
|
data: Optional[Any] = Field(None, description="响应数据:成功时返回、错误时为 None")
|
||||||
|
|
||||||
|
# Pydantic V2 配置(支持从 ORM 对象转换)
|
||||||
|
model_config = {"from_attributes": True}
|
32
schema/user_schema.py
Normal file
32
schema/user_schema.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 请求模型(前端传参校验)
|
||||||
|
# ------------------------------
|
||||||
|
class UserRegisterRequest(BaseModel):
|
||||||
|
"""用户注册请求模型"""
|
||||||
|
username: str = Field(..., min_length=3, max_length=50, description="用户名(3-50字符)")
|
||||||
|
password: str = Field(..., min_length=6, max_length=100, description="密码(6-100字符)")
|
||||||
|
|
||||||
|
|
||||||
|
class UserLoginRequest(BaseModel):
|
||||||
|
"""用户登录请求模型"""
|
||||||
|
username: str = Field(..., description="用户名")
|
||||||
|
password: str = Field(..., description="密码")
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 响应模型(后端返回用户数据)
|
||||||
|
# ------------------------------
|
||||||
|
class UserResponse(BaseModel):
|
||||||
|
"""用户信息响应模型(隐藏密码等敏感字段)"""
|
||||||
|
id: int = Field(..., description="用户ID")
|
||||||
|
username: str = Field(..., description="用户名")
|
||||||
|
created_at: datetime = Field(..., description="创建时间")
|
||||||
|
updated_at: datetime = Field(..., description="更新时间")
|
||||||
|
|
||||||
|
# Pydantic V2 配置(支持从数据库查询结果转换)
|
||||||
|
model_config = {"from_attributes": True}
|
BIN
service/__pycache__/device_service.cpython-312.pyc
Normal file
BIN
service/__pycache__/device_service.cpython-312.pyc
Normal file
Binary file not shown.
BIN
service/__pycache__/user_service.cpython-312.pyc
Normal file
BIN
service/__pycache__/user_service.cpython-312.pyc
Normal file
Binary file not shown.
251
service/device_service.py
Normal file
251
service/device_service.py
Normal file
@ -0,0 +1,251 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Query, APIRouter, Depends, Request
|
||||||
|
from mysql.connector import Error as MySQLError
|
||||||
|
|
||||||
|
from ds.config import LIVE_CONFIG
|
||||||
|
from ds.db import db
|
||||||
|
from middle.auth_middleware import get_current_user
|
||||||
|
# 注意:导入的Schema已更新字段
|
||||||
|
from schema.device_schema import (
|
||||||
|
DeviceCreateRequest,
|
||||||
|
DeviceResponse,
|
||||||
|
DeviceListResponse,
|
||||||
|
md5_encrypt
|
||||||
|
)
|
||||||
|
from schema.response_schema import APIResponse
|
||||||
|
from schema.user_schema import UserResponse
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/devices",
|
||||||
|
tags=["设备管理"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 1. 创建设备信息
|
||||||
|
# ------------------------------
|
||||||
|
@router.post("/add", response_model=APIResponse, summary="创建设备信息")
|
||||||
|
async def create_device(request: Request, device_data: DeviceCreateRequest):
|
||||||
|
conn = None
|
||||||
|
cursor = None
|
||||||
|
try:
|
||||||
|
conn = db.get_connection()
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
|
# 新增:检查client_ip是否已存在
|
||||||
|
cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (device_data.ip,))
|
||||||
|
existing_device = cursor.fetchone()
|
||||||
|
if existing_device:
|
||||||
|
raise Exception(f"客户端IP {device_data.ip} 已存在,无法重复添加")
|
||||||
|
|
||||||
|
# 获取RTMP URL
|
||||||
|
rtmp_url = str(LIVE_CONFIG.get("rtmp_url", ""))
|
||||||
|
webrtc_url = str(LIVE_CONFIG.get("webrtc_url", ""))
|
||||||
|
|
||||||
|
# 将设备详细信息(params)转换为JSON字符串(对应表中params字段)
|
||||||
|
device_params_json = json.dumps(device_data.params) if device_data.params else None
|
||||||
|
|
||||||
|
# 对JSON字符串进行MD5加密(用于生成唯一RTMP地址)
|
||||||
|
device_md5 = md5_encrypt(device_params_json) if device_params_json else ""
|
||||||
|
|
||||||
|
# 解析User-Agent获取设备类型
|
||||||
|
user_agent = request.headers.get("User-Agent", "").lower()
|
||||||
|
|
||||||
|
# 优先处理User-Agent为default的情况
|
||||||
|
if user_agent == "default":
|
||||||
|
# 检查params中是否存在os键
|
||||||
|
if device_data.params and isinstance(device_data.params, dict) and "os" in device_data.params:
|
||||||
|
device_type = device_data.params["os"]
|
||||||
|
else:
|
||||||
|
device_type = "unknown"
|
||||||
|
elif "windows" in user_agent:
|
||||||
|
device_type = "windows"
|
||||||
|
elif "android" in user_agent:
|
||||||
|
device_type = "android"
|
||||||
|
elif "linux" in user_agent:
|
||||||
|
device_type = "linux"
|
||||||
|
else:
|
||||||
|
device_type = "unknown"
|
||||||
|
|
||||||
|
# SQL字段对齐表结构
|
||||||
|
insert_query = """
|
||||||
|
INSERT INTO devices
|
||||||
|
(client_ip, hostname, rtmp_push_url, live_webrtc_url, detection_webrtc_url,
|
||||||
|
device_online_status, device_type, alarm_count, params)
|
||||||
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||||
|
"""
|
||||||
|
cursor.execute(insert_query, (
|
||||||
|
device_data.ip,
|
||||||
|
device_data.hostname,
|
||||||
|
rtmp_url + device_md5,
|
||||||
|
webrtc_url + device_md5,
|
||||||
|
"",
|
||||||
|
1,
|
||||||
|
device_type,
|
||||||
|
0,
|
||||||
|
device_params_json
|
||||||
|
))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
# 获取刚创建的设备信息
|
||||||
|
device_id = cursor.lastrowid
|
||||||
|
cursor.execute("SELECT * FROM devices WHERE id = %s", (device_id,))
|
||||||
|
device = cursor.fetchone()
|
||||||
|
|
||||||
|
return APIResponse(
|
||||||
|
code=200,
|
||||||
|
message="设备创建成功",
|
||||||
|
data=DeviceResponse(**device)
|
||||||
|
)
|
||||||
|
except MySQLError as e:
|
||||||
|
if conn:
|
||||||
|
conn.rollback()
|
||||||
|
raise Exception(f"创建设备失败:{str(e)}") from e
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise Exception(f"设备信息JSON序列化失败:{str(e)}") from e
|
||||||
|
except Exception as e:
|
||||||
|
# 捕获IP已存在的自定义异常
|
||||||
|
if conn:
|
||||||
|
conn.rollback()
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 2. 获取设备列表
|
||||||
|
# ------------------------------
|
||||||
|
@router.get("/", response_model=APIResponse, summary="获取设备列表")
|
||||||
|
async def get_device_list(
|
||||||
|
page: int = Query(1, ge=1, description="页码"),
|
||||||
|
page_size: int = Query(10, ge=1, le=100, description="每页条数"),
|
||||||
|
device_type: str = Query(None, description="设备类型筛选"),
|
||||||
|
online_status: int = Query(None, ge=0, le=1, description="在线状态筛选(1-在线、0-离线)")
|
||||||
|
):
|
||||||
|
conn = None
|
||||||
|
cursor = None
|
||||||
|
try:
|
||||||
|
conn = db.get_connection()
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
|
# 构建查询条件
|
||||||
|
where_clause = []
|
||||||
|
params = []
|
||||||
|
|
||||||
|
if device_type:
|
||||||
|
where_clause.append("device_type = %s")
|
||||||
|
params.append(device_type)
|
||||||
|
|
||||||
|
if online_status is not None:
|
||||||
|
where_clause.append("device_online_status = %s")
|
||||||
|
params.append(online_status)
|
||||||
|
|
||||||
|
# 总条数查询
|
||||||
|
count_query = "SELECT COUNT(*) as total FROM devices"
|
||||||
|
if where_clause:
|
||||||
|
count_query += " WHERE " + " AND ".join(where_clause)
|
||||||
|
|
||||||
|
cursor.execute(count_query, params)
|
||||||
|
total = cursor.fetchone()["total"]
|
||||||
|
|
||||||
|
# 分页查询(SELECT * 会自动匹配表字段、响应模型已对齐)
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
query = "SELECT * FROM devices"
|
||||||
|
if where_clause:
|
||||||
|
query += " WHERE " + " AND ".join(where_clause)
|
||||||
|
query += " ORDER BY id DESC LIMIT %s OFFSET %s"
|
||||||
|
params.extend([page_size, offset])
|
||||||
|
|
||||||
|
cursor.execute(query, params)
|
||||||
|
devices = cursor.fetchall()
|
||||||
|
|
||||||
|
# 响应模型已更新为params字段、直接转换即可
|
||||||
|
device_list = [DeviceResponse(**device) for device in devices]
|
||||||
|
|
||||||
|
return APIResponse(
|
||||||
|
code=200,
|
||||||
|
message="获取设备列表成功",
|
||||||
|
data=DeviceListResponse(total=total, devices=device_list)
|
||||||
|
)
|
||||||
|
except MySQLError as e:
|
||||||
|
raise Exception(f"获取设备列表失败:{str(e)}") from e
|
||||||
|
finally:
|
||||||
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 3. 获取单个设备详情
|
||||||
|
# ------------------------------
|
||||||
|
@router.get("/{device_id}", response_model=APIResponse, summary="获取设备详情")
|
||||||
|
async def get_device_detail(
|
||||||
|
device_id: int,
|
||||||
|
current_user: UserResponse = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
conn = None
|
||||||
|
cursor = None
|
||||||
|
try:
|
||||||
|
conn = db.get_connection()
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
|
# 查询设备信息(SELECT * 匹配表字段)
|
||||||
|
query = "SELECT * FROM devices WHERE id = %s"
|
||||||
|
cursor.execute(query, (device_id,))
|
||||||
|
device = cursor.fetchone()
|
||||||
|
|
||||||
|
if not device:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"设备ID为 {device_id} 的设备不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 响应模型已更新为params字段
|
||||||
|
return APIResponse(
|
||||||
|
code=200,
|
||||||
|
message="获取设备详情成功",
|
||||||
|
data=DeviceResponse(**device)
|
||||||
|
)
|
||||||
|
except MySQLError as e:
|
||||||
|
raise Exception(f"获取设备详情失败:{str(e)}") from e
|
||||||
|
finally:
|
||||||
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 4. 删除设备信息
|
||||||
|
# ------------------------------
|
||||||
|
@router.delete("/{device_id}", response_model=APIResponse, summary="删除设备信息")
|
||||||
|
async def delete_device(
|
||||||
|
device_id: int,
|
||||||
|
current_user: UserResponse = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
conn = None
|
||||||
|
cursor = None
|
||||||
|
try:
|
||||||
|
conn = db.get_connection()
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
|
# 检查设备是否存在
|
||||||
|
cursor.execute("SELECT id FROM devices WHERE id = %s", (device_id,))
|
||||||
|
if not cursor.fetchone():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"设备ID为 {device_id} 的设备不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 执行删除
|
||||||
|
delete_query = "DELETE FROM devices WHERE id = %s"
|
||||||
|
cursor.execute(delete_query, (device_id,))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
return APIResponse(
|
||||||
|
code=200,
|
||||||
|
message=f"设备ID为 {device_id} 的设备已成功删除",
|
||||||
|
data=None
|
||||||
|
)
|
||||||
|
except MySQLError as e:
|
||||||
|
if conn:
|
||||||
|
conn.rollback()
|
||||||
|
raise Exception(f"删除设备失败:{str(e)}") from e
|
||||||
|
finally:
|
||||||
|
db.close_connection(conn, cursor)
|
154
service/user_service.py
Normal file
154
service/user_service.py
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from mysql.connector import Error as MySQLError
|
||||||
|
|
||||||
|
from ds.db import db
|
||||||
|
from schema.user_schema import UserRegisterRequest, UserLoginRequest, UserResponse
|
||||||
|
from schema.response_schema import APIResponse
|
||||||
|
from middle.auth_middleware import (
|
||||||
|
get_password_hash,
|
||||||
|
verify_password,
|
||||||
|
create_access_token,
|
||||||
|
ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||||
|
get_current_user
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建用户接口路由(前缀 /users、标签用于 Swagger 分类)
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/users",
|
||||||
|
tags=["用户管理"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 1. 用户注册接口
|
||||||
|
# ------------------------------
|
||||||
|
@router.post("/register", response_model=APIResponse, summary="用户注册")
|
||||||
|
async def user_register(request: UserRegisterRequest):
|
||||||
|
"""
|
||||||
|
用户注册:
|
||||||
|
- 校验用户名是否已存在
|
||||||
|
- 加密密码后插入数据库
|
||||||
|
- 返回注册成功信息
|
||||||
|
"""
|
||||||
|
conn = None
|
||||||
|
cursor = None
|
||||||
|
try:
|
||||||
|
conn = db.get_connection()
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
|
# 1. 检查用户名是否已存在(唯一索引)
|
||||||
|
check_query = "SELECT username FROM users WHERE username = %s"
|
||||||
|
cursor.execute(check_query, (request.username,))
|
||||||
|
existing_user = cursor.fetchone()
|
||||||
|
if existing_user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"用户名 '{request.username}' 已存在、请更换其他用户名"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. 加密密码
|
||||||
|
hashed_password = get_password_hash(request.password)
|
||||||
|
|
||||||
|
# 3. 插入新用户到数据库
|
||||||
|
insert_query = """
|
||||||
|
INSERT INTO users (username, password)
|
||||||
|
VALUES (%s, %s)
|
||||||
|
"""
|
||||||
|
cursor.execute(insert_query, (request.username, hashed_password))
|
||||||
|
conn.commit() # 提交事务
|
||||||
|
|
||||||
|
# 4. 返回注册成功响应
|
||||||
|
return APIResponse(
|
||||||
|
code=201, # 201 表示资源创建成功
|
||||||
|
message=f"用户 '{request.username}' 注册成功",
|
||||||
|
data=None
|
||||||
|
)
|
||||||
|
except MySQLError as e:
|
||||||
|
conn.rollback() # 数据库错误时回滚事务
|
||||||
|
raise Exception(f"注册失败:{str(e)}") from e
|
||||||
|
finally:
|
||||||
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 2. 用户登录接口
|
||||||
|
# ------------------------------
|
||||||
|
@router.post("/login", response_model=APIResponse, summary="用户登录(获取 Token)")
|
||||||
|
async def user_login(request: UserLoginRequest):
|
||||||
|
"""
|
||||||
|
用户登录:
|
||||||
|
- 校验用户名是否存在
|
||||||
|
- 校验密码是否正确
|
||||||
|
- 生成 JWT Token 并返回
|
||||||
|
"""
|
||||||
|
conn = None
|
||||||
|
cursor = None
|
||||||
|
try:
|
||||||
|
conn = db.get_connection()
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
|
# 修复:SQL查询添加 created_at 和 updated_at 字段
|
||||||
|
query = """
|
||||||
|
SELECT id, username, password, created_at, updated_at
|
||||||
|
FROM users
|
||||||
|
WHERE username = %s
|
||||||
|
"""
|
||||||
|
cursor.execute(query, (request.username,))
|
||||||
|
user = cursor.fetchone()
|
||||||
|
|
||||||
|
# 2. 校验用户名和密码
|
||||||
|
if not user or not verify_password(request.password, user["password"]):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail="用户名或密码错误",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 生成 Token(过期时间从配置读取)
|
||||||
|
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||||
|
access_token = create_access_token(
|
||||||
|
data={"sub": user["username"]},
|
||||||
|
expires_delta=access_token_expires
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. 返回 Token 和用户基本信息
|
||||||
|
return APIResponse(
|
||||||
|
code=200,
|
||||||
|
message="登录成功",
|
||||||
|
data={
|
||||||
|
"access_token": access_token,
|
||||||
|
"token_type": "bearer",
|
||||||
|
"user": UserResponse(
|
||||||
|
id=user["id"],
|
||||||
|
username=user["username"],
|
||||||
|
created_at=user.get("created_at"),
|
||||||
|
updated_at=user.get("updated_at")
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except MySQLError as e:
|
||||||
|
raise Exception(f"登录失败:{str(e)}") from e
|
||||||
|
finally:
|
||||||
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 3. 获取当前登录用户信息(需认证)
|
||||||
|
# ------------------------------
|
||||||
|
@router.get("/me", response_model=APIResponse, summary="获取当前用户信息")
|
||||||
|
async def get_current_user_info(
|
||||||
|
current_user: UserResponse = Depends(get_current_user) # 依赖认证中间件
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取当前登录用户信息:
|
||||||
|
- 需在请求头携带 Token(格式:Bearer <token>)
|
||||||
|
- 认证通过后返回用户信息
|
||||||
|
"""
|
||||||
|
return APIResponse(
|
||||||
|
code=200,
|
||||||
|
message="获取用户信息成功",
|
||||||
|
data=current_user
|
||||||
|
)
|
||||||
|
|
482
ws.html
Normal file
482
ws.html
Normal file
@ -0,0 +1,482 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="zh-CN">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>WebSocket 测试工具</title>
|
||||||
|
<style>
|
||||||
|
* {
|
||||||
|
box-sizing: border-box;
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
font-family: 'Arial', 'Microsoft YaHei', sans-serif;
|
||||||
|
}
|
||||||
|
|
||||||
|
body {
|
||||||
|
max-width: 1200px;
|
||||||
|
margin: 20px auto;
|
||||||
|
padding: 0 20px;
|
||||||
|
background-color: #f5f7fa;
|
||||||
|
}
|
||||||
|
|
||||||
|
.container {
|
||||||
|
background: white;
|
||||||
|
border-radius: 8px;
|
||||||
|
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
|
||||||
|
padding: 25px;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
}
|
||||||
|
|
||||||
|
h1 {
|
||||||
|
color: #2c3e50;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
font-size: 24px;
|
||||||
|
border-bottom: 2px solid #3498db;
|
||||||
|
padding-bottom: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-bar {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 15px;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
padding: 12px 15px;
|
||||||
|
background-color: #f8f9fa;
|
||||||
|
border-radius: 6px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-label {
|
||||||
|
font-weight: bold;
|
||||||
|
color: #495057;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-value {
|
||||||
|
padding: 4px 10px;
|
||||||
|
border-radius: 4px;
|
||||||
|
font-weight: bold;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-connected {
|
||||||
|
background-color: #d4edda;
|
||||||
|
color: #155724;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-disconnected {
|
||||||
|
background-color: #f8d7da;
|
||||||
|
color: #721c24;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-connecting {
|
||||||
|
background-color: #fff3cd;
|
||||||
|
color: #856404;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn {
|
||||||
|
padding: 8px 16px;
|
||||||
|
border: none;
|
||||||
|
border-radius: 4px;
|
||||||
|
cursor: pointer;
|
||||||
|
font-size: 14px;
|
||||||
|
font-weight: 500;
|
||||||
|
transition: background-color 0.2s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-primary {
|
||||||
|
background-color: #3498db;
|
||||||
|
color: white;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-primary:hover {
|
||||||
|
background-color: #2980b9;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-danger {
|
||||||
|
background-color: #e74c3c;
|
||||||
|
color: white;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-danger:hover {
|
||||||
|
background-color: #c0392b;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-success {
|
||||||
|
background-color: #2ecc71;
|
||||||
|
color: white;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-success:hover {
|
||||||
|
background-color: #27ae60;
|
||||||
|
}
|
||||||
|
|
||||||
|
.control-group {
|
||||||
|
display: flex;
|
||||||
|
gap: 15px;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
align-items: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.input-group {
|
||||||
|
display: flex;
|
||||||
|
gap: 10px;
|
||||||
|
align-items: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.input-group label {
|
||||||
|
color: #495057;
|
||||||
|
font-weight: 500;
|
||||||
|
}
|
||||||
|
|
||||||
|
.input-group input, .input-group select {
|
||||||
|
padding: 8px 12px;
|
||||||
|
border: 1px solid #ced4da;
|
||||||
|
border-radius: 4px;
|
||||||
|
font-size: 14px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message-area {
|
||||||
|
margin-top: 20px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message-input {
|
||||||
|
width: 100%;
|
||||||
|
height: 100px;
|
||||||
|
padding: 12px;
|
||||||
|
border: 1px solid #ced4da;
|
||||||
|
border-radius: 6px;
|
||||||
|
resize: none;
|
||||||
|
font-size: 14px;
|
||||||
|
margin-bottom: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.log-area {
|
||||||
|
width: 100%;
|
||||||
|
height: 300px;
|
||||||
|
padding: 15px;
|
||||||
|
border: 1px solid #ced4da;
|
||||||
|
border-radius: 6px;
|
||||||
|
background-color: #f8f9fa;
|
||||||
|
overflow-y: auto;
|
||||||
|
font-size: 14px;
|
||||||
|
line-height: 1.6;
|
||||||
|
}
|
||||||
|
|
||||||
|
.log-item {
|
||||||
|
margin-bottom: 8px;
|
||||||
|
padding-bottom: 8px;
|
||||||
|
border-bottom: 1px dashed #e9ecef;
|
||||||
|
}
|
||||||
|
|
||||||
|
.log-time {
|
||||||
|
color: #6c757d;
|
||||||
|
font-size: 12px;
|
||||||
|
margin-right: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.log-send {
|
||||||
|
color: #2980b9;
|
||||||
|
}
|
||||||
|
|
||||||
|
.log-receive {
|
||||||
|
color: #27ae60;
|
||||||
|
}
|
||||||
|
|
||||||
|
.log-status {
|
||||||
|
color: #856404;
|
||||||
|
}
|
||||||
|
|
||||||
|
.log-error {
|
||||||
|
color: #e74c3c;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
<h1>WebSocket 测试工具</h1>
|
||||||
|
|
||||||
|
<!-- 连接状态区 -->
|
||||||
|
<div class="status-bar">
|
||||||
|
<div class="status-label">连接状态:</div>
|
||||||
|
<div id="connectionStatus" class="status-value status-disconnected">未连接</div>
|
||||||
|
<div class="status-label">服务地址:</div>
|
||||||
|
<div id="wsUrl" class="status-value">ws://192.168.110.25:8000/ws</div>
|
||||||
|
<div class="status-label">连接时间:</div>
|
||||||
|
<div id="connectTime" class="status-value">-</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 控制按钮区 -->
|
||||||
|
<div class="control-group">
|
||||||
|
<button id="connectBtn" class="btn btn-primary">建立连接</button>
|
||||||
|
<button id="disconnectBtn" class="btn btn-danger" disabled>断开连接</button>
|
||||||
|
|
||||||
|
<!-- 心跳控制 -->
|
||||||
|
<div class="input-group">
|
||||||
|
<label>自动心跳:</label>
|
||||||
|
<select id="autoHeartbeat">
|
||||||
|
<option value="on">开启</option>
|
||||||
|
<option value="off">关闭</option>
|
||||||
|
</select>
|
||||||
|
<label>间隔(秒):</label>
|
||||||
|
<input type="number" id="heartbeatInterval" value="30" min="10" max="120" style="width: 80px;">
|
||||||
|
<button id="sendHeartbeatBtn" class="btn btn-success">手动发送心跳</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 自定义消息发送区 -->
|
||||||
|
<div class="message-area">
|
||||||
|
<h3>发送自定义消息</h3>
|
||||||
|
<textarea id="messageInput" class="message-input"
|
||||||
|
placeholder='示例:{"type":"test","content":"Hello WebSocket"}'>{"type":"test","content":"Hello WebSocket"}</textarea>
|
||||||
|
<button id="sendMessageBtn" class="btn btn-primary" disabled>发送消息</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 日志显示区 -->
|
||||||
|
<div class="message-area">
|
||||||
|
<h3>消息日志</h3>
|
||||||
|
<div id="logContainer" class="log-area">
|
||||||
|
<div class="log-item"><span class="log-time">[加载完成]</span> 请点击「建立连接」开始测试</div>
|
||||||
|
</div>
|
||||||
|
<button id="clearLogBtn" class="btn btn-primary" style="margin-top: 10px;">清空日志</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
// 全局变量
|
||||||
|
let ws = null;
|
||||||
|
let heartbeatTimer = null;
|
||||||
|
const wsUrl = "ws://192.168.110.25:8000/ws";
|
||||||
|
|
||||||
|
// DOM 元素
|
||||||
|
const connectionStatus = document.getElementById('connectionStatus');
|
||||||
|
const connectTime = document.getElementById('connectTime');
|
||||||
|
const connectBtn = document.getElementById('connectBtn');
|
||||||
|
const disconnectBtn = document.getElementById('disconnectBtn');
|
||||||
|
const sendMessageBtn = document.getElementById('sendMessageBtn');
|
||||||
|
const sendHeartbeatBtn = document.getElementById('sendHeartbeatBtn');
|
||||||
|
const autoHeartbeat = document.getElementById('autoHeartbeat');
|
||||||
|
const heartbeatInterval = document.getElementById('heartbeatInterval');
|
||||||
|
const messageInput = document.getElementById('messageInput');
|
||||||
|
const logContainer = document.getElementById('logContainer');
|
||||||
|
const clearLogBtn = document.getElementById('clearLogBtn');
|
||||||
|
|
||||||
|
// 工具函数:添加日志
|
||||||
|
function addLog(content, type = 'status') {
|
||||||
|
const now = new Date().toLocaleString('zh-CN', {
|
||||||
|
year: 'numeric', month: '2-digit', day: '2-digit',
|
||||||
|
hour: '2-digit', minute: '2-digit', second: '2-digit'
|
||||||
|
});
|
||||||
|
const logItem = document.createElement('div');
|
||||||
|
logItem.className = 'log-item';
|
||||||
|
|
||||||
|
let logClass = '';
|
||||||
|
switch (type) {
|
||||||
|
case 'send':
|
||||||
|
logClass = 'log-send';
|
||||||
|
break;
|
||||||
|
case 'receive':
|
||||||
|
logClass = 'log-receive';
|
||||||
|
break;
|
||||||
|
case 'error':
|
||||||
|
logClass = 'log-error';
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
logClass = 'log-status';
|
||||||
|
}
|
||||||
|
|
||||||
|
logItem.innerHTML = `<span class="log-time">[${now}]</span> <span class="${logClass}">${content}</span>`;
|
||||||
|
logContainer.appendChild(logItem);
|
||||||
|
// 滚动到最新日志
|
||||||
|
logContainer.scrollTop = logContainer.scrollHeight;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 工具函数:格式化JSON(便于日志显示)
|
||||||
|
function formatJson(jsonStr) {
|
||||||
|
try {
|
||||||
|
const obj = JSON.parse(jsonStr);
|
||||||
|
return JSON.stringify(obj, null, 2);
|
||||||
|
} catch (e) {
|
||||||
|
return jsonStr; // 非JSON格式直接返回
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 建立WebSocket连接
|
||||||
|
function connectWebSocket() {
|
||||||
|
if (ws) {
|
||||||
|
addLog('已存在连接,无需重复建立', 'error');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
ws = new WebSocket(wsUrl);
|
||||||
|
|
||||||
|
// 连接成功
|
||||||
|
ws.onopen = function () {
|
||||||
|
connectionStatus.className = 'status-value status-connected';
|
||||||
|
connectionStatus.textContent = '已连接';
|
||||||
|
const now = new Date().toLocaleString('zh-CN');
|
||||||
|
connectTime.textContent = now;
|
||||||
|
addLog(`连接成功!服务地址:${wsUrl}`, 'status');
|
||||||
|
|
||||||
|
// 更新按钮状态
|
||||||
|
connectBtn.disabled = true;
|
||||||
|
disconnectBtn.disabled = false;
|
||||||
|
sendMessageBtn.disabled = false;
|
||||||
|
|
||||||
|
// 开启自动心跳(默认开启)
|
||||||
|
if (autoHeartbeat.value === 'on') {
|
||||||
|
startAutoHeartbeat();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// 接收消息
|
||||||
|
ws.onmessage = function (event) {
|
||||||
|
const message = event.data;
|
||||||
|
addLog(`收到消息:\n${formatJson(message)}`, 'receive');
|
||||||
|
};
|
||||||
|
|
||||||
|
// 连接关闭
|
||||||
|
ws.onclose = function (event) {
|
||||||
|
connectionStatus.className = 'status-value status-disconnected';
|
||||||
|
connectionStatus.textContent = '已断开';
|
||||||
|
addLog(`连接断开!代码:${event.code},原因:${event.reason || '未知'}`, 'status');
|
||||||
|
|
||||||
|
// 清除自动心跳
|
||||||
|
stopAutoHeartbeat();
|
||||||
|
|
||||||
|
// 更新按钮状态
|
||||||
|
connectBtn.disabled = false;
|
||||||
|
disconnectBtn.disabled = true;
|
||||||
|
sendMessageBtn.disabled = true;
|
||||||
|
|
||||||
|
// 重置WebSocket对象
|
||||||
|
ws = null;
|
||||||
|
};
|
||||||
|
|
||||||
|
// 连接错误
|
||||||
|
ws.onerror = function (error) {
|
||||||
|
addLog(`连接错误:${error.message || '未知错误'}`, 'error');
|
||||||
|
};
|
||||||
|
|
||||||
|
} catch (e) {
|
||||||
|
addLog(`建立连接失败:${e.message}`, 'error');
|
||||||
|
ws = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 断开WebSocket连接
|
||||||
|
function disconnectWebSocket() {
|
||||||
|
if (!ws) {
|
||||||
|
addLog('当前无连接,无需断开', 'error');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
ws.close(1000, '手动断开连接');
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送心跳消息(符合约定格式:{"timestamp":xxxxx, "type":"heartbeat"})
|
||||||
|
function sendHeartbeat() {
|
||||||
|
if (!ws || ws.readyState !== WebSocket.OPEN) {
|
||||||
|
addLog('发送心跳失败:当前无有效连接', 'error');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const heartbeatMsg = {
|
||||||
|
timestamp: Date.now(), // 当前毫秒时间戳
|
||||||
|
type: "heartbeat"
|
||||||
|
};
|
||||||
|
const msgStr = JSON.stringify(heartbeatMsg);
|
||||||
|
|
||||||
|
ws.send(msgStr);
|
||||||
|
addLog(`发送心跳:\n${formatJson(msgStr)}`, 'send');
|
||||||
|
}
|
||||||
|
|
||||||
|
// 开启自动心跳
|
||||||
|
function startAutoHeartbeat() {
|
||||||
|
// 先停止已有定时器
|
||||||
|
stopAutoHeartbeat();
|
||||||
|
|
||||||
|
const interval = parseInt(heartbeatInterval.value) * 1000;
|
||||||
|
if (isNaN(interval) || interval < 10000) {
|
||||||
|
addLog('自动心跳间隔无效,已重置为30秒', 'error');
|
||||||
|
heartbeatInterval.value = 30;
|
||||||
|
return startAutoHeartbeat();
|
||||||
|
}
|
||||||
|
|
||||||
|
addLog(`开启自动心跳,间隔:${heartbeatInterval.value}秒`, 'status');
|
||||||
|
heartbeatTimer = setInterval(sendHeartbeat, interval);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 停止自动心跳
|
||||||
|
function stopAutoHeartbeat() {
|
||||||
|
if (heartbeatTimer) {
|
||||||
|
clearInterval(heartbeatTimer);
|
||||||
|
heartbeatTimer = null;
|
||||||
|
addLog('已停止自动心跳', 'status');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送自定义消息
|
||||||
|
function sendCustomMessage() {
|
||||||
|
if (!ws || ws.readyState !== WebSocket.OPEN) {
|
||||||
|
addLog('发送消息失败:当前无有效连接', 'error');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const msgStr = messageInput.value.trim();
|
||||||
|
if (!msgStr) {
|
||||||
|
addLog('发送消息失败:消息内容不能为空', 'error');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 验证JSON格式(可选,仅提示不强制)
|
||||||
|
JSON.parse(msgStr);
|
||||||
|
ws.send(msgStr);
|
||||||
|
addLog(`发送自定义消息:\n${formatJson(msgStr)}`, 'send');
|
||||||
|
} catch (e) {
|
||||||
|
addLog(`JSON格式错误:${e.message},仍尝试发送原始内容`, 'error');
|
||||||
|
ws.send(msgStr);
|
||||||
|
addLog(`发送自定义消息(非JSON):\n${msgStr}`, 'send');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 绑定按钮事件
|
||||||
|
connectBtn.addEventListener('click', connectWebSocket);
|
||||||
|
disconnectBtn.addEventListener('click', disconnectWebSocket);
|
||||||
|
sendMessageBtn.addEventListener('click', sendCustomMessage);
|
||||||
|
sendHeartbeatBtn.addEventListener('click', sendHeartbeat);
|
||||||
|
clearLogBtn.addEventListener('click', () => {
|
||||||
|
logContainer.innerHTML = '<div class="log-item"><span class="log-time">[日志已清空]</span> 请继续操作...</div>';
|
||||||
|
});
|
||||||
|
|
||||||
|
// 自动心跳开关变更事件
|
||||||
|
autoHeartbeat.addEventListener('change', function () {
|
||||||
|
if (ws && ws.readyState === WebSocket.OPEN) {
|
||||||
|
if (this.value === 'on') {
|
||||||
|
startAutoHeartbeat();
|
||||||
|
} else {
|
||||||
|
stopAutoHeartbeat();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
addLog('需先建立有效连接才能控制自动心跳', 'error');
|
||||||
|
// 重置选择
|
||||||
|
this.value = 'off';
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// 心跳间隔变更事件(实时生效)
|
||||||
|
heartbeatInterval.addEventListener('change', function () {
|
||||||
|
if (autoHeartbeat.value === 'on' && ws && ws.readyState === WebSocket.OPEN) {
|
||||||
|
startAutoHeartbeat();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// 快捷键支持(Ctrl+Enter发送消息)
|
||||||
|
messageInput.addEventListener('keydown', function (e) {
|
||||||
|
if (e.ctrlKey && e.key === 'Enter') {
|
||||||
|
sendCustomMessage();
|
||||||
|
e.preventDefault();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
BIN
ws/__pycache__/ws.cpython-312.pyc
Normal file
BIN
ws/__pycache__/ws.cpython-312.pyc
Normal file
Binary file not shown.
200
ws/ws.py
Normal file
200
ws/ws.py
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
import datetime
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
# 创建WebSocket路由
|
||||||
|
ws_router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
# 客户端连接信息数据结构
|
||||||
|
class ClientConnection:
|
||||||
|
def __init__(self, websocket: WebSocket, client_ip: str):
|
||||||
|
self.websocket = websocket
|
||||||
|
self.client_ip = client_ip
|
||||||
|
self.last_heartbeat = datetime.datetime.now() # 初始心跳时间为连接时间
|
||||||
|
|
||||||
|
def update_heartbeat(self):
|
||||||
|
"""更新心跳时间为当前时间"""
|
||||||
|
self.last_heartbeat = datetime.datetime.now()
|
||||||
|
# 打印心跳更新日志
|
||||||
|
print(f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 客户端 {self.client_ip} 心跳时间已更新")
|
||||||
|
|
||||||
|
def is_alive(self, timeout_seconds: int = 60) -> bool:
|
||||||
|
"""检查客户端是否活跃(心跳超时阈值:60秒)"""
|
||||||
|
timeout = (datetime.datetime.now() - self.last_heartbeat).total_seconds()
|
||||||
|
# 打印心跳检查明细(便于排查超时原因)
|
||||||
|
print(f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 客户端 {self.client_ip} 心跳检查:"
|
||||||
|
f"上次心跳距今 {timeout:.1f} 秒(阈值:{timeout_seconds}秒)")
|
||||||
|
return timeout < timeout_seconds
|
||||||
|
|
||||||
|
|
||||||
|
# 存储所有已连接的客户端(key:客户端IP、value:ClientConnection对象)
|
||||||
|
connected_clients: Dict[str, ClientConnection] = {}
|
||||||
|
|
||||||
|
# 心跳检查任务引用(全局变量、用于应用关闭时取消任务)
|
||||||
|
heartbeat_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
|
||||||
|
async def heartbeat_checker():
|
||||||
|
"""定期检查客户端心跳(每30秒一次)、超时直接剔除(不发通知)"""
|
||||||
|
while True:
|
||||||
|
current_time = datetime.datetime.now()
|
||||||
|
print(f"\n[{current_time:%Y-%m-%d %H:%M:%S}] === 开始新一轮心跳检查 ===")
|
||||||
|
print(f"[{current_time:%Y-%m-%d %H:%M:%S}] 当前在线客户端总数:{len(connected_clients)}")
|
||||||
|
|
||||||
|
# 1. 收集超时客户端IP(避免遍历中修改字典)
|
||||||
|
timeout_clients = []
|
||||||
|
for client_ip, connection in connected_clients.items():
|
||||||
|
if not connection.is_alive():
|
||||||
|
timeout_clients.append(client_ip)
|
||||||
|
|
||||||
|
# 2. 处理超时客户端(关闭连接+移除记录)
|
||||||
|
if timeout_clients:
|
||||||
|
print(f"[{current_time:%Y-%m-%d %H:%M:%S}] 发现超时客户端:{timeout_clients}(共{len(timeout_clients)}个)")
|
||||||
|
for client_ip in timeout_clients:
|
||||||
|
try:
|
||||||
|
connection = connected_clients[client_ip]
|
||||||
|
# 直接关闭连接(不发送任何通知)
|
||||||
|
await connection.websocket.close(code=1008, reason="心跳超时(>60秒)")
|
||||||
|
print(f"[{current_time:%Y-%m-%d %H:%M:%S}] 客户端 {client_ip} 已关闭(超时)")
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
f"[{current_time:%Y-%m-%d %H:%M:%S}] 关闭客户端 {client_ip} 失败:{str(e)}(错误类型:{type(e).__name__})")
|
||||||
|
finally:
|
||||||
|
# 确保从客户端列表中移除(无论关闭是否成功)
|
||||||
|
if client_ip in connected_clients:
|
||||||
|
del connected_clients[client_ip]
|
||||||
|
print(f"[{current_time:%Y-%m-%d %H:%M:%S}] 客户端 {client_ip} 已从连接列表移除")
|
||||||
|
else:
|
||||||
|
print(f"[{current_time:%Y-%m-%d %H:%M:%S}] 无超时客户端、心跳检查完成")
|
||||||
|
|
||||||
|
# 3. 等待30秒后进行下一轮检查
|
||||||
|
await asyncio.sleep(30)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""应用生命周期管理:启动时创建心跳任务、关闭时取消任务"""
|
||||||
|
global heartbeat_task
|
||||||
|
# 启动阶段:创建心跳检查任务
|
||||||
|
heartbeat_task = asyncio.create_task(heartbeat_checker())
|
||||||
|
print(f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 心跳检查任务已启动(任务ID:{id(heartbeat_task)})")
|
||||||
|
yield # 应用运行中
|
||||||
|
# 关闭阶段:取消心跳任务
|
||||||
|
if heartbeat_task and not heartbeat_task.done():
|
||||||
|
heartbeat_task.cancel()
|
||||||
|
try:
|
||||||
|
await heartbeat_task # 等待任务优雅退出
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
print(f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 心跳检查任务已正常取消")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 取消心跳任务时出错:{str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
async def send_heartbeat_ack(client_ip: str, client_timestamp: Any) -> bool:
|
||||||
|
"""向客户端回复心跳确认(严格遵循 {"timestamp":xxxxx, "type":"heartbeat"} 格式)"""
|
||||||
|
if client_ip not in connected_clients:
|
||||||
|
print(f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 回复心跳失败:客户端 {client_ip} 不在连接列表中")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 修复:将这部分代码移出if语句块,确保始终定义ack_msg
|
||||||
|
# 服务端当前格式化时间戳(字符串类型,与日志时间格式匹配)
|
||||||
|
server_latest_timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
ack_msg = {
|
||||||
|
"timestamp": server_latest_timestamp,
|
||||||
|
"type": "heartbeat"
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
connection = connected_clients[client_ip]
|
||||||
|
await connection.websocket.send_json(ack_msg)
|
||||||
|
print(
|
||||||
|
f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 已向客户端 {client_ip} 回复心跳:{json.dumps(ack_msg, ensure_ascii=False)}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 回复客户端 {client_ip} 心跳失败:{str(e)}(错误类型:{type(e).__name__})")
|
||||||
|
# 发送失败时移除客户端(避免无效连接残留)
|
||||||
|
if client_ip in connected_clients:
|
||||||
|
del connected_clients[client_ip]
|
||||||
|
print(f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 客户端 {client_ip} 因心跳回复失败被移除")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@ws_router.websocket("/ws")
|
||||||
|
async def websocket_endpoint(websocket: WebSocket):
|
||||||
|
"""WebSocket核心端点:处理连接建立/消息接收/连接关闭"""
|
||||||
|
current_time = datetime.datetime.now()
|
||||||
|
# 1. 接受客户端连接请求
|
||||||
|
await websocket.accept()
|
||||||
|
# 获取客户端IP(作为唯一标识)
|
||||||
|
client_ip = websocket.client.host if websocket.client else "unknown_ip"
|
||||||
|
print(f"\n[{current_time:%Y-%m-%d %H:%M:%S}] 客户端 {client_ip} 连接请求已接受(WebSocket握手成功)")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 2. 处理"同一IP重复连接"场景:关闭旧连接、保留新连接
|
||||||
|
if client_ip in connected_clients:
|
||||||
|
old_connection = connected_clients[client_ip]
|
||||||
|
await old_connection.websocket.close(code=1008, reason="同一IP新连接已建立")
|
||||||
|
del connected_clients[client_ip]
|
||||||
|
print(f"[{current_time:%Y-%m-%d %H:%M:%S}] 已关闭客户端 {client_ip} 的旧连接(新连接已建立)")
|
||||||
|
|
||||||
|
# 3. 注册新客户端到连接列表
|
||||||
|
new_connection = ClientConnection(websocket, client_ip)
|
||||||
|
connected_clients[client_ip] = new_connection
|
||||||
|
print(
|
||||||
|
f"[{current_time:%Y-%m-%d %H:%M:%S}] 客户端 {client_ip} 已注册到连接列表、当前在线数:{len(connected_clients)}")
|
||||||
|
|
||||||
|
# 4. 循环接收客户端消息(持续监听)
|
||||||
|
while True:
|
||||||
|
# 接收原始文本消息(避免提前解析JSON、便于日志打印)
|
||||||
|
raw_data = await websocket.receive_text()
|
||||||
|
recv_time = datetime.datetime.now()
|
||||||
|
print(f"\n[{recv_time:%Y-%m-%d %H:%M:%S}] 收到客户端 {client_ip} 的消息:{raw_data}")
|
||||||
|
|
||||||
|
# 尝试解析JSON消息
|
||||||
|
try:
|
||||||
|
message = json.loads(raw_data)
|
||||||
|
print(
|
||||||
|
f"[{recv_time:%Y-%m-%d %H:%M:%S}] 消息解析成功:{json.dumps(message, ensure_ascii=False, indent=2)}")
|
||||||
|
|
||||||
|
# 5. 区分消息类型:仅处理心跳、其他消息不回复
|
||||||
|
if message.get("type") == "heartbeat":
|
||||||
|
# 验证心跳消息是否包含timestamp字段
|
||||||
|
client_timestamp = message.get("timestamp")
|
||||||
|
if client_timestamp is None:
|
||||||
|
print(f"[{recv_time:%Y-%m-%d %H:%M:%S}] 警告:客户端 {client_ip} 发送的心跳缺少'timestamp'字段")
|
||||||
|
continue # 不回复无效心跳
|
||||||
|
|
||||||
|
# 更新心跳时间 + 回复心跳确认
|
||||||
|
new_connection.update_heartbeat()
|
||||||
|
await send_heartbeat_ack(client_ip, client_timestamp)
|
||||||
|
else:
|
||||||
|
# 非心跳消息:仅打印日志、不回复任何内容
|
||||||
|
print(f"[{recv_time:%Y-%m-%d %H:%M:%S}] 非心跳消息(类型:{message.get('type')})、不回复")
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
# JSON格式错误:仅打印日志、不回复
|
||||||
|
print(f"[{recv_time:%Y-%m-%d %H:%M:%S}] 客户端 {client_ip} 消息格式错误:无效JSON(错误:{str(e)})")
|
||||||
|
except Exception as e:
|
||||||
|
# 其他未知错误:仅打印日志、不回复
|
||||||
|
print(
|
||||||
|
f"[{recv_time:%Y-%m-%d %H:%M:%S}] 处理客户端 {client_ip} 消息时出错:{str(e)}(错误类型:{type(e).__name__})")
|
||||||
|
|
||||||
|
except WebSocketDisconnect as e:
|
||||||
|
# 客户端主动断开连接(如关闭页面、网络中断)
|
||||||
|
print(
|
||||||
|
f"\n[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 客户端 {client_ip} 主动断开连接(代码:{e.code}、原因:{e.reason})")
|
||||||
|
except Exception as e:
|
||||||
|
# 其他连接级错误(如网络异常)
|
||||||
|
print(
|
||||||
|
f"\n[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 客户端 {client_ip} 连接异常:{str(e)}(错误类型:{type(e).__name__})")
|
||||||
|
finally:
|
||||||
|
# 无论何种退出原因、确保客户端从列表中移除
|
||||||
|
if client_ip in connected_clients:
|
||||||
|
del connected_clients[client_ip]
|
||||||
|
print(
|
||||||
|
f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 客户端 {client_ip} 已从连接列表移除、当前在线数:{len(connected_clients)}")
|
Reference in New Issue
Block a user