commit fe1b33a6e5e8f6137a96675e18bdb20459cffe55 Author: ZZX9599 <536509593@qq.com> Date: Tue Sep 2 18:51:50 2025 +0800 初始化 diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/Video.iml b/.idea/Video.iml new file mode 100644 index 0000000..8f67bb8 --- /dev/null +++ b/.idea/Video.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..9f642d0 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,98 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..0f99d01 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..9e5508f --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/__pycache__/main.cpython-312.pyc b/__pycache__/main.cpython-312.pyc new file mode 100644 index 0000000..47530bf Binary files /dev/null and b/__pycache__/main.cpython-312.pyc differ diff --git a/config.ini b/config.ini new file mode 100644 index 0000000..9f431f7 --- /dev/null +++ b/config.ini @@ -0,0 +1,19 @@ +[server] +port = 8000 + +[mysql] +host = 192.168.110.65 +port = 6975 +user = video_check +password = fsjPfhxCs8NrFGmL +database = video_check +charset = utf8mb4 + +[jwt] +secret_key = 6tsieyd87wefdw2wgeduwte23rfcsd +algorithm = HS256 +access_token_expire_minutes = 30 + +[live] +rtmp_url = rtmp://192.168.110.65:1935/live/ +webrtc_url = http://192.168.110.65:1985/rtc/v1/whep/?app=live&stream= diff --git a/ds/__pycache__/config.cpython-312.pyc b/ds/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000..c86a219 Binary files /dev/null and b/ds/__pycache__/config.cpython-312.pyc differ diff --git a/ds/__pycache__/db.cpython-312.pyc b/ds/__pycache__/db.cpython-312.pyc new file mode 100644 index 0000000..cf7f378 Binary files /dev/null and b/ds/__pycache__/db.cpython-312.pyc differ diff --git a/ds/config.py b/ds/config.py new file mode 100644 index 0000000..ced1b84 --- /dev/null +++ b/ds/config.py @@ -0,0 +1,17 @@ +import configparser +import os + +# 读取配置文件路径 +config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../config.ini") + +# 初始化配置解析器 +config = configparser.ConfigParser() + +# 读取配置文件 +config.read(config_path, encoding="utf-8") + +# 暴露配置项(方便其他文件调用) +SERVER_CONFIG = config["server"] +MYSQL_CONFIG = config["mysql"] +JWT_CONFIG = config["jwt"] +LIVE_CONFIG = config["live"] diff --git a/ds/db.py b/ds/db.py new file mode 100644 index 0000000..ae891b7 --- /dev/null +++ b/ds/db.py @@ -0,0 +1,46 @@ +import mysql.connector +from mysql.connector import Error + +from .config import MYSQL_CONFIG + + +class Database: + """MySQL 连接池管理类""" + pool_config = { + "host": MYSQL_CONFIG.get("host", "localhost"), + "port": int(MYSQL_CONFIG.get("port", 3306)), + "user": MYSQL_CONFIG.get("user", "root"), + "password": MYSQL_CONFIG.get("password", ""), + "database": MYSQL_CONFIG.get("database", ""), + "charset": MYSQL_CONFIG.get("charset", "utf8mb4"), + "pool_name": "fastapi_pool", + "pool_size": 5, + "pool_reset_session": True + } + + @classmethod + def get_connection(cls): + """获取数据库连接""" + try: + # 从连接池获取连接 + conn = mysql.connector.connect(**cls.pool_config) + if conn.is_connected(): + return conn + except Error as e: + # 抛出数据库连接错误(会被全局异常处理器捕获) + raise Exception(f"MySQL 连接失败: {str(e)}") from e + + @classmethod + def close_connection(cls, conn, cursor=None): + """关闭连接和游标""" + try: + if cursor: + cursor.close() + if conn and conn.is_connected(): + conn.close() + except Error as e: + raise Exception(f"MySQL 连接关闭失败: {str(e)}") from e + + +# 暴露数据库操作工具 +db = Database() diff --git a/main.py b/main.py new file mode 100644 index 0000000..12cfe33 --- /dev/null +++ b/main.py @@ -0,0 +1,43 @@ +import uvicorn +from fastapi import FastAPI + +from ds.config import SERVER_CONFIG +from middle.error_handler import global_exception_handler +from service.user_service import router as user_router +from service.device_service import router as device_router +from ws.ws import ws_router, lifespan + +# ------------------------------ +# 初始化 FastAPI 应用、指定生命周期管理 +# ------------------------------ +app = FastAPI( + title="内容安全审核后台", + description="内容安全审核后台", + version="1.0.0", + lifespan=lifespan +) + +# ------------------------------ +# 注册路由 +# ------------------------------ +app.include_router(user_router) +app.include_router(device_router) +app.include_router(ws_router) + +# ------------------------------ +# 注册全局异常处理器 +# ------------------------------ +app.add_exception_handler(Exception, global_exception_handler) + +# ------------------------------ +# 启动服务 +# ------------------------------ +if __name__ == "__main__": + port = int(SERVER_CONFIG.get("port", 8000)) + uvicorn.run( + app="main:app", + host="0.0.0.0", + port=port, + reload=True, + ws="websockets" + ) diff --git a/middle/__pycache__/auth_middleware.cpython-312.pyc b/middle/__pycache__/auth_middleware.cpython-312.pyc new file mode 100644 index 0000000..e11b84e Binary files /dev/null and b/middle/__pycache__/auth_middleware.cpython-312.pyc differ diff --git a/middle/__pycache__/error_handler.cpython-312.pyc b/middle/__pycache__/error_handler.cpython-312.pyc new file mode 100644 index 0000000..7728b6c Binary files /dev/null and b/middle/__pycache__/error_handler.cpython-312.pyc differ diff --git a/middle/auth_middleware.py b/middle/auth_middleware.py new file mode 100644 index 0000000..9cac02d --- /dev/null +++ b/middle/auth_middleware.py @@ -0,0 +1,96 @@ +from datetime import datetime, timedelta, timezone +from typing import Optional + +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer +from jose import JWTError, jwt +from passlib.context import CryptContext + +from ds.config import JWT_CONFIG +from ds.db import db +from service.user_service import UserResponse + +# ------------------------------ +# 密码加密配置 +# ------------------------------ +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + +# ------------------------------ +# JWT 配置 +# ------------------------------ +SECRET_KEY = JWT_CONFIG["secret_key"] +ALGORITHM = JWT_CONFIG["algorithm"] +ACCESS_TOKEN_EXPIRE_MINUTES = int(JWT_CONFIG["access_token_expire_minutes"]) + +# OAuth2 依赖(从请求头获取 Token、格式:Bearer ) +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login") + +# ------------------------------ +# 密码工具函数 +# ------------------------------ +def verify_password(plain_password: str, hashed_password: str) -> bool: + """验证明文密码与加密密码是否匹配""" + return pwd_context.verify(plain_password, hashed_password) + +def get_password_hash(password: str) -> str: + """对明文密码进行 bcrypt 加密""" + return pwd_context.hash(password) + +# ------------------------------ +# JWT 工具函数 +# ------------------------------ +def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: + """生成 JWT Token""" + to_encode = data.copy() + # 设置过期时间 + if expires_delta: + expire = datetime.now(timezone.utc) + expires_delta + else: + expire = datetime.now(timezone.utc) + timedelta(minutes=15) + # 添加过期时间到 Token 数据 + to_encode.update({"exp": expire}) + # 生成 Token + encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt + +# ------------------------------ +# 认证依赖(获取当前登录用户) +# ------------------------------ +def get_current_user(token: str = Depends(oauth2_scheme)) -> UserResponse: + """从 Token 中解析用户信息、验证通过后返回当前用户""" + # 认证失败异常 + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Token 无效或已过期", + headers={"WWW-Authenticate": "Bearer"}, + ) + + try: + # 解码 Token + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + # 获取 Token 中的用户名 + username: str = payload.get("sub") + if username is None: + raise credentials_exception + except JWTError: + raise credentials_exception + + # 从数据库查询用户(验证用户是否存在) + conn = None + cursor = None + try: + conn = db.get_connection() + cursor = conn.cursor(dictionary=True) # 返回字典格式结果 + query = "SELECT id, username, created_at, updated_at FROM users WHERE username = %s" + cursor.execute(query, (username,)) + user = cursor.fetchone() + + if user is None: + raise credentials_exception # 用户不存在 + + # 转换为 UserResponse 模型(自动校验字段) + return UserResponse(** user) + except Exception as e: + raise credentials_exception from e + finally: + db.close_connection(conn, cursor) \ No newline at end of file diff --git a/middle/error_handler.py b/middle/error_handler.py new file mode 100644 index 0000000..11521eb --- /dev/null +++ b/middle/error_handler.py @@ -0,0 +1,68 @@ +from fastapi import Request, status +from fastapi.responses import JSONResponse +from fastapi.exceptions import HTTPException, RequestValidationError +from mysql.connector import Error as MySQLError +from jose import JWTError + +from schema.response_schema import APIResponse + + +async def global_exception_handler(request: Request, exc: Exception): + """全局异常处理器:所有未捕获的异常都会在这里统一处理""" + # 1. 请求参数验证错误(Pydantic 校验失败) + if isinstance(exc, RequestValidationError): + error_details = [] + for err in exc.errors(): + error_details.append(f"{err['loc'][1]}: {err['msg']}") + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=APIResponse( + code=400, + message=f"请求参数错误:{'; '.join(error_details)}", + data=None + ).model_dump() + ) + + # 2. HTTP 异常(主动抛出的业务错误、如 401/404) + if isinstance(exc, HTTPException): + return JSONResponse( + status_code=exc.status_code, + content=APIResponse( + code=exc.status_code, + message=exc.detail, + data=None + ).model_dump() + ) + + # 3. JWT 相关错误(Token 无效/过期) + if isinstance(exc, JWTError): + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content=APIResponse( + code=401, + message="Token 无效或已过期", + data=None + ).model_dump(), + headers={"WWW-Authenticate": "Bearer"}, + ) + + # 4. MySQL 数据库错误 + if isinstance(exc, MySQLError): + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content=APIResponse( + code=500, + message=f"数据库错误:{str(exc)}", + data=None + ).model_dump() + ) + + # 5. 其他未知错误(兜底处理) + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content=APIResponse( + code=500, + message=f"服务器内部错误:{str(exc)}", + data=None + ).model_dump() + ) diff --git a/schema/__pycache__/device_schema.cpython-312.pyc b/schema/__pycache__/device_schema.cpython-312.pyc new file mode 100644 index 0000000..01a7f7e Binary files /dev/null and b/schema/__pycache__/device_schema.cpython-312.pyc differ diff --git a/schema/__pycache__/response_schema.cpython-312.pyc b/schema/__pycache__/response_schema.cpython-312.pyc new file mode 100644 index 0000000..fae4d8b Binary files /dev/null and b/schema/__pycache__/response_schema.cpython-312.pyc differ diff --git a/schema/__pycache__/user_schema.cpython-312.pyc b/schema/__pycache__/user_schema.cpython-312.pyc new file mode 100644 index 0000000..3d6f487 Binary files /dev/null and b/schema/__pycache__/user_schema.cpython-312.pyc differ diff --git a/schema/device_schema.py b/schema/device_schema.py new file mode 100644 index 0000000..bf80632 --- /dev/null +++ b/schema/device_schema.py @@ -0,0 +1,51 @@ +import hashlib +from datetime import datetime +from typing import Optional, List, Dict + +from pydantic import BaseModel, Field + + +# ------------------------------ +# 请求模型(前端传参校验) +# ------------------------------ +class DeviceCreateRequest(BaseModel): + """设备流信息创建请求模型""" + ip: Optional[str] = Field(..., max_length=100, description="设备IP地址") + hostname: Optional[str] = Field(None, max_length=100, description="设备别名") + params: Optional[Dict] = Field(None, description="设备详细信息") + + +def md5_encrypt(text: str) -> str: + """对字符串进行MD5加密""" + if not text: + return "" + md5_hash = hashlib.md5() + md5_hash.update(text.encode('utf-8')) + return md5_hash.hexdigest() + + +# ------------------------------ +# 响应模型(后端返回设备数据) +# ------------------------------ +class DeviceResponse(BaseModel): + """设备流信息响应模型(字段与表结构完全对齐)""" + id: int = Field(..., description="设备ID") + hostname: Optional[str] = Field(None, max_length=100, description="设备别名") + rtmp_push_url: Optional[str] = Field(None, description="需要推送的RTMP地址") + live_webrtc_url: Optional[str] = Field(None, description="直播的Webrtc地址") + detection_webrtc_url: Optional[str] = Field(None, description="检测的Webrtc地址") + device_online_status: int = Field(..., description="设备在线状态(1-在线、0-离线)") + device_type: Optional[str] = Field(None, description="设备类型") + alarm_count: int = Field(..., description="报警次数") + params: Optional[str] = Field(None, description="设备详细信息") + created_at: datetime = Field(..., description="记录创建时间") + updated_at: datetime = Field(..., description="记录更新时间") + + # 支持从数据库查询结果转换 + model_config = {"from_attributes": True} + + +class DeviceListResponse(BaseModel): + """设备流信息列表响应模型""" + total: int = Field(..., description="设备总数") + devices: List[DeviceResponse] = Field(..., description="设备列表") diff --git a/schema/response_schema.py b/schema/response_schema.py new file mode 100644 index 0000000..0461a3b --- /dev/null +++ b/schema/response_schema.py @@ -0,0 +1,13 @@ +from typing import Optional, Any + +from pydantic import BaseModel, Field + + +class APIResponse(BaseModel): + """统一 API 响应模型(所有接口必返此格式)""" + code: int = Field(..., description="状态码:200=成功、4xx=客户端错误、5xx=服务端错误") + message: str = Field(..., description="响应信息:成功/错误描述") + data: Optional[Any] = Field(None, description="响应数据:成功时返回、错误时为 None") + + # Pydantic V2 配置(支持从 ORM 对象转换) + model_config = {"from_attributes": True} diff --git a/schema/user_schema.py b/schema/user_schema.py new file mode 100644 index 0000000..6d8d9b1 --- /dev/null +++ b/schema/user_schema.py @@ -0,0 +1,32 @@ +from datetime import datetime + +from pydantic import BaseModel, Field + + +# ------------------------------ +# 请求模型(前端传参校验) +# ------------------------------ +class UserRegisterRequest(BaseModel): + """用户注册请求模型""" + username: str = Field(..., min_length=3, max_length=50, description="用户名(3-50字符)") + password: str = Field(..., min_length=6, max_length=100, description="密码(6-100字符)") + + +class UserLoginRequest(BaseModel): + """用户登录请求模型""" + username: str = Field(..., description="用户名") + password: str = Field(..., description="密码") + + +# ------------------------------ +# 响应模型(后端返回用户数据) +# ------------------------------ +class UserResponse(BaseModel): + """用户信息响应模型(隐藏密码等敏感字段)""" + id: int = Field(..., description="用户ID") + username: str = Field(..., description="用户名") + created_at: datetime = Field(..., description="创建时间") + updated_at: datetime = Field(..., description="更新时间") + + # Pydantic V2 配置(支持从数据库查询结果转换) + model_config = {"from_attributes": True} diff --git a/service/__pycache__/device_service.cpython-312.pyc b/service/__pycache__/device_service.cpython-312.pyc new file mode 100644 index 0000000..dba856b Binary files /dev/null and b/service/__pycache__/device_service.cpython-312.pyc differ diff --git a/service/__pycache__/user_service.cpython-312.pyc b/service/__pycache__/user_service.cpython-312.pyc new file mode 100644 index 0000000..af778d2 Binary files /dev/null and b/service/__pycache__/user_service.cpython-312.pyc differ diff --git a/service/device_service.py b/service/device_service.py new file mode 100644 index 0000000..6431193 --- /dev/null +++ b/service/device_service.py @@ -0,0 +1,251 @@ +import json + +from fastapi import HTTPException, Query, APIRouter, Depends, Request +from mysql.connector import Error as MySQLError + +from ds.config import LIVE_CONFIG +from ds.db import db +from middle.auth_middleware import get_current_user +# 注意:导入的Schema已更新字段 +from schema.device_schema import ( + DeviceCreateRequest, + DeviceResponse, + DeviceListResponse, + md5_encrypt +) +from schema.response_schema import APIResponse +from schema.user_schema import UserResponse + +router = APIRouter( + prefix="/devices", + tags=["设备管理"] +) + + +# ------------------------------ +# 1. 创建设备信息 +# ------------------------------ +@router.post("/add", response_model=APIResponse, summary="创建设备信息") +async def create_device(request: Request, device_data: DeviceCreateRequest): + conn = None + cursor = None + try: + conn = db.get_connection() + cursor = conn.cursor(dictionary=True) + + # 新增:检查client_ip是否已存在 + cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (device_data.ip,)) + existing_device = cursor.fetchone() + if existing_device: + raise Exception(f"客户端IP {device_data.ip} 已存在,无法重复添加") + + # 获取RTMP URL + rtmp_url = str(LIVE_CONFIG.get("rtmp_url", "")) + webrtc_url = str(LIVE_CONFIG.get("webrtc_url", "")) + + # 将设备详细信息(params)转换为JSON字符串(对应表中params字段) + device_params_json = json.dumps(device_data.params) if device_data.params else None + + # 对JSON字符串进行MD5加密(用于生成唯一RTMP地址) + device_md5 = md5_encrypt(device_params_json) if device_params_json else "" + + # 解析User-Agent获取设备类型 + user_agent = request.headers.get("User-Agent", "").lower() + + # 优先处理User-Agent为default的情况 + if user_agent == "default": + # 检查params中是否存在os键 + if device_data.params and isinstance(device_data.params, dict) and "os" in device_data.params: + device_type = device_data.params["os"] + else: + device_type = "unknown" + elif "windows" in user_agent: + device_type = "windows" + elif "android" in user_agent: + device_type = "android" + elif "linux" in user_agent: + device_type = "linux" + else: + device_type = "unknown" + + # SQL字段对齐表结构 + insert_query = """ + INSERT INTO devices + (client_ip, hostname, rtmp_push_url, live_webrtc_url, detection_webrtc_url, + device_online_status, device_type, alarm_count, params) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) + """ + cursor.execute(insert_query, ( + device_data.ip, + device_data.hostname, + rtmp_url + device_md5, + webrtc_url + device_md5, + "", + 1, + device_type, + 0, + device_params_json + )) + conn.commit() + + # 获取刚创建的设备信息 + device_id = cursor.lastrowid + cursor.execute("SELECT * FROM devices WHERE id = %s", (device_id,)) + device = cursor.fetchone() + + return APIResponse( + code=200, + message="设备创建成功", + data=DeviceResponse(**device) + ) + except MySQLError as e: + if conn: + conn.rollback() + raise Exception(f"创建设备失败:{str(e)}") from e + except json.JSONDecodeError as e: + raise Exception(f"设备信息JSON序列化失败:{str(e)}") from e + except Exception as e: + # 捕获IP已存在的自定义异常 + if conn: + conn.rollback() + raise e + finally: + db.close_connection(conn, cursor) + + +# ------------------------------ +# 2. 获取设备列表 +# ------------------------------ +@router.get("/", response_model=APIResponse, summary="获取设备列表") +async def get_device_list( + page: int = Query(1, ge=1, description="页码"), + page_size: int = Query(10, ge=1, le=100, description="每页条数"), + device_type: str = Query(None, description="设备类型筛选"), + online_status: int = Query(None, ge=0, le=1, description="在线状态筛选(1-在线、0-离线)") +): + conn = None + cursor = None + try: + conn = db.get_connection() + cursor = conn.cursor(dictionary=True) + + # 构建查询条件 + where_clause = [] + params = [] + + if device_type: + where_clause.append("device_type = %s") + params.append(device_type) + + if online_status is not None: + where_clause.append("device_online_status = %s") + params.append(online_status) + + # 总条数查询 + count_query = "SELECT COUNT(*) as total FROM devices" + if where_clause: + count_query += " WHERE " + " AND ".join(where_clause) + + cursor.execute(count_query, params) + total = cursor.fetchone()["total"] + + # 分页查询(SELECT * 会自动匹配表字段、响应模型已对齐) + offset = (page - 1) * page_size + query = "SELECT * FROM devices" + if where_clause: + query += " WHERE " + " AND ".join(where_clause) + query += " ORDER BY id DESC LIMIT %s OFFSET %s" + params.extend([page_size, offset]) + + cursor.execute(query, params) + devices = cursor.fetchall() + + # 响应模型已更新为params字段、直接转换即可 + device_list = [DeviceResponse(**device) for device in devices] + + return APIResponse( + code=200, + message="获取设备列表成功", + data=DeviceListResponse(total=total, devices=device_list) + ) + except MySQLError as e: + raise Exception(f"获取设备列表失败:{str(e)}") from e + finally: + db.close_connection(conn, cursor) + + +# ------------------------------ +# 3. 获取单个设备详情 +# ------------------------------ +@router.get("/{device_id}", response_model=APIResponse, summary="获取设备详情") +async def get_device_detail( + device_id: int, + current_user: UserResponse = Depends(get_current_user) +): + conn = None + cursor = None + try: + conn = db.get_connection() + cursor = conn.cursor(dictionary=True) + + # 查询设备信息(SELECT * 匹配表字段) + query = "SELECT * FROM devices WHERE id = %s" + cursor.execute(query, (device_id,)) + device = cursor.fetchone() + + if not device: + raise HTTPException( + status_code=404, + detail=f"设备ID为 {device_id} 的设备不存在" + ) + + # 响应模型已更新为params字段 + return APIResponse( + code=200, + message="获取设备详情成功", + data=DeviceResponse(**device) + ) + except MySQLError as e: + raise Exception(f"获取设备详情失败:{str(e)}") from e + finally: + db.close_connection(conn, cursor) + + +# ------------------------------ +# 4. 删除设备信息 +# ------------------------------ +@router.delete("/{device_id}", response_model=APIResponse, summary="删除设备信息") +async def delete_device( + device_id: int, + current_user: UserResponse = Depends(get_current_user) +): + conn = None + cursor = None + try: + conn = db.get_connection() + cursor = conn.cursor(dictionary=True) + + # 检查设备是否存在 + cursor.execute("SELECT id FROM devices WHERE id = %s", (device_id,)) + if not cursor.fetchone(): + raise HTTPException( + status_code=404, + detail=f"设备ID为 {device_id} 的设备不存在" + ) + + # 执行删除 + delete_query = "DELETE FROM devices WHERE id = %s" + cursor.execute(delete_query, (device_id,)) + conn.commit() + + return APIResponse( + code=200, + message=f"设备ID为 {device_id} 的设备已成功删除", + data=None + ) + except MySQLError as e: + if conn: + conn.rollback() + raise Exception(f"删除设备失败:{str(e)}") from e + finally: + db.close_connection(conn, cursor) diff --git a/service/user_service.py b/service/user_service.py new file mode 100644 index 0000000..4ab04ec --- /dev/null +++ b/service/user_service.py @@ -0,0 +1,154 @@ +from datetime import timedelta + +from fastapi import APIRouter, Depends, HTTPException +from mysql.connector import Error as MySQLError + +from ds.db import db +from schema.user_schema import UserRegisterRequest, UserLoginRequest, UserResponse +from schema.response_schema import APIResponse +from middle.auth_middleware import ( + get_password_hash, + verify_password, + create_access_token, + ACCESS_TOKEN_EXPIRE_MINUTES, + get_current_user +) + +# 创建用户接口路由(前缀 /users、标签用于 Swagger 分类) +router = APIRouter( + prefix="/users", + tags=["用户管理"] +) + + +# ------------------------------ +# 1. 用户注册接口 +# ------------------------------ +@router.post("/register", response_model=APIResponse, summary="用户注册") +async def user_register(request: UserRegisterRequest): + """ + 用户注册: + - 校验用户名是否已存在 + - 加密密码后插入数据库 + - 返回注册成功信息 + """ + conn = None + cursor = None + try: + conn = db.get_connection() + cursor = conn.cursor(dictionary=True) + + # 1. 检查用户名是否已存在(唯一索引) + check_query = "SELECT username FROM users WHERE username = %s" + cursor.execute(check_query, (request.username,)) + existing_user = cursor.fetchone() + if existing_user: + raise HTTPException( + status_code=400, + detail=f"用户名 '{request.username}' 已存在、请更换其他用户名" + ) + + # 2. 加密密码 + hashed_password = get_password_hash(request.password) + + # 3. 插入新用户到数据库 + insert_query = """ + INSERT INTO users (username, password) + VALUES (%s, %s) + """ + cursor.execute(insert_query, (request.username, hashed_password)) + conn.commit() # 提交事务 + + # 4. 返回注册成功响应 + return APIResponse( + code=201, # 201 表示资源创建成功 + message=f"用户 '{request.username}' 注册成功", + data=None + ) + except MySQLError as e: + conn.rollback() # 数据库错误时回滚事务 + raise Exception(f"注册失败:{str(e)}") from e + finally: + db.close_connection(conn, cursor) + + +# ------------------------------ +# 2. 用户登录接口 +# ------------------------------ +@router.post("/login", response_model=APIResponse, summary="用户登录(获取 Token)") +async def user_login(request: UserLoginRequest): + """ + 用户登录: + - 校验用户名是否存在 + - 校验密码是否正确 + - 生成 JWT Token 并返回 + """ + conn = None + cursor = None + try: + conn = db.get_connection() + cursor = conn.cursor(dictionary=True) + + # 修复:SQL查询添加 created_at 和 updated_at 字段 + query = """ + SELECT id, username, password, created_at, updated_at + FROM users + WHERE username = %s + """ + cursor.execute(query, (request.username,)) + user = cursor.fetchone() + + # 2. 校验用户名和密码 + if not user or not verify_password(request.password, user["password"]): + raise HTTPException( + status_code=401, + detail="用户名或密码错误", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # 3. 生成 Token(过期时间从配置读取) + access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + access_token = create_access_token( + data={"sub": user["username"]}, + expires_delta=access_token_expires + ) + + # 4. 返回 Token 和用户基本信息 + return APIResponse( + code=200, + message="登录成功", + data={ + "access_token": access_token, + "token_type": "bearer", + "user": UserResponse( + id=user["id"], + username=user["username"], + created_at=user.get("created_at"), + updated_at=user.get("updated_at") + ) + } + ) + except MySQLError as e: + raise Exception(f"登录失败:{str(e)}") from e + finally: + db.close_connection(conn, cursor) + + +# ------------------------------ +# 3. 获取当前登录用户信息(需认证) +# ------------------------------ +@router.get("/me", response_model=APIResponse, summary="获取当前用户信息") +async def get_current_user_info( + current_user: UserResponse = Depends(get_current_user) # 依赖认证中间件 +): + """ + 获取当前登录用户信息: + - 需在请求头携带 Token(格式:Bearer ) + - 认证通过后返回用户信息 + """ + return APIResponse( + code=200, + message="获取用户信息成功", + data=current_user + ) + diff --git a/ws.html b/ws.html new file mode 100644 index 0000000..d81ceb2 --- /dev/null +++ b/ws.html @@ -0,0 +1,482 @@ + + + + + + WebSocket 测试工具 + + + +
+

WebSocket 测试工具

+ + +
+
连接状态:
+
未连接
+
服务地址:
+
ws://192.168.110.25:8000/ws
+
连接时间:
+
-
+
+ + +
+ + + + +
+ + + + + +
+
+ + +
+

发送自定义消息

+ + +
+ + +
+

消息日志

+
+
[加载完成] 请点击「建立连接」开始测试
+
+ +
+
+ + + + \ No newline at end of file diff --git a/ws/__pycache__/ws.cpython-312.pyc b/ws/__pycache__/ws.cpython-312.pyc new file mode 100644 index 0000000..c5b72cf Binary files /dev/null and b/ws/__pycache__/ws.cpython-312.pyc differ diff --git a/ws/ws.py b/ws/ws.py new file mode 100644 index 0000000..738c4f3 --- /dev/null +++ b/ws/ws.py @@ -0,0 +1,200 @@ +from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI +from typing import Dict, Any, Optional +import datetime +import asyncio +import json +from contextlib import asynccontextmanager + +# 创建WebSocket路由 +ws_router = APIRouter() + + +# 客户端连接信息数据结构 +class ClientConnection: + def __init__(self, websocket: WebSocket, client_ip: str): + self.websocket = websocket + self.client_ip = client_ip + self.last_heartbeat = datetime.datetime.now() # 初始心跳时间为连接时间 + + def update_heartbeat(self): + """更新心跳时间为当前时间""" + self.last_heartbeat = datetime.datetime.now() + # 打印心跳更新日志 + print(f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 客户端 {self.client_ip} 心跳时间已更新") + + def is_alive(self, timeout_seconds: int = 60) -> bool: + """检查客户端是否活跃(心跳超时阈值:60秒)""" + timeout = (datetime.datetime.now() - self.last_heartbeat).total_seconds() + # 打印心跳检查明细(便于排查超时原因) + print(f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 客户端 {self.client_ip} 心跳检查:" + f"上次心跳距今 {timeout:.1f} 秒(阈值:{timeout_seconds}秒)") + return timeout < timeout_seconds + + +# 存储所有已连接的客户端(key:客户端IP、value:ClientConnection对象) +connected_clients: Dict[str, ClientConnection] = {} + +# 心跳检查任务引用(全局变量、用于应用关闭时取消任务) +heartbeat_task: Optional[asyncio.Task] = None + + +async def heartbeat_checker(): + """定期检查客户端心跳(每30秒一次)、超时直接剔除(不发通知)""" + while True: + current_time = datetime.datetime.now() + print(f"\n[{current_time:%Y-%m-%d %H:%M:%S}] === 开始新一轮心跳检查 ===") + print(f"[{current_time:%Y-%m-%d %H:%M:%S}] 当前在线客户端总数:{len(connected_clients)}") + + # 1. 收集超时客户端IP(避免遍历中修改字典) + timeout_clients = [] + for client_ip, connection in connected_clients.items(): + if not connection.is_alive(): + timeout_clients.append(client_ip) + + # 2. 处理超时客户端(关闭连接+移除记录) + if timeout_clients: + print(f"[{current_time:%Y-%m-%d %H:%M:%S}] 发现超时客户端:{timeout_clients}(共{len(timeout_clients)}个)") + for client_ip in timeout_clients: + try: + connection = connected_clients[client_ip] + # 直接关闭连接(不发送任何通知) + await connection.websocket.close(code=1008, reason="心跳超时(>60秒)") + print(f"[{current_time:%Y-%m-%d %H:%M:%S}] 客户端 {client_ip} 已关闭(超时)") + except Exception as e: + print( + f"[{current_time:%Y-%m-%d %H:%M:%S}] 关闭客户端 {client_ip} 失败:{str(e)}(错误类型:{type(e).__name__})") + finally: + # 确保从客户端列表中移除(无论关闭是否成功) + if client_ip in connected_clients: + del connected_clients[client_ip] + print(f"[{current_time:%Y-%m-%d %H:%M:%S}] 客户端 {client_ip} 已从连接列表移除") + else: + print(f"[{current_time:%Y-%m-%d %H:%M:%S}] 无超时客户端、心跳检查完成") + + # 3. 等待30秒后进行下一轮检查 + await asyncio.sleep(30) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """应用生命周期管理:启动时创建心跳任务、关闭时取消任务""" + global heartbeat_task + # 启动阶段:创建心跳检查任务 + heartbeat_task = asyncio.create_task(heartbeat_checker()) + print(f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 心跳检查任务已启动(任务ID:{id(heartbeat_task)})") + yield # 应用运行中 + # 关闭阶段:取消心跳任务 + if heartbeat_task and not heartbeat_task.done(): + heartbeat_task.cancel() + try: + await heartbeat_task # 等待任务优雅退出 + except asyncio.CancelledError: + print(f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 心跳检查任务已正常取消") + except Exception as e: + print(f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 取消心跳任务时出错:{str(e)}") + + +async def send_heartbeat_ack(client_ip: str, client_timestamp: Any) -> bool: + """向客户端回复心跳确认(严格遵循 {"timestamp":xxxxx, "type":"heartbeat"} 格式)""" + if client_ip not in connected_clients: + print(f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 回复心跳失败:客户端 {client_ip} 不在连接列表中") + return False + + # 修复:将这部分代码移出if语句块,确保始终定义ack_msg + # 服务端当前格式化时间戳(字符串类型,与日志时间格式匹配) + server_latest_timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + ack_msg = { + "timestamp": server_latest_timestamp, + "type": "heartbeat" + } + + try: + connection = connected_clients[client_ip] + await connection.websocket.send_json(ack_msg) + print( + f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 已向客户端 {client_ip} 回复心跳:{json.dumps(ack_msg, ensure_ascii=False)}") + return True + except Exception as e: + print( + f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 回复客户端 {client_ip} 心跳失败:{str(e)}(错误类型:{type(e).__name__})") + # 发送失败时移除客户端(避免无效连接残留) + if client_ip in connected_clients: + del connected_clients[client_ip] + print(f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 客户端 {client_ip} 因心跳回复失败被移除") + return False + + +@ws_router.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket): + """WebSocket核心端点:处理连接建立/消息接收/连接关闭""" + current_time = datetime.datetime.now() + # 1. 接受客户端连接请求 + await websocket.accept() + # 获取客户端IP(作为唯一标识) + client_ip = websocket.client.host if websocket.client else "unknown_ip" + print(f"\n[{current_time:%Y-%m-%d %H:%M:%S}] 客户端 {client_ip} 连接请求已接受(WebSocket握手成功)") + + try: + # 2. 处理"同一IP重复连接"场景:关闭旧连接、保留新连接 + if client_ip in connected_clients: + old_connection = connected_clients[client_ip] + await old_connection.websocket.close(code=1008, reason="同一IP新连接已建立") + del connected_clients[client_ip] + print(f"[{current_time:%Y-%m-%d %H:%M:%S}] 已关闭客户端 {client_ip} 的旧连接(新连接已建立)") + + # 3. 注册新客户端到连接列表 + new_connection = ClientConnection(websocket, client_ip) + connected_clients[client_ip] = new_connection + print( + f"[{current_time:%Y-%m-%d %H:%M:%S}] 客户端 {client_ip} 已注册到连接列表、当前在线数:{len(connected_clients)}") + + # 4. 循环接收客户端消息(持续监听) + while True: + # 接收原始文本消息(避免提前解析JSON、便于日志打印) + raw_data = await websocket.receive_text() + recv_time = datetime.datetime.now() + print(f"\n[{recv_time:%Y-%m-%d %H:%M:%S}] 收到客户端 {client_ip} 的消息:{raw_data}") + + # 尝试解析JSON消息 + try: + message = json.loads(raw_data) + print( + f"[{recv_time:%Y-%m-%d %H:%M:%S}] 消息解析成功:{json.dumps(message, ensure_ascii=False, indent=2)}") + + # 5. 区分消息类型:仅处理心跳、其他消息不回复 + if message.get("type") == "heartbeat": + # 验证心跳消息是否包含timestamp字段 + client_timestamp = message.get("timestamp") + if client_timestamp is None: + print(f"[{recv_time:%Y-%m-%d %H:%M:%S}] 警告:客户端 {client_ip} 发送的心跳缺少'timestamp'字段") + continue # 不回复无效心跳 + + # 更新心跳时间 + 回复心跳确认 + new_connection.update_heartbeat() + await send_heartbeat_ack(client_ip, client_timestamp) + else: + # 非心跳消息:仅打印日志、不回复任何内容 + print(f"[{recv_time:%Y-%m-%d %H:%M:%S}] 非心跳消息(类型:{message.get('type')})、不回复") + + except json.JSONDecodeError as e: + # JSON格式错误:仅打印日志、不回复 + print(f"[{recv_time:%Y-%m-%d %H:%M:%S}] 客户端 {client_ip} 消息格式错误:无效JSON(错误:{str(e)})") + except Exception as e: + # 其他未知错误:仅打印日志、不回复 + print( + f"[{recv_time:%Y-%m-%d %H:%M:%S}] 处理客户端 {client_ip} 消息时出错:{str(e)}(错误类型:{type(e).__name__})") + + except WebSocketDisconnect as e: + # 客户端主动断开连接(如关闭页面、网络中断) + print( + f"\n[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 客户端 {client_ip} 主动断开连接(代码:{e.code}、原因:{e.reason})") + except Exception as e: + # 其他连接级错误(如网络异常) + print( + f"\n[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 客户端 {client_ip} 连接异常:{str(e)}(错误类型:{type(e).__name__})") + finally: + # 无论何种退出原因、确保客户端从列表中移除 + if client_ip in connected_clients: + del connected_clients[client_ip] + print( + f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] 客户端 {client_ip} 已从连接列表移除、当前在线数:{len(connected_clients)}") \ No newline at end of file