初始化

This commit is contained in:
ZZX9599
2025-09-02 18:51:50 +08:00
commit fe1b33a6e5
30 changed files with 1607 additions and 0 deletions

8
.idea/.gitignore generated vendored Normal file
View File

@ -0,0 +1,8 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

8
.idea/Video.iml generated Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="video" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

View File

@ -0,0 +1,98 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="76">
<item index="0" class="java.lang.String" itemvalue="scipy" />
<item index="1" class="java.lang.String" itemvalue="protobuf" />
<item index="2" class="java.lang.String" itemvalue="thop" />
<item index="3" class="java.lang.String" itemvalue="opencv-python" />
<item index="4" class="java.lang.String" itemvalue="PyYAML" />
<item index="5" class="java.lang.String" itemvalue="ipython" />
<item index="6" class="java.lang.String" itemvalue="torch" />
<item index="7" class="java.lang.String" itemvalue="numpy" />
<item index="8" class="java.lang.String" itemvalue="requests" />
<item index="9" class="java.lang.String" itemvalue="torchvision" />
<item index="10" class="java.lang.String" itemvalue="psutil" />
<item index="11" class="java.lang.String" itemvalue="tqdm" />
<item index="12" class="java.lang.String" itemvalue="pandas" />
<item index="13" class="java.lang.String" itemvalue="tensorboard" />
<item index="14" class="java.lang.String" itemvalue="seaborn" />
<item index="15" class="java.lang.String" itemvalue="matplotlib" />
<item index="16" class="java.lang.String" itemvalue="Pillow" />
<item index="17" class="java.lang.String" itemvalue="fastapi" />
<item index="18" class="java.lang.String" itemvalue="uvicorn" />
<item index="19" class="java.lang.String" itemvalue="python-jose" />
<item index="20" class="java.lang.String" itemvalue="passlib" />
<item index="21" class="java.lang.String" itemvalue="pydantic" />
<item index="22" class="java.lang.String" itemvalue="sqlalchemy" />
<item index="23" class="java.lang.String" itemvalue="imageio_ffmpeg" />
<item index="24" class="java.lang.String" itemvalue="ultralytics" />
<item index="25" class="java.lang.String" itemvalue="future" />
<item index="26" class="java.lang.String" itemvalue="jose" />
<item index="27" class="java.lang.String" itemvalue="ffmpeg-python" />
<item index="28" class="java.lang.String" itemvalue="setuptools" />
<item index="29" class="java.lang.String" itemvalue="opencv_python" />
<item index="30" class="java.lang.String" itemvalue="rsa" />
<item index="31" class="java.lang.String" itemvalue="greenlet" />
<item index="32" class="java.lang.String" itemvalue="networkx" />
<item index="33" class="java.lang.String" itemvalue="python-dateutil" />
<item index="34" class="java.lang.String" itemvalue="SQLAlchemy" />
<item index="35" class="java.lang.String" itemvalue="cffi" />
<item index="36" class="java.lang.String" itemvalue="python-dotenv" />
<item index="37" class="java.lang.String" itemvalue="h11" />
<item index="38" class="java.lang.String" itemvalue="py-cpuinfo" />
<item index="39" class="java.lang.String" itemvalue="cycler" />
<item index="40" class="java.lang.String" itemvalue="MarkupSafe" />
<item index="41" class="java.lang.String" itemvalue="pyasn1" />
<item index="42" class="java.lang.String" itemvalue="pycparser" />
<item index="43" class="java.lang.String" itemvalue="Jinja2" />
<item index="44" class="java.lang.String" itemvalue="sniffio" />
<item index="45" class="java.lang.String" itemvalue="ultralytics-thop" />
<item index="46" class="java.lang.String" itemvalue="fsspec" />
<item index="47" class="java.lang.String" itemvalue="filelock" />
<item index="48" class="java.lang.String" itemvalue="starlette" />
<item index="49" class="java.lang.String" itemvalue="certifi" />
<item index="50" class="java.lang.String" itemvalue="anyio" />
<item index="51" class="java.lang.String" itemvalue="urllib3" />
<item index="52" class="java.lang.String" itemvalue="pyparsing" />
<item index="53" class="java.lang.String" itemvalue="sympy" />
<item index="54" class="java.lang.String" itemvalue="annotated-types" />
<item index="55" class="java.lang.String" itemvalue="pydantic-settings" />
<item index="56" class="java.lang.String" itemvalue="six" />
<item index="57" class="java.lang.String" itemvalue="tzdata" />
<item index="58" class="java.lang.String" itemvalue="ecdsa" />
<item index="59" class="java.lang.String" itemvalue="kiwisolver" />
<item index="60" class="java.lang.String" itemvalue="packaging" />
<item index="61" class="java.lang.String" itemvalue="python-multipart" />
<item index="62" class="java.lang.String" itemvalue="click" />
<item index="63" class="java.lang.String" itemvalue="contourpy" />
<item index="64" class="java.lang.String" itemvalue="fonttools" />
<item index="65" class="java.lang.String" itemvalue="pydantic_core" />
<item index="66" class="java.lang.String" itemvalue="av" />
<item index="67" class="java.lang.String" itemvalue="colorama" />
<item index="68" class="java.lang.String" itemvalue="mpmath" />
<item index="69" class="java.lang.String" itemvalue="argon2-cffi-bindings" />
<item index="70" class="java.lang.String" itemvalue="typing_extensions" />
<item index="71" class="java.lang.String" itemvalue="charset-normalizer" />
<item index="72" class="java.lang.String" itemvalue="pillow" />
<item index="73" class="java.lang.String" itemvalue="argon2-cffi" />
<item index="74" class="java.lang.String" itemvalue="pytz" />
<item index="75" class="java.lang.String" itemvalue="idna" />
</list>
</value>
</option>
</inspection_tool>
<inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
<option name="ignoredErrors">
<list>
<option value="N802" />
</list>
</option>
</inspection_tool>
<inspection_tool class="Stylelint" enabled="true" level="ERROR" enabled_by_default="true" />
</profile>
</component>

View File

@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

7
.idea/misc.xml generated Normal file
View File

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="video" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="video" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml generated Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/Video.iml" filepath="$PROJECT_DIR$/.idea/Video.iml" />
</modules>
</component>
</project>

Binary file not shown.

19
config.ini Normal file
View File

@ -0,0 +1,19 @@
[server]
port = 8000
[mysql]
host = 192.168.110.65
port = 6975
user = video_check
password = fsjPfhxCs8NrFGmL
database = video_check
charset = utf8mb4
[jwt]
secret_key = 6tsieyd87wefdw2wgeduwte23rfcsd
algorithm = HS256
access_token_expire_minutes = 30
[live]
rtmp_url = rtmp://192.168.110.65:1935/live/
webrtc_url = http://192.168.110.65:1985/rtc/v1/whep/?app=live&stream=

Binary file not shown.

Binary file not shown.

17
ds/config.py Normal file
View File

@ -0,0 +1,17 @@
import configparser
import os
# 读取配置文件路径
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../config.ini")
# 初始化配置解析器
config = configparser.ConfigParser()
# 读取配置文件
config.read(config_path, encoding="utf-8")
# 暴露配置项(方便其他文件调用)
SERVER_CONFIG = config["server"]
MYSQL_CONFIG = config["mysql"]
JWT_CONFIG = config["jwt"]
LIVE_CONFIG = config["live"]

46
ds/db.py Normal file
View File

@ -0,0 +1,46 @@
import mysql.connector
from mysql.connector import Error
from .config import MYSQL_CONFIG
class Database:
"""MySQL 连接池管理类"""
pool_config = {
"host": MYSQL_CONFIG.get("host", "localhost"),
"port": int(MYSQL_CONFIG.get("port", 3306)),
"user": MYSQL_CONFIG.get("user", "root"),
"password": MYSQL_CONFIG.get("password", ""),
"database": MYSQL_CONFIG.get("database", ""),
"charset": MYSQL_CONFIG.get("charset", "utf8mb4"),
"pool_name": "fastapi_pool",
"pool_size": 5,
"pool_reset_session": True
}
@classmethod
def get_connection(cls):
"""获取数据库连接"""
try:
# 从连接池获取连接
conn = mysql.connector.connect(**cls.pool_config)
if conn.is_connected():
return conn
except Error as e:
# 抛出数据库连接错误(会被全局异常处理器捕获)
raise Exception(f"MySQL 连接失败: {str(e)}") from e
@classmethod
def close_connection(cls, conn, cursor=None):
"""关闭连接和游标"""
try:
if cursor:
cursor.close()
if conn and conn.is_connected():
conn.close()
except Error as e:
raise Exception(f"MySQL 连接关闭失败: {str(e)}") from e
# 暴露数据库操作工具
db = Database()

43
main.py Normal file
View File

@ -0,0 +1,43 @@
import uvicorn
from fastapi import FastAPI
from ds.config import SERVER_CONFIG
from middle.error_handler import global_exception_handler
from service.user_service import router as user_router
from service.device_service import router as device_router
from ws.ws import ws_router, lifespan
# ------------------------------
# 初始化 FastAPI 应用、指定生命周期管理
# ------------------------------
app = FastAPI(
title="内容安全审核后台",
description="内容安全审核后台",
version="1.0.0",
lifespan=lifespan
)
# ------------------------------
# 注册路由
# ------------------------------
app.include_router(user_router)
app.include_router(device_router)
app.include_router(ws_router)
# ------------------------------
# 注册全局异常处理器
# ------------------------------
app.add_exception_handler(Exception, global_exception_handler)
# ------------------------------
# 启动服务
# ------------------------------
if __name__ == "__main__":
port = int(SERVER_CONFIG.get("port", 8000))
uvicorn.run(
app="main:app",
host="0.0.0.0",
port=port,
reload=True,
ws="websockets"
)

Binary file not shown.

Binary file not shown.

96
middle/auth_middleware.py Normal file
View File

@ -0,0 +1,96 @@
from datetime import datetime, timedelta, timezone
from typing import Optional
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from passlib.context import CryptContext
from ds.config import JWT_CONFIG
from ds.db import db
from service.user_service import UserResponse
# ------------------------------
# 密码加密配置
# ------------------------------
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# ------------------------------
# JWT 配置
# ------------------------------
SECRET_KEY = JWT_CONFIG["secret_key"]
ALGORITHM = JWT_CONFIG["algorithm"]
ACCESS_TOKEN_EXPIRE_MINUTES = int(JWT_CONFIG["access_token_expire_minutes"])
# OAuth2 依赖(从请求头获取 Token、格式Bearer <token>
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login")
# ------------------------------
# 密码工具函数
# ------------------------------
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""验证明文密码与加密密码是否匹配"""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""对明文密码进行 bcrypt 加密"""
return pwd_context.hash(password)
# ------------------------------
# JWT 工具函数
# ------------------------------
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""生成 JWT Token"""
to_encode = data.copy()
# 设置过期时间
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=15)
# 添加过期时间到 Token 数据
to_encode.update({"exp": expire})
# 生成 Token
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
# ------------------------------
# 认证依赖(获取当前登录用户)
# ------------------------------
def get_current_user(token: str = Depends(oauth2_scheme)) -> UserResponse:
"""从 Token 中解析用户信息、验证通过后返回当前用户"""
# 认证失败异常
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token 无效或已过期",
headers={"WWW-Authenticate": "Bearer"},
)
try:
# 解码 Token
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
# 获取 Token 中的用户名
username: str = payload.get("sub")
if username is None:
raise credentials_exception
except JWTError:
raise credentials_exception
# 从数据库查询用户(验证用户是否存在)
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True) # 返回字典格式结果
query = "SELECT id, username, created_at, updated_at FROM users WHERE username = %s"
cursor.execute(query, (username,))
user = cursor.fetchone()
if user is None:
raise credentials_exception # 用户不存在
# 转换为 UserResponse 模型(自动校验字段)
return UserResponse(** user)
except Exception as e:
raise credentials_exception from e
finally:
db.close_connection(conn, cursor)

68
middle/error_handler.py Normal file
View File

@ -0,0 +1,68 @@
from fastapi import Request, status
from fastapi.responses import JSONResponse
from fastapi.exceptions import HTTPException, RequestValidationError
from mysql.connector import Error as MySQLError
from jose import JWTError
from schema.response_schema import APIResponse
async def global_exception_handler(request: Request, exc: Exception):
"""全局异常处理器:所有未捕获的异常都会在这里统一处理"""
# 1. 请求参数验证错误Pydantic 校验失败)
if isinstance(exc, RequestValidationError):
error_details = []
for err in exc.errors():
error_details.append(f"{err['loc'][1]}: {err['msg']}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=APIResponse(
code=400,
message=f"请求参数错误:{'; '.join(error_details)}",
data=None
).model_dump()
)
# 2. HTTP 异常(主动抛出的业务错误、如 401/404
if isinstance(exc, HTTPException):
return JSONResponse(
status_code=exc.status_code,
content=APIResponse(
code=exc.status_code,
message=exc.detail,
data=None
).model_dump()
)
# 3. JWT 相关错误Token 无效/过期)
if isinstance(exc, JWTError):
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content=APIResponse(
code=401,
message="Token 无效或已过期",
data=None
).model_dump(),
headers={"WWW-Authenticate": "Bearer"},
)
# 4. MySQL 数据库错误
if isinstance(exc, MySQLError):
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=APIResponse(
code=500,
message=f"数据库错误:{str(exc)}",
data=None
).model_dump()
)
# 5. 其他未知错误(兜底处理)
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=APIResponse(
code=500,
message=f"服务器内部错误:{str(exc)}",
data=None
).model_dump()
)

Binary file not shown.

Binary file not shown.

Binary file not shown.

51
schema/device_schema.py Normal file
View File

@ -0,0 +1,51 @@
import hashlib
from datetime import datetime
from typing import Optional, List, Dict
from pydantic import BaseModel, Field
# ------------------------------
# 请求模型(前端传参校验)
# ------------------------------
class DeviceCreateRequest(BaseModel):
"""设备流信息创建请求模型"""
ip: Optional[str] = Field(..., max_length=100, description="设备IP地址")
hostname: Optional[str] = Field(None, max_length=100, description="设备别名")
params: Optional[Dict] = Field(None, description="设备详细信息")
def md5_encrypt(text: str) -> str:
"""对字符串进行MD5加密"""
if not text:
return ""
md5_hash = hashlib.md5()
md5_hash.update(text.encode('utf-8'))
return md5_hash.hexdigest()
# ------------------------------
# 响应模型(后端返回设备数据)
# ------------------------------
class DeviceResponse(BaseModel):
"""设备流信息响应模型(字段与表结构完全对齐)"""
id: int = Field(..., description="设备ID")
hostname: Optional[str] = Field(None, max_length=100, description="设备别名")
rtmp_push_url: Optional[str] = Field(None, description="需要推送的RTMP地址")
live_webrtc_url: Optional[str] = Field(None, description="直播的Webrtc地址")
detection_webrtc_url: Optional[str] = Field(None, description="检测的Webrtc地址")
device_online_status: int = Field(..., description="设备在线状态1-在线、0-离线)")
device_type: Optional[str] = Field(None, description="设备类型")
alarm_count: int = Field(..., description="报警次数")
params: Optional[str] = Field(None, description="设备详细信息")
created_at: datetime = Field(..., description="记录创建时间")
updated_at: datetime = Field(..., description="记录更新时间")
# 支持从数据库查询结果转换
model_config = {"from_attributes": True}
class DeviceListResponse(BaseModel):
"""设备流信息列表响应模型"""
total: int = Field(..., description="设备总数")
devices: List[DeviceResponse] = Field(..., description="设备列表")

13
schema/response_schema.py Normal file
View File

@ -0,0 +1,13 @@
from typing import Optional, Any
from pydantic import BaseModel, Field
class APIResponse(BaseModel):
"""统一 API 响应模型(所有接口必返此格式)"""
code: int = Field(..., description="状态码200=成功、4xx=客户端错误、5xx=服务端错误")
message: str = Field(..., description="响应信息:成功/错误描述")
data: Optional[Any] = Field(None, description="响应数据:成功时返回、错误时为 None")
# Pydantic V2 配置(支持从 ORM 对象转换)
model_config = {"from_attributes": True}

32
schema/user_schema.py Normal file
View File

@ -0,0 +1,32 @@
from datetime import datetime
from pydantic import BaseModel, Field
# ------------------------------
# 请求模型(前端传参校验)
# ------------------------------
class UserRegisterRequest(BaseModel):
"""用户注册请求模型"""
username: str = Field(..., min_length=3, max_length=50, description="用户名3-50字符")
password: str = Field(..., min_length=6, max_length=100, description="密码6-100字符")
class UserLoginRequest(BaseModel):
"""用户登录请求模型"""
username: str = Field(..., description="用户名")
password: str = Field(..., description="密码")
# ------------------------------
# 响应模型(后端返回用户数据)
# ------------------------------
class UserResponse(BaseModel):
"""用户信息响应模型(隐藏密码等敏感字段)"""
id: int = Field(..., description="用户ID")
username: str = Field(..., description="用户名")
created_at: datetime = Field(..., description="创建时间")
updated_at: datetime = Field(..., description="更新时间")
# Pydantic V2 配置(支持从数据库查询结果转换)
model_config = {"from_attributes": True}

Binary file not shown.

Binary file not shown.

251
service/device_service.py Normal file
View File

@ -0,0 +1,251 @@
import json
from fastapi import HTTPException, Query, APIRouter, Depends, Request
from mysql.connector import Error as MySQLError
from ds.config import LIVE_CONFIG
from ds.db import db
from middle.auth_middleware import get_current_user
# 注意导入的Schema已更新字段
from schema.device_schema import (
DeviceCreateRequest,
DeviceResponse,
DeviceListResponse,
md5_encrypt
)
from schema.response_schema import APIResponse
from schema.user_schema import UserResponse
router = APIRouter(
prefix="/devices",
tags=["设备管理"]
)
# ------------------------------
# 1. 创建设备信息
# ------------------------------
@router.post("/add", response_model=APIResponse, summary="创建设备信息")
async def create_device(request: Request, device_data: DeviceCreateRequest):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 新增检查client_ip是否已存在
cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (device_data.ip,))
existing_device = cursor.fetchone()
if existing_device:
raise Exception(f"客户端IP {device_data.ip} 已存在,无法重复添加")
# 获取RTMP URL
rtmp_url = str(LIVE_CONFIG.get("rtmp_url", ""))
webrtc_url = str(LIVE_CONFIG.get("webrtc_url", ""))
# 将设备详细信息params转换为JSON字符串对应表中params字段
device_params_json = json.dumps(device_data.params) if device_data.params else None
# 对JSON字符串进行MD5加密用于生成唯一RTMP地址
device_md5 = md5_encrypt(device_params_json) if device_params_json else ""
# 解析User-Agent获取设备类型
user_agent = request.headers.get("User-Agent", "").lower()
# 优先处理User-Agent为default的情况
if user_agent == "default":
# 检查params中是否存在os键
if device_data.params and isinstance(device_data.params, dict) and "os" in device_data.params:
device_type = device_data.params["os"]
else:
device_type = "unknown"
elif "windows" in user_agent:
device_type = "windows"
elif "android" in user_agent:
device_type = "android"
elif "linux" in user_agent:
device_type = "linux"
else:
device_type = "unknown"
# SQL字段对齐表结构
insert_query = """
INSERT INTO devices
(client_ip, hostname, rtmp_push_url, live_webrtc_url, detection_webrtc_url,
device_online_status, device_type, alarm_count, params)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
"""
cursor.execute(insert_query, (
device_data.ip,
device_data.hostname,
rtmp_url + device_md5,
webrtc_url + device_md5,
"",
1,
device_type,
0,
device_params_json
))
conn.commit()
# 获取刚创建的设备信息
device_id = cursor.lastrowid
cursor.execute("SELECT * FROM devices WHERE id = %s", (device_id,))
device = cursor.fetchone()
return APIResponse(
code=200,
message="设备创建成功",
data=DeviceResponse(**device)
)
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"创建设备失败:{str(e)}") from e
except json.JSONDecodeError as e:
raise Exception(f"设备信息JSON序列化失败{str(e)}") from e
except Exception as e:
# 捕获IP已存在的自定义异常
if conn:
conn.rollback()
raise e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 2. 获取设备列表
# ------------------------------
@router.get("/", response_model=APIResponse, summary="获取设备列表")
async def get_device_list(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(10, ge=1, le=100, description="每页条数"),
device_type: str = Query(None, description="设备类型筛选"),
online_status: int = Query(None, ge=0, le=1, description="在线状态筛选1-在线、0-离线)")
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 构建查询条件
where_clause = []
params = []
if device_type:
where_clause.append("device_type = %s")
params.append(device_type)
if online_status is not None:
where_clause.append("device_online_status = %s")
params.append(online_status)
# 总条数查询
count_query = "SELECT COUNT(*) as total FROM devices"
if where_clause:
count_query += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_query, params)
total = cursor.fetchone()["total"]
# 分页查询SELECT * 会自动匹配表字段、响应模型已对齐)
offset = (page - 1) * page_size
query = "SELECT * FROM devices"
if where_clause:
query += " WHERE " + " AND ".join(where_clause)
query += " ORDER BY id DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset])
cursor.execute(query, params)
devices = cursor.fetchall()
# 响应模型已更新为params字段、直接转换即可
device_list = [DeviceResponse(**device) for device in devices]
return APIResponse(
code=200,
message="获取设备列表成功",
data=DeviceListResponse(total=total, devices=device_list)
)
except MySQLError as e:
raise Exception(f"获取设备列表失败:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 3. 获取单个设备详情
# ------------------------------
@router.get("/{device_id}", response_model=APIResponse, summary="获取设备详情")
async def get_device_detail(
device_id: int,
current_user: UserResponse = Depends(get_current_user)
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 查询设备信息SELECT * 匹配表字段)
query = "SELECT * FROM devices WHERE id = %s"
cursor.execute(query, (device_id,))
device = cursor.fetchone()
if not device:
raise HTTPException(
status_code=404,
detail=f"设备ID为 {device_id} 的设备不存在"
)
# 响应模型已更新为params字段
return APIResponse(
code=200,
message="获取设备详情成功",
data=DeviceResponse(**device)
)
except MySQLError as e:
raise Exception(f"获取设备详情失败:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 4. 删除设备信息
# ------------------------------
@router.delete("/{device_id}", response_model=APIResponse, summary="删除设备信息")
async def delete_device(
device_id: int,
current_user: UserResponse = Depends(get_current_user)
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 检查设备是否存在
cursor.execute("SELECT id FROM devices WHERE id = %s", (device_id,))
if not cursor.fetchone():
raise HTTPException(
status_code=404,
detail=f"设备ID为 {device_id} 的设备不存在"
)
# 执行删除
delete_query = "DELETE FROM devices WHERE id = %s"
cursor.execute(delete_query, (device_id,))
conn.commit()
return APIResponse(
code=200,
message=f"设备ID为 {device_id} 的设备已成功删除",
data=None
)
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"删除设备失败:{str(e)}") from e
finally:
db.close_connection(conn, cursor)

154
service/user_service.py Normal file
View File

@ -0,0 +1,154 @@
from datetime import timedelta
from fastapi import APIRouter, Depends, HTTPException
from mysql.connector import Error as MySQLError
from ds.db import db
from schema.user_schema import UserRegisterRequest, UserLoginRequest, UserResponse
from schema.response_schema import APIResponse
from middle.auth_middleware import (
get_password_hash,
verify_password,
create_access_token,
ACCESS_TOKEN_EXPIRE_MINUTES,
get_current_user
)
# 创建用户接口路由(前缀 /users、标签用于 Swagger 分类)
router = APIRouter(
prefix="/users",
tags=["用户管理"]
)
# ------------------------------
# 1. 用户注册接口
# ------------------------------
@router.post("/register", response_model=APIResponse, summary="用户注册")
async def user_register(request: UserRegisterRequest):
"""
用户注册:
- 校验用户名是否已存在
- 加密密码后插入数据库
- 返回注册成功信息
"""
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 1. 检查用户名是否已存在(唯一索引)
check_query = "SELECT username FROM users WHERE username = %s"
cursor.execute(check_query, (request.username,))
existing_user = cursor.fetchone()
if existing_user:
raise HTTPException(
status_code=400,
detail=f"用户名 '{request.username}' 已存在、请更换其他用户名"
)
# 2. 加密密码
hashed_password = get_password_hash(request.password)
# 3. 插入新用户到数据库
insert_query = """
INSERT INTO users (username, password)
VALUES (%s, %s)
"""
cursor.execute(insert_query, (request.username, hashed_password))
conn.commit() # 提交事务
# 4. 返回注册成功响应
return APIResponse(
code=201, # 201 表示资源创建成功
message=f"用户 '{request.username}' 注册成功",
data=None
)
except MySQLError as e:
conn.rollback() # 数据库错误时回滚事务
raise Exception(f"注册失败:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 2. 用户登录接口
# ------------------------------
@router.post("/login", response_model=APIResponse, summary="用户登录(获取 Token")
async def user_login(request: UserLoginRequest):
"""
用户登录:
- 校验用户名是否存在
- 校验密码是否正确
- 生成 JWT Token 并返回
"""
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 修复SQL查询添加 created_at 和 updated_at 字段
query = """
SELECT id, username, password, created_at, updated_at
FROM users
WHERE username = %s
"""
cursor.execute(query, (request.username,))
user = cursor.fetchone()
# 2. 校验用户名和密码
if not user or not verify_password(request.password, user["password"]):
raise HTTPException(
status_code=401,
detail="用户名或密码错误",
headers={"WWW-Authenticate": "Bearer"},
)
# 3. 生成 Token过期时间从配置读取
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user["username"]},
expires_delta=access_token_expires
)
# 4. 返回 Token 和用户基本信息
return APIResponse(
code=200,
message="登录成功",
data={
"access_token": access_token,
"token_type": "bearer",
"user": UserResponse(
id=user["id"],
username=user["username"],
created_at=user.get("created_at"),
updated_at=user.get("updated_at")
)
}
)
except MySQLError as e:
raise Exception(f"登录失败:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 3. 获取当前登录用户信息(需认证)
# ------------------------------
@router.get("/me", response_model=APIResponse, summary="获取当前用户信息")
async def get_current_user_info(
current_user: UserResponse = Depends(get_current_user) # 依赖认证中间件
):
"""
获取当前登录用户信息:
- 需在请求头携带 Token格式Bearer <token>
- 认证通过后返回用户信息
"""
return APIResponse(
code=200,
message="获取用户信息成功",
data=current_user
)

482
ws.html Normal file
View File

@ -0,0 +1,482 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>WebSocket 测试工具</title>
<style>
* {
box-sizing: border-box;
margin: 0;
padding: 0;
font-family: 'Arial', 'Microsoft YaHei', sans-serif;
}
body {
max-width: 1200px;
margin: 20px auto;
padding: 0 20px;
background-color: #f5f7fa;
}
.container {
background: white;
border-radius: 8px;
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
padding: 25px;
margin-bottom: 20px;
}
h1 {
color: #2c3e50;
margin-bottom: 20px;
font-size: 24px;
border-bottom: 2px solid #3498db;
padding-bottom: 10px;
}
.status-bar {
display: flex;
align-items: center;
gap: 15px;
margin-bottom: 20px;
padding: 12px 15px;
background-color: #f8f9fa;
border-radius: 6px;
}
.status-label {
font-weight: bold;
color: #495057;
}
.status-value {
padding: 4px 10px;
border-radius: 4px;
font-weight: bold;
}
.status-connected {
background-color: #d4edda;
color: #155724;
}
.status-disconnected {
background-color: #f8d7da;
color: #721c24;
}
.status-connecting {
background-color: #fff3cd;
color: #856404;
}
.btn {
padding: 8px 16px;
border: none;
border-radius: 4px;
cursor: pointer;
font-size: 14px;
font-weight: 500;
transition: background-color 0.2s;
}
.btn-primary {
background-color: #3498db;
color: white;
}
.btn-primary:hover {
background-color: #2980b9;
}
.btn-danger {
background-color: #e74c3c;
color: white;
}
.btn-danger:hover {
background-color: #c0392b;
}
.btn-success {
background-color: #2ecc71;
color: white;
}
.btn-success:hover {
background-color: #27ae60;
}
.control-group {
display: flex;
gap: 15px;
margin-bottom: 20px;
align-items: center;
}
.input-group {
display: flex;
gap: 10px;
align-items: center;
}
.input-group label {
color: #495057;
font-weight: 500;
}
.input-group input, .input-group select {
padding: 8px 12px;
border: 1px solid #ced4da;
border-radius: 4px;
font-size: 14px;
}
.message-area {
margin-top: 20px;
}
.message-input {
width: 100%;
height: 100px;
padding: 12px;
border: 1px solid #ced4da;
border-radius: 6px;
resize: none;
font-size: 14px;
margin-bottom: 10px;
}
.log-area {
width: 100%;
height: 300px;
padding: 15px;
border: 1px solid #ced4da;
border-radius: 6px;
background-color: #f8f9fa;
overflow-y: auto;
font-size: 14px;
line-height: 1.6;
}
.log-item {
margin-bottom: 8px;
padding-bottom: 8px;
border-bottom: 1px dashed #e9ecef;
}
.log-time {
color: #6c757d;
font-size: 12px;
margin-right: 10px;
}
.log-send {
color: #2980b9;
}
.log-receive {
color: #27ae60;
}
.log-status {
color: #856404;
}
.log-error {
color: #e74c3c;
}
</style>
</head>
<body>
<div class="container">
<h1>WebSocket 测试工具</h1>
<!-- 连接状态区 -->
<div class="status-bar">
<div class="status-label">连接状态:</div>
<div id="connectionStatus" class="status-value status-disconnected">未连接</div>
<div class="status-label">服务地址:</div>
<div id="wsUrl" class="status-value">ws://192.168.110.25:8000/ws</div>
<div class="status-label">连接时间:</div>
<div id="connectTime" class="status-value">-</div>
</div>
<!-- 控制按钮区 -->
<div class="control-group">
<button id="connectBtn" class="btn btn-primary">建立连接</button>
<button id="disconnectBtn" class="btn btn-danger" disabled>断开连接</button>
<!-- 心跳控制 -->
<div class="input-group">
<label>自动心跳:</label>
<select id="autoHeartbeat">
<option value="on">开启</option>
<option value="off">关闭</option>
</select>
<label>间隔(秒)</label>
<input type="number" id="heartbeatInterval" value="30" min="10" max="120" style="width: 80px;">
<button id="sendHeartbeatBtn" class="btn btn-success">手动发送心跳</button>
</div>
</div>
<!-- 自定义消息发送区 -->
<div class="message-area">
<h3>发送自定义消息</h3>
<textarea id="messageInput" class="message-input"
placeholder='示例:{"type":"test","content":"Hello WebSocket"}'>{"type":"test","content":"Hello WebSocket"}</textarea>
<button id="sendMessageBtn" class="btn btn-primary" disabled>发送消息</button>
</div>
<!-- 日志显示区 -->
<div class="message-area">
<h3>消息日志</h3>
<div id="logContainer" class="log-area">
<div class="log-item"><span class="log-time">[加载完成]</span> 请点击「建立连接」开始测试</div>
</div>
<button id="clearLogBtn" class="btn btn-primary" style="margin-top: 10px;">清空日志</button>
</div>
</div>
<script>
// 全局变量
let ws = null;
let heartbeatTimer = null;
const wsUrl = "ws://192.168.110.25:8000/ws";
// DOM 元素
const connectionStatus = document.getElementById('connectionStatus');
const connectTime = document.getElementById('connectTime');
const connectBtn = document.getElementById('connectBtn');
const disconnectBtn = document.getElementById('disconnectBtn');
const sendMessageBtn = document.getElementById('sendMessageBtn');
const sendHeartbeatBtn = document.getElementById('sendHeartbeatBtn');
const autoHeartbeat = document.getElementById('autoHeartbeat');
const heartbeatInterval = document.getElementById('heartbeatInterval');
const messageInput = document.getElementById('messageInput');
const logContainer = document.getElementById('logContainer');
const clearLogBtn = document.getElementById('clearLogBtn');
// 工具函数:添加日志
function addLog(content, type = 'status') {
const now = new Date().toLocaleString('zh-CN', {
year: 'numeric', month: '2-digit', day: '2-digit',
hour: '2-digit', minute: '2-digit', second: '2-digit'
});
const logItem = document.createElement('div');
logItem.className = 'log-item';
let logClass = '';
switch (type) {
case 'send':
logClass = 'log-send';
break;
case 'receive':
logClass = 'log-receive';
break;
case 'error':
logClass = 'log-error';
break;
default:
logClass = 'log-status';
}
logItem.innerHTML = `<span class="log-time">[${now}]</span> <span class="${logClass}">${content}</span>`;
logContainer.appendChild(logItem);
// 滚动到最新日志
logContainer.scrollTop = logContainer.scrollHeight;
}
// 工具函数格式化JSON便于日志显示
function formatJson(jsonStr) {
try {
const obj = JSON.parse(jsonStr);
return JSON.stringify(obj, null, 2);
} catch (e) {
return jsonStr; // 非JSON格式直接返回
}
}
// 建立WebSocket连接
function connectWebSocket() {
if (ws) {
addLog('已存在连接,无需重复建立', 'error');
return;
}
try {
ws = new WebSocket(wsUrl);
// 连接成功
ws.onopen = function () {
connectionStatus.className = 'status-value status-connected';
connectionStatus.textContent = '已连接';
const now = new Date().toLocaleString('zh-CN');
connectTime.textContent = now;
addLog(`连接成功!服务地址:${wsUrl}`, 'status');
// 更新按钮状态
connectBtn.disabled = true;
disconnectBtn.disabled = false;
sendMessageBtn.disabled = false;
// 开启自动心跳(默认开启)
if (autoHeartbeat.value === 'on') {
startAutoHeartbeat();
}
};
// 接收消息
ws.onmessage = function (event) {
const message = event.data;
addLog(`收到消息:\n${formatJson(message)}`, 'receive');
};
// 连接关闭
ws.onclose = function (event) {
connectionStatus.className = 'status-value status-disconnected';
connectionStatus.textContent = '已断开';
addLog(`连接断开!代码:${event.code},原因:${event.reason || '未知'}`, 'status');
// 清除自动心跳
stopAutoHeartbeat();
// 更新按钮状态
connectBtn.disabled = false;
disconnectBtn.disabled = true;
sendMessageBtn.disabled = true;
// 重置WebSocket对象
ws = null;
};
// 连接错误
ws.onerror = function (error) {
addLog(`连接错误:${error.message || '未知错误'}`, 'error');
};
} catch (e) {
addLog(`建立连接失败:${e.message}`, 'error');
ws = null;
}
}
// 断开WebSocket连接
function disconnectWebSocket() {
if (!ws) {
addLog('当前无连接,无需断开', 'error');
return;
}
ws.close(1000, '手动断开连接');
}
// 发送心跳消息(符合约定格式:{"timestamp":xxxxx, "type":"heartbeat"}
function sendHeartbeat() {
if (!ws || ws.readyState !== WebSocket.OPEN) {
addLog('发送心跳失败:当前无有效连接', 'error');
return;
}
const heartbeatMsg = {
timestamp: Date.now(), // 当前毫秒时间戳
type: "heartbeat"
};
const msgStr = JSON.stringify(heartbeatMsg);
ws.send(msgStr);
addLog(`发送心跳:\n${formatJson(msgStr)}`, 'send');
}
// 开启自动心跳
function startAutoHeartbeat() {
// 先停止已有定时器
stopAutoHeartbeat();
const interval = parseInt(heartbeatInterval.value) * 1000;
if (isNaN(interval) || interval < 10000) {
addLog('自动心跳间隔无效已重置为30秒', 'error');
heartbeatInterval.value = 30;
return startAutoHeartbeat();
}
addLog(`开启自动心跳,间隔:${heartbeatInterval.value}`, 'status');
heartbeatTimer = setInterval(sendHeartbeat, interval);
}
// 停止自动心跳
function stopAutoHeartbeat() {
if (heartbeatTimer) {
clearInterval(heartbeatTimer);
heartbeatTimer = null;
addLog('已停止自动心跳', 'status');
}
}
// 发送自定义消息
function sendCustomMessage() {
if (!ws || ws.readyState !== WebSocket.OPEN) {
addLog('发送消息失败:当前无有效连接', 'error');
return;
}
const msgStr = messageInput.value.trim();
if (!msgStr) {
addLog('发送消息失败:消息内容不能为空', 'error');
return;
}
try {
// 验证JSON格式可选仅提示不强制
JSON.parse(msgStr);
ws.send(msgStr);
addLog(`发送自定义消息:\n${formatJson(msgStr)}`, 'send');
} catch (e) {
addLog(`JSON格式错误${e.message},仍尝试发送原始内容`, 'error');
ws.send(msgStr);
addLog(`发送自定义消息非JSON\n${msgStr}`, 'send');
}
}
// 绑定按钮事件
connectBtn.addEventListener('click', connectWebSocket);
disconnectBtn.addEventListener('click', disconnectWebSocket);
sendMessageBtn.addEventListener('click', sendCustomMessage);
sendHeartbeatBtn.addEventListener('click', sendHeartbeat);
clearLogBtn.addEventListener('click', () => {
logContainer.innerHTML = '<div class="log-item"><span class="log-time">[日志已清空]</span> 请继续操作...</div>';
});
// 自动心跳开关变更事件
autoHeartbeat.addEventListener('change', function () {
if (ws && ws.readyState === WebSocket.OPEN) {
if (this.value === 'on') {
startAutoHeartbeat();
} else {
stopAutoHeartbeat();
}
} else {
addLog('需先建立有效连接才能控制自动心跳', 'error');
// 重置选择
this.value = 'off';
}
});
// 心跳间隔变更事件(实时生效)
heartbeatInterval.addEventListener('change', function () {
if (autoHeartbeat.value === 'on' && ws && ws.readyState === WebSocket.OPEN) {
startAutoHeartbeat();
}
});
// 快捷键支持Ctrl+Enter发送消息
messageInput.addEventListener('keydown', function (e) {
if (e.ctrlKey && e.key === 'Enter') {
sendCustomMessage();
e.preventDefault();
}
});
</script>
</body>
</html>

Binary file not shown.

200
ws/ws.py Normal file
View File

@ -0,0 +1,200 @@
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI
from typing import Dict, Any, Optional
import datetime
import asyncio
import json
from contextlib import asynccontextmanager
# 创建WebSocket路由
ws_router = APIRouter()
# 客户端连接信息数据结构
class ClientConnection:
def __init__(self, websocket: WebSocket, client_ip: str):
self.websocket = websocket
self.client_ip = client_ip
self.last_heartbeat = datetime.datetime.now() # 初始心跳时间为连接时间
def update_heartbeat(self):
"""更新心跳时间为当前时间"""
self.last_heartbeat = datetime.datetime.now()
# 打印心跳更新日志
print(f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 客户端 {self.client_ip} 心跳时间已更新")
def is_alive(self, timeout_seconds: int = 60) -> bool:
"""检查客户端是否活跃心跳超时阈值60秒"""
timeout = (datetime.datetime.now() - self.last_heartbeat).total_seconds()
# 打印心跳检查明细(便于排查超时原因)
print(f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 客户端 {self.client_ip} 心跳检查:"
f"上次心跳距今 {timeout:.1f} 秒(阈值:{timeout_seconds}秒)")
return timeout < timeout_seconds
# 存储所有已连接的客户端key客户端IP、valueClientConnection对象
connected_clients: Dict[str, ClientConnection] = {}
# 心跳检查任务引用(全局变量、用于应用关闭时取消任务)
heartbeat_task: Optional[asyncio.Task] = None
async def heartbeat_checker():
"""定期检查客户端心跳每30秒一次、超时直接剔除不发通知"""
while True:
current_time = datetime.datetime.now()
print(f"\n[{current_time:%Y-%m-%d %H:%M:%S}] === 开始新一轮心跳检查 ===")
print(f"[{current_time:%Y-%m-%d %H:%M:%S}] 当前在线客户端总数:{len(connected_clients)}")
# 1. 收集超时客户端IP避免遍历中修改字典
timeout_clients = []
for client_ip, connection in connected_clients.items():
if not connection.is_alive():
timeout_clients.append(client_ip)
# 2. 处理超时客户端(关闭连接+移除记录)
if timeout_clients:
print(f"[{current_time:%Y-%m-%d %H:%M:%S}] 发现超时客户端:{timeout_clients}(共{len(timeout_clients)}个)")
for client_ip in timeout_clients:
try:
connection = connected_clients[client_ip]
# 直接关闭连接(不发送任何通知)
await connection.websocket.close(code=1008, reason="心跳超时(>60秒")
print(f"[{current_time:%Y-%m-%d %H:%M:%S}] 客户端 {client_ip} 已关闭(超时)")
except Exception as e:
print(
f"[{current_time:%Y-%m-%d %H:%M:%S}] 关闭客户端 {client_ip} 失败:{str(e)}(错误类型:{type(e).__name__}")
finally:
# 确保从客户端列表中移除(无论关闭是否成功)
if client_ip in connected_clients:
del connected_clients[client_ip]
print(f"[{current_time:%Y-%m-%d %H:%M:%S}] 客户端 {client_ip} 已从连接列表移除")
else:
print(f"[{current_time:%Y-%m-%d %H:%M:%S}] 无超时客户端、心跳检查完成")
# 3. 等待30秒后进行下一轮检查
await asyncio.sleep(30)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理:启动时创建心跳任务、关闭时取消任务"""
global heartbeat_task
# 启动阶段:创建心跳检查任务
heartbeat_task = asyncio.create_task(heartbeat_checker())
print(f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 心跳检查任务已启动任务ID{id(heartbeat_task)}")
yield # 应用运行中
# 关闭阶段:取消心跳任务
if heartbeat_task and not heartbeat_task.done():
heartbeat_task.cancel()
try:
await heartbeat_task # 等待任务优雅退出
except asyncio.CancelledError:
print(f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 心跳检查任务已正常取消")
except Exception as e:
print(f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 取消心跳任务时出错:{str(e)}")
async def send_heartbeat_ack(client_ip: str, client_timestamp: Any) -> bool:
"""向客户端回复心跳确认(严格遵循 {"timestamp":xxxxx, "type":"heartbeat"} 格式)"""
if client_ip not in connected_clients:
print(f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 回复心跳失败:客户端 {client_ip} 不在连接列表中")
return False
# 修复将这部分代码移出if语句块确保始终定义ack_msg
# 服务端当前格式化时间戳(字符串类型,与日志时间格式匹配)
server_latest_timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
ack_msg = {
"timestamp": server_latest_timestamp,
"type": "heartbeat"
}
try:
connection = connected_clients[client_ip]
await connection.websocket.send_json(ack_msg)
print(
f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 已向客户端 {client_ip} 回复心跳:{json.dumps(ack_msg, ensure_ascii=False)}")
return True
except Exception as e:
print(
f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 回复客户端 {client_ip} 心跳失败:{str(e)}(错误类型:{type(e).__name__}")
# 发送失败时移除客户端(避免无效连接残留)
if client_ip in connected_clients:
del connected_clients[client_ip]
print(f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 客户端 {client_ip} 因心跳回复失败被移除")
return False
@ws_router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket核心端点处理连接建立/消息接收/连接关闭"""
current_time = datetime.datetime.now()
# 1. 接受客户端连接请求
await websocket.accept()
# 获取客户端IP作为唯一标识
client_ip = websocket.client.host if websocket.client else "unknown_ip"
print(f"\n[{current_time:%Y-%m-%d %H:%M:%S}] 客户端 {client_ip} 连接请求已接受WebSocket握手成功")
try:
# 2. 处理"同一IP重复连接"场景:关闭旧连接、保留新连接
if client_ip in connected_clients:
old_connection = connected_clients[client_ip]
await old_connection.websocket.close(code=1008, reason="同一IP新连接已建立")
del connected_clients[client_ip]
print(f"[{current_time:%Y-%m-%d %H:%M:%S}] 已关闭客户端 {client_ip} 的旧连接(新连接已建立)")
# 3. 注册新客户端到连接列表
new_connection = ClientConnection(websocket, client_ip)
connected_clients[client_ip] = new_connection
print(
f"[{current_time:%Y-%m-%d %H:%M:%S}] 客户端 {client_ip} 已注册到连接列表、当前在线数:{len(connected_clients)}")
# 4. 循环接收客户端消息(持续监听)
while True:
# 接收原始文本消息避免提前解析JSON、便于日志打印
raw_data = await websocket.receive_text()
recv_time = datetime.datetime.now()
print(f"\n[{recv_time:%Y-%m-%d %H:%M:%S}] 收到客户端 {client_ip} 的消息:{raw_data}")
# 尝试解析JSON消息
try:
message = json.loads(raw_data)
print(
f"[{recv_time:%Y-%m-%d %H:%M:%S}] 消息解析成功:{json.dumps(message, ensure_ascii=False, indent=2)}")
# 5. 区分消息类型:仅处理心跳、其他消息不回复
if message.get("type") == "heartbeat":
# 验证心跳消息是否包含timestamp字段
client_timestamp = message.get("timestamp")
if client_timestamp is None:
print(f"[{recv_time:%Y-%m-%d %H:%M:%S}] 警告:客户端 {client_ip} 发送的心跳缺少'timestamp'字段")
continue # 不回复无效心跳
# 更新心跳时间 + 回复心跳确认
new_connection.update_heartbeat()
await send_heartbeat_ack(client_ip, client_timestamp)
else:
# 非心跳消息:仅打印日志、不回复任何内容
print(f"[{recv_time:%Y-%m-%d %H:%M:%S}] 非心跳消息(类型:{message.get('type')})、不回复")
except json.JSONDecodeError as e:
# JSON格式错误仅打印日志、不回复
print(f"[{recv_time:%Y-%m-%d %H:%M:%S}] 客户端 {client_ip} 消息格式错误无效JSON错误{str(e)}")
except Exception as e:
# 其他未知错误:仅打印日志、不回复
print(
f"[{recv_time:%Y-%m-%d %H:%M:%S}] 处理客户端 {client_ip} 消息时出错:{str(e)}(错误类型:{type(e).__name__}")
except WebSocketDisconnect as e:
# 客户端主动断开连接(如关闭页面、网络中断)
print(
f"\n[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 客户端 {client_ip} 主动断开连接(代码:{e.code}、原因:{e.reason}")
except Exception as e:
# 其他连接级错误(如网络异常)
print(
f"\n[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 客户端 {client_ip} 连接异常:{str(e)}(错误类型:{type(e).__name__}")
finally:
# 无论何种退出原因、确保客户端从列表中移除
if client_ip in connected_clients:
del connected_clients[client_ip]
print(
f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 客户端 {client_ip} 已从连接列表移除、当前在线数:{len(connected_clients)}")