yolo模型识别不到
This commit is contained in:
		
							
								
								
									
										10
									
								
								core/all.py
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								core/all.py
									
									
									
									
									
								
							| @ -67,7 +67,7 @@ def detect(client_ip, frame): | |||||||
|         if full_save_path: |         if full_save_path: | ||||||
|             cv2.imwrite(full_save_path, frame) |             cv2.imwrite(full_save_path, frame) | ||||||
|             print(f"✅ yolo违规图片已保存:{display_path}")  # 日志也修正 |             print(f"✅ yolo违规图片已保存:{display_path}")  # 日志也修正 | ||||||
|         save_db(model_type="yolo", client_ip=client_ip, result=str(full_save_path)) |         save_db(model_type="yolo", client_ip=client_ip, result=str(display_path)) | ||||||
|         return (True, yolo_result, "yolo") |         return (True, yolo_result, "yolo") | ||||||
|  |  | ||||||
|     # 2. 人脸检测 |     # 2. 人脸检测 | ||||||
| @ -77,17 +77,19 @@ def detect(client_ip, frame): | |||||||
|         if full_save_path: |         if full_save_path: | ||||||
|             cv2.imwrite(full_save_path, frame) |             cv2.imwrite(full_save_path, frame) | ||||||
|             print(f"✅ face违规图片已保存:{display_path}")  # 日志也修正 |             print(f"✅ face违规图片已保存:{display_path}")  # 日志也修正 | ||||||
|         save_db(model_type="face", client_ip=client_ip, result=str(full_save_path)) |         save_db(model_type="face", client_ip=client_ip, result=str(display_path)) | ||||||
|         return (True, face_result, "face") |         return (True, face_result, "face") | ||||||
|  |  | ||||||
|     # 3. OCR检测 |     # 3. OCR检测 | ||||||
|     ocr_flag, ocr_result = ocrDetect(frame) |     ocr_flag, ocr_result = ocrDetect(frame) | ||||||
|     if ocr_flag: |     if ocr_flag: | ||||||
|         full_save_path, display_path = get_image_save_path(model_type="ocr", client_ip=client_ip)  # 这里改了 |         full_save_path, display_path = get_image_save_path(model_type="ocr", client_ip=client_ip) | ||||||
|  |         print(f"✅ ocr违规图片已保存:{display_path}") | ||||||
|  |         # 这里改了 | ||||||
|         if full_save_path: |         if full_save_path: | ||||||
|             cv2.imwrite(full_save_path, frame) |             cv2.imwrite(full_save_path, frame) | ||||||
|             print(f"✅ ocr违规图片已保存:{display_path}")  # 日志也修正 |             print(f"✅ ocr违规图片已保存:{display_path}")  # 日志也修正 | ||||||
|         save_db(model_type="ocr", client_ip=client_ip, result=str(full_save_path)) |         save_db(model_type="ocr", client_ip=client_ip, result=str(display_path)) | ||||||
|         return (True, ocr_result, "ocr") |         return (True, ocr_result, "ocr") | ||||||
|  |  | ||||||
|     # 4. 无违规内容(不保存图片) |     # 4. 无违规内容(不保存图片) | ||||||
|  | |||||||
| @ -1,30 +1,29 @@ | |||||||
| import datetime | import datetime | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  | from typing import List, Tuple | ||||||
|  |  | ||||||
| from service.device_service import get_unique_client_ips | from service.device_service import get_unique_client_ips | ||||||
|  |  | ||||||
|  |  | ||||||
| def create_directory_structure(): | def create_directory_structure(): | ||||||
|     """创建项目所需的目录结构,为所有客户端IP预创建基础目录""" |     """创建项目所需的目录结构,为所有客户端IP预创建基础目录""" | ||||||
|     try: |     try: | ||||||
|         # 1. 创建根目录下的resource文件夹(存在则跳过,不覆盖子内容) |         # 1. 创建根目录下的resource文件夹(存在则跳过,不覆盖子内容) | ||||||
|         resource_dir = Path("resource") |         resource_dir = Path("resource") | ||||||
|         resource_dir.mkdir(exist_ok=True) |         resource_dir.mkdir(exist_ok=True) | ||||||
|         # print(f"确保resource目录存在: {resource_dir.absolute()}") |  | ||||||
|  |  | ||||||
|         # 2. 在resource下创建dect文件夹 |         # 2. 在resource下创建dect文件夹 | ||||||
|         dect_dir = resource_dir / "dect" |         dect_dir = resource_dir / "dect" | ||||||
|         dect_dir.mkdir(exist_ok=True) |         dect_dir.mkdir(exist_ok=True) | ||||||
|         # print(f"确保dect目录存在: {dect_dir.absolute()}") |  | ||||||
|  |  | ||||||
|         # 3. 在dect下创建三个模型文件夹 |         # 3. 在dect下创建三个模型文件夹 | ||||||
|         model_dirs = ["ocr", "face", "yolo"] |         model_dirs = ["ocr", "face", "yolo"] | ||||||
|         for model in model_dirs: |         for model in model_dirs: | ||||||
|             model_dir = dect_dir / model |             model_dir = dect_dir / model | ||||||
|             model_dir.mkdir(exist_ok=True) |             model_dir.mkdir(exist_ok=True) | ||||||
|             # print(f"确保{model}模型目录存在: {model_dir.absolute()}") |  | ||||||
|  |  | ||||||
|         # 4. 调用外部方法获取所有客户端IP地址 |         # 4. 调用外部方法获取所有客户端IP地址 | ||||||
|         try: |         try: | ||||||
|             # 调用外部ip_read()方法获取所有客户端IP地址列表 |  | ||||||
|             all_ip_addresses = get_unique_client_ips() |             all_ip_addresses = get_unique_client_ips() | ||||||
|  |  | ||||||
|             # 确保返回的是列表类型 |             # 确保返回的是列表类型 | ||||||
| @ -58,7 +57,6 @@ def create_directory_structure(): | |||||||
|  |  | ||||||
|                     # 递归创建目录(存在则跳过,不覆盖) |                     # 递归创建目录(存在则跳过,不覆盖) | ||||||
|                     month_dir.mkdir(parents=True, exist_ok=True) |                     month_dir.mkdir(parents=True, exist_ok=True) | ||||||
|                     # print(f"为客户端IP {ip} 创建/确保目录存在: {month_dir.absolute()}") |  | ||||||
|  |  | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             print(f"处理客户端IP和日期目录时发生错误: {str(e)}") |             print(f"处理客户端IP和日期目录时发生错误: {str(e)}") | ||||||
| @ -67,52 +65,68 @@ def create_directory_structure(): | |||||||
|         print(f"创建基础目录结构时发生错误: {str(e)}") |         print(f"创建基础目录结构时发生错误: {str(e)}") | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_image_save_path(model_type: str, client_ip: str) -> tuple: | def get_image_save_path(model_type: str, client_ip: str) -> Tuple[str, str]: | ||||||
|     """ |     """ | ||||||
|     获取图片保存的「完整路径」和「显示用短路径」 |     获取图片保存的「本地完整路径」和「带路由前缀的显示路径」 | ||||||
|  |  | ||||||
|     参数: |     参数: | ||||||
|         model_type: 模型类型,应为"ocr"、"face"或"yolo" |         model_type: 模型类型,应为"ocr"、"face"或"yolo" | ||||||
|         client_ip: 检测到违禁的客户端IP地址(原始格式,如192.168.1.101) |         client_ip: 检测到违禁的客户端IP地址(原始格式,如192.168.1.101) | ||||||
|  |  | ||||||
|     返回: |     返回: | ||||||
|         元组 (完整保存路径, 显示用短路径);若出错则返回 ("", "") |         元组 (本地完整保存路径, 带/api/file/前缀的显示路径);若出错则返回 ("", "") | ||||||
|     """ |     """ | ||||||
|     try: |     try: | ||||||
|  |         # 验证模型类型有效性 | ||||||
|  |         valid_models = ["ocr", "face", "yolo"] | ||||||
|  |         if model_type not in valid_models: | ||||||
|  |             raise ValueError(f"无效的模型类型: {model_type},必须是{valid_models}之一") | ||||||
|  |  | ||||||
|         # 1. 验证客户端IP有效性(检查是否在已知IP列表中) |         # 1. 验证客户端IP有效性(检查是否在已知IP列表中) | ||||||
|         all_ip_addresses = get_unique_client_ips() |         all_ip_addresses = get_unique_client_ips() | ||||||
|         if not isinstance(all_ip_addresses, list): |         if not isinstance(all_ip_addresses, list): | ||||||
|             all_ip_addresses = [all_ip_addresses] |             all_ip_addresses = [all_ip_addresses] | ||||||
|         valid_ips = [ip.strip() for ip in all_ip_addresses if ip.strip()] |         valid_ips = [ip.strip() for ip in all_ip_addresses if ip.strip()] | ||||||
|  |  | ||||||
|         if client_ip.strip() not in valid_ips: |         client_ip_stripped = client_ip.strip() | ||||||
|  |         if client_ip_stripped not in valid_ips: | ||||||
|             raise ValueError(f"客户端IP {client_ip} 不在已知IP列表中,无法保存文件") |             raise ValueError(f"客户端IP {client_ip} 不在已知IP列表中,无法保存文件") | ||||||
|  |  | ||||||
|         # 2. 处理IP地址(与目录创建逻辑一致,将.替换为_) |         # 2. 处理IP地址(将.替换为_,避免路径问题) | ||||||
|         safe_ip = client_ip.strip().replace(".", "_") |         safe_ip = client_ip_stripped.replace(".", "_") | ||||||
|  |  | ||||||
|         # 3. 获取当前日期和毫秒级时间戳(确保文件名唯一) |         # 3. 获取当前日期和毫秒级时间戳(确保文件名唯一) | ||||||
|         now = datetime.datetime.now() |         now = datetime.datetime.now() | ||||||
|         current_year = str(now.year) |         current_year = str(now.year) | ||||||
|         current_month = str(now.month) |         current_month = str(now.month).zfill(2)  # 确保月份为两位数 | ||||||
|         current_day = str(now.day) |         current_day = str(now.day).zfill(2)      # 确保日期为两位数 | ||||||
|         # 时间戳格式:年月日时分秒毫秒(如20250910143050123) |         timestamp = now.strftime("%Y%m%d%H%M%S%f")[:-3]  # 取毫秒级时间戳 | ||||||
|         timestamp = now.strftime("%Y%m%d%H%M%S%f")[:-3] |  | ||||||
|  |  | ||||||
|         # 4. 定义基础目录(用于生成相对路径) |         # 4. 定义基础目录(用于生成相对路径) | ||||||
|         base_dir = Path("resource") / "dect"  # 显示路径会去掉这个前缀 |         base_dir = Path("resource") / "dect" | ||||||
|         # 构建日级目录(完整路径:resource/dect/{model}/{safe_ip}/{年}/{月}/{日}) |         # 构建日级目录(完整路径:resource/dect/{model}/{safe_ip}/{年}/{月}/{日}) | ||||||
|         day_dir = base_dir / model_type / safe_ip / current_year / current_month / current_day |         day_dir = base_dir / model_type / safe_ip / current_year / current_month / current_day | ||||||
|         day_dir.mkdir(parents=True, exist_ok=True)  # 确保日目录存在 |         day_dir.mkdir(parents=True, exist_ok=True)  # 确保目录存在 | ||||||
|  |  | ||||||
|         # 5. 构建唯一文件名 |         # 5. 构建唯一文件名 | ||||||
|         image_filename = f"dect_{model_type}_{safe_ip}_{current_year}{current_month}{current_day}_{timestamp}.jpg" |         image_filename = f"dect_{model_type}_{safe_ip}_{current_year}{current_month}{current_day}_{timestamp}.jpg" | ||||||
|  |  | ||||||
|         # 6. 生成完整路径(用于实际保存图片)和显示路径(用于打印) |         # 6. 生成「本地完整路径」(使用系统路径,但在字符串表示时统一为正斜杠) | ||||||
|         full_path = day_dir / image_filename  # 完整路径:resource/dect/.../xxx.jpg |         local_full_path = day_dir / image_filename | ||||||
|         display_path = full_path.relative_to(base_dir)  # 短路径:{model}/.../xxx.jpg(去掉resource/dect) |         # 转换为字符串并统一使用正斜杠 | ||||||
|  |         local_full_path_str = str(local_full_path).replace("\\", "/") | ||||||
|  |  | ||||||
|         return str(full_path), str(display_path) |         # 7. 生成带路由前缀的显示路径(核心修改部分) | ||||||
|  |         # 获取项目根目录(base_dir是resource/dect,向上两级即为项目根目录) | ||||||
|  |         project_root = base_dir.parents[1] | ||||||
|  |         # 计算相对于项目根目录的路径(包含resource/dect层级) | ||||||
|  |         relative_path = local_full_path.relative_to(project_root) | ||||||
|  |         # 转换为字符串并统一使用正斜杠 | ||||||
|  |         relative_path_str = str(relative_path).replace("\\", "/") | ||||||
|  |         # 拼接路由前缀 | ||||||
|  |         routed_display_path = f"/api/file/{relative_path_str}" | ||||||
|  |  | ||||||
|  |         return local_full_path_str, routed_display_path | ||||||
|  |  | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         print(f"获取图片保存路径时发生错误: {str(e)}") |         print(f"获取图片保存路径时发生错误: {str(e)}") | ||||||
|  | |||||||
							
								
								
									
										83
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										83
									
								
								main.py
									
									
									
									
									
								
							| @ -1,10 +1,8 @@ | |||||||
| import uvicorn | import uvicorn | ||||||
| import threading |  | ||||||
| import time | import time | ||||||
| import os | import os | ||||||
| from fastapi import FastAPI | from fastapi import FastAPI | ||||||
| from fastapi.middleware.cors import CORSMiddleware | from fastapi.middleware.cors import CORSMiddleware | ||||||
| from service.file_service import app as flask_app |  | ||||||
|  |  | ||||||
| # 原有业务导入 | # 原有业务导入 | ||||||
| from core.all import load_model | from core.all import load_model | ||||||
| @ -15,32 +13,12 @@ from service.sensitive_service import router as sensitive_router | |||||||
| from service.face_service import router as face_router | from service.face_service import router as face_router | ||||||
| from service.device_service import router as device_router | from service.device_service import router as device_router | ||||||
| from service.model_service import router as model_router | from service.model_service import router as model_router | ||||||
|  | from service.file_service import router as file_router | ||||||
|  | from service.device_danger_service import router as device_danger_router | ||||||
| from ws.ws import ws_router, lifespan | from ws.ws import ws_router, lifespan | ||||||
| from core.establish import create_directory_structure | from core.establish import create_directory_structure | ||||||
|  |  | ||||||
|  | # 初始化 FastAPI 应用 | ||||||
| # Flask 服务启动函数(不变) |  | ||||||
| def start_flask_service(): |  | ||||||
|     try: |  | ||||||
|         print(f"\n[Flask 服务] 准备启动,端口:5000") |  | ||||||
|         print(f"[Flask 服务] 访问示例:http://服务器IP:5000/resource/dect/ocr/xxx.jpg\n") |  | ||||||
|  |  | ||||||
|         BASE_IMAGE_DIR = os.path.abspath(os.path.join("resource", "dect")) |  | ||||||
|         if not os.path.exists(BASE_IMAGE_DIR): |  | ||||||
|             print(f"[Flask 服务] 图片根目录不存在,创建:{BASE_IMAGE_DIR}") |  | ||||||
|             os.makedirs(BASE_IMAGE_DIR, exist_ok=True) |  | ||||||
|  |  | ||||||
|         flask_app.run( |  | ||||||
|             host="0.0.0.0", |  | ||||||
|             port=5000, |  | ||||||
|             debug=False, |  | ||||||
|             use_reloader=False |  | ||||||
|         ) |  | ||||||
|     except Exception as e: |  | ||||||
|         print(f"[Flask 服务] 启动失败:{str(e)}") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # 初始化 FastAPI 应用(新增 CORS 配置) |  | ||||||
| app = FastAPI( | app = FastAPI( | ||||||
|     title="内容安全审核后台", |     title="内容安全审核后台", | ||||||
|     description="含图片访问服务和动态模型管理", |     description="含图片访问服务和动态模型管理", | ||||||
| @ -48,38 +26,33 @@ app = FastAPI( | |||||||
|     lifespan=lifespan |     lifespan=lifespan | ||||||
| ) | ) | ||||||
|  |  | ||||||
| # ------------------------------ |  | ||||||
| # 新增:完整 CORS 配置(解决跨域问题) |  | ||||||
| # ------------------------------ |  | ||||||
| # 1. 允许的前端域名(根据实际情况修改!本地开发通常是 http://localhost:8080 等) |  | ||||||
| ALLOWED_ORIGINS = [ | ALLOWED_ORIGINS = [ | ||||||
|     # "http://localhost:8080",  # 前端本地开发地址(必改,填实际前端地址) |     "*" | ||||||
|     # "http://127.0.0.1:8080", |  | ||||||
|     # "http://服务器IP:8080",    # 部署后前端地址(如适用) |  | ||||||
|     "*" #表示允许所有域名(开发环境可用,生产环境不推荐) |  | ||||||
| ] | ] | ||||||
|  |  | ||||||
| # 2. 配置 CORS 中间件 | # 配置 CORS 中间件 | ||||||
| app.add_middleware( | app.add_middleware( | ||||||
|     CORSMiddleware, |     CORSMiddleware, | ||||||
|     allow_origins=ALLOWED_ORIGINS,        # 允许的前端域名 |     allow_origins=ALLOWED_ORIGINS,        # 允许的前端域名 | ||||||
|     allow_credentials=True,               # 允许携带 Cookie(如需登录态则必开) |     allow_credentials=True,               # 允许携带 Cookie | ||||||
|     allow_methods=["*"],                  # 允许所有 HTTP 方法(包括 PUT/DELETE) |     allow_methods=["*"],                  # 允许所有 HTTP 方法 | ||||||
|     allow_headers=["*"],                  # 允许所有请求头(包括 Content-Type) |     allow_headers=["*"],                  # 允许所有请求头 | ||||||
| ) | ) | ||||||
|  |  | ||||||
| # 注册路由(不变) | # 注册路由 | ||||||
| app.include_router(user_router) | app.include_router(user_router) | ||||||
| app.include_router(device_router) | app.include_router(device_router) | ||||||
| app.include_router(face_router) | app.include_router(face_router) | ||||||
| app.include_router(sensitive_router) | app.include_router(sensitive_router) | ||||||
| app.include_router(model_router)  # 模型管理路由 | app.include_router(model_router) | ||||||
|  | app.include_router(file_router) | ||||||
|  | app.include_router(device_danger_router) | ||||||
| app.include_router(ws_router) | app.include_router(ws_router) | ||||||
|  |  | ||||||
| # 注册全局异常处理器(不变) | # 注册全局异常处理器 | ||||||
| app.add_exception_handler(Exception, global_exception_handler) | app.add_exception_handler(Exception, global_exception_handler) | ||||||
|  |  | ||||||
| # 主服务启动入口(不变) | # 主服务启动入口 | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     create_directory_structure() |     create_directory_structure() | ||||||
|     print(f"[初始化] 目录结构创建完成") |     print(f"[初始化] 目录结构创建完成") | ||||||
| @ -89,11 +62,11 @@ if __name__ == "__main__": | |||||||
|     os.makedirs(MODEL_SAVE_DIR, exist_ok=True) |     os.makedirs(MODEL_SAVE_DIR, exist_ok=True) | ||||||
|     print(f"[初始化] 模型保存目录:{MODEL_SAVE_DIR}") |     print(f"[初始化] 模型保存目录:{MODEL_SAVE_DIR}") | ||||||
|  |  | ||||||
|     # # 模型路径配置 |     # 确保图片目录存在(原Flask服务负责的目录) | ||||||
|     # YOLO_MODEL_PATH = os.path.join("core", "models", "best.pt") |     BASE_IMAGE_DIR = os.path.abspath(os.path.join("resource", "dect")) | ||||||
|     # OCR_CONFIG_PATH = os.path.join("core", "config", "config.yaml") |     if not os.path.exists(BASE_IMAGE_DIR): | ||||||
|     # print(f"[初始化] 默认YOLO模型路径:{YOLO_MODEL_PATH}") |         print(f"[初始化] 图片根目录不存在,创建:{BASE_IMAGE_DIR}") | ||||||
|     # print(f"[初始化] OCR 配置路径:{OCR_CONFIG_PATH}") |         os.makedirs(BASE_IMAGE_DIR, exist_ok=True) | ||||||
|  |  | ||||||
|     # 加载检测模型 |     # 加载检测模型 | ||||||
|     try: |     try: | ||||||
| @ -105,23 +78,7 @@ if __name__ == "__main__": | |||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         print(f"[初始化] 模型加载警告:{str(e)}(服务仍启动)") |         print(f"[初始化] 模型加载警告:{str(e)}(服务仍启动)") | ||||||
|  |  | ||||||
|  |     # 启动 FastAPI 主服务(仅使用8000端口) | ||||||
|  |  | ||||||
|     # 2. 启动 Flask 服务(子线程) |  | ||||||
|     flask_thread = threading.Thread( |  | ||||||
|         target=start_flask_service, |  | ||||||
|         daemon=True |  | ||||||
|     ) |  | ||||||
|     flask_thread.start() |  | ||||||
|  |  | ||||||
|     # 等待 Flask 初始化 |  | ||||||
|     time.sleep(1) |  | ||||||
|     if flask_thread.is_alive(): |  | ||||||
|         print(f"[Flask 服务] 启动成功(运行中)") |  | ||||||
|     else: |  | ||||||
|         print(f"[Flask 服务] 启动失败!图片访问不可用") |  | ||||||
|  |  | ||||||
|     # 3. 启动 FastAPI 主服务 |  | ||||||
|     port = int(SERVER_CONFIG.get("port", 8000)) |     port = int(SERVER_CONFIG.get("port", 8000)) | ||||||
|     print(f"\n[FastAPI 服务] 准备启动,端口:{port}") |     print(f"\n[FastAPI 服务] 准备启动,端口:{port}") | ||||||
|     print(f"[FastAPI 服务] 接口文档:http://服务器IP:{port}/docs\n") |     print(f"[FastAPI 服务] 接口文档:http://服务器IP:{port}/docs\n") | ||||||
|  | |||||||
							
								
								
									
										33
									
								
								schema/device_danger_schema.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								schema/device_danger_schema.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,33 @@ | |||||||
|  | from datetime import datetime | ||||||
|  | from typing import Optional, List | ||||||
|  | from pydantic import BaseModel, Field | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 请求模型 | ||||||
|  | # ------------------------------ | ||||||
|  | class DeviceDangerCreateRequest(BaseModel): | ||||||
|  |     """设备危险记录创建请求模型""" | ||||||
|  |     client_ip: str = Field(..., max_length=100, description="设备IP地址(必须与devices表中IP对应)") | ||||||
|  |     type: str = Field(..., max_length=50, description="危险类型(如:病毒检测、端口异常、权限泄露等)") | ||||||
|  |     result: str = Field(..., description="危险检测结果/处理结果(如:检测到木马病毒,已隔离;端口22异常开放,已关闭)") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 响应模型 | ||||||
|  | # ------------------------------ | ||||||
|  | class DeviceDangerResponse(BaseModel): | ||||||
|  |     """单条设备危险记录响应模型(与device_danger表字段对齐,updated_at允许为null)""" | ||||||
|  |     id: int = Field(..., description="危险记录主键ID") | ||||||
|  |     client_ip: str = Field(..., max_length=100, description="设备IP地址") | ||||||
|  |     type: str = Field(..., max_length=50, description="危险类型") | ||||||
|  |     result: str = Field(..., description="危险检测结果/处理结果") | ||||||
|  |     created_at: datetime = Field(..., description="记录创建时间(危险发生/检测时间)") | ||||||
|  |     updated_at: Optional[datetime] = Field(None, description="记录更新时间(数据库中该字段当前为null)") | ||||||
|  |     model_config = {"from_attributes": True} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class DeviceDangerListResponse(BaseModel): | ||||||
|  |     """设备危险记录列表响应模型(带分页)""" | ||||||
|  |     total: int = Field(..., description="危险记录总数") | ||||||
|  |     dangers: List[DeviceDangerResponse] = Field(..., description="设备危险记录列表") | ||||||
| @ -28,7 +28,6 @@ class DeviceResponse(BaseModel): | |||||||
|     params: Optional[str] = Field(None, description="扩展参数(JSON字符串)") |     params: Optional[str] = Field(None, description="扩展参数(JSON字符串)") | ||||||
|     created_at: datetime = Field(..., description="记录创建时间") |     created_at: datetime = Field(..., description="记录创建时间") | ||||||
|     updated_at: datetime = Field(..., description="记录更新时间") |     updated_at: datetime = Field(..., description="记录更新时间") | ||||||
|  |  | ||||||
|     model_config = {"from_attributes": True}  # 支持从数据库结果直接转换 |     model_config = {"from_attributes": True}  # 支持从数据库结果直接转换 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -12,7 +12,7 @@ from schema.response_schema import APIResponse | |||||||
|  |  | ||||||
| # 路由配置 | # 路由配置 | ||||||
| router = APIRouter( | router = APIRouter( | ||||||
|     prefix="/device/actions", |     prefix="/api/device/actions", | ||||||
|     tags=["设备操作记录"] |     tags=["设备操作记录"] | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										267
									
								
								service/device_danger_service.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										267
									
								
								service/device_danger_service.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,267 @@ | |||||||
|  | import json | ||||||
|  | from datetime import date | ||||||
|  |  | ||||||
|  | from fastapi import APIRouter, Query, HTTPException, Path | ||||||
|  | from mysql.connector import Error as MySQLError | ||||||
|  |  | ||||||
|  | from ds.db import db | ||||||
|  | from encryption.encrypt_decorator import encrypt_response | ||||||
|  | from schema.device_danger_schema import ( | ||||||
|  |     DeviceDangerCreateRequest, DeviceDangerResponse, DeviceDangerListResponse | ||||||
|  | ) | ||||||
|  | from schema.response_schema import APIResponse | ||||||
|  |  | ||||||
|  | # 路由初始化(前缀与设备管理相关,标签区分功能) | ||||||
|  | router = APIRouter( | ||||||
|  |     prefix="/api/devices/dangers", | ||||||
|  |     tags=["设备管理-危险记录"] | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 内部工具方法 - 检查设备是否存在(复用设备表逻辑) | ||||||
|  | # ------------------------------ | ||||||
|  | def check_device_exist(client_ip: str) -> bool: | ||||||
|  |     """ | ||||||
|  |     检查指定IP的设备是否在devices表中存在 | ||||||
|  |  | ||||||
|  |     :param client_ip: 设备IP地址 | ||||||
|  |     :return: 存在返回True,不存在返回False | ||||||
|  |     """ | ||||||
|  |     if not client_ip: | ||||||
|  |         raise ValueError("设备IP不能为空") | ||||||
|  |  | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,)) | ||||||
|  |         return cursor.fetchone() is not None | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise Exception(f"检查设备存在性失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 内部工具方法 - 创建设备危险记录(核心插入逻辑) | ||||||
|  | # ------------------------------ | ||||||
|  | def create_danger_record(danger_data: DeviceDangerCreateRequest) -> DeviceDangerResponse: | ||||||
|  |     """ | ||||||
|  |     内部工具方法:向device_danger表插入新的危险记录 | ||||||
|  |  | ||||||
|  |     :param danger_data: 危险记录创建请求数据 | ||||||
|  |     :return: 创建成功的危险记录模型对象 | ||||||
|  |     """ | ||||||
|  |     # 先检查设备是否存在 | ||||||
|  |     if not check_device_exist(danger_data.client_ip): | ||||||
|  |         raise ValueError(f"IP为 {danger_data.client_ip} 的设备不存在,无法创建危险记录") | ||||||
|  |  | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 插入危险记录(id自增,时间自动填充) | ||||||
|  |         insert_query = """ | ||||||
|  |             INSERT INTO device_danger  | ||||||
|  |             (client_ip, type, result, created_at, updated_at) | ||||||
|  |             VALUES (%s, %s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) | ||||||
|  |         """ | ||||||
|  |         cursor.execute(insert_query, ( | ||||||
|  |             danger_data.client_ip, | ||||||
|  |             danger_data.type, | ||||||
|  |             danger_data.result | ||||||
|  |         )) | ||||||
|  |         conn.commit() | ||||||
|  |  | ||||||
|  |         # 获取刚创建的记录(用自增ID查询) | ||||||
|  |         danger_id = cursor.lastrowid | ||||||
|  |         cursor.execute("SELECT * FROM device_danger WHERE id = %s", (danger_id,)) | ||||||
|  |         new_danger = cursor.fetchone() | ||||||
|  |  | ||||||
|  |         return DeviceDangerResponse(**new_danger) | ||||||
|  |     except MySQLError as e: | ||||||
|  |         if conn: | ||||||
|  |             conn.rollback() | ||||||
|  |         raise Exception(f"插入危险记录失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 接口1:创建设备危险记录 | ||||||
|  | # ------------------------------ | ||||||
|  | @router.post("/add", response_model=APIResponse, summary="创建设备危险记录") | ||||||
|  | @encrypt_response() | ||||||
|  | async def add_device_danger(danger_data: DeviceDangerCreateRequest): | ||||||
|  |     try: | ||||||
|  |         # 调用内部方法创建记录 | ||||||
|  |         new_danger = create_danger_record(danger_data) | ||||||
|  |  | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message=f"设备[{danger_data.client_ip}]危险记录创建成功", | ||||||
|  |             data=new_danger | ||||||
|  |         ) | ||||||
|  |     except ValueError as e: | ||||||
|  |         # 设备不存在等业务异常 | ||||||
|  |         raise HTTPException(status_code=400, detail=str(e)) from e | ||||||
|  |     except Exception as e: | ||||||
|  |         # 数据库异常等系统错误 | ||||||
|  |         raise HTTPException(status_code=500, detail=str(e)) from e | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 接口2:获取危险记录列表(支持多条件筛选+分页) | ||||||
|  | # ------------------------------ | ||||||
|  | @router.get("/", response_model=APIResponse, summary="获取设备危险记录列表(多条件筛选)") | ||||||
|  | @encrypt_response() | ||||||
|  | async def get_danger_list( | ||||||
|  |         page: int = Query(1, ge=1, description="页码,默认第1页"), | ||||||
|  |         page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100之间"), | ||||||
|  |         client_ip: str = Query(None, max_length=100, description="按设备IP筛选"), | ||||||
|  |         danger_type: str = Query(None, max_length=50, alias="type", description="按危险类型筛选"), | ||||||
|  |         start_date: date = Query(None, description="按创建时间筛选(开始日期,格式YYYY-MM-DD)"), | ||||||
|  |         end_date: date = Query(None, description="按创建时间筛选(结束日期,格式YYYY-MM-DD)") | ||||||
|  | ): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 构建筛选条件 | ||||||
|  |         where_clause = [] | ||||||
|  |         params = [] | ||||||
|  |  | ||||||
|  |         if client_ip: | ||||||
|  |             where_clause.append("client_ip = %s") | ||||||
|  |             params.append(client_ip) | ||||||
|  |         if danger_type: | ||||||
|  |             where_clause.append("type = %s") | ||||||
|  |             params.append(danger_type) | ||||||
|  |         if start_date: | ||||||
|  |             where_clause.append("DATE(created_at) >= %s") | ||||||
|  |             params.append(start_date.strftime("%Y-%m-%d")) | ||||||
|  |         if end_date: | ||||||
|  |             where_clause.append("DATE(created_at) <= %s") | ||||||
|  |             params.append(end_date.strftime("%Y-%m-%d")) | ||||||
|  |  | ||||||
|  |         # 1. 统计符合条件的总记录数 | ||||||
|  |         count_query = "SELECT COUNT(*) AS total FROM device_danger" | ||||||
|  |         if where_clause: | ||||||
|  |             count_query += " WHERE " + " AND ".join(where_clause) | ||||||
|  |         cursor.execute(count_query, params) | ||||||
|  |         total = cursor.fetchone()["total"] | ||||||
|  |  | ||||||
|  |         # 2. 分页查询记录(按创建时间倒序,最新的在前) | ||||||
|  |         offset = (page - 1) * page_size | ||||||
|  |         list_query = "SELECT * FROM device_danger" | ||||||
|  |         if where_clause: | ||||||
|  |             list_query += " WHERE " + " AND ".join(where_clause) | ||||||
|  |         list_query += " ORDER BY created_at DESC LIMIT %s OFFSET %s" | ||||||
|  |         params.extend([page_size, offset])  # 追加分页参数 | ||||||
|  |  | ||||||
|  |         cursor.execute(list_query, params) | ||||||
|  |         danger_list = cursor.fetchall() | ||||||
|  |  | ||||||
|  |         # 转换为响应模型 | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message="获取危险记录列表成功", | ||||||
|  |             data=DeviceDangerListResponse( | ||||||
|  |                 total=total, | ||||||
|  |                 dangers=[DeviceDangerResponse(**item) for item in danger_list] | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise HTTPException(status_code=500, detail=f"查询危险记录失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 接口3:获取单个设备的所有危险记录 | ||||||
|  | # ------------------------------ | ||||||
|  | @router.get("/device/{client_ip}", response_model=APIResponse, summary="获取单个设备的所有危险记录") | ||||||
|  | # @encrypt_response() | ||||||
|  | async def get_device_dangers( | ||||||
|  |         client_ip: str = Path(..., max_length=100, description="设备IP地址"), | ||||||
|  |         page: int = Query(1, ge=1, description="页码,默认第1页"), | ||||||
|  |         page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100之间") | ||||||
|  | ): | ||||||
|  |     # 先检查设备是否存在 | ||||||
|  |     if not check_device_exist(client_ip): | ||||||
|  |         raise HTTPException(status_code=404, detail=f"IP为 {client_ip} 的设备不存在") | ||||||
|  |  | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 1. 统计该设备的危险记录总数 | ||||||
|  |         count_query = "SELECT COUNT(*) AS total FROM device_danger WHERE client_ip = %s" | ||||||
|  |         cursor.execute(count_query, (client_ip,)) | ||||||
|  |         total = cursor.fetchone()["total"] | ||||||
|  |  | ||||||
|  |         # 2. 分页查询该设备的危险记录 | ||||||
|  |         offset = (page - 1) * page_size | ||||||
|  |         list_query = """ | ||||||
|  |             SELECT * FROM device_danger  | ||||||
|  |             WHERE client_ip = %s  | ||||||
|  |             ORDER BY created_at DESC  | ||||||
|  |             LIMIT %s OFFSET %s | ||||||
|  |         """ | ||||||
|  |         cursor.execute(list_query, (client_ip, page_size, offset)) | ||||||
|  |         danger_list = cursor.fetchall() | ||||||
|  |  | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message=f"获取设备[{client_ip}]危险记录成功(共{total}条)", | ||||||
|  |             data=DeviceDangerListResponse( | ||||||
|  |                 total=total, | ||||||
|  |                 dangers=[DeviceDangerResponse(**item) for item in danger_list] | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise HTTPException(status_code=500, detail=f"查询设备[{client_ip}]危险记录失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 接口4:根据ID获取单个危险记录详情 | ||||||
|  | # ------------------------------ | ||||||
|  | @router.get("/{danger_id}", response_model=APIResponse, summary="根据ID获取单个危险记录详情") | ||||||
|  | @encrypt_response() | ||||||
|  | async def get_danger_detail( | ||||||
|  |         danger_id: int = Path(..., ge=1, description="危险记录ID") | ||||||
|  | ): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 查询单个危险记录 | ||||||
|  |         query = "SELECT * FROM device_danger WHERE id = %s" | ||||||
|  |         cursor.execute(query, (danger_id,)) | ||||||
|  |         danger = cursor.fetchone() | ||||||
|  |  | ||||||
|  |         if not danger: | ||||||
|  |             raise HTTPException(status_code=404, detail=f"ID为 {danger_id} 的危险记录不存在") | ||||||
|  |  | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message="获取危险记录详情成功", | ||||||
|  |             data=DeviceDangerResponse(**danger) | ||||||
|  |         ) | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise HTTPException(status_code=500, detail=f"查询危险记录详情失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
| @ -13,7 +13,7 @@ from schema.device_schema import ( | |||||||
| from schema.response_schema import APIResponse | from schema.response_schema import APIResponse | ||||||
|  |  | ||||||
| router = APIRouter( | router = APIRouter( | ||||||
|     prefix="/devices", |     prefix="/api/devices", | ||||||
|     tags=["设备管理"] |     tags=["设备管理"] | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | |||||||
| @ -17,7 +17,7 @@ from schema.response_schema import APIResponse | |||||||
| from util.face_util import add_binary_data, get_average_feature | from util.face_util import add_binary_data, get_average_feature | ||||||
| from util.file_util import save_face_to_up_images | from util.file_util import save_face_to_up_images | ||||||
|  |  | ||||||
| router = APIRouter(prefix="/faces", tags=["人脸管理"]) | router = APIRouter(prefix="/api/faces", tags=["人脸管理"]) | ||||||
|  |  | ||||||
|  |  | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
|  | |||||||
| @ -1,276 +1,174 @@ | |||||||
| from flask import Flask, send_from_directory, abort, request | from fastapi import FastAPI, HTTPException, Request, Depends, APIRouter | ||||||
|  | from fastapi.responses import FileResponse | ||||||
| import os | import os | ||||||
| import logging | import logging | ||||||
| from functools import wraps | from functools import wraps | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from flask_cors import CORS | from fastapi.middleware.cors import CORSMiddleware | ||||||
|  | from typing import Annotated | ||||||
|  |  | ||||||
| # 配置日志(保持原有格式) |  | ||||||
| logging.basicConfig( |  | ||||||
|     level=logging.INFO, |  | ||||||
|     format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |  | ||||||
| ) |  | ||||||
| logger = logging.getLogger(__name__) |  | ||||||
|  |  | ||||||
| # 初始化 Flask 应用(供 main.py 导入) | router = APIRouter( | ||||||
| app = Flask(__name__) |     prefix="/api/file", | ||||||
|  |     tags=["文件管理"] | ||||||
| # ------------------------------ |  | ||||||
| # 核心修改:与 FastAPI 对齐的跨域配置 |  | ||||||
| # ------------------------------ |  | ||||||
| # 1. 允许的前端域名(根据实际环境修改,生产环境删除 "*") |  | ||||||
| ALLOWED_ORIGINS = [ |  | ||||||
|     # "http://localhost:8080",  # 本地前端开发地址 |  | ||||||
|     # "http://127.0.0.1:8080", |  | ||||||
|     # "http://服务器IP:8080",    # 部署后前端地址 |  | ||||||
|     "*" |  | ||||||
| ] |  | ||||||
|  |  | ||||||
| # 2. 配置 CORS(与 FastAPI 规则完全对齐) |  | ||||||
| CORS( |  | ||||||
|     app, |  | ||||||
|     resources={ |  | ||||||
|         r"/*": { |  | ||||||
|             "origins": ALLOWED_ORIGINS, |  | ||||||
|             "allow_credentials": True, |  | ||||||
|             "methods": ["*"], |  | ||||||
|             "allow_headers": ["*"], |  | ||||||
|         } |  | ||||||
|     }, |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
| # 核心路径配置(关键修改:修正 PROJECT_ROOT 计算) | # 4. 路径配置 | ||||||
| # 原问题:file_service.py 在 service 文件夹内,需向上跳一级到项目根目录 |  | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
| CURRENT_FILE_PATH = Path(__file__).resolve()  # 当前文件路径:service/file_service.py | CURRENT_FILE_PATH = Path(__file__).resolve() | ||||||
| PROJECT_ROOT = CURRENT_FILE_PATH.parent.parent  # 项目根目录(service 文件夹的父目录) | PROJECT_ROOT = CURRENT_FILE_PATH.parent.parent  # 项目根目录 | ||||||
| # 资源目录(现在正确指向项目根目录下的文件夹) |  | ||||||
| BASE_IMAGE_DIR_DECT = str((PROJECT_ROOT / "resource" / "dect").resolve())       # 根目录/resource/dect |  | ||||||
| BASE_IMAGE_DIR_UP_IMAGES = str((PROJECT_ROOT / "up_images").resolve())          # 根目录/up_images |  | ||||||
| BASE_MODEL_DIR = str((PROJECT_ROOT / "resource" / "models").resolve())          # 根目录/resource/models |  | ||||||
|  |  | ||||||
|  | # 资源目录定义 | ||||||
|  | BASE_IMAGE_DIR_DECT = str((PROJECT_ROOT / "resource" / "dect").resolve())       # 检测图片目录 | ||||||
|  | BASE_IMAGE_DIR_UP_IMAGES = str((PROJECT_ROOT / "up_images").resolve())          # 人脸图片目录 | ||||||
|  | BASE_MODEL_DIR = str((PROJECT_ROOT / "resource" / "models").resolve())          # 模型文件目录 | ||||||
|  |  | ||||||
|  | # 确保资源目录存在 | ||||||
|  | for dir_path in [BASE_IMAGE_DIR_DECT, BASE_IMAGE_DIR_UP_IMAGES, BASE_MODEL_DIR]: | ||||||
|  |     if not os.path.exists(dir_path): | ||||||
|  |         os.makedirs(dir_path, exist_ok=True) | ||||||
|  |         print(f"[创建目录] {dir_path}") | ||||||
|  |  | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
| # 安全检查装饰器(不变) | # 5. 安全依赖项(替代Flask装饰器) | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
| def safe_path_check(root_dir: str): | def safe_path_check(root_dir: str): | ||||||
|     def decorator(func): |     """ | ||||||
|         @wraps(func) |     安全路径校验依赖项: | ||||||
|         def wrapper(*args, **kwargs): |     1. 禁止路径遍历(确保请求文件在根目录内) | ||||||
|             resource_path = kwargs.get('resource_path', '').strip() |     2. 校验文件存在且为有效文件(非目录) | ||||||
|             # 统一路径分隔符(兼容 Windows \ 和 Linux /) |     3. 限制文件大小(模型200MB,图片10MB) | ||||||
|  |     """ | ||||||
|  |     async def dependency(request: Request, resource_path: str): | ||||||
|  |         # 统一路径分隔符 | ||||||
|         resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep) |         resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep) | ||||||
|             # 拼接完整路径(防止路径遍历) |         # 拼接完整路径 | ||||||
|         full_file_path = os.path.abspath(os.path.join(root_dir, resource_path)) |         full_file_path = os.path.abspath(os.path.join(root_dir, resource_path)) | ||||||
|             logger.debug( |  | ||||||
|                 f"[Flask 安全检查] 请求路径:{resource_path} | 完整路径:{full_file_path} | 根目录:{root_dir}" |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|             # 1. 禁止路径遍历(确保请求文件在根目录内) |         # 校验1:禁止路径遍历 | ||||||
|         if not full_file_path.startswith(root_dir): |         if not full_file_path.startswith(root_dir): | ||||||
|                 logger.warning( |             print(f"[安全检查] 禁止路径遍历!IP:{request.client.host} | 请求路径:{resource_path}") | ||||||
|                     f"[Flask 安全拦截] 非法路径遍历!IP:{request.remote_addr} | 请求路径:{resource_path}" |             raise HTTPException(status_code=403, detail="非法路径访问") | ||||||
|                 ) |  | ||||||
|                 abort(403) |  | ||||||
|  |  | ||||||
|             # 2. 检查文件存在且为有效文件(非目录) |         # 校验2:文件存在且为有效文件 | ||||||
|         if not os.path.exists(full_file_path) or not os.path.isfile(full_file_path): |         if not os.path.exists(full_file_path) or not os.path.isfile(full_file_path): | ||||||
|                 logger.warning( |             print(f"[资源错误] 文件不存在/非文件!IP:{request.client.host} | 路径:{full_file_path}") | ||||||
|                     f"[Flask 资源错误] 文件不存在/非文件!IP:{request.remote_addr} | 路径:{full_file_path}" |             raise HTTPException(status_code=404, detail="文件不存在") | ||||||
|                 ) |  | ||||||
|                 abort(404) |  | ||||||
|  |  | ||||||
|             # 3. 限制文件大小(模型200MB,图片10MB) |         # 校验3:文件大小限制 | ||||||
|         max_size = 200 * 1024 * 1024 if "models" in root_dir else 10 * 1024 * 1024 |         max_size = 200 * 1024 * 1024 if "models" in root_dir else 10 * 1024 * 1024 | ||||||
|         if os.path.getsize(full_file_path) > max_size: |         if os.path.getsize(full_file_path) > max_size: | ||||||
|                 logger.warning( |             print(f"[大小超限] 文件超过{max_size//1024//1024}MB!IP:{request.client.host} | 路径:{full_file_path}") | ||||||
|                     f"[Flask 大小超限] 文件超过{max_size//1024//1024}MB!IP:{request.remote_addr} | 路径:{full_file_path}" |             raise HTTPException(status_code=413, detail=f"文件大小超过限制({max_size//1024//1024}MB)") | ||||||
|                 ) |  | ||||||
|                 abort(413) |  | ||||||
|  |  | ||||||
|             # 安全检查通过,传递根目录给视图函数 |         return full_file_path | ||||||
|             return func(*args, **kwargs, root_dir=root_dir) |     return dependency | ||||||
|         return wrapper |  | ||||||
|     return decorator |  | ||||||
|  |  | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
| # 1. 模型下载接口(/model/download/*) | # 6. 核心接口 | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
| @app.route('/model/download/<path:resource_path>') | @router.get("/model/download/{resource_path:path}", summary="模型下载接口") | ||||||
| @safe_path_check(root_dir=BASE_MODEL_DIR) | async def download_model( | ||||||
| def download_model(resource_path, root_dir): |     resource_path: str, | ||||||
|  |     full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_MODEL_DIR))], | ||||||
|  |     request: Request | ||||||
|  | ): | ||||||
|  |     """模型下载接口(仅允许 .pt 格式,强制浏览器下载)""" | ||||||
|     try: |     try: | ||||||
|         resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep) |         dir_path, file_name = os.path.split(full_file_path) | ||||||
|         dir_path, file_name = os.path.split(resource_path) |  | ||||||
|         full_dir = os.path.abspath(os.path.join(root_dir, dir_path)) |  | ||||||
|  |  | ||||||
|         # 仅允许 .pt 格式(YOLO 模型) |         # 额外校验:仅允许 YOLO 模型格式(.pt) | ||||||
|         if not file_name.lower().endswith('.pt'): |         if not file_name.lower().endswith(".pt"): | ||||||
|             logger.warning( |             print(f"[格式错误] 非 .pt 模型文件!IP:{request.client.host} | 文件名:{file_name}") | ||||||
|                 f"[Flask 格式错误] 非 .pt 模型文件!IP:{request.remote_addr} | 文件名:{file_name}" |             raise HTTPException(status_code=415, detail="仅支持 .pt 格式的模型文件") | ||||||
|  |  | ||||||
|  |         print(f"[模型下载] 尝试下载!IP:{request.client.host} | 文件:{file_name} | 目录:{dir_path}") | ||||||
|  |  | ||||||
|  |         # 强制下载 | ||||||
|  |         return FileResponse( | ||||||
|  |             full_file_path, | ||||||
|  |             filename=file_name, | ||||||
|  |             media_type="application/octet-stream" | ||||||
|         ) |         ) | ||||||
|             abort(415) |     except HTTPException: | ||||||
|  |         raise | ||||||
|         logger.info( |  | ||||||
|             f"[Flask 模型下载] 成功请求!IP:{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}" |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         # 强制浏览器下载(而非预览) |  | ||||||
|         return send_from_directory( |  | ||||||
|             full_dir, |  | ||||||
|             file_name, |  | ||||||
|             as_attachment=True, |  | ||||||
|             mimetype="application/octet-stream" |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         logger.error( |         print(f"[下载异常] IP:{request.client.host} | 错误:{str(e)}") | ||||||
|             f"[Flask 模型下载异常] IP:{request.remote_addr} | 错误:{str(e)}" |         raise HTTPException(status_code=500, detail="服务器内部错误") | ||||||
|         ) |  | ||||||
|         abort(500) |  | ||||||
|  |  | ||||||
| # ------------------------------ |  | ||||||
| # 2. 人脸图片访问接口(/up_images/*) | @router.get("/up_images/{resource_path:path}", summary="人脸图片访问接口") | ||||||
| # ------------------------------ | async def get_face_image( | ||||||
| @app.route('/up_images/<path:resource_path>') |     resource_path: str, | ||||||
| @safe_path_check(root_dir=BASE_IMAGE_DIR_UP_IMAGES) |     full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_IMAGE_DIR_UP_IMAGES))], | ||||||
| def get_face_image(resource_path, root_dir): |     request: Request | ||||||
|  | ): | ||||||
|  |     """人脸图片访问接口(允许浏览器预览,仅支持常见图片格式)""" | ||||||
|     try: |     try: | ||||||
|         resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep) |         dir_path, file_name = os.path.split(full_file_path) | ||||||
|         dir_path, file_name = os.path.split(resource_path) |  | ||||||
|         full_dir = os.path.abspath(os.path.join(root_dir, dir_path)) |  | ||||||
|  |  | ||||||
|         # 仅允许常见图片格式 |         # 图片格式校验 | ||||||
|         allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp') |         allowed_ext = (".png", ".jpg", ".jpeg", ".gif", ".bmp") | ||||||
|         if not file_name.lower().endswith(allowed_ext): |         if not file_name.lower().endswith(allowed_ext): | ||||||
|             logger.warning( |             print(f"[格式错误] 非图片文件!IP:{request.client.host} | 文件名:{file_name}") | ||||||
|                 f"[Flask 格式错误] 非图片文件!IP:{request.remote_addr} | 文件名:{file_name}" |             raise HTTPException(status_code=415, detail="仅支持常见图片格式") | ||||||
|             ) |  | ||||||
|             abort(415) |  | ||||||
|  |  | ||||||
|         logger.info( |         print(f"[人脸图片] 尝试访问!IP:{request.client.host} | 文件:{file_name} | 目录:{dir_path}") | ||||||
|             f"[Flask 人脸图片] 成功请求!IP:{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}" |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         # 允许浏览器预览图片 |  | ||||||
|         return send_from_directory(full_dir, file_name, as_attachment=False) |  | ||||||
|  |  | ||||||
|  |         return FileResponse(full_file_path) | ||||||
|  |     except HTTPException: | ||||||
|  |         raise | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         logger.error( |         print(f"[人脸图片异常] IP:{request.client.host} | 错误:{str(e)}") | ||||||
|             f"[Flask 人脸图片异常] IP:{request.remote_addr} | 错误:{str(e)}" |  | ||||||
|         ) |  | ||||||
|         abort(500) |  | ||||||
|  |  | ||||||
| # ------------------------------ |  | ||||||
| # 3. 检测图片访问接口(/resource/dect/*) | @router.get("/resource/dect/{resource_path:path}", summary="检测图片访问接口") | ||||||
| # ------------------------------ | async def get_dect_image( | ||||||
| @app.route('/resource/dect/<path:resource_path>') |     resource_path: str, | ||||||
| @safe_path_check(root_dir=BASE_IMAGE_DIR_DECT) |     full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_IMAGE_DIR_DECT))], | ||||||
| def get_dect_image(resource_path, root_dir): |     request: Request | ||||||
|  | ): | ||||||
|  |     """检测图片访问接口(允许浏览器预览,仅支持常见图片格式)""" | ||||||
|     try: |     try: | ||||||
|         resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep) |         dir_path, file_name = os.path.split(full_file_path) | ||||||
|         dir_path, file_name = os.path.split(resource_path) |  | ||||||
|         full_dir = os.path.abspath(os.path.join(root_dir, dir_path)) |  | ||||||
|  |  | ||||||
|         # 仅允许常见图片格式 |         # 图片格式校验 | ||||||
|         allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp') |         allowed_ext = (".png", ".jpg", ".jpeg", ".gif", ".bmp") | ||||||
|         if not file_name.lower().endswith(allowed_ext): |         if not file_name.lower().endswith(allowed_ext): | ||||||
|             logger.warning( |             print(f"[格式错误] 非图片文件!IP:{request.client.host} | 文件名:{file_name}") | ||||||
|                 f"[Flask 格式错误] 非图片文件!IP:{request.remote_addr} | 文件名:{file_name}" |             raise HTTPException(status_code=415, detail="仅支持常见图片格式") | ||||||
|             ) |  | ||||||
|             abort(415) |  | ||||||
|  |  | ||||||
|         logger.info( |         print(f"[检测图片] 尝试访问!IP:{request.client.host} | 文件:{file_name} | 目录:{dir_path}") | ||||||
|             f"[Flask 检测图片] 成功请求!IP:{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}" |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         return send_from_directory(full_dir, file_name, as_attachment=False) |  | ||||||
|  |  | ||||||
|  |         return FileResponse(full_file_path) | ||||||
|  |     except HTTPException: | ||||||
|  |         raise | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         logger.error( |         print(f"[检测图片异常] IP:{request.client.host} | 错误:{str(e)}") | ||||||
|             f"[Flask 检测图片异常] IP:{request.remote_addr} | 错误:{str(e)}" |         raise HTTPException(status_code=500, detail="服务器内部错误") | ||||||
|         ) |  | ||||||
|         abort(500) |  | ||||||
|  |  | ||||||
| # ------------------------------ |  | ||||||
| # 4. 兼容旧图片接口(/images/* → 映射到 /resource/dect/*) | @router.get("/images/{resource_path:path}", summary="兼容旧接口") | ||||||
| # ------------------------------ | async def get_compatible_image( | ||||||
| @app.route('/images/<path:resource_path>') |     resource_path: str, | ||||||
| @safe_path_check(root_dir=BASE_IMAGE_DIR_DECT) |     full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_IMAGE_DIR_DECT))], | ||||||
| def get_compatible_image(resource_path, root_dir): |     request: Request | ||||||
|  | ): | ||||||
|  |     """兼容旧接口(/images/* → 映射到 /resource/dect/*,保留历史兼容性)""" | ||||||
|     try: |     try: | ||||||
|         resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep) |         dir_path, file_name = os.path.split(full_file_path) | ||||||
|         dir_path, file_name = os.path.split(resource_path) |  | ||||||
|         full_dir = os.path.abspath(os.path.join(root_dir, dir_path)) |  | ||||||
|  |  | ||||||
|         allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp') |         # 图片格式校验 | ||||||
|  |         allowed_ext = (".png", ".jpg", ".jpeg", ".gif", ".bmp") | ||||||
|         if not file_name.lower().endswith(allowed_ext): |         if not file_name.lower().endswith(allowed_ext): | ||||||
|             logger.warning( |             print(f"[格式错误] 非图片文件!IP:{request.client.host} | 文件名:{file_name}") | ||||||
|                 f"[Flask 格式错误] 非图片文件!IP:{request.remote_addr} | 文件名:{file_name}" |             raise HTTPException(status_code=415, detail="仅支持常见图片格式") | ||||||
|             ) |         print(f"[兼容图片] 尝试访问!IP:{request.client.host} | 文件:{file_name} | 目录:{dir_path}") | ||||||
|             abort(415) |  | ||||||
|  |  | ||||||
|         logger.info( |  | ||||||
|             f"[Flask 兼容图片] 成功请求!IP:{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}" |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         return send_from_directory(full_dir, file_name, as_attachment=False) |  | ||||||
|  |  | ||||||
|  |         return FileResponse(full_file_path) | ||||||
|  |     except HTTPException: | ||||||
|  |         raise | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         logger.error( |         print(f"[兼容图片异常] IP:{request.client.host} | 错误:{str(e)}") | ||||||
|             f"[Flask 兼容图片异常] IP:{request.remote_addr} | 错误:{str(e)}" |         raise HTTPException(status_code=500, detail="服务器内部错误") | ||||||
|         ) |  | ||||||
|         abort(500) |  | ||||||
|  |  | ||||||
| # ------------------------------ |  | ||||||
| # 全局错误处理器(不变) |  | ||||||
| # ------------------------------ |  | ||||||
| @app.errorhandler(403) |  | ||||||
| def forbidden_error(error): |  | ||||||
|     return "❌ 禁止访问:路径非法(可能存在路径遍历)或无权限", 403 |  | ||||||
|  |  | ||||||
| @app.errorhandler(404) |  | ||||||
| def not_found_error(error): |  | ||||||
|     return "❌ 资源不存在:请检查URL路径(IP、目录、文件名)是否正确", 404 |  | ||||||
|  |  | ||||||
| @app.errorhandler(413) |  | ||||||
| def too_large_error(error): |  | ||||||
|     return "❌ 文件过大:图片最大10MB,模型最大200MB", 413 |  | ||||||
|  |  | ||||||
| @app.errorhandler(415) |  | ||||||
| def unsupported_type_error(error): |  | ||||||
|     return "❌ 不支持的文件类型:图片支持png/jpg/jpeg/gif/bmp,模型仅支持pt", 415 |  | ||||||
|  |  | ||||||
| @app.errorhandler(500) |  | ||||||
| def server_error(error): |  | ||||||
|     return "❌ 服务器内部错误:请联系管理员查看后台日志", 500 |  | ||||||
|  |  | ||||||
| # ------------------------------ |  | ||||||
| # Flask 独立启动入口(供测试,实际由 main.py 子线程启动) |  | ||||||
| # ------------------------------ |  | ||||||
| if __name__ == '__main__': |  | ||||||
|     # 确保所有资源目录存在 |  | ||||||
|     required_dirs = [ |  | ||||||
|         (BASE_IMAGE_DIR_DECT, "检测图片目录"), |  | ||||||
|         (BASE_IMAGE_DIR_UP_IMAGES, "人脸图片目录"), |  | ||||||
|         (BASE_MODEL_DIR, "模型文件目录") |  | ||||||
|     ] |  | ||||||
|     for dir_path, dir_desc in required_dirs: |  | ||||||
|         if not os.path.exists(dir_path): |  | ||||||
|             logger.info(f"[Flask 初始化] {dir_desc}不存在,创建:{dir_path}") |  | ||||||
|             os.makedirs(dir_path, exist_ok=True) |  | ||||||
|  |  | ||||||
|     # 启动提示 |  | ||||||
|     logger.info("\n[Flask 服务启动成功!] 支持的接口:") |  | ||||||
|     logger.info(f"1. 模型下载 → http://服务器IP:5000/model/download/resource/models/xxx.pt") |  | ||||||
|     logger.info(f"2. 人脸图片 → http://服务器IP:5000/up_images/xxx.jpg") |  | ||||||
|     logger.info(f"3. 检测图片 → http://服务器IP:5000/resource/dect/xxx.jpg 或 http://服务器IP:5000/images/xxx.jpg\n") |  | ||||||
|  |  | ||||||
|     # 启动服务(禁用 debug 和自动重载) |  | ||||||
|     app.run( |  | ||||||
|         host="0.0.0.0", |  | ||||||
|         port=5000, |  | ||||||
|         debug=False, |  | ||||||
|         use_reloader=False |  | ||||||
|     ) |  | ||||||
|  | |||||||
| @ -38,7 +38,7 @@ _yolo_model = None | |||||||
| _current_model_version = None  # 模型版本标识 | _current_model_version = None  # 模型版本标识 | ||||||
| _current_conf_threshold = 0.8  # 默认置信度初始值 | _current_conf_threshold = 0.8  # 默认置信度初始值 | ||||||
|  |  | ||||||
| router = APIRouter(prefix="/models", tags=["模型管理"]) | router = APIRouter(prefix="/api/models", tags=["模型管理"]) | ||||||
|  |  | ||||||
|  |  | ||||||
| # 服务重启核心工具函数(保持不变) | # 服务重启核心工具函数(保持不变) | ||||||
|  | |||||||
| @ -16,7 +16,7 @@ from schema.user_schema import UserResponse | |||||||
|  |  | ||||||
| # 创建敏感信息接口路由(前缀 /sensitives、标签用于 Swagger 分类) | # 创建敏感信息接口路由(前缀 /sensitives、标签用于 Swagger 分类) | ||||||
| router = APIRouter( | router = APIRouter( | ||||||
|     prefix="/sensitives", |     prefix="/api/sensitives", | ||||||
|     tags=["敏感信息管理"] |     tags=["敏感信息管理"] | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | |||||||
| @ -18,7 +18,7 @@ from middle.auth_middleware import ( | |||||||
|  |  | ||||||
| # 创建用户接口路由(前缀 /users、标签用于 Swagger 分类) | # 创建用户接口路由(前缀 /users、标签用于 Swagger 分类) | ||||||
| router = APIRouter( | router = APIRouter( | ||||||
|     prefix="/users", |     prefix="/api/users", | ||||||
|     tags=["用户管理"] |     tags=["用户管理"] | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | |||||||
| @ -12,7 +12,8 @@ def save_face_to_up_images( | |||||||
| ) -> Dict[str, str]: | ) -> Dict[str, str]: | ||||||
|     """ |     """ | ||||||
|     保存人脸图片到 `/up_images/用户IP/人脸名字/` 路径 |     保存人脸图片到 `/up_images/用户IP/人脸名字/` 路径 | ||||||
|     确保db_path以up_images开头,且统一使用正斜杠 |     确保db_path以 /api/file/up_images 开头,且统一使用正斜杠 | ||||||
|  |     本地不创建/api/file/文件夹,仅URL访问时使用该前缀路由 | ||||||
|  |  | ||||||
|     参数: |     参数: | ||||||
|         client_ip: 客户端IP(原始格式,如192.168.1.101) |         client_ip: 客户端IP(原始格式,如192.168.1.101) | ||||||
| @ -21,10 +22,10 @@ def save_face_to_up_images( | |||||||
|         image_format: 图片格式(默认jpg) |         image_format: 图片格式(默认jpg) | ||||||
|  |  | ||||||
|     返回: |     返回: | ||||||
|         字典:success(是否成功)、db_path(存数据库的相对路径)、local_abs_path(本地绝对路径)、msg(提示) |         字典:success(是否成功)、db_path(存数据库的路径,带/api/file/前缀)、local_abs_path(本地绝对路径)、msg(提示) | ||||||
|     """ |     """ | ||||||
|     try: |     try: | ||||||
|         # 1. 基础参数校验 |         # 1. 基础参数校验(不变) | ||||||
|         if not client_ip.strip(): |         if not client_ip.strip(): | ||||||
|             return {"success": False, "db_path": "", "local_abs_path": "", "msg": "客户端IP不能为空"} |             return {"success": False, "db_path": "", "local_abs_path": "", "msg": "客户端IP不能为空"} | ||||||
|         if not image_bytes: |         if not image_bytes: | ||||||
| @ -32,49 +33,50 @@ def save_face_to_up_images( | |||||||
|         if image_format.lower() not in ["jpg", "jpeg", "png"]: |         if image_format.lower() not in ["jpg", "jpeg", "png"]: | ||||||
|             return {"success": False, "db_path": "", "local_abs_path": "", "msg": "仅支持jpg/jpeg/png格式"} |             return {"success": False, "db_path": "", "local_abs_path": "", "msg": "仅支持jpg/jpeg/png格式"} | ||||||
|  |  | ||||||
|         # 2. 处理特殊字符(避免路径错误) |         # 2. 处理特殊字符(避免路径错误)(不变) | ||||||
|         safe_ip = client_ip.strip().replace(".", "_")  # IP中的.替换为_ |         safe_ip = client_ip.strip().replace(".", "_")  # IP中的.替换为_ | ||||||
|         safe_face_name = face_name.strip() if (face_name and face_name.strip()) else "未命名" |         safe_face_name = face_name.strip() if (face_name and face_name.strip()) else "未命名" | ||||||
|         safe_face_name = "".join([c for c in safe_face_name if c not in r'\/:*?"<>|'])  # 过滤非法字符 |         safe_face_name = "".join([c for c in safe_face_name if c not in r'\/:*?"<>|'])  # 过滤非法字符 | ||||||
|  |  | ||||||
|         # 3. 构建根目录(强制转为绝对路径,避免相对路径混淆) |         # 3. 构建根目录(强制转为绝对路径,避免相对路径混淆) | ||||||
|         root_dir = Path("up_images").resolve()  # 转为绝对路径(如D:/Git/bin/video/up_images) |         root_dir = Path("up_images").resolve() | ||||||
|         if not root_dir.exists(): |         if not root_dir.exists(): | ||||||
|             root_dir.mkdir(parents=True, exist_ok=True) |             root_dir.mkdir(parents=True, exist_ok=True) | ||||||
|             print(f"[FileUtil] 已创建up_images根目录:{root_dir}") |             print(f"[FileUtil] 已创建up_images根目录:{root_dir}") | ||||||
|  |  | ||||||
|         # 4. 构建文件层级路径(确保在root_dir子目录下) |         # 4. 构建文件层级路径(确保在root_dir子目录下)(不变) | ||||||
|         ip_dir = root_dir / safe_ip |         ip_dir = root_dir / safe_ip | ||||||
|         face_name_dir = ip_dir / safe_face_name |         face_name_dir = ip_dir / safe_face_name | ||||||
|         face_name_dir.mkdir(parents=True, exist_ok=True)  # 自动创建目录 |         face_name_dir.mkdir(parents=True, exist_ok=True) | ||||||
|         print(f"[FileUtil] 图片存储目录:{face_name_dir}") |         print(f"[FileUtil] 图片存储目录(本地):{face_name_dir}") | ||||||
|  |  | ||||||
|         # 5. 生成唯一文件名(毫秒级时间戳) |         # 5. 生成唯一文件名(毫秒级时间戳)(不变) | ||||||
|         timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3] |         timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3] | ||||||
|  |  | ||||||
|         image_filename = f"face_{safe_ip}_{timestamp}.{image_format.lower()}" |         image_filename = f"face_{safe_ip}_{timestamp}.{image_format.lower()}" | ||||||
|  |         local_abs_path = face_name_dir / image_filename | ||||||
|  |  | ||||||
|         # 6. 计算路径(确保所有路径都是绝对路径且在root_dir下) |  | ||||||
|         local_abs_path = face_name_dir / image_filename  # 绝对路径 |  | ||||||
|  |  | ||||||
|         # 验证路径是否在root_dir下(防止路径穿越攻击) |  | ||||||
|         if not local_abs_path.resolve().is_relative_to(root_dir.resolve()): |         if not local_abs_path.resolve().is_relative_to(root_dir.resolve()): | ||||||
|             raise Exception(f"图片路径不在up_images根目录下(安全校验失败):{local_abs_path}") |             raise Exception(f"图片路径不在up_images根目录下(安全校验失败):{local_abs_path}") | ||||||
|  |  | ||||||
|         # 数据库存储路径:强制包含up_images前缀,统一使用正斜杠 |         # 数据库存储路径:核心修改——在原有relative_path前添加 /api/file/ 前缀 | ||||||
|         relative_path = local_abs_path.relative_to(root_dir.parent)  # 相对于root_dir的父目录 |         relative_path = local_abs_path.relative_to(root_dir.parent) | ||||||
|         db_path = str(relative_path).replace("\\", "/")  # 此时会包含up_images部分 |  | ||||||
|  |  | ||||||
|         # 7. 写入图片文件 |         relative_path_str = str(relative_path).replace("\\", "/") | ||||||
|  |         # 2. 再拼接/api/file/前缀 | ||||||
|  |         db_path = f"/api/file/{relative_path_str}" | ||||||
|  |  | ||||||
|  |         # 7. 写入图片文件(不变) | ||||||
|         with open(local_abs_path, "wb") as f: |         with open(local_abs_path, "wb") as f: | ||||||
|             f.write(image_bytes) |             f.write(image_bytes) | ||||||
|         print(f"[FileUtil] 图片保存成功:") |         print(f"[FileUtil] 图片保存成功:") | ||||||
|         print(f"  数据库路径:{db_path}") |         print(f"  数据库路径(带/api/file/前缀):{db_path}") | ||||||
|         print(f"  本地绝对路径:{local_abs_path}") |         print(f"  本地绝对路径(无/api/file/):{local_abs_path}") | ||||||
|  |  | ||||||
|         return { |         return { | ||||||
|             "success": True, |             "success": True, | ||||||
|             "db_path": db_path,  # 格式为 up_images/192_168_110_31/小龙/xxx.jpg |             "db_path": db_path, | ||||||
|             "local_abs_path": str(local_abs_path),  # 本地绝对路径(完整路径) |             "local_abs_path": str(local_abs_path), | ||||||
|             "msg": "图片保存成功" |             "msg": "图片保存成功" | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										127
									
								
								ws/ws.py
									
									
									
									
									
								
							
							
						
						
									
										127
									
								
								ws/ws.py
									
									
									
									
									
								
							| @ -17,17 +17,16 @@ from service.device_action_service import add_device_action | |||||||
| from schema.device_action_schema import DeviceActionCreate | from schema.device_action_schema import DeviceActionCreate | ||||||
| from core.all import detect, load_model | from core.all import detect, load_model | ||||||
|  |  | ||||||
| # -------------------------- 1. AES 加密解密工具(固定密钥)-------------------------- | # -------------------------- 1. AES 加密工具(仅用于服务器向客户端发送消息)-------------------------- | ||||||
| AES_SECRET_KEY = b"jr1vA6tfWMHOYi6UXw67UuO6fdak2rMa"  # 约定密钥(32字节) | AES_SECRET_KEY = b"jr1vA6tfWMHOYi6UXw67UuO6fdak2rMa"  # 约定密钥(32字节) | ||||||
| AES_BLOCK_SIZE = 16  # AES固定块大小 | AES_BLOCK_SIZE = 16  # AES固定块大小 | ||||||
|  |  | ||||||
|  |  | ||||||
| def aes_encrypt(plaintext: str) -> dict: | def aes_encrypt(plaintext: str) -> dict: | ||||||
|     """AES-CBC加密:返回{密文, IV, 算法标识}(均Base64编码)""" |     """AES-CBC加密:返回{密文, IV, 算法标识}(均Base64编码)- 仅用于服务器发消息""" | ||||||
|     try: |     try: | ||||||
|         iv = os.urandom(AES_BLOCK_SIZE)  # 随机IV(16字节) |         iv = os.urandom(AES_BLOCK_SIZE)  # 随机IV(16字节) | ||||||
|         cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv) |         cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv) | ||||||
|         # 明文填充+加密+Base64编码 |  | ||||||
|         padded_plaintext = pad(plaintext.encode("utf-8"), AES_BLOCK_SIZE) |         padded_plaintext = pad(plaintext.encode("utf-8"), AES_BLOCK_SIZE) | ||||||
|         ciphertext = base64.b64encode(cipher.encrypt(padded_plaintext)).decode("utf-8") |         ciphertext = base64.b64encode(cipher.encrypt(padded_plaintext)).decode("utf-8") | ||||||
|         iv_base64 = base64.b64encode(iv).decode("utf-8") |         iv_base64 = base64.b64encode(iv).decode("utf-8") | ||||||
| @ -40,20 +39,6 @@ def aes_encrypt(plaintext: str) -> dict: | |||||||
|         raise Exception(f"AES加密失败: {str(e)}") from e |         raise Exception(f"AES加密失败: {str(e)}") from e | ||||||
|  |  | ||||||
|  |  | ||||||
| def aes_decrypt(encrypted_dict: dict) -> str: |  | ||||||
|     """AES-CBC解密:输入加密字典,返回原始文本""" |  | ||||||
|     try: |  | ||||||
|         # Base64解码密文和IV |  | ||||||
|         ciphertext = base64.b64decode(encrypted_dict["ciphertext"]) |  | ||||||
|         iv = base64.b64decode(encrypted_dict["iv"]) |  | ||||||
|         # 解密+去除填充 |  | ||||||
|         cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv) |  | ||||||
|         decrypted = unpad(cipher.decrypt(ciphertext), AES_BLOCK_SIZE).decode("utf-8") |  | ||||||
|         return decrypted |  | ||||||
|     except Exception as e: |  | ||||||
|         raise Exception(f"AES解密失败: {str(e)}") from e |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # -------------------------- 2. 配置常量(保持原有)-------------------------- | # -------------------------- 2. 配置常量(保持原有)-------------------------- | ||||||
| HEARTBEAT_INTERVAL = 30  # 心跳检查间隔(秒) | HEARTBEAT_INTERVAL = 30  # 心跳检查间隔(秒) | ||||||
| HEARTBEAT_TIMEOUT = 600  # 客户端超时阈值(秒) | HEARTBEAT_TIMEOUT = 600  # 客户端超时阈值(秒) | ||||||
| @ -72,7 +57,7 @@ def get_current_time_file_str() -> str: | |||||||
|     return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") |     return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") | ||||||
|  |  | ||||||
|  |  | ||||||
| # -------------------------- 4. 客户端连接封装(新增消息加密)-------------------------- | # -------------------------- 4. 客户端连接封装(服务器发消息仍加密,接收消息改明文)-------------------------- | ||||||
| class ClientConnection: | class ClientConnection: | ||||||
|     def __init__(self, websocket: WebSocket, client_ip: str): |     def __init__(self, websocket: WebSocket, client_ip: str): | ||||||
|         self.websocket = websocket |         self.websocket = websocket | ||||||
| @ -96,28 +81,25 @@ class ClientConnection: | |||||||
|         return self.consumer_task |         return self.consumer_task | ||||||
|  |  | ||||||
|     async def send_frame_permit(self): |     async def send_frame_permit(self): | ||||||
|         """发送加密的帧许可信号""" |         """发送加密的帧许可信号(服务器→客户端:加密)""" | ||||||
|         try: |         try: | ||||||
|             # 1. 构建原始消息 |  | ||||||
|             frame_permit_msg = { |             frame_permit_msg = { | ||||||
|                 "type": "frame", |                 "type": "frame", | ||||||
|                 "timestamp": get_current_time_str(), |                 "timestamp": get_current_time_str(), | ||||||
|                 "client_ip": self.client_ip |                 "client_ip": self.client_ip | ||||||
|             } |             } | ||||||
|             # 2. AES加密消息 |             encrypted_msg = aes_encrypt(json.dumps(frame_permit_msg))  # 保持加密 | ||||||
|             encrypted_msg = aes_encrypt(json.dumps(frame_permit_msg)) |  | ||||||
|             # 3. 发送加密消息 |  | ||||||
|             await self.websocket.send_json(encrypted_msg) |             await self.websocket.send_json(encrypted_msg) | ||||||
|             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 已发送加密帧许可") |             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 已发送加密帧许可") | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧许可加密/发送失败 - {str(e)}") |             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧许可加密/发送失败 - {str(e)}") | ||||||
|  |  | ||||||
|     async def consume_frames(self) -> None: |     async def consume_frames(self) -> None: | ||||||
|         """消费队列中的帧并处理""" |         """消费队列中的明文图像帧并处理""" | ||||||
|         try: |         try: | ||||||
|             while True: |             while True: | ||||||
|                 frame_data = await self.frame_queue.get() |                 frame_data = await self.frame_queue.get() | ||||||
|                 await self.send_frame_permit()  # 发送下一帧许可 |                 await self.send_frame_permit()  # 回复仍加密 | ||||||
|                 try: |                 try: | ||||||
|                     await self.process_frame(frame_data) |                     await self.process_frame(frame_data) | ||||||
|                 finally: |                 finally: | ||||||
| @ -128,23 +110,22 @@ class ClientConnection: | |||||||
|             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费错误 - {str(e)}") |             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费错误 - {str(e)}") | ||||||
|  |  | ||||||
|     async def process_frame(self, frame_data: bytes) -> None: |     async def process_frame(self, frame_data: bytes) -> None: | ||||||
|         """处理单帧图像(含加密危险通知)""" |         """处理明文图像帧(危险通知仍加密发送)""" | ||||||
|         # 二进制转OpenCV图像 |         # 二进制转OpenCV图像(客户端发的是明文二进制,直接解析) | ||||||
|         nparr = np.frombuffer(frame_data, np.uint8) |         nparr = np.frombuffer(frame_data, np.uint8) | ||||||
|         img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |         img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | ||||||
|         if img is None: |         if img is None: | ||||||
|             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 无法解析图像") |             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 无法解析明文图像") | ||||||
|             return |             return | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             # 调用检测函数(client_ip + img 双参数) |  | ||||||
|             has_violation, data, detector_type = await asyncio.to_thread( |             has_violation, data, detector_type = await asyncio.to_thread( | ||||||
|                 detect, self.client_ip, img |                 detect, self.client_ip, img | ||||||
|             ) |             ) | ||||||
|             print( |             print( | ||||||
|                 f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测结果 - 违规: {has_violation}, 类型: {detector_type}") |                 f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测结果 - 违规: {has_violation}, 类型: {detector_type}") | ||||||
|  |  | ||||||
|             # 处理违规逻辑(发送加密危险通知) |             # 违规通知:服务器→客户端,仍加密 | ||||||
|             if has_violation: |             if has_violation: | ||||||
|                 print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - {data}") |                 print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - {data}") | ||||||
|                 # 违规次数+1 |                 # 违规次数+1 | ||||||
| @ -154,19 +135,17 @@ class ClientConnection: | |||||||
|                 except Exception as e: |                 except Exception as e: | ||||||
|                     print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数更新失败 - {str(e)}") |                     print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数更新失败 - {str(e)}") | ||||||
|  |  | ||||||
|                 # 1. 构建原始危险通知 |                 # 构建危险通知并加密发送 | ||||||
|                 danger_msg = { |                 danger_msg = { | ||||||
|                     "type": "danger", |                     "type": "danger", | ||||||
|                     "timestamp": get_current_time_str(), |                     "timestamp": get_current_time_str(), | ||||||
|                     "client_ip": self.client_ip, |                     "client_ip": self.client_ip, | ||||||
|                     "detail": data |                     "detail": data | ||||||
|                 } |                 } | ||||||
|                 # 2. AES加密通知 |                 encrypted_danger_msg = aes_encrypt(json.dumps(danger_msg))  # 保持加密 | ||||||
|                 encrypted_danger_msg = aes_encrypt(json.dumps(danger_msg)) |  | ||||||
|                 # 3. 发送加密通知 |  | ||||||
|                 await self.websocket.send_json(encrypted_danger_msg) |                 await self.websocket.send_json(encrypted_danger_msg) | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 图像处理错误 - {str(e)}") |             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 明文图像处理错误 - {str(e)}") | ||||||
|  |  | ||||||
|  |  | ||||||
| # -------------------------- 5. 全局状态与心跳管理(保持原有)-------------------------- | # -------------------------- 5. 全局状态与心跳管理(保持原有)-------------------------- | ||||||
| @ -178,7 +157,6 @@ async def heartbeat_checker(): | |||||||
|     """全局心跳检查任务""" |     """全局心跳检查任务""" | ||||||
|     while True: |     while True: | ||||||
|         current_time = get_current_time_str() |         current_time = get_current_time_str() | ||||||
|         # 筛选超时客户端 |  | ||||||
|         timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()] |         timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()] | ||||||
|  |  | ||||||
|         if timeout_ips: |         if timeout_ips: | ||||||
| @ -186,11 +164,9 @@ async def heartbeat_checker(): | |||||||
|             for ip in timeout_ips: |             for ip in timeout_ips: | ||||||
|                 try: |                 try: | ||||||
|                     conn = connected_clients[ip] |                     conn = connected_clients[ip] | ||||||
|                     # 取消消费任务+关闭连接 |  | ||||||
|                     if conn.consumer_task and not conn.consumer_task.done(): |                     if conn.consumer_task and not conn.consumer_task.done(): | ||||||
|                         conn.consumer_task.cancel() |                         conn.consumer_task.cancel() | ||||||
|                     await conn.websocket.close(code=1008, reason="心跳超时") |                     await conn.websocket.close(code=1008, reason="心跳超时") | ||||||
|                     # 标记离线 |  | ||||||
|                     await asyncio.to_thread(update_online_status_by_ip, ip, 0) |                     await asyncio.to_thread(update_online_status_by_ip, ip, 0) | ||||||
|                     action_data = DeviceActionCreate(client_ip=ip, action=0) |                     action_data = DeviceActionCreate(client_ip=ip, action=0) | ||||||
|                     await asyncio.to_thread(add_device_action, action_data) |                     await asyncio.to_thread(add_device_action, action_data) | ||||||
| @ -205,19 +181,16 @@ async def heartbeat_checker(): | |||||||
|         await asyncio.sleep(HEARTBEAT_INTERVAL) |         await asyncio.sleep(HEARTBEAT_INTERVAL) | ||||||
|  |  | ||||||
|  |  | ||||||
| # -------------------------- 6. 消息处理工具(新增消息解密)-------------------------- | # -------------------------- 6. 客户端明文消息处理(关键修改:删除解密逻辑)-------------------------- | ||||||
| async def send_heartbeat_ack(conn: ClientConnection): | async def send_heartbeat_ack(conn: ClientConnection): | ||||||
|     """发送加密的心跳确认""" |     """发送加密的心跳确认(服务器→客户端:加密)""" | ||||||
|     try: |     try: | ||||||
|         # 1. 构建原始心跳确认 |  | ||||||
|         heartbeat_ack_msg = { |         heartbeat_ack_msg = { | ||||||
|             "type": "heart", |             "type": "heart", | ||||||
|             "timestamp": get_current_time_str(), |             "timestamp": get_current_time_str(), | ||||||
|             "client_ip": conn.client_ip |             "client_ip": conn.client_ip | ||||||
|         } |         } | ||||||
|         # 2. AES加密 |         encrypted_msg = aes_encrypt(json.dumps(heartbeat_ack_msg))  # 保持加密 | ||||||
|         encrypted_msg = aes_encrypt(json.dumps(heartbeat_ack_msg)) |  | ||||||
|         # 3. 发送 |  | ||||||
|         await conn.websocket.send_json(encrypted_msg) |         await conn.websocket.send_json(encrypted_msg) | ||||||
|         print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 已发送加密心跳确认") |         print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 已发送加密心跳确认") | ||||||
|         return True |         return True | ||||||
| @ -228,44 +201,22 @@ async def send_heartbeat_ack(conn: ClientConnection): | |||||||
|  |  | ||||||
|  |  | ||||||
| async def handle_text_msg(conn: ClientConnection, text: str): | async def handle_text_msg(conn: ClientConnection, text: str): | ||||||
|     """处理加密的文本消息(如心跳)""" |     """处理客户端明文文本消息(如心跳)- 关键修改:无需解密,直接解析JSON""" | ||||||
|     try: |     try: | ||||||
|         # 1. 解析加密字典 |         # 客户端发的是明文JSON,直接解析(删除原解密步骤) | ||||||
|         encrypted_dict = json.loads(text) |         msg = json.loads(text) | ||||||
|         # 2. AES解密 |  | ||||||
|         decrypted_text = aes_decrypt(encrypted_dict) |  | ||||||
|         # 3. 解析业务消息 |  | ||||||
|         msg = json.loads(decrypted_text) |  | ||||||
|  |  | ||||||
|         if msg.get("type") == "heart": |         if msg.get("type") == "heart": | ||||||
|             conn.update_heartbeat() |             conn.update_heartbeat() | ||||||
|             await send_heartbeat_ack(conn) |             await send_heartbeat_ack(conn)  # 服务器回复仍加密 | ||||||
|         else: |         else: | ||||||
|             print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 未知文本类型({msg.get('type')})") |             print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 未知明文文本类型({msg.get('type')})") | ||||||
|     except json.JSONDecodeError: |     except json.JSONDecodeError: | ||||||
|         print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 无效JSON格式") |         print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 无效JSON格式(明文文本)") | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 文本消息解密失败 - {str(e)}") |         print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 明文文本消息处理失败 - {str(e)}") | ||||||
|  |  | ||||||
|  |  | ||||||
| async def handle_binary_msg(conn: ClientConnection, data: str): | # -------------------------- 7. WebSocket路由与生命周期(关键修改:处理明文二进制图像)-------------------------- | ||||||
|     """处理加密的图像消息(客户端需先转Base64+加密)""" |  | ||||||
|     try: |  | ||||||
|         # 1. 解密得到Base64编码的图像 |  | ||||||
|         encrypted_dict = json.loads(data) |  | ||||||
|         decrypted_base64 = aes_decrypt(encrypted_dict) |  | ||||||
|         # 2. Base64解码为二进制图像 |  | ||||||
|         frame_data = base64.b64decode(decrypted_base64) |  | ||||||
|         # 3. 加入帧队列 |  | ||||||
|         conn.frame_queue.put_nowait(frame_data) |  | ||||||
|         print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 解密后图像({len(frame_data)}字节)入队") |  | ||||||
|     except asyncio.QueueFull: |  | ||||||
|         print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 帧队列已满,丢弃数据") |  | ||||||
|     except Exception as e: |  | ||||||
|         print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 图像消息解密失败 - {str(e)}") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # -------------------------- 7. WebSocket路由与生命周期(保持原有结构)-------------------------- |  | ||||||
| ws_router = APIRouter() | ws_router = APIRouter() | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -276,7 +227,6 @@ async def lifespan(app: FastAPI): | |||||||
|     heartbeat_task = asyncio.create_task(heartbeat_checker()) |     heartbeat_task = asyncio.create_task(heartbeat_checker()) | ||||||
|     print(f"[{get_current_time_str()}] 心跳检查任务启动(ID: {id(heartbeat_task)})") |     print(f"[{get_current_time_str()}] 心跳检查任务启动(ID: {id(heartbeat_task)})") | ||||||
|     yield |     yield | ||||||
|     # 关闭时清理 |  | ||||||
|     if heartbeat_task and not heartbeat_task.done(): |     if heartbeat_task and not heartbeat_task.done(): | ||||||
|         heartbeat_task.cancel() |         heartbeat_task.cancel() | ||||||
|         await heartbeat_task |         await heartbeat_task | ||||||
| @ -285,8 +235,8 @@ async def lifespan(app: FastAPI): | |||||||
|  |  | ||||||
| @ws_router.websocket(WS_ENDPOINT) | @ws_router.websocket(WS_ENDPOINT) | ||||||
| async def websocket_endpoint(websocket: WebSocket): | async def websocket_endpoint(websocket: WebSocket): | ||||||
|     """WebSocket连接处理入口""" |     """WebSocket连接处理入口 - 关键修改:接收客户端明文二进制图像""" | ||||||
|     load_model()  # 加载检测模型(仅一次) |     load_model()  # 加载检测模型(建议移到全局,避免重复加载) | ||||||
|     await websocket.accept() |     await websocket.accept() | ||||||
|     client_ip = websocket.client.host if websocket.client else "unknown_ip" |     client_ip = websocket.client.host if websocket.client else "unknown_ip" | ||||||
|     current_time = get_current_time_str() |     current_time = get_current_time_str() | ||||||
| @ -306,8 +256,8 @@ async def websocket_endpoint(websocket: WebSocket): | |||||||
|         # 注册新连接 |         # 注册新连接 | ||||||
|         new_conn = ClientConnection(websocket, client_ip) |         new_conn = ClientConnection(websocket, client_ip) | ||||||
|         connected_clients[client_ip] = new_conn |         connected_clients[client_ip] = new_conn | ||||||
|         new_conn.start_consumer()  # 启动帧消费 |         new_conn.start_consumer() | ||||||
|         await new_conn.send_frame_permit()  # 发送首次帧许可 |         await new_conn.send_frame_permit()  # 首次许可仍加密 | ||||||
|  |  | ||||||
|         # 标记客户端上线 |         # 标记客户端上线 | ||||||
|         try: |         try: | ||||||
| @ -321,28 +271,33 @@ async def websocket_endpoint(websocket: WebSocket): | |||||||
|  |  | ||||||
|         print(f"[{current_time}] 客户端{client_ip}: 连接注册成功,在线数: {len(connected_clients)}") |         print(f"[{current_time}] 客户端{client_ip}: 连接注册成功,在线数: {len(connected_clients)}") | ||||||
|  |  | ||||||
|         # 消息循环(接收客户端消息) |         # 消息循环:接收客户端明文消息(关键修改) | ||||||
|         while True: |         while True: | ||||||
|             data = await websocket.receive() |             data = await websocket.receive() | ||||||
|             if "text" in data: |             if "text" in data: | ||||||
|                 # 处理加密文本消息(心跳、客户端指令) |                 # 处理客户端明文文本(如心跳:{"type":"heart",...}) | ||||||
|                 await handle_text_msg(new_conn, data["text"]) |                 await handle_text_msg(new_conn, data["text"]) | ||||||
|             elif "bytes" in data: |             elif "bytes" in data: | ||||||
|                 # 兼容客户端发送二进制:先转Base64再处理 |                 # 处理客户端明文二进制图像(直接入队,无需解密) | ||||||
|                 base64_data = base64.b64encode(data["bytes"]).decode("utf-8") |                 frame_data = data["bytes"] | ||||||
|                 await handle_binary_msg(new_conn, base64_data) |                 try: | ||||||
|  |                     new_conn.frame_queue.put_nowait(frame_data) | ||||||
|  |                     print(f"[{get_current_time_str()}] 客户端{client_ip}: 明文图像({len(frame_data)}字节)入队") | ||||||
|  |                 except asyncio.QueueFull: | ||||||
|  |                     print(f"[{get_current_time_str()}] 客户端{client_ip}: 帧队列已满,丢弃数据") | ||||||
|  |                 except Exception as e: | ||||||
|  |                     print(f"[{get_current_time_str()}] 客户端{client_ip}: 明文图像处理失败 - {str(e)}") | ||||||
|  |  | ||||||
|     except WebSocketDisconnect as e: |     except WebSocketDisconnect as e: | ||||||
|         print(f"[{get_current_time_str()}] 客户端{client_ip}: 主动断开连接(代码: {e.code})") |         print(f"[{get_current_time_str()}] 客户端{client_ip}: 主动断开连接(代码: {e.code})") | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         print(f"[{get_current_time_str()}] 客户端{client_ip}: 连接异常 - {str(e)[:50]}") |         print(f"[{get_current_time_str()}] 客户端{client_ip}: 连接异常 - {str(e)[:50]}") | ||||||
|     finally: |     finally: | ||||||
|         # 清理资源(断开后处理) |         # 清理资源 | ||||||
|         if client_ip in connected_clients: |         if client_ip in connected_clients: | ||||||
|             conn = connected_clients[client_ip] |             conn = connected_clients[client_ip] | ||||||
|             if conn.consumer_task and not conn.consumer_task.done(): |             if conn.consumer_task and not conn.consumer_task.done(): | ||||||
|                 conn.consumer_task.cancel() |                 conn.consumer_task.cancel() | ||||||
|             # 仅上线成功的客户端,才标记离线 |  | ||||||
|             if is_online_updated: |             if is_online_updated: | ||||||
|                 try: |                 try: | ||||||
|                     await asyncio.to_thread(update_online_status_by_ip, client_ip, 0) |                     await asyncio.to_thread(update_online_status_by_ip, client_ip, 0) | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user