目前可以成功动态更换模型运行的
This commit is contained in:
		
							
								
								
									
										283
									
								
								app.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										283
									
								
								app.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,283 @@ | |||||||
|  | from flask import Flask, send_from_directory, abort, request | ||||||
|  | import os | ||||||
|  | import logging | ||||||
|  | from functools import wraps | ||||||
|  | from pathlib import Path | ||||||
|  | # 跨域依赖(必须安装:pip install flask-cors) | ||||||
|  | from flask_cors import CORS | ||||||
|  |  | ||||||
|  | # 配置日志(保持原有格式) | ||||||
|  | logging.basicConfig( | ||||||
|  |     level=logging.INFO, | ||||||
|  |     format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | ||||||
|  | ) | ||||||
|  | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
|  | # 初始化 Flask 应用(供 main.py 导入) | ||||||
|  | app = Flask(__name__) | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 核心修改:与 FastAPI 对齐的跨域配置 | ||||||
|  | # ------------------------------ | ||||||
|  | # 1. 允许的前端域名(完全复制 FastAPI 的 ALLOWED_ORIGINS,确保前后端一致) | ||||||
|  | ALLOWED_ORIGINS = [ | ||||||
|  |     # "http://localhost:8080",  # 本地前端开发地址(必改:替换为你的前端实际地址) | ||||||
|  |     # "http://127.0.0.1:8080", | ||||||
|  |     # "http://服务器IP:8080",    # 部署后前端地址(替换为你的服务器IP/域名) | ||||||
|  |     # # "*" 仅开发环境临时使用,生产环境必须删除(安全风险) | ||||||
|  |     "*" | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | # 2. 配置 CORS(与 FastAPI 规则完全对齐) | ||||||
|  | CORS( | ||||||
|  |     app, | ||||||
|  |     resources={ | ||||||
|  |         r"/*": {  # 对所有 Flask 路由生效(覆盖图片、模型下载所有接口) | ||||||
|  |             "origins": ALLOWED_ORIGINS,        # 允许的前端域名(与 FastAPI 一致) | ||||||
|  |             "allow_credentials": True,         # 允许携带 Cookie(与 FastAPI 一致,需登录态必开) | ||||||
|  |             "methods": ["*"],                  # 允许所有 HTTP 方法(FastAPI 用 "*",此处同步) | ||||||
|  |             "allow_headers": ["*"],            # 允许所有请求头(与 FastAPI 一致) | ||||||
|  |         } | ||||||
|  |     }, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 核心路径配置(不变,确保资源目录正确) | ||||||
|  | # ------------------------------ | ||||||
|  | CURRENT_FILE_PATH = Path(__file__).resolve() | ||||||
|  | PROJECT_ROOT = CURRENT_FILE_PATH.parent  # 项目根目录(video/) | ||||||
|  | # 资源目录(图片、模型) | ||||||
|  | BASE_IMAGE_DIR_DECT = str((PROJECT_ROOT / "resource" / "dect").resolve())       # 检测图片目录 | ||||||
|  | BASE_IMAGE_DIR_UP_IMAGES = str((PROJECT_ROOT / "up_images").resolve())          # 人脸图片目录 | ||||||
|  | BASE_MODEL_DIR = str((PROJECT_ROOT / "resource" / "models").resolve())          # 模型文件目录 | ||||||
|  |  | ||||||
|  | # 打印路径配置(调试用,确认目录正确) | ||||||
|  | logger.info(f"[Flask 配置] 项目根目录:{PROJECT_ROOT}") | ||||||
|  | logger.info(f"[Flask 配置] 模型目录:{BASE_MODEL_DIR}") | ||||||
|  | logger.info(f"[Flask 配置] 人脸图片目录:{BASE_IMAGE_DIR_UP_IMAGES}") | ||||||
|  | logger.info(f"[Flask 配置] 检测图片目录:{BASE_IMAGE_DIR_DECT}") | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 安全检查装饰器(不变,防路径遍历/非法文件) | ||||||
|  | # ------------------------------ | ||||||
|  | def safe_path_check(root_dir: str): | ||||||
|  |     def decorator(func): | ||||||
|  |         @wraps(func) | ||||||
|  |         def wrapper(*args, **kwargs): | ||||||
|  |             resource_path = kwargs.get('resource_path', '').strip() | ||||||
|  |             # 统一路径分隔符(兼容 Windows \ 和 Linux /) | ||||||
|  |             resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep) | ||||||
|  |             # 拼接完整路径(防止路径遍历) | ||||||
|  |             full_file_path = os.path.abspath(os.path.join(root_dir, resource_path)) | ||||||
|  |             logger.debug( | ||||||
|  |                 f"[Flask 安全检查] 请求路径:{resource_path} | 完整路径:{full_file_path} | 根目录:{root_dir}" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |             # 1. 禁止路径遍历(确保请求文件在根目录内) | ||||||
|  |             if not full_file_path.startswith(root_dir): | ||||||
|  |                 logger.warning( | ||||||
|  |                     f"[Flask 安全拦截] 非法路径遍历!IP:{request.remote_addr} | 请求路径:{resource_path}" | ||||||
|  |                 ) | ||||||
|  |                 abort(403) | ||||||
|  |  | ||||||
|  |             # 2. 检查文件存在且为有效文件(非目录) | ||||||
|  |             if not os.path.exists(full_file_path) or not os.path.isfile(full_file_path): | ||||||
|  |                 logger.warning( | ||||||
|  |                     f"[Flask 资源错误] 文件不存在/非文件!IP:{request.remote_addr} | 路径:{full_file_path}" | ||||||
|  |                 ) | ||||||
|  |                 abort(404) | ||||||
|  |  | ||||||
|  |             # 3. 限制文件大小(模型200MB,图片10MB,避免超大文件攻击) | ||||||
|  |             max_size = 200 * 1024 * 1024 if "models" in root_dir else 10 * 1024 * 1024 | ||||||
|  |             if os.path.getsize(full_file_path) > max_size: | ||||||
|  |                 logger.warning( | ||||||
|  |                     f"[Flask 大小超限] 文件超过{max_size//1024//1024}MB!IP:{request.remote_addr} | 路径:{full_file_path}" | ||||||
|  |                 ) | ||||||
|  |                 abort(413) | ||||||
|  |  | ||||||
|  |             # 安全检查通过,传递根目录给视图函数 | ||||||
|  |             return func(*args, **kwargs, root_dir=root_dir) | ||||||
|  |         return wrapper | ||||||
|  |     return decorator | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 1. 模型下载接口(/model/download/*) | ||||||
|  | # ------------------------------ | ||||||
|  | @app.route('/model/download/<path:resource_path>') | ||||||
|  | @safe_path_check(root_dir=BASE_MODEL_DIR) | ||||||
|  | def download_model(resource_path, root_dir): | ||||||
|  |     try: | ||||||
|  |         resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep) | ||||||
|  |         dir_path, file_name = os.path.split(resource_path) | ||||||
|  |         full_dir = os.path.abspath(os.path.join(root_dir, dir_path)) | ||||||
|  |  | ||||||
|  |         # 仅允许 .pt 格式(YOLO 模型) | ||||||
|  |         if not file_name.lower().endswith('.pt'): | ||||||
|  |             logger.warning( | ||||||
|  |                 f"[Flask 格式错误] 非 .pt 模型文件!IP:{request.remote_addr} | 文件名:{file_name}" | ||||||
|  |             ) | ||||||
|  |             abort(415) | ||||||
|  |  | ||||||
|  |         logger.info( | ||||||
|  |             f"[Flask 模型下载] 成功请求!IP:{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         # 强制浏览器下载(而非预览),设置二进制文件类型 | ||||||
|  |         return send_from_directory( | ||||||
|  |             full_dir, | ||||||
|  |             file_name, | ||||||
|  |             as_attachment=True, | ||||||
|  |             mimetype="application/octet-stream" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     except Exception as e: | ||||||
|  |         logger.error( | ||||||
|  |             f"[Flask 模型下载异常] IP:{request.remote_addr} | 错误:{str(e)}" | ||||||
|  |         ) | ||||||
|  |         abort(500) | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 2. 人脸图片访问接口(/up_images/*) | ||||||
|  | # ------------------------------ | ||||||
|  | @app.route('/up_images/<path:resource_path>') | ||||||
|  | @safe_path_check(root_dir=BASE_IMAGE_DIR_UP_IMAGES) | ||||||
|  | def get_face_image(resource_path, root_dir): | ||||||
|  |     try: | ||||||
|  |         resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep) | ||||||
|  |         dir_path, file_name = os.path.split(resource_path) | ||||||
|  |         full_dir = os.path.abspath(os.path.join(root_dir, dir_path)) | ||||||
|  |  | ||||||
|  |         # 仅允许常见图片格式 | ||||||
|  |         allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp') | ||||||
|  |         if not file_name.lower().endswith(allowed_ext): | ||||||
|  |             logger.warning( | ||||||
|  |                 f"[Flask 格式错误] 非图片文件!IP:{request.remote_addr} | 文件名:{file_name}" | ||||||
|  |             ) | ||||||
|  |             abort(415) | ||||||
|  |  | ||||||
|  |         logger.info( | ||||||
|  |             f"[Flask 人脸图片] 成功请求!IP:{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         # 允许浏览器预览图片(而非下载) | ||||||
|  |         return send_from_directory(full_dir, file_name, as_attachment=False) | ||||||
|  |  | ||||||
|  |     except Exception as e: | ||||||
|  |         logger.error( | ||||||
|  |             f"[Flask 人脸图片异常] IP:{request.remote_addr} | 错误:{str(e)}" | ||||||
|  |         ) | ||||||
|  |         abort(500) | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 3. 检测图片访问接口(/resource/dect/*) | ||||||
|  | # ------------------------------ | ||||||
|  | @app.route('/resource/dect/<path:resource_path>') | ||||||
|  | @safe_path_check(root_dir=BASE_IMAGE_DIR_DECT) | ||||||
|  | def get_dect_image(resource_path, root_dir): | ||||||
|  |     try: | ||||||
|  |         resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep) | ||||||
|  |         dir_path, file_name = os.path.split(resource_path) | ||||||
|  |         full_dir = os.path.abspath(os.path.join(root_dir, dir_path)) | ||||||
|  |  | ||||||
|  |         # 仅允许常见图片格式 | ||||||
|  |         allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp') | ||||||
|  |         if not file_name.lower().endswith(allowed_ext): | ||||||
|  |             logger.warning( | ||||||
|  |                 f"[Flask 格式错误] 非图片文件!IP:{request.remote_addr} | 文件名:{file_name}" | ||||||
|  |             ) | ||||||
|  |             abort(415) | ||||||
|  |  | ||||||
|  |         logger.info( | ||||||
|  |             f"[Flask 检测图片] 成功请求!IP:{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         return send_from_directory(full_dir, file_name, as_attachment=False) | ||||||
|  |  | ||||||
|  |     except Exception as e: | ||||||
|  |         logger.error( | ||||||
|  |             f"[Flask 检测图片异常] IP:{request.remote_addr} | 错误:{str(e)}" | ||||||
|  |         ) | ||||||
|  |         abort(500) | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 4. 兼容旧图片接口(/images/* → 映射到 /resource/dect/*) | ||||||
|  | # ------------------------------ | ||||||
|  | @app.route('/images/<path:resource_path>') | ||||||
|  | @safe_path_check(root_dir=BASE_IMAGE_DIR_DECT) | ||||||
|  | def get_compatible_image(resource_path, root_dir): | ||||||
|  |     try: | ||||||
|  |         # 逻辑与检测图片接口一致,仅URL前缀不同(兼容旧前端) | ||||||
|  |         resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep) | ||||||
|  |         dir_path, file_name = os.path.split(resource_path) | ||||||
|  |         full_dir = os.path.abspath(os.path.join(root_dir, dir_path)) | ||||||
|  |  | ||||||
|  |         allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp') | ||||||
|  |         if not file_name.lower().endswith(allowed_ext): | ||||||
|  |             logger.warning( | ||||||
|  |                 f"[Flask 格式错误] 非图片文件!IP:{request.remote_addr} | 文件名:{file_name}" | ||||||
|  |             ) | ||||||
|  |             abort(415) | ||||||
|  |  | ||||||
|  |         logger.info( | ||||||
|  |             f"[Flask 兼容图片] 成功请求!IP:{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         return send_from_directory(full_dir, file_name, as_attachment=False) | ||||||
|  |  | ||||||
|  |     except Exception as e: | ||||||
|  |         logger.error( | ||||||
|  |             f"[Flask 兼容图片异常] IP:{request.remote_addr} | 错误:{str(e)}" | ||||||
|  |         ) | ||||||
|  |         abort(500) | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 全局错误处理器(友好提示,与 FastAPI 错误信息风格一致) | ||||||
|  | # ------------------------------ | ||||||
|  | @app.errorhandler(403) | ||||||
|  | def forbidden_error(error): | ||||||
|  |     return "❌ 禁止访问:路径非法(可能存在路径遍历)或无权限", 403 | ||||||
|  |  | ||||||
|  | @app.errorhandler(404) | ||||||
|  | def not_found_error(error): | ||||||
|  |     return "❌ 资源不存在:请检查URL路径(IP、目录、文件名)是否正确", 404 | ||||||
|  |  | ||||||
|  | @app.errorhandler(413) | ||||||
|  | def too_large_error(error): | ||||||
|  |     return "❌ 文件过大:图片最大10MB,模型最大200MB", 413 | ||||||
|  |  | ||||||
|  | @app.errorhandler(415) | ||||||
|  | def unsupported_type_error(error): | ||||||
|  |     return "❌ 不支持的文件类型:图片支持png/jpg/jpeg/gif/bmp,模型仅支持pt", 415 | ||||||
|  |  | ||||||
|  | @app.errorhandler(500) | ||||||
|  | def server_error(error): | ||||||
|  |     return "❌ 服务器内部错误:请联系管理员查看后台日志", 500 | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # Flask 独立启动入口(供测试,实际由 main.py 子线程启动) | ||||||
|  | # ------------------------------ | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     # 确保所有资源目录存在(防止初始化失败) | ||||||
|  |     required_dirs = [ | ||||||
|  |         (BASE_IMAGE_DIR_DECT, "检测图片目录"), | ||||||
|  |         (BASE_IMAGE_DIR_UP_IMAGES, "人脸图片目录"), | ||||||
|  |         (BASE_MODEL_DIR, "模型文件目录") | ||||||
|  |     ] | ||||||
|  |     for dir_path, dir_desc in required_dirs: | ||||||
|  |         if not os.path.exists(dir_path): | ||||||
|  |             logger.info(f"[Flask 初始化] {dir_desc}不存在,创建:{dir_path}") | ||||||
|  |             os.makedirs(dir_path, exist_ok=True) | ||||||
|  |  | ||||||
|  |     # 启动提示(含访问示例) | ||||||
|  |     logger.info("\n[Flask 服务启动成功!] 支持的接口:") | ||||||
|  |     logger.info(f"1. 模型下载 → http://服务器IP:5000/model/download/resource/models/xxx.pt") | ||||||
|  |     logger.info(f"2. 人脸图片 → http://服务器IP:5000/up_images/xxx.jpg") | ||||||
|  |     logger.info(f"3. 检测图片 → http://服务器IP:5000/resource/dect/xxx.jpg 或 http://服务器IP:5000/images/xxx.jpg\n") | ||||||
|  |  | ||||||
|  |     # 启动服务(禁用 debug 和自动重载,避免多线程冲突) | ||||||
|  |     app.run( | ||||||
|  |         host="0.0.0.0",  # 允许外部IP访问 | ||||||
|  |         port=5000,        # 与 main.py 中 Flask 端口一致 | ||||||
|  |         debug=False, | ||||||
|  |         use_reloader=False | ||||||
|  |     ) | ||||||
							
								
								
									
										41
									
								
								core/all.py
									
									
									
									
									
								
							
							
						
						
									
										41
									
								
								core/all.py
									
									
									
									
									
								
							| @ -57,50 +57,39 @@ def save_db(model_type, client_ip, result): | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 修正后的 detect 函数关键部分 | ||||||
| def detect(client_ip, frame): | def detect(client_ip, frame): | ||||||
|     """ |     # 1. YOLO检测 | ||||||
|     执行模型检测,检测到违规时按指定格式保存图片 |  | ||||||
|     参数: |  | ||||||
|         frame: 待检测的图像帧(OpenCV格式,numpy.ndarray类型) |  | ||||||
|     返回: |  | ||||||
|         (检测结果布尔值, 检测详情, 检测模型类型) |  | ||||||
|     """ |  | ||||||
|     # 1. YOLO检测(优先级1) |  | ||||||
|     yolo_flag, yolo_result = yoloDetect(frame) |     yolo_flag, yolo_result = yoloDetect(frame) | ||||||
|     print(f"YOLO检测结果:{yolo_result}") |  | ||||||
|     if yolo_flag: |     if yolo_flag: | ||||||
|  |         # model_type 传入 "yolo"(正确) | ||||||
|         full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip) |         full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip) | ||||||
|         if full_save_path:  # 只判断完整路径是否有效(用于保存) |         if full_save_path: | ||||||
|             cv2.imwrite(full_save_path, frame) |             cv2.imwrite(full_save_path, frame) | ||||||
|             # 打印时使用「显示用短路径」,符合需求格式 |             print(f"✅ yolo违规图片已保存:{display_path}")  # 日志也修正 | ||||||
|             print(f"✅ YOLO违规图片已保存:{display_path}") |  | ||||||
|         save_db(model_type="yolo", client_ip=client_ip, result=str(full_save_path)) |         save_db(model_type="yolo", client_ip=client_ip, result=str(full_save_path)) | ||||||
|         return (True, yolo_result, "yolo") |         return (True, yolo_result, "yolo") | ||||||
|     # |  | ||||||
|     # # 2. 人脸检测(优先级2) |     # 2. 人脸检测 | ||||||
|     face_flag, face_result = faceDetect(frame) |     face_flag, face_result = faceDetect(frame) | ||||||
|     print(f"人脸检测结果:{face_result}") |  | ||||||
|     if face_flag: |     if face_flag: | ||||||
|         full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip) |         full_save_path, display_path = get_image_save_path(model_type="face", client_ip=client_ip)  # 这里改了 | ||||||
|         if full_save_path:  # 只判断完整路径是否有效(用于保存) |         if full_save_path: | ||||||
|             cv2.imwrite(full_save_path, frame) |             cv2.imwrite(full_save_path, frame) | ||||||
|             # 打印时使用「显示用短路径」,符合需求格式 |             print(f"✅ face违规图片已保存:{display_path}")  # 日志也修正 | ||||||
|             print(f"✅ face违规图片已保存:{display_path}") |  | ||||||
|         save_db(model_type="face", client_ip=client_ip, result=str(full_save_path)) |         save_db(model_type="face", client_ip=client_ip, result=str(full_save_path)) | ||||||
|         return (True, face_result, "face") |         return (True, face_result, "face") | ||||||
|  |  | ||||||
|     # 3. OCR检测(优先级3) |     # 3. OCR检测 | ||||||
|     ocr_flag, ocr_result = ocrDetect(frame) |     ocr_flag, ocr_result = ocrDetect(frame) | ||||||
|     print(f"OCR检测结果:{ocr_result}") |  | ||||||
|     if ocr_flag: |     if ocr_flag: | ||||||
|         # 解构元组,保存用完整路径,打印用短路径 |         full_save_path, display_path = get_image_save_path(model_type="ocr", client_ip=client_ip)  # 这里改了 | ||||||
|         full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip) |         if full_save_path: | ||||||
|         if full_save_path:  # 只判断完整路径是否有效(用于保存) |  | ||||||
|             cv2.imwrite(full_save_path, frame) |             cv2.imwrite(full_save_path, frame) | ||||||
|             # 打印时使用「显示用短路径」,符合需求格式 |             print(f"✅ ocr违规图片已保存:{display_path}")  # 日志也修正 | ||||||
|             print(f"✅ ocr违规图片已保存:{display_path}") |  | ||||||
|         save_db(model_type="ocr", client_ip=client_ip, result=str(full_save_path)) |         save_db(model_type="ocr", client_ip=client_ip, result=str(full_save_path)) | ||||||
|         return (True, ocr_result, "ocr") |         return (True, ocr_result, "ocr") | ||||||
|  |  | ||||||
|     # 4. 无违规内容(不保存图片) |     # 4. 无违规内容(不保存图片) | ||||||
|     print(f"❌ 未检测到任何违规内容,不保存图片") |     print(f"❌ 未检测到任何违规内容,不保存图片") | ||||||
|     return (False, "未检测到任何内容", "none") |     return (False, "未检测到任何内容", "none") | ||||||
							
								
								
									
										47
									
								
								core/yolo.py
									
									
									
									
									
								
							
							
						
						
									
										47
									
								
								core/yolo.py
									
									
									
									
									
								
							| @ -1,37 +1,43 @@ | |||||||
| import os | import os | ||||||
|  | import numpy as np | ||||||
| from ultralytics import YOLO | from ultralytics import YOLO | ||||||
|  | from service.model_service import get_current_yolo_model  # 从模型管理模块获取模型 | ||||||
|  |  | ||||||
| # 全局变量 | # 全局模型变量 | ||||||
| _yolo_model = None | _yolo_model = None | ||||||
|  |  | ||||||
|  |  | ||||||
| model_path = os.path.join(os.path.dirname(__file__), "models", "best.pt") | def load_model(model_path=None): | ||||||
|  |     """加载YOLO模型(优先使用模型管理模块的默认模型)""" | ||||||
|  |  | ||||||
| def load_model(): |  | ||||||
|     """加载YOLO目标检测模型""" |  | ||||||
|     global _yolo_model |     global _yolo_model | ||||||
|  |  | ||||||
|  |     if model_path is None: | ||||||
|  |         _yolo_model = get_current_yolo_model() | ||||||
|  |         return _yolo_model is not None | ||||||
|  |  | ||||||
|     try: |     try: | ||||||
|         _yolo_model = YOLO(model_path) |         _yolo_model = YOLO(model_path) | ||||||
|  |         return True | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         print(f"YOLO model load failed: {e}") |         print(f"YOLO模型加载失败(指定路径):{str(e)}") | ||||||
|         return False |         return False | ||||||
|  |  | ||||||
|     return True if _yolo_model else False |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def detect(frame, conf_threshold=0.2): | def detect(frame, conf_threshold=0.2): | ||||||
|     """YOLO目标检测、返回(是否识别到, 结果字符串)""" |     """执行目标检测,返回(是否成功, 结果字符串)""" | ||||||
|     global _yolo_model |     global _yolo_model | ||||||
|  |  | ||||||
|     if not _yolo_model or frame is None: |     # 确保模型已加载 | ||||||
|         return (False, "未初始化或无效帧") |     if not _yolo_model: | ||||||
|  |         if not load_model(): | ||||||
|  |             return (False, "模型未初始化") | ||||||
|  |  | ||||||
|  |     if frame is None: | ||||||
|  |         return (False, "无效输入帧") | ||||||
|  |  | ||||||
|     try: |     try: | ||||||
|         results = _yolo_model(frame, conf=conf_threshold) |         # 执行检测(frame应为numpy数组) | ||||||
|         # 检查是否有检测结果 |         results = _yolo_model(frame, conf=conf_threshold, verbose=False) | ||||||
|         has_results = len(results[0].boxes) > 0 if results else False |         has_results = len(results[0].boxes) > 0 if results else False | ||||||
|  |  | ||||||
|         if not has_results: |         if not has_results: | ||||||
| @ -42,13 +48,12 @@ def detect(frame, conf_threshold=0.2): | |||||||
|         for box in results[0].boxes: |         for box in results[0].boxes: | ||||||
|             cls = int(box.cls[0]) |             cls = int(box.cls[0]) | ||||||
|             conf = float(box.conf[0]) |             conf = float(box.conf[0]) | ||||||
|             bbox = [float(x) for x in box.xyxy[0]] |             bbox = [round(x, 2) for x in box.xyxy[0].tolist()]  # 保留两位小数 | ||||||
|             class_name = _yolo_model.names[cls] if hasattr(_yolo_model, 'names') else f"类别{cls}" |             class_name = _yolo_model.names[cls] if hasattr(_yolo_model, 'names') else f"类别{cls}" | ||||||
|             result_parts.append(f"{class_name} (置信度: {conf:.2f}, 边界框: {bbox})") |             result_parts.append(f"{class_name}(置信度:{conf:.2f},位置:{bbox})") | ||||||
|  |  | ||||||
|         result_str = "; ".join(result_parts) |         return (True, "; ".join(result_parts)) | ||||||
|         return (has_results, result_str) |  | ||||||
|  |  | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         print(f"YOLO detect error: {e}") |         print(f"检测过程出错:{str(e)}") | ||||||
|         return (False, f"检测错误: {str(e)}") |         return (False, f"检测错误:{str(e)}") | ||||||
|  | |||||||
							
								
								
									
										128
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										128
									
								
								main.py
									
									
									
									
									
								
							| @ -1,9 +1,17 @@ | |||||||
| from PIL import Image  # 正确导入 |  | ||||||
| import numpy as np |  | ||||||
|  |  | ||||||
| import uvicorn |  | ||||||
| from PIL import Image | from PIL import Image | ||||||
|  | import numpy as np | ||||||
|  | import uvicorn | ||||||
|  | import threading | ||||||
|  | import time | ||||||
|  | import os | ||||||
| from fastapi import FastAPI | from fastapi import FastAPI | ||||||
|  | # 新增:导入 CORS 相关依赖 | ||||||
|  | from fastapi.middleware.cors import CORSMiddleware | ||||||
|  |  | ||||||
|  | # 导入 Flask 服务实例 | ||||||
|  | from app import app as flask_app | ||||||
|  |  | ||||||
|  | # 原有业务导入 | ||||||
| from core.all import load_model, detect | from core.all import load_model, detect | ||||||
| from ds.config import SERVER_CONFIG | from ds.config import SERVER_CONFIG | ||||||
| from middle.error_handler import global_exception_handler | from middle.error_handler import global_exception_handler | ||||||
| @ -11,52 +19,124 @@ from service.user_service import router as user_router | |||||||
| from service.sensitive_service import router as sensitive_router | from service.sensitive_service import router as sensitive_router | ||||||
| from service.face_service import router as face_router | from service.face_service import router as face_router | ||||||
| from service.device_service import router as device_router | from service.device_service import router as device_router | ||||||
|  | from service.model_service import router as model_router  # 模型管理路由 | ||||||
| from ws.ws import ws_router, lifespan | from ws.ws import ws_router, lifespan | ||||||
| from core.establish import create_directory_structure | from core.establish import create_directory_structure | ||||||
|  |  | ||||||
| # ------------------------------ |  | ||||||
| # 初始化 FastAPI 应用、指定生命周期管理 | # Flask 服务启动函数(不变) | ||||||
| # ------------------------------ | def start_flask_service(): | ||||||
|  |     try: | ||||||
|  |         print(f"\n[Flask 服务] 准备启动,端口:5000") | ||||||
|  |         print(f"[Flask 服务] 访问示例:http://服务器IP:5000/resource/dect/ocr/xxx.jpg\n") | ||||||
|  |  | ||||||
|  |         BASE_IMAGE_DIR = os.path.abspath(os.path.join("resource", "dect")) | ||||||
|  |         if not os.path.exists(BASE_IMAGE_DIR): | ||||||
|  |             print(f"[Flask 服务] 图片根目录不存在,创建:{BASE_IMAGE_DIR}") | ||||||
|  |             os.makedirs(BASE_IMAGE_DIR, exist_ok=True) | ||||||
|  |  | ||||||
|  |         flask_app.run( | ||||||
|  |             host="0.0.0.0", | ||||||
|  |             port=5000, | ||||||
|  |             debug=False, | ||||||
|  |             use_reloader=False | ||||||
|  |         ) | ||||||
|  |     except Exception as e: | ||||||
|  |         print(f"[Flask 服务] 启动失败:{str(e)}") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 初始化 FastAPI 应用(新增 CORS 配置) | ||||||
| app = FastAPI( | app = FastAPI( | ||||||
|     title="内容安全审核后台", |     title="内容安全审核后台", | ||||||
|     description="内容安全审核后台", |     description="含图片访问服务和动态模型管理", | ||||||
|     version="1.0.0", |     version="1.0.0", | ||||||
|     lifespan=lifespan |     lifespan=lifespan | ||||||
| ) | ) | ||||||
|  |  | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
| # 注册路由 | # 新增:完整 CORS 配置(解决跨域问题) | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
|  | # 1. 允许的前端域名(根据实际情况修改!本地开发通常是 http://localhost:8080 等) | ||||||
|  | ALLOWED_ORIGINS = [ | ||||||
|  |     # "http://localhost:8080",  # 前端本地开发地址(必改,填实际前端地址) | ||||||
|  |     # "http://127.0.0.1:8080", | ||||||
|  |     # "http://服务器IP:8080",    # 部署后前端地址(如适用) | ||||||
|  |     "*" #表示允许所有域名(开发环境可用,生产环境不推荐) | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | # 2. 配置 CORS 中间件 | ||||||
|  | app.add_middleware( | ||||||
|  |     CORSMiddleware, | ||||||
|  |     allow_origins=ALLOWED_ORIGINS,        # 允许的前端域名 | ||||||
|  |     allow_credentials=True,               # 允许携带 Cookie(如需登录态则必开) | ||||||
|  |     allow_methods=["*"],                  # 允许所有 HTTP 方法(包括 PUT/DELETE) | ||||||
|  |     allow_headers=["*"],                  # 允许所有请求头(包括 Content-Type) | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | # 注册路由(不变) | ||||||
| app.include_router(user_router) | app.include_router(user_router) | ||||||
| app.include_router(device_router) | app.include_router(device_router) | ||||||
| app.include_router(face_router) | app.include_router(face_router) | ||||||
| app.include_router(sensitive_router) | app.include_router(sensitive_router) | ||||||
|  | app.include_router(model_router)  # 模型管理路由 | ||||||
| app.include_router(ws_router) | app.include_router(ws_router) | ||||||
|  |  | ||||||
| # ------------------------------ | # 注册全局异常处理器(不变) | ||||||
| # 注册全局异常处理器 |  | ||||||
| # ------------------------------ |  | ||||||
| app.add_exception_handler(Exception, global_exception_handler) | app.add_exception_handler(Exception, global_exception_handler) | ||||||
|  |  | ||||||
| # ------------------------------ | # 主服务启动入口(不变) | ||||||
| # 启动服务 |  | ||||||
| # ------------------------------ |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     # -------------------------- 配置调整 -------------------------- |     # 1. 初始化资源 | ||||||
|     # 模型配置路径(建议改为环境变量) |  | ||||||
|     YOLO_MODEL_PATH = r"/core/models\best.pt" |  | ||||||
|     OCR_CONFIG_PATH = r"/core/config\config.yaml" |  | ||||||
|  |  | ||||||
|     create_directory_structure() |     create_directory_structure() | ||||||
|  |     print(f"[初始化] 目录结构创建完成") | ||||||
|  |  | ||||||
|     # 初始化项目(默认端口设为8000、避免初始化失败时port未定义) |     # 创建模型保存目录 | ||||||
|  |     MODEL_SAVE_DIR = os.path.join("core", "models") | ||||||
|  |     os.makedirs(MODEL_SAVE_DIR, exist_ok=True) | ||||||
|  |     print(f"[初始化] 模型保存目录:{MODEL_SAVE_DIR}") | ||||||
|  |  | ||||||
|  |     # # 模型路径配置 | ||||||
|  |     # YOLO_MODEL_PATH = os.path.join("core", "models", "best.pt") | ||||||
|  |     # OCR_CONFIG_PATH = os.path.join("core", "config", "config.yaml") | ||||||
|  |     # print(f"[初始化] 默认YOLO模型路径:{YOLO_MODEL_PATH}") | ||||||
|  |     # print(f"[初始化] OCR 配置路径:{OCR_CONFIG_PATH}") | ||||||
|  |  | ||||||
|  |     # 加载检测模型 | ||||||
|  |     try: | ||||||
|  |         load_success = load_model() | ||||||
|  |         if load_success: | ||||||
|  |             print(f"[初始化] 检测模型加载完成") | ||||||
|  |         else: | ||||||
|  |             print(f"[初始化] 未找到默认模型,可通过API上传并设置默认模型") | ||||||
|  |     except Exception as e: | ||||||
|  |         print(f"[初始化] 模型加载警告:{str(e)}(服务仍启动)") | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     # 2. 启动 Flask 服务(子线程) | ||||||
|  |     flask_thread = threading.Thread( | ||||||
|  |         target=start_flask_service, | ||||||
|  |         daemon=True | ||||||
|  |     ) | ||||||
|  |     flask_thread.start() | ||||||
|  |  | ||||||
|  |     # 等待 Flask 初始化 | ||||||
|  |     time.sleep(1) | ||||||
|  |     if flask_thread.is_alive(): | ||||||
|  |         print(f"[Flask 服务] 启动成功(运行中)") | ||||||
|  |     else: | ||||||
|  |         print(f"[Flask 服务] 启动失败!图片访问不可用") | ||||||
|  |  | ||||||
|  |     # 3. 启动 FastAPI 主服务 | ||||||
|     port = int(SERVER_CONFIG.get("port", 8000)) |     port = int(SERVER_CONFIG.get("port", 8000)) | ||||||
|  |     print(f"\n[FastAPI 服务] 准备启动,端口:{port}") | ||||||
|  |     print(f"[FastAPI 服务] 接口文档:http://服务器IP:{port}/docs\n") | ||||||
|  |  | ||||||
|     # 启动 UVicorn 服务 |  | ||||||
|     uvicorn.run( |     uvicorn.run( | ||||||
|         app="main:app", |         app="main:app", | ||||||
|         host="0.0.0.0", |         host="0.0.0.0", | ||||||
|         port=port, |         port=port, | ||||||
|         workers=8, |         workers=1, | ||||||
|         ws="websockets" |         ws="websockets", | ||||||
|  |         reload=False | ||||||
|     ) |     ) | ||||||
| @ -1,30 +1,41 @@ | |||||||
| from datetime import datetime | from datetime import datetime | ||||||
| from pydantic import BaseModel, Field | from pydantic import BaseModel, Field | ||||||
|  | from typing import List, Optional | ||||||
|  |  | ||||||
|  |  | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
| # 请求模型(前端传参校验) | # 请求模型(前端传参校验)- 保留update的eigenvalue(如需更新特征值) | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
| class FaceCreateRequest(BaseModel): | class FaceCreateRequest(BaseModel): | ||||||
|     """创建人脸记录请求模型(无需ID、由数据库自增)""" |     """创建人脸记录请求模型(无需ID、由数据库自增)""" | ||||||
|     name: str = Field(None, max_length=255, description="名称(可选、最长255字符)") |     name: Optional[str] = Field(None, max_length=255, description="名称(可选、最长255字符)") | ||||||
|  |  | ||||||
|  |  | ||||||
| class FaceUpdateRequest(BaseModel): | class FaceUpdateRequest(BaseModel): | ||||||
|     """更新人脸记录请求模型(不变)""" |     """更新人脸记录请求模型 - 保留eigenvalue(如需更新特征值,不影响返回)""" | ||||||
|     name: str = Field(None, max_length=255, description="名称") |     name: Optional[str] = Field(None, max_length=255, description="名称(可选)") | ||||||
|     eigenvalue: str = Field(None, max_length=255, description="特征(文件处理后可更新)") |     eigenvalue: Optional[str] = Field(None, description="特征值(可选,文件处理后可更新)")  # 保留更新能力 | ||||||
|  |     address: Optional[str] = Field(None, description="图片完整路径(可选,更新图片时使用)") | ||||||
|  |  | ||||||
|  |  | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
| # 响应模型(后端返回数据) | # 响应模型(后端返回数据)- 核心修改:删除eigenvalue字段 | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
| class FaceResponse(BaseModel): | class FaceResponse(BaseModel): | ||||||
|     """人脸记录响应模型(仍包含ID、由数据库生成后返回)""" |     """人脸记录响应模型(仅返回需要的字段,移除eigenvalue)""" | ||||||
|     id: int = Field(..., description="主键ID(数据库自增)") |     id: int = Field(..., description="主键ID(数据库自增)") | ||||||
|     name: str = Field(None, description="名称") |     name: Optional[str] = Field(None, description="名称") | ||||||
|     eigenvalue: str | None = Field(None, description="特征(可为空)") |     address: Optional[str] = Field(None, description="人脸图片完整保存路径(数据库新增字段)")  # 仅保留address | ||||||
|     created_at: datetime = Field(..., description="记录创建时间") |     created_at: datetime = Field(..., description="记录创建时间(数据库自动生成)") | ||||||
|     updated_at: datetime = Field(..., description="记录更新时间") |     updated_at: datetime = Field(..., description="记录更新时间(数据库自动生成)") | ||||||
|  |  | ||||||
|  |     # 关键配置:支持从数据库查询结果(字典)直接转换 | ||||||
|  |     model_config = {"from_attributes": True} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class FaceListResponse(BaseModel): | ||||||
|  |     """人脸列表分页响应模型(结构不变,内部FaceResponse已移除eigenvalue)""" | ||||||
|  |     total: int = Field(..., description="筛选后的总记录数") | ||||||
|  |     faces: List[FaceResponse] = Field(..., description="当前页的人脸记录列表") | ||||||
|  |  | ||||||
|     model_config = {"from_attributes": True} |     model_config = {"from_attributes": True} | ||||||
							
								
								
									
										37
									
								
								schema/model_schema.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								schema/model_schema.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,37 @@ | |||||||
|  | from datetime import datetime | ||||||
|  | from pydantic import BaseModel, Field | ||||||
|  | from typing import List, Optional | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 请求模型 | ||||||
|  | class ModelCreateRequest(BaseModel): | ||||||
|  |     name: str = Field(..., max_length=255, description="模型名称(必填,如:yolo-v8s-car)") | ||||||
|  |     description: Optional[str] = Field(None, description="模型描述(可选)") | ||||||
|  |     is_default: Optional[bool] = Field(False, description="是否设为默认模型") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ModelUpdateRequest(BaseModel): | ||||||
|  |     name: Optional[str] = Field(None, max_length=255, description="模型名称(可选修改)") | ||||||
|  |     description: Optional[str] = Field(None, description="模型描述(可选修改)") | ||||||
|  |     is_default: Optional[bool] = Field(None, description="是否设为默认模型(可选切换)") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 响应模型 | ||||||
|  | class ModelResponse(BaseModel): | ||||||
|  |     id: int = Field(..., description="模型ID") | ||||||
|  |     name: str = Field(..., description="模型名称") | ||||||
|  |     path: str = Field(..., description="模型文件相对路径") | ||||||
|  |     is_default: bool = Field(..., description="是否默认模型") | ||||||
|  |     description: Optional[str] = Field(None, description="模型描述") | ||||||
|  |     file_size: Optional[int] = Field(None, description="文件大小(字节)") | ||||||
|  |     created_at: datetime = Field(..., description="创建时间") | ||||||
|  |     updated_at: datetime = Field(..., description="更新时间") | ||||||
|  |  | ||||||
|  |     model_config = {"from_attributes": True} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ModelListResponse(BaseModel): | ||||||
|  |     total: int = Field(..., description="总记录数") | ||||||
|  |     models: List[ModelResponse] = Field(..., description="当前页模型列表") | ||||||
|  |  | ||||||
|  |     model_config = {"from_attributes": True} | ||||||
| @ -1,6 +1,6 @@ | |||||||
| from datetime import datetime | from datetime import datetime | ||||||
|  |  | ||||||
| from pydantic import BaseModel, Field | from pydantic import BaseModel, Field | ||||||
|  | from typing import List, Optional | ||||||
|  |  | ||||||
|  |  | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
| @ -30,3 +30,11 @@ class UserResponse(BaseModel): | |||||||
|  |  | ||||||
|     # Pydantic V2 配置(支持从数据库查询结果转换) |     # Pydantic V2 配置(支持从数据库查询结果转换) | ||||||
|     model_config = {"from_attributes": True} |     model_config = {"from_attributes": True} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class UserListResponse(BaseModel): | ||||||
|  |     """用户列表分页响应模型(与设备/人脸列表结构对齐)""" | ||||||
|  |     total: int = Field(..., description="用户总数") | ||||||
|  |     users: List[UserResponse] = Field(..., description="当前页用户列表") | ||||||
|  |  | ||||||
|  |     model_config = {"from_attributes": True} | ||||||
| @ -1,162 +1,140 @@ | |||||||
| from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form | from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query, Request | ||||||
|  | from fastapi.responses import FileResponse | ||||||
| from mysql.connector import Error as MySQLError | from mysql.connector import Error as MySQLError | ||||||
|  | import os | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
| from ds.db import db | from ds.db import db | ||||||
| from schema.face_schema import FaceCreateRequest, FaceUpdateRequest, FaceResponse | from schema.face_schema import ( | ||||||
|  |     FaceCreateRequest, | ||||||
|  |     FaceUpdateRequest, | ||||||
|  |     FaceResponse, | ||||||
|  |     FaceListResponse | ||||||
|  | ) | ||||||
| from schema.response_schema import APIResponse | from schema.response_schema import APIResponse | ||||||
| from middle.auth_middleware import get_current_user |  | ||||||
| from schema.user_schema import UserResponse |  | ||||||
|  |  | ||||||
| from util.face_util import add_binary_data, get_average_feature | from util.face_util import add_binary_data, get_average_feature | ||||||
| #初始化实例 | from util.file_util import save_face_to_up_images | ||||||
|  |  | ||||||
| router = APIRouter( |  | ||||||
|     prefix="/faces", |  | ||||||
|     tags=["人脸管理"] |  | ||||||
| ) |  | ||||||
|  |  | ||||||
|  | router = APIRouter(prefix="/faces", tags=["人脸管理"]) | ||||||
|  |  | ||||||
|  |  | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
| # 1. 创建人脸记录(核心修正: ID 数据库自增、前端无需传) | # 1. 创建人脸记录(使用修复后的路径) | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
| @router.post("", response_model=APIResponse, summary="创建人脸记录(传名称+文件、ID自增)") | @router.post("", response_model=APIResponse, summary="创建人脸记录") | ||||||
| async def create_face( | async def create_face( | ||||||
|         # 前端仅需传: name(可选、Form格式)、file(必传、文件) |         request: Request, | ||||||
|         name: str = Form(None, max_length=255, description="名称(可选)"), |         name: str = Form(None, max_length=255, description="名称(可选)"), | ||||||
|         file: UploadFile = File(..., description="人脸文件(必传、暂不处理内容)") |         file: UploadFile = File(..., description="人脸文件(必传)") | ||||||
| ): | ): | ||||||
|     """ |  | ||||||
|     创建人脸记录:  |  | ||||||
|     - 需登录认证 |  | ||||||
|     - 前端传参: multipart/form-data 表单(name 可选、file 必传) |  | ||||||
|     - ID 由数据库自动生成、无需前端传入 |  | ||||||
|     - 暂不处理文件内容、eigenvalue 设为 None |  | ||||||
|     """ |  | ||||||
|  |  | ||||||
|     # 调用你的方法 |  | ||||||
|     conn = None |     conn = None | ||||||
|     cursor = None |     cursor = None | ||||||
|     try: |     try: | ||||||
|         # 1. 用模型校验 name(仅校验长度、无需ID) |  | ||||||
|         face_create = FaceCreateRequest(name=name) |         face_create = FaceCreateRequest(name=name) | ||||||
|  |         client_ip = request.client.host if request.client else "" | ||||||
|  |         if not client_ip: | ||||||
|  |             raise HTTPException(status_code=400, detail="无法获取客户端IP") | ||||||
|  |  | ||||||
|         conn = db.get_connection() |         conn = db.get_connection() | ||||||
|         cursor = conn.cursor(dictionary=True) |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|         # 把文件转为二进制数组 |         # 读取图片并保存(使用修复后的路径逻辑) | ||||||
|         file_content = await file.read() |         file_content = await file.read() | ||||||
|  |         file_ext = file.filename.split(".")[-1].lower() if "." in file.filename else "jpg" | ||||||
|         # 计算特征值 |         save_result = save_face_to_up_images( | ||||||
|         flag, eigenvalue = add_binary_data(file_content) |             client_ip=client_ip, | ||||||
|  |             face_name=name, | ||||||
|         if flag == False: |             image_bytes=file_content, | ||||||
|             raise HTTPException( |             image_format=file_ext | ||||||
|                 status_code=500, |  | ||||||
|                 detail="未检测到人脸" |  | ||||||
|         ) |         ) | ||||||
|  |         if not save_result["success"]: | ||||||
|  |             raise HTTPException(status_code=500, detail=f"图片保存失败:{save_result['msg']}") | ||||||
|  |         db_image_path = save_result["db_path"]  # 从修复后的方法获取路径 | ||||||
|  |  | ||||||
|         # 打印数组长度 |         # 提取人脸特征 | ||||||
|         print(f"文件大小: {len(file_content)} 字节") |         detect_success, detect_result = add_binary_data(file_content) | ||||||
|  |         if not detect_success: | ||||||
|  |             raise HTTPException(status_code=400, detail=f"人脸检测失败:{detect_result}") | ||||||
|  |         eigenvalue = detect_result | ||||||
|  |  | ||||||
|         # 2. 插入数据库: 无需传 ID(自增)、只传 name 和 eigenvalue(None) |         # 插入数据库 | ||||||
|         insert_query = """ |         insert_query = """ | ||||||
|             INSERT INTO face (name, eigenvalue) |             INSERT INTO face (name, eigenvalue, address) | ||||||
|             VALUES (%s, %s) |             VALUES (%s, %s, %s) | ||||||
|         """ |         """ | ||||||
|         cursor.execute(insert_query, (face_create.name, str(eigenvalue))) |         cursor.execute(insert_query, (face_create.name, str(eigenvalue), db_image_path)) | ||||||
|         conn.commit() |         conn.commit() | ||||||
|  |  | ||||||
|         # 3. 获取数据库自动生成的 ID(关键: 用 LAST_INSERT_ID() 查刚插入的记录) |         # 查询新记录 | ||||||
|         select_new_query = "SELECT * FROM face WHERE id = LAST_INSERT_ID()" |         cursor.execute(""" | ||||||
|         cursor.execute(select_new_query) |             SELECT id, name, address, created_at, updated_at  | ||||||
|  |             FROM face  | ||||||
|  |             WHERE id = LAST_INSERT_ID() | ||||||
|  |         """) | ||||||
|         created_face = cursor.fetchone() |         created_face = cursor.fetchone() | ||||||
|  |  | ||||||
|         if not created_face: |         if not created_face: | ||||||
|             raise HTTPException( |             raise HTTPException(status_code=500, detail="创建成功但无法获取记录") | ||||||
|                 status_code=500, |  | ||||||
|                 detail="创建人脸记录成功、但无法获取新创建的记录" |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|         return APIResponse( |         return APIResponse( | ||||||
|             code=201, |             code=201, | ||||||
|             message=f"人脸记录创建成功(ID: {created_face['id']}、文件名: {file.filename})", |             message=f"人脸记录创建成功(ID: {created_face['id']})", | ||||||
|             data=FaceResponse(**created_face) |             data=FaceResponse(**created_face) | ||||||
|         ) |         ) | ||||||
|     except MySQLError as e: |     except MySQLError as e: | ||||||
|         if conn: |         if conn: | ||||||
|             conn.rollback() |             conn.rollback() | ||||||
|         # 改为使用HTTPException |         raise HTTPException(status_code=500, detail=f"创建失败: {str(e)}") from e | ||||||
|         raise HTTPException( |  | ||||||
|             status_code=500, |  | ||||||
|             detail=f"创建人脸记录失败: {str(e)}" |  | ||||||
|         ) from e |  | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         # 捕获其他可能的异常 |         raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e | ||||||
|         raise HTTPException( |  | ||||||
|             status_code=500, |  | ||||||
|             detail=f"服务器错误: {str(e)}" |  | ||||||
|         ) from e |  | ||||||
|     finally: |     finally: | ||||||
|         await file.close()  # 关闭文件流 |         await file.close() | ||||||
|         db.close_connection(conn, cursor) |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|         # 调用人脸识别得到特征值(这里可以添加你的人脸识别逻辑) |  | ||||||
|         flag, eigenvalue = add_binary_data(file_content) |  | ||||||
|         if flag == False: |  | ||||||
|             raise HTTPException( |  | ||||||
|                 status_code=500, |  | ||||||
|                 detail="未检测到人脸" |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|         # 将 eigenvalue 转为 str |  | ||||||
|         eigenvalue = str(eigenvalue) |  | ||||||
|  |  | ||||||
|  | # 其他接口(获取单条/列表、更新、删除、获取图片)与之前一致,无需修改 | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
| # 2. 获取单个人脸记录(不变、用自增ID查询) | # 2. 获取单个人脸记录 | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
| @router.get("/{face_id}", response_model=APIResponse, summary="获取单个人脸记录") | @router.get("/{face_id}", response_model=APIResponse, summary="获取单个人脸记录") | ||||||
| async def get_face( | async def get_face(face_id: int): | ||||||
|         face_id: int,  # 这里的 ID 是数据库自增的、前端从创建响应中获取 |  | ||||||
|         current_user: UserResponse = Depends(get_current_user) |  | ||||||
| ): |  | ||||||
|     conn = None |     conn = None | ||||||
|     cursor = None |     cursor = None | ||||||
|     try: |     try: | ||||||
|         conn = db.get_connection() |         conn = db.get_connection() | ||||||
|         cursor = conn.cursor(dictionary=True) |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|         query = "SELECT * FROM face WHERE id = %s" |         query = """ | ||||||
|  |             SELECT id, name, address, created_at, updated_at  | ||||||
|  |             FROM face  | ||||||
|  |             WHERE id = %s | ||||||
|  |         """ | ||||||
|         cursor.execute(query, (face_id,)) |         cursor.execute(query, (face_id,)) | ||||||
|         face = cursor.fetchone() |         face = cursor.fetchone() | ||||||
|  |  | ||||||
|         if not face: |         if not face: | ||||||
|             raise HTTPException( |             raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在") | ||||||
|                 status_code=404, |  | ||||||
|                 detail=f"ID为 {face_id} 的人脸记录不存在" |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|         return APIResponse( |         return APIResponse( | ||||||
|             code=200, |             code=200, | ||||||
|             message="人脸记录查询成功", |             message="查询成功", | ||||||
|             data=FaceResponse(**face) |             data=FaceResponse(**face) | ||||||
|         ) |         ) | ||||||
|     except MySQLError as e: |     except MySQLError as e: | ||||||
|         # 改为使用HTTPException |         raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e | ||||||
|         raise HTTPException( |  | ||||||
|             status_code=500, |  | ||||||
|             detail=f"查询人脸记录失败: {str(e)}" |  | ||||||
|         ) from e |  | ||||||
|     finally: |     finally: | ||||||
|         db.close_connection(conn, cursor) |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
| # 后续 3.获取所有、4.更新、5.删除 接口(修复异常处理) |  | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
| # 3. 获取所有人脸记录(不变) | # 3. 获取人脸列表 | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
| @router.get("", response_model=APIResponse, summary="获取所有人脸记录") | @router.get("", response_model=APIResponse, summary="获取人脸列表(分页+筛选)") | ||||||
| async def get_all_faces( | async def get_face_list( | ||||||
|  |         page: int = Query(1, ge=1), | ||||||
|  |         page_size: int = Query(10, ge=1, le=100), | ||||||
|  |         name: str = Query(None), | ||||||
|  |         has_eigenvalue: bool = Query(None) | ||||||
| ): | ): | ||||||
|     conn = None |     conn = None | ||||||
|     cursor = None |     cursor = None | ||||||
| @ -164,50 +142,66 @@ async def get_all_faces( | |||||||
|         conn = db.get_connection() |         conn = db.get_connection() | ||||||
|         cursor = conn.cursor(dictionary=True) |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|         query = "SELECT * FROM face ORDER BY id"  # 按自增ID排序 |         where_clause = [] | ||||||
|         cursor.execute(query) |         params = [] | ||||||
|         faces = cursor.fetchall() |         if name: | ||||||
|  |             where_clause.append("name LIKE %s") | ||||||
|  |             params.append(f"%{name}%") | ||||||
|  |         if has_eigenvalue is not None: | ||||||
|  |             where_clause.append("eigenvalue IS NOT NULL" if has_eigenvalue else "eigenvalue IS NULL") | ||||||
|  |  | ||||||
|  |         # 总记录数 | ||||||
|  |         count_query = "SELECT COUNT(*) AS total FROM face" | ||||||
|  |         if where_clause: | ||||||
|  |             count_query += " WHERE " + " AND ".join(where_clause) | ||||||
|  |         cursor.execute(count_query, params) | ||||||
|  |         total = cursor.fetchone()["total"] | ||||||
|  |  | ||||||
|  |         # 列表数据 | ||||||
|  |         offset = (page - 1) * page_size | ||||||
|  |         list_query = """ | ||||||
|  |             SELECT id, name, address, created_at, updated_at  | ||||||
|  |             FROM face | ||||||
|  |         """ | ||||||
|  |         if where_clause: | ||||||
|  |             list_query += " WHERE " + " AND ".join(where_clause) | ||||||
|  |         list_query += " ORDER BY id DESC LIMIT %s OFFSET %s" | ||||||
|  |         params.extend([page_size, offset]) | ||||||
|  |  | ||||||
|  |         cursor.execute(list_query, params) | ||||||
|  |         face_list = cursor.fetchall() | ||||||
|  |  | ||||||
|         return APIResponse( |         return APIResponse( | ||||||
|             code=200, |             code=200, | ||||||
|             message="所有人脸记录查询成功", |             message=f"获取成功(共{total}条)", | ||||||
|             data=[FaceResponse(** face) for face in faces] |             data=FaceListResponse( | ||||||
|  |                 total=total, | ||||||
|  |                 faces=[FaceResponse(**face) for face in face_list] | ||||||
|  |             ) | ||||||
|         ) |         ) | ||||||
|     except MySQLError as e: |     except MySQLError as e: | ||||||
|         raise HTTPException( |         raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e | ||||||
|             status_code=500, |  | ||||||
|             detail=f"查询所有人脸记录失败: {str(e)}" |  | ||||||
|         ) from e |  | ||||||
|     finally: |     finally: | ||||||
|         db.close_connection(conn, cursor) |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
| # 4. 更新人脸记录(不变、用自增ID更新) | # 4. 更新人脸记录 | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
| @router.put("/{face_id}", response_model=APIResponse, summary="更新人脸记录") | @router.put("/{face_id}", response_model=APIResponse, summary="更新人脸记录") | ||||||
| async def update_face( | async def update_face(face_id: int, face_update: FaceUpdateRequest): | ||||||
|         face_id: int, |  | ||||||
|         face_update: FaceUpdateRequest, |  | ||||||
|         current_user: UserResponse = Depends(get_current_user) |  | ||||||
| ): |  | ||||||
|     conn = None |     conn = None | ||||||
|     cursor = None |     cursor = None | ||||||
|     try: |     try: | ||||||
|         conn = db.get_connection() |         conn = db.get_connection() | ||||||
|         cursor = conn.cursor(dictionary=True) |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|         # 检查记录是否存在 |         cursor.execute("SELECT id, address FROM face WHERE id = %s", (face_id,)) | ||||||
|         check_query = "SELECT id FROM face WHERE id = %s" |         exist_face = cursor.fetchone() | ||||||
|         cursor.execute(check_query, (face_id,)) |         if not exist_face: | ||||||
|         existing_face = cursor.fetchone() |             raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在") | ||||||
|         if not existing_face: |         old_db_path = exist_face["address"] | ||||||
|             raise HTTPException( |  | ||||||
|                 status_code=404, |  | ||||||
|                 detail=f"ID为 {face_id} 的人脸记录不存在" |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|         # 构建更新语句 |  | ||||||
|         update_fields = [] |         update_fields = [] | ||||||
|         params = [] |         params = [] | ||||||
|         if face_update.name is not None: |         if face_update.name is not None: | ||||||
| @ -216,6 +210,18 @@ async def update_face( | |||||||
|         if face_update.eigenvalue is not None: |         if face_update.eigenvalue is not None: | ||||||
|             update_fields.append("eigenvalue = %s") |             update_fields.append("eigenvalue = %s") | ||||||
|             params.append(face_update.eigenvalue) |             params.append(face_update.eigenvalue) | ||||||
|  |         if face_update.address is not None: | ||||||
|  |             # 删除旧图片(相对路径转绝对路径) | ||||||
|  |             if old_db_path: | ||||||
|  |                 old_abs_path = Path(old_db_path).resolve() | ||||||
|  |                 if old_abs_path.exists(): | ||||||
|  |                     try: | ||||||
|  |                         old_abs_path.unlink()  # 使用Path方法删除更安全 | ||||||
|  |                         print(f"[FaceRouter] 已删除旧图片:{old_abs_path}") | ||||||
|  |                     except Exception as e: | ||||||
|  |                         print(f"[FaceRouter] 删除旧图片失败:{str(e)}") | ||||||
|  |             update_fields.append("address = %s") | ||||||
|  |             params.append(face_update.address) | ||||||
|  |  | ||||||
|         if not update_fields: |         if not update_fields: | ||||||
|             raise HTTPException(status_code=400, detail="至少需提供一个更新字段") |             raise HTTPException(status_code=400, detail="至少需提供一个更新字段") | ||||||
| @ -225,117 +231,143 @@ async def update_face( | |||||||
|         cursor.execute(update_query, params) |         cursor.execute(update_query, params) | ||||||
|         conn.commit() |         conn.commit() | ||||||
|  |  | ||||||
|         # 查询更新后记录 |         cursor.execute(""" | ||||||
|         select_query = "SELECT * FROM face WHERE id = %s" |             SELECT id, name, address, created_at, updated_at  | ||||||
|         cursor.execute(select_query, (face_id,)) |             FROM face  | ||||||
|  |             WHERE id = %s | ||||||
|  |         """, (face_id,)) | ||||||
|         updated_face = cursor.fetchone() |         updated_face = cursor.fetchone() | ||||||
|  |  | ||||||
|         return APIResponse( |         return APIResponse( | ||||||
|             code=200, |             code=200, | ||||||
|             message="人脸记录更新成功", |             message="更新成功", | ||||||
|             data=FaceResponse(**updated_face) |             data=FaceResponse(**updated_face) | ||||||
|         ) |         ) | ||||||
|     except MySQLError as e: |     except MySQLError as e: | ||||||
|         if conn: |         if conn: | ||||||
|             conn.rollback() |             conn.rollback() | ||||||
|         raise HTTPException( |         raise HTTPException(status_code=500, detail=f"更新失败: {str(e)}") from e | ||||||
|             status_code=500, |  | ||||||
|             detail=f"更新人脸记录失败: {str(e)}" |  | ||||||
|         ) from e |  | ||||||
|     finally: |     finally: | ||||||
|         db.close_connection(conn, cursor) |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
| # 5. 删除人脸记录(不变、用自增ID删除) | # 5. 删除人脸记录 | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
| @router.delete("/{face_id}", response_model=APIResponse, summary="删除人脸记录") | @router.delete("/{face_id}", response_model=APIResponse, summary="删除人脸记录") | ||||||
| async def delete_face( | async def delete_face(face_id: int): | ||||||
|         face_id: int, |  | ||||||
|         current_user: UserResponse = Depends(get_current_user) |  | ||||||
| ): |  | ||||||
|     conn = None |     conn = None | ||||||
|     cursor = None |     cursor = None | ||||||
|     try: |     try: | ||||||
|         conn = db.get_connection() |         conn = db.get_connection() | ||||||
|         cursor = conn.cursor(dictionary=True) |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|         check_query = "SELECT id FROM face WHERE id = %s" |         cursor.execute("SELECT id, address FROM face WHERE id = %s", (face_id,)) | ||||||
|         cursor.execute(check_query, (face_id,)) |         exist_face = cursor.fetchone() | ||||||
|         existing_face = cursor.fetchone() |         if not exist_face: | ||||||
|         if not existing_face: |             raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在") | ||||||
|             raise HTTPException( |         old_db_path = exist_face["address"] | ||||||
|                 status_code=404, |  | ||||||
|                 detail=f"ID为 {face_id} 的人脸记录不存在" |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|         delete_query = "DELETE FROM face WHERE id = %s" |         cursor.execute("DELETE FROM face WHERE id = %s", (face_id,)) | ||||||
|         cursor.execute(delete_query, (face_id,)) |  | ||||||
|         conn.commit() |         conn.commit() | ||||||
|  |  | ||||||
|  |         # 删除图片 | ||||||
|  |         extra_msg = "" | ||||||
|  |         if old_db_path: | ||||||
|  |             old_abs_path = Path(old_db_path).resolve() | ||||||
|  |             if old_abs_path.exists(): | ||||||
|  |                 try: | ||||||
|  |                     old_abs_path.unlink() | ||||||
|  |                     print(f"[FaceRouter] 已删除图片:{old_abs_path}") | ||||||
|  |                     extra_msg = "(已同步删除图片)" | ||||||
|  |                 except Exception as e: | ||||||
|  |                     print(f"[FaceRouter] 删除图片失败:{str(e)}") | ||||||
|  |                     extra_msg = "(图片删除失败)" | ||||||
|  |             else: | ||||||
|  |                 extra_msg = "(图片不存在)" | ||||||
|  |         else: | ||||||
|  |             extra_msg = "(无关联图片)" | ||||||
|  |  | ||||||
|         return APIResponse( |         return APIResponse( | ||||||
|             code=200, |             code=200, | ||||||
|             message=f"ID为 {face_id} 的人脸记录删除成功", |             message=f"ID为 {face_id} 的记录删除成功 {extra_msg}", | ||||||
|             data=None |             data=None | ||||||
|         ) |         ) | ||||||
|     except MySQLError as e: |     except MySQLError as e: | ||||||
|         if conn: |         if conn: | ||||||
|             conn.rollback() |             conn.rollback() | ||||||
|         raise HTTPException( |         raise HTTPException(status_code=500, detail=f"删除失败: {str(e)}") from e | ||||||
|             status_code=500, |  | ||||||
|             detail=f"删除人脸记录失败: {str(e)}" |  | ||||||
|         ) from e |  | ||||||
|     finally: |     finally: | ||||||
|         db.close_connection(conn, cursor) |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_all_face_name_with_eigenvalue() -> dict: | # ------------------------------ | ||||||
|     """ | # 6. 获取人脸图片 | ||||||
|     获取所有人脸的名称及其对应的特征值、组成字典返回 | # ------------------------------ | ||||||
|     key: 人脸名称(name) | @router.get("/{face_id}/image", summary="获取人脸图片") | ||||||
|     value: 人脸特征值(eigenvalue)、若名称重复则返回平均特征值 | async def get_face_image(face_id: int): | ||||||
|     注: 过滤掉name为None的记录、避免字典key为None的情况 |     conn = None | ||||||
|     """ |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         query = "SELECT address, name FROM face WHERE id = %s" | ||||||
|  |         cursor.execute(query, (face_id,)) | ||||||
|  |         face = cursor.fetchone() | ||||||
|  |  | ||||||
|  |         if not face: | ||||||
|  |             raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在") | ||||||
|  |  | ||||||
|  |         db_path = face["address"] | ||||||
|  |         abs_path = Path(db_path).resolve()  # 转为绝对路径 | ||||||
|  |         if not db_path or not abs_path.exists(): | ||||||
|  |             raise HTTPException(status_code=404, detail=f"图片不存在(路径:{db_path})") | ||||||
|  |  | ||||||
|  |         return FileResponse( | ||||||
|  |             path=abs_path, | ||||||
|  |             filename=f"face_{face_id}_{face['name'] or '未命名'}.{db_path.split('.')[-1]}", | ||||||
|  |             media_type=f"image/{db_path.split('.')[-1]}" | ||||||
|  |         ) | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise HTTPException(status_code=500, detail=f"获取图片失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 内部工具方法 | ||||||
|  | # ------------------------------ | ||||||
|  | def get_all_face_name_with_eigenvalue() -> dict: | ||||||
|     conn = None |     conn = None | ||||||
|     cursor = None |     cursor = None | ||||||
|     try: |     try: | ||||||
|         # 1. 建立数据库连接并获取游标(dictionary=True使结果以字典形式返回) |  | ||||||
|         conn = db.get_connection() |         conn = db.get_connection() | ||||||
|         cursor = conn.cursor(dictionary=True) |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|         # 2. 执行SQL查询: 只获取name非空的记录、减少数据传输 |  | ||||||
|         query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL" |         query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL" | ||||||
|         cursor.execute(query) |         cursor.execute(query) | ||||||
|         faces = cursor.fetchall()  # 返回结果: 列表套字典、如 [{"name":"张三","eigenvalue":...}, ...] |         faces = cursor.fetchall() | ||||||
|  |  | ||||||
|         # 3. 收集同一名称对应的所有特征值(处理名称重复场景) |  | ||||||
|         name_to_eigenvalues = {} |         name_to_eigenvalues = {} | ||||||
|         for face in faces: |         for face in faces: | ||||||
|             name = face["name"] |             name = face["name"] | ||||||
|             eigenvalue = face["eigenvalue"] |             eigenvalue = face["eigenvalue"] | ||||||
|             # 若名称已存在、追加特征值;否则新建列表存储 |  | ||||||
|             if name in name_to_eigenvalues: |             if name in name_to_eigenvalues: | ||||||
|                 name_to_eigenvalues[name].append(eigenvalue) |                 name_to_eigenvalues[name].append(eigenvalue) | ||||||
|             else: |             else: | ||||||
|                 name_to_eigenvalues[name] = [eigenvalue] |                 name_to_eigenvalues[name] = [eigenvalue] | ||||||
|  |  | ||||||
|         # 4. 构建最终字典: 重复名称取平均、唯一名称直接取特征值 |  | ||||||
|         face_dict = {} |         face_dict = {} | ||||||
|         for name, eigenvalues in name_to_eigenvalues.items(): |         for name, eigenvalues in name_to_eigenvalues.items(): | ||||||
|  |  | ||||||
|             # 处理特征值: 多个则求平均、单个则直接使用 |  | ||||||
|             if len(eigenvalues) > 1: |             if len(eigenvalues) > 1: | ||||||
|                 # 调用外部方法计算平均特征值(需确保binary_face_feature_handler已正确导入) |  | ||||||
|                 face_dict[name] = get_average_feature(eigenvalues) |                 face_dict[name] = get_average_feature(eigenvalues) | ||||||
|             else: |             else: | ||||||
|                 # 取列表中唯一的特征值(避免value为列表类型) |  | ||||||
|                 face_dict[name] = eigenvalues[0] |                 face_dict[name] = eigenvalues[0] | ||||||
|  |  | ||||||
|         return face_dict |         return face_dict | ||||||
|  |  | ||||||
|     except MySQLError as e: |     except MySQLError as e: | ||||||
|         # 捕获数据库异常、添加上下文信息后重新抛出(便于定位问题) |         raise Exception(f"获取人脸特征失败: {str(e)}") from e | ||||||
|         raise Exception(f"获取人脸名称与特征值失败: {str(e)}") from e |  | ||||||
|     finally: |     finally: | ||||||
|         # 5. 无论是否异常、均释放数据库连接和游标(避免资源泄漏) |  | ||||||
|         db.close_connection(conn, cursor) |         db.close_connection(conn, cursor) | ||||||
							
								
								
									
										497
									
								
								service/model_service.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										497
									
								
								service/model_service.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,497 @@ | |||||||
|  | from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query | ||||||
|  | from fastapi.responses import FileResponse | ||||||
|  | from mysql.connector import Error as MySQLError | ||||||
|  | import os | ||||||
|  | import shutil | ||||||
|  | from pathlib import Path | ||||||
|  | from datetime import datetime | ||||||
|  |  | ||||||
|  | # 复用项目依赖 | ||||||
|  | from ds.db import db | ||||||
|  | from schema.model_schema import ( | ||||||
|  |     ModelCreateRequest, | ||||||
|  |     ModelUpdateRequest, | ||||||
|  |     ModelResponse, | ||||||
|  |     ModelListResponse | ||||||
|  | ) | ||||||
|  | from schema.response_schema import APIResponse | ||||||
|  | from util.model_util import load_yolo_model  # 使用修复后的模型加载工具 | ||||||
|  |  | ||||||
|  | # 路径配置 | ||||||
|  | CURRENT_FILE_PATH = Path(__file__).resolve() | ||||||
|  | PROJECT_ROOT = CURRENT_FILE_PATH.parent.parent | ||||||
|  | MODEL_SAVE_ROOT = PROJECT_ROOT / "resource" / "models" | ||||||
|  | MODEL_SAVE_ROOT.mkdir(exist_ok=True, parents=True) | ||||||
|  | DB_PATH_PREFIX_TO_REMOVE = str(PROJECT_ROOT) + os.sep | ||||||
|  |  | ||||||
|  | # 模型限制 | ||||||
|  | ALLOWED_MODEL_EXT = {"pt"} | ||||||
|  | MAX_MODEL_SIZE = 100 * 1024 * 1024  # 100MB | ||||||
|  |  | ||||||
|  | # 全局模型变量 | ||||||
|  | global _yolo_model | ||||||
|  | _yolo_model = None | ||||||
|  |  | ||||||
|  | router = APIRouter(prefix="/models", tags=["模型管理"]) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 工具函数:验证模型路径 | ||||||
|  | def get_valid_model_abs_path(relative_path: str) -> str: | ||||||
|  |     try: | ||||||
|  |         relative_path = relative_path.replace("/", os.sep) | ||||||
|  |         model_abs_path = PROJECT_ROOT / relative_path | ||||||
|  |         model_abs_path = model_abs_path.resolve() | ||||||
|  |         model_abs_path_str = str(model_abs_path) | ||||||
|  |  | ||||||
|  |         if not model_abs_path_str.startswith(str(MODEL_SAVE_ROOT)): | ||||||
|  |             raise HTTPException( | ||||||
|  |                 status_code=400, | ||||||
|  |                 detail=f"模型路径非法!允许目录:{str(MODEL_SAVE_ROOT)},当前路径:{model_abs_path_str}" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         if not model_abs_path.exists(): | ||||||
|  |             raise HTTPException( | ||||||
|  |                 status_code=404, | ||||||
|  |                 detail=f"模型文件不存在!路径:{model_abs_path_str}" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         if not model_abs_path.is_file(): | ||||||
|  |             raise HTTPException( | ||||||
|  |                 status_code=400, | ||||||
|  |                 detail=f"路径不是文件!路径:{model_abs_path_str}" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         file_size = model_abs_path.stat().st_size | ||||||
|  |         if file_size > MAX_MODEL_SIZE: | ||||||
|  |             raise HTTPException( | ||||||
|  |                 status_code=400, | ||||||
|  |                 detail=f"模型文件过大({file_size // 1024 // 1024}MB),超过限制{MAX_MODEL_SIZE // 1024 // 1024}MB" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         file_ext = model_abs_path.suffix.lower() | ||||||
|  |         if file_ext not in [f".{ext}" for ext in ALLOWED_MODEL_EXT]: | ||||||
|  |             raise HTTPException( | ||||||
|  |                 status_code=400, | ||||||
|  |                 detail=f"模型格式非法(仅支持{ALLOWED_MODEL_EXT})!当前格式:{file_ext}" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         print(f"[模型路径校验] 成功!路径:{model_abs_path_str},大小:{file_size // 1024}KB") | ||||||
|  |         return model_abs_path_str | ||||||
|  |  | ||||||
|  |     except HTTPException as e: | ||||||
|  |         raise e | ||||||
|  |     except Exception as e: | ||||||
|  |         raise HTTPException( | ||||||
|  |             status_code=500, | ||||||
|  |             detail=f"路径处理失败:{str(e)}" | ||||||
|  |         ) from e | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 1. 上传模型 | ||||||
|  | @router.post("", response_model=APIResponse, summary="上传YOLO模型(.pt格式)") | ||||||
|  | async def upload_model( | ||||||
|  |         name: str = Form(..., description="模型名称"), | ||||||
|  |         description: str = Form(None, description="模型描述"), | ||||||
|  |         is_default: bool = Form(False, description="是否设为默认模型"), | ||||||
|  |         file: UploadFile = File(..., description=f"YOLO模型文件(.pt,最大{MAX_MODEL_SIZE // 1024 // 1024}MB)") | ||||||
|  | ): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     saved_file_path = None | ||||||
|  |     try: | ||||||
|  |         # 校验文件 | ||||||
|  |         file_ext = file.filename.split(".")[-1].lower() if "." in file.filename else "" | ||||||
|  |         if file_ext not in ALLOWED_MODEL_EXT: | ||||||
|  |             raise HTTPException( | ||||||
|  |                 status_code=400, | ||||||
|  |                 detail=f"仅支持{ALLOWED_MODEL_EXT}格式,当前:{file_ext}" | ||||||
|  |             ) | ||||||
|  |         if file.size > MAX_MODEL_SIZE: | ||||||
|  |             raise HTTPException( | ||||||
|  |                 status_code=400, | ||||||
|  |                 detail=f"文件过大!最大{MAX_MODEL_SIZE // 1024 // 1024}MB,当前{file.size // 1024 // 1024}MB" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         # 保存文件 | ||||||
|  |         timestamp = datetime.now().strftime("%Y%m%d%H%M%S") | ||||||
|  |         safe_filename = f"model_{timestamp}_{file.filename.replace(' ', '_')}" | ||||||
|  |         saved_file_path = MODEL_SAVE_ROOT / safe_filename | ||||||
|  |         with open(saved_file_path, "wb") as f: | ||||||
|  |             shutil.copyfileobj(file.file, f) | ||||||
|  |         saved_file_path.chmod(0o644)  # 设置权限 | ||||||
|  |  | ||||||
|  |         # 数据库路径处理 | ||||||
|  |         db_relative_path = str(saved_file_path).replace(DB_PATH_PREFIX_TO_REMOVE, "").replace(os.sep, "/") | ||||||
|  |  | ||||||
|  |         # 数据库操作 | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         if is_default: | ||||||
|  |             cursor.execute("UPDATE model SET is_default = 0") | ||||||
|  |  | ||||||
|  |         insert_sql = """ | ||||||
|  |             INSERT INTO model (name, path, is_default, description, file_size) | ||||||
|  |             VALUES (%s, %s, %s, %s, %s) | ||||||
|  |         """ | ||||||
|  |         cursor.execute(insert_sql, (name, db_relative_path, 1 if is_default else 0, description, file.size)) | ||||||
|  |         conn.commit() | ||||||
|  |  | ||||||
|  |         cursor.execute("SELECT * FROM model WHERE id = LAST_INSERT_ID()") | ||||||
|  |         new_model = cursor.fetchone() | ||||||
|  |         if not new_model: | ||||||
|  |             raise HTTPException(status_code=500, detail="上传成功但无法获取记录") | ||||||
|  |  | ||||||
|  |         # 加载默认模型 | ||||||
|  |         global _yolo_model | ||||||
|  |         if is_default: | ||||||
|  |             valid_abs_path = get_valid_model_abs_path(db_relative_path) | ||||||
|  |             _yolo_model = load_yolo_model(valid_abs_path) | ||||||
|  |             if not _yolo_model: | ||||||
|  |                 raise HTTPException( | ||||||
|  |                     status_code=500, | ||||||
|  |                     detail=f"上传成功,但加载默认模型失败(路径:{valid_abs_path})" | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |         return APIResponse( | ||||||
|  |             code=201, | ||||||
|  |             message=f"模型上传成功!ID:{new_model['id']}", | ||||||
|  |             data=ModelResponse(**new_model) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     except MySQLError as e: | ||||||
|  |         if conn: | ||||||
|  |             conn.rollback() | ||||||
|  |         if saved_file_path and saved_file_path.exists(): | ||||||
|  |             saved_file_path.unlink() | ||||||
|  |         raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e | ||||||
|  |     except Exception as e: | ||||||
|  |         if saved_file_path and saved_file_path.exists(): | ||||||
|  |             saved_file_path.unlink() | ||||||
|  |         raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         await file.close() | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 2. 获取模型列表 | ||||||
|  | @router.get("", response_model=APIResponse, summary="获取模型列表(分页)") | ||||||
|  | async def get_model_list( | ||||||
|  |         page: int = Query(1, ge=1), | ||||||
|  |         page_size: int = Query(10, ge=1, le=100), | ||||||
|  |         name: str = Query(None), | ||||||
|  |         is_default: bool = Query(None) | ||||||
|  | ): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         where_clause = [] | ||||||
|  |         params = [] | ||||||
|  |         if name: | ||||||
|  |             where_clause.append("name LIKE %s") | ||||||
|  |             params.append(f"%{name}%") | ||||||
|  |         if is_default is not None: | ||||||
|  |             where_clause.append("is_default = %s") | ||||||
|  |             params.append(1 if is_default else 0) | ||||||
|  |  | ||||||
|  |         # 总记录数 | ||||||
|  |         count_sql = "SELECT COUNT(*) AS total FROM model" | ||||||
|  |         if where_clause: | ||||||
|  |             count_sql += " WHERE " + " AND ".join(where_clause) | ||||||
|  |         cursor.execute(count_sql, params) | ||||||
|  |         total = cursor.fetchone()["total"] | ||||||
|  |  | ||||||
|  |         # 分页数据 | ||||||
|  |         offset = (page - 1) * page_size | ||||||
|  |         list_sql = "SELECT * FROM model" | ||||||
|  |         if where_clause: | ||||||
|  |             list_sql += " WHERE " + " AND ".join(where_clause) | ||||||
|  |         list_sql += " ORDER BY updated_at DESC LIMIT %s OFFSET %s" | ||||||
|  |         params.extend([page_size, offset]) | ||||||
|  |  | ||||||
|  |         cursor.execute(list_sql, params) | ||||||
|  |         model_list = cursor.fetchall() | ||||||
|  |  | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message=f"获取成功!共{total}条记录", | ||||||
|  |             data=ModelListResponse( | ||||||
|  |                 total=total, | ||||||
|  |                 models=[ModelResponse(**model) for model in model_list] | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 3. 获取默认模型 | ||||||
|  | @router.get("/default", response_model=APIResponse, summary="获取当前默认模型") | ||||||
|  | async def get_default_model(): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         cursor.execute("SELECT * FROM model WHERE is_default = 1") | ||||||
|  |         default_model = cursor.fetchone() | ||||||
|  |  | ||||||
|  |         if not default_model: | ||||||
|  |             raise HTTPException(status_code=404, detail="暂无默认模型") | ||||||
|  |  | ||||||
|  |         valid_abs_path = get_valid_model_abs_path(default_model["path"]) | ||||||
|  |         global _yolo_model | ||||||
|  |  | ||||||
|  |         if not _yolo_model: | ||||||
|  |             _yolo_model = load_yolo_model(valid_abs_path) | ||||||
|  |             if not _yolo_model: | ||||||
|  |                 raise HTTPException( | ||||||
|  |                     status_code=500, | ||||||
|  |                     detail=f"默认模型存在,但加载失败(路径:{valid_abs_path})" | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message="默认模型查询成功", | ||||||
|  |             data=ModelResponse(**default_model) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 4. 获取单个模型详情 | ||||||
|  | @router.get("/{model_id}", response_model=APIResponse, summary="获取单个模型详情") | ||||||
|  | async def get_model(model_id: int): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,)) | ||||||
|  |         model = cursor.fetchone() | ||||||
|  |  | ||||||
|  |         if not model: | ||||||
|  |             raise HTTPException(status_code=404, detail=f"模型不存在!ID:{model_id}") | ||||||
|  |  | ||||||
|  |         try: | ||||||
|  |             model_abs_path = get_valid_model_abs_path(model["path"]) | ||||||
|  |         except HTTPException as e: | ||||||
|  |             return APIResponse( | ||||||
|  |                 code=200, | ||||||
|  |                 message=f"查询成功,但路径异常:{e.detail}", | ||||||
|  |                 data=ModelResponse(**model) | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message="查询成功", | ||||||
|  |             data=ModelResponse(**model) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 5. 更新模型信息 | ||||||
|  | @router.put("/{model_id}", response_model=APIResponse, summary="更新模型信息") | ||||||
|  | async def update_model(model_id: int, model_update: ModelUpdateRequest): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,)) | ||||||
|  |         exist_model = cursor.fetchone() | ||||||
|  |         if not exist_model: | ||||||
|  |             raise HTTPException(status_code=404, detail=f"模型不存在!ID:{model_id}") | ||||||
|  |  | ||||||
|  |         update_fields = [] | ||||||
|  |         params = [] | ||||||
|  |         if model_update.name is not None: | ||||||
|  |             update_fields.append("name = %s") | ||||||
|  |             params.append(model_update.name) | ||||||
|  |         if model_update.description is not None: | ||||||
|  |             update_fields.append("description = %s") | ||||||
|  |             params.append(model_update.description) | ||||||
|  |  | ||||||
|  |         need_load_default = False | ||||||
|  |         if model_update.is_default is not None: | ||||||
|  |             if model_update.is_default: | ||||||
|  |                 cursor.execute("UPDATE model SET is_default = 0") | ||||||
|  |                 update_fields.append("is_default = 1") | ||||||
|  |                 need_load_default = True | ||||||
|  |             else: | ||||||
|  |                 cursor.execute("SELECT COUNT(*) AS cnt FROM model WHERE is_default = 1") | ||||||
|  |                 default_count = cursor.fetchone()["cnt"] | ||||||
|  |                 if default_count == 1 and exist_model["is_default"]: | ||||||
|  |                     raise HTTPException( | ||||||
|  |                         status_code=400, | ||||||
|  |                         detail="当前是唯一默认模型,不可取消!" | ||||||
|  |                     ) | ||||||
|  |                 update_fields.append("is_default = 0") | ||||||
|  |  | ||||||
|  |         if not update_fields: | ||||||
|  |             raise HTTPException(status_code=400, detail="至少需提供一个更新字段") | ||||||
|  |  | ||||||
|  |         params.append(model_id) | ||||||
|  |         update_sql = f""" | ||||||
|  |             UPDATE model  | ||||||
|  |             SET {', '.join(update_fields)}, updated_at = CURRENT_TIMESTAMP  | ||||||
|  |             WHERE id = %s | ||||||
|  |         """ | ||||||
|  |         cursor.execute(update_sql, params) | ||||||
|  |         conn.commit() | ||||||
|  |  | ||||||
|  |         cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,)) | ||||||
|  |         updated_model = cursor.fetchone() | ||||||
|  |  | ||||||
|  |         global _yolo_model | ||||||
|  |         if need_load_default: | ||||||
|  |             valid_abs_path = get_valid_model_abs_path(updated_model["path"]) | ||||||
|  |             _yolo_model = load_yolo_model(valid_abs_path) | ||||||
|  |             if not _yolo_model: | ||||||
|  |                 raise HTTPException( | ||||||
|  |                     status_code=500, | ||||||
|  |                     detail=f"更新成功,但加载新默认模型失败(路径:{valid_abs_path})" | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message="模型更新成功", | ||||||
|  |             data=ModelResponse(**updated_model) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     except MySQLError as e: | ||||||
|  |         if conn: | ||||||
|  |             conn.rollback() | ||||||
|  |         raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 6. 删除模型 | ||||||
|  | @router.delete("/{model_id}", response_model=APIResponse, summary="删除模型") | ||||||
|  | async def delete_model(model_id: int): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,)) | ||||||
|  |         exist_model = cursor.fetchone() | ||||||
|  |         if not exist_model: | ||||||
|  |             raise HTTPException(status_code=404, detail=f"模型不存在!ID:{model_id}") | ||||||
|  |         if exist_model["is_default"]: | ||||||
|  |             raise HTTPException(status_code=400, detail="默认模型不可删除!") | ||||||
|  |  | ||||||
|  |         try: | ||||||
|  |             model_abs_path_str = get_valid_model_abs_path(exist_model["path"]) | ||||||
|  |             model_abs_path = Path(model_abs_path_str) | ||||||
|  |         except HTTPException as e: | ||||||
|  |             cursor.execute("DELETE FROM model WHERE id = %s", (model_id,)) | ||||||
|  |             conn.commit() | ||||||
|  |             return APIResponse( | ||||||
|  |                 code=200, | ||||||
|  |                 message=f"记录删除成功,文件异常:{e.detail}", | ||||||
|  |                 data=None | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         cursor.execute("DELETE FROM model WHERE id = %s", (model_id,)) | ||||||
|  |         conn.commit() | ||||||
|  |  | ||||||
|  |         extra_msg = "" | ||||||
|  |         try: | ||||||
|  |             model_abs_path.unlink() | ||||||
|  |             extra_msg = f"(已删除文件)" | ||||||
|  |         except Exception as e: | ||||||
|  |             extra_msg = f"(文件删除失败:{str(e)})" | ||||||
|  |  | ||||||
|  |         global _yolo_model | ||||||
|  |         if _yolo_model and str(_yolo_model.model_path) == model_abs_path_str: | ||||||
|  |             _yolo_model = None | ||||||
|  |             print(f"[模型删除] 已清空全局模型(路径:{model_abs_path_str})") | ||||||
|  |  | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message=f"模型删除成功!ID:{model_id} {extra_msg}", | ||||||
|  |             data=None | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     except MySQLError as e: | ||||||
|  |         if conn: | ||||||
|  |             conn.rollback() | ||||||
|  |         raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 7. 下载模型文件 | ||||||
|  | @router.get("/{model_id}/download", summary="下载模型文件") | ||||||
|  | async def download_model(model_id: int): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,)) | ||||||
|  |         model = cursor.fetchone() | ||||||
|  |         if not model: | ||||||
|  |             raise HTTPException(status_code=404, detail=f"模型不存在!ID:{model_id}") | ||||||
|  |  | ||||||
|  |         valid_abs_path = get_valid_model_abs_path(model["path"]) | ||||||
|  |         model_abs_path = Path(valid_abs_path) | ||||||
|  |  | ||||||
|  |         return FileResponse( | ||||||
|  |             path=model_abs_path, | ||||||
|  |             filename=f"model_{model_id}_{model['name']}.pt", | ||||||
|  |             media_type="application/octet-stream" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 对外提供当前模型 | ||||||
|  | def get_current_yolo_model(): | ||||||
|  |     """供检测模块获取当前加载的模型""" | ||||||
|  |     global _yolo_model | ||||||
|  |     if not _yolo_model: | ||||||
|  |         conn = None | ||||||
|  |         cursor = None | ||||||
|  |         try: | ||||||
|  |             conn = db.get_connection() | ||||||
|  |             cursor = conn.cursor(dictionary=True) | ||||||
|  |             cursor.execute("SELECT path FROM model WHERE is_default = 1") | ||||||
|  |             default_model = cursor.fetchone() | ||||||
|  |             if not default_model: | ||||||
|  |                 print("[get_current_yolo_model] 暂无默认模型") | ||||||
|  |                 return None | ||||||
|  |  | ||||||
|  |             valid_abs_path = get_valid_model_abs_path(default_model["path"]) | ||||||
|  |             _yolo_model = load_yolo_model(valid_abs_path) | ||||||
|  |             if _yolo_model: | ||||||
|  |                 print(f"[get_current_yolo_model] 自动加载默认模型成功") | ||||||
|  |             else: | ||||||
|  |                 print(f"[get_current_yolo_model] 自动加载默认模型失败") | ||||||
|  |         except Exception as e: | ||||||
|  |             print(f"[get_current_yolo_model] 加载失败:{str(e)}") | ||||||
|  |         finally: | ||||||
|  |             db.close_connection(conn, cursor) | ||||||
|  |     return _yolo_model | ||||||
| @ -1,6 +1,7 @@ | |||||||
| from datetime import timedelta | from datetime import timedelta | ||||||
|  | from typing import Optional | ||||||
|  |  | ||||||
| from fastapi import APIRouter, Depends, HTTPException | from fastapi import APIRouter, Depends, HTTPException, Query | ||||||
| from mysql.connector import Error as MySQLError | from mysql.connector import Error as MySQLError | ||||||
|  |  | ||||||
| from ds.db import db | from ds.db import db | ||||||
| @ -11,7 +12,7 @@ from middle.auth_middleware import ( | |||||||
|     verify_password, |     verify_password, | ||||||
|     create_access_token, |     create_access_token, | ||||||
|     ACCESS_TOKEN_EXPIRE_MINUTES, |     ACCESS_TOKEN_EXPIRE_MINUTES, | ||||||
|     get_current_user |     get_current_user  # 仅保留登录用户校验,移除is_admin导入 | ||||||
| ) | ) | ||||||
|  |  | ||||||
| # 创建用户接口路由(前缀 /users、标签用于 Swagger 分类) | # 创建用户接口路由(前缀 /users、标签用于 Swagger 分类) | ||||||
| @ -152,3 +153,98 @@ async def get_current_user_info( | |||||||
|         data=current_user |         data=current_user | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 4. 获取用户列表(仅需登录权限) | ||||||
|  | # ------------------------------ | ||||||
|  | @router.get("/list", response_model=APIResponse, summary="获取用户列表") | ||||||
|  | async def get_user_list( | ||||||
|  |         page: int = Query(1, ge=1, description="页码,从1开始"), | ||||||
|  |         page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100之间"), | ||||||
|  |         username: Optional[str] = Query(None, description="用户名模糊搜索"), | ||||||
|  |         current_user: UserResponse = Depends(get_current_user)  # 仅需登录即可访问(移除管理员校验) | ||||||
|  | ): | ||||||
|  |     """ | ||||||
|  |     获取用户列表: | ||||||
|  |     - 需登录权限(请求头携带 Token: Bearer <token>) | ||||||
|  |     - 支持分页查询(page=页码,page_size=每页条数) | ||||||
|  |     - 支持用户名模糊搜索(如输入"test"可匹配"test123"、"admin_test"等) | ||||||
|  |     - 仅返回用户ID、用户名、创建时间、更新时间(不包含密码等敏感信息) | ||||||
|  |     """ | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 计算分页偏移量(page从1开始,偏移量=(页码-1)*每页条数) | ||||||
|  |         offset = (page - 1) * page_size | ||||||
|  |  | ||||||
|  |         # 基础查询(仅查非敏感字段) | ||||||
|  |         base_query = """ | ||||||
|  |             SELECT id, username, created_at, updated_at  | ||||||
|  |             FROM users | ||||||
|  |         """ | ||||||
|  |         # 总条数查询(用于分页计算) | ||||||
|  |         count_query = "SELECT COUNT(*) as total FROM users" | ||||||
|  |  | ||||||
|  |         # 条件拼接(支持用户名模糊搜索) | ||||||
|  |         conditions = [] | ||||||
|  |         params = [] | ||||||
|  |         if username: | ||||||
|  |             conditions.append("username LIKE %s") | ||||||
|  |             params.append(f"%{username}%")  # 模糊匹配:%表示任意字符 | ||||||
|  |  | ||||||
|  |         # 构建最终查询语句 | ||||||
|  |         if conditions: | ||||||
|  |             where_clause = " WHERE " + " AND ".join(conditions) | ||||||
|  |             final_query = f"{base_query}{where_clause} LIMIT %s OFFSET %s" | ||||||
|  |             final_count_query = f"{count_query}{where_clause}" | ||||||
|  |             params.extend([page_size, offset])  # 追加分页参数 | ||||||
|  |         else: | ||||||
|  |             final_query = f"{base_query} LIMIT %s OFFSET %s" | ||||||
|  |             final_count_query = count_query | ||||||
|  |             params = [page_size, offset] | ||||||
|  |  | ||||||
|  |         # 1. 查询用户列表数据 | ||||||
|  |         cursor.execute(final_query, params) | ||||||
|  |         users = cursor.fetchall() | ||||||
|  |  | ||||||
|  |         # 2. 查询总条数(用于计算总页数) | ||||||
|  |         count_params = [f"%{username}%"] if username else [] | ||||||
|  |         cursor.execute(final_count_query, count_params) | ||||||
|  |         total = cursor.fetchone()["total"] | ||||||
|  |  | ||||||
|  |         # 3. 转换为UserResponse模型(确保字段匹配) | ||||||
|  |         user_list = [ | ||||||
|  |             UserResponse( | ||||||
|  |                 id=user["id"], | ||||||
|  |                 username=user["username"], | ||||||
|  |                 created_at=user["created_at"], | ||||||
|  |                 updated_at=user["updated_at"] | ||||||
|  |             ) | ||||||
|  |             for user in users | ||||||
|  |         ] | ||||||
|  |  | ||||||
|  |         # 4. 计算总页数(向上取整,如11条数据每页10条=2页) | ||||||
|  |         total_pages = (total + page_size - 1) // page_size | ||||||
|  |  | ||||||
|  |         # 返回结果(包含列表和分页信息) | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message="获取用户列表成功", | ||||||
|  |             data={ | ||||||
|  |                 "users": user_list, | ||||||
|  |                 "pagination": { | ||||||
|  |                     "page": page,          # 当前页码 | ||||||
|  |                     "page_size": page_size, # 每页条数 | ||||||
|  |                     "total": total,         # 总数据量 | ||||||
|  |                     "total_pages": total_pages  # 总页数 | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         ) | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise Exception(f"获取用户列表失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         # 无论成功失败,都关闭数据库连接 | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
| @ -4,6 +4,11 @@ import insightface | |||||||
| from insightface.app import FaceAnalysis | from insightface.app import FaceAnalysis | ||||||
| from io import BytesIO | from io import BytesIO | ||||||
| from PIL import Image | from PIL import Image | ||||||
|  | import logging | ||||||
|  |  | ||||||
|  | # 配置日志(便于排查) | ||||||
|  | logging.basicConfig(level=logging.INFO, format='%(asctime)s - [FaceUtil] - %(levelname)s - %(message)s') | ||||||
|  | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| # 全局变量存储InsightFace引擎和特征列表 | # 全局变量存储InsightFace引擎和特征列表 | ||||||
| _insightface_app = None | _insightface_app = None | ||||||
| @ -11,135 +16,141 @@ _feature_list = [] | |||||||
|  |  | ||||||
|  |  | ||||||
| def init_insightface(): | def init_insightface(): | ||||||
|     """初始化InsightFace引擎""" |     """初始化InsightFace引擎(确保成功后再使用)""" | ||||||
|     global _insightface_app |     global _insightface_app | ||||||
|     try: |     try: | ||||||
|         print("正在初始化InsightFace引擎...") |         if _insightface_app is not None: | ||||||
|         app = FaceAnalysis(name='buffalo_l', root='~/.insightface') |             logger.info("InsightFace引擎已初始化,无需重复执行") | ||||||
|         app.prepare(ctx_id=0, det_size=(640, 640)) |             return _insightface_app | ||||||
|         print("InsightFace引擎初始化完成") |  | ||||||
|  |         logger.info("正在初始化InsightFace引擎(模型:buffalo_l)...") | ||||||
|  |         # 手动指定模型下载路径(避免权限问题,可选) | ||||||
|  |         app = FaceAnalysis( | ||||||
|  |             name='buffalo_l', | ||||||
|  |             root='~/.insightface',  # 模型默认下载路径 | ||||||
|  |             providers=['CPUExecutionProvider']  # 强制用CPU(若有GPU可加'CUDAExecutionProvider') | ||||||
|  |         ) | ||||||
|  |         app.prepare(ctx_id=0, det_size=(640, 640))  # det_size越大,小人脸检测越准 | ||||||
|  |         logger.info("InsightFace引擎初始化完成") | ||||||
|         _insightface_app = app |         _insightface_app = app | ||||||
|         return app |         return app | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         print(f"InsightFace初始化失败: {e}") |         logger.error(f"InsightFace初始化失败:{str(e)}", exc_info=True)  # 打印详细堆栈 | ||||||
|  |         _insightface_app = None | ||||||
|         return None |         return None | ||||||
|  |  | ||||||
|  |  | ||||||
| def add_binary_data(binary_data): | def add_binary_data(binary_data): | ||||||
|     """ |     """ | ||||||
|     接收单张图片的二进制数据、提取特征并保存 |     接收单张图片的二进制数据、提取特征并保存 | ||||||
|  |     返回:(True, 特征值numpy数组) 或 (False, 错误信息字符串) | ||||||
|     参数: |  | ||||||
|         binary_data: 图片的二进制数据(bytes类型) |  | ||||||
|  |  | ||||||
|     返回: |  | ||||||
|         成功提取特征时返回 (True, 特征值numpy数组) |  | ||||||
|         失败时返回 (False, None) |  | ||||||
|     """ |     """ | ||||||
|     global _insightface_app, _feature_list |     global _insightface_app, _feature_list | ||||||
|  |  | ||||||
|  |     # 1. 先检查引擎是否初始化成功 | ||||||
|     if not _insightface_app: |     if not _insightface_app: | ||||||
|         print("引擎未初始化、无法处理") |         init_result = init_insightface()  # 尝试重新初始化 | ||||||
|         return False, None |         if not init_result: | ||||||
|  |             error_msg = "InsightFace引擎未初始化,无法检测人脸" | ||||||
|  |             logger.error(error_msg) | ||||||
|  |             return False, error_msg | ||||||
|  |  | ||||||
|     try: |     try: | ||||||
|         # 直接处理二进制数据: 转换为图像格式 |         # 2. 验证二进制数据有效性 | ||||||
|         img = Image.open(BytesIO(binary_data)) |         if len(binary_data) < 1024:  # 过滤过小的无效图片(小于1KB) | ||||||
|         frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) |             error_msg = f"图片过小({len(binary_data)}字节),可能不是有效图片" | ||||||
|  |             logger.warning(error_msg) | ||||||
|  |             return False, error_msg | ||||||
|  |  | ||||||
|         # 提取特征 |         # 3. 二进制数据转CV2格式(关键步骤,避免通道错误) | ||||||
|         faces = _insightface_app.get(frame) |         try: | ||||||
|         if faces: |             img = Image.open(BytesIO(binary_data)).convert("RGB")  # 强制转RGB | ||||||
|             # 获取当前提取的特征值 |             frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)  # InsightFace需要BGR格式 | ||||||
|             current_feature = faces[0].embedding |  | ||||||
|             # 添加到特征列表 |  | ||||||
|             _feature_list.append(current_feature) |  | ||||||
|             print(f"已累计 {len(_feature_list)} 个特征") |  | ||||||
|             # 返回成功标志和当前特征值 |  | ||||||
|             return True, current_feature |  | ||||||
|         else: |  | ||||||
|             print("二进制数据中未检测到人脸") |  | ||||||
|             return False, None |  | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|         print(f"处理二进制数据出错: {e}") |             error_msg = f"图片格式转换失败:{str(e)}" | ||||||
|         return False, None |             logger.error(error_msg, exc_info=True) | ||||||
|  |             return False, error_msg | ||||||
|  |  | ||||||
|  |         # 4. 检查图片尺寸(避免极端尺寸导致检测失败) | ||||||
|  |         height, width = frame.shape[:2] | ||||||
|  |         if height < 64 or width < 64:  # 人脸检测最小建议尺寸 | ||||||
|  |             error_msg = f"图片尺寸过小({width}x{height}),需至少64x64像素" | ||||||
|  |             logger.warning(error_msg) | ||||||
|  |             return False, error_msg | ||||||
|  |  | ||||||
|  |         # 5. 调用InsightFace检测人脸 | ||||||
|  |         logger.info(f"开始检测人脸(图片尺寸:{width}x{height},格式:BGR)") | ||||||
|  |         faces = _insightface_app.get(frame) | ||||||
|  |  | ||||||
|  |         if not faces: | ||||||
|  |             error_msg = "未检测到人脸(请确保图片包含清晰正面人脸,无遮挡、不模糊)" | ||||||
|  |             logger.warning(error_msg) | ||||||
|  |             return False, error_msg | ||||||
|  |  | ||||||
|  |         # 6. 提取特征并保存 | ||||||
|  |         current_feature = faces[0].embedding | ||||||
|  |         _feature_list.append(current_feature) | ||||||
|  |         logger.info(f"人脸检测成功,提取特征值(维度:{current_feature.shape[0]}),累计特征数:{len(_feature_list)}") | ||||||
|  |         return True, current_feature | ||||||
|  |  | ||||||
|  |     except Exception as e: | ||||||
|  |         error_msg = f"处理图片时发生异常:{str(e)}" | ||||||
|  |         logger.error(error_msg, exc_info=True) | ||||||
|  |         return False, error_msg | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 以下函数保持不变(get_average_feature/clear_features/get_feature_list) | ||||||
| def get_average_feature(features=None): | def get_average_feature(features=None): | ||||||
|     """ |  | ||||||
|     计算多个特征向量的平均值 |  | ||||||
|  |  | ||||||
|     参数: |  | ||||||
|         features: 可选、特征值列表。如果未提供、则使用全局存储的_feature_list |  | ||||||
|                   每个元素可以是字符串格式或numpy数组 |  | ||||||
|  |  | ||||||
|     返回: |  | ||||||
|         单一平均特征向量的numpy数组、若无可计算数据则返回None |  | ||||||
|     """ |  | ||||||
|     global _feature_list |     global _feature_list | ||||||
|  |     try: | ||||||
|     # 如果未提供features参数、则使用全局特征列表 |  | ||||||
|         if features is None: |         if features is None: | ||||||
|             features = _feature_list |             features = _feature_list | ||||||
|  |  | ||||||
|     try: |  | ||||||
|         # 验证输入是否为列表且不为空 |  | ||||||
|         if not isinstance(features, list) or len(features) == 0: |         if not isinstance(features, list) or len(features) == 0: | ||||||
|             print("输入必须是包含至少一个特征值的列表") |             logger.warning("输入必须是包含至少一个特征值的列表") | ||||||
|             return None |             return None | ||||||
|  |  | ||||||
|         # 处理每个特征值 |  | ||||||
|         processed_features = [] |         processed_features = [] | ||||||
|         for i, embedding in enumerate(features): |         for i, embedding in enumerate(features): | ||||||
|             try: |             try: | ||||||
|                 if isinstance(embedding, str): |                 if isinstance(embedding, str): | ||||||
|                     # 处理包含括号和逗号的字符串格式 |  | ||||||
|                     embedding_str = embedding.replace('[', '').replace(']', '').replace(',', ' ').strip() |                     embedding_str = embedding.replace('[', '').replace(']', '').replace(',', ' ').strip() | ||||||
|                     embedding_list = [float(num) for num in embedding_str.split() if num.strip()] |                     embedding_list = [float(num) for num in embedding_str.split() if num.strip()] | ||||||
|                     embedding_np = np.array(embedding_list, dtype=np.float32) |                     embedding_np = np.array(embedding_list, dtype=np.float32) | ||||||
|                 else: |                 else: | ||||||
|                     embedding_np = np.array(embedding, dtype=np.float32) |                     embedding_np = np.array(embedding, dtype=np.float32) | ||||||
|  |  | ||||||
|                 # 验证特征值格式 |  | ||||||
|                 if len(embedding_np.shape) == 1: |                 if len(embedding_np.shape) == 1: | ||||||
|                     processed_features.append(embedding_np) |                     processed_features.append(embedding_np) | ||||||
|                     print(f"已添加第 {i + 1} 个特征值用于计算平均值") |                     logger.info(f"已添加第 {i + 1} 个特征值用于计算平均值") | ||||||
|                 else: |                 else: | ||||||
|                     print(f"跳过第 {i + 1} 个特征值、不是一维数组") |                     logger.warning(f"跳过第 {i + 1} 个特征值:不是一维数组") | ||||||
|  |  | ||||||
|             except Exception as e: |             except Exception as e: | ||||||
|                 print(f"处理第 {i + 1} 个特征值时出错: {e}") |                 logger.error(f"处理第 {i + 1} 个特征值时出错:{str(e)}") | ||||||
|  |  | ||||||
|         # 确保有有效的特征值 |  | ||||||
|         if not processed_features: |         if not processed_features: | ||||||
|             print("没有有效的特征值用于计算平均值") |             logger.warning("没有有效的特征值用于计算平均值") | ||||||
|             return None |             return None | ||||||
|  |  | ||||||
|         # 检查所有特征向量维度是否相同 |  | ||||||
|         dims = {feat.shape[0] for feat in processed_features} |         dims = {feat.shape[0] for feat in processed_features} | ||||||
|         if len(dims) > 1: |         if len(dims) > 1: | ||||||
|             print(f"特征值维度不一致、无法计算平均值。检测到的维度: {dims}") |             logger.error(f"特征值维度不一致:{dims},无法计算平均值") | ||||||
|             return None |             return None | ||||||
|  |  | ||||||
|         # 计算平均值 |  | ||||||
|         avg_feature = np.mean(processed_features, axis=0) |         avg_feature = np.mean(processed_features, axis=0) | ||||||
|         print(f"成功计算 {len(processed_features)} 个特征值的平均特征向量、维度: {avg_feature.shape[0]}") |         logger.info(f"计算成功:{len(processed_features)} 个特征值的平均向量(维度:{avg_feature.shape[0]})") | ||||||
|  |  | ||||||
|         return avg_feature |         return avg_feature | ||||||
|  |  | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         print(f"计算平均特征值时出错: {e}") |         logger.error(f"计算平均特征值出错:{str(e)}", exc_info=True) | ||||||
|         return None |         return None | ||||||
|  |  | ||||||
|  |  | ||||||
| def clear_features(): | def clear_features(): | ||||||
|     """清空已存储的特征数据""" |  | ||||||
|     global _feature_list |     global _feature_list | ||||||
|     _feature_list = [] |     _feature_list = [] | ||||||
|     print("已清空所有特征数据") |     logger.info("已清空所有特征数据") | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_feature_list(): | def get_feature_list(): | ||||||
|     """获取当前存储的特征列表""" |  | ||||||
|     global _feature_list |     global _feature_list | ||||||
|     return _feature_list.copy()  # 返回副本防止外部直接修改 |     logger.info(f"当前特征列表长度:{len(_feature_list)}") | ||||||
|  |     return _feature_list.copy() | ||||||
							
								
								
									
										83
									
								
								util/file_util.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								util/file_util.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,83 @@ | |||||||
|  | import os | ||||||
|  | import datetime | ||||||
|  | from pathlib import Path | ||||||
|  | from typing import Dict | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def save_face_to_up_images( | ||||||
|  |         client_ip: str, | ||||||
|  |         face_name: str, | ||||||
|  |         image_bytes: bytes, | ||||||
|  |         image_format: str = "jpg" | ||||||
|  | ) -> Dict[str, str]: | ||||||
|  |     """ | ||||||
|  |     保存人脸图片到 `/up_images/用户IP/人脸名字/` 路径 | ||||||
|  |     修复路径计算错误,确保所有路径在up_images根目录下 | ||||||
|  |  | ||||||
|  |     参数: | ||||||
|  |         client_ip: 客户端IP(原始格式,如192.168.1.101) | ||||||
|  |         face_name: 人脸名称(用户输入,可为空) | ||||||
|  |         image_bytes: 人脸图片二进制数据 | ||||||
|  |         image_format: 图片格式(默认jpg) | ||||||
|  |  | ||||||
|  |     返回: | ||||||
|  |         字典:success(是否成功)、db_path(存数据库的相对路径)、local_abs_path(本地绝对路径)、msg(提示) | ||||||
|  |     """ | ||||||
|  |     try: | ||||||
|  |         # 1. 基础参数校验 | ||||||
|  |         if not client_ip.strip(): | ||||||
|  |             return {"success": False, "db_path": "", "local_abs_path": "", "msg": "客户端IP不能为空"} | ||||||
|  |         if not image_bytes: | ||||||
|  |             return {"success": False, "db_path": "", "local_abs_path": "", "msg": "图片二进制数据为空"} | ||||||
|  |         if image_format.lower() not in ["jpg", "jpeg", "png"]: | ||||||
|  |             return {"success": False, "db_path": "", "local_abs_path": "", "msg": "仅支持jpg/jpeg/png格式"} | ||||||
|  |  | ||||||
|  |         # 2. 处理特殊字符(避免路径错误) | ||||||
|  |         safe_ip = client_ip.strip().replace(".", "_")  # IP中的.替换为_ | ||||||
|  |         safe_face_name = face_name.strip() if (face_name and face_name.strip()) else "未命名" | ||||||
|  |         safe_face_name = "".join([c for c in safe_face_name if c not in r'\/:*?"<>|'])  # 过滤非法字符 | ||||||
|  |  | ||||||
|  |         # 3. 构建根目录(强制转为绝对路径,避免相对路径混淆) | ||||||
|  |         root_dir = Path("up_images").resolve()  # 转为绝对路径(关键修复!) | ||||||
|  |         if not root_dir.exists(): | ||||||
|  |             root_dir.mkdir(parents=True, exist_ok=True) | ||||||
|  |             print(f"[FileUtil] 已创建up_images根目录:{root_dir}") | ||||||
|  |  | ||||||
|  |         # 4. 构建文件层级路径(确保在root_dir子目录下) | ||||||
|  |         ip_dir = root_dir / safe_ip | ||||||
|  |         face_name_dir = ip_dir / safe_face_name | ||||||
|  |         face_name_dir.mkdir(parents=True, exist_ok=True)  # 自动创建目录 | ||||||
|  |         print(f"[FileUtil] 图片存储目录:{face_name_dir}") | ||||||
|  |  | ||||||
|  |         # 5. 生成唯一文件名(毫秒级时间戳) | ||||||
|  |         timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3] | ||||||
|  |         image_filename = f"face_{safe_ip}_{timestamp}.{image_format.lower()}" | ||||||
|  |  | ||||||
|  |         # 6. 计算路径(关键修复:确保所有路径都是绝对路径且在root_dir下) | ||||||
|  |         local_abs_path = face_name_dir / image_filename  # 绝对路径 | ||||||
|  |  | ||||||
|  |         # 验证路径是否在root_dir下(防止路径穿越攻击) | ||||||
|  |         if not local_abs_path.resolve().is_relative_to(root_dir.resolve()): | ||||||
|  |             raise Exception(f"图片路径不在up_images根目录下(安全校验失败):{local_abs_path}") | ||||||
|  |  | ||||||
|  |         # 数据库存储路径:从root_dir开始的相对路径(如 up_images/192_168_110_31/小王/xxx.jpg) | ||||||
|  |         db_path = str(root_dir.name / local_abs_path.relative_to(root_dir)) | ||||||
|  |  | ||||||
|  |         # 7. 写入图片文件 | ||||||
|  |         with open(local_abs_path, "wb") as f: | ||||||
|  |             f.write(image_bytes) | ||||||
|  |         print(f"[FileUtil] 图片保存成功:") | ||||||
|  |         print(f"  数据库路径:{db_path}") | ||||||
|  |         print(f"  本地绝对路径:{local_abs_path}") | ||||||
|  |  | ||||||
|  |         return { | ||||||
|  |             "success": True, | ||||||
|  |             "db_path": db_path,  # 存数据库的相对路径(up_images开头) | ||||||
|  |             "local_abs_path": str(local_abs_path),  # 本地绝对路径 | ||||||
|  |             "msg": "图片保存成功" | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |     except Exception as e: | ||||||
|  |         error_msg = f"图片保存失败:{str(e)}" | ||||||
|  |         print(f"[FileUtil] 错误:{error_msg}") | ||||||
|  |         return {"success": False, "db_path": "", "local_abs_path": "", "msg": error_msg} | ||||||
							
								
								
									
										61
									
								
								util/model_util.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								util/model_util.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,61 @@ | |||||||
|  | import os | ||||||
|  | import numpy as np | ||||||
|  | import traceback | ||||||
|  | from ultralytics import YOLO | ||||||
|  | from typing import Optional | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def load_yolo_model(model_path: str) -> Optional[YOLO]: | ||||||
|  |     """ | ||||||
|  |     加载YOLO模型(支持v5/v8),并校验模型有效性 | ||||||
|  |     :param model_path: 模型文件的绝对路径 | ||||||
|  |     :return: 加载成功返回YOLO模型实例,失败返回None | ||||||
|  |     """ | ||||||
|  |     try: | ||||||
|  |         # 加载前的基础信息检查 | ||||||
|  |         print(f"\n[模型工具] 开始加载模型:{model_path}") | ||||||
|  |         print(f"[模型工具] 文件是否存在:{os.path.exists(model_path)}") | ||||||
|  |         if os.path.exists(model_path): | ||||||
|  |             print(f"[模型工具] 文件大小:{os.path.getsize(model_path) / 1024 / 1024:.2f} MB") | ||||||
|  |  | ||||||
|  |         # 强制重新加载模型,避免缓存问题 | ||||||
|  |         model = YOLO(model_path) | ||||||
|  |  | ||||||
|  |         # 兼容性校验:使用numpy空数组测试模型 | ||||||
|  |         dummy_image = np.zeros((640, 640, 3), dtype=np.uint8) | ||||||
|  |  | ||||||
|  |         try: | ||||||
|  |             # 优先使用新版本参数 | ||||||
|  |             model.predict( | ||||||
|  |                 source=dummy_image, | ||||||
|  |                 imgsz=640, | ||||||
|  |                 conf=0.25, | ||||||
|  |                 verbose=False, | ||||||
|  |                 stream=False | ||||||
|  |             ) | ||||||
|  |         except Exception as pred_e: | ||||||
|  |             print(f"[模型工具] 预测校验兼容处理:{str(pred_e)}") | ||||||
|  |             # 兼容旧版本YOLO参数 | ||||||
|  |             model.predict( | ||||||
|  |                 img=dummy_image, | ||||||
|  |                 imgsz=640, | ||||||
|  |                 conf=0.25, | ||||||
|  |                 verbose=False | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         # 验证模型基本属性 | ||||||
|  |         if not hasattr(model, 'names'): | ||||||
|  |             print("[模型工具] 警告:模型缺少类别名称属性") | ||||||
|  |         else: | ||||||
|  |             print(f"[模型工具] 模型包含类别:{list(model.names.values())[:5]}...")  # 显示前5个类别 | ||||||
|  |  | ||||||
|  |         print(f"[模型工具] 模型加载成功!") | ||||||
|  |         return model | ||||||
|  |  | ||||||
|  |     except Exception as e: | ||||||
|  |         # 详细错误信息输出 | ||||||
|  |         print(f"\n[模型工具] 加载模型失败!路径:{model_path}") | ||||||
|  |         print(f"[模型工具] 异常类型:{type(e).__name__}") | ||||||
|  |         print(f"[模型工具] 异常详情:{str(e)}") | ||||||
|  |         print(f"[模型工具] 堆栈跟踪:\n{traceback.format_exc()}") | ||||||
|  |         return None | ||||||
		Reference in New Issue
	
	Block a user