95 lines
		
	
	
		
			3.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			95 lines
		
	
	
		
			3.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import cv2
 | ||
| import numpy as np
 | ||
| from PIL.Image import Image
 | ||
| 
 | ||
| from core.establish import get_image_save_path
 | ||
| from core.ocr import load_model as ocrLoadModel, detect as ocrDetect
 | ||
| from core.face import load_model as faceLoadModel, detect as faceDetect
 | ||
| from core.yolo import load_model as yoloLoadModel, detect as yoloDetect
 | ||
| # 导入保存路径函数(根据实际文件位置调整导入路径)
 | ||
| import numpy as np
 | ||
| import base64
 | ||
| from io import BytesIO
 | ||
| from PIL import Image
 | ||
| from ds.db import db
 | ||
| from mysql.connector import Error as MySQLError
 | ||
| 
 | ||
| # 模型加载状态标记(避免重复加载)
 | ||
| 
 | ||
| 
 | ||
| _model_loaded = False
 | ||
| 
 | ||
| 
 | ||
| def load_model():
 | ||
|     """加载所有检测模型(仅首次调用时执行)"""
 | ||
|     global _model_loaded
 | ||
|     if _model_loaded:
 | ||
|         print("模型已加载,无需重复执行")
 | ||
|         return
 | ||
| 
 | ||
|     # 依次加载OCR、人脸、YOLO模型
 | ||
|     ocrLoadModel()
 | ||
|     faceLoadModel()
 | ||
|     yoloLoadModel()
 | ||
| 
 | ||
|     _model_loaded = True
 | ||
|     print("所有检测模型加载完成")
 | ||
| 
 | ||
| 
 | ||
| def save_db(model_type, client_ip, result):
 | ||
|     conn = None
 | ||
|     cursor = None
 | ||
|     try:
 | ||
|         # 连接数据库
 | ||
|         conn = db.get_connection()
 | ||
|         # 往表插入数据
 | ||
|         cursor = conn.cursor(dictionary=True)  # 返回字典格式结果
 | ||
|         insert_query = """
 | ||
|             INSERT INTO device_danger (client_ip, type, result)
 | ||
|             VALUES (%s, %s, %s)
 | ||
|         """
 | ||
|         cursor.execute(insert_query, (client_ip, model_type, result))
 | ||
|         conn.commit()
 | ||
|     except MySQLError as e:
 | ||
|         raise Exception(f"获取设备列表失败: {str(e)}") from e
 | ||
|     finally:
 | ||
|         db.close_connection(conn, cursor)
 | ||
| 
 | ||
| 
 | ||
| 
 | ||
| # 修正后的 detect 函数关键部分
 | ||
| def detect(client_ip, frame):
 | ||
|     # 1. YOLO检测
 | ||
|     yolo_flag, yolo_result = yoloDetect(frame)
 | ||
|     if yolo_flag:
 | ||
|         # model_type 传入 "yolo"(正确)
 | ||
|         full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip)
 | ||
|         if full_save_path:
 | ||
|             cv2.imwrite(full_save_path, frame)
 | ||
|             print(f"✅ yolo违规图片已保存:{display_path}")  # 日志也修正
 | ||
|         save_db(model_type="yolo", client_ip=client_ip, result=str(full_save_path))
 | ||
|         return (True, yolo_result, "yolo")
 | ||
| 
 | ||
|     # 2. 人脸检测
 | ||
|     face_flag, face_result = faceDetect(frame)
 | ||
|     if face_flag:
 | ||
|         full_save_path, display_path = get_image_save_path(model_type="face", client_ip=client_ip)  # 这里改了
 | ||
|         if full_save_path:
 | ||
|             cv2.imwrite(full_save_path, frame)
 | ||
|             print(f"✅ face违规图片已保存:{display_path}")  # 日志也修正
 | ||
|         save_db(model_type="face", client_ip=client_ip, result=str(full_save_path))
 | ||
|         return (True, face_result, "face")
 | ||
| 
 | ||
|     # 3. OCR检测
 | ||
|     ocr_flag, ocr_result = ocrDetect(frame)
 | ||
|     if ocr_flag:
 | ||
|         full_save_path, display_path = get_image_save_path(model_type="ocr", client_ip=client_ip)  # 这里改了
 | ||
|         if full_save_path:
 | ||
|             cv2.imwrite(full_save_path, frame)
 | ||
|             print(f"✅ ocr违规图片已保存:{display_path}")  # 日志也修正
 | ||
|         save_db(model_type="ocr", client_ip=client_ip, result=str(full_save_path))
 | ||
|         return (True, ocr_result, "ocr")
 | ||
| 
 | ||
|     # 4. 无违规内容(不保存图片)
 | ||
|     print(f"❌ 未检测到任何违规内容,不保存图片")
 | ||
|     return (False, "未检测到任何内容", "none") |