内容安全审核
							
								
								
									
										8
									
								
								.idea/.gitignore
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,8 @@ | |||||||
|  | # Default ignored files | ||||||
|  | /shelf/ | ||||||
|  | /workspace.xml | ||||||
|  | # Editor-based HTTP Client requests | ||||||
|  | /httpRequests/ | ||||||
|  | # Datasource local storage ignored files | ||||||
|  | /dataSources/ | ||||||
|  | /dataSources.local.xml | ||||||
							
								
								
									
										8
									
								
								.idea/Detect.iml
									
									
									
										generated
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,8 @@ | |||||||
|  | <?xml version="1.0" encoding="UTF-8"?> | ||||||
|  | <module type="PYTHON_MODULE" version="4"> | ||||||
|  |   <component name="NewModuleRootManager"> | ||||||
|  |     <content url="file://$MODULE_DIR$" /> | ||||||
|  |     <orderEntry type="jdk" jdkName="Python 3.10 (2)" jdkType="Python SDK" /> | ||||||
|  |     <orderEntry type="sourceFolder" forTests="false" /> | ||||||
|  |   </component> | ||||||
|  | </module> | ||||||
							
								
								
									
										98
									
								
								.idea/inspectionProfiles/Project_Default.xml
									
									
									
										generated
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,98 @@ | |||||||
|  | <component name="InspectionProjectProfileManager"> | ||||||
|  |   <profile version="1.0"> | ||||||
|  |     <option name="myName" value="Project Default" /> | ||||||
|  |     <inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" /> | ||||||
|  |     <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true"> | ||||||
|  |       <option name="ignoredPackages"> | ||||||
|  |         <value> | ||||||
|  |           <list size="76"> | ||||||
|  |             <item index="0" class="java.lang.String" itemvalue="scipy" /> | ||||||
|  |             <item index="1" class="java.lang.String" itemvalue="protobuf" /> | ||||||
|  |             <item index="2" class="java.lang.String" itemvalue="thop" /> | ||||||
|  |             <item index="3" class="java.lang.String" itemvalue="opencv-python" /> | ||||||
|  |             <item index="4" class="java.lang.String" itemvalue="PyYAML" /> | ||||||
|  |             <item index="5" class="java.lang.String" itemvalue="ipython" /> | ||||||
|  |             <item index="6" class="java.lang.String" itemvalue="torch" /> | ||||||
|  |             <item index="7" class="java.lang.String" itemvalue="numpy" /> | ||||||
|  |             <item index="8" class="java.lang.String" itemvalue="requests" /> | ||||||
|  |             <item index="9" class="java.lang.String" itemvalue="torchvision" /> | ||||||
|  |             <item index="10" class="java.lang.String" itemvalue="psutil" /> | ||||||
|  |             <item index="11" class="java.lang.String" itemvalue="tqdm" /> | ||||||
|  |             <item index="12" class="java.lang.String" itemvalue="pandas" /> | ||||||
|  |             <item index="13" class="java.lang.String" itemvalue="tensorboard" /> | ||||||
|  |             <item index="14" class="java.lang.String" itemvalue="seaborn" /> | ||||||
|  |             <item index="15" class="java.lang.String" itemvalue="matplotlib" /> | ||||||
|  |             <item index="16" class="java.lang.String" itemvalue="Pillow" /> | ||||||
|  |             <item index="17" class="java.lang.String" itemvalue="fastapi" /> | ||||||
|  |             <item index="18" class="java.lang.String" itemvalue="uvicorn" /> | ||||||
|  |             <item index="19" class="java.lang.String" itemvalue="python-jose" /> | ||||||
|  |             <item index="20" class="java.lang.String" itemvalue="passlib" /> | ||||||
|  |             <item index="21" class="java.lang.String" itemvalue="pydantic" /> | ||||||
|  |             <item index="22" class="java.lang.String" itemvalue="sqlalchemy" /> | ||||||
|  |             <item index="23" class="java.lang.String" itemvalue="imageio_ffmpeg" /> | ||||||
|  |             <item index="24" class="java.lang.String" itemvalue="ultralytics" /> | ||||||
|  |             <item index="25" class="java.lang.String" itemvalue="future" /> | ||||||
|  |             <item index="26" class="java.lang.String" itemvalue="jose" /> | ||||||
|  |             <item index="27" class="java.lang.String" itemvalue="ffmpeg-python" /> | ||||||
|  |             <item index="28" class="java.lang.String" itemvalue="setuptools" /> | ||||||
|  |             <item index="29" class="java.lang.String" itemvalue="opencv_python" /> | ||||||
|  |             <item index="30" class="java.lang.String" itemvalue="rsa" /> | ||||||
|  |             <item index="31" class="java.lang.String" itemvalue="greenlet" /> | ||||||
|  |             <item index="32" class="java.lang.String" itemvalue="networkx" /> | ||||||
|  |             <item index="33" class="java.lang.String" itemvalue="python-dateutil" /> | ||||||
|  |             <item index="34" class="java.lang.String" itemvalue="SQLAlchemy" /> | ||||||
|  |             <item index="35" class="java.lang.String" itemvalue="cffi" /> | ||||||
|  |             <item index="36" class="java.lang.String" itemvalue="python-dotenv" /> | ||||||
|  |             <item index="37" class="java.lang.String" itemvalue="h11" /> | ||||||
|  |             <item index="38" class="java.lang.String" itemvalue="py-cpuinfo" /> | ||||||
|  |             <item index="39" class="java.lang.String" itemvalue="cycler" /> | ||||||
|  |             <item index="40" class="java.lang.String" itemvalue="MarkupSafe" /> | ||||||
|  |             <item index="41" class="java.lang.String" itemvalue="pyasn1" /> | ||||||
|  |             <item index="42" class="java.lang.String" itemvalue="pycparser" /> | ||||||
|  |             <item index="43" class="java.lang.String" itemvalue="Jinja2" /> | ||||||
|  |             <item index="44" class="java.lang.String" itemvalue="sniffio" /> | ||||||
|  |             <item index="45" class="java.lang.String" itemvalue="ultralytics-thop" /> | ||||||
|  |             <item index="46" class="java.lang.String" itemvalue="fsspec" /> | ||||||
|  |             <item index="47" class="java.lang.String" itemvalue="filelock" /> | ||||||
|  |             <item index="48" class="java.lang.String" itemvalue="starlette" /> | ||||||
|  |             <item index="49" class="java.lang.String" itemvalue="certifi" /> | ||||||
|  |             <item index="50" class="java.lang.String" itemvalue="anyio" /> | ||||||
|  |             <item index="51" class="java.lang.String" itemvalue="urllib3" /> | ||||||
|  |             <item index="52" class="java.lang.String" itemvalue="pyparsing" /> | ||||||
|  |             <item index="53" class="java.lang.String" itemvalue="sympy" /> | ||||||
|  |             <item index="54" class="java.lang.String" itemvalue="annotated-types" /> | ||||||
|  |             <item index="55" class="java.lang.String" itemvalue="pydantic-settings" /> | ||||||
|  |             <item index="56" class="java.lang.String" itemvalue="six" /> | ||||||
|  |             <item index="57" class="java.lang.String" itemvalue="tzdata" /> | ||||||
|  |             <item index="58" class="java.lang.String" itemvalue="ecdsa" /> | ||||||
|  |             <item index="59" class="java.lang.String" itemvalue="kiwisolver" /> | ||||||
|  |             <item index="60" class="java.lang.String" itemvalue="packaging" /> | ||||||
|  |             <item index="61" class="java.lang.String" itemvalue="python-multipart" /> | ||||||
|  |             <item index="62" class="java.lang.String" itemvalue="click" /> | ||||||
|  |             <item index="63" class="java.lang.String" itemvalue="contourpy" /> | ||||||
|  |             <item index="64" class="java.lang.String" itemvalue="fonttools" /> | ||||||
|  |             <item index="65" class="java.lang.String" itemvalue="pydantic_core" /> | ||||||
|  |             <item index="66" class="java.lang.String" itemvalue="av" /> | ||||||
|  |             <item index="67" class="java.lang.String" itemvalue="colorama" /> | ||||||
|  |             <item index="68" class="java.lang.String" itemvalue="mpmath" /> | ||||||
|  |             <item index="69" class="java.lang.String" itemvalue="argon2-cffi-bindings" /> | ||||||
|  |             <item index="70" class="java.lang.String" itemvalue="typing_extensions" /> | ||||||
|  |             <item index="71" class="java.lang.String" itemvalue="charset-normalizer" /> | ||||||
|  |             <item index="72" class="java.lang.String" itemvalue="pillow" /> | ||||||
|  |             <item index="73" class="java.lang.String" itemvalue="argon2-cffi" /> | ||||||
|  |             <item index="74" class="java.lang.String" itemvalue="pytz" /> | ||||||
|  |             <item index="75" class="java.lang.String" itemvalue="idna" /> | ||||||
|  |           </list> | ||||||
|  |         </value> | ||||||
|  |       </option> | ||||||
|  |     </inspection_tool> | ||||||
|  |     <inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true"> | ||||||
|  |       <option name="ignoredErrors"> | ||||||
|  |         <list> | ||||||
|  |           <option value="N802" /> | ||||||
|  |         </list> | ||||||
|  |       </option> | ||||||
|  |     </inspection_tool> | ||||||
|  |     <inspection_tool class="Stylelint" enabled="true" level="ERROR" enabled_by_default="true" /> | ||||||
|  |   </profile> | ||||||
|  | </component> | ||||||
							
								
								
									
										6
									
								
								.idea/inspectionProfiles/profiles_settings.xml
									
									
									
										generated
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,6 @@ | |||||||
|  | <component name="InspectionProjectProfileManager"> | ||||||
|  |   <settings> | ||||||
|  |     <option name="USE_PROJECT_PROFILE" value="false" /> | ||||||
|  |     <version value="1.0" /> | ||||||
|  |   </settings> | ||||||
|  | </component> | ||||||
							
								
								
									
										7
									
								
								.idea/misc.xml
									
									
									
										generated
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,7 @@ | |||||||
|  | <?xml version="1.0" encoding="UTF-8"?> | ||||||
|  | <project version="4"> | ||||||
|  |   <component name="Black"> | ||||||
|  |     <option name="sdkName" value="video" /> | ||||||
|  |   </component> | ||||||
|  |   <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (2)" project-jdk-type="Python SDK" /> | ||||||
|  | </project> | ||||||
							
								
								
									
										8
									
								
								.idea/modules.xml
									
									
									
										generated
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,8 @@ | |||||||
|  | <?xml version="1.0" encoding="UTF-8"?> | ||||||
|  | <project version="4"> | ||||||
|  |   <component name="ProjectModuleManager"> | ||||||
|  |     <modules> | ||||||
|  |       <module fileurl="file://$PROJECT_DIR$/.idea/Detect.iml" filepath="$PROJECT_DIR$/.idea/Detect.iml" /> | ||||||
|  |     </modules> | ||||||
|  |   </component> | ||||||
|  | </project> | ||||||
							
								
								
									
										
											BIN
										
									
								
								__pycache__/main.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										20
									
								
								config.ini
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,20 @@ | |||||||
|  | [server] | ||||||
|  | port = 8000 | ||||||
|  |  | ||||||
|  | [mysql] | ||||||
|  | host = 192.168.110.65 | ||||||
|  | port = 6975 | ||||||
|  | user = video_check | ||||||
|  | password = fsjPfhxCs8NrFGmL | ||||||
|  | database = video_check | ||||||
|  | charset = utf8mb4 | ||||||
|  |  | ||||||
|  | [jwt] | ||||||
|  | secret_key = 6tsieyd87wefdw2wgeduwte23rfcsd | ||||||
|  | algorithm = HS256 | ||||||
|  | access_token_expire_minutes = 30 | ||||||
|  |  | ||||||
|  | [business] | ||||||
|  | ocr_conf = 0.6 | ||||||
|  | face_conf = 0.6 | ||||||
|  | yolo_conf = 0.7 | ||||||
							
								
								
									
										
											BIN
										
									
								
								core/__pycache__/detect.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										140
									
								
								core/detect.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,140 @@ | |||||||
|  | import json | ||||||
|  | import re | ||||||
|  |  | ||||||
|  | from mysql.connector import Error as MySQLError | ||||||
|  |  | ||||||
|  | from ds.config import BUSINESS_CONFIG | ||||||
|  | from ds.db import db | ||||||
|  | from service.face_service import detect as faceDetect,init_insightface | ||||||
|  | from service.model_service import load_yolo_model,detect as yoloDetect | ||||||
|  | from service.ocr_service import detect as ocrDetect,init_ocr_engine | ||||||
|  | from service.file_service import save_detect_file, save_detect_yolo_file, save_detect_face_file | ||||||
|  | import asyncio | ||||||
|  | from concurrent.futures import ThreadPoolExecutor | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 创建线程池执行器 | ||||||
|  | executor = ThreadPoolExecutor(max_workers=10) | ||||||
|  | def init(): | ||||||
|  |     # # 人脸相关 | ||||||
|  |     init_insightface() | ||||||
|  |     # # 初始化OCR引擎 | ||||||
|  |     init_ocr_engine() | ||||||
|  |     #初始化YOLO模型 | ||||||
|  |     load_yolo_model() | ||||||
|  |  | ||||||
|  | def save_db(model_type, client_ip, result): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         # 连接数据库 | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         # 往表插入数据 | ||||||
|  |         cursor = conn.cursor(dictionary=True)  # 返回字典格式结果 | ||||||
|  |         insert_query = """ | ||||||
|  |             INSERT INTO device_danger (client_ip, type, result) | ||||||
|  |             VALUES (%s, %s, %s) | ||||||
|  |         """ | ||||||
|  |         cursor.execute(insert_query, (client_ip, model_type, result)) | ||||||
|  |         conn.commit() | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise Exception(f"获取设备列表失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def detectFrame(client_ip, frame): | ||||||
|  |  | ||||||
|  |     # YOLO检测 | ||||||
|  |     yolo_flag, yolo_result = yoloDetect(frame, float(BUSINESS_CONFIG["yolo_conf"])) | ||||||
|  |     if yolo_flag: | ||||||
|  |         print(f"❌ 检测到违规内容,保存图片,YOLO") | ||||||
|  |         danger_handler(client_ip) | ||||||
|  |         path = save_detect_yolo_file(client_ip, frame, yolo_result, "yolo") | ||||||
|  |         save_db(model_type="色情", client_ip=client_ip, result=str(path)) | ||||||
|  |  | ||||||
|  |     # 人脸检测 | ||||||
|  |     face_flag, face_result = faceDetect(frame, float(BUSINESS_CONFIG["face_conf"])) | ||||||
|  |     if face_flag: | ||||||
|  |         print(f"❌ 检测到违规内容,保存图片,FACE") | ||||||
|  |         print("人脸识别内容:", face_result) | ||||||
|  |         model_type = extract_face_names(face_result) | ||||||
|  |         danger_handler(client_ip) | ||||||
|  |         path = save_detect_face_file(client_ip, frame, face_result, "face") | ||||||
|  |         save_db(model_type=model_type, client_ip=client_ip, result=str(path)) | ||||||
|  |  | ||||||
|  |     # OCR检测部分(使用修正后的提取函数) | ||||||
|  |     ocr_flag, ocr_result = ocrDetect(frame, float(BUSINESS_CONFIG["ocr_conf"])) | ||||||
|  |     if ocr_flag: | ||||||
|  |         print(f"❌ 检测到违规内容,保存图片,OCR") | ||||||
|  |         print("ocr识别内容:", ocr_result) | ||||||
|  |         danger_handler(client_ip) | ||||||
|  |         path = save_detect_file(client_ip, frame, "ocr") | ||||||
|  |         save_db(model_type=str(ocr_result), client_ip=client_ip, result=str(path)) | ||||||
|  |  | ||||||
|  |     # 仅当所有检测均未发现违规时才提示 | ||||||
|  |     if not (face_flag or yolo_flag or ocr_flag): | ||||||
|  |         print(f"所有模型未检测到任何违规内容") | ||||||
|  |  | ||||||
|  | def danger_handler(client_ip): | ||||||
|  |     from ws.ws import send_message_to_client, get_current_time_str | ||||||
|  |     from service.device_service import increment_alarm_count_by_ip | ||||||
|  |     from service.device_service import update_is_need_handler_by_client_ip | ||||||
|  |  | ||||||
|  |     danger_msg = { | ||||||
|  |         "type": "danger", | ||||||
|  |         "timestamp": get_current_time_str(), | ||||||
|  |         "client_ip": client_ip, | ||||||
|  |     } | ||||||
|  |     asyncio.run( | ||||||
|  |         send_message_to_client( | ||||||
|  |             client_ip=client_ip, | ||||||
|  |             json_data=json.dumps(danger_msg) | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|  |     lock_msg = { | ||||||
|  |         "type": "lock", | ||||||
|  |         "timestamp": get_current_time_str(), | ||||||
|  |         "client_ip": client_ip | ||||||
|  |     } | ||||||
|  |     asyncio.run( | ||||||
|  |         send_message_to_client( | ||||||
|  |             client_ip=client_ip, | ||||||
|  |             json_data=json.dumps(lock_msg) | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     # 增加危险记录次数 | ||||||
|  |     increment_alarm_count_by_ip(client_ip) | ||||||
|  |  | ||||||
|  |     # 更新设备状态为未处理 | ||||||
|  |     update_is_need_handler_by_client_ip(client_ip, 1) | ||||||
|  |  | ||||||
|  | def extract_prohibited_words(ocr_result: str) -> str: | ||||||
|  |     """ | ||||||
|  |     从多文本块的ocr_result中提取所有违禁词(去重后用逗号拼接) | ||||||
|  |     适配格式:多个"文本: ... 包含违禁词: ...;"片段 | ||||||
|  |     """ | ||||||
|  |     # 用正则匹配所有"包含违禁词: ...;"的片段(非贪婪匹配到分号) | ||||||
|  |     # 匹配规则:"包含违禁词: "后面的内容,直到遇到";"结束 | ||||||
|  |     pattern = r"包含违禁词: (.*?);" | ||||||
|  |     all_prohibited_segments = re.findall(pattern, ocr_result, re.DOTALL) | ||||||
|  |  | ||||||
|  |     all_words = [] | ||||||
|  |     for segment in all_prohibited_segments: | ||||||
|  |         # 去除每个片段中的置信度信息(如"(置信度: 1.00)") | ||||||
|  |         cleaned = re.sub(r"\s*\([^)]*\)", "", segment.strip()) | ||||||
|  |         # 分割词语并过滤空值 | ||||||
|  |         words = [word.strip() for word in cleaned.split(",") if word.strip()] | ||||||
|  |         all_words.extend(words) | ||||||
|  |  | ||||||
|  |     # 去重后用逗号拼接 | ||||||
|  |     unique_words = list(set(all_words)) | ||||||
|  |     return ",".join(unique_words) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def extract_face_names(face_result: str) -> str: | ||||||
|  |     pattern = r"匹配: (.*?) \(" | ||||||
|  |     all_names = re.findall(pattern, face_result) | ||||||
|  |     unique_names = list(set([name.strip() for name in all_names if name.strip()])) | ||||||
|  |     return ",".join(unique_names) | ||||||
							
								
								
									
										
											BIN
										
									
								
								ds/__pycache__/config.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										
											BIN
										
									
								
								ds/__pycache__/db.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										17
									
								
								ds/config.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,17 @@ | |||||||
|  | import configparser | ||||||
|  | import os | ||||||
|  |  | ||||||
|  | # 读取配置文件路径 | ||||||
|  | config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../config.ini") | ||||||
|  |  | ||||||
|  | # 初始化配置解析器 | ||||||
|  | config = configparser.ConfigParser() | ||||||
|  |  | ||||||
|  | # 读取配置文件 | ||||||
|  | config.read(config_path, encoding="utf-8") | ||||||
|  |  | ||||||
|  | # 暴露配置项(方便其他文件调用) | ||||||
|  | SERVER_CONFIG = config["server"] | ||||||
|  | MYSQL_CONFIG = config["mysql"] | ||||||
|  | JWT_CONFIG = config["jwt"] | ||||||
|  | BUSINESS_CONFIG = config["business"] | ||||||
							
								
								
									
										59
									
								
								ds/db.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,59 @@ | |||||||
|  | import mysql.connector | ||||||
|  | from mysql.connector import Error | ||||||
|  | from .config import MYSQL_CONFIG | ||||||
|  |  | ||||||
|  | _connection_pool = None | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Database: | ||||||
|  |     """MySQL 连接池管理类""" | ||||||
|  |     pool_config = { | ||||||
|  |         "host": MYSQL_CONFIG.get("host", "localhost"), | ||||||
|  |         "port": int(MYSQL_CONFIG.get("port", 3306)), | ||||||
|  |         "user": MYSQL_CONFIG.get("user", "root"), | ||||||
|  |         "password": MYSQL_CONFIG.get("password", ""), | ||||||
|  |         "database": MYSQL_CONFIG.get("database", ""), | ||||||
|  |         "charset": MYSQL_CONFIG.get("charset", "utf8mb4"), | ||||||
|  |         "pool_name": "fastapi_pool", | ||||||
|  |         "pool_size": 5, | ||||||
|  |         "pool_reset_session": True | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def get_connection(cls): | ||||||
|  |         """获取数据库连接""" | ||||||
|  |         try: | ||||||
|  |             # 从连接池获取连接 | ||||||
|  |             conn = mysql.connector.connect(**cls.pool_config) | ||||||
|  |             if conn.is_connected(): | ||||||
|  |                 return conn | ||||||
|  |         except Error as e: | ||||||
|  |             raise Exception(f"MySQL 连接失败: {str(e)}") from e | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def close_connection(cls, conn, cursor=None): | ||||||
|  |         """关闭连接和游标""" | ||||||
|  |         try: | ||||||
|  |             if cursor: | ||||||
|  |                 cursor.close() | ||||||
|  |             if conn and conn.is_connected(): | ||||||
|  |                 conn.close() | ||||||
|  |         except Error as e: | ||||||
|  |             raise Exception(f"MySQL 连接关闭失败: {str(e)}") from e | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def close_all_connections(cls): | ||||||
|  |         """清理连接池(服务重启前调用)""" | ||||||
|  |         try: | ||||||
|  |             # 先检查属性是否存在、再判断是否有值 | ||||||
|  |             if hasattr(cls, "_connection_pool") and cls._connection_pool: | ||||||
|  |                 cls._connection_pool = None  # 重置连接池 | ||||||
|  |                 print("[Database] 连接池已重置、旧连接将被自动清理") | ||||||
|  |             else: | ||||||
|  |                 print("[Database] 连接池未初始化或已重置、无需操作") | ||||||
|  |         except Exception as e: | ||||||
|  |             print(f"[Database] 重置连接池失败: {str(e)}") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 暴露数据库操作工具 | ||||||
|  | db = Database() | ||||||
							
								
								
									
										
											BIN
										
									
								
								encryption/__pycache__/encrypt_decorator.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										
											BIN
										
									
								
								encryption/__pycache__/encryption.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										40
									
								
								encryption/encrypt_decorator.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,40 @@ | |||||||
|  | import json | ||||||
|  | from datetime import datetime | ||||||
|  | from functools import wraps | ||||||
|  | from typing import Any | ||||||
|  |  | ||||||
|  | from encryption.encryption import aes_encrypt | ||||||
|  | from schema.response_schema import APIResponse | ||||||
|  | from pydantic import BaseModel | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def encrypt_response(field: str = "data"): | ||||||
|  |     """接口返回值加密装饰器:正确序列化自定义对象为JSON""" | ||||||
|  |  | ||||||
|  |     def decorator(func): | ||||||
|  |         @wraps(func) | ||||||
|  |         async def wrapper(*args, **kwargs): | ||||||
|  |             original_response: APIResponse = await func(*args, **kwargs) | ||||||
|  |             field_value = getattr(original_response, field) | ||||||
|  |  | ||||||
|  |             if not field_value: | ||||||
|  |                 return original_response | ||||||
|  |  | ||||||
|  |             # 自定义JSON序列化函数:处理Pydantic模型和datetime | ||||||
|  |             def json_default(obj: Any) -> Any: | ||||||
|  |                 if isinstance(obj, BaseModel): | ||||||
|  |                     return obj.model_dump()  | ||||||
|  |                 if isinstance(obj, datetime): | ||||||
|  |                     return obj.isoformat() | ||||||
|  |                 return str(obj)  | ||||||
|  |  | ||||||
|  |             # 使用自定义序列化函数、确保生成标准JSON | ||||||
|  |             field_value_json = json.dumps(field_value, default=json_default) | ||||||
|  |             encrypted_data = aes_encrypt(field_value_json) | ||||||
|  |             setattr(original_response, field, encrypted_data) | ||||||
|  |  | ||||||
|  |             return original_response | ||||||
|  |  | ||||||
|  |         return wrapper | ||||||
|  |  | ||||||
|  |     return decorator | ||||||
							
								
								
									
										56
									
								
								encryption/encryption.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,56 @@ | |||||||
|  | import os | ||||||
|  | import base64 | ||||||
|  | from Crypto.Cipher import AES | ||||||
|  | from Crypto.Util.Padding import pad, unpad | ||||||
|  | from fastapi import HTTPException | ||||||
|  |  | ||||||
|  | # 硬编码AES密钥(32字节、AES-256) | ||||||
|  | AES_SECRET_KEY = b"jr1vA6tfWMHOYi6UXw67UuO6fdak2rMa" | ||||||
|  | AES_BLOCK_SIZE = 16  # AES固定块大小 | ||||||
|  |  | ||||||
|  | # 校验密钥长度 | ||||||
|  | valid_key_lengths = [16, 24, 32] | ||||||
|  | if len(AES_SECRET_KEY) not in valid_key_lengths: | ||||||
|  |     raise ValueError( | ||||||
|  |         f"AES密钥长度必须为{valid_key_lengths}字节、当前为{len(AES_SECRET_KEY)}字节" | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def aes_encrypt(plaintext: str) -> dict: | ||||||
|  |     """AES-CBC模式加密(返回密文+IV、均为Base64编码)""" | ||||||
|  |     try: | ||||||
|  |         # 生成随机IV(16字节) | ||||||
|  |         iv = os.urandom(AES_BLOCK_SIZE) | ||||||
|  |  | ||||||
|  |         # 创建加密器 | ||||||
|  |         cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv) | ||||||
|  |  | ||||||
|  |         # 明文填充并加密 | ||||||
|  |         padded_plaintext = pad(plaintext.encode("utf-8"), AES_BLOCK_SIZE) | ||||||
|  |         ciphertext = base64.b64encode(cipher.encrypt(padded_plaintext)).decode("utf-8") | ||||||
|  |         iv_base64 = base64.b64encode(iv).decode("utf-8") | ||||||
|  |  | ||||||
|  |         return { | ||||||
|  |             "ciphertext": ciphertext, | ||||||
|  |             "iv": iv_base64, | ||||||
|  |             "algorithm": "AES-CBC" | ||||||
|  |         } | ||||||
|  |     except Exception as e: | ||||||
|  |         raise HTTPException(status_code=500, detail=f"AES加密失败:{str(e)}") from e | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def aes_decrypt(ciphertext: str, iv: str) -> str: | ||||||
|  |     """AES-CBC模式解密""" | ||||||
|  |     try: | ||||||
|  |         # 解码Base64 | ||||||
|  |         ciphertext_bytes = base64.b64decode(ciphertext) | ||||||
|  |         iv_bytes = base64.b64decode(iv) | ||||||
|  |  | ||||||
|  |         # 创建解密器 | ||||||
|  |         cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv_bytes) | ||||||
|  |  | ||||||
|  |         # 解密并去填充 | ||||||
|  |         decrypted_bytes = unpad(cipher.decrypt(ciphertext_bytes), AES_BLOCK_SIZE) | ||||||
|  |         return decrypted_bytes.decode("utf-8") | ||||||
|  |     except Exception as e: | ||||||
|  |         raise HTTPException(status_code=500, detail=f"AES解密失败:{str(e)}") from e | ||||||
							
								
								
									
										66
									
								
								main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,66 @@ | |||||||
|  | import uvicorn | ||||||
|  | import os | ||||||
|  | from fastapi import FastAPI | ||||||
|  | from fastapi.middleware.cors import CORSMiddleware | ||||||
|  |  | ||||||
|  | # 原有业务导入 | ||||||
|  | from ds.config import SERVER_CONFIG | ||||||
|  | from middle.error_handler import global_exception_handler | ||||||
|  | from router.user_router import router as user_router | ||||||
|  | from router.sensitive_router import router as sensitive_router | ||||||
|  | from router.face_router import router as face_router | ||||||
|  | from router.device_router import router as device_router | ||||||
|  | from router.model_router import router as model_router | ||||||
|  | from router.file_router import router as file_router | ||||||
|  | from router.device_danger_router import router as device_danger_router | ||||||
|  | from core.detect import init | ||||||
|  | from ws.ws import ws_router, lifespan | ||||||
|  |  | ||||||
|  | # 初始化 FastAPI 应用 | ||||||
|  | app = FastAPI( | ||||||
|  |     title="内容安全审核后台", | ||||||
|  |     description="含图片访问服务和动态模型管理", | ||||||
|  |     version="1.0.0", | ||||||
|  |     lifespan=lifespan | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | ALLOWED_ORIGINS = [ | ||||||
|  |     "*" | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | # 配置 CORS 中间件 | ||||||
|  | app.add_middleware( | ||||||
|  |     CORSMiddleware, | ||||||
|  |     allow_origins=ALLOWED_ORIGINS,        # 允许的前端域名 | ||||||
|  |     allow_credentials=True,               # 允许携带 Cookie | ||||||
|  |     allow_methods=["*"],                  # 允许所有 HTTP 方法 | ||||||
|  |     allow_headers=["*"],                  # 允许所有请求头 | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | # 注册路由 | ||||||
|  | app.include_router(user_router) | ||||||
|  | app.include_router(device_router) | ||||||
|  | app.include_router(face_router) | ||||||
|  | app.include_router(sensitive_router) | ||||||
|  | app.include_router(model_router) | ||||||
|  | app.include_router(file_router) | ||||||
|  | app.include_router(device_danger_router) | ||||||
|  | app.include_router(ws_router) | ||||||
|  |  | ||||||
|  | # 注册全局异常处理器 | ||||||
|  | app.add_exception_handler(Exception, global_exception_handler) | ||||||
|  |  | ||||||
|  | # 主服务启动入口 | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     # 启动 FastAPI 主服务(仅使用8000端口) | ||||||
|  |     port = int(SERVER_CONFIG.get("port", 8000)) | ||||||
|  |     # 加载所有模型 | ||||||
|  |     init() | ||||||
|  |     uvicorn.run( | ||||||
|  |         app="main:app", | ||||||
|  |         host="0.0.0.0", | ||||||
|  |         port=port, | ||||||
|  |         workers=1, | ||||||
|  |         ws="websockets", | ||||||
|  |         reload=False | ||||||
|  |     ) | ||||||
							
								
								
									
										
											BIN
										
									
								
								middle/__pycache__/auth_middleware.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										
											BIN
										
									
								
								middle/__pycache__/error_handler.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										102
									
								
								middle/auth_middleware.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,102 @@ | |||||||
|  | from datetime import datetime, timedelta, timezone | ||||||
|  | from typing import Optional | ||||||
|  |  | ||||||
|  | from fastapi import Depends, HTTPException, status | ||||||
|  | from fastapi.security import OAuth2PasswordBearer | ||||||
|  | from jose import JWTError, jwt | ||||||
|  | from passlib.context import CryptContext | ||||||
|  |  | ||||||
|  | from ds.config import JWT_CONFIG | ||||||
|  | from ds.db import db | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 密码加密配置 | ||||||
|  | # ------------------------------ | ||||||
|  | pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # JWT 配置 | ||||||
|  | # ------------------------------ | ||||||
|  | SECRET_KEY = JWT_CONFIG["secret_key"] | ||||||
|  | ALGORITHM = JWT_CONFIG["algorithm"] | ||||||
|  | ACCESS_TOKEN_EXPIRE_MINUTES = int(JWT_CONFIG["access_token_expire_minutes"]) | ||||||
|  |  | ||||||
|  | # OAuth2 依赖(从请求头获取 Token、格式: Bearer <token>) | ||||||
|  | oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 密码工具函数 | ||||||
|  | # ------------------------------ | ||||||
|  | def verify_password(plain_password: str, hashed_password: str) -> bool: | ||||||
|  |     """验证明文密码与加密密码是否匹配""" | ||||||
|  |     return pwd_context.verify(plain_password, hashed_password) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_password_hash(password: str) -> str: | ||||||
|  |     """对明文密码进行 bcrypt 加密""" | ||||||
|  |     return pwd_context.hash(password) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # JWT 工具函数 | ||||||
|  | # ------------------------------ | ||||||
|  | def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: | ||||||
|  |     """生成 JWT Token""" | ||||||
|  |     to_encode = data.copy() | ||||||
|  |     # 设置过期时间 | ||||||
|  |     if expires_delta: | ||||||
|  |         expire = datetime.now(timezone.utc) + expires_delta | ||||||
|  |     else: | ||||||
|  |         expire = datetime.now(timezone.utc) + timedelta(minutes=15) | ||||||
|  |     # 添加过期时间到 Token 数据 | ||||||
|  |     to_encode.update({"exp": expire}) | ||||||
|  |     # 生成 Token | ||||||
|  |     encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) | ||||||
|  |     return encoded_jwt | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 认证依赖(获取当前登录用户) | ||||||
|  | # ------------------------------ | ||||||
|  | def get_current_user(token: str = Depends(oauth2_scheme)):  # 移除返回类型注解 | ||||||
|  |     """从 Token 中解析用户信息、验证通过后返回当前用户""" | ||||||
|  |     # 延迟导入、打破循环依赖 | ||||||
|  |     from schema.user_schema import UserResponse  # 在这里导入 | ||||||
|  |  | ||||||
|  |     # 认证失败异常 | ||||||
|  |     credentials_exception = HTTPException( | ||||||
|  |         status_code=status.HTTP_401_UNAUTHORIZED, | ||||||
|  |         detail="Token 无效或已过期", | ||||||
|  |         headers={"WWW-Authenticate": "Bearer"}, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     try: | ||||||
|  |         # 解码 Token | ||||||
|  |         payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | ||||||
|  |         # 获取 Token 中的用户名 | ||||||
|  |         username: str = payload.get("sub") | ||||||
|  |         if username is None: | ||||||
|  |             raise credentials_exception | ||||||
|  |     except JWTError: | ||||||
|  |         raise credentials_exception | ||||||
|  |  | ||||||
|  |     # 从数据库查询用户(验证用户是否存在) | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |         query = "SELECT id, username, created_at, updated_at FROM users WHERE username = %s" | ||||||
|  |         cursor.execute(query, (username,)) | ||||||
|  |         user = cursor.fetchone() | ||||||
|  |  | ||||||
|  |         if user is None: | ||||||
|  |             raise credentials_exception | ||||||
|  |  | ||||||
|  |         # 转换为 UserResponse 模型(自动校验字段) | ||||||
|  |         return UserResponse(**user) | ||||||
|  |     except Exception as e: | ||||||
|  |         raise credentials_exception from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
							
								
								
									
										68
									
								
								middle/error_handler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,68 @@ | |||||||
|  | from fastapi import Request, status | ||||||
|  | from fastapi.responses import JSONResponse | ||||||
|  | from fastapi.exceptions import HTTPException, RequestValidationError | ||||||
|  | from mysql.connector import Error as MySQLError | ||||||
|  | from jose import JWTError | ||||||
|  |  | ||||||
|  | from schema.response_schema import APIResponse | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def global_exception_handler(request: Request, exc: Exception): | ||||||
|  |     """全局异常处理器: 所有未捕获的异常都会在这里统一处理""" | ||||||
|  |     # 请求参数验证错误(Pydantic 校验失败) | ||||||
|  |     if isinstance(exc, RequestValidationError): | ||||||
|  |         error_details = [] | ||||||
|  |         for err in exc.errors(): | ||||||
|  |             error_details.append(f"{err['loc'][1]}: {err['msg']}") | ||||||
|  |         return JSONResponse( | ||||||
|  |             status_code=status.HTTP_400_BAD_REQUEST, | ||||||
|  |             content=APIResponse( | ||||||
|  |                 code=400, | ||||||
|  |                 message=f"请求参数错误: {'; '.join(error_details)}", | ||||||
|  |                 data=None | ||||||
|  |             ).model_dump() | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     # HTTP 异常(主动抛出的业务错误、如 401/404) | ||||||
|  |     if isinstance(exc, HTTPException): | ||||||
|  |         return JSONResponse( | ||||||
|  |             status_code=exc.status_code, | ||||||
|  |             content=APIResponse( | ||||||
|  |                 code=exc.status_code, | ||||||
|  |                 message=exc.detail, | ||||||
|  |                 data=None | ||||||
|  |             ).model_dump() | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     # JWT 相关错误(Token 无效/过期) | ||||||
|  |     if isinstance(exc, JWTError): | ||||||
|  |         return JSONResponse( | ||||||
|  |             status_code=status.HTTP_401_UNAUTHORIZED, | ||||||
|  |             content=APIResponse( | ||||||
|  |                 code=401, | ||||||
|  |                 message="Token 无效或已过期", | ||||||
|  |                 data=None | ||||||
|  |             ).model_dump(), | ||||||
|  |             headers={"WWW-Authenticate": "Bearer"}, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     # MySQL 数据库错误 | ||||||
|  |     if isinstance(exc, MySQLError): | ||||||
|  |         return JSONResponse( | ||||||
|  |             status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||||||
|  |             content=APIResponse( | ||||||
|  |                 code=500, | ||||||
|  |                 message=f"数据库错误: {str(exc)}", | ||||||
|  |                 data=None | ||||||
|  |             ).model_dump() | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     # 其他未知错误 | ||||||
|  |     return JSONResponse( | ||||||
|  |         status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||||||
|  |         content=APIResponse( | ||||||
|  |             code=500, | ||||||
|  |             message=f"服务器内部错误: {str(exc)}", | ||||||
|  |             data=None | ||||||
|  |         ).model_dump() | ||||||
|  |     ) | ||||||
							
								
								
									
										21
									
								
								requirements.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,21 @@ | |||||||
|  | fastapi==0.116.1 | ||||||
|  | insightface==0.7.3 | ||||||
|  | numpy==1.26.1 | ||||||
|  | opencv_contrib_python==4.6.0.66 | ||||||
|  | opencv_python==4.8.1.78 | ||||||
|  | opencv_python_headless==4.11.0.86 | ||||||
|  | paddleocr==2.6.0.1 | ||||||
|  | paddlepaddle-gpu==2.5.2 | ||||||
|  | passlib==1.7.4 | ||||||
|  | Pillow==11.3.0 | ||||||
|  | pycryptodome==3.23.0 | ||||||
|  | pydantic==2.11.7 | ||||||
|  | python_jose==3.5.0 | ||||||
|  | torch==2.3.1+cu118 | ||||||
|  | torchaudio==2.3.1+cu118 | ||||||
|  | torchvision==0.18.1+cu118 | ||||||
|  | ultralytics==8.3.198 | ||||||
|  | uvicorn==0.35.0 | ||||||
|  | insightface==0.7.3 | ||||||
|  | onnxruntime-gpu==1.15.1 | ||||||
|  | uvicorn[standard] | ||||||
							
								
								
									
										
											BIN
										
									
								
								router/__pycache__/device_danger_router.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										
											BIN
										
									
								
								router/__pycache__/device_router.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										
											BIN
										
									
								
								router/__pycache__/face_router.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										
											BIN
										
									
								
								router/__pycache__/file_router.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										
											BIN
										
									
								
								router/__pycache__/model_router.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										
											BIN
										
									
								
								router/__pycache__/sensitive_router.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										
											BIN
										
									
								
								router/__pycache__/user_router.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										112
									
								
								router/device_action_router.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,112 @@ | |||||||
|  | from fastapi import APIRouter, Query, Path | ||||||
|  | from mysql.connector import Error as MySQLError | ||||||
|  |  | ||||||
|  | from ds.db import db | ||||||
|  | from encryption.encrypt_decorator import encrypt_response | ||||||
|  | from schema.device_action_schema import ( | ||||||
|  |     DeviceActionResponse, | ||||||
|  |     DeviceActionListResponse | ||||||
|  | ) | ||||||
|  | from schema.response_schema import APIResponse | ||||||
|  |  | ||||||
|  | # 路由配置 | ||||||
|  | router = APIRouter( | ||||||
|  |     prefix="/api/device/actions", | ||||||
|  |     tags=["设备操作记录"] | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | @router.get("/list", response_model=APIResponse, summary="分页查询设备操作记录") | ||||||
|  | @encrypt_response() | ||||||
|  | async def get_device_action_list( | ||||||
|  |         page: int = Query(1, ge=1, description="页码、默认1"), | ||||||
|  |         page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100"), | ||||||
|  |         client_ip: str = Query(None, description="按客户端IP筛选"), | ||||||
|  |         action: int = Query(None, ge=0, le=1, description="按状态筛选(0=离线、1=上线)") | ||||||
|  | ): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |         # 构建筛选条件(参数化查询、避免注入) | ||||||
|  |         where_clause = [] | ||||||
|  |         params = [] | ||||||
|  |         if client_ip: | ||||||
|  |             where_clause.append("client_ip = %s") | ||||||
|  |             params.append(client_ip) | ||||||
|  |         if action is not None: | ||||||
|  |             where_clause.append("action = %s") | ||||||
|  |             params.append(action) | ||||||
|  |  | ||||||
|  |         # 查询总记录数(用于返回 total) | ||||||
|  |         count_sql = "SELECT COUNT(*) AS total FROM device_action" | ||||||
|  |         if where_clause: | ||||||
|  |             count_sql += " WHERE " + " AND ".join(where_clause) | ||||||
|  |         cursor.execute(count_sql, params) | ||||||
|  |         total = cursor.fetchone()["total"] | ||||||
|  |  | ||||||
|  |         # 分页查询记录(按创建时间倒序、确保最新记录在前) | ||||||
|  |         offset = (page - 1) * page_size | ||||||
|  |         list_sql = "SELECT * FROM device_action" | ||||||
|  |         if where_clause: | ||||||
|  |             list_sql += " WHERE " + " AND ".join(where_clause) | ||||||
|  |         list_sql += " ORDER BY created_at DESC LIMIT %s OFFSET %s" | ||||||
|  |         params.extend([page_size, offset]) | ||||||
|  |  | ||||||
|  |         cursor.execute(list_sql, params) | ||||||
|  |         action_list = cursor.fetchall() | ||||||
|  |  | ||||||
|  |         # 仅返回 total + device_actions | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message="查询成功", | ||||||
|  |             data=DeviceActionListResponse( | ||||||
|  |                 total=total, | ||||||
|  |                 device_actions=[DeviceActionResponse(**item) for item in action_list] | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise Exception(f"查询记录失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  | @router.get("/{client_ip}", response_model=APIResponse, summary="根据IP查询设备操作记录") | ||||||
|  | @encrypt_response() | ||||||
|  | async def get_device_actions_by_ip( | ||||||
|  |         client_ip: str = Path(..., description="客户端IP地址") | ||||||
|  | ): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 1. 查询总记录数 | ||||||
|  |         count_sql = "SELECT COUNT(*) AS total FROM device_action WHERE client_ip = %s" | ||||||
|  |         cursor.execute(count_sql, (client_ip,)) | ||||||
|  |         total = cursor.fetchone()["total"] | ||||||
|  |  | ||||||
|  |         # 2. 查询该IP的所有记录(按创建时间倒序) | ||||||
|  |         list_sql = """ | ||||||
|  |             SELECT * FROM device_action  | ||||||
|  |             WHERE client_ip = %s  | ||||||
|  |             ORDER BY created_at DESC | ||||||
|  |         """ | ||||||
|  |         cursor.execute(list_sql, (client_ip,)) | ||||||
|  |         action_list = cursor.fetchall() | ||||||
|  |  | ||||||
|  |         # 3. 返回结果 | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message="查询成功", | ||||||
|  |             data=DeviceActionListResponse( | ||||||
|  |                 total=total, | ||||||
|  |                 device_actions=[DeviceActionResponse(**item) for item in action_list] | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise Exception(f"查询记录失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
							
								
								
									
										164
									
								
								router/device_danger_router.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,164 @@ | |||||||
|  | from datetime import date | ||||||
|  |  | ||||||
|  | from fastapi import APIRouter, Query, HTTPException, Path | ||||||
|  | from mysql.connector import Error as MySQLError | ||||||
|  |  | ||||||
|  | from ds.db import db | ||||||
|  | from encryption.encrypt_decorator import encrypt_response | ||||||
|  | from schema.device_danger_schema import ( | ||||||
|  |      DeviceDangerResponse, DeviceDangerListResponse | ||||||
|  | ) | ||||||
|  | from schema.response_schema import APIResponse | ||||||
|  |  | ||||||
|  | router = APIRouter( | ||||||
|  |     prefix="/api/devices/dangers", | ||||||
|  |     tags=["设备管理-危险记录"] | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | # 获取危险记录列表 | ||||||
|  | @router.get("/", response_model=APIResponse, summary="获取设备危险记录列表(多条件筛选)") | ||||||
|  | @encrypt_response() | ||||||
|  | async def get_danger_list( | ||||||
|  |         page: int = Query(1, ge=1, description="页码、默认第1页"), | ||||||
|  |         page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间"), | ||||||
|  |         client_ip: str = Query(None, max_length=100, description="按设备IP筛选"), | ||||||
|  |         danger_type: str = Query(None, max_length=255, alias="type", description="按危险类型筛选"), | ||||||
|  |         start_date: date = Query(None, description="按创建时间筛选(开始日期、格式YYYY-MM-DD)"), | ||||||
|  |         end_date: date = Query(None, description="按创建时间筛选(结束日期、格式YYYY-MM-DD)") | ||||||
|  | ): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 构建筛选条件 | ||||||
|  |         where_clause = [] | ||||||
|  |         params = [] | ||||||
|  |  | ||||||
|  |         if client_ip: | ||||||
|  |             where_clause.append("client_ip = %s") | ||||||
|  |             params.append(client_ip) | ||||||
|  |         if danger_type: | ||||||
|  |             where_clause.append("type = %s") | ||||||
|  |             params.append(danger_type) | ||||||
|  |         if start_date: | ||||||
|  |             where_clause.append("DATE(created_at) >= %s") | ||||||
|  |             params.append(start_date.strftime("%Y-%m-%d")) | ||||||
|  |         if end_date: | ||||||
|  |             where_clause.append("DATE(created_at) <= %s") | ||||||
|  |             params.append(end_date.strftime("%Y-%m-%d")) | ||||||
|  |  | ||||||
|  |         # 1. 统计符合条件的总记录数 | ||||||
|  |         count_query = "SELECT COUNT(*) AS total FROM device_danger" | ||||||
|  |         if where_clause: | ||||||
|  |             count_query += " WHERE " + " AND ".join(where_clause) | ||||||
|  |         cursor.execute(count_query, params) | ||||||
|  |         total = cursor.fetchone()["total"] | ||||||
|  |  | ||||||
|  |         # 2. 分页查询记录(按创建时间倒序、最新的在前) | ||||||
|  |         offset = (page - 1) * page_size | ||||||
|  |         list_query = "SELECT * FROM device_danger" | ||||||
|  |         if where_clause: | ||||||
|  |             list_query += " WHERE " + " AND ".join(where_clause) | ||||||
|  |         list_query += " ORDER BY created_at DESC LIMIT %s OFFSET %s" | ||||||
|  |         params.extend([page_size, offset])  # 追加分页参数 | ||||||
|  |  | ||||||
|  |         cursor.execute(list_query, params) | ||||||
|  |         danger_list = cursor.fetchall() | ||||||
|  |  | ||||||
|  |         # 转换为响应模型 | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message="获取危险记录列表成功", | ||||||
|  |             data=DeviceDangerListResponse( | ||||||
|  |                 total=total, | ||||||
|  |                 dangers=[DeviceDangerResponse(**item) for item in danger_list] | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise HTTPException(status_code=500, detail=f"查询危险记录失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 获取单个设备的所有危险记录 | ||||||
|  | @router.get("/device/{client_ip}", response_model=APIResponse, summary="获取单个设备的所有危险记录") | ||||||
|  | # @encrypt_response() | ||||||
|  | async def get_device_dangers( | ||||||
|  |         client_ip: str = Path(..., max_length=100, description="设备IP地址"), | ||||||
|  |         page: int = Query(1, ge=1, description="页码、默认第1页"), | ||||||
|  |         page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间") | ||||||
|  | ): | ||||||
|  |     # 先检查设备是否存在 | ||||||
|  |     from service.device_danger_service import check_device_exist | ||||||
|  |     if not check_device_exist(client_ip): | ||||||
|  |         raise HTTPException(status_code=404, detail=f"IP为 {client_ip} 的设备不存在") | ||||||
|  |  | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 1. 统计该设备的危险记录总数 | ||||||
|  |         count_query = "SELECT COUNT(*) AS total FROM device_danger WHERE client_ip = %s" | ||||||
|  |         cursor.execute(count_query, (client_ip,)) | ||||||
|  |         total = cursor.fetchone()["total"] | ||||||
|  |  | ||||||
|  |         # 2. 分页查询该设备的危险记录 | ||||||
|  |         offset = (page - 1) * page_size | ||||||
|  |         list_query = """ | ||||||
|  |             SELECT * FROM device_danger  | ||||||
|  |             WHERE client_ip = %s  | ||||||
|  |             ORDER BY created_at DESC  | ||||||
|  |             LIMIT %s OFFSET %s | ||||||
|  |         """ | ||||||
|  |         cursor.execute(list_query, (client_ip, page_size, offset)) | ||||||
|  |         danger_list = cursor.fetchall() | ||||||
|  |  | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message=f"获取设备[{client_ip}]危险记录成功(共{total}条)", | ||||||
|  |             data=DeviceDangerListResponse( | ||||||
|  |                 total=total, | ||||||
|  |                 dangers=[DeviceDangerResponse(**item) for item in danger_list] | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise HTTPException(status_code=500, detail=f"查询设备[{client_ip}]危险记录失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 根据ID获取单个危险记录详情 | ||||||
|  | # ------------------------------ | ||||||
|  | @router.get("/{danger_id}", response_model=APIResponse, summary="根据ID获取单个危险记录详情") | ||||||
|  | @encrypt_response() | ||||||
|  | async def get_danger_detail( | ||||||
|  |         danger_id: int = Path(..., ge=1, description="危险记录ID") | ||||||
|  | ): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 查询单个危险记录 | ||||||
|  |         query = "SELECT * FROM device_danger WHERE id = %s" | ||||||
|  |         cursor.execute(query, (danger_id,)) | ||||||
|  |         danger = cursor.fetchone() | ||||||
|  |  | ||||||
|  |         if not danger: | ||||||
|  |             raise HTTPException(status_code=404, detail=f"ID为 {danger_id} 的危险记录不存在") | ||||||
|  |  | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message="获取危险记录详情成功", | ||||||
|  |             data=DeviceDangerResponse(**danger) | ||||||
|  |         ) | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise HTTPException(status_code=500, detail=f"查询危险记录详情失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
							
								
								
									
										329
									
								
								router/device_router.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,329 @@ | |||||||
|  | import asyncio | ||||||
|  | import json | ||||||
|  | from datetime import date | ||||||
|  |  | ||||||
|  | from fastapi import APIRouter, Query, HTTPException, Request, Path | ||||||
|  | from mysql.connector import Error as MySQLError | ||||||
|  | from ds.db import db | ||||||
|  | from encryption.encrypt_decorator import encrypt_response | ||||||
|  | from schema.device_schema import ( | ||||||
|  |     DeviceCreateRequest, DeviceResponse, DeviceListResponse, | ||||||
|  |     DeviceStatusHistoryResponse, DeviceStatusHistoryListResponse | ||||||
|  | ) | ||||||
|  | from schema.response_schema import APIResponse | ||||||
|  | from service.device_service import update_online_status_by_ip | ||||||
|  | from ws.ws import get_current_time_str, aes_encrypt, is_client_connected | ||||||
|  |  | ||||||
|  | router = APIRouter( | ||||||
|  |     prefix="/api/devices", | ||||||
|  |     tags=["设备管理"] | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 创建设备信息接口 | ||||||
|  | @router.post("/add", response_model=APIResponse, summary="创建设备信息") | ||||||
|  | @encrypt_response() | ||||||
|  | async def create_device(device_data: DeviceCreateRequest, request: Request): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 检查设备是否已存在 | ||||||
|  |         cursor.execute("SELECT * FROM devices WHERE client_ip = %s", (device_data.ip,)) | ||||||
|  |         existing_device = cursor.fetchone() | ||||||
|  |         if existing_device: | ||||||
|  |             # 更新设备为在线状态 | ||||||
|  |             from service.device_service import update_online_status_by_ip | ||||||
|  |             update_online_status_by_ip(client_ip=device_data.ip, online_status=1) | ||||||
|  |             return APIResponse( | ||||||
|  |                 code=200, | ||||||
|  |                 message=f"设备IP {device_data.ip} 已存在、返回已有设备信息", | ||||||
|  |                 data=DeviceResponse(**existing_device) | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         # 通过 User-Agent 判断设备类型 | ||||||
|  |         user_agent = request.headers.get("User-Agent", "").lower() | ||||||
|  |         device_type = "unknown" | ||||||
|  |         if user_agent == "default": | ||||||
|  |             device_type = device_data.params.get("os") if (device_data.params and isinstance(device_data.params, dict)) else "unknown" | ||||||
|  |         elif "windows" in user_agent: | ||||||
|  |             device_type = "windows" | ||||||
|  |         elif "android" in user_agent: | ||||||
|  |             device_type = "android" | ||||||
|  |         elif "linux" in user_agent: | ||||||
|  |             device_type = "linux" | ||||||
|  |  | ||||||
|  |         device_params_json = json.dumps(device_data.params) if device_data.params else None | ||||||
|  |  | ||||||
|  |         # 插入新设备 | ||||||
|  |         insert_query = """ | ||||||
|  |             INSERT INTO devices  | ||||||
|  |             (client_ip, hostname, device_online_status, device_type, alarm_count, params, is_need_handler) | ||||||
|  |             VALUES (%s, %s, %s, %s, %s, %s, %s) | ||||||
|  |         """ | ||||||
|  |         cursor.execute(insert_query, ( | ||||||
|  |             device_data.ip, | ||||||
|  |             device_data.hostname, | ||||||
|  |             0, | ||||||
|  |             device_type, | ||||||
|  |             0, | ||||||
|  |             device_params_json, | ||||||
|  |             0 | ||||||
|  |         )) | ||||||
|  |         conn.commit() | ||||||
|  |  | ||||||
|  |         # 获取新设备并返回 | ||||||
|  |         device_id = cursor.lastrowid | ||||||
|  |         cursor.execute("SELECT * FROM devices WHERE id = %s", (device_id,)) | ||||||
|  |         new_device = cursor.fetchone() | ||||||
|  |  | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message="设备创建成功", | ||||||
|  |             data=DeviceResponse(**new_device) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     except MySQLError as e: | ||||||
|  |         if conn: | ||||||
|  |             conn.rollback() | ||||||
|  |         raise Exception(f"创建设备失败: {str(e)}") from e | ||||||
|  |     except json.JSONDecodeError as e: | ||||||
|  |         raise Exception(f"设备参数JSON序列化失败: {str(e)}") from e | ||||||
|  |     except Exception as e: | ||||||
|  |         if conn: | ||||||
|  |             conn.rollback() | ||||||
|  |         raise e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 获取设备列表接口 | ||||||
|  | # ------------------------------ | ||||||
|  | @router.get("/", response_model=APIResponse, summary="获取设备列表(支持筛选分页)") | ||||||
|  | @encrypt_response() | ||||||
|  | async def get_device_list( | ||||||
|  |         page: int = Query(1, ge=1, description="页码、默认第1页"), | ||||||
|  |         page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间"), | ||||||
|  |         device_type: str = Query(None, description="按设备类型筛选"), | ||||||
|  |         online_status: int = Query(None, ge=0, le=1, description="按在线状态筛选") | ||||||
|  | ): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         where_clause = [] | ||||||
|  |         params = [] | ||||||
|  |  | ||||||
|  |         if device_type: | ||||||
|  |             where_clause.append("device_type = %s") | ||||||
|  |             params.append(device_type) | ||||||
|  |         if online_status is not None: | ||||||
|  |             where_clause.append("device_online_status = %s") | ||||||
|  |             params.append(online_status) | ||||||
|  |  | ||||||
|  |         # 统计总数 | ||||||
|  |         count_query = "SELECT COUNT(*) AS total FROM devices" | ||||||
|  |         if where_clause: | ||||||
|  |             count_query += " WHERE " + " AND ".join(where_clause) | ||||||
|  |         cursor.execute(count_query, params) | ||||||
|  |         total = cursor.fetchone()["total"] | ||||||
|  |  | ||||||
|  |         # 分页查询列表 | ||||||
|  |         offset = (page - 1) * page_size | ||||||
|  |         list_query = "SELECT * FROM devices" | ||||||
|  |         if where_clause: | ||||||
|  |             list_query += " WHERE " + " AND ".join(where_clause) | ||||||
|  |         list_query += " ORDER BY id DESC LIMIT %s OFFSET %s" | ||||||
|  |         params.extend([page_size, offset]) | ||||||
|  |  | ||||||
|  |         cursor.execute(list_query, params) | ||||||
|  |         device_list = cursor.fetchall() | ||||||
|  |  | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message="获取设备列表成功", | ||||||
|  |             data=DeviceListResponse( | ||||||
|  |                 total=total, | ||||||
|  |                 devices=[DeviceResponse(**device) for device in device_list] | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise Exception(f"获取设备列表失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 获取设备上下线记录接口 | ||||||
|  | # ------------------------------ | ||||||
|  | @router.get("/status-history", response_model=APIResponse, summary="获取设备上下线记录") | ||||||
|  | @encrypt_response() | ||||||
|  | async def get_device_status_history( | ||||||
|  |         client_ip: str = Query(None, description="客户端IP地址(非必填,为空时返回所有设备记录)"), | ||||||
|  |         page: int = Query(1, ge=1, description="页码、默认第1页"), | ||||||
|  |         page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间"), | ||||||
|  |         start_date: date = Query(None, description="开始日期、格式YYYY-MM-DD"), | ||||||
|  |         end_date: date = Query(None, description="结束日期、格式YYYY-MM-DD") | ||||||
|  | ): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 1. 检查设备是否存在(仅传IP时):强制指定Collation | ||||||
|  |         if client_ip is not None: | ||||||
|  |             # 关键调整1:WHERE条件中给d.client_ip指定Collation(与da一致或反之) | ||||||
|  |             check_query = """ | ||||||
|  |                 SELECT id FROM devices  | ||||||
|  |                 WHERE client_ip COLLATE utf8mb4_general_ci = %s COLLATE utf8mb4_general_ci | ||||||
|  |             """ | ||||||
|  |             cursor.execute(check_query, (client_ip,)) | ||||||
|  |             device = cursor.fetchone() | ||||||
|  |             if not device: | ||||||
|  |                 raise HTTPException(status_code=404, detail=f"客户端IP为 {client_ip} 的设备不存在") | ||||||
|  |  | ||||||
|  |         # 2. 构建WHERE条件 | ||||||
|  |         where_clause = [] | ||||||
|  |         params = [] | ||||||
|  |  | ||||||
|  |         # 关键调整2:传IP时,强制指定da.client_ip的Collation | ||||||
|  |         if client_ip is not None: | ||||||
|  |             where_clause.append("da.client_ip COLLATE utf8mb4_general_ci = %s COLLATE utf8mb4_general_ci") | ||||||
|  |             params.append(client_ip) | ||||||
|  |         if start_date: | ||||||
|  |             where_clause.append("DATE(da.created_at) >= %s") | ||||||
|  |             params.append(start_date.strftime("%Y-%m-%d")) | ||||||
|  |         if end_date: | ||||||
|  |             where_clause.append("DATE(da.created_at) <= %s") | ||||||
|  |             params.append(end_date.strftime("%Y-%m-%d")) | ||||||
|  |  | ||||||
|  |         # 3. 统计总数:JOIN时强制统一Collation | ||||||
|  |         count_query = """ | ||||||
|  |             SELECT COUNT(*) AS total  | ||||||
|  |             FROM device_action da  | ||||||
|  |             LEFT JOIN devices d  | ||||||
|  |                 ON da.client_ip COLLATE utf8mb4_general_ci = d.client_ip COLLATE utf8mb4_general_ci | ||||||
|  |         """ | ||||||
|  |         if where_clause: | ||||||
|  |             count_query += " WHERE " + " AND ".join(where_clause) | ||||||
|  |         cursor.execute(count_query, params) | ||||||
|  |         total = cursor.fetchone()["total"] | ||||||
|  |  | ||||||
|  |         # 4. 分页查询:JOIN时强制统一Collation | ||||||
|  |         offset = (page - 1) * page_size | ||||||
|  |         list_query = """ | ||||||
|  |             SELECT da.*, d.id AS device_id  | ||||||
|  |             FROM device_action da  | ||||||
|  |             LEFT JOIN devices d  | ||||||
|  |                 ON da.client_ip COLLATE utf8mb4_general_ci = d.client_ip COLLATE utf8mb4_general_ci | ||||||
|  |         """ | ||||||
|  |         if where_clause: | ||||||
|  |             list_query += " WHERE " + " AND ".join(where_clause) | ||||||
|  |         list_query += " ORDER BY da.created_at DESC LIMIT %s OFFSET %s" | ||||||
|  |         params.extend([page_size, offset]) | ||||||
|  |         cursor.execute(list_query, params) | ||||||
|  |         history_list = cursor.fetchall() | ||||||
|  |  | ||||||
|  |         # 后续格式化响应逻辑不变... | ||||||
|  |         formatted_history = [] | ||||||
|  |         for item in history_list: | ||||||
|  |             formatted_item = { | ||||||
|  |                 "id": item["id"], | ||||||
|  |                 "device_id": item["device_id"],  # 可能为None(IP无对应设备) | ||||||
|  |                 "client_ip": item["client_ip"], | ||||||
|  |                 "status": item["action"], | ||||||
|  |                 "status_time": item["created_at"] | ||||||
|  |             } | ||||||
|  |             formatted_history.append(formatted_item) | ||||||
|  |  | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message="获取设备上下线记录成功", | ||||||
|  |             data=DeviceStatusHistoryListResponse( | ||||||
|  |                 total=total, | ||||||
|  |                 history=[DeviceStatusHistoryResponse(**item) for item in formatted_history] | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise Exception(f"获取设备上下线记录失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 通过客户端IP设置设备is_need_handler为0接口 | ||||||
|  | # ------------------------------ | ||||||
|  | @router.post("/need-handler/reset", response_model=APIResponse, summary="解封客户端") | ||||||
|  | @encrypt_response() | ||||||
|  | async def reset_device_need_handler( | ||||||
|  |     client_ip: str = Query(..., description="目标设备的客户端IP地址(必填)") | ||||||
|  | ): | ||||||
|  |     try: | ||||||
|  |         from service.device_service import update_is_need_handler_by_client_ip | ||||||
|  |         success = update_is_need_handler_by_client_ip( | ||||||
|  |             client_ip=client_ip, | ||||||
|  |             is_need_handler=0  # 固定设置为0(不需要处理) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         if success: | ||||||
|  |             online_status = is_client_connected(client_ip) | ||||||
|  |  | ||||||
|  |             # 如果设备在线,则发送消息给前端 | ||||||
|  |             if online_status: | ||||||
|  |                 # 调用 ws 发送一个消息给前端、告诉他已解锁 | ||||||
|  |                 unlock_msg = { | ||||||
|  |                     "type": "unlock", | ||||||
|  |                     "timestamp": get_current_time_str(), | ||||||
|  |                     "client_ip": client_ip | ||||||
|  |                 } | ||||||
|  |                 from ws.ws import send_message_to_client | ||||||
|  |                 await send_message_to_client(client_ip, json.dumps(unlock_msg)) | ||||||
|  |  | ||||||
|  |                 # 休眠 100 ms | ||||||
|  |                 await asyncio.sleep(0.1) | ||||||
|  |  | ||||||
|  |                 frame_permit_msg = { | ||||||
|  |                     "type": "frame", | ||||||
|  |                     "timestamp": get_current_time_str(), | ||||||
|  |                     "client_ip": client_ip | ||||||
|  |                 } | ||||||
|  |                 await send_message_to_client(client_ip, json.dumps(frame_permit_msg)) | ||||||
|  |  | ||||||
|  |                 # 更新设备在线状态为1 | ||||||
|  |                 update_online_status_by_ip(client_ip, 1) | ||||||
|  |  | ||||||
|  |             return APIResponse( | ||||||
|  |                 code=200, | ||||||
|  |                 message=f"设备已解封", | ||||||
|  |                 data={ | ||||||
|  |                     "client_ip": client_ip, | ||||||
|  |                     "is_need_handler": 0, | ||||||
|  |                     "status_desc": "设备已解封" | ||||||
|  |                 } | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     # 捕获工具方法抛出的业务异常(如IP为空、设备不存在) | ||||||
|  |     except ValueError as e: | ||||||
|  |         # 业务异常返回400/404状态码(与现有接口异常规范一致) | ||||||
|  |         raise HTTPException( | ||||||
|  |             status_code=404 if "设备不存在" in str(e) else 400, | ||||||
|  |             detail=str(e) | ||||||
|  |         ) from e | ||||||
|  |  | ||||||
|  |     # 捕获数据库层面异常(如连接失败、SQL执行错误) | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise Exception(f"设置is_need_handler失败:数据库操作异常 - {str(e)}") from e | ||||||
|  |  | ||||||
|  |     # 捕获其他未知异常 | ||||||
|  |     except Exception as e: | ||||||
|  |         raise Exception(f"设置is_need_handler失败:未知错误 - {str(e)}") from e | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
							
								
								
									
										326
									
								
								router/face_router.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,326 @@ | |||||||
|  | from io import BytesIO | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  | from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query, Request | ||||||
|  | from mysql.connector import Error as MySQLError | ||||||
|  |  | ||||||
|  | from ds.db import db | ||||||
|  | from encryption.encrypt_decorator import encrypt_response | ||||||
|  | from schema.face_schema import ( | ||||||
|  |     FaceCreateRequest, | ||||||
|  |     FaceResponse, | ||||||
|  |     FaceListResponse | ||||||
|  | ) | ||||||
|  | from schema.response_schema import APIResponse | ||||||
|  | from service.face_service import update_face_data | ||||||
|  | from util.face_util import add_binary_data | ||||||
|  | from service.file_service import save_source_file | ||||||
|  |  | ||||||
|  | router = APIRouter(prefix="/api/faces", tags=["人脸管理"]) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 创建人脸记录 | ||||||
|  | @router.post("", response_model=APIResponse, summary="创建人脸记录") | ||||||
|  | @encrypt_response() | ||||||
|  | async def create_face( | ||||||
|  |         request: Request, | ||||||
|  |         name: str = Form(None, max_length=255, description="名称(可选)"), | ||||||
|  |         file: UploadFile = File(..., description="人脸文件(必传)") | ||||||
|  | ): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         face_create = FaceCreateRequest(name=name) | ||||||
|  |         client_ip = request.client.host if request.client else "" | ||||||
|  |         if not client_ip: | ||||||
|  |             raise HTTPException(status_code=400, detail="无法获取客户端IP") | ||||||
|  |  | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |         # 先读取文件内容 | ||||||
|  |         file_content = await file.read() | ||||||
|  |         # 将文件指针重置到开头 | ||||||
|  |         await file.seek(0) | ||||||
|  |         # 再保存文件 | ||||||
|  |         path = save_source_file(file, "face") | ||||||
|  |         # 提取人脸特征 | ||||||
|  |         detect_success, detect_result = add_binary_data(file_content) | ||||||
|  |         if not detect_success: | ||||||
|  |             raise HTTPException(status_code=400, detail=f"人脸检测失败:{detect_result}") | ||||||
|  |         eigenvalue = detect_result | ||||||
|  |  | ||||||
|  |         # 插入数据库 | ||||||
|  |         insert_query = """ | ||||||
|  |             INSERT INTO face (name, eigenvalue, address) | ||||||
|  |             VALUES (%s, %s, %s) | ||||||
|  |         """ | ||||||
|  |         cursor.execute(insert_query, (face_create.name, str(eigenvalue), path)) | ||||||
|  |         conn.commit() | ||||||
|  |  | ||||||
|  |         # 查询新记录 | ||||||
|  |         cursor.execute(""" | ||||||
|  |             SELECT id, name, address, created_at, updated_at  | ||||||
|  |             FROM face  | ||||||
|  |             WHERE id = LAST_INSERT_ID() | ||||||
|  |         """) | ||||||
|  |         created_face = cursor.fetchone() | ||||||
|  |         if not created_face: | ||||||
|  |             raise HTTPException(status_code=500, detail="创建成功但无法获取记录") | ||||||
|  |  | ||||||
|  |  | ||||||
|  |         # TODO 重新加载人脸模型 | ||||||
|  |         update_face_data() | ||||||
|  |  | ||||||
|  |  | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message=f"人脸记录创建成功(ID: {created_face['id']})", | ||||||
|  |             data=FaceResponse(**created_face) | ||||||
|  |         ) | ||||||
|  |     except MySQLError as e: | ||||||
|  |         if conn: | ||||||
|  |             conn.rollback() | ||||||
|  |         raise HTTPException(status_code=500, detail=f"创建失败: {str(e)}") from e | ||||||
|  |     except Exception as e: | ||||||
|  |         raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         await file.close() | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 获取单个人脸记录 | ||||||
|  | @router.get("/{face_id}", response_model=APIResponse, summary="获取单个人脸记录") | ||||||
|  | @encrypt_response() | ||||||
|  | async def get_face(face_id: int): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         query = """ | ||||||
|  |             SELECT id, name, address, created_at, updated_at  | ||||||
|  |             FROM face  | ||||||
|  |             WHERE id = %s | ||||||
|  |         """ | ||||||
|  |         cursor.execute(query, (face_id,)) | ||||||
|  |         face = cursor.fetchone() | ||||||
|  |  | ||||||
|  |         if not face: | ||||||
|  |             raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在") | ||||||
|  |  | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message="查询成功", | ||||||
|  |             data=FaceResponse(**face) | ||||||
|  |         ) | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 获取人脸列表 | ||||||
|  | @router.get("", response_model=APIResponse, summary="获取人脸列表(分页+筛选)") | ||||||
|  | @encrypt_response() | ||||||
|  | async def get_face_list( | ||||||
|  |         page: int = Query(1, ge=1), | ||||||
|  |         page_size: int = Query(10, ge=1, le=100), | ||||||
|  |         name: str = Query(None), | ||||||
|  |         has_eigenvalue: bool = Query(None) | ||||||
|  | ): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         where_clause = [] | ||||||
|  |         params = [] | ||||||
|  |         if name: | ||||||
|  |             where_clause.append("name LIKE %s") | ||||||
|  |             params.append(f"%{name}%") | ||||||
|  |         if has_eigenvalue is not None: | ||||||
|  |             where_clause.append("eigenvalue IS NOT NULL" if has_eigenvalue else "eigenvalue IS NULL") | ||||||
|  |  | ||||||
|  |         # 总记录数 | ||||||
|  |         count_query = "SELECT COUNT(*) AS total FROM face" | ||||||
|  |         if where_clause: | ||||||
|  |             count_query += " WHERE " + " AND ".join(where_clause) | ||||||
|  |         cursor.execute(count_query, params) | ||||||
|  |         total = cursor.fetchone()["total"] | ||||||
|  |  | ||||||
|  |         # 列表数据 | ||||||
|  |         offset = (page - 1) * page_size | ||||||
|  |         list_query = """ | ||||||
|  |             SELECT id, name, address, created_at, updated_at  | ||||||
|  |             FROM face | ||||||
|  |         """ | ||||||
|  |         if where_clause: | ||||||
|  |             list_query += " WHERE " + " AND ".join(where_clause) | ||||||
|  |         list_query += " ORDER BY id DESC LIMIT %s OFFSET %s" | ||||||
|  |         params.extend([page_size, offset]) | ||||||
|  |  | ||||||
|  |         cursor.execute(list_query, params) | ||||||
|  |         face_list = cursor.fetchall() | ||||||
|  |  | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message=f"获取成功(共{total}条)", | ||||||
|  |             data=FaceListResponse( | ||||||
|  |                 total=total, | ||||||
|  |                 faces=[FaceResponse(**face) for face in face_list] | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 删除人脸记录 | ||||||
|  | @router.delete("/{face_id}", response_model=APIResponse, summary="删除人脸记录") | ||||||
|  | @encrypt_response() | ||||||
|  | async def delete_face(face_id: int): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         cursor.execute("SELECT id, address FROM face WHERE id = %s", (face_id,)) | ||||||
|  |         exist_face = cursor.fetchone() | ||||||
|  |         if not exist_face: | ||||||
|  |             raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在") | ||||||
|  |         old_db_path = exist_face["address"] | ||||||
|  |  | ||||||
|  |         cursor.execute("DELETE FROM face WHERE id = %s", (face_id,)) | ||||||
|  |         conn.commit() | ||||||
|  |  | ||||||
|  |         # 删除图片 | ||||||
|  |         if old_db_path: | ||||||
|  |             old_abs_path = Path(old_db_path).resolve() | ||||||
|  |             if old_abs_path.exists(): | ||||||
|  |                 try: | ||||||
|  |                     old_abs_path.unlink() | ||||||
|  |                     print(f"[FaceRouter] 已删除图片:{old_abs_path}") | ||||||
|  |                     extra_msg = "(已同步删除图片)" | ||||||
|  |                 except Exception as e: | ||||||
|  |                     print(f"[FaceRouter] 删除图片失败:{str(e)}") | ||||||
|  |                     extra_msg = "(图片删除失败)" | ||||||
|  |             else: | ||||||
|  |                 extra_msg = "(图片不存在)" | ||||||
|  |         else: | ||||||
|  |             extra_msg = "(无关联图片)" | ||||||
|  |  | ||||||
|  |  | ||||||
|  |         # TODO 重新加载人脸模型 | ||||||
|  |         update_face_data() | ||||||
|  |  | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message=f"ID为 {face_id} 的记录删除成功 {extra_msg}", | ||||||
|  |             data=None | ||||||
|  |         ) | ||||||
|  |     except MySQLError as e: | ||||||
|  |         if conn: | ||||||
|  |             conn.rollback() | ||||||
|  |         raise HTTPException(status_code=500, detail=f"删除失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @router.post("/batch-import", response_model=APIResponse, summary="批量导入文件夹下的人脸图片") | ||||||
|  | # @encrypt_response() | ||||||
|  | async def batch_import_faces( | ||||||
|  |     folder_path: str = Form(..., description="人脸图片所在的**服务器本地文件夹路径**") | ||||||
|  | ): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     success_count = 0  # 成功导入数量 | ||||||
|  |     fail_list = []     # 失败记录(包含文件名、错误原因) | ||||||
|  |     try: | ||||||
|  |         # 1. 验证文件夹有效性 | ||||||
|  |         folder = Path(folder_path) | ||||||
|  |         if not folder.exists() or not folder.is_dir(): | ||||||
|  |             raise HTTPException(status_code=400, detail=f"文件夹 {folder_path} 不存在或不是有效目录") | ||||||
|  |  | ||||||
|  |         # 2. 定义支持的图片格式 | ||||||
|  |         supported_extensions = {".png", ".jpg", ".jpeg", ".webp"} | ||||||
|  |  | ||||||
|  |         # 3. 数据库连接初始化 | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 4. 遍历文件夹内所有文件 | ||||||
|  |         for file_path in folder.iterdir(): | ||||||
|  |             if file_path.is_file() and file_path.suffix.lower() in supported_extensions: | ||||||
|  |                 file_name = file_path.stem  # 提取文件名(不含后缀)作为 `name` | ||||||
|  |                 try: | ||||||
|  |                     # 4.1 读取文件二进制内容 | ||||||
|  |                     with open(file_path, "rb") as f: | ||||||
|  |                         file_content = f.read() | ||||||
|  |  | ||||||
|  |                     # 4.2 构造模拟的 UploadFile 对象(用于兼容 `save_source_file`) | ||||||
|  |                     mock_file = UploadFile( | ||||||
|  |                         filename=file_path.name, | ||||||
|  |                         file=BytesIO(file_content) | ||||||
|  |                     ) | ||||||
|  |  | ||||||
|  |                     # 4.3 保存文件到指定目录 | ||||||
|  |                     saved_path = save_source_file(mock_file, "face") | ||||||
|  |  | ||||||
|  |                     # 4.4 提取人脸特征 | ||||||
|  |                     detect_success, detect_result = add_binary_data(file_content) | ||||||
|  |                     if not detect_success: | ||||||
|  |                         fail_list.append({ | ||||||
|  |                             "name": file_name, | ||||||
|  |                             "file_path": str(file_path), | ||||||
|  |                             "error": f"人脸检测失败:{detect_result}" | ||||||
|  |                         }) | ||||||
|  |                         continue  # 跳过当前文件,处理下一个 | ||||||
|  |                     eigenvalue = detect_result | ||||||
|  |  | ||||||
|  |                     # 4.5 插入数据库 | ||||||
|  |                     insert_sql = """ | ||||||
|  |                         INSERT INTO face (name, eigenvalue, address) | ||||||
|  |                         VALUES (%s, %s, %s) | ||||||
|  |                     """ | ||||||
|  |                     cursor.execute(insert_sql, (file_name, str(eigenvalue), saved_path)) | ||||||
|  |                     conn.commit()  # 提交当前文件的插入操作 | ||||||
|  |  | ||||||
|  |                     success_count += 1 | ||||||
|  |  | ||||||
|  |                 except Exception as e: | ||||||
|  |                     # 捕获单文件处理的异常,记录后继续处理其他文件 | ||||||
|  |                     fail_list.append({ | ||||||
|  |                         "name": file_name, | ||||||
|  |                         "file_path": str(file_path), | ||||||
|  |                         "error": str(e) | ||||||
|  |                     }) | ||||||
|  |                     if conn: | ||||||
|  |                         conn.rollback()  # 回滚当前失败文件的插入 | ||||||
|  |  | ||||||
|  |         # 5. 重新加载人脸模型(确保新增数据生效) | ||||||
|  |         update_face_data() | ||||||
|  |  | ||||||
|  |         # 6. 构造返回结果 | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message=f"批量导入完成,成功 {success_count} 条,失败 {len(fail_list)} 条", | ||||||
|  |             data={ | ||||||
|  |                 "success_count": success_count, | ||||||
|  |                 "fail_details": fail_list | ||||||
|  |             } | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     except MySQLError as e: | ||||||
|  |         if conn: | ||||||
|  |             conn.rollback() | ||||||
|  |         raise HTTPException(status_code=500, detail=f"数据库操作失败: {str(e)}") from e | ||||||
|  |     except HTTPException: | ||||||
|  |         raise  # 直接抛出400等由业务逻辑触发的HTTP异常 | ||||||
|  |     except Exception as e: | ||||||
|  |         raise HTTPException(status_code=500, detail=f"服务器内部错误: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
							
								
								
									
										31
									
								
								router/file_router.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,31 @@ | |||||||
|  | import os | ||||||
|  | from fastapi import FastAPI, HTTPException, Path, APIRouter | ||||||
|  | from fastapi.responses import FileResponse | ||||||
|  | from service.file_service import UPLOAD_ROOT | ||||||
|  |  | ||||||
|  | router = APIRouter( | ||||||
|  |     prefix="/api/file", | ||||||
|  |     tags=["文件管理"] | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @router.get("/download/{relative_path:path}", summary="下载文件") | ||||||
|  | async def download_file( | ||||||
|  |         relative_path: str = Path(..., description="文件的相对路径") | ||||||
|  | ): | ||||||
|  |     file_path = os.path.abspath(os.path.join(UPLOAD_ROOT, relative_path)) | ||||||
|  |  | ||||||
|  |     if not os.path.exists(file_path): | ||||||
|  |         raise HTTPException(status_code=404, detail=f"文件不存在: {file_path}") | ||||||
|  |  | ||||||
|  |     if not os.path.isfile(file_path): | ||||||
|  |         raise HTTPException(status_code=400, detail="路径指向的不是文件") | ||||||
|  |  | ||||||
|  |     if not file_path.startswith(os.path.abspath(UPLOAD_ROOT)): | ||||||
|  |         raise HTTPException(status_code=403, detail="无权访问该文件") | ||||||
|  |  | ||||||
|  |     return FileResponse( | ||||||
|  |         path=file_path, | ||||||
|  |         filename=os.path.basename(file_path), | ||||||
|  |         media_type="application/octet-stream" | ||||||
|  |     ) | ||||||
							
								
								
									
										269
									
								
								router/model_router.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,269 @@ | |||||||
|  | import os | ||||||
|  | from pathlib import Path | ||||||
|  | from service.file_service import save_source_file | ||||||
|  |  | ||||||
|  | from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query | ||||||
|  | from mysql.connector import Error as MySQLError | ||||||
|  |  | ||||||
|  | from ds.db import db | ||||||
|  | from encryption.encrypt_decorator import encrypt_response | ||||||
|  | from schema.model_schema import ( | ||||||
|  |     ModelResponse, | ||||||
|  |     ModelListResponse | ||||||
|  | ) | ||||||
|  | from schema.response_schema import APIResponse | ||||||
|  | from service.model_service import ALLOWED_MODEL_EXT, MAX_MODEL_SIZE, load_yolo_model | ||||||
|  |  | ||||||
|  | router = APIRouter(prefix="/api/models", tags=["模型管理"]) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 上传模型 | ||||||
|  | @router.post("", response_model=APIResponse, summary="上传YOLO模型(.pt格式)") | ||||||
|  | @encrypt_response() | ||||||
|  | async def upload_model( | ||||||
|  |         name: str = Form(..., description="模型名称"), | ||||||
|  |         description: str = Form(None, description="模型描述"), | ||||||
|  |         file: UploadFile = File(..., description=f"YOLO模型文件(.pt、最大{MAX_MODEL_SIZE // 1024 // 1024}MB)") | ||||||
|  | ): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         # 校验文件格式 | ||||||
|  |         file_ext = file.filename.split(".")[-1].lower() if "." in file.filename else "" | ||||||
|  |         if file_ext not in ALLOWED_MODEL_EXT: | ||||||
|  |             raise HTTPException( | ||||||
|  |                 status_code=400, | ||||||
|  |                 detail=f"仅支持{ALLOWED_MODEL_EXT}格式、当前:{file_ext}" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         # 校验文件大小 | ||||||
|  |         if file.size > MAX_MODEL_SIZE: | ||||||
|  |             raise HTTPException( | ||||||
|  |                 status_code=400, | ||||||
|  |                 detail=f"文件过大!最大{MAX_MODEL_SIZE // 1024 // 1024}MB、当前{file.size // 1024 // 1024}MB" | ||||||
|  |             ) | ||||||
|  |         # 保存文件 | ||||||
|  |         file_path = save_source_file(file, "model") | ||||||
|  |  | ||||||
|  |         # 数据库操作 | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         insert_sql = """ | ||||||
|  |             INSERT INTO model (name, path, is_default, description, file_size) | ||||||
|  |             VALUES (%s, %s, 0, %s, %s) | ||||||
|  |         """ | ||||||
|  |         cursor.execute(insert_sql, (name, file_path, description, file.size)) | ||||||
|  |         conn.commit() | ||||||
|  |  | ||||||
|  |         # 获取新增记录 | ||||||
|  |         cursor.execute("SELECT * FROM model WHERE id = LAST_INSERT_ID()") | ||||||
|  |         new_model = cursor.fetchone() | ||||||
|  |         if not new_model: | ||||||
|  |             raise HTTPException(status_code=500, detail="上传成功但无法获取记录") | ||||||
|  |  | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message=f"模型上传成功", | ||||||
|  |             data=ModelResponse(**new_model) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     except MySQLError as e: | ||||||
|  |         if conn: | ||||||
|  |             conn.rollback() | ||||||
|  |         raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e | ||||||
|  |     except Exception as e: | ||||||
|  |         raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         await file.close() | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 获取模型列表 | ||||||
|  | @router.get("", response_model=APIResponse, summary="获取模型列表(分页)") | ||||||
|  | @encrypt_response() | ||||||
|  | async def get_model_list( | ||||||
|  |         page: int = Query(1, ge=1), | ||||||
|  |         page_size: int = Query(10, ge=1, le=100), | ||||||
|  |         name: str = Query(None), | ||||||
|  |         is_default: bool = Query(None) | ||||||
|  | ): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         where_clause = [] | ||||||
|  |         params = [] | ||||||
|  |         if name: | ||||||
|  |             where_clause.append("name LIKE %s") | ||||||
|  |             params.append(f"%{name}%") | ||||||
|  |         if is_default is not None: | ||||||
|  |             where_clause.append("is_default = %s") | ||||||
|  |             params.append(1 if is_default else 0) | ||||||
|  |  | ||||||
|  |         # 总记录数 | ||||||
|  |         count_sql = "SELECT COUNT(*) AS total FROM model" | ||||||
|  |         if where_clause: | ||||||
|  |             count_sql += " WHERE " + " AND ".join(where_clause) | ||||||
|  |         cursor.execute(count_sql, params) | ||||||
|  |         total = cursor.fetchone()["total"] | ||||||
|  |  | ||||||
|  |         # 分页数据 | ||||||
|  |         offset = (page - 1) * page_size | ||||||
|  |         list_sql = "SELECT * FROM model" | ||||||
|  |         if where_clause: | ||||||
|  |             list_sql += " WHERE " + " AND ".join(where_clause) | ||||||
|  |         list_sql += " ORDER BY updated_at DESC LIMIT %s OFFSET %s" | ||||||
|  |         params.extend([page_size, offset]) | ||||||
|  |  | ||||||
|  |         cursor.execute(list_sql, params) | ||||||
|  |         model_list = cursor.fetchall() | ||||||
|  |  | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message=f"获取成功!", | ||||||
|  |             data=ModelListResponse( | ||||||
|  |                 total=total, | ||||||
|  |                 models=[ModelResponse(**model) for model in model_list] | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 更换默认模型 | ||||||
|  | @router.put("/{model_id}/set-default", response_model=APIResponse, summary="更换默认模型") | ||||||
|  | @encrypt_response() | ||||||
|  | async def set_default_model( | ||||||
|  |         model_id: int | ||||||
|  | ): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |         conn.autocommit = False | ||||||
|  |  | ||||||
|  |         # 校验目标模型是否存在 | ||||||
|  |         cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,)) | ||||||
|  |         target_model = cursor.fetchone() | ||||||
|  |         if not target_model: | ||||||
|  |             raise HTTPException(status_code=404, detail=f"目标模型不存在!") | ||||||
|  |  | ||||||
|  |         # 检查是否已为默认模型 | ||||||
|  |         if target_model["is_default"]: | ||||||
|  |             return APIResponse( | ||||||
|  |                 code=200, | ||||||
|  |                 message=f"已是默认模型、无需更换", | ||||||
|  |                 data=ModelResponse(**target_model) | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         # 数据库事务:更新默认模型状态 | ||||||
|  |         try: | ||||||
|  |             cursor.execute("UPDATE model SET is_default = 0, updated_at = CURRENT_TIMESTAMP") | ||||||
|  |             cursor.execute( | ||||||
|  |                 "UPDATE model SET is_default = 1, updated_at = CURRENT_TIMESTAMP WHERE id = %s", | ||||||
|  |                 (model_id,) | ||||||
|  |             ) | ||||||
|  |             conn.commit() | ||||||
|  |         except MySQLError as e: | ||||||
|  |             conn.rollback() | ||||||
|  |             raise HTTPException( | ||||||
|  |                 status_code=500, | ||||||
|  |                 detail=f"更新默认模型状态失败(已回滚):{str(e)}" | ||||||
|  |             ) from e | ||||||
|  |  | ||||||
|  |         # 更新模型 | ||||||
|  |         load_yolo_model() | ||||||
|  |         # 返回成功响应 | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message=f"更换成功", | ||||||
|  |             data=None | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     except MySQLError as e: | ||||||
|  |         if conn: | ||||||
|  |             conn.rollback() | ||||||
|  |         raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         if conn: | ||||||
|  |             conn.autocommit = True | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 路由文件(如 model_router.py)中的删除接口 | ||||||
|  | @router.delete("/{model_id}", response_model=APIResponse, summary="删除模型") | ||||||
|  | @encrypt_response() | ||||||
|  | async def delete_model(model_id: int): | ||||||
|  |     # 1. 正确导入 model_service 中的全局变量(关键修复:变量名匹配) | ||||||
|  |     from service.model_service import ( | ||||||
|  |         current_yolo_model, | ||||||
|  |         current_model_absolute_path, | ||||||
|  |         load_yolo_model  # 用于删除后重新加载模型(可选) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 2. 查询待删除模型信息 | ||||||
|  |         cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,)) | ||||||
|  |         exist_model = cursor.fetchone() | ||||||
|  |         if not exist_model: | ||||||
|  |             raise HTTPException(status_code=404, detail=f"模型不存在!") | ||||||
|  |  | ||||||
|  |         # 3. 关键判断:①默认模型不可删 ②正在使用的模型不可删 | ||||||
|  |         if exist_model["is_default"]: | ||||||
|  |             raise HTTPException(status_code=400, detail="默认模型不可删除!") | ||||||
|  |  | ||||||
|  |         # 计算待删除模型的绝对路径(与 model_service 逻辑一致) | ||||||
|  |         from service.file_service import get_absolute_path | ||||||
|  |         del_model_abs_path = get_absolute_path(exist_model["path"]) | ||||||
|  |  | ||||||
|  |         # 判断是否正在使用(对比 current_model_absolute_path) | ||||||
|  |         if current_model_absolute_path and del_model_abs_path == current_model_absolute_path: | ||||||
|  |             raise HTTPException(status_code=400, detail="该模型正在使用中,禁止删除!") | ||||||
|  |  | ||||||
|  |         # 4. 先删除数据库记录(避免文件删除失败导致数据不一致) | ||||||
|  |         cursor.execute("DELETE FROM model WHERE id = %s", (model_id,)) | ||||||
|  |         conn.commit() | ||||||
|  |  | ||||||
|  |         # 5. 再删除本地文件(捕获文件删除异常,不影响数据库删除结果) | ||||||
|  |         extra_msg = "" | ||||||
|  |         try: | ||||||
|  |             if os.path.exists(del_model_abs_path): | ||||||
|  |                 os.remove(del_model_abs_path)  # 或用 Path(del_model_abs_path).unlink() | ||||||
|  |                 extra_msg = "(本地文件已同步删除)" | ||||||
|  |             else: | ||||||
|  |                 extra_msg = "(本地文件不存在,无需删除)" | ||||||
|  |         except Exception as e: | ||||||
|  |             extra_msg = f"(本地文件删除失败:{str(e)})" | ||||||
|  |  | ||||||
|  |         # 6. 若删除后当前模型为空(极端情况),重新加载默认模型(可选优化) | ||||||
|  |         if current_yolo_model is None: | ||||||
|  |             try: | ||||||
|  |                 load_yolo_model() | ||||||
|  |                 print(f"[模型删除后] 重新加载默认模型成功") | ||||||
|  |             except Exception as e: | ||||||
|  |                 print(f"[模型删除后] 重新加载默认模型失败:{str(e)}") | ||||||
|  |  | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message=f"模型删除成功!", | ||||||
|  |             data=None | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     except MySQLError as e: | ||||||
|  |         if conn: | ||||||
|  |             conn.rollback() | ||||||
|  |         raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
							
								
								
									
										306
									
								
								router/sensitive_router.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,306 @@ | |||||||
|  | from fastapi import APIRouter, Depends, HTTPException, Query, File, UploadFile | ||||||
|  | from mysql.connector import Error as MySQLError | ||||||
|  | from typing import Optional | ||||||
|  |  | ||||||
|  | from ds.db import db | ||||||
|  | from encryption.encrypt_decorator import encrypt_response | ||||||
|  | from schema.sensitive_schema import ( | ||||||
|  |     SensitiveCreateRequest, | ||||||
|  |     SensitiveResponse, | ||||||
|  |     SensitiveListResponse | ||||||
|  | ) | ||||||
|  | from schema.response_schema import APIResponse | ||||||
|  | from middle.auth_middleware import get_current_user | ||||||
|  | from schema.user_schema import UserResponse | ||||||
|  | from service.ocr_service import  set_forbidden_words | ||||||
|  | from service.sensitive_service import get_all_sensitive_words | ||||||
|  |  | ||||||
|  | router = APIRouter( | ||||||
|  |     prefix="/api/sensitives", | ||||||
|  |     tags=["敏感信息管理"] | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 创建敏感信息记录 | ||||||
|  | @router.post("", response_model=APIResponse, summary="创建敏感信息记录") | ||||||
|  | @encrypt_response() | ||||||
|  | async def create_sensitive( | ||||||
|  |         sensitive: SensitiveCreateRequest, | ||||||
|  |         current_user: UserResponse = Depends(get_current_user) | ||||||
|  | ): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 插入新敏感信息记录到数据库(不包含ID、由数据库自动生成) | ||||||
|  |         insert_query = """ | ||||||
|  |             INSERT INTO sensitives (name, created_at, updated_at) | ||||||
|  |             VALUES (%s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) | ||||||
|  |         """ | ||||||
|  |         cursor.execute(insert_query, (sensitive.name,)) | ||||||
|  |         conn.commit() | ||||||
|  |  | ||||||
|  |         # 获取刚插入记录的ID(使用LAST_INSERT_ID()函数) | ||||||
|  |         new_id = cursor.lastrowid | ||||||
|  |  | ||||||
|  |         # 查询刚创建的记录并返回 | ||||||
|  |         select_query = "SELECT * FROM sensitives WHERE id = %s" | ||||||
|  |         cursor.execute(select_query, (new_id,)) | ||||||
|  |         created_sensitive = cursor.fetchone() | ||||||
|  |  | ||||||
|  |         # 重新加载最新的敏感词 | ||||||
|  |         set_forbidden_words(get_all_sensitive_words()) | ||||||
|  |  | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message="敏感信息记录创建成功", | ||||||
|  |             data=SensitiveResponse(**created_sensitive) | ||||||
|  |         ) | ||||||
|  |     except MySQLError as e: | ||||||
|  |         if conn: | ||||||
|  |             conn.rollback() | ||||||
|  |         raise HTTPException( | ||||||
|  |             status_code=500, | ||||||
|  |             detail=f"创建敏感信息记录失败: {str(e)}" | ||||||
|  |         ) from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 获取敏感信息分页列表 | ||||||
|  | @router.get("", response_model=APIResponse, summary="获取敏感信息分页列表(支持关键词搜索)") | ||||||
|  | @encrypt_response() | ||||||
|  | async def get_sensitive_list( | ||||||
|  |         page: int = Query(1, ge=1, description="页码(默认1、最小1)"), | ||||||
|  |         page_size: int = Query(10, ge=1, le=100, description="每页条数(默认10、1-100)"), | ||||||
|  |         name: Optional[str] = Query(None, description="敏感词关键词搜索(模糊匹配)") | ||||||
|  | ): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 1. 构建查询条件(支持关键词搜索) | ||||||
|  |         where_clause = [] | ||||||
|  |         params = [] | ||||||
|  |         if name: | ||||||
|  |             where_clause.append("name LIKE %s") | ||||||
|  |             params.append(f"%{name}%")  # 模糊匹配关键词 | ||||||
|  |  | ||||||
|  |         # 2. 查询总记录数(用于分页计算) | ||||||
|  |         count_sql = "SELECT COUNT(*) AS total FROM sensitives" | ||||||
|  |         if where_clause: | ||||||
|  |             count_sql += " WHERE " + " AND ".join(where_clause) | ||||||
|  |         cursor.execute(count_sql, params.copy())  # 复制参数列表、避免后续污染 | ||||||
|  |         total = cursor.fetchone()["total"] | ||||||
|  |  | ||||||
|  |         # 3. 计算分页偏移量 | ||||||
|  |         offset = (page - 1) * page_size | ||||||
|  |  | ||||||
|  |         # 4. 分页查询敏感词数据(按更新时间倒序、最新的在前) | ||||||
|  |         list_sql = "SELECT * FROM sensitives" | ||||||
|  |         if where_clause: | ||||||
|  |             list_sql += " WHERE " + " AND ".join(where_clause) | ||||||
|  |         # 排序+分页(LIMIT 条数 OFFSET 偏移量) | ||||||
|  |         list_sql += " ORDER BY updated_at DESC LIMIT %s OFFSET %s" | ||||||
|  |         # 补充分页参数(page_size和offset) | ||||||
|  |         params.extend([page_size, offset]) | ||||||
|  |  | ||||||
|  |         cursor.execute(list_sql, params) | ||||||
|  |         sensitive_list = cursor.fetchall() | ||||||
|  |  | ||||||
|  |         # 5. 构造分页响应数据 | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message=f"敏感信息列表查询成功(共{total}条记录、当前第{page}页)", | ||||||
|  |             data=SensitiveListResponse( | ||||||
|  |                 total=total, | ||||||
|  |                 sensitives=[SensitiveResponse(**item) for item in sensitive_list] | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise HTTPException( | ||||||
|  |             status_code=500, | ||||||
|  |             detail=f"查询敏感信息列表失败: {str(e)}" | ||||||
|  |         ) from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 删除敏感信息记录 | ||||||
|  | @router.delete("/{sensitive_id}", response_model=APIResponse, summary="删除敏感信息记录") | ||||||
|  | @encrypt_response() | ||||||
|  | async def delete_sensitive( | ||||||
|  |         sensitive_id: int, | ||||||
|  |         current_user: UserResponse = Depends(get_current_user)  # 需登录认证 | ||||||
|  | ): | ||||||
|  |     """ | ||||||
|  |     删除敏感信息记录: | ||||||
|  |     - 需登录认证 | ||||||
|  |     - 根据ID删除敏感信息记录 | ||||||
|  |     - 返回删除成功信息 | ||||||
|  |     """ | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 1. 检查记录是否存在 | ||||||
|  |         check_query = "SELECT id FROM sensitives WHERE id = %s" | ||||||
|  |         cursor.execute(check_query, (sensitive_id,)) | ||||||
|  |         existing_sensitive = cursor.fetchone() | ||||||
|  |         if not existing_sensitive: | ||||||
|  |             raise HTTPException( | ||||||
|  |                 status_code=404, | ||||||
|  |                 detail=f"ID为 {sensitive_id} 的敏感信息记录不存在" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         # 2. 执行删除操作 | ||||||
|  |         delete_query = "DELETE FROM sensitives WHERE id = %s" | ||||||
|  |         cursor.execute(delete_query, (sensitive_id,)) | ||||||
|  |         conn.commit() | ||||||
|  |  | ||||||
|  |         # 重新加载最新的敏感词 | ||||||
|  |         set_forbidden_words(get_all_sensitive_words()) | ||||||
|  |  | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message=f"ID为 {sensitive_id} 的敏感信息记录删除成功", | ||||||
|  |             data=None | ||||||
|  |         ) | ||||||
|  |     except MySQLError as e: | ||||||
|  |         if conn: | ||||||
|  |             conn.rollback() | ||||||
|  |         raise HTTPException( | ||||||
|  |             status_code=500, | ||||||
|  |             detail=f"删除敏感信息记录失败: {str(e)}" | ||||||
|  |         ) from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 批量导入敏感信息(从txt文件) | ||||||
|  | @router.post("/batch-import", response_model=APIResponse, summary="批量导入敏感信息(从txt文件)") | ||||||
|  | @encrypt_response() | ||||||
|  | async def batch_import_sensitives( | ||||||
|  |         file: UploadFile = File(..., description="包含敏感词的txt文件,每行一个敏感词"), | ||||||
|  |         # current_user: UserResponse = Depends(get_current_user)  # 添加认证依赖 | ||||||
|  | ): | ||||||
|  |     """ | ||||||
|  |     批量导入敏感信息: | ||||||
|  |     - 需登录认证 | ||||||
|  |     - 接收txt文件,文件中每行一个敏感词 | ||||||
|  |     - 批量插入到数据库中(仅插入不存在的敏感词) | ||||||
|  |     - 返回导入结果统计 | ||||||
|  |     """ | ||||||
|  |     # 检查文件类型 | ||||||
|  |     filename = file.filename or "" | ||||||
|  |     if not filename.lower().endswith(".txt"): | ||||||
|  |         raise HTTPException( | ||||||
|  |             status_code=400, | ||||||
|  |             detail=f"请上传txt格式的文件,当前文件格式: {filename.split('.')[-1] if '.' in filename else '未知'}" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     # 检查文件大小 | ||||||
|  |     file_size = await file.read(1)  # 读取1字节获取文件信息 | ||||||
|  |     await file.seek(0)  # 重置文件指针 | ||||||
|  |     if not file_size:  # 文件为空 | ||||||
|  |         raise HTTPException( | ||||||
|  |             status_code=400, | ||||||
|  |             detail="上传的文件为空,请提供有效的敏感词文件" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         # 读取文件内容 | ||||||
|  |         contents = await file.read() | ||||||
|  |         # 按行分割内容,处理不同操作系统的换行符 | ||||||
|  |         lines = contents.decode("utf-8", errors="replace").splitlines() | ||||||
|  |  | ||||||
|  |         # 过滤空行和仅含空白字符的行 | ||||||
|  |         sensitive_words = [line.strip() for line in lines if line.strip()] | ||||||
|  |  | ||||||
|  |         if not sensitive_words: | ||||||
|  |             return APIResponse( | ||||||
|  |                 code=200, | ||||||
|  |                 message="文件中没有有效的敏感词", | ||||||
|  |                 data={"imported": 0, "total": 0} | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 先查询数据库中已存在的敏感词 | ||||||
|  |         query = "SELECT name FROM sensitives WHERE name IN (%s)" | ||||||
|  |         # 处理参数,根据敏感词数量生成占位符 | ||||||
|  |         placeholders = ', '.join(['%s'] * len(sensitive_words)) | ||||||
|  |         cursor.execute(query % placeholders, sensitive_words) | ||||||
|  |         existing_words = {row['name'] for row in cursor.fetchall()} | ||||||
|  |  | ||||||
|  |         # 过滤掉已存在的敏感词 | ||||||
|  |         new_words = [word for word in sensitive_words if word not in existing_words] | ||||||
|  |  | ||||||
|  |         if not new_words: | ||||||
|  |             return APIResponse( | ||||||
|  |                 code=200, | ||||||
|  |                 message="所有敏感词均已存在于数据库中", | ||||||
|  |                 data={ | ||||||
|  |                     "total": len(sensitive_words), | ||||||
|  |                     "imported": 0, | ||||||
|  |                     "duplicates": len(sensitive_words), | ||||||
|  |                     "message": f"共处理{len(sensitive_words)}个敏感词,全部已存在,未导入任何新敏感词" | ||||||
|  |                 } | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         # 批量插入新的敏感词 | ||||||
|  |         insert_query = """ | ||||||
|  |             INSERT INTO sensitives (name, created_at, updated_at) | ||||||
|  |             VALUES (%s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         # 准备参数列表 | ||||||
|  |         params = [(word,) for word in new_words] | ||||||
|  |  | ||||||
|  |         # 执行批量插入 | ||||||
|  |         cursor.executemany(insert_query, params) | ||||||
|  |         conn.commit() | ||||||
|  |  | ||||||
|  |         # 重新加载最新的敏感词 | ||||||
|  |         set_forbidden_words(get_all_sensitive_words()) | ||||||
|  |  | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message=f"敏感词批量导入成功", | ||||||
|  |             data={ | ||||||
|  |                 "total": len(sensitive_words), | ||||||
|  |                 "imported": len(new_words), | ||||||
|  |                 "duplicates": len(sensitive_words) - len(new_words), | ||||||
|  |                 "message": f"共处理{len(sensitive_words)}个敏感词,成功导入{len(new_words)}个,{len(sensitive_words) - len(new_words)}个已存在" | ||||||
|  |             } | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     except UnicodeDecodeError: | ||||||
|  |         raise HTTPException( | ||||||
|  |             status_code=400, | ||||||
|  |             detail="文件编码格式错误,请使用UTF-8编码的txt文件" | ||||||
|  |         ) | ||||||
|  |     except MySQLError as e: | ||||||
|  |         if conn: | ||||||
|  |             conn.rollback() | ||||||
|  |         raise HTTPException( | ||||||
|  |             status_code=500, | ||||||
|  |             detail=f"批量导入敏感词失败: {str(e)}" | ||||||
|  |         ) from e | ||||||
|  |     except Exception as e: | ||||||
|  |         raise HTTPException( | ||||||
|  |             status_code=500, | ||||||
|  |             detail=f"处理文件时发生错误: {str(e)}" | ||||||
|  |         ) from e | ||||||
|  |     finally: | ||||||
|  |         await file.close() | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
							
								
								
									
										246
									
								
								router/user_router.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,246 @@ | |||||||
|  | from datetime import timedelta | ||||||
|  | from typing import Optional | ||||||
|  |  | ||||||
|  | from fastapi import APIRouter, Depends, HTTPException, Query | ||||||
|  | from mysql.connector import Error as MySQLError | ||||||
|  |  | ||||||
|  | from ds.db import db | ||||||
|  | from encryption.encrypt_decorator import encrypt_response | ||||||
|  | from middle.auth_middleware import ( | ||||||
|  |     get_password_hash, | ||||||
|  |     verify_password, | ||||||
|  |     create_access_token, | ||||||
|  |     ACCESS_TOKEN_EXPIRE_MINUTES, | ||||||
|  |     get_current_user | ||||||
|  | ) | ||||||
|  | from schema.response_schema import APIResponse | ||||||
|  | from schema.user_schema import UserRegisterRequest, UserLoginRequest, UserResponse | ||||||
|  |  | ||||||
|  | router = APIRouter( | ||||||
|  |     prefix="/api/users", | ||||||
|  |     tags=["用户管理"] | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 用户注册接口 | ||||||
|  | @router.post("/register", response_model=APIResponse, summary="用户注册") | ||||||
|  | @encrypt_response() | ||||||
|  | async def user_register(request: UserRegisterRequest): | ||||||
|  |     """ | ||||||
|  |     用户注册: | ||||||
|  |     - 校验用户名是否已存在 | ||||||
|  |     - 加密密码后插入数据库 | ||||||
|  |     - 返回注册成功信息 | ||||||
|  |     """ | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 1. 检查用户名是否已存在(唯一索引) | ||||||
|  |         check_query = "SELECT username FROM users WHERE username = %s" | ||||||
|  |         cursor.execute(check_query, (request.username,)) | ||||||
|  |         existing_user = cursor.fetchone() | ||||||
|  |         if existing_user: | ||||||
|  |             raise HTTPException( | ||||||
|  |                 status_code=400, | ||||||
|  |                 detail=f"用户名 '{request.username}' 已存在、请更换其他用户名" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         # 2. 加密密码 | ||||||
|  |         hashed_password = get_password_hash(request.password) | ||||||
|  |  | ||||||
|  |         # 3. 插入新用户到数据库 | ||||||
|  |         insert_query = """ | ||||||
|  |             INSERT INTO users (username, password) | ||||||
|  |             VALUES (%s, %s) | ||||||
|  |         """ | ||||||
|  |         cursor.execute(insert_query, (request.username, hashed_password)) | ||||||
|  |         conn.commit()  # 提交事务 | ||||||
|  |  | ||||||
|  |         # 4. 返回注册成功响应 | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200,  # 200 表示资源创建成功 | ||||||
|  |             message=f"用户 '{request.username}' 注册成功", | ||||||
|  |             data=None | ||||||
|  |         ) | ||||||
|  |     except MySQLError as e: | ||||||
|  |         conn.rollback()  # 数据库错误时回滚事务 | ||||||
|  |         raise Exception(f"注册失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 用户登录接口 | ||||||
|  | @router.post("/login", response_model=APIResponse, summary="用户登录(获取 Token)") | ||||||
|  | @encrypt_response() | ||||||
|  | async def user_login(request: UserLoginRequest): | ||||||
|  |     """ | ||||||
|  |     用户登录: | ||||||
|  |     - 校验用户名是否存在 | ||||||
|  |     - 校验密码是否正确 | ||||||
|  |     - 生成 JWT Token 并返回 | ||||||
|  |     """ | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 修复: SQL查询添加 created_at 和 updated_at 字段 | ||||||
|  |         query = """ | ||||||
|  |             SELECT id, username, password, created_at, updated_at  | ||||||
|  |             FROM users  | ||||||
|  |             WHERE username = %s | ||||||
|  |         """ | ||||||
|  |         cursor.execute(query, (request.username,)) | ||||||
|  |         user = cursor.fetchone() | ||||||
|  |  | ||||||
|  |         # 2. 校验用户名和密码 | ||||||
|  |         if not user or not verify_password(request.password, user["password"]): | ||||||
|  |             raise HTTPException( | ||||||
|  |                 status_code=401, | ||||||
|  |                 detail="用户名或密码错误", | ||||||
|  |                 headers={"WWW-Authenticate": "Bearer"}, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         # 3. 生成 Token(过期时间从配置读取) | ||||||
|  |         access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) | ||||||
|  |         access_token = create_access_token( | ||||||
|  |             data={"sub": user["username"]}, | ||||||
|  |             expires_delta=access_token_expires | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         # 4. 返回 Token 和用户基本信息 | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message="登录成功", | ||||||
|  |             data={ | ||||||
|  |                 "access_token": access_token, | ||||||
|  |                 "token_type": "bearer", | ||||||
|  |                 "user": UserResponse( | ||||||
|  |                     id=user["id"], | ||||||
|  |                     username=user["username"], | ||||||
|  |                     created_at=user.get("created_at"), | ||||||
|  |                     updated_at=user.get("updated_at") | ||||||
|  |                 ) | ||||||
|  |             } | ||||||
|  |         ) | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise Exception(f"登录失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 获取当前登录用户信息(需认证) | ||||||
|  | @router.get("/me", response_model=APIResponse, summary="获取当前用户信息") | ||||||
|  | @encrypt_response() | ||||||
|  | async def get_current_user_info( | ||||||
|  |         current_user: UserResponse = Depends(get_current_user)  # 依赖认证中间件 | ||||||
|  | ): | ||||||
|  |     """ | ||||||
|  |     获取当前登录用户信息: | ||||||
|  |     - 需在请求头携带 Token(格式: Bearer <token>) | ||||||
|  |     - 认证通过后返回用户信息 | ||||||
|  |     """ | ||||||
|  |     return APIResponse( | ||||||
|  |         code=200, | ||||||
|  |         message="获取用户信息成功", | ||||||
|  |         data=current_user | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 获取用户列表(仅需登录权限) | ||||||
|  | @router.get("/list", response_model=APIResponse, summary="获取用户列表") | ||||||
|  | @encrypt_response() | ||||||
|  | async def get_user_list( | ||||||
|  |         page: int = Query(1, ge=1, description="页码、从1开始"), | ||||||
|  |         page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间"), | ||||||
|  |         username: Optional[str] = Query(None, description="用户名模糊搜索"), | ||||||
|  |         current_user: UserResponse = Depends(get_current_user)  # 仅需登录即可访问(移除管理员校验) | ||||||
|  | ): | ||||||
|  |     """ | ||||||
|  |     获取用户列表: | ||||||
|  |     - 需登录权限(请求头携带 Token: Bearer <token>) | ||||||
|  |     - 支持分页查询(page=页码、page_size=每页条数) | ||||||
|  |     - 支持用户名模糊搜索(如输入"test"可匹配"test123"、"admin_test"等) | ||||||
|  |     - 仅返回用户ID、用户名、创建时间、更新时间(不包含密码等敏感信息) | ||||||
|  |     """ | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 计算分页偏移量(page从1开始、偏移量=(页码-1)*每页条数) | ||||||
|  |         offset = (page - 1) * page_size | ||||||
|  |  | ||||||
|  |         # 基础查询(仅查非敏感字段) | ||||||
|  |         base_query = """ | ||||||
|  |             SELECT id, username, created_at, updated_at  | ||||||
|  |             FROM users | ||||||
|  |         """ | ||||||
|  |         # 总条数查询(用于分页计算) | ||||||
|  |         count_query = "SELECT COUNT(*) as total FROM users" | ||||||
|  |  | ||||||
|  |         # 条件拼接(支持用户名模糊搜索) | ||||||
|  |         conditions = [] | ||||||
|  |         params = [] | ||||||
|  |         if username: | ||||||
|  |             conditions.append("username LIKE %s") | ||||||
|  |             params.append(f"%{username}%")  # 模糊匹配:%表示任意字符 | ||||||
|  |  | ||||||
|  |         # 构建最终查询语句 | ||||||
|  |         if conditions: | ||||||
|  |             where_clause = " WHERE " + " AND ".join(conditions) | ||||||
|  |             final_query = f"{base_query}{where_clause} LIMIT %s OFFSET %s" | ||||||
|  |             final_count_query = f"{count_query}{where_clause}" | ||||||
|  |             params.extend([page_size, offset])  # 追加分页参数 | ||||||
|  |         else: | ||||||
|  |             final_query = f"{base_query} LIMIT %s OFFSET %s" | ||||||
|  |             final_count_query = count_query | ||||||
|  |             params = [page_size, offset] | ||||||
|  |  | ||||||
|  |         # 1. 查询用户列表数据 | ||||||
|  |         cursor.execute(final_query, params) | ||||||
|  |         users = cursor.fetchall() | ||||||
|  |  | ||||||
|  |         # 2. 查询总条数(用于计算总页数) | ||||||
|  |         count_params = [f"%{username}%"] if username else [] | ||||||
|  |         cursor.execute(final_count_query, count_params) | ||||||
|  |         total = cursor.fetchone()["total"] | ||||||
|  |  | ||||||
|  |         # 3. 转换为UserResponse模型(确保字段匹配) | ||||||
|  |         user_list = [ | ||||||
|  |             UserResponse( | ||||||
|  |                 id=user["id"], | ||||||
|  |                 username=user["username"], | ||||||
|  |                 created_at=user["created_at"], | ||||||
|  |                 updated_at=user["updated_at"] | ||||||
|  |             ) | ||||||
|  |             for user in users | ||||||
|  |         ] | ||||||
|  |  | ||||||
|  |         # 4. 计算总页数(向上取整、如11条数据每页10条=2页) | ||||||
|  |         total_pages = (total + page_size - 1) // page_size | ||||||
|  |  | ||||||
|  |         # 返回结果(包含列表和分页信息) | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message="获取用户列表成功", | ||||||
|  |             data={ | ||||||
|  |                 "users": user_list, | ||||||
|  |                 "pagination": { | ||||||
|  |                     "page": page,  # 当前页码 | ||||||
|  |                     "page_size": page_size,  # 每页条数 | ||||||
|  |                     "total": total,  # 总数据量 | ||||||
|  |                     "total_pages": total_pages  # 总页数 | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         ) | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise Exception(f"获取用户列表失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         # 无论成功失败、都关闭数据库连接 | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
							
								
								
									
										
											BIN
										
									
								
								schema/__pycache__/device_action_schema.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										
											BIN
										
									
								
								schema/__pycache__/device_danger_schema.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										
											BIN
										
									
								
								schema/__pycache__/device_schema.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										
											BIN
										
									
								
								schema/__pycache__/face_schema.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										
											BIN
										
									
								
								schema/__pycache__/model_schema.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										
											BIN
										
									
								
								schema/__pycache__/response_schema.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										
											BIN
										
									
								
								schema/__pycache__/sensitive_schema.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										
											BIN
										
									
								
								schema/__pycache__/user_schema.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										36
									
								
								schema/device_action_schema.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,36 @@ | |||||||
|  | from datetime import datetime | ||||||
|  | from typing import Optional, List | ||||||
|  | from pydantic import BaseModel, Field | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 请求模型 | ||||||
|  | # ------------------------------ | ||||||
|  | class DeviceActionCreate(BaseModel): | ||||||
|  |     """设备操作记录创建模型(0=离线、1=上线)""" | ||||||
|  |     client_ip: str = Field(..., description="客户端IP") | ||||||
|  |     action: int = Field(..., ge=0, le=1, description="操作状态(0=离线、1=上线)") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 响应模型(单条记录) | ||||||
|  | # ------------------------------ | ||||||
|  | class DeviceActionResponse(BaseModel): | ||||||
|  |     """设备操作记录响应模型(与自增表对齐)""" | ||||||
|  |     id: int = Field(..., description="自增主键ID") | ||||||
|  |     client_ip: Optional[str] = Field(None, description="客户端IP") | ||||||
|  |     action: Optional[int] = Field(None, description="操作状态(0=离线、1=上线)") | ||||||
|  |     created_at: datetime = Field(..., description="记录创建时间") | ||||||
|  |     updated_at: datetime = Field(..., description="记录更新时间") | ||||||
|  |  | ||||||
|  |     # 支持从数据库结果直接转换 | ||||||
|  |     model_config = {"from_attributes": True} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 列表响应模型(仅含 total + device_actions) | ||||||
|  | # ------------------------------ | ||||||
|  | class DeviceActionListResponse(BaseModel): | ||||||
|  |     """设备操作记录列表(仅核心返回字段)""" | ||||||
|  |     total: int = Field(..., description="总记录数") | ||||||
|  |     device_actions: List[DeviceActionResponse] = Field(..., description="操作记录列表") | ||||||
							
								
								
									
										33
									
								
								schema/device_danger_schema.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,33 @@ | |||||||
|  | from datetime import datetime | ||||||
|  | from typing import Optional, List | ||||||
|  | from pydantic import BaseModel, Field | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 请求模型 | ||||||
|  | # ------------------------------ | ||||||
|  | class DeviceDangerCreateRequest(BaseModel): | ||||||
|  |     """设备危险记录创建请求模型""" | ||||||
|  |     client_ip: str = Field(..., max_length=100, description="设备IP地址(必须与devices表中IP对应)") | ||||||
|  |     type: str = Field(..., max_length=255, description="危险类型(如:病毒检测、端口异常、权限泄露等)") | ||||||
|  |     result: str = Field(..., description="危险检测结果/处理结果(如:检测到木马病毒、已隔离;端口22异常开放、已关闭)") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 响应模型 | ||||||
|  | # ------------------------------ | ||||||
|  | class DeviceDangerResponse(BaseModel): | ||||||
|  |     """单条设备危险记录响应模型(与device_danger表字段对齐、updated_at允许为null)""" | ||||||
|  |     id: int = Field(..., description="危险记录主键ID") | ||||||
|  |     client_ip: str = Field(..., max_length=100, description="设备IP地址") | ||||||
|  |     type: str = Field(..., max_length=255, description="危险类型") | ||||||
|  |     result: str = Field(..., description="危险检测结果/处理结果") | ||||||
|  |     created_at: datetime = Field(..., description="记录创建时间(危险发生/检测时间)") | ||||||
|  |     updated_at: Optional[datetime] = Field(None, description="记录更新时间(数据库中该字段当前为null)") | ||||||
|  |     model_config = {"from_attributes": True} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class DeviceDangerListResponse(BaseModel): | ||||||
|  |     """设备危险记录列表响应模型(带分页)""" | ||||||
|  |     total: int = Field(..., description="危险记录总数") | ||||||
|  |     dangers: List[DeviceDangerResponse] = Field(..., description="设备危险记录列表") | ||||||
							
								
								
									
										56
									
								
								schema/device_schema.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,56 @@ | |||||||
|  | from datetime import datetime | ||||||
|  | from typing import Optional, List, Dict | ||||||
|  |  | ||||||
|  | from pydantic import BaseModel, Field | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 请求模型 | ||||||
|  | # ------------------------------ | ||||||
|  | class DeviceCreateRequest(BaseModel): | ||||||
|  |     """设备创建请求模型""" | ||||||
|  |     ip: Optional[str] = Field(..., max_length=100, description="设备IP地址") | ||||||
|  |     hostname: Optional[str] = Field(None, max_length=100, description="设备别名") | ||||||
|  |     params: Optional[Dict] = Field(None, description="设备扩展参数(JSON格式)") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 响应模型 | ||||||
|  | # ------------------------------ | ||||||
|  | class DeviceResponse(BaseModel): | ||||||
|  |     """单设备信息响应模型(与数据库表字段对齐)""" | ||||||
|  |     id: int = Field(..., description="设备主键ID") | ||||||
|  |     device_id: Optional[int] = Field(None, description="关联设备ID(若历史记录IP无对应设备则为None)") | ||||||
|  |     hostname: Optional[str] = Field(None, max_length=100, description="设备别名") | ||||||
|  |     device_online_status: int = Field(..., description="在线状态(1-在线、0-离线)") | ||||||
|  |     device_type: Optional[str] = Field(None, description="设备类型") | ||||||
|  |     alarm_count: int = Field(..., description="报警次数") | ||||||
|  |     params: Optional[str] = Field(None, description="扩展参数(JSON字符串)") | ||||||
|  |     client_ip: Optional[str] = Field(None, max_length=100, description="设备IP地址") | ||||||
|  |     is_need_handler: int = Field(..., description="是否需要处理(1-需要、0-不需要)") | ||||||
|  |     created_at: datetime = Field(..., description="记录创建时间") | ||||||
|  |     updated_at: datetime = Field(..., description="记录更新时间") | ||||||
|  |     model_config = {"from_attributes": True}  # 支持从数据库结果直接转换 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class DeviceListResponse(BaseModel): | ||||||
|  |     """设备列表响应模型""" | ||||||
|  |     total: int = Field(..., description="设备总数") | ||||||
|  |     devices: List[DeviceResponse] = Field(..., description="设备列表") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class DeviceStatusHistoryResponse(BaseModel): | ||||||
|  |     """设备上下线记录响应模型""" | ||||||
|  |     id: int = Field(..., description="记录ID") | ||||||
|  |     device_id: int = Field(..., description="关联设备ID") | ||||||
|  |     client_ip: Optional[str] = Field(None, description="设备IP地址") | ||||||
|  |     status: int = Field(..., description="状态(1-在线、0-离线)") | ||||||
|  |     status_time: datetime = Field(..., description="状态变更时间") | ||||||
|  |  | ||||||
|  |     model_config = {"from_attributes": True} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class DeviceStatusHistoryListResponse(BaseModel): | ||||||
|  |     """设备上下线记录列表响应模型""" | ||||||
|  |     total: int = Field(..., description="记录总数") | ||||||
|  |     history: List[DeviceStatusHistoryResponse] = Field(..., description="上下线记录列表") | ||||||
							
								
								
									
										41
									
								
								schema/face_schema.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,41 @@ | |||||||
|  | from datetime import datetime | ||||||
|  | from pydantic import BaseModel, Field | ||||||
|  | from typing import List, Optional | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 请求模型(前端传参校验)- 保留update的eigenvalue(如需更新特征值) | ||||||
|  | # ------------------------------ | ||||||
|  | class FaceCreateRequest(BaseModel): | ||||||
|  |     """创建人脸记录请求模型(无需ID、由数据库自增)""" | ||||||
|  |     name: Optional[str] = Field(None, max_length=255, description="名称(可选、最长255字符)") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class FaceUpdateRequest(BaseModel): | ||||||
|  |     """更新人脸记录请求模型 - 保留eigenvalue(如需更新特征值、不影响返回)""" | ||||||
|  |     name: Optional[str] = Field(None, max_length=255, description="名称(可选)") | ||||||
|  |     eigenvalue: Optional[str] = Field(None, description="特征值(可选、文件处理后可更新)")  # 保留更新能力 | ||||||
|  |     address: Optional[str] = Field(None, description="图片完整路径(可选、更新图片时使用)") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 响应模型(后端返回数据)- 核心修改:删除eigenvalue字段 | ||||||
|  | # ------------------------------ | ||||||
|  | class FaceResponse(BaseModel): | ||||||
|  |     """人脸记录响应模型(仅返回需要的字段、移除eigenvalue)""" | ||||||
|  |     id: int = Field(..., description="主键ID(数据库自增)") | ||||||
|  |     name: Optional[str] = Field(None, description="名称") | ||||||
|  |     address: Optional[str] = Field(None, description="人脸图片完整保存路径(数据库新增字段)")  # 仅保留address | ||||||
|  |     created_at: datetime = Field(..., description="记录创建时间(数据库自动生成)") | ||||||
|  |     updated_at: datetime = Field(..., description="记录更新时间(数据库自动生成)") | ||||||
|  |  | ||||||
|  |     # 关键配置:支持从数据库查询结果(字典)直接转换 | ||||||
|  |     model_config = {"from_attributes": True} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class FaceListResponse(BaseModel): | ||||||
|  |     """人脸列表分页响应模型(结构不变、内部FaceResponse已移除eigenvalue)""" | ||||||
|  |     total: int = Field(..., description="筛选后的总记录数") | ||||||
|  |     faces: List[FaceResponse] = Field(..., description="当前页的人脸记录列表") | ||||||
|  |  | ||||||
|  |     model_config = {"from_attributes": True} | ||||||
							
								
								
									
										37
									
								
								schema/model_schema.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,37 @@ | |||||||
|  | from datetime import datetime | ||||||
|  | from pydantic import BaseModel, Field | ||||||
|  | from typing import List, Optional | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 请求模型 | ||||||
|  | class ModelCreateRequest(BaseModel): | ||||||
|  |     name: str = Field(..., max_length=255, description="模型名称(必填、如:yolo-v8s-car)") | ||||||
|  |     description: Optional[str] = Field(None, description="模型描述(可选)") | ||||||
|  |     is_default: Optional[bool] = Field(False, description="是否设为默认模型") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ModelUpdateRequest(BaseModel): | ||||||
|  |     name: Optional[str] = Field(None, max_length=255, description="模型名称(可选修改)") | ||||||
|  |     description: Optional[str] = Field(None, description="模型描述(可选修改)") | ||||||
|  |     is_default: Optional[bool] = Field(None, description="是否设为默认模型(可选切换)") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 响应模型 | ||||||
|  | class ModelResponse(BaseModel): | ||||||
|  |     id: int = Field(..., description="模型ID") | ||||||
|  |     name: str = Field(..., description="模型名称") | ||||||
|  |     path: str = Field(..., description="模型文件相对路径") | ||||||
|  |     is_default: bool = Field(..., description="是否默认模型") | ||||||
|  |     description: Optional[str] = Field(None, description="模型描述") | ||||||
|  |     file_size: Optional[int] = Field(None, description="文件大小(字节)") | ||||||
|  |     created_at: datetime = Field(..., description="创建时间") | ||||||
|  |     updated_at: datetime = Field(..., description="更新时间") | ||||||
|  |  | ||||||
|  |     model_config = {"from_attributes": True} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ModelListResponse(BaseModel): | ||||||
|  |     total: int = Field(..., description="总记录数") | ||||||
|  |     models: List[ModelResponse] = Field(..., description="当前页模型列表") | ||||||
|  |  | ||||||
|  |     model_config = {"from_attributes": True} | ||||||
							
								
								
									
										13
									
								
								schema/response_schema.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,13 @@ | |||||||
|  | from typing import Optional, Any | ||||||
|  |  | ||||||
|  | from pydantic import BaseModel, Field | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class APIResponse(BaseModel): | ||||||
|  |     """统一 API 响应模型(所有接口必返此格式)""" | ||||||
|  |     code: int = Field(..., description="状态码: 200=成功、4xx=客户端错误、5xx=服务端错误") | ||||||
|  |     message: str = Field(..., description="响应信息: 成功/错误描述") | ||||||
|  |     data: Optional[Any] = Field(None, description="响应数据: 成功时返回、错误时为 None") | ||||||
|  |  | ||||||
|  |     # Pydantic V2 配置(支持从 ORM 对象转换) | ||||||
|  |     model_config = {"from_attributes": True} | ||||||
							
								
								
									
										38
									
								
								schema/sensitive_schema.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,38 @@ | |||||||
|  | from datetime import datetime | ||||||
|  | from pydantic import BaseModel, Field | ||||||
|  | from typing import List, Optional | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 请求模型(前端传参校验) | ||||||
|  | # ------------------------------ | ||||||
|  | class SensitiveCreateRequest(BaseModel): | ||||||
|  |     """创建敏感信息记录请求模型""" | ||||||
|  |     name: str = Field(..., max_length=255, description="敏感词内容(必填)") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SensitiveUpdateRequest(BaseModel): | ||||||
|  |     """更新敏感信息记录请求模型""" | ||||||
|  |     name: Optional[str] = Field(None, max_length=255, description="敏感词内容(可选修改)") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 响应模型(后端返回数据) | ||||||
|  | # ------------------------------ | ||||||
|  | class SensitiveResponse(BaseModel): | ||||||
|  |     """敏感信息单条记录响应模型""" | ||||||
|  |     id: int = Field(..., description="主键ID") | ||||||
|  |     name: str = Field(..., description="敏感词内容") | ||||||
|  |     created_at: datetime = Field(..., description="记录创建时间") | ||||||
|  |     updated_at: datetime = Field(..., description="记录更新时间") | ||||||
|  |  | ||||||
|  |     # 支持从数据库查询结果(字典/对象)自动转换 | ||||||
|  |     model_config = {"from_attributes": True} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SensitiveListResponse(BaseModel): | ||||||
|  |     """敏感信息分页列表响应模型(新增)""" | ||||||
|  |     total: int = Field(..., description="敏感词总记录数") | ||||||
|  |     sensitives: List[SensitiveResponse] = Field(..., description="当前页敏感词列表") | ||||||
|  |  | ||||||
|  |     model_config = {"from_attributes": True} | ||||||
							
								
								
									
										40
									
								
								schema/user_schema.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,40 @@ | |||||||
|  | from datetime import datetime | ||||||
|  | from pydantic import BaseModel, Field | ||||||
|  | from typing import List, Optional | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 请求模型(前端传参校验) | ||||||
|  | # ------------------------------ | ||||||
|  | class UserRegisterRequest(BaseModel): | ||||||
|  |     """用户注册请求模型""" | ||||||
|  |     username: str = Field(..., min_length=3, max_length=50, description="用户名(3-50字符)") | ||||||
|  |     password: str = Field(..., min_length=6, max_length=100, description="密码(6-100字符)") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class UserLoginRequest(BaseModel): | ||||||
|  |     """用户登录请求模型""" | ||||||
|  |     username: str = Field(..., description="用户名") | ||||||
|  |     password: str = Field(..., description="密码") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 响应模型(后端返回用户数据) | ||||||
|  | # ------------------------------ | ||||||
|  | class UserResponse(BaseModel): | ||||||
|  |     """用户信息响应模型(隐藏密码等敏感字段)""" | ||||||
|  |     id: int = Field(..., description="用户ID") | ||||||
|  |     username: str = Field(..., description="用户名") | ||||||
|  |     created_at: datetime = Field(..., description="创建时间") | ||||||
|  |     updated_at: datetime = Field(..., description="更新时间") | ||||||
|  |  | ||||||
|  |     # Pydantic V2 配置(支持从数据库查询结果转换) | ||||||
|  |     model_config = {"from_attributes": True} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class UserListResponse(BaseModel): | ||||||
|  |     """用户列表分页响应模型(与设备/人脸列表结构对齐)""" | ||||||
|  |     total: int = Field(..., description="用户总数") | ||||||
|  |     users: List[UserResponse] = Field(..., description="当前页用户列表") | ||||||
|  |  | ||||||
|  |     model_config = {"from_attributes": True} | ||||||
							
								
								
									
										
											BIN
										
									
								
								service/__pycache__/device_action_service.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										
											BIN
										
									
								
								service/__pycache__/device_danger_service.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										
											BIN
										
									
								
								service/__pycache__/device_service.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										
											BIN
										
									
								
								service/__pycache__/face_service.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										
											BIN
										
									
								
								service/__pycache__/file_service.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										
											BIN
										
									
								
								service/__pycache__/model_service.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										
											BIN
										
									
								
								service/__pycache__/ocr_service.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										
											BIN
										
									
								
								service/__pycache__/sensitive_service.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
							
								
								
									
										43
									
								
								service/device_action_service.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,43 @@ | |||||||
|  | from ds.db import db | ||||||
|  | from schema.device_action_schema import DeviceActionCreate, DeviceActionResponse | ||||||
|  | from mysql.connector import Error as MySQLError | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 新增设备操作记录 | ||||||
|  | def add_device_action(client_ip: str, action: int) -> DeviceActionResponse: | ||||||
|  |     """ | ||||||
|  |     新增设备操作记录(内部方法、非接口) | ||||||
|  |     :param action_data: 含client_ip和action(0/1) | ||||||
|  |     :return: 新增的完整记录 | ||||||
|  |     """ | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 插入SQL(id自增、依赖数据库自动生成) | ||||||
|  |         insert_query = """ | ||||||
|  |             INSERT INTO device_action  | ||||||
|  |             (client_ip, action, created_at, updated_at) | ||||||
|  |             VALUES (%s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) | ||||||
|  |         """ | ||||||
|  |         cursor.execute(insert_query, ( | ||||||
|  |             client_ip, | ||||||
|  |             action | ||||||
|  |         )) | ||||||
|  |         conn.commit() | ||||||
|  |  | ||||||
|  |         # 获取新增记录(通过自增ID查询) | ||||||
|  |         new_id = cursor.lastrowid | ||||||
|  |         cursor.execute("SELECT * FROM device_action WHERE id = %s", (new_id,)) | ||||||
|  |         new_action = cursor.fetchone() | ||||||
|  |  | ||||||
|  |         return DeviceActionResponse(**new_action) | ||||||
|  |  | ||||||
|  |     except MySQLError as e: | ||||||
|  |         if conn: | ||||||
|  |             conn.rollback() | ||||||
|  |         raise Exception(f"新增记录失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
							
								
								
									
										78
									
								
								service/device_danger_service.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,78 @@ | |||||||
|  | from mysql.connector import Error as MySQLError | ||||||
|  |  | ||||||
|  | from ds.db import db | ||||||
|  | from schema.device_danger_schema import DeviceDangerCreateRequest, DeviceDangerResponse | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 内部工具方法 - 检查设备是否存在(复用设备表逻辑) | ||||||
|  | # ------------------------------ | ||||||
|  | def check_device_exist(client_ip: str) -> bool: | ||||||
|  |     """ | ||||||
|  |     检查指定IP的设备是否在devices表中存在 | ||||||
|  |  | ||||||
|  |     :param client_ip: 设备IP地址 | ||||||
|  |     :return: 存在返回True、不存在返回False | ||||||
|  |     """ | ||||||
|  |     if not client_ip: | ||||||
|  |         raise ValueError("设备IP不能为空") | ||||||
|  |  | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,)) | ||||||
|  |         return cursor.fetchone() is not None | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise Exception(f"检查设备存在性失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 内部工具方法 - 创建设备危险记录(核心插入逻辑) | ||||||
|  | # ------------------------------ | ||||||
|  | def create_danger_record(danger_data: DeviceDangerCreateRequest) -> DeviceDangerResponse: | ||||||
|  |     """ | ||||||
|  |     内部工具方法:向device_danger表插入新的危险记录 | ||||||
|  |  | ||||||
|  |     :param danger_data: 危险记录创建请求数据 | ||||||
|  |     :return: 创建成功的危险记录模型对象 | ||||||
|  |     """ | ||||||
|  |     # 先检查设备是否存在 | ||||||
|  |     if not check_device_exist(danger_data.client_ip): | ||||||
|  |         raise ValueError(f"IP为 {danger_data.client_ip} 的设备不存在、无法创建危险记录") | ||||||
|  |  | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 插入危险记录(id自增、时间自动填充) | ||||||
|  |         insert_query = """ | ||||||
|  |             INSERT INTO device_danger  | ||||||
|  |             (client_ip, type, result, created_at, updated_at) | ||||||
|  |             VALUES (%s, %s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) | ||||||
|  |         """ | ||||||
|  |         cursor.execute(insert_query, ( | ||||||
|  |             danger_data.client_ip, | ||||||
|  |             danger_data.type, | ||||||
|  |             danger_data.result | ||||||
|  |         )) | ||||||
|  |         conn.commit() | ||||||
|  |  | ||||||
|  |         # 获取刚创建的记录(用自增ID查询) | ||||||
|  |         danger_id = cursor.lastrowid | ||||||
|  |         cursor.execute("SELECT * FROM device_danger WHERE id = %s", (danger_id,)) | ||||||
|  |         new_danger = cursor.fetchone() | ||||||
|  |  | ||||||
|  |         return DeviceDangerResponse(**new_danger) | ||||||
|  |     except MySQLError as e: | ||||||
|  |         if conn: | ||||||
|  |             conn.rollback() | ||||||
|  |         raise Exception(f"插入危险记录失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
							
								
								
									
										250
									
								
								service/device_service.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,250 @@ | |||||||
|  | import threading | ||||||
|  | import time | ||||||
|  | from typing import Optional | ||||||
|  |  | ||||||
|  | from mysql.connector import Error as MySQLError | ||||||
|  |  | ||||||
|  | from ds.db import db | ||||||
|  | from service.device_action_service import add_device_action | ||||||
|  | _last_alarm_timestamps: dict[str, float] = {} | ||||||
|  | _timestamp_lock = threading.Lock() | ||||||
|  |  | ||||||
|  | # 获取所有去重的客户端IP列表 | ||||||
|  | def get_unique_client_ips() -> list[str]: | ||||||
|  |     """获取所有去重的客户端IP列表""" | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |         query = "SELECT DISTINCT client_ip FROM devices WHERE client_ip IS NOT NULL" | ||||||
|  |         cursor.execute(query) | ||||||
|  |         results = cursor.fetchall() | ||||||
|  |         return [item['client_ip'] for item in results] | ||||||
|  |  | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise Exception(f"获取客户端IP列表失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  | # 通过客户端IP更新设备是否需要处理 | ||||||
|  | def update_is_need_handler_by_client_ip(client_ip: str, is_need_handler: int) -> bool: | ||||||
|  |     """ | ||||||
|  |     通过客户端IP更新设备的「是否需要处理」状态(is_need_handler字段) | ||||||
|  |     """ | ||||||
|  |     # 参数合法性校验 | ||||||
|  |     if not client_ip: | ||||||
|  |         raise ValueError("客户端IP不能为空") | ||||||
|  |  | ||||||
|  |     # 校验is_need_handler取值(需与数据库字段类型匹配、通常为0/1 tinyint) | ||||||
|  |     if is_need_handler not in (0, 1): | ||||||
|  |         raise ValueError("是否需要处理(is_need_handler)必须是0(不需要)或1(需要)") | ||||||
|  |  | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         # 2. 获取数据库连接与游标 | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 3. 先校验设备是否存在(通过client_ip定位) | ||||||
|  |         cursor.execute( | ||||||
|  |             "SELECT id FROM devices WHERE client_ip = %s", | ||||||
|  |             (client_ip,) | ||||||
|  |         ) | ||||||
|  |         device = cursor.fetchone() | ||||||
|  |         if not device: | ||||||
|  |             raise ValueError(f"客户端IP为 {client_ip} 的设备不存在、无法更新「是否需要处理」状态") | ||||||
|  |  | ||||||
|  |         # 4. 执行更新操作(同时更新时间戳、保持与其他更新逻辑一致性) | ||||||
|  |         update_query = """ | ||||||
|  |             UPDATE devices  | ||||||
|  |             SET is_need_handler = %s,  | ||||||
|  |                 updated_at = CURRENT_TIMESTAMP  | ||||||
|  |             WHERE client_ip = %s | ||||||
|  |         """ | ||||||
|  |         cursor.execute(update_query, (is_need_handler, client_ip)) | ||||||
|  |  | ||||||
|  |         # 5. 确认更新生效(判断影响行数、避免无意义更新) | ||||||
|  |         if cursor.rowcount <= 0: | ||||||
|  |             raise Exception(f"更新失败:客户端IP {client_ip} 的设备未发生状态变更(可能已为目标值)") | ||||||
|  |  | ||||||
|  |         # 6. 提交事务 | ||||||
|  |         conn.commit() | ||||||
|  |         return True | ||||||
|  |  | ||||||
|  |     except MySQLError as e: | ||||||
|  |         # 数据库异常时回滚事务 | ||||||
|  |         if conn: | ||||||
|  |             conn.rollback() | ||||||
|  |         raise Exception(f"数据库操作失败:更新设备「是否需要处理」状态时出错 - {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         # 无论成功失败、都关闭数据库连接(避免连接泄漏) | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  | def increment_alarm_count_by_ip(client_ip: str) -> bool: | ||||||
|  |     """ | ||||||
|  |     通过客户端IP增加设备的报警次数,相同IP 200ms内重复调用会被忽略 | ||||||
|  |  | ||||||
|  |     :param client_ip: 客户端IP地址 | ||||||
|  |     :return: 操作是否成功(是否实际执行了数据库更新) | ||||||
|  |     """ | ||||||
|  |     if not client_ip: | ||||||
|  |         raise ValueError("客户端IP不能为空") | ||||||
|  |  | ||||||
|  |     current_time = time.time()  # 获取当前时间戳(秒,含小数) | ||||||
|  |     with _timestamp_lock:  # 确保线程安全的字典操作 | ||||||
|  |         last_time: Optional[float] = _last_alarm_timestamps.get(client_ip) | ||||||
|  |  | ||||||
|  |         # 如果存在最近记录且间隔小于200ms,直接返回False(不执行更新) | ||||||
|  |         if last_time is not None and (current_time - last_time) < 0.2: | ||||||
|  |             return False | ||||||
|  |  | ||||||
|  |         # 更新当前IP的最近调用时间 | ||||||
|  |         _last_alarm_timestamps[client_ip] = current_time | ||||||
|  |  | ||||||
|  |     # 2. 执行数据库更新操作(只有通过时间校验才会执行) | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 检查设备是否存在 | ||||||
|  |         cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,)) | ||||||
|  |         device = cursor.fetchone() | ||||||
|  |         if not device: | ||||||
|  |             raise ValueError(f"客户端IP为 {client_ip} 的设备不存在") | ||||||
|  |  | ||||||
|  |         # 报警次数加1、并更新时间戳 | ||||||
|  |         update_query = """ | ||||||
|  |             UPDATE devices  | ||||||
|  |             SET alarm_count = alarm_count + 1,  | ||||||
|  |                 updated_at = CURRENT_TIMESTAMP  | ||||||
|  |             WHERE client_ip = %s | ||||||
|  |         """ | ||||||
|  |         cursor.execute(update_query, (client_ip,)) | ||||||
|  |         conn.commit() | ||||||
|  |  | ||||||
|  |         return True | ||||||
|  |     except MySQLError as e: | ||||||
|  |         if conn: | ||||||
|  |             conn.rollback() | ||||||
|  |         raise Exception(f"更新报警次数失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  | # 通过客户端IP更新设备在线状态 | ||||||
|  | def update_online_status_by_ip(client_ip: str, online_status: int) -> bool: | ||||||
|  |     """ | ||||||
|  |     通过客户端IP更新设备的在线状态 | ||||||
|  |  | ||||||
|  |     :param client_ip: 客户端IP地址 | ||||||
|  |     :param online_status: 在线状态(1-在线、0-离线) | ||||||
|  |     :return: 操作是否成功 | ||||||
|  |     """ | ||||||
|  |     if not client_ip: | ||||||
|  |         raise ValueError("客户端IP不能为空") | ||||||
|  |  | ||||||
|  |     if online_status not in (0, 1): | ||||||
|  |         raise ValueError("在线状态必须是0(离线)或1(在线)") | ||||||
|  |  | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 检查设备是否存在并获取设备ID | ||||||
|  |         cursor.execute("SELECT id, device_online_status FROM devices WHERE client_ip = %s", (client_ip,)) | ||||||
|  |         device = cursor.fetchone() | ||||||
|  |         if not device: | ||||||
|  |             raise ValueError(f"客户端IP为 {client_ip} 的设备不存在") | ||||||
|  |  | ||||||
|  |         # 状态无变化则不操作 | ||||||
|  |         if device['device_online_status'] == online_status: | ||||||
|  |             return True | ||||||
|  |  | ||||||
|  |         # 更新在线状态和时间戳 | ||||||
|  |         update_query = """ | ||||||
|  |             UPDATE devices  | ||||||
|  |             SET device_online_status = %s,  | ||||||
|  |                 updated_at = CURRENT_TIMESTAMP  | ||||||
|  |             WHERE client_ip = %s | ||||||
|  |         """ | ||||||
|  |         cursor.execute(update_query, (online_status, client_ip)) | ||||||
|  |  | ||||||
|  |         # 记录状态变更历史 | ||||||
|  |         add_device_action(client_ip, online_status) | ||||||
|  |  | ||||||
|  |         conn.commit() | ||||||
|  |         return True | ||||||
|  |     except MySQLError as e: | ||||||
|  |         if conn: | ||||||
|  |             conn.rollback() | ||||||
|  |         raise Exception(f"更新设备在线状态失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  | # 通过客户端IP查询设备在数据库中存在 | ||||||
|  | def is_device_exist_by_ip(client_ip: str) -> bool: | ||||||
|  |     """ | ||||||
|  |     通过客户端IP查询设备在数据库中是否存在 | ||||||
|  |     """ | ||||||
|  |     if not client_ip: | ||||||
|  |         raise ValueError("客户端IP不能为空") | ||||||
|  |  | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 查询设备是否存在 | ||||||
|  |         cursor.execute( | ||||||
|  |             "SELECT id FROM devices WHERE client_ip = %s", | ||||||
|  |             (client_ip,) | ||||||
|  |         ) | ||||||
|  |         device = cursor.fetchone() | ||||||
|  |  | ||||||
|  |         # 如果查询到结果则存在,否则不存在 | ||||||
|  |         return bool(device) | ||||||
|  |  | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise Exception(f"查询设备是否存在失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  | # 根据客户端IP获取是否需要处理 | ||||||
|  | def get_is_need_handler_by_ip(client_ip: str) -> int: | ||||||
|  |     """ | ||||||
|  |     通过客户端IP查询设备的is_need_handler状态 | ||||||
|  |  | ||||||
|  |     :param client_ip: 客户端IP地址 | ||||||
|  |     :return: 设备的is_need_handler状态(0-不需要处理,1-需要处理) | ||||||
|  |     """ | ||||||
|  |     if not client_ip: | ||||||
|  |         raise ValueError("客户端IP不能为空") | ||||||
|  |  | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 查询设备的is_need_handler状态 | ||||||
|  |         cursor.execute( | ||||||
|  |             "SELECT is_need_handler FROM devices WHERE client_ip = %s", | ||||||
|  |             (client_ip,) | ||||||
|  |         ) | ||||||
|  |         device = cursor.fetchone() | ||||||
|  |  | ||||||
|  |         if not device: | ||||||
|  |             raise ValueError(f"客户端IP为 {client_ip} 的设备不存在") | ||||||
|  |  | ||||||
|  |         return device['is_need_handler'] | ||||||
|  |  | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise Exception(f"查询设备is_need_handler状态失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
							
								
								
									
										339
									
								
								service/face_service.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,339 @@ | |||||||
|  | import cv2 | ||||||
|  | from io import BytesIO | ||||||
|  | from PIL import Image | ||||||
|  | from mysql.connector import Error as MySQLError | ||||||
|  |  | ||||||
|  | from ds.db import db | ||||||
|  |  | ||||||
|  | import numpy as np | ||||||
|  | import threading | ||||||
|  | from insightface.app import FaceAnalysis | ||||||
|  |  | ||||||
|  | # 全局变量定义 | ||||||
|  | _insightface_app = None | ||||||
|  | _known_faces_embeddings = {}  # 存储已知人脸特征 {姓名: 特征向量} | ||||||
|  | _known_faces_names = []  # 存储已知人脸姓名列表 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def init_insightface(): | ||||||
|  |     """初始化InsightFace引擎""" | ||||||
|  |     global _insightface_app | ||||||
|  |     if _insightface_app is not None: | ||||||
|  |         print("InsightFace引擎已初始化,无需重复执行") | ||||||
|  |         return _insightface_app | ||||||
|  |  | ||||||
|  |     try: | ||||||
|  |         print("正在初始化 InsightFace 引擎(模型:buffalo_l)...") | ||||||
|  |         # 初始化引擎,指定模型路径和计算 providers | ||||||
|  |         app = FaceAnalysis( | ||||||
|  |             name='buffalo_l', | ||||||
|  |             root='~/.insightface', | ||||||
|  |             providers=['CPUExecutionProvider']  # 如需GPU可添加'CUDAExecutionProvider' | ||||||
|  |         ) | ||||||
|  |         app.prepare(ctx_id=0, det_size=(640, 640))  # 调整检测尺寸 | ||||||
|  |         print("InsightFace 引擎初始化完成") | ||||||
|  |  | ||||||
|  |         # 初始化时加载人脸特征库 | ||||||
|  |         init_face_data() | ||||||
|  |  | ||||||
|  |         _insightface_app = app | ||||||
|  |         return app | ||||||
|  |     except Exception as e: | ||||||
|  |         print(f"InsightFace 初始化失败:{str(e)}") | ||||||
|  |         _insightface_app = None | ||||||
|  |         return None | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def init_face_data(): | ||||||
|  |     """初始化或更新人脸特征库(清空旧数据,避免重复)""" | ||||||
|  |     global _known_faces_embeddings, _known_faces_names | ||||||
|  |     # 清空原有数据,防止重复加载 | ||||||
|  |     _known_faces_embeddings.clear() | ||||||
|  |     _known_faces_names.clear() | ||||||
|  |  | ||||||
|  |     try: | ||||||
|  |         face_data = get_all_face_name_with_eigenvalue()  # 假设该函数已定义 | ||||||
|  |         print(f"已加载 {len(face_data)} 个人脸数据") | ||||||
|  |         for person_name, eigenvalue_data in face_data.items(): | ||||||
|  |             # 解析特征值(支持numpy数组或字符串格式) | ||||||
|  |             if isinstance(eigenvalue_data, np.ndarray): | ||||||
|  |                 eigenvalue = eigenvalue_data.astype(np.float32) | ||||||
|  |             elif isinstance(eigenvalue_data, str): | ||||||
|  |                 # 增强字符串解析:支持逗号/空格分隔,清理特殊字符 | ||||||
|  |                 cleaned = (eigenvalue_data | ||||||
|  |                            .replace("[", "").replace("]", "") | ||||||
|  |                            .replace("\n", "").replace(",", " ") | ||||||
|  |                            .strip()) | ||||||
|  |                 values = [v for v in cleaned.split() if v]  # 过滤空字符串 | ||||||
|  |                 if not values: | ||||||
|  |                     print(f"特征值解析失败(空值),跳过 {person_name}") | ||||||
|  |                     continue | ||||||
|  |                 eigenvalue = np.array(list(map(float, values)), dtype=np.float32) | ||||||
|  |             else: | ||||||
|  |                 print(f"不支持的特征值类型({type(eigenvalue_data)}),跳过 {person_name}") | ||||||
|  |                 continue | ||||||
|  |  | ||||||
|  |             # 特征值归一化(确保相似度计算一致性) | ||||||
|  |             norm = np.linalg.norm(eigenvalue) | ||||||
|  |             if norm == 0: | ||||||
|  |                 print(f"特征值为零向量,跳过 {person_name}") | ||||||
|  |                 continue | ||||||
|  |             eigenvalue = eigenvalue / norm | ||||||
|  |  | ||||||
|  |             # 更新全局特征库 | ||||||
|  |             _known_faces_embeddings[person_name] = eigenvalue | ||||||
|  |             _known_faces_names.append(person_name) | ||||||
|  |  | ||||||
|  |         print(f"成功加载 {len(_known_faces_names)} 个人脸的特征库") | ||||||
|  |     except Exception as e: | ||||||
|  |         print(f"加载人脸特征库失败: {e}") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def update_face_data(): | ||||||
|  |     """更新人脸特征库(清空旧数据,加载最新数据)""" | ||||||
|  |     global _known_faces_embeddings, _known_faces_names | ||||||
|  |  | ||||||
|  |     print("开始更新人脸特征库...") | ||||||
|  |  | ||||||
|  |     # 清空原有数据 | ||||||
|  |     _known_faces_embeddings.clear() | ||||||
|  |     _known_faces_names.clear() | ||||||
|  |  | ||||||
|  |     try: | ||||||
|  |         # 获取最新人脸数据 | ||||||
|  |         face_data = get_all_face_name_with_eigenvalue() | ||||||
|  |         print(f"获取到 {len(face_data)} 条最新人脸数据") | ||||||
|  |  | ||||||
|  |         # 处理并加载新数据(复用原有解析逻辑) | ||||||
|  |         for person_name, eigenvalue_data in face_data.items(): | ||||||
|  |             # 解析特征值(支持numpy数组或字符串格式) | ||||||
|  |             if isinstance(eigenvalue_data, np.ndarray): | ||||||
|  |                 eigenvalue = eigenvalue_data.astype(np.float32) | ||||||
|  |             elif isinstance(eigenvalue_data, str): | ||||||
|  |                 # 增强字符串解析:支持逗号/空格分隔,清理特殊字符 | ||||||
|  |                 cleaned = (eigenvalue_data | ||||||
|  |                            .replace("[", "").replace("]", "") | ||||||
|  |                            .replace("\n", "").replace(",", " ") | ||||||
|  |                            .strip()) | ||||||
|  |                 values = [v for v in cleaned.split() if v]  # 过滤空字符串 | ||||||
|  |                 if not values: | ||||||
|  |                     print(f"特征值解析失败(空值),跳过 {person_name}") | ||||||
|  |                     continue | ||||||
|  |                 eigenvalue = np.array(list(map(float, values)), dtype=np.float32) | ||||||
|  |             else: | ||||||
|  |                 print(f"不支持的特征值类型({type(eigenvalue_data)}),跳过 {person_name}") | ||||||
|  |                 continue | ||||||
|  |  | ||||||
|  |             # 特征值归一化(确保相似度计算一致性) | ||||||
|  |             norm = np.linalg.norm(eigenvalue) | ||||||
|  |             if norm == 0: | ||||||
|  |                 print(f"特征值为零向量,跳过 {person_name}") | ||||||
|  |                 continue | ||||||
|  |             eigenvalue = eigenvalue / norm | ||||||
|  |  | ||||||
|  |             # 更新全局特征库 | ||||||
|  |             _known_faces_embeddings[person_name] = eigenvalue | ||||||
|  |             _known_faces_names.append(person_name) | ||||||
|  |  | ||||||
|  |         print(f"人脸特征库更新完成,共加载 {len(_known_faces_names)} 个人脸数据") | ||||||
|  |         return True  # 更新成功 | ||||||
|  |     except Exception as e: | ||||||
|  |         print(f"人脸特征库更新失败: {e}") | ||||||
|  |         return False  # 更新失败 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def detect(frame, similarity_threshold=0.4): | ||||||
|  |     global _insightface_app, _known_faces_embeddings | ||||||
|  |  | ||||||
|  |     # 校验输入有效性 | ||||||
|  |     if frame is None or frame.size == 0: | ||||||
|  |         return (False, "无效的输入帧数据") | ||||||
|  |  | ||||||
|  |     # 校验引擎和特征库状态 | ||||||
|  |     if not _insightface_app: | ||||||
|  |         return (False, "人脸引擎未初始化") | ||||||
|  |     if not _known_faces_embeddings: | ||||||
|  |         return (False, "人脸特征库为空") | ||||||
|  |  | ||||||
|  |     try: | ||||||
|  |         # 执行人脸检测与特征提取 | ||||||
|  |         faces = _insightface_app.get(frame) | ||||||
|  |     except Exception as e: | ||||||
|  |         return (False, f"检测错误: {str(e)}") | ||||||
|  |  | ||||||
|  |     result_parts = [] | ||||||
|  |     has_matched_known_face = False  # 是否有匹配到已知人脸 | ||||||
|  |  | ||||||
|  |     for face in faces: | ||||||
|  |         # 归一化当前人脸特征 | ||||||
|  |         face_embedding = face.embedding.astype(np.float32) | ||||||
|  |         norm = np.linalg.norm(face_embedding) | ||||||
|  |         if norm == 0: | ||||||
|  |             result_parts.append("检测到人脸但特征值为零向量(忽略)") | ||||||
|  |             continue | ||||||
|  |         face_embedding = face_embedding / norm | ||||||
|  |  | ||||||
|  |         # 与已知特征库比对 | ||||||
|  |         max_similarity, best_match_name = -1.0, "Unknown" | ||||||
|  |         for name, known_emb in _known_faces_embeddings.items(): | ||||||
|  |             similarity = np.dot(face_embedding, known_emb)  # 余弦相似度 | ||||||
|  |             if similarity > max_similarity: | ||||||
|  |                 max_similarity = similarity | ||||||
|  |                 best_match_name = name | ||||||
|  |  | ||||||
|  |         # 判断是否匹配成功 | ||||||
|  |         is_matched = max_similarity >= similarity_threshold | ||||||
|  |  | ||||||
|  |         if is_matched: | ||||||
|  |             has_matched_known_face = True | ||||||
|  |  | ||||||
|  |         # 记录结果(边界框转为整数列表) | ||||||
|  |         bbox = face.bbox.astype(int).tolist() | ||||||
|  |         result_parts.append( | ||||||
|  |             f"{'匹配' if is_matched else '未匹配'}: {best_match_name} " | ||||||
|  |             f"(相似度: {max_similarity:.2f}, 边界框: {bbox})" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     # 构建最终结果 | ||||||
|  |     result_str = "未检测到人脸" if not result_parts else "; ".join(result_parts) | ||||||
|  |     return (has_matched_known_face, result_str) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 上传图片并提取特征 | ||||||
|  | def add_binary_data(binary_data): | ||||||
|  |     """ | ||||||
|  |     接收单张图片的二进制数据、提取特征并保存 | ||||||
|  |     返回:(True, 特征值numpy数组) 或 (False, 错误信息字符串) | ||||||
|  |     """ | ||||||
|  |     global _insightface_app, _feature_list | ||||||
|  |  | ||||||
|  |     # 1. 先检查引擎是否初始化成功 | ||||||
|  |     if not _insightface_app: | ||||||
|  |         init_result = init_insightface()  # 尝试重新初始化 | ||||||
|  |         if not init_result: | ||||||
|  |             error_msg = "InsightFace引擎未初始化、无法检测人脸" | ||||||
|  |             print(error_msg) | ||||||
|  |             return False, error_msg | ||||||
|  |  | ||||||
|  |     try: | ||||||
|  |         # 2. 验证二进制数据有效性 | ||||||
|  |         if len(binary_data) < 1024:  # 过滤过小的无效图片(小于1KB) | ||||||
|  |             error_msg = f"图片过小({len(binary_data)}字节)、可能不是有效图片" | ||||||
|  |             print(error_msg) | ||||||
|  |             return False, error_msg | ||||||
|  |  | ||||||
|  |         # 3. 二进制数据转CV2格式(关键步骤、避免通道错误) | ||||||
|  |         try: | ||||||
|  |             img = Image.open(BytesIO(binary_data)).convert("RGB")  # 强制转RGB | ||||||
|  |             frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)  # InsightFace需要BGR格式 | ||||||
|  |         except Exception as e: | ||||||
|  |             error_msg = f"图片格式转换失败:{str(e)}" | ||||||
|  |             print(error_msg) | ||||||
|  |             return False, error_msg | ||||||
|  |  | ||||||
|  |         # 4. 检查图片尺寸(避免极端尺寸导致检测失败) | ||||||
|  |         height, width = frame.shape[:2] | ||||||
|  |         if height < 64 or width < 64:  # 人脸检测最小建议尺寸 | ||||||
|  |             error_msg = f"图片尺寸过小({width}x{height})、需至少64x64像素" | ||||||
|  |             print(error_msg) | ||||||
|  |             return False, error_msg | ||||||
|  |  | ||||||
|  |         # 5. 调用InsightFace检测人脸 | ||||||
|  |         print(f"开始检测人脸(图片尺寸:{width}x{height}、格式:BGR)") | ||||||
|  |         faces = _insightface_app.get(frame) | ||||||
|  |  | ||||||
|  |         if not faces: | ||||||
|  |             error_msg = "未检测到人脸(请确保图片包含清晰正面人脸、无遮挡、不模糊)" | ||||||
|  |             print(error_msg) | ||||||
|  |             return False, error_msg | ||||||
|  |  | ||||||
|  |         # 6. 提取特征并保存 | ||||||
|  |         current_feature = faces[0].embedding | ||||||
|  |         _feature_list.append(current_feature) | ||||||
|  |         print(f"人脸检测成功、提取特征值(维度:{current_feature.shape[0]})、累计特征数:{len(_feature_list)}") | ||||||
|  |         return True, current_feature | ||||||
|  |  | ||||||
|  |     except Exception as e: | ||||||
|  |         error_msg = f"处理图片时发生异常:{str(e)}" | ||||||
|  |         print(error_msg) | ||||||
|  |         return False, error_msg | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 获取数据库最新的人脸及其特征 | ||||||
|  | def get_all_face_name_with_eigenvalue() -> dict: | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |         query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL" | ||||||
|  |         cursor.execute(query) | ||||||
|  |         faces = cursor.fetchall() | ||||||
|  |  | ||||||
|  |         name_to_eigenvalues = {} | ||||||
|  |         for face in faces: | ||||||
|  |             name = face["name"] | ||||||
|  |             eigenvalue = face["eigenvalue"] | ||||||
|  |             if name in name_to_eigenvalues: | ||||||
|  |                 name_to_eigenvalues[name].append(eigenvalue) | ||||||
|  |             else: | ||||||
|  |                 name_to_eigenvalues[name] = [eigenvalue] | ||||||
|  |  | ||||||
|  |         face_dict = {} | ||||||
|  |         for name, eigenvalues in name_to_eigenvalues.items(): | ||||||
|  |             if len(eigenvalues) > 1: | ||||||
|  |                 face_dict[name] = get_average_feature(eigenvalues) | ||||||
|  |             else: | ||||||
|  |                 face_dict[name] = eigenvalues[0] | ||||||
|  |  | ||||||
|  |         return face_dict | ||||||
|  |  | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise Exception(f"获取人脸特征失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 获取平均特征值 | ||||||
|  | def get_average_feature(features=None): | ||||||
|  |     global _feature_list | ||||||
|  |     try: | ||||||
|  |         if features is None: | ||||||
|  |             features = _feature_list | ||||||
|  |         if not isinstance(features, list) or len(features) == 0: | ||||||
|  |             print("输入必须是包含至少一个特征值的列表") | ||||||
|  |             return None | ||||||
|  |  | ||||||
|  |         processed_features = [] | ||||||
|  |         for i, embedding in enumerate(features): | ||||||
|  |             try: | ||||||
|  |                 if isinstance(embedding, str): | ||||||
|  |                     embedding_str = embedding.replace('[', '').replace(']', '').replace(',', ' ').strip() | ||||||
|  |                     embedding_list = [float(num) for num in embedding_str.split() if num.strip()] | ||||||
|  |                     embedding_np = np.array(embedding_list, dtype=np.float32) | ||||||
|  |                 else: | ||||||
|  |                     embedding_np = np.array(embedding, dtype=np.float32) | ||||||
|  |  | ||||||
|  |                 if len(embedding_np.shape) == 1: | ||||||
|  |                     processed_features.append(embedding_np) | ||||||
|  |                     print(f"已添加第 {i + 1} 个特征值用于计算平均值") | ||||||
|  |                 else: | ||||||
|  |                     print(f"跳过第 {i + 1} 个特征值:不是一维数组") | ||||||
|  |             except Exception as e: | ||||||
|  |                 print(f"处理第 {i + 1} 个特征值时出错:{str(e)}") | ||||||
|  |  | ||||||
|  |         if not processed_features: | ||||||
|  |             print("没有有效的特征值用于计算平均值") | ||||||
|  |             return None | ||||||
|  |  | ||||||
|  |         dims = {feat.shape[0] for feat in processed_features} | ||||||
|  |         if len(dims) > 1: | ||||||
|  |             print(f"特征值维度不一致:{dims}、无法计算平均值") | ||||||
|  |             return None | ||||||
|  |  | ||||||
|  |         avg_feature = np.mean(processed_features, axis=0) | ||||||
|  |         print(f"计算成功:{len(processed_features)} 个特征值的平均向量(维度:{avg_feature.shape[0]})") | ||||||
|  |         return avg_feature | ||||||
|  |     except Exception as e: | ||||||
|  |         print(f"计算平均特征值出错:{str(e)}") | ||||||
|  |         return None | ||||||
							
								
								
									
										343
									
								
								service/file_service.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,343 @@ | |||||||
|  | import os | ||||||
|  | import re | ||||||
|  | import shutil | ||||||
|  | from datetime import datetime | ||||||
|  | from PIL import ImageDraw, ImageFont | ||||||
|  | from fastapi import UploadFile | ||||||
|  | import cv2 | ||||||
|  | from PIL import Image | ||||||
|  | import numpy as np | ||||||
|  |  | ||||||
|  | # 上传根目录 | ||||||
|  | UPLOAD_ROOT = "upload" | ||||||
|  | PRE = "/api/file/download/" | ||||||
|  |  | ||||||
|  | # 确保上传根目录存在 | ||||||
|  | os.makedirs(UPLOAD_ROOT, exist_ok=True) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def save_detect_file(client_ip: str, image_np: np.ndarray, file_type: str) -> str: | ||||||
|  |     """保存numpy数组格式的PNG图片到detect目录,返回下载路径""" | ||||||
|  |     today = datetime.now() | ||||||
|  |     year = today.strftime("%Y") | ||||||
|  |     month = today.strftime("%m") | ||||||
|  |     day = today.strftime("%d") | ||||||
|  |  | ||||||
|  |     # 构建目录路径: upload/detect/客户端IP/type/年/月/日(包含UPLOAD_ROOT) | ||||||
|  |     file_dir = os.path.join( | ||||||
|  |         UPLOAD_ROOT, | ||||||
|  |         "detect", | ||||||
|  |         client_ip, | ||||||
|  |         file_type, | ||||||
|  |         year, | ||||||
|  |         month, | ||||||
|  |         day | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     # 创建目录(确保目录存在) | ||||||
|  |     os.makedirs(file_dir, exist_ok=True) | ||||||
|  |  | ||||||
|  |     # 生成当前时间戳作为文件名,确保唯一性 | ||||||
|  |     timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f") | ||||||
|  |     filename = f"{timestamp}.png" | ||||||
|  |  | ||||||
|  |     # 1. 完整路径:用于实际保存文件(包含UPLOAD_ROOT) | ||||||
|  |     full_path = os.path.join(file_dir, filename) | ||||||
|  |     # 2. 相对路径:用于返回给前端(移除UPLOAD_ROOT前缀) | ||||||
|  |     relative_path = full_path.replace(UPLOAD_ROOT, "", 1).lstrip(os.sep) | ||||||
|  |  | ||||||
|  |     # 保存numpy数组为PNG图片 | ||||||
|  |     try: | ||||||
|  |         # -------- 新增/修改:处理颜色通道和数据类型 -------- | ||||||
|  |         # 1. 数据类型转换:确保是uint8(若为float32且范围0-1,需转成0-255的uint8) | ||||||
|  |         if image_np.dtype != np.uint8: | ||||||
|  |             image_np = (image_np * 255).astype(np.uint8) | ||||||
|  |  | ||||||
|  |         # 2. 通道顺序转换:若为OpenCV的BGR格式,转成PIL需要的RGB格式 | ||||||
|  |         image_rgb = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) | ||||||
|  |  | ||||||
|  |         # 3. 转换为PIL Image并保存 | ||||||
|  |         img = Image.fromarray(image_rgb) | ||||||
|  |         img.save(full_path, format='PNG') | ||||||
|  |     except Exception as e: | ||||||
|  |         # 处理可能的异常(如数组格式不正确) | ||||||
|  |         raise Exception(f"保存图片失败: {str(e)}") | ||||||
|  |  | ||||||
|  |     # 统一路径分隔符为/,拼接前缀返回 | ||||||
|  |     return PRE + relative_path.replace(os.sep, "/") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def save_detect_yolo_file( | ||||||
|  |         client_ip: str, | ||||||
|  |         image_np: np.ndarray, | ||||||
|  |         detection_results: list, | ||||||
|  |         file_type: str = "yolo" | ||||||
|  | ) -> str: | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     print("......................") | ||||||
|  |     """ | ||||||
|  |     保存YOLO检测结果图片(在原图上绘制边界框+标签),返回前端可访问的下载路径 | ||||||
|  |     """ | ||||||
|  |     # 输入参数验证 | ||||||
|  |     if not isinstance(image_np, np.ndarray): | ||||||
|  |         raise ValueError(f"输入image_np必须是numpy数组,当前类型:{type(image_np)}") | ||||||
|  |     if image_np.ndim != 3 or image_np.shape[-1] != 3: | ||||||
|  |         raise ValueError(f"输入图像必须是 (h, w, 3) 的BGR数组,当前shape:{image_np.shape}") | ||||||
|  |  | ||||||
|  |     if not isinstance(detection_results, list): | ||||||
|  |         raise ValueError(f"detection_results必须是列表,当前类型:{type(detection_results)}") | ||||||
|  |     for idx, result in enumerate(detection_results): | ||||||
|  |         required_keys = {"class", "confidence", "bbox"} | ||||||
|  |         if not isinstance(result, dict) or not required_keys.issubset(result.keys()): | ||||||
|  |             raise ValueError( | ||||||
|  |                 f"detection_results第{idx}个元素格式错误,需包含键:{required_keys}," | ||||||
|  |                 f"当前元素:{result}" | ||||||
|  |             ) | ||||||
|  |         bbox = result["bbox"] | ||||||
|  |         if not (isinstance(bbox, (tuple, list)) and len(bbox) == 4 and all(isinstance(x, int) for x in bbox)): | ||||||
|  |             raise ValueError( | ||||||
|  |                 f"detection_results第{idx}个元素的bbox格式错误,需为(x1,y1,x2,y2)整数元组," | ||||||
|  |                 f"当前bbox:{bbox}" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     #图像预处理(数据类型+通道) | ||||||
|  |     draw_image = image_np.copy() | ||||||
|  |     if draw_image.dtype != np.uint8: | ||||||
|  |         draw_image = np.clip(draw_image * 255, 0, 255).astype(np.uint8) | ||||||
|  |  | ||||||
|  |     #绘制边界框+标签 | ||||||
|  |     # 遍历所有检测结果,逐个绘制 | ||||||
|  |     for result in detection_results: | ||||||
|  |         class_name = result["class"] | ||||||
|  |         confidence = result["confidence"] | ||||||
|  |         x1, y1, x2, y2 = result["bbox"] | ||||||
|  |         cv2.rectangle(draw_image, (x1, y1), (x2, y2), color=(0, 255, 0), thickness=2) | ||||||
|  |         label = f"{class_name}: {confidence:.2f}" | ||||||
|  |         font = cv2.FONT_HERSHEY_SIMPLEX | ||||||
|  |         font_scale = 0.5 | ||||||
|  |         font_thickness = 2 | ||||||
|  |         (label_width, label_height), baseline = cv2.getTextSize( | ||||||
|  |             label, font, font_scale, font_thickness | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         bg_top_left = (x1, y1 - label_height - 10) | ||||||
|  |         bg_bottom_right = (x1 + label_width, y1) | ||||||
|  |         if bg_top_left[1] < 0: | ||||||
|  |             bg_top_left = (x1, 0) | ||||||
|  |             bg_bottom_right = (x1 + label_width, label_height + 10) | ||||||
|  |         cv2.rectangle(draw_image, bg_top_left, bg_bottom_right, color=(0, 0, 0), thickness=-1) | ||||||
|  |  | ||||||
|  |         text_origin = (x1, y1 - 5) | ||||||
|  |         if bg_top_left[1] == 0: | ||||||
|  |             text_origin = (x1, label_height + 5) | ||||||
|  |         cv2.putText( | ||||||
|  |             draw_image, label, text_origin, | ||||||
|  |             font, font_scale, color=(255, 255, 255), thickness=font_thickness | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     #保存图片 | ||||||
|  |     try: | ||||||
|  |         today = datetime.now() | ||||||
|  |         year = today.strftime("%Y") | ||||||
|  |         month = today.strftime("%m") | ||||||
|  |         day = today.strftime("%d") | ||||||
|  |         file_dir = os.path.join( | ||||||
|  |             UPLOAD_ROOT, "detect", client_ip, file_type, year, month, day | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         #创建目录(若不存在则创建,支持多级目录) | ||||||
|  |         os.makedirs(file_dir, exist_ok=True) | ||||||
|  |  | ||||||
|  |         #生成唯一文件名 | ||||||
|  |         timestamp = today.strftime("%Y%m%d%H%M%S%f") | ||||||
|  |         filename = f"{timestamp}.png" | ||||||
|  |  | ||||||
|  |         # 4.4 构建完整保存路径和前端访问路径 | ||||||
|  |         full_path = os.path.join(file_dir, filename)  # 本地完整路径 | ||||||
|  |         # 相对路径:移除UPLOAD_ROOT前缀,统一用"/"作为分隔符(兼容Windows/Linux) | ||||||
|  |         relative_path = full_path.replace(UPLOAD_ROOT, "", 1).lstrip(os.sep) | ||||||
|  |         download_path = PRE + relative_path.replace(os.sep, "/") | ||||||
|  |  | ||||||
|  |         # 4.5 保存图片(CV2绘制的是BGR,需转RGB后用PIL保存,与原逻辑一致) | ||||||
|  |         image_rgb = cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB) | ||||||
|  |         img_pil = Image.fromarray(image_rgb) | ||||||
|  |         img_pil.save(full_path, format="PNG", quality=95)  # PNG格式无压缩,quality可忽略 | ||||||
|  |  | ||||||
|  |         print(f"YOLO检测图片保存成功 | 本地路径:{full_path} | 下载路径:{download_path}") | ||||||
|  |         return download_path | ||||||
|  |  | ||||||
|  |     except Exception as e: | ||||||
|  |         raise Exception(f"YOLO检测图片保存失败:{str(e)}") from e | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def save_detect_face_file( | ||||||
|  |     client_ip: str, | ||||||
|  |     image_np: np.ndarray, | ||||||
|  |     face_result: str, | ||||||
|  |     file_type: str = "face", | ||||||
|  |     matched_color: tuple = (0, 255, 0) | ||||||
|  | ) -> str: | ||||||
|  |     """ | ||||||
|  |     保存人脸识别结果图片(仅为「匹配成功」的人脸画框,标签不包含“匹配”二字) | ||||||
|  |     """ | ||||||
|  |     #输入参数验证 | ||||||
|  |     if not isinstance(image_np, np.ndarray) or image_np.ndim != 3 or image_np.shape[-1] != 3: | ||||||
|  |         raise ValueError(f"输入图像需为 (h, w, 3) 的BGR数组,当前shape:{image_np.shape}") | ||||||
|  |     if not isinstance(face_result, str) or face_result.strip() == "": | ||||||
|  |         raise ValueError("face_result必须是非空字符串") | ||||||
|  |  | ||||||
|  |     # 解析face_result提取人脸信息 | ||||||
|  |     face_info_list = [] | ||||||
|  |     if face_result.strip() != "未检测到人脸": | ||||||
|  |         face_pattern = re.compile( | ||||||
|  |             r"(匹配|未匹配):\s*([^\s(]+)\s*\(相似度:\s*(\d+\.\d+),\s*边界框:\s*\[(\d+,\s*\d+,\s*\d+,\s*\d+)\]\)" | ||||||
|  |         ) | ||||||
|  |         for part in [p.strip() for p in face_result.split(";") if p.strip()]: | ||||||
|  |             match = face_pattern.match(part) | ||||||
|  |             if match: | ||||||
|  |                 status, name, similarity, bbox_str = match.groups() | ||||||
|  |                 bbox = list(map(int, bbox_str.replace(" ", "").split(","))) | ||||||
|  |                 if len(bbox) == 4: | ||||||
|  |                     face_info_list.append({ | ||||||
|  |                         "status": status, | ||||||
|  |                         "name": name, | ||||||
|  |                         "similarity": float(similarity), | ||||||
|  |                         "bbox": bbox | ||||||
|  |                     }) | ||||||
|  |  | ||||||
|  |     # 图像格式转换(OpenCV→PIL) | ||||||
|  |     image_rgb = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) | ||||||
|  |     pil_img = Image.fromarray(image_rgb) | ||||||
|  |     draw = ImageDraw.Draw(pil_img) | ||||||
|  |  | ||||||
|  |     # 绘制边界框和标签 | ||||||
|  |     font_size = 12 | ||||||
|  |     try: | ||||||
|  |         font = ImageFont.truetype("simhei", font_size) | ||||||
|  |     except: | ||||||
|  |         try: | ||||||
|  |             font = ImageFont.truetype("simsun", font_size) | ||||||
|  |         except: | ||||||
|  |             font = ImageFont.load_default() | ||||||
|  |             print("警告:未找到指定中文字体,使用PIL默认字体(可能影响中文显示)") | ||||||
|  |  | ||||||
|  |     for face_info in face_info_list: | ||||||
|  |         status = face_info["status"] | ||||||
|  |         if status != "匹配": | ||||||
|  |             print(f"跳过未匹配人脸:{face_info['name']}(相似度:{face_info['similarity']:.2f})") | ||||||
|  |             continue | ||||||
|  |  | ||||||
|  |         name = face_info["name"] | ||||||
|  |         similarity = face_info["similarity"] | ||||||
|  |         x1, y1, x2, y2 = face_info["bbox"] | ||||||
|  |  | ||||||
|  |         # 4.1 绘制边界框(绿色) | ||||||
|  |         img_cv = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) | ||||||
|  |         cv2.rectangle(img_cv, (x1, y1), (x2, y2), color=matched_color, thickness=2) | ||||||
|  |         pil_img = Image.fromarray(cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB)) | ||||||
|  |         draw = ImageDraw.Draw(pil_img) | ||||||
|  |  | ||||||
|  |         label = f"{name} (相似度: {similarity:.2f})" | ||||||
|  |  | ||||||
|  |         # 4.3 计算标签尺寸(文本变短后会自动适配,无需额外调整) | ||||||
|  |         label_bbox = draw.textbbox((0, 0), label, font=font) | ||||||
|  |         label_width = label_bbox[2] - label_bbox[0] | ||||||
|  |         label_height = label_bbox[3] - label_bbox[1] | ||||||
|  |  | ||||||
|  |         # 4.4 计算标签背景位置(避免超出图像) | ||||||
|  |         bg_x1, bg_y1 = x1, y1 - label_height - 10 | ||||||
|  |         bg_x2, bg_y2 = x1 + label_width, y1 | ||||||
|  |         if bg_y1 < 0: | ||||||
|  |             bg_y1, bg_y2 = y2 + 5, y2 + label_height + 15 | ||||||
|  |  | ||||||
|  |         # 4.5 绘制标签背景(黑色)和文本(白色) | ||||||
|  |         draw.rectangle([(bg_x1, bg_y1), (bg_x2, bg_y2)], fill=(0, 0, 0)) | ||||||
|  |         text_x = bg_x1 | ||||||
|  |         text_y = bg_y1 if bg_y1 >= 0 else bg_y1 + label_height | ||||||
|  |         draw.text((text_x, text_y), label, font=font, fill=(255, 255, 255)) | ||||||
|  |  | ||||||
|  |     #保存图片 | ||||||
|  |     try: | ||||||
|  |         today = datetime.now() | ||||||
|  |         file_dir = os.path.join( | ||||||
|  |             UPLOAD_ROOT, "detect", client_ip, file_type, | ||||||
|  |             today.strftime("%Y"), today.strftime("%m"), today.strftime("%d") | ||||||
|  |         ) | ||||||
|  |         os.makedirs(file_dir, exist_ok=True) | ||||||
|  |  | ||||||
|  |         timestamp = today.strftime("%Y%m%d%H%M%S%f") | ||||||
|  |         filename = f"{timestamp}.png" | ||||||
|  |         full_path = os.path.join(file_dir, filename) | ||||||
|  |  | ||||||
|  |         pil_img.save(full_path, format="PNG", quality=100) | ||||||
|  |  | ||||||
|  |         relative_path = full_path.replace(UPLOAD_ROOT, "", 1).lstrip(os.sep) | ||||||
|  |         download_path = PRE + relative_path.replace(os.sep, "/") | ||||||
|  |  | ||||||
|  |         matched_count = sum(1 for info in face_info_list if info["status"] == "匹配") | ||||||
|  |         print(f"人脸检测图片保存成功 | 客户端IP:{client_ip} | 匹配人脸数:{matched_count} | 保存路径:{download_path}") | ||||||
|  |         return download_path | ||||||
|  |  | ||||||
|  |     except Exception as e: | ||||||
|  |         raise Exception(f"人脸检测图片保存失败(客户端IP:{client_ip}):{str(e)}") from e | ||||||
|  |  | ||||||
|  | def save_source_file(upload_file: UploadFile, file_type: str) -> str: | ||||||
|  |     """保存上传的文件到source目录,返回下载路径""" | ||||||
|  |     today = datetime.now() | ||||||
|  |     year = today.strftime("%Y") | ||||||
|  |     month = today.strftime("%m") | ||||||
|  |     day = today.strftime("%d") | ||||||
|  |  | ||||||
|  |     # 生成精确到微秒的时间戳,确保文件名唯一 | ||||||
|  |     timestamp = today.strftime("%Y%m%d%H%M%S%f") | ||||||
|  |     # 构建新文件名:时间戳_原文件名 | ||||||
|  |     unique_filename = f"{timestamp}_{upload_file.filename}" | ||||||
|  |  | ||||||
|  |     # 构建目录路径: upload/source/type/年/月/日(包含UPLOAD_ROOT) | ||||||
|  |     file_dir = os.path.join( | ||||||
|  |         UPLOAD_ROOT, | ||||||
|  |         "source", | ||||||
|  |         file_type, | ||||||
|  |         year, | ||||||
|  |         month, | ||||||
|  |         day | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     # 创建目录(确保目录存在) | ||||||
|  |     os.makedirs(file_dir, exist_ok=True) | ||||||
|  |  | ||||||
|  |     # 1. 完整路径:用于实际保存文件(使用带时间戳的唯一文件名) | ||||||
|  |     full_path = os.path.join(file_dir, unique_filename) | ||||||
|  |     # 2. 相对路径:用于返回给前端 | ||||||
|  |     relative_path = full_path.replace(UPLOAD_ROOT, "", 1).lstrip(os.sep) | ||||||
|  |  | ||||||
|  |     # 保存文件(使用完整路径) | ||||||
|  |     try: | ||||||
|  |         with open(full_path, "wb") as buffer: | ||||||
|  |             shutil.copyfileobj(upload_file.file, buffer) | ||||||
|  |     finally: | ||||||
|  |         upload_file.file.close() | ||||||
|  |  | ||||||
|  |     # 统一路径分隔符为/ | ||||||
|  |     return PRE + relative_path.replace(os.sep, "/") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_absolute_path(relative_path: str) -> str: | ||||||
|  |     """ | ||||||
|  |     根据相对路径获取服务器上的绝对路径 | ||||||
|  |     """ | ||||||
|  |     path_without_pre = relative_path.replace(PRE, "", 1) | ||||||
|  |  | ||||||
|  |     # 将相对路径转换为系统兼容的格式 | ||||||
|  |     normalized_path = os.path.normpath(path_without_pre) | ||||||
|  |  | ||||||
|  |     # 拼接基础路径和相对路径,得到绝对路径 | ||||||
|  |     absolute_path = os.path.abspath(os.path.join(UPLOAD_ROOT, normalized_path)) | ||||||
|  |  | ||||||
|  |     # 安全检查:确保生成的路径在UPLOAD_ROOT目录下,防止路径遍历 | ||||||
|  |     if not absolute_path.startswith(os.path.abspath(UPLOAD_ROOT)): | ||||||
|  |         raise ValueError("无效的相对路径,可能试图访问上传目录之外的内容") | ||||||
|  |  | ||||||
|  |     return absolute_path | ||||||
							
								
								
									
										131
									
								
								service/model_service.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,131 @@ | |||||||
|  | from http.client import HTTPException | ||||||
|  |  | ||||||
|  | import numpy as np | ||||||
|  | import torch | ||||||
|  | from MySQLdb import MySQLError | ||||||
|  | from ultralytics import YOLO | ||||||
|  | import os | ||||||
|  |  | ||||||
|  | from ds.db import db | ||||||
|  | from service.file_service import get_absolute_path | ||||||
|  |  | ||||||
|  | # 全局变量 | ||||||
|  | current_yolo_model = None | ||||||
|  | current_model_absolute_path = None  # 存储模型绝对路径,不依赖model实例 | ||||||
|  |  | ||||||
|  | ALLOWED_MODEL_EXT = {"pt"} | ||||||
|  | MAX_MODEL_SIZE = 100 * 1024 * 1024  # 100MB | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def load_yolo_model(): | ||||||
|  |     """加载模型并存储绝对路径""" | ||||||
|  |     global current_yolo_model, current_model_absolute_path | ||||||
|  |     model_rel_path = get_enabled_model_rel_path() | ||||||
|  |     print(f"[模型初始化] 加载模型:{model_rel_path}") | ||||||
|  |  | ||||||
|  |     # 计算并存储绝对路径 | ||||||
|  |     current_model_absolute_path = get_absolute_path(model_rel_path) | ||||||
|  |     print(f"[模型初始化] 绝对路径:{current_model_absolute_path}") | ||||||
|  |  | ||||||
|  |     # 检查模型文件 | ||||||
|  |     if not os.path.exists(current_model_absolute_path): | ||||||
|  |         raise FileNotFoundError(f"模型文件不存在: {current_model_absolute_path}") | ||||||
|  |  | ||||||
|  |     try: | ||||||
|  |         new_model = YOLO(current_model_absolute_path) | ||||||
|  |         if torch.cuda.is_available(): | ||||||
|  |             new_model.to('cuda') | ||||||
|  |             print("模型已移动到GPU") | ||||||
|  |         else: | ||||||
|  |             print("使用CPU进行推理") | ||||||
|  |         current_yolo_model = new_model | ||||||
|  |         print(f"成功加载模型: {current_model_absolute_path}") | ||||||
|  |         return current_yolo_model | ||||||
|  |     except Exception as e: | ||||||
|  |         print(f"模型加载失败:{str(e)}") | ||||||
|  |         raise | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_current_model(): | ||||||
|  |     """获取当前模型实例""" | ||||||
|  |     if current_yolo_model is None: | ||||||
|  |         raise ValueError("尚未加载任何YOLO模型,请先调用load_yolo_model加载模型") | ||||||
|  |     return current_yolo_model | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def detect(image_np, conf_threshold=0.8): | ||||||
|  |     # 1. 输入格式验证 | ||||||
|  |     if not isinstance(image_np, np.ndarray): | ||||||
|  |         raise ValueError("输入必须是numpy数组(BGR图像)") | ||||||
|  |     if image_np.ndim != 3 or image_np.shape[-1] != 3: | ||||||
|  |         raise ValueError(f"输入图像格式错误,需为 (h, w, 3) 的BGR数组,当前shape: {image_np.shape}") | ||||||
|  |     detection_results = [] | ||||||
|  |     try: | ||||||
|  |         model = get_current_model() | ||||||
|  |         if not current_model_absolute_path: | ||||||
|  |             raise RuntimeError("模型未初始化!请先调用 load_yolo_model 加载模型") | ||||||
|  |         device = "cuda" if torch.cuda.is_available() else "cpu" | ||||||
|  |         print(f"检测设备:{device} | 置信度阈值:{conf_threshold}") | ||||||
|  |  | ||||||
|  |         # 图像尺寸信息 | ||||||
|  |         img_height, img_width = image_np.shape[:2] | ||||||
|  |         print(f"输入图像尺寸:{img_width}x{img_height}") | ||||||
|  |  | ||||||
|  |         # YOLO检测 | ||||||
|  |         print("执行YOLO检测") | ||||||
|  |         results = model.predict( | ||||||
|  |             image_np, | ||||||
|  |             conf=conf_threshold, | ||||||
|  |             device=device, | ||||||
|  |             show=False, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         # 4. 整理检测结果(仅保留Chest类别,ID=2) | ||||||
|  |         for box in results[0].boxes: | ||||||
|  |             class_id = int(box.cls[0])  # 类别ID | ||||||
|  |             class_name = model.names[class_id] | ||||||
|  |             confidence = float(box.conf[0]) | ||||||
|  |             bbox = tuple(map(int, box.xyxy[0])) | ||||||
|  |  | ||||||
|  |             # 过滤条件:置信度达标 + 类别为Chest(class_id=2) | ||||||
|  |             # and class_id == 2 | ||||||
|  |             if confidence >= conf_threshold: | ||||||
|  |                 detection_results.append({ | ||||||
|  |                     "class": class_name, | ||||||
|  |                     "confidence": confidence, | ||||||
|  |                     "bbox": bbox | ||||||
|  |                 }) | ||||||
|  |  | ||||||
|  |         # 判断是否有目标 | ||||||
|  |         has_content = len(detection_results) > 0 | ||||||
|  |         return has_content, detection_results | ||||||
|  |  | ||||||
|  |     except Exception as e: | ||||||
|  |         error_msg = f"检测过程出错:{str(e)}" | ||||||
|  |         print(error_msg) | ||||||
|  |         return False, None | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_enabled_model_rel_path(): | ||||||
|  |     """获取数据库中启用的模型相对路径""" | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |         query = "SELECT path FROM model WHERE is_default = 1 LIMIT 1" | ||||||
|  |         cursor.execute(query) | ||||||
|  |         result = cursor.fetchone() | ||||||
|  |  | ||||||
|  |         if not result or not result.get('path'): | ||||||
|  |             raise HTTPException(status_code=404, detail="未找到启用的默认模型") | ||||||
|  |  | ||||||
|  |         return result['path'] | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise HTTPException(status_code=500, detail=f"查询默认模型时发生数据库错误:{str(e)}") from e | ||||||
|  |     except Exception as e: | ||||||
|  |         if isinstance(e, HTTPException): | ||||||
|  |             raise e | ||||||
|  |         raise HTTPException(status_code=500, detail=f"获取默认模型路径失败:{str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
							
								
								
									
										131
									
								
								service/ocr_service.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,131 @@ | |||||||
|  | # 首先添加NumPy兼容处理 | ||||||
|  | import numpy as np | ||||||
|  |  | ||||||
|  | # 修复np.int已弃用的问题 | ||||||
|  | if not hasattr(np, 'int'): | ||||||
|  |     np.int = int | ||||||
|  |  | ||||||
|  | from paddleocr import PaddleOCR | ||||||
|  | from service.sensitive_service import get_all_sensitive_words | ||||||
|  |  | ||||||
|  | _ocr_engine = None | ||||||
|  | _forbidden_words = set() | ||||||
|  | _conf_threshold = 0.5 | ||||||
|  |  | ||||||
|  | def set_forbidden_words(new_words): | ||||||
|  |     global _forbidden_words | ||||||
|  |     if not isinstance(new_words, (set, list, tuple)): | ||||||
|  |         raise TypeError("新违禁词必须是集合、列表或元组类型") | ||||||
|  |     _forbidden_words = set(new_words)  # 确保是集合类型 | ||||||
|  |     print(f"已通过函数更新违禁词,当前数量: {len(_forbidden_words)}") | ||||||
|  |  | ||||||
|  | def load_forbidden_words(): | ||||||
|  |     global _forbidden_words | ||||||
|  |     try: | ||||||
|  |         _forbidden_words = get_all_sensitive_words() | ||||||
|  |         print(f"加载的违禁词数量: {len(_forbidden_words)}") | ||||||
|  |     except Exception as e: | ||||||
|  |         print(f"Forbidden words load error: {e}") | ||||||
|  |         return False | ||||||
|  |     return True | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def init_ocr_engine(): | ||||||
|  |     global _ocr_engine | ||||||
|  |     try: | ||||||
|  |         _ocr_engine = PaddleOCR( | ||||||
|  |             use_angle_cls=True, | ||||||
|  |             lang="ch", | ||||||
|  |             show_log=False, | ||||||
|  |             use_gpu=True, | ||||||
|  |             max_text_length=1024 | ||||||
|  |         ) | ||||||
|  |         load_result = load_forbidden_words() | ||||||
|  |         if not load_result: | ||||||
|  |             print("警告:违禁词加载失败,可能影响检测功能") | ||||||
|  |         print("OCR引擎初始化完成") | ||||||
|  |         return True | ||||||
|  |     except Exception as e: | ||||||
|  |         print(f"OCR引擎初始化错误: {e}") | ||||||
|  |         _ocr_engine = None | ||||||
|  |         return False | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def detect(frame, conf_threshold=0.8): | ||||||
|  |     print("开始进行OCR检测...") | ||||||
|  |     try: | ||||||
|  |         ocr_res = _ocr_engine.ocr(frame, cls=True) | ||||||
|  |         if not ocr_res or not isinstance(ocr_res, list): | ||||||
|  |             return (False, "无OCR结果") | ||||||
|  |  | ||||||
|  |         texts = [] | ||||||
|  |         confs = [] | ||||||
|  |         for line in ocr_res: | ||||||
|  |             if line is None: | ||||||
|  |                 continue | ||||||
|  |             if isinstance(line, list): | ||||||
|  |                 items_to_process = line | ||||||
|  |             else: | ||||||
|  |                 items_to_process = [line] | ||||||
|  |  | ||||||
|  |             for item in items_to_process: | ||||||
|  |                 if isinstance(item, list) and len(item) == 4: | ||||||
|  |                     is_coordinate = True | ||||||
|  |                     for point in item: | ||||||
|  |                         if not (isinstance(point, list) and len(point) == 2 and | ||||||
|  |                                 all(isinstance(coord, (int, float)) for coord in point)): | ||||||
|  |                             is_coordinate = False | ||||||
|  |                             break | ||||||
|  |                     if is_coordinate: | ||||||
|  |                         continue | ||||||
|  |                 if isinstance(item, list) and all(isinstance(x, (int, float)) for x in item): | ||||||
|  |                     continue | ||||||
|  |                 if isinstance(item, tuple) and len(item) == 2: | ||||||
|  |                     text, conf = item | ||||||
|  |                     if isinstance(text, str) and isinstance(conf, (int, float)): | ||||||
|  |                         texts.append(text.strip()) | ||||||
|  |                         confs.append(float(conf)) | ||||||
|  |                         continue | ||||||
|  |                 if isinstance(item, list) and len(item) >= 2: | ||||||
|  |                     text_data = item[1] | ||||||
|  |                     if isinstance(text_data, tuple) and len(text_data) == 2: | ||||||
|  |                         text, conf = text_data | ||||||
|  |                         if isinstance(text, str) and isinstance(conf, (int, float)): | ||||||
|  |                             texts.append(text.strip()) | ||||||
|  |                             confs.append(float(conf)) | ||||||
|  |                             continue | ||||||
|  |                     elif isinstance(text_data, str): | ||||||
|  |                         texts.append(text_data.strip()) | ||||||
|  |                         confs.append(1.0) | ||||||
|  |                         continue | ||||||
|  |                 print(f"无法解析的OCR结果格式: {item}") | ||||||
|  |  | ||||||
|  |         if len(texts) != len(confs): | ||||||
|  |             return (False, "OCR结果格式异常") | ||||||
|  |  | ||||||
|  |         # 收集所有识别到的违禁词(去重且保持出现顺序) | ||||||
|  |         vio_words = [] | ||||||
|  |         for txt, conf in zip(texts, confs): | ||||||
|  |             if conf < _conf_threshold:  # 过滤低置信度结果 | ||||||
|  |                 continue | ||||||
|  |             # 提取当前文本中包含的违禁词 | ||||||
|  |             matched = [w for w in _forbidden_words if w in txt] | ||||||
|  |             # 仅添加未记录过的违禁词(去重) | ||||||
|  |             for word in matched: | ||||||
|  |                 if word not in vio_words: | ||||||
|  |                     vio_words.append(word) | ||||||
|  |  | ||||||
|  |         has_text = len(texts) > 0 | ||||||
|  |         has_violation = len(vio_words) > 0 | ||||||
|  |  | ||||||
|  |         if not has_text: | ||||||
|  |             return (False, "未识别到文本") | ||||||
|  |         elif has_violation: | ||||||
|  |             # 多个违禁词用逗号拼接 | ||||||
|  |             return (True, ", ".join(vio_words)) | ||||||
|  |         else: | ||||||
|  |             return (False, "未检测到违禁词") | ||||||
|  |  | ||||||
|  |     except Exception as e: | ||||||
|  |         print(f"OCR detect error: {e}") | ||||||
|  |         return (False, f"检测错误: {str(e)}") | ||||||
							
								
								
									
										36
									
								
								service/sensitive_service.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| @ -0,0 +1,36 @@ | |||||||
|  | from mysql.connector import Error as MySQLError | ||||||
|  |  | ||||||
|  | from ds.db import db | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_all_sensitive_words() -> list[str]: | ||||||
|  |     """ | ||||||
|  |     获取所有敏感词(返回纯字符串列表、用于过滤业务) | ||||||
|  |  | ||||||
|  |     返回: | ||||||
|  |         list[str]: 包含所有敏感词的数组 | ||||||
|  |  | ||||||
|  |     异常: | ||||||
|  |         MySQLError: 数据库操作相关错误 | ||||||
|  |     """ | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         # 获取数据库连接 | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 执行查询(只获取敏感词字段、按ID排序) | ||||||
|  |         query = "SELECT name FROM sensitives ORDER BY id" | ||||||
|  |         cursor.execute(query) | ||||||
|  |         sensitive_records = cursor.fetchall() | ||||||
|  |  | ||||||
|  |         # 提取敏感词到纯字符串数组 | ||||||
|  |         return [record['name'] for record in sensitive_records] | ||||||
|  |  | ||||||
|  |     except MySQLError as e: | ||||||
|  |         # 数据库错误向上抛出、由调用方处理 | ||||||
|  |         raise MySQLError(f"查询敏感词列表失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         # 确保数据库连接正确释放 | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
| After Width: | Height: | Size: 652 KiB | 
| After Width: | Height: | Size: 550 KiB | 
| After Width: | Height: | Size: 549 KiB | 
| After Width: | Height: | Size: 562 KiB | 
| After Width: | Height: | Size: 546 KiB | 
| After Width: | Height: | Size: 585 KiB | 
| After Width: | Height: | Size: 584 KiB | 
| After Width: | Height: | Size: 561 KiB | 
| After Width: | Height: | Size: 561 KiB | 
| After Width: | Height: | Size: 562 KiB | 
| After Width: | Height: | Size: 546 KiB | 
| After Width: | Height: | Size: 386 KiB | 
| After Width: | Height: | Size: 338 KiB | 
| After Width: | Height: | Size: 338 KiB | 
| After Width: | Height: | Size: 341 KiB | 
| After Width: | Height: | Size: 347 KiB | 
| After Width: | Height: | Size: 337 KiB | 
| After Width: | Height: | Size: 250 KiB | 
| After Width: | Height: | Size: 247 KiB | 
| After Width: | Height: | Size: 250 KiB | 
| After Width: | Height: | Size: 822 KiB | 
| After Width: | Height: | Size: 805 KiB | 
| After Width: | Height: | Size: 791 KiB | 
| After Width: | Height: | Size: 1.1 MiB | 
| After Width: | Height: | Size: 877 KiB | 
| After Width: | Height: | Size: 814 KiB | 
| After Width: | Height: | Size: 749 KiB | 
| After Width: | Height: | Size: 197 KiB | 
| After Width: | Height: | Size: 200 KiB |