115 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			115 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import json
 | ||
| import re
 | ||
| 
 | ||
| from mysql.connector import Error as MySQLError
 | ||
| 
 | ||
| from ds.config import BUSINESS_CONFIG
 | ||
| from ds.db import db
 | ||
| from service.face_service import detect as faceDetect,init_insightface
 | ||
| from service.model_service import load_yolo_model,detect as yoloDetect
 | ||
| from service.ocr_service import detect as ocrDetect,init_ocr_engine
 | ||
| from service.file_service import save_detect_file, save_detect_yolo_file, save_detect_face_file
 | ||
| import asyncio
 | ||
| from concurrent.futures import ThreadPoolExecutor
 | ||
| 
 | ||
| 
 | ||
| # 创建线程池执行器
 | ||
| executor = ThreadPoolExecutor(max_workers=10)
 | ||
| def init():
 | ||
|     # # 人脸相关
 | ||
|     init_insightface()
 | ||
|     # # 初始化OCR引擎
 | ||
|     init_ocr_engine()
 | ||
|     #初始化YOLO模型
 | ||
|     load_yolo_model()
 | ||
| 
 | ||
| def save_db(model_type, client_ip, result):
 | ||
|     conn = None
 | ||
|     cursor = None
 | ||
|     try:
 | ||
|         # 连接数据库
 | ||
|         conn = db.get_connection()
 | ||
|         # 往表插入数据
 | ||
|         cursor = conn.cursor(dictionary=True)  # 返回字典格式结果
 | ||
|         insert_query = """
 | ||
|             INSERT INTO device_danger (client_ip, type, result)
 | ||
|             VALUES (%s, %s, %s)
 | ||
|         """
 | ||
|         cursor.execute(insert_query, (client_ip, model_type, result))
 | ||
|         conn.commit()
 | ||
|     except MySQLError as e:
 | ||
|         raise Exception(f"获取设备列表失败: {str(e)}") from e
 | ||
|     finally:
 | ||
|         db.close_connection(conn, cursor)
 | ||
| 
 | ||
| 
 | ||
| def detectFrame(client_ip, frame):
 | ||
|     # YOLO检测
 | ||
|     yolo_flag, yolo_result = yoloDetect(frame, float(BUSINESS_CONFIG["yolo_conf"]))
 | ||
|     if yolo_flag:
 | ||
|         print(f"❌ 检测到违规内容,保存图片,YOLO")
 | ||
|         danger_handler(client_ip)
 | ||
|         path = save_detect_yolo_file(client_ip, frame, yolo_result, "yolo")
 | ||
|         save_db(model_type="色情", client_ip=client_ip, result=str(path))
 | ||
| 
 | ||
|     # 人脸检测
 | ||
|     face_flag, face_result = faceDetect(frame, float(BUSINESS_CONFIG["face_conf"]))
 | ||
|     if face_flag:
 | ||
|         print(f"❌ 检测到违规内容,保存图片,FACE")
 | ||
|         print("人脸识别内容:", face_result)
 | ||
|         model_type = extract_face_names(face_result)
 | ||
|         danger_handler(client_ip)
 | ||
|         path = save_detect_face_file(client_ip, frame, face_result, "face")
 | ||
|         save_db(model_type=model_type, client_ip=client_ip, result=str(path))
 | ||
| 
 | ||
|     # OCR检测部分(使用修正后的提取函数)
 | ||
|     ocr_flag, ocr_result = ocrDetect(frame, float(BUSINESS_CONFIG["ocr_conf"]))
 | ||
|     if ocr_flag:
 | ||
|         print(f"❌ 检测到违规内容,保存图片,OCR")
 | ||
|         print("ocr识别内容:", ocr_result)
 | ||
|         danger_handler(client_ip)
 | ||
|         path = save_detect_file(client_ip, frame, "ocr")
 | ||
|         save_db(model_type=str(ocr_result), client_ip=client_ip, result=str(path))
 | ||
| 
 | ||
|     # 仅当所有检测均未发现违规时才提示
 | ||
|     if not (face_flag or yolo_flag or ocr_flag):
 | ||
|         print(f"所有模型未检测到任何违规内容")
 | ||
| 
 | ||
| def danger_handler(client_ip):
 | ||
|     from ws.ws import send_message_to_client, get_current_time_str
 | ||
|     from service.device_service import increment_alarm_count_by_ip
 | ||
|     from service.device_service import update_is_need_handler_by_client_ip
 | ||
| 
 | ||
|     danger_msg = {
 | ||
|         "type": "danger",
 | ||
|         "timestamp": get_current_time_str(),
 | ||
|         "client_ip": client_ip,
 | ||
|     }
 | ||
|     asyncio.run(
 | ||
|         send_message_to_client(
 | ||
|             client_ip=client_ip,
 | ||
|             json_data=json.dumps(danger_msg)
 | ||
|         )
 | ||
|     )
 | ||
|     lock_msg = {
 | ||
|         "type": "lock",
 | ||
|         "timestamp": get_current_time_str(),
 | ||
|         "client_ip": client_ip
 | ||
|     }
 | ||
|     asyncio.run(
 | ||
|         send_message_to_client(
 | ||
|             client_ip=client_ip,
 | ||
|             json_data=json.dumps(lock_msg)
 | ||
|         )
 | ||
|     )
 | ||
|     # 增加危险记录次数
 | ||
|     increment_alarm_count_by_ip(client_ip)
 | ||
|     # 更新设备状态为未处理
 | ||
|     update_is_need_handler_by_client_ip(client_ip, 1)
 | ||
| 
 | ||
| def extract_face_names(face_result: str) -> str:
 | ||
|     pattern = r"匹配: (.*?) \("
 | ||
|     all_names = re.findall(pattern, face_result)
 | ||
|     unique_names = list(set([name.strip() for name in all_names if name.strip()]))
 | ||
|     return ",".join(unique_names)
 |