yolo模型识别不到
This commit is contained in:
		
							
								
								
									
										10
									
								
								core/all.py
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								core/all.py
									
									
									
									
									
								
							| @ -67,7 +67,7 @@ def detect(client_ip, frame): | ||||
|         if full_save_path: | ||||
|             cv2.imwrite(full_save_path, frame) | ||||
|             print(f"✅ yolo违规图片已保存:{display_path}")  # 日志也修正 | ||||
|         save_db(model_type="yolo", client_ip=client_ip, result=str(full_save_path)) | ||||
|         save_db(model_type="yolo", client_ip=client_ip, result=str(display_path)) | ||||
|         return (True, yolo_result, "yolo") | ||||
|  | ||||
|     # 2. 人脸检测 | ||||
| @ -77,17 +77,19 @@ def detect(client_ip, frame): | ||||
|         if full_save_path: | ||||
|             cv2.imwrite(full_save_path, frame) | ||||
|             print(f"✅ face违规图片已保存:{display_path}")  # 日志也修正 | ||||
|         save_db(model_type="face", client_ip=client_ip, result=str(full_save_path)) | ||||
|         save_db(model_type="face", client_ip=client_ip, result=str(display_path)) | ||||
|         return (True, face_result, "face") | ||||
|  | ||||
|     # 3. OCR检测 | ||||
|     ocr_flag, ocr_result = ocrDetect(frame) | ||||
|     if ocr_flag: | ||||
|         full_save_path, display_path = get_image_save_path(model_type="ocr", client_ip=client_ip)  # 这里改了 | ||||
|         full_save_path, display_path = get_image_save_path(model_type="ocr", client_ip=client_ip) | ||||
|         print(f"✅ ocr违规图片已保存:{display_path}") | ||||
|         # 这里改了 | ||||
|         if full_save_path: | ||||
|             cv2.imwrite(full_save_path, frame) | ||||
|             print(f"✅ ocr违规图片已保存:{display_path}")  # 日志也修正 | ||||
|         save_db(model_type="ocr", client_ip=client_ip, result=str(full_save_path)) | ||||
|         save_db(model_type="ocr", client_ip=client_ip, result=str(display_path)) | ||||
|         return (True, ocr_result, "ocr") | ||||
|  | ||||
|     # 4. 无违规内容(不保存图片) | ||||
|  | ||||
| @ -1,30 +1,29 @@ | ||||
| import datetime | ||||
| from pathlib import Path | ||||
| from typing import List, Tuple | ||||
|  | ||||
| from service.device_service import get_unique_client_ips | ||||
|  | ||||
|  | ||||
| def create_directory_structure(): | ||||
|     """创建项目所需的目录结构,为所有客户端IP预创建基础目录""" | ||||
|     try: | ||||
|         # 1. 创建根目录下的resource文件夹(存在则跳过,不覆盖子内容) | ||||
|         resource_dir = Path("resource") | ||||
|         resource_dir.mkdir(exist_ok=True) | ||||
|         # print(f"确保resource目录存在: {resource_dir.absolute()}") | ||||
|  | ||||
|         # 2. 在resource下创建dect文件夹 | ||||
|         dect_dir = resource_dir / "dect" | ||||
|         dect_dir.mkdir(exist_ok=True) | ||||
|         # print(f"确保dect目录存在: {dect_dir.absolute()}") | ||||
|  | ||||
|         # 3. 在dect下创建三个模型文件夹 | ||||
|         model_dirs = ["ocr", "face", "yolo"] | ||||
|         for model in model_dirs: | ||||
|             model_dir = dect_dir / model | ||||
|             model_dir.mkdir(exist_ok=True) | ||||
|             # print(f"确保{model}模型目录存在: {model_dir.absolute()}") | ||||
|  | ||||
|         # 4. 调用外部方法获取所有客户端IP地址 | ||||
|         try: | ||||
|             # 调用外部ip_read()方法获取所有客户端IP地址列表 | ||||
|             all_ip_addresses = get_unique_client_ips() | ||||
|  | ||||
|             # 确保返回的是列表类型 | ||||
| @ -58,7 +57,6 @@ def create_directory_structure(): | ||||
|  | ||||
|                     # 递归创建目录(存在则跳过,不覆盖) | ||||
|                     month_dir.mkdir(parents=True, exist_ok=True) | ||||
|                     # print(f"为客户端IP {ip} 创建/确保目录存在: {month_dir.absolute()}") | ||||
|  | ||||
|         except Exception as e: | ||||
|             print(f"处理客户端IP和日期目录时发生错误: {str(e)}") | ||||
| @ -67,52 +65,68 @@ def create_directory_structure(): | ||||
|         print(f"创建基础目录结构时发生错误: {str(e)}") | ||||
|  | ||||
|  | ||||
| def get_image_save_path(model_type: str, client_ip: str) -> tuple: | ||||
| def get_image_save_path(model_type: str, client_ip: str) -> Tuple[str, str]: | ||||
|     """ | ||||
|     获取图片保存的「完整路径」和「显示用短路径」 | ||||
|     获取图片保存的「本地完整路径」和「带路由前缀的显示路径」 | ||||
|  | ||||
|     参数: | ||||
|         model_type: 模型类型,应为"ocr"、"face"或"yolo" | ||||
|         client_ip: 检测到违禁的客户端IP地址(原始格式,如192.168.1.101) | ||||
|  | ||||
|     返回: | ||||
|         元组 (完整保存路径, 显示用短路径);若出错则返回 ("", "") | ||||
|         元组 (本地完整保存路径, 带/api/file/前缀的显示路径);若出错则返回 ("", "") | ||||
|     """ | ||||
|     try: | ||||
|         # 验证模型类型有效性 | ||||
|         valid_models = ["ocr", "face", "yolo"] | ||||
|         if model_type not in valid_models: | ||||
|             raise ValueError(f"无效的模型类型: {model_type},必须是{valid_models}之一") | ||||
|  | ||||
|         # 1. 验证客户端IP有效性(检查是否在已知IP列表中) | ||||
|         all_ip_addresses = get_unique_client_ips() | ||||
|         if not isinstance(all_ip_addresses, list): | ||||
|             all_ip_addresses = [all_ip_addresses] | ||||
|         valid_ips = [ip.strip() for ip in all_ip_addresses if ip.strip()] | ||||
|  | ||||
|         if client_ip.strip() not in valid_ips: | ||||
|         client_ip_stripped = client_ip.strip() | ||||
|         if client_ip_stripped not in valid_ips: | ||||
|             raise ValueError(f"客户端IP {client_ip} 不在已知IP列表中,无法保存文件") | ||||
|  | ||||
|         # 2. 处理IP地址(与目录创建逻辑一致,将.替换为_) | ||||
|         safe_ip = client_ip.strip().replace(".", "_") | ||||
|         # 2. 处理IP地址(将.替换为_,避免路径问题) | ||||
|         safe_ip = client_ip_stripped.replace(".", "_") | ||||
|  | ||||
|         # 3. 获取当前日期和毫秒级时间戳(确保文件名唯一) | ||||
|         now = datetime.datetime.now() | ||||
|         current_year = str(now.year) | ||||
|         current_month = str(now.month) | ||||
|         current_day = str(now.day) | ||||
|         # 时间戳格式:年月日时分秒毫秒(如20250910143050123) | ||||
|         timestamp = now.strftime("%Y%m%d%H%M%S%f")[:-3] | ||||
|         current_month = str(now.month).zfill(2)  # 确保月份为两位数 | ||||
|         current_day = str(now.day).zfill(2)      # 确保日期为两位数 | ||||
|         timestamp = now.strftime("%Y%m%d%H%M%S%f")[:-3]  # 取毫秒级时间戳 | ||||
|  | ||||
|         # 4. 定义基础目录(用于生成相对路径) | ||||
|         base_dir = Path("resource") / "dect"  # 显示路径会去掉这个前缀 | ||||
|         base_dir = Path("resource") / "dect" | ||||
|         # 构建日级目录(完整路径:resource/dect/{model}/{safe_ip}/{年}/{月}/{日}) | ||||
|         day_dir = base_dir / model_type / safe_ip / current_year / current_month / current_day | ||||
|         day_dir.mkdir(parents=True, exist_ok=True)  # 确保日目录存在 | ||||
|         day_dir.mkdir(parents=True, exist_ok=True)  # 确保目录存在 | ||||
|  | ||||
|         # 5. 构建唯一文件名 | ||||
|         image_filename = f"dect_{model_type}_{safe_ip}_{current_year}{current_month}{current_day}_{timestamp}.jpg" | ||||
|  | ||||
|         # 6. 生成完整路径(用于实际保存图片)和显示路径(用于打印) | ||||
|         full_path = day_dir / image_filename  # 完整路径:resource/dect/.../xxx.jpg | ||||
|         display_path = full_path.relative_to(base_dir)  # 短路径:{model}/.../xxx.jpg(去掉resource/dect) | ||||
|         # 6. 生成「本地完整路径」(使用系统路径,但在字符串表示时统一为正斜杠) | ||||
|         local_full_path = day_dir / image_filename | ||||
|         # 转换为字符串并统一使用正斜杠 | ||||
|         local_full_path_str = str(local_full_path).replace("\\", "/") | ||||
|  | ||||
|         return str(full_path), str(display_path) | ||||
|         # 7. 生成带路由前缀的显示路径(核心修改部分) | ||||
|         # 获取项目根目录(base_dir是resource/dect,向上两级即为项目根目录) | ||||
|         project_root = base_dir.parents[1] | ||||
|         # 计算相对于项目根目录的路径(包含resource/dect层级) | ||||
|         relative_path = local_full_path.relative_to(project_root) | ||||
|         # 转换为字符串并统一使用正斜杠 | ||||
|         relative_path_str = str(relative_path).replace("\\", "/") | ||||
|         # 拼接路由前缀 | ||||
|         routed_display_path = f"/api/file/{relative_path_str}" | ||||
|  | ||||
|         return local_full_path_str, routed_display_path | ||||
|  | ||||
|     except Exception as e: | ||||
|         print(f"获取图片保存路径时发生错误: {str(e)}") | ||||
|  | ||||
							
								
								
									
										83
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										83
									
								
								main.py
									
									
									
									
									
								
							| @ -1,10 +1,8 @@ | ||||
| import uvicorn | ||||
| import threading | ||||
| import time | ||||
| import os | ||||
| from fastapi import FastAPI | ||||
| from fastapi.middleware.cors import CORSMiddleware | ||||
| from service.file_service import app as flask_app | ||||
|  | ||||
| # 原有业务导入 | ||||
| from core.all import load_model | ||||
| @ -15,32 +13,12 @@ from service.sensitive_service import router as sensitive_router | ||||
| from service.face_service import router as face_router | ||||
| from service.device_service import router as device_router | ||||
| from service.model_service import router as model_router | ||||
| from service.file_service import router as file_router | ||||
| from service.device_danger_service import router as device_danger_router | ||||
| from ws.ws import ws_router, lifespan | ||||
| from core.establish import create_directory_structure | ||||
|  | ||||
|  | ||||
| # Flask 服务启动函数(不变) | ||||
| def start_flask_service(): | ||||
|     try: | ||||
|         print(f"\n[Flask 服务] 准备启动,端口:5000") | ||||
|         print(f"[Flask 服务] 访问示例:http://服务器IP:5000/resource/dect/ocr/xxx.jpg\n") | ||||
|  | ||||
|         BASE_IMAGE_DIR = os.path.abspath(os.path.join("resource", "dect")) | ||||
|         if not os.path.exists(BASE_IMAGE_DIR): | ||||
|             print(f"[Flask 服务] 图片根目录不存在,创建:{BASE_IMAGE_DIR}") | ||||
|             os.makedirs(BASE_IMAGE_DIR, exist_ok=True) | ||||
|  | ||||
|         flask_app.run( | ||||
|             host="0.0.0.0", | ||||
|             port=5000, | ||||
|             debug=False, | ||||
|             use_reloader=False | ||||
|         ) | ||||
|     except Exception as e: | ||||
|         print(f"[Flask 服务] 启动失败:{str(e)}") | ||||
|  | ||||
|  | ||||
| # 初始化 FastAPI 应用(新增 CORS 配置) | ||||
| # 初始化 FastAPI 应用 | ||||
| app = FastAPI( | ||||
|     title="内容安全审核后台", | ||||
|     description="含图片访问服务和动态模型管理", | ||||
| @ -48,38 +26,33 @@ app = FastAPI( | ||||
|     lifespan=lifespan | ||||
| ) | ||||
|  | ||||
| # ------------------------------ | ||||
| # 新增:完整 CORS 配置(解决跨域问题) | ||||
| # ------------------------------ | ||||
| # 1. 允许的前端域名(根据实际情况修改!本地开发通常是 http://localhost:8080 等) | ||||
| ALLOWED_ORIGINS = [ | ||||
|     # "http://localhost:8080",  # 前端本地开发地址(必改,填实际前端地址) | ||||
|     # "http://127.0.0.1:8080", | ||||
|     # "http://服务器IP:8080",    # 部署后前端地址(如适用) | ||||
|     "*" #表示允许所有域名(开发环境可用,生产环境不推荐) | ||||
|     "*" | ||||
| ] | ||||
|  | ||||
| # 2. 配置 CORS 中间件 | ||||
| # 配置 CORS 中间件 | ||||
| app.add_middleware( | ||||
|     CORSMiddleware, | ||||
|     allow_origins=ALLOWED_ORIGINS,        # 允许的前端域名 | ||||
|     allow_credentials=True,               # 允许携带 Cookie(如需登录态则必开) | ||||
|     allow_methods=["*"],                  # 允许所有 HTTP 方法(包括 PUT/DELETE) | ||||
|     allow_headers=["*"],                  # 允许所有请求头(包括 Content-Type) | ||||
|     allow_credentials=True,               # 允许携带 Cookie | ||||
|     allow_methods=["*"],                  # 允许所有 HTTP 方法 | ||||
|     allow_headers=["*"],                  # 允许所有请求头 | ||||
| ) | ||||
|  | ||||
| # 注册路由(不变) | ||||
| # 注册路由 | ||||
| app.include_router(user_router) | ||||
| app.include_router(device_router) | ||||
| app.include_router(face_router) | ||||
| app.include_router(sensitive_router) | ||||
| app.include_router(model_router)  # 模型管理路由 | ||||
| app.include_router(model_router) | ||||
| app.include_router(file_router) | ||||
| app.include_router(device_danger_router) | ||||
| app.include_router(ws_router) | ||||
|  | ||||
| # 注册全局异常处理器(不变) | ||||
| # 注册全局异常处理器 | ||||
| app.add_exception_handler(Exception, global_exception_handler) | ||||
|  | ||||
| # 主服务启动入口(不变) | ||||
| # 主服务启动入口 | ||||
| if __name__ == "__main__": | ||||
|     create_directory_structure() | ||||
|     print(f"[初始化] 目录结构创建完成") | ||||
| @ -89,11 +62,11 @@ if __name__ == "__main__": | ||||
|     os.makedirs(MODEL_SAVE_DIR, exist_ok=True) | ||||
|     print(f"[初始化] 模型保存目录:{MODEL_SAVE_DIR}") | ||||
|  | ||||
|     # # 模型路径配置 | ||||
|     # YOLO_MODEL_PATH = os.path.join("core", "models", "best.pt") | ||||
|     # OCR_CONFIG_PATH = os.path.join("core", "config", "config.yaml") | ||||
|     # print(f"[初始化] 默认YOLO模型路径:{YOLO_MODEL_PATH}") | ||||
|     # print(f"[初始化] OCR 配置路径:{OCR_CONFIG_PATH}") | ||||
|     # 确保图片目录存在(原Flask服务负责的目录) | ||||
|     BASE_IMAGE_DIR = os.path.abspath(os.path.join("resource", "dect")) | ||||
|     if not os.path.exists(BASE_IMAGE_DIR): | ||||
|         print(f"[初始化] 图片根目录不存在,创建:{BASE_IMAGE_DIR}") | ||||
|         os.makedirs(BASE_IMAGE_DIR, exist_ok=True) | ||||
|  | ||||
|     # 加载检测模型 | ||||
|     try: | ||||
| @ -105,23 +78,7 @@ if __name__ == "__main__": | ||||
|     except Exception as e: | ||||
|         print(f"[初始化] 模型加载警告:{str(e)}(服务仍启动)") | ||||
|  | ||||
|  | ||||
|  | ||||
|     # 2. 启动 Flask 服务(子线程) | ||||
|     flask_thread = threading.Thread( | ||||
|         target=start_flask_service, | ||||
|         daemon=True | ||||
|     ) | ||||
|     flask_thread.start() | ||||
|  | ||||
|     # 等待 Flask 初始化 | ||||
|     time.sleep(1) | ||||
|     if flask_thread.is_alive(): | ||||
|         print(f"[Flask 服务] 启动成功(运行中)") | ||||
|     else: | ||||
|         print(f"[Flask 服务] 启动失败!图片访问不可用") | ||||
|  | ||||
|     # 3. 启动 FastAPI 主服务 | ||||
|     # 启动 FastAPI 主服务(仅使用8000端口) | ||||
|     port = int(SERVER_CONFIG.get("port", 8000)) | ||||
|     print(f"\n[FastAPI 服务] 准备启动,端口:{port}") | ||||
|     print(f"[FastAPI 服务] 接口文档:http://服务器IP:{port}/docs\n") | ||||
|  | ||||
							
								
								
									
										33
									
								
								schema/device_danger_schema.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								schema/device_danger_schema.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,33 @@ | ||||
| from datetime import datetime | ||||
| from typing import Optional, List | ||||
| from pydantic import BaseModel, Field | ||||
|  | ||||
|  | ||||
| # ------------------------------ | ||||
| # 请求模型 | ||||
| # ------------------------------ | ||||
| class DeviceDangerCreateRequest(BaseModel): | ||||
|     """设备危险记录创建请求模型""" | ||||
|     client_ip: str = Field(..., max_length=100, description="设备IP地址(必须与devices表中IP对应)") | ||||
|     type: str = Field(..., max_length=50, description="危险类型(如:病毒检测、端口异常、权限泄露等)") | ||||
|     result: str = Field(..., description="危险检测结果/处理结果(如:检测到木马病毒,已隔离;端口22异常开放,已关闭)") | ||||
|  | ||||
|  | ||||
| # ------------------------------ | ||||
| # 响应模型 | ||||
| # ------------------------------ | ||||
| class DeviceDangerResponse(BaseModel): | ||||
|     """单条设备危险记录响应模型(与device_danger表字段对齐,updated_at允许为null)""" | ||||
|     id: int = Field(..., description="危险记录主键ID") | ||||
|     client_ip: str = Field(..., max_length=100, description="设备IP地址") | ||||
|     type: str = Field(..., max_length=50, description="危险类型") | ||||
|     result: str = Field(..., description="危险检测结果/处理结果") | ||||
|     created_at: datetime = Field(..., description="记录创建时间(危险发生/检测时间)") | ||||
|     updated_at: Optional[datetime] = Field(None, description="记录更新时间(数据库中该字段当前为null)") | ||||
|     model_config = {"from_attributes": True} | ||||
|  | ||||
|  | ||||
| class DeviceDangerListResponse(BaseModel): | ||||
|     """设备危险记录列表响应模型(带分页)""" | ||||
|     total: int = Field(..., description="危险记录总数") | ||||
|     dangers: List[DeviceDangerResponse] = Field(..., description="设备危险记录列表") | ||||
| @ -28,7 +28,6 @@ class DeviceResponse(BaseModel): | ||||
|     params: Optional[str] = Field(None, description="扩展参数(JSON字符串)") | ||||
|     created_at: datetime = Field(..., description="记录创建时间") | ||||
|     updated_at: datetime = Field(..., description="记录更新时间") | ||||
|  | ||||
|     model_config = {"from_attributes": True}  # 支持从数据库结果直接转换 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -12,7 +12,7 @@ from schema.response_schema import APIResponse | ||||
|  | ||||
| # 路由配置 | ||||
| router = APIRouter( | ||||
|     prefix="/device/actions", | ||||
|     prefix="/api/device/actions", | ||||
|     tags=["设备操作记录"] | ||||
| ) | ||||
|  | ||||
|  | ||||
							
								
								
									
										267
									
								
								service/device_danger_service.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										267
									
								
								service/device_danger_service.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,267 @@ | ||||
| import json | ||||
| from datetime import date | ||||
|  | ||||
| from fastapi import APIRouter, Query, HTTPException, Path | ||||
| from mysql.connector import Error as MySQLError | ||||
|  | ||||
| from ds.db import db | ||||
| from encryption.encrypt_decorator import encrypt_response | ||||
| from schema.device_danger_schema import ( | ||||
|     DeviceDangerCreateRequest, DeviceDangerResponse, DeviceDangerListResponse | ||||
| ) | ||||
| from schema.response_schema import APIResponse | ||||
|  | ||||
| # 路由初始化(前缀与设备管理相关,标签区分功能) | ||||
| router = APIRouter( | ||||
|     prefix="/api/devices/dangers", | ||||
|     tags=["设备管理-危险记录"] | ||||
| ) | ||||
|  | ||||
|  | ||||
| # ------------------------------ | ||||
| # 内部工具方法 - 检查设备是否存在(复用设备表逻辑) | ||||
| # ------------------------------ | ||||
| def check_device_exist(client_ip: str) -> bool: | ||||
|     """ | ||||
|     检查指定IP的设备是否在devices表中存在 | ||||
|  | ||||
|     :param client_ip: 设备IP地址 | ||||
|     :return: 存在返回True,不存在返回False | ||||
|     """ | ||||
|     if not client_ip: | ||||
|         raise ValueError("设备IP不能为空") | ||||
|  | ||||
|     conn = None | ||||
|     cursor = None | ||||
|     try: | ||||
|         conn = db.get_connection() | ||||
|         cursor = conn.cursor(dictionary=True) | ||||
|  | ||||
|         cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,)) | ||||
|         return cursor.fetchone() is not None | ||||
|     except MySQLError as e: | ||||
|         raise Exception(f"检查设备存在性失败: {str(e)}") from e | ||||
|     finally: | ||||
|         db.close_connection(conn, cursor) | ||||
|  | ||||
|  | ||||
| # ------------------------------ | ||||
| # 内部工具方法 - 创建设备危险记录(核心插入逻辑) | ||||
| # ------------------------------ | ||||
| def create_danger_record(danger_data: DeviceDangerCreateRequest) -> DeviceDangerResponse: | ||||
|     """ | ||||
|     内部工具方法:向device_danger表插入新的危险记录 | ||||
|  | ||||
|     :param danger_data: 危险记录创建请求数据 | ||||
|     :return: 创建成功的危险记录模型对象 | ||||
|     """ | ||||
|     # 先检查设备是否存在 | ||||
|     if not check_device_exist(danger_data.client_ip): | ||||
|         raise ValueError(f"IP为 {danger_data.client_ip} 的设备不存在,无法创建危险记录") | ||||
|  | ||||
|     conn = None | ||||
|     cursor = None | ||||
|     try: | ||||
|         conn = db.get_connection() | ||||
|         cursor = conn.cursor(dictionary=True) | ||||
|  | ||||
|         # 插入危险记录(id自增,时间自动填充) | ||||
|         insert_query = """ | ||||
|             INSERT INTO device_danger  | ||||
|             (client_ip, type, result, created_at, updated_at) | ||||
|             VALUES (%s, %s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) | ||||
|         """ | ||||
|         cursor.execute(insert_query, ( | ||||
|             danger_data.client_ip, | ||||
|             danger_data.type, | ||||
|             danger_data.result | ||||
|         )) | ||||
|         conn.commit() | ||||
|  | ||||
|         # 获取刚创建的记录(用自增ID查询) | ||||
|         danger_id = cursor.lastrowid | ||||
|         cursor.execute("SELECT * FROM device_danger WHERE id = %s", (danger_id,)) | ||||
|         new_danger = cursor.fetchone() | ||||
|  | ||||
|         return DeviceDangerResponse(**new_danger) | ||||
|     except MySQLError as e: | ||||
|         if conn: | ||||
|             conn.rollback() | ||||
|         raise Exception(f"插入危险记录失败: {str(e)}") from e | ||||
|     finally: | ||||
|         db.close_connection(conn, cursor) | ||||
|  | ||||
|  | ||||
| # ------------------------------ | ||||
| # 接口1:创建设备危险记录 | ||||
| # ------------------------------ | ||||
| @router.post("/add", response_model=APIResponse, summary="创建设备危险记录") | ||||
| @encrypt_response() | ||||
| async def add_device_danger(danger_data: DeviceDangerCreateRequest): | ||||
|     try: | ||||
|         # 调用内部方法创建记录 | ||||
|         new_danger = create_danger_record(danger_data) | ||||
|  | ||||
|         return APIResponse( | ||||
|             code=200, | ||||
|             message=f"设备[{danger_data.client_ip}]危险记录创建成功", | ||||
|             data=new_danger | ||||
|         ) | ||||
|     except ValueError as e: | ||||
|         # 设备不存在等业务异常 | ||||
|         raise HTTPException(status_code=400, detail=str(e)) from e | ||||
|     except Exception as e: | ||||
|         # 数据库异常等系统错误 | ||||
|         raise HTTPException(status_code=500, detail=str(e)) from e | ||||
|  | ||||
|  | ||||
| # ------------------------------ | ||||
| # 接口2:获取危险记录列表(支持多条件筛选+分页) | ||||
| # ------------------------------ | ||||
| @router.get("/", response_model=APIResponse, summary="获取设备危险记录列表(多条件筛选)") | ||||
| @encrypt_response() | ||||
| async def get_danger_list( | ||||
|         page: int = Query(1, ge=1, description="页码,默认第1页"), | ||||
|         page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100之间"), | ||||
|         client_ip: str = Query(None, max_length=100, description="按设备IP筛选"), | ||||
|         danger_type: str = Query(None, max_length=50, alias="type", description="按危险类型筛选"), | ||||
|         start_date: date = Query(None, description="按创建时间筛选(开始日期,格式YYYY-MM-DD)"), | ||||
|         end_date: date = Query(None, description="按创建时间筛选(结束日期,格式YYYY-MM-DD)") | ||||
| ): | ||||
|     conn = None | ||||
|     cursor = None | ||||
|     try: | ||||
|         conn = db.get_connection() | ||||
|         cursor = conn.cursor(dictionary=True) | ||||
|  | ||||
|         # 构建筛选条件 | ||||
|         where_clause = [] | ||||
|         params = [] | ||||
|  | ||||
|         if client_ip: | ||||
|             where_clause.append("client_ip = %s") | ||||
|             params.append(client_ip) | ||||
|         if danger_type: | ||||
|             where_clause.append("type = %s") | ||||
|             params.append(danger_type) | ||||
|         if start_date: | ||||
|             where_clause.append("DATE(created_at) >= %s") | ||||
|             params.append(start_date.strftime("%Y-%m-%d")) | ||||
|         if end_date: | ||||
|             where_clause.append("DATE(created_at) <= %s") | ||||
|             params.append(end_date.strftime("%Y-%m-%d")) | ||||
|  | ||||
|         # 1. 统计符合条件的总记录数 | ||||
|         count_query = "SELECT COUNT(*) AS total FROM device_danger" | ||||
|         if where_clause: | ||||
|             count_query += " WHERE " + " AND ".join(where_clause) | ||||
|         cursor.execute(count_query, params) | ||||
|         total = cursor.fetchone()["total"] | ||||
|  | ||||
|         # 2. 分页查询记录(按创建时间倒序,最新的在前) | ||||
|         offset = (page - 1) * page_size | ||||
|         list_query = "SELECT * FROM device_danger" | ||||
|         if where_clause: | ||||
|             list_query += " WHERE " + " AND ".join(where_clause) | ||||
|         list_query += " ORDER BY created_at DESC LIMIT %s OFFSET %s" | ||||
|         params.extend([page_size, offset])  # 追加分页参数 | ||||
|  | ||||
|         cursor.execute(list_query, params) | ||||
|         danger_list = cursor.fetchall() | ||||
|  | ||||
|         # 转换为响应模型 | ||||
|         return APIResponse( | ||||
|             code=200, | ||||
|             message="获取危险记录列表成功", | ||||
|             data=DeviceDangerListResponse( | ||||
|                 total=total, | ||||
|                 dangers=[DeviceDangerResponse(**item) for item in danger_list] | ||||
|             ) | ||||
|         ) | ||||
|     except MySQLError as e: | ||||
|         raise HTTPException(status_code=500, detail=f"查询危险记录失败: {str(e)}") from e | ||||
|     finally: | ||||
|         db.close_connection(conn, cursor) | ||||
|  | ||||
|  | ||||
| # ------------------------------ | ||||
| # 接口3:获取单个设备的所有危险记录 | ||||
| # ------------------------------ | ||||
| @router.get("/device/{client_ip}", response_model=APIResponse, summary="获取单个设备的所有危险记录") | ||||
| # @encrypt_response() | ||||
| async def get_device_dangers( | ||||
|         client_ip: str = Path(..., max_length=100, description="设备IP地址"), | ||||
|         page: int = Query(1, ge=1, description="页码,默认第1页"), | ||||
|         page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100之间") | ||||
| ): | ||||
|     # 先检查设备是否存在 | ||||
|     if not check_device_exist(client_ip): | ||||
|         raise HTTPException(status_code=404, detail=f"IP为 {client_ip} 的设备不存在") | ||||
|  | ||||
|     conn = None | ||||
|     cursor = None | ||||
|     try: | ||||
|         conn = db.get_connection() | ||||
|         cursor = conn.cursor(dictionary=True) | ||||
|  | ||||
|         # 1. 统计该设备的危险记录总数 | ||||
|         count_query = "SELECT COUNT(*) AS total FROM device_danger WHERE client_ip = %s" | ||||
|         cursor.execute(count_query, (client_ip,)) | ||||
|         total = cursor.fetchone()["total"] | ||||
|  | ||||
|         # 2. 分页查询该设备的危险记录 | ||||
|         offset = (page - 1) * page_size | ||||
|         list_query = """ | ||||
|             SELECT * FROM device_danger  | ||||
|             WHERE client_ip = %s  | ||||
|             ORDER BY created_at DESC  | ||||
|             LIMIT %s OFFSET %s | ||||
|         """ | ||||
|         cursor.execute(list_query, (client_ip, page_size, offset)) | ||||
|         danger_list = cursor.fetchall() | ||||
|  | ||||
|         return APIResponse( | ||||
|             code=200, | ||||
|             message=f"获取设备[{client_ip}]危险记录成功(共{total}条)", | ||||
|             data=DeviceDangerListResponse( | ||||
|                 total=total, | ||||
|                 dangers=[DeviceDangerResponse(**item) for item in danger_list] | ||||
|             ) | ||||
|         ) | ||||
|     except MySQLError as e: | ||||
|         raise HTTPException(status_code=500, detail=f"查询设备[{client_ip}]危险记录失败: {str(e)}") from e | ||||
|     finally: | ||||
|         db.close_connection(conn, cursor) | ||||
|  | ||||
|  | ||||
| # ------------------------------ | ||||
| # 接口4:根据ID获取单个危险记录详情 | ||||
| # ------------------------------ | ||||
| @router.get("/{danger_id}", response_model=APIResponse, summary="根据ID获取单个危险记录详情") | ||||
| @encrypt_response() | ||||
| async def get_danger_detail( | ||||
|         danger_id: int = Path(..., ge=1, description="危险记录ID") | ||||
| ): | ||||
|     conn = None | ||||
|     cursor = None | ||||
|     try: | ||||
|         conn = db.get_connection() | ||||
|         cursor = conn.cursor(dictionary=True) | ||||
|  | ||||
|         # 查询单个危险记录 | ||||
|         query = "SELECT * FROM device_danger WHERE id = %s" | ||||
|         cursor.execute(query, (danger_id,)) | ||||
|         danger = cursor.fetchone() | ||||
|  | ||||
|         if not danger: | ||||
|             raise HTTPException(status_code=404, detail=f"ID为 {danger_id} 的危险记录不存在") | ||||
|  | ||||
|         return APIResponse( | ||||
|             code=200, | ||||
|             message="获取危险记录详情成功", | ||||
|             data=DeviceDangerResponse(**danger) | ||||
|         ) | ||||
|     except MySQLError as e: | ||||
|         raise HTTPException(status_code=500, detail=f"查询危险记录详情失败: {str(e)}") from e | ||||
|     finally: | ||||
|         db.close_connection(conn, cursor) | ||||
| @ -13,7 +13,7 @@ from schema.device_schema import ( | ||||
| from schema.response_schema import APIResponse | ||||
|  | ||||
| router = APIRouter( | ||||
|     prefix="/devices", | ||||
|     prefix="/api/devices", | ||||
|     tags=["设备管理"] | ||||
| ) | ||||
|  | ||||
|  | ||||
| @ -17,7 +17,7 @@ from schema.response_schema import APIResponse | ||||
| from util.face_util import add_binary_data, get_average_feature | ||||
| from util.file_util import save_face_to_up_images | ||||
|  | ||||
| router = APIRouter(prefix="/faces", tags=["人脸管理"]) | ||||
| router = APIRouter(prefix="/api/faces", tags=["人脸管理"]) | ||||
|  | ||||
|  | ||||
| # ------------------------------ | ||||
|  | ||||
| @ -1,276 +1,174 @@ | ||||
| from flask import Flask, send_from_directory, abort, request | ||||
| from fastapi import FastAPI, HTTPException, Request, Depends, APIRouter | ||||
| from fastapi.responses import FileResponse | ||||
| import os | ||||
| import logging | ||||
| from functools import wraps | ||||
| from pathlib import Path | ||||
| from flask_cors import CORS | ||||
| from fastapi.middleware.cors import CORSMiddleware | ||||
| from typing import Annotated | ||||
|  | ||||
| # 配置日志(保持原有格式) | ||||
| logging.basicConfig( | ||||
|     level=logging.INFO, | ||||
|     format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | ||||
| ) | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
| # 初始化 Flask 应用(供 main.py 导入) | ||||
| app = Flask(__name__) | ||||
|  | ||||
| # ------------------------------ | ||||
| # 核心修改:与 FastAPI 对齐的跨域配置 | ||||
| # ------------------------------ | ||||
| # 1. 允许的前端域名(根据实际环境修改,生产环境删除 "*") | ||||
| ALLOWED_ORIGINS = [ | ||||
|     # "http://localhost:8080",  # 本地前端开发地址 | ||||
|     # "http://127.0.0.1:8080", | ||||
|     # "http://服务器IP:8080",    # 部署后前端地址 | ||||
|     "*" | ||||
| ] | ||||
|  | ||||
| # 2. 配置 CORS(与 FastAPI 规则完全对齐) | ||||
| CORS( | ||||
|     app, | ||||
|     resources={ | ||||
|         r"/*": { | ||||
|             "origins": ALLOWED_ORIGINS, | ||||
|             "allow_credentials": True, | ||||
|             "methods": ["*"], | ||||
|             "allow_headers": ["*"], | ||||
|         } | ||||
|     }, | ||||
| router = APIRouter( | ||||
|     prefix="/api/file", | ||||
|     tags=["文件管理"] | ||||
| ) | ||||
|  | ||||
| # ------------------------------ | ||||
| # 核心路径配置(关键修改:修正 PROJECT_ROOT 计算) | ||||
| # 原问题:file_service.py 在 service 文件夹内,需向上跳一级到项目根目录 | ||||
| # 4. 路径配置 | ||||
| # ------------------------------ | ||||
| CURRENT_FILE_PATH = Path(__file__).resolve()  # 当前文件路径:service/file_service.py | ||||
| PROJECT_ROOT = CURRENT_FILE_PATH.parent.parent  # 项目根目录(service 文件夹的父目录) | ||||
| # 资源目录(现在正确指向项目根目录下的文件夹) | ||||
| BASE_IMAGE_DIR_DECT = str((PROJECT_ROOT / "resource" / "dect").resolve())       # 根目录/resource/dect | ||||
| BASE_IMAGE_DIR_UP_IMAGES = str((PROJECT_ROOT / "up_images").resolve())          # 根目录/up_images | ||||
| BASE_MODEL_DIR = str((PROJECT_ROOT / "resource" / "models").resolve())          # 根目录/resource/models | ||||
| CURRENT_FILE_PATH = Path(__file__).resolve() | ||||
| PROJECT_ROOT = CURRENT_FILE_PATH.parent.parent  # 项目根目录 | ||||
|  | ||||
| # 资源目录定义 | ||||
| BASE_IMAGE_DIR_DECT = str((PROJECT_ROOT / "resource" / "dect").resolve())       # 检测图片目录 | ||||
| BASE_IMAGE_DIR_UP_IMAGES = str((PROJECT_ROOT / "up_images").resolve())          # 人脸图片目录 | ||||
| BASE_MODEL_DIR = str((PROJECT_ROOT / "resource" / "models").resolve())          # 模型文件目录 | ||||
|  | ||||
| # 确保资源目录存在 | ||||
| for dir_path in [BASE_IMAGE_DIR_DECT, BASE_IMAGE_DIR_UP_IMAGES, BASE_MODEL_DIR]: | ||||
|     if not os.path.exists(dir_path): | ||||
|         os.makedirs(dir_path, exist_ok=True) | ||||
|         print(f"[创建目录] {dir_path}") | ||||
|  | ||||
| # ------------------------------ | ||||
| # 安全检查装饰器(不变) | ||||
| # 5. 安全依赖项(替代Flask装饰器) | ||||
| # ------------------------------ | ||||
| def safe_path_check(root_dir: str): | ||||
|     def decorator(func): | ||||
|         @wraps(func) | ||||
|         def wrapper(*args, **kwargs): | ||||
|             resource_path = kwargs.get('resource_path', '').strip() | ||||
|             # 统一路径分隔符(兼容 Windows \ 和 Linux /) | ||||
|     """ | ||||
|     安全路径校验依赖项: | ||||
|     1. 禁止路径遍历(确保请求文件在根目录内) | ||||
|     2. 校验文件存在且为有效文件(非目录) | ||||
|     3. 限制文件大小(模型200MB,图片10MB) | ||||
|     """ | ||||
|     async def dependency(request: Request, resource_path: str): | ||||
|         # 统一路径分隔符 | ||||
|         resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep) | ||||
|             # 拼接完整路径(防止路径遍历) | ||||
|         # 拼接完整路径 | ||||
|         full_file_path = os.path.abspath(os.path.join(root_dir, resource_path)) | ||||
|             logger.debug( | ||||
|                 f"[Flask 安全检查] 请求路径:{resource_path} | 完整路径:{full_file_path} | 根目录:{root_dir}" | ||||
|             ) | ||||
|  | ||||
|             # 1. 禁止路径遍历(确保请求文件在根目录内) | ||||
|         # 校验1:禁止路径遍历 | ||||
|         if not full_file_path.startswith(root_dir): | ||||
|                 logger.warning( | ||||
|                     f"[Flask 安全拦截] 非法路径遍历!IP:{request.remote_addr} | 请求路径:{resource_path}" | ||||
|                 ) | ||||
|                 abort(403) | ||||
|             print(f"[安全检查] 禁止路径遍历!IP:{request.client.host} | 请求路径:{resource_path}") | ||||
|             raise HTTPException(status_code=403, detail="非法路径访问") | ||||
|  | ||||
|             # 2. 检查文件存在且为有效文件(非目录) | ||||
|         # 校验2:文件存在且为有效文件 | ||||
|         if not os.path.exists(full_file_path) or not os.path.isfile(full_file_path): | ||||
|                 logger.warning( | ||||
|                     f"[Flask 资源错误] 文件不存在/非文件!IP:{request.remote_addr} | 路径:{full_file_path}" | ||||
|                 ) | ||||
|                 abort(404) | ||||
|             print(f"[资源错误] 文件不存在/非文件!IP:{request.client.host} | 路径:{full_file_path}") | ||||
|             raise HTTPException(status_code=404, detail="文件不存在") | ||||
|  | ||||
|             # 3. 限制文件大小(模型200MB,图片10MB) | ||||
|         # 校验3:文件大小限制 | ||||
|         max_size = 200 * 1024 * 1024 if "models" in root_dir else 10 * 1024 * 1024 | ||||
|         if os.path.getsize(full_file_path) > max_size: | ||||
|                 logger.warning( | ||||
|                     f"[Flask 大小超限] 文件超过{max_size//1024//1024}MB!IP:{request.remote_addr} | 路径:{full_file_path}" | ||||
|                 ) | ||||
|                 abort(413) | ||||
|             print(f"[大小超限] 文件超过{max_size//1024//1024}MB!IP:{request.client.host} | 路径:{full_file_path}") | ||||
|             raise HTTPException(status_code=413, detail=f"文件大小超过限制({max_size//1024//1024}MB)") | ||||
|  | ||||
|             # 安全检查通过,传递根目录给视图函数 | ||||
|             return func(*args, **kwargs, root_dir=root_dir) | ||||
|         return wrapper | ||||
|     return decorator | ||||
|         return full_file_path | ||||
|     return dependency | ||||
|  | ||||
| # ------------------------------ | ||||
| # 1. 模型下载接口(/model/download/*) | ||||
| # 6. 核心接口 | ||||
| # ------------------------------ | ||||
| @app.route('/model/download/<path:resource_path>') | ||||
| @safe_path_check(root_dir=BASE_MODEL_DIR) | ||||
| def download_model(resource_path, root_dir): | ||||
| @router.get("/model/download/{resource_path:path}", summary="模型下载接口") | ||||
| async def download_model( | ||||
|     resource_path: str, | ||||
|     full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_MODEL_DIR))], | ||||
|     request: Request | ||||
| ): | ||||
|     """模型下载接口(仅允许 .pt 格式,强制浏览器下载)""" | ||||
|     try: | ||||
|         resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep) | ||||
|         dir_path, file_name = os.path.split(resource_path) | ||||
|         full_dir = os.path.abspath(os.path.join(root_dir, dir_path)) | ||||
|         dir_path, file_name = os.path.split(full_file_path) | ||||
|  | ||||
|         # 仅允许 .pt 格式(YOLO 模型) | ||||
|         if not file_name.lower().endswith('.pt'): | ||||
|             logger.warning( | ||||
|                 f"[Flask 格式错误] 非 .pt 模型文件!IP:{request.remote_addr} | 文件名:{file_name}" | ||||
|         # 额外校验:仅允许 YOLO 模型格式(.pt) | ||||
|         if not file_name.lower().endswith(".pt"): | ||||
|             print(f"[格式错误] 非 .pt 模型文件!IP:{request.client.host} | 文件名:{file_name}") | ||||
|             raise HTTPException(status_code=415, detail="仅支持 .pt 格式的模型文件") | ||||
|  | ||||
|         print(f"[模型下载] 尝试下载!IP:{request.client.host} | 文件:{file_name} | 目录:{dir_path}") | ||||
|  | ||||
|         # 强制下载 | ||||
|         return FileResponse( | ||||
|             full_file_path, | ||||
|             filename=file_name, | ||||
|             media_type="application/octet-stream" | ||||
|         ) | ||||
|             abort(415) | ||||
|  | ||||
|         logger.info( | ||||
|             f"[Flask 模型下载] 成功请求!IP:{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}" | ||||
|         ) | ||||
|  | ||||
|         # 强制浏览器下载(而非预览) | ||||
|         return send_from_directory( | ||||
|             full_dir, | ||||
|             file_name, | ||||
|             as_attachment=True, | ||||
|             mimetype="application/octet-stream" | ||||
|         ) | ||||
|  | ||||
|     except HTTPException: | ||||
|         raise | ||||
|     except Exception as e: | ||||
|         logger.error( | ||||
|             f"[Flask 模型下载异常] IP:{request.remote_addr} | 错误:{str(e)}" | ||||
|         ) | ||||
|         abort(500) | ||||
|         print(f"[下载异常] IP:{request.client.host} | 错误:{str(e)}") | ||||
|         raise HTTPException(status_code=500, detail="服务器内部错误") | ||||
|  | ||||
| # ------------------------------ | ||||
| # 2. 人脸图片访问接口(/up_images/*) | ||||
| # ------------------------------ | ||||
| @app.route('/up_images/<path:resource_path>') | ||||
| @safe_path_check(root_dir=BASE_IMAGE_DIR_UP_IMAGES) | ||||
| def get_face_image(resource_path, root_dir): | ||||
|  | ||||
| @router.get("/up_images/{resource_path:path}", summary="人脸图片访问接口") | ||||
| async def get_face_image( | ||||
|     resource_path: str, | ||||
|     full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_IMAGE_DIR_UP_IMAGES))], | ||||
|     request: Request | ||||
| ): | ||||
|     """人脸图片访问接口(允许浏览器预览,仅支持常见图片格式)""" | ||||
|     try: | ||||
|         resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep) | ||||
|         dir_path, file_name = os.path.split(resource_path) | ||||
|         full_dir = os.path.abspath(os.path.join(root_dir, dir_path)) | ||||
|         dir_path, file_name = os.path.split(full_file_path) | ||||
|  | ||||
|         # 仅允许常见图片格式 | ||||
|         allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp') | ||||
|         # 图片格式校验 | ||||
|         allowed_ext = (".png", ".jpg", ".jpeg", ".gif", ".bmp") | ||||
|         if not file_name.lower().endswith(allowed_ext): | ||||
|             logger.warning( | ||||
|                 f"[Flask 格式错误] 非图片文件!IP:{request.remote_addr} | 文件名:{file_name}" | ||||
|             ) | ||||
|             abort(415) | ||||
|             print(f"[格式错误] 非图片文件!IP:{request.client.host} | 文件名:{file_name}") | ||||
|             raise HTTPException(status_code=415, detail="仅支持常见图片格式") | ||||
|  | ||||
|         logger.info( | ||||
|             f"[Flask 人脸图片] 成功请求!IP:{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}" | ||||
|         ) | ||||
|  | ||||
|         # 允许浏览器预览图片 | ||||
|         return send_from_directory(full_dir, file_name, as_attachment=False) | ||||
|         print(f"[人脸图片] 尝试访问!IP:{request.client.host} | 文件:{file_name} | 目录:{dir_path}") | ||||
|  | ||||
|         return FileResponse(full_file_path) | ||||
|     except HTTPException: | ||||
|         raise | ||||
|     except Exception as e: | ||||
|         logger.error( | ||||
|             f"[Flask 人脸图片异常] IP:{request.remote_addr} | 错误:{str(e)}" | ||||
|         ) | ||||
|         abort(500) | ||||
|         print(f"[人脸图片异常] IP:{request.client.host} | 错误:{str(e)}") | ||||
|  | ||||
| # ------------------------------ | ||||
| # 3. 检测图片访问接口(/resource/dect/*) | ||||
| # ------------------------------ | ||||
| @app.route('/resource/dect/<path:resource_path>') | ||||
| @safe_path_check(root_dir=BASE_IMAGE_DIR_DECT) | ||||
| def get_dect_image(resource_path, root_dir): | ||||
|  | ||||
| @router.get("/resource/dect/{resource_path:path}", summary="检测图片访问接口") | ||||
| async def get_dect_image( | ||||
|     resource_path: str, | ||||
|     full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_IMAGE_DIR_DECT))], | ||||
|     request: Request | ||||
| ): | ||||
|     """检测图片访问接口(允许浏览器预览,仅支持常见图片格式)""" | ||||
|     try: | ||||
|         resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep) | ||||
|         dir_path, file_name = os.path.split(resource_path) | ||||
|         full_dir = os.path.abspath(os.path.join(root_dir, dir_path)) | ||||
|         dir_path, file_name = os.path.split(full_file_path) | ||||
|  | ||||
|         # 仅允许常见图片格式 | ||||
|         allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp') | ||||
|         # 图片格式校验 | ||||
|         allowed_ext = (".png", ".jpg", ".jpeg", ".gif", ".bmp") | ||||
|         if not file_name.lower().endswith(allowed_ext): | ||||
|             logger.warning( | ||||
|                 f"[Flask 格式错误] 非图片文件!IP:{request.remote_addr} | 文件名:{file_name}" | ||||
|             ) | ||||
|             abort(415) | ||||
|             print(f"[格式错误] 非图片文件!IP:{request.client.host} | 文件名:{file_name}") | ||||
|             raise HTTPException(status_code=415, detail="仅支持常见图片格式") | ||||
|  | ||||
|         logger.info( | ||||
|             f"[Flask 检测图片] 成功请求!IP:{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}" | ||||
|         ) | ||||
|  | ||||
|         return send_from_directory(full_dir, file_name, as_attachment=False) | ||||
|         print(f"[检测图片] 尝试访问!IP:{request.client.host} | 文件:{file_name} | 目录:{dir_path}") | ||||
|  | ||||
|         return FileResponse(full_file_path) | ||||
|     except HTTPException: | ||||
|         raise | ||||
|     except Exception as e: | ||||
|         logger.error( | ||||
|             f"[Flask 检测图片异常] IP:{request.remote_addr} | 错误:{str(e)}" | ||||
|         ) | ||||
|         abort(500) | ||||
|         print(f"[检测图片异常] IP:{request.client.host} | 错误:{str(e)}") | ||||
|         raise HTTPException(status_code=500, detail="服务器内部错误") | ||||
|  | ||||
| # ------------------------------ | ||||
| # 4. 兼容旧图片接口(/images/* → 映射到 /resource/dect/*) | ||||
| # ------------------------------ | ||||
| @app.route('/images/<path:resource_path>') | ||||
| @safe_path_check(root_dir=BASE_IMAGE_DIR_DECT) | ||||
| def get_compatible_image(resource_path, root_dir): | ||||
|  | ||||
| @router.get("/images/{resource_path:path}", summary="兼容旧接口") | ||||
| async def get_compatible_image( | ||||
|     resource_path: str, | ||||
|     full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_IMAGE_DIR_DECT))], | ||||
|     request: Request | ||||
| ): | ||||
|     """兼容旧接口(/images/* → 映射到 /resource/dect/*,保留历史兼容性)""" | ||||
|     try: | ||||
|         resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep) | ||||
|         dir_path, file_name = os.path.split(resource_path) | ||||
|         full_dir = os.path.abspath(os.path.join(root_dir, dir_path)) | ||||
|         dir_path, file_name = os.path.split(full_file_path) | ||||
|  | ||||
|         allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp') | ||||
|         # 图片格式校验 | ||||
|         allowed_ext = (".png", ".jpg", ".jpeg", ".gif", ".bmp") | ||||
|         if not file_name.lower().endswith(allowed_ext): | ||||
|             logger.warning( | ||||
|                 f"[Flask 格式错误] 非图片文件!IP:{request.remote_addr} | 文件名:{file_name}" | ||||
|             ) | ||||
|             abort(415) | ||||
|  | ||||
|         logger.info( | ||||
|             f"[Flask 兼容图片] 成功请求!IP:{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}" | ||||
|         ) | ||||
|  | ||||
|         return send_from_directory(full_dir, file_name, as_attachment=False) | ||||
|             print(f"[格式错误] 非图片文件!IP:{request.client.host} | 文件名:{file_name}") | ||||
|             raise HTTPException(status_code=415, detail="仅支持常见图片格式") | ||||
|         print(f"[兼容图片] 尝试访问!IP:{request.client.host} | 文件:{file_name} | 目录:{dir_path}") | ||||
|  | ||||
|         return FileResponse(full_file_path) | ||||
|     except HTTPException: | ||||
|         raise | ||||
|     except Exception as e: | ||||
|         logger.error( | ||||
|             f"[Flask 兼容图片异常] IP:{request.remote_addr} | 错误:{str(e)}" | ||||
|         ) | ||||
|         abort(500) | ||||
|  | ||||
| # ------------------------------ | ||||
| # 全局错误处理器(不变) | ||||
| # ------------------------------ | ||||
| @app.errorhandler(403) | ||||
| def forbidden_error(error): | ||||
|     return "❌ 禁止访问:路径非法(可能存在路径遍历)或无权限", 403 | ||||
|  | ||||
| @app.errorhandler(404) | ||||
| def not_found_error(error): | ||||
|     return "❌ 资源不存在:请检查URL路径(IP、目录、文件名)是否正确", 404 | ||||
|  | ||||
| @app.errorhandler(413) | ||||
| def too_large_error(error): | ||||
|     return "❌ 文件过大:图片最大10MB,模型最大200MB", 413 | ||||
|  | ||||
| @app.errorhandler(415) | ||||
| def unsupported_type_error(error): | ||||
|     return "❌ 不支持的文件类型:图片支持png/jpg/jpeg/gif/bmp,模型仅支持pt", 415 | ||||
|  | ||||
| @app.errorhandler(500) | ||||
| def server_error(error): | ||||
|     return "❌ 服务器内部错误:请联系管理员查看后台日志", 500 | ||||
|  | ||||
| # ------------------------------ | ||||
| # Flask 独立启动入口(供测试,实际由 main.py 子线程启动) | ||||
| # ------------------------------ | ||||
| if __name__ == '__main__': | ||||
|     # 确保所有资源目录存在 | ||||
|     required_dirs = [ | ||||
|         (BASE_IMAGE_DIR_DECT, "检测图片目录"), | ||||
|         (BASE_IMAGE_DIR_UP_IMAGES, "人脸图片目录"), | ||||
|         (BASE_MODEL_DIR, "模型文件目录") | ||||
|     ] | ||||
|     for dir_path, dir_desc in required_dirs: | ||||
|         if not os.path.exists(dir_path): | ||||
|             logger.info(f"[Flask 初始化] {dir_desc}不存在,创建:{dir_path}") | ||||
|             os.makedirs(dir_path, exist_ok=True) | ||||
|  | ||||
|     # 启动提示 | ||||
|     logger.info("\n[Flask 服务启动成功!] 支持的接口:") | ||||
|     logger.info(f"1. 模型下载 → http://服务器IP:5000/model/download/resource/models/xxx.pt") | ||||
|     logger.info(f"2. 人脸图片 → http://服务器IP:5000/up_images/xxx.jpg") | ||||
|     logger.info(f"3. 检测图片 → http://服务器IP:5000/resource/dect/xxx.jpg 或 http://服务器IP:5000/images/xxx.jpg\n") | ||||
|  | ||||
|     # 启动服务(禁用 debug 和自动重载) | ||||
|     app.run( | ||||
|         host="0.0.0.0", | ||||
|         port=5000, | ||||
|         debug=False, | ||||
|         use_reloader=False | ||||
|     ) | ||||
|         print(f"[兼容图片异常] IP:{request.client.host} | 错误:{str(e)}") | ||||
|         raise HTTPException(status_code=500, detail="服务器内部错误") | ||||
|  | ||||
| @ -38,7 +38,7 @@ _yolo_model = None | ||||
| _current_model_version = None  # 模型版本标识 | ||||
| _current_conf_threshold = 0.8  # 默认置信度初始值 | ||||
|  | ||||
| router = APIRouter(prefix="/models", tags=["模型管理"]) | ||||
| router = APIRouter(prefix="/api/models", tags=["模型管理"]) | ||||
|  | ||||
|  | ||||
| # 服务重启核心工具函数(保持不变) | ||||
|  | ||||
| @ -16,7 +16,7 @@ from schema.user_schema import UserResponse | ||||
|  | ||||
| # 创建敏感信息接口路由(前缀 /sensitives、标签用于 Swagger 分类) | ||||
| router = APIRouter( | ||||
|     prefix="/sensitives", | ||||
|     prefix="/api/sensitives", | ||||
|     tags=["敏感信息管理"] | ||||
| ) | ||||
|  | ||||
|  | ||||
| @ -18,7 +18,7 @@ from middle.auth_middleware import ( | ||||
|  | ||||
| # 创建用户接口路由(前缀 /users、标签用于 Swagger 分类) | ||||
| router = APIRouter( | ||||
|     prefix="/users", | ||||
|     prefix="/api/users", | ||||
|     tags=["用户管理"] | ||||
| ) | ||||
|  | ||||
|  | ||||
| @ -12,7 +12,8 @@ def save_face_to_up_images( | ||||
| ) -> Dict[str, str]: | ||||
|     """ | ||||
|     保存人脸图片到 `/up_images/用户IP/人脸名字/` 路径 | ||||
|     确保db_path以up_images开头,且统一使用正斜杠 | ||||
|     确保db_path以 /api/file/up_images 开头,且统一使用正斜杠 | ||||
|     本地不创建/api/file/文件夹,仅URL访问时使用该前缀路由 | ||||
|  | ||||
|     参数: | ||||
|         client_ip: 客户端IP(原始格式,如192.168.1.101) | ||||
| @ -21,10 +22,10 @@ def save_face_to_up_images( | ||||
|         image_format: 图片格式(默认jpg) | ||||
|  | ||||
|     返回: | ||||
|         字典:success(是否成功)、db_path(存数据库的相对路径)、local_abs_path(本地绝对路径)、msg(提示) | ||||
|         字典:success(是否成功)、db_path(存数据库的路径,带/api/file/前缀)、local_abs_path(本地绝对路径)、msg(提示) | ||||
|     """ | ||||
|     try: | ||||
|         # 1. 基础参数校验 | ||||
|         # 1. 基础参数校验(不变) | ||||
|         if not client_ip.strip(): | ||||
|             return {"success": False, "db_path": "", "local_abs_path": "", "msg": "客户端IP不能为空"} | ||||
|         if not image_bytes: | ||||
| @ -32,49 +33,50 @@ def save_face_to_up_images( | ||||
|         if image_format.lower() not in ["jpg", "jpeg", "png"]: | ||||
|             return {"success": False, "db_path": "", "local_abs_path": "", "msg": "仅支持jpg/jpeg/png格式"} | ||||
|  | ||||
|         # 2. 处理特殊字符(避免路径错误) | ||||
|         # 2. 处理特殊字符(避免路径错误)(不变) | ||||
|         safe_ip = client_ip.strip().replace(".", "_")  # IP中的.替换为_ | ||||
|         safe_face_name = face_name.strip() if (face_name and face_name.strip()) else "未命名" | ||||
|         safe_face_name = "".join([c for c in safe_face_name if c not in r'\/:*?"<>|'])  # 过滤非法字符 | ||||
|  | ||||
|         # 3. 构建根目录(强制转为绝对路径,避免相对路径混淆) | ||||
|         root_dir = Path("up_images").resolve()  # 转为绝对路径(如D:/Git/bin/video/up_images) | ||||
|         root_dir = Path("up_images").resolve() | ||||
|         if not root_dir.exists(): | ||||
|             root_dir.mkdir(parents=True, exist_ok=True) | ||||
|             print(f"[FileUtil] 已创建up_images根目录:{root_dir}") | ||||
|  | ||||
|         # 4. 构建文件层级路径(确保在root_dir子目录下) | ||||
|         # 4. 构建文件层级路径(确保在root_dir子目录下)(不变) | ||||
|         ip_dir = root_dir / safe_ip | ||||
|         face_name_dir = ip_dir / safe_face_name | ||||
|         face_name_dir.mkdir(parents=True, exist_ok=True)  # 自动创建目录 | ||||
|         print(f"[FileUtil] 图片存储目录:{face_name_dir}") | ||||
|         face_name_dir.mkdir(parents=True, exist_ok=True) | ||||
|         print(f"[FileUtil] 图片存储目录(本地):{face_name_dir}") | ||||
|  | ||||
|         # 5. 生成唯一文件名(毫秒级时间戳) | ||||
|         # 5. 生成唯一文件名(毫秒级时间戳)(不变) | ||||
|         timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3] | ||||
|  | ||||
|         image_filename = f"face_{safe_ip}_{timestamp}.{image_format.lower()}" | ||||
|         local_abs_path = face_name_dir / image_filename | ||||
|  | ||||
|         # 6. 计算路径(确保所有路径都是绝对路径且在root_dir下) | ||||
|         local_abs_path = face_name_dir / image_filename  # 绝对路径 | ||||
|  | ||||
|         # 验证路径是否在root_dir下(防止路径穿越攻击) | ||||
|         if not local_abs_path.resolve().is_relative_to(root_dir.resolve()): | ||||
|             raise Exception(f"图片路径不在up_images根目录下(安全校验失败):{local_abs_path}") | ||||
|  | ||||
|         # 数据库存储路径:强制包含up_images前缀,统一使用正斜杠 | ||||
|         relative_path = local_abs_path.relative_to(root_dir.parent)  # 相对于root_dir的父目录 | ||||
|         db_path = str(relative_path).replace("\\", "/")  # 此时会包含up_images部分 | ||||
|         # 数据库存储路径:核心修改——在原有relative_path前添加 /api/file/ 前缀 | ||||
|         relative_path = local_abs_path.relative_to(root_dir.parent) | ||||
|  | ||||
|         # 7. 写入图片文件 | ||||
|         relative_path_str = str(relative_path).replace("\\", "/") | ||||
|         # 2. 再拼接/api/file/前缀 | ||||
|         db_path = f"/api/file/{relative_path_str}" | ||||
|  | ||||
|         # 7. 写入图片文件(不变) | ||||
|         with open(local_abs_path, "wb") as f: | ||||
|             f.write(image_bytes) | ||||
|         print(f"[FileUtil] 图片保存成功:") | ||||
|         print(f"  数据库路径:{db_path}") | ||||
|         print(f"  本地绝对路径:{local_abs_path}") | ||||
|         print(f"  数据库路径(带/api/file/前缀):{db_path}") | ||||
|         print(f"  本地绝对路径(无/api/file/):{local_abs_path}") | ||||
|  | ||||
|         return { | ||||
|             "success": True, | ||||
|             "db_path": db_path,  # 格式为 up_images/192_168_110_31/小龙/xxx.jpg | ||||
|             "local_abs_path": str(local_abs_path),  # 本地绝对路径(完整路径) | ||||
|             "db_path": db_path, | ||||
|             "local_abs_path": str(local_abs_path), | ||||
|             "msg": "图片保存成功" | ||||
|         } | ||||
|  | ||||
|  | ||||
							
								
								
									
										127
									
								
								ws/ws.py
									
									
									
									
									
								
							
							
						
						
									
										127
									
								
								ws/ws.py
									
									
									
									
									
								
							| @ -17,17 +17,16 @@ from service.device_action_service import add_device_action | ||||
| from schema.device_action_schema import DeviceActionCreate | ||||
| from core.all import detect, load_model | ||||
|  | ||||
| # -------------------------- 1. AES 加密解密工具(固定密钥)-------------------------- | ||||
| # -------------------------- 1. AES 加密工具(仅用于服务器向客户端发送消息)-------------------------- | ||||
| AES_SECRET_KEY = b"jr1vA6tfWMHOYi6UXw67UuO6fdak2rMa"  # 约定密钥(32字节) | ||||
| AES_BLOCK_SIZE = 16  # AES固定块大小 | ||||
|  | ||||
|  | ||||
| def aes_encrypt(plaintext: str) -> dict: | ||||
|     """AES-CBC加密:返回{密文, IV, 算法标识}(均Base64编码)""" | ||||
|     """AES-CBC加密:返回{密文, IV, 算法标识}(均Base64编码)- 仅用于服务器发消息""" | ||||
|     try: | ||||
|         iv = os.urandom(AES_BLOCK_SIZE)  # 随机IV(16字节) | ||||
|         cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv) | ||||
|         # 明文填充+加密+Base64编码 | ||||
|         padded_plaintext = pad(plaintext.encode("utf-8"), AES_BLOCK_SIZE) | ||||
|         ciphertext = base64.b64encode(cipher.encrypt(padded_plaintext)).decode("utf-8") | ||||
|         iv_base64 = base64.b64encode(iv).decode("utf-8") | ||||
| @ -40,20 +39,6 @@ def aes_encrypt(plaintext: str) -> dict: | ||||
|         raise Exception(f"AES加密失败: {str(e)}") from e | ||||
|  | ||||
|  | ||||
| def aes_decrypt(encrypted_dict: dict) -> str: | ||||
|     """AES-CBC解密:输入加密字典,返回原始文本""" | ||||
|     try: | ||||
|         # Base64解码密文和IV | ||||
|         ciphertext = base64.b64decode(encrypted_dict["ciphertext"]) | ||||
|         iv = base64.b64decode(encrypted_dict["iv"]) | ||||
|         # 解密+去除填充 | ||||
|         cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv) | ||||
|         decrypted = unpad(cipher.decrypt(ciphertext), AES_BLOCK_SIZE).decode("utf-8") | ||||
|         return decrypted | ||||
|     except Exception as e: | ||||
|         raise Exception(f"AES解密失败: {str(e)}") from e | ||||
|  | ||||
|  | ||||
| # -------------------------- 2. 配置常量(保持原有)-------------------------- | ||||
| HEARTBEAT_INTERVAL = 30  # 心跳检查间隔(秒) | ||||
| HEARTBEAT_TIMEOUT = 600  # 客户端超时阈值(秒) | ||||
| @ -72,7 +57,7 @@ def get_current_time_file_str() -> str: | ||||
|     return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") | ||||
|  | ||||
|  | ||||
| # -------------------------- 4. 客户端连接封装(新增消息加密)-------------------------- | ||||
| # -------------------------- 4. 客户端连接封装(服务器发消息仍加密,接收消息改明文)-------------------------- | ||||
| class ClientConnection: | ||||
|     def __init__(self, websocket: WebSocket, client_ip: str): | ||||
|         self.websocket = websocket | ||||
| @ -96,28 +81,25 @@ class ClientConnection: | ||||
|         return self.consumer_task | ||||
|  | ||||
|     async def send_frame_permit(self): | ||||
|         """发送加密的帧许可信号""" | ||||
|         """发送加密的帧许可信号(服务器→客户端:加密)""" | ||||
|         try: | ||||
|             # 1. 构建原始消息 | ||||
|             frame_permit_msg = { | ||||
|                 "type": "frame", | ||||
|                 "timestamp": get_current_time_str(), | ||||
|                 "client_ip": self.client_ip | ||||
|             } | ||||
|             # 2. AES加密消息 | ||||
|             encrypted_msg = aes_encrypt(json.dumps(frame_permit_msg)) | ||||
|             # 3. 发送加密消息 | ||||
|             encrypted_msg = aes_encrypt(json.dumps(frame_permit_msg))  # 保持加密 | ||||
|             await self.websocket.send_json(encrypted_msg) | ||||
|             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 已发送加密帧许可") | ||||
|         except Exception as e: | ||||
|             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧许可加密/发送失败 - {str(e)}") | ||||
|  | ||||
|     async def consume_frames(self) -> None: | ||||
|         """消费队列中的帧并处理""" | ||||
|         """消费队列中的明文图像帧并处理""" | ||||
|         try: | ||||
|             while True: | ||||
|                 frame_data = await self.frame_queue.get() | ||||
|                 await self.send_frame_permit()  # 发送下一帧许可 | ||||
|                 await self.send_frame_permit()  # 回复仍加密 | ||||
|                 try: | ||||
|                     await self.process_frame(frame_data) | ||||
|                 finally: | ||||
| @ -128,23 +110,22 @@ class ClientConnection: | ||||
|             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费错误 - {str(e)}") | ||||
|  | ||||
|     async def process_frame(self, frame_data: bytes) -> None: | ||||
|         """处理单帧图像(含加密危险通知)""" | ||||
|         # 二进制转OpenCV图像 | ||||
|         """处理明文图像帧(危险通知仍加密发送)""" | ||||
|         # 二进制转OpenCV图像(客户端发的是明文二进制,直接解析) | ||||
|         nparr = np.frombuffer(frame_data, np.uint8) | ||||
|         img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | ||||
|         if img is None: | ||||
|             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 无法解析图像") | ||||
|             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 无法解析明文图像") | ||||
|             return | ||||
|  | ||||
|         try: | ||||
|             # 调用检测函数(client_ip + img 双参数) | ||||
|             has_violation, data, detector_type = await asyncio.to_thread( | ||||
|                 detect, self.client_ip, img | ||||
|             ) | ||||
|             print( | ||||
|                 f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测结果 - 违规: {has_violation}, 类型: {detector_type}") | ||||
|  | ||||
|             # 处理违规逻辑(发送加密危险通知) | ||||
|             # 违规通知:服务器→客户端,仍加密 | ||||
|             if has_violation: | ||||
|                 print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - {data}") | ||||
|                 # 违规次数+1 | ||||
| @ -154,19 +135,17 @@ class ClientConnection: | ||||
|                 except Exception as e: | ||||
|                     print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数更新失败 - {str(e)}") | ||||
|  | ||||
|                 # 1. 构建原始危险通知 | ||||
|                 # 构建危险通知并加密发送 | ||||
|                 danger_msg = { | ||||
|                     "type": "danger", | ||||
|                     "timestamp": get_current_time_str(), | ||||
|                     "client_ip": self.client_ip, | ||||
|                     "detail": data | ||||
|                 } | ||||
|                 # 2. AES加密通知 | ||||
|                 encrypted_danger_msg = aes_encrypt(json.dumps(danger_msg)) | ||||
|                 # 3. 发送加密通知 | ||||
|                 encrypted_danger_msg = aes_encrypt(json.dumps(danger_msg))  # 保持加密 | ||||
|                 await self.websocket.send_json(encrypted_danger_msg) | ||||
|         except Exception as e: | ||||
|             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 图像处理错误 - {str(e)}") | ||||
|             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 明文图像处理错误 - {str(e)}") | ||||
|  | ||||
|  | ||||
| # -------------------------- 5. 全局状态与心跳管理(保持原有)-------------------------- | ||||
| @ -178,7 +157,6 @@ async def heartbeat_checker(): | ||||
|     """全局心跳检查任务""" | ||||
|     while True: | ||||
|         current_time = get_current_time_str() | ||||
|         # 筛选超时客户端 | ||||
|         timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()] | ||||
|  | ||||
|         if timeout_ips: | ||||
| @ -186,11 +164,9 @@ async def heartbeat_checker(): | ||||
|             for ip in timeout_ips: | ||||
|                 try: | ||||
|                     conn = connected_clients[ip] | ||||
|                     # 取消消费任务+关闭连接 | ||||
|                     if conn.consumer_task and not conn.consumer_task.done(): | ||||
|                         conn.consumer_task.cancel() | ||||
|                     await conn.websocket.close(code=1008, reason="心跳超时") | ||||
|                     # 标记离线 | ||||
|                     await asyncio.to_thread(update_online_status_by_ip, ip, 0) | ||||
|                     action_data = DeviceActionCreate(client_ip=ip, action=0) | ||||
|                     await asyncio.to_thread(add_device_action, action_data) | ||||
| @ -205,19 +181,16 @@ async def heartbeat_checker(): | ||||
|         await asyncio.sleep(HEARTBEAT_INTERVAL) | ||||
|  | ||||
|  | ||||
| # -------------------------- 6. 消息处理工具(新增消息解密)-------------------------- | ||||
| # -------------------------- 6. 客户端明文消息处理(关键修改:删除解密逻辑)-------------------------- | ||||
| async def send_heartbeat_ack(conn: ClientConnection): | ||||
|     """发送加密的心跳确认""" | ||||
|     """发送加密的心跳确认(服务器→客户端:加密)""" | ||||
|     try: | ||||
|         # 1. 构建原始心跳确认 | ||||
|         heartbeat_ack_msg = { | ||||
|             "type": "heart", | ||||
|             "timestamp": get_current_time_str(), | ||||
|             "client_ip": conn.client_ip | ||||
|         } | ||||
|         # 2. AES加密 | ||||
|         encrypted_msg = aes_encrypt(json.dumps(heartbeat_ack_msg)) | ||||
|         # 3. 发送 | ||||
|         encrypted_msg = aes_encrypt(json.dumps(heartbeat_ack_msg))  # 保持加密 | ||||
|         await conn.websocket.send_json(encrypted_msg) | ||||
|         print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 已发送加密心跳确认") | ||||
|         return True | ||||
| @ -228,44 +201,22 @@ async def send_heartbeat_ack(conn: ClientConnection): | ||||
|  | ||||
|  | ||||
| async def handle_text_msg(conn: ClientConnection, text: str): | ||||
|     """处理加密的文本消息(如心跳)""" | ||||
|     """处理客户端明文文本消息(如心跳)- 关键修改:无需解密,直接解析JSON""" | ||||
|     try: | ||||
|         # 1. 解析加密字典 | ||||
|         encrypted_dict = json.loads(text) | ||||
|         # 2. AES解密 | ||||
|         decrypted_text = aes_decrypt(encrypted_dict) | ||||
|         # 3. 解析业务消息 | ||||
|         msg = json.loads(decrypted_text) | ||||
|  | ||||
|         # 客户端发的是明文JSON,直接解析(删除原解密步骤) | ||||
|         msg = json.loads(text) | ||||
|         if msg.get("type") == "heart": | ||||
|             conn.update_heartbeat() | ||||
|             await send_heartbeat_ack(conn) | ||||
|             await send_heartbeat_ack(conn)  # 服务器回复仍加密 | ||||
|         else: | ||||
|             print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 未知文本类型({msg.get('type')})") | ||||
|             print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 未知明文文本类型({msg.get('type')})") | ||||
|     except json.JSONDecodeError: | ||||
|         print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 无效JSON格式") | ||||
|         print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 无效JSON格式(明文文本)") | ||||
|     except Exception as e: | ||||
|         print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 文本消息解密失败 - {str(e)}") | ||||
|         print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 明文文本消息处理失败 - {str(e)}") | ||||
|  | ||||
|  | ||||
| async def handle_binary_msg(conn: ClientConnection, data: str): | ||||
|     """处理加密的图像消息(客户端需先转Base64+加密)""" | ||||
|     try: | ||||
|         # 1. 解密得到Base64编码的图像 | ||||
|         encrypted_dict = json.loads(data) | ||||
|         decrypted_base64 = aes_decrypt(encrypted_dict) | ||||
|         # 2. Base64解码为二进制图像 | ||||
|         frame_data = base64.b64decode(decrypted_base64) | ||||
|         # 3. 加入帧队列 | ||||
|         conn.frame_queue.put_nowait(frame_data) | ||||
|         print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 解密后图像({len(frame_data)}字节)入队") | ||||
|     except asyncio.QueueFull: | ||||
|         print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 帧队列已满,丢弃数据") | ||||
|     except Exception as e: | ||||
|         print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 图像消息解密失败 - {str(e)}") | ||||
|  | ||||
|  | ||||
| # -------------------------- 7. WebSocket路由与生命周期(保持原有结构)-------------------------- | ||||
| # -------------------------- 7. WebSocket路由与生命周期(关键修改:处理明文二进制图像)-------------------------- | ||||
| ws_router = APIRouter() | ||||
|  | ||||
|  | ||||
| @ -276,7 +227,6 @@ async def lifespan(app: FastAPI): | ||||
|     heartbeat_task = asyncio.create_task(heartbeat_checker()) | ||||
|     print(f"[{get_current_time_str()}] 心跳检查任务启动(ID: {id(heartbeat_task)})") | ||||
|     yield | ||||
|     # 关闭时清理 | ||||
|     if heartbeat_task and not heartbeat_task.done(): | ||||
|         heartbeat_task.cancel() | ||||
|         await heartbeat_task | ||||
| @ -285,8 +235,8 @@ async def lifespan(app: FastAPI): | ||||
|  | ||||
| @ws_router.websocket(WS_ENDPOINT) | ||||
| async def websocket_endpoint(websocket: WebSocket): | ||||
|     """WebSocket连接处理入口""" | ||||
|     load_model()  # 加载检测模型(仅一次) | ||||
|     """WebSocket连接处理入口 - 关键修改:接收客户端明文二进制图像""" | ||||
|     load_model()  # 加载检测模型(建议移到全局,避免重复加载) | ||||
|     await websocket.accept() | ||||
|     client_ip = websocket.client.host if websocket.client else "unknown_ip" | ||||
|     current_time = get_current_time_str() | ||||
| @ -306,8 +256,8 @@ async def websocket_endpoint(websocket: WebSocket): | ||||
|         # 注册新连接 | ||||
|         new_conn = ClientConnection(websocket, client_ip) | ||||
|         connected_clients[client_ip] = new_conn | ||||
|         new_conn.start_consumer()  # 启动帧消费 | ||||
|         await new_conn.send_frame_permit()  # 发送首次帧许可 | ||||
|         new_conn.start_consumer() | ||||
|         await new_conn.send_frame_permit()  # 首次许可仍加密 | ||||
|  | ||||
|         # 标记客户端上线 | ||||
|         try: | ||||
| @ -321,28 +271,33 @@ async def websocket_endpoint(websocket: WebSocket): | ||||
|  | ||||
|         print(f"[{current_time}] 客户端{client_ip}: 连接注册成功,在线数: {len(connected_clients)}") | ||||
|  | ||||
|         # 消息循环(接收客户端消息) | ||||
|         # 消息循环:接收客户端明文消息(关键修改) | ||||
|         while True: | ||||
|             data = await websocket.receive() | ||||
|             if "text" in data: | ||||
|                 # 处理加密文本消息(心跳、客户端指令) | ||||
|                 # 处理客户端明文文本(如心跳:{"type":"heart",...}) | ||||
|                 await handle_text_msg(new_conn, data["text"]) | ||||
|             elif "bytes" in data: | ||||
|                 # 兼容客户端发送二进制:先转Base64再处理 | ||||
|                 base64_data = base64.b64encode(data["bytes"]).decode("utf-8") | ||||
|                 await handle_binary_msg(new_conn, base64_data) | ||||
|                 # 处理客户端明文二进制图像(直接入队,无需解密) | ||||
|                 frame_data = data["bytes"] | ||||
|                 try: | ||||
|                     new_conn.frame_queue.put_nowait(frame_data) | ||||
|                     print(f"[{get_current_time_str()}] 客户端{client_ip}: 明文图像({len(frame_data)}字节)入队") | ||||
|                 except asyncio.QueueFull: | ||||
|                     print(f"[{get_current_time_str()}] 客户端{client_ip}: 帧队列已满,丢弃数据") | ||||
|                 except Exception as e: | ||||
|                     print(f"[{get_current_time_str()}] 客户端{client_ip}: 明文图像处理失败 - {str(e)}") | ||||
|  | ||||
|     except WebSocketDisconnect as e: | ||||
|         print(f"[{get_current_time_str()}] 客户端{client_ip}: 主动断开连接(代码: {e.code})") | ||||
|     except Exception as e: | ||||
|         print(f"[{get_current_time_str()}] 客户端{client_ip}: 连接异常 - {str(e)[:50]}") | ||||
|     finally: | ||||
|         # 清理资源(断开后处理) | ||||
|         # 清理资源 | ||||
|         if client_ip in connected_clients: | ||||
|             conn = connected_clients[client_ip] | ||||
|             if conn.consumer_task and not conn.consumer_task.done(): | ||||
|                 conn.consumer_task.cancel() | ||||
|             # 仅上线成功的客户端,才标记离线 | ||||
|             if is_online_updated: | ||||
|                 try: | ||||
|                     await asyncio.to_thread(update_online_status_by_ip, client_ip, 0) | ||||
|  | ||||
		Reference in New Issue
	
	Block a user