| 
									
										
										
										
											2025-09-09 16:30:12 +08:00
										 |  |  |  | import cv2 | 
					
						
							| 
									
										
										
										
											2025-09-10 08:57:56 +08:00
										 |  |  |  | import numpy as np | 
					
						
							|  |  |  |  | from PIL.Image import Image | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-10 10:53:07 +08:00
										 |  |  |  | from core.establish import get_image_save_path | 
					
						
							| 
									
										
										
										
											2025-09-04 22:59:27 +08:00
										 |  |  |  | from core.ocr import load_model as ocrLoadModel, detect as ocrDetect | 
					
						
							|  |  |  |  | from core.face import load_model as faceLoadModel, detect as faceDetect | 
					
						
							|  |  |  |  | from core.yolo import load_model as yoloLoadModel, detect as yoloDetect | 
					
						
							| 
									
										
										
										
											2025-09-09 16:30:12 +08:00
										 |  |  |  | # 导入保存路径函数(根据实际文件位置调整导入路径) | 
					
						
							| 
									
										
										
										
											2025-09-10 08:57:56 +08:00
										 |  |  |  | import numpy as np | 
					
						
							|  |  |  |  | import base64 | 
					
						
							|  |  |  |  | from io import BytesIO | 
					
						
							|  |  |  |  | from PIL import Image | 
					
						
							|  |  |  |  | from ds.db import db | 
					
						
							|  |  |  |  | from mysql.connector import Error as MySQLError | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-09 16:30:12 +08:00
										 |  |  |  | # 模型加载状态标记(避免重复加载) | 
					
						
							| 
									
										
										
										
											2025-09-04 22:59:27 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-09 16:30:12 +08:00
										 |  |  |  | _model_loaded = False | 
					
						
							| 
									
										
										
										
											2025-09-05 17:23:50 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-04 22:59:27 +08:00
										 |  |  |  | def load_model(): | 
					
						
							| 
									
										
										
										
											2025-09-09 16:30:12 +08:00
										 |  |  |  |     """加载所有检测模型(仅首次调用时执行)""" | 
					
						
							| 
									
										
										
										
											2025-09-04 22:59:27 +08:00
										 |  |  |  |     global _model_loaded | 
					
						
							| 
									
										
										
										
											2025-09-09 16:30:12 +08:00
										 |  |  |  |     if _model_loaded: | 
					
						
							|  |  |  |  |         print("模型已加载,无需重复执行") | 
					
						
							|  |  |  |  |         return | 
					
						
							| 
									
										
										
										
											2025-09-05 17:23:50 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-09 16:30:12 +08:00
										 |  |  |  |     # 依次加载OCR、人脸、YOLO模型 | 
					
						
							|  |  |  |  |     ocrLoadModel() | 
					
						
							|  |  |  |  |     faceLoadModel() | 
					
						
							|  |  |  |  |     yoloLoadModel() | 
					
						
							| 
									
										
										
										
											2025-09-05 17:23:50 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-09 16:30:12 +08:00
										 |  |  |  |     _model_loaded = True | 
					
						
							|  |  |  |  |     print("所有检测模型加载完成") | 
					
						
							| 
									
										
										
										
											2025-09-05 17:23:50 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-10 08:57:56 +08:00
										 |  |  |  | def save_db(model_type, client_ip, result): | 
					
						
							|  |  |  |  |     conn = None | 
					
						
							|  |  |  |  |     cursor = None | 
					
						
							|  |  |  |  |     try: | 
					
						
							|  |  |  |  |         # 连接数据库 | 
					
						
							|  |  |  |  |         conn = db.get_connection() | 
					
						
							|  |  |  |  |         # 往表插入数据 | 
					
						
							|  |  |  |  |         cursor = conn.cursor(dictionary=True)  # 返回字典格式结果 | 
					
						
							|  |  |  |  |         insert_query = """
 | 
					
						
							|  |  |  |  |             INSERT INTO device_danger (client_ip, type, result) | 
					
						
							|  |  |  |  |             VALUES (%s, %s, %s) | 
					
						
							|  |  |  |  |         """
 | 
					
						
							|  |  |  |  |         cursor.execute(insert_query, (client_ip, model_type, result)) | 
					
						
							|  |  |  |  |         conn.commit() | 
					
						
							|  |  |  |  |     except MySQLError as e: | 
					
						
							|  |  |  |  |         raise Exception(f"获取设备列表失败: {str(e)}") from e | 
					
						
							|  |  |  |  |     finally: | 
					
						
							|  |  |  |  |         db.close_connection(conn, cursor) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | # 修正后的 detect 函数关键部分 | 
					
						
							| 
									
										
										
										
											2025-09-10 08:57:56 +08:00
										 |  |  |  | def detect(client_ip, frame): | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  |     # 1. YOLO检测 | 
					
						
							| 
									
										
										
										
											2025-09-09 16:30:12 +08:00
										 |  |  |  |     yolo_flag, yolo_result = yoloDetect(frame) | 
					
						
							|  |  |  |  |     if yolo_flag: | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  |         # model_type 传入 "yolo"(正确) | 
					
						
							| 
									
										
										
										
											2025-09-10 10:53:07 +08:00
										 |  |  |  |         full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip) | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  |         if full_save_path: | 
					
						
							| 
									
										
										
										
											2025-09-10 10:53:07 +08:00
										 |  |  |  |             cv2.imwrite(full_save_path, frame) | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  |             print(f"✅ yolo违规图片已保存:{display_path}")  # 日志也修正 | 
					
						
							| 
									
										
										
										
											2025-09-16 20:17:48 +08:00
										 |  |  |  |         save_db(model_type="yolo", client_ip=client_ip, result=str(display_path)) | 
					
						
							| 
									
										
										
										
											2025-09-09 16:30:12 +08:00
										 |  |  |  |         return (True, yolo_result, "yolo") | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |     # 2. 人脸检测 | 
					
						
							| 
									
										
										
										
											2025-09-09 16:30:12 +08:00
										 |  |  |  |     face_flag, face_result = faceDetect(frame) | 
					
						
							|  |  |  |  |     if face_flag: | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  |         full_save_path, display_path = get_image_save_path(model_type="face", client_ip=client_ip)  # 这里改了 | 
					
						
							|  |  |  |  |         if full_save_path: | 
					
						
							| 
									
										
										
										
											2025-09-10 10:53:07 +08:00
										 |  |  |  |             cv2.imwrite(full_save_path, frame) | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  |             print(f"✅ face违规图片已保存:{display_path}")  # 日志也修正 | 
					
						
							| 
									
										
										
										
											2025-09-16 20:17:48 +08:00
										 |  |  |  |         save_db(model_type="face", client_ip=client_ip, result=str(display_path)) | 
					
						
							| 
									
										
										
										
											2025-09-09 16:30:12 +08:00
										 |  |  |  |         return (True, face_result, "face") | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  |     # 3. OCR检测 | 
					
						
							| 
									
										
										
										
											2025-09-09 16:30:12 +08:00
										 |  |  |  |     ocr_flag, ocr_result = ocrDetect(frame) | 
					
						
							|  |  |  |  |     if ocr_flag: | 
					
						
							| 
									
										
										
										
											2025-09-16 20:17:48 +08:00
										 |  |  |  |         full_save_path, display_path = get_image_save_path(model_type="ocr", client_ip=client_ip) | 
					
						
							|  |  |  |  |         print(f"✅ ocr违规图片已保存:{display_path}") | 
					
						
							|  |  |  |  |         # 这里改了 | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  |         if full_save_path: | 
					
						
							| 
									
										
										
										
											2025-09-10 10:53:07 +08:00
										 |  |  |  |             cv2.imwrite(full_save_path, frame) | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  |             print(f"✅ ocr违规图片已保存:{display_path}")  # 日志也修正 | 
					
						
							| 
									
										
										
										
											2025-09-16 20:17:48 +08:00
										 |  |  |  |         save_db(model_type="ocr", client_ip=client_ip, result=str(display_path)) | 
					
						
							| 
									
										
										
										
											2025-09-09 16:30:12 +08:00
										 |  |  |  |         return (True, ocr_result, "ocr") | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-09 16:30:12 +08:00
										 |  |  |  |     # 4. 无违规内容(不保存图片) | 
					
						
							|  |  |  |  |     print(f"❌ 未检测到任何违规内容,不保存图片") | 
					
						
							| 
									
										
										
										
											2025-09-10 10:53:07 +08:00
										 |  |  |  |     return (False, "未检测到任何内容", "none") |