import json import re from mysql.connector import Error as MySQLError from ds.config import BUSINESS_CONFIG from ds.db import db from service.face_service import detect as faceDetect,init_insightface from service.model_service import load_yolo_model,detect as yoloDetect from service.ocr_service import detect as ocrDetect,init_ocr_engine from service.file_service import save_detect_file, save_detect_yolo_file, save_detect_face_file import asyncio from concurrent.futures import ThreadPoolExecutor # 创建线程池执行器 executor = ThreadPoolExecutor(max_workers=10) def init(): # # 人脸相关 init_insightface() # # 初始化OCR引擎 init_ocr_engine() #初始化YOLO模型 load_yolo_model() def save_db(model_type, client_ip, result): conn = None cursor = None try: # 连接数据库 conn = db.get_connection() # 往表插入数据 cursor = conn.cursor(dictionary=True) # 返回字典格式结果 insert_query = """ INSERT INTO device_danger (client_ip, type, result) VALUES (%s, %s, %s) """ cursor.execute(insert_query, (client_ip, model_type, result)) conn.commit() except MySQLError as e: raise Exception(f"获取设备列表失败: {str(e)}") from e finally: db.close_connection(conn, cursor) def detectFrame(client_ip, frame): # YOLO检测 yolo_flag, yolo_result = yoloDetect(frame, float(BUSINESS_CONFIG["yolo_conf"])) if yolo_flag: print(f"❌ 检测到违规内容,保存图片,YOLO") danger_handler(client_ip) path = save_detect_yolo_file(client_ip, frame, yolo_result, "yolo") save_db(model_type="色情", client_ip=client_ip, result=str(path)) # 人脸检测 face_flag, face_result = faceDetect(frame, float(BUSINESS_CONFIG["face_conf"])) if face_flag: print(f"❌ 检测到违规内容,保存图片,FACE") print("人脸识别内容:", face_result) model_type = extract_face_names(face_result) danger_handler(client_ip) path = save_detect_face_file(client_ip, frame, face_result, "face") save_db(model_type=model_type, client_ip=client_ip, result=str(path)) # OCR检测部分(使用修正后的提取函数) ocr_flag, ocr_result = ocrDetect(frame, float(BUSINESS_CONFIG["ocr_conf"])) if ocr_flag: print(f"❌ 检测到违规内容,保存图片,OCR") print("ocr识别内容:", ocr_result) danger_handler(client_ip) path = save_detect_file(client_ip, frame, "ocr") save_db(model_type=str(ocr_result), client_ip=client_ip, result=str(path)) # 仅当所有检测均未发现违规时才提示 if not (face_flag or yolo_flag or ocr_flag): print(f"所有模型未检测到任何违规内容") def danger_handler(client_ip): from ws.ws import send_message_to_client, get_current_time_str from service.device_service import increment_alarm_count_by_ip from service.device_service import update_is_need_handler_by_client_ip danger_msg = { "type": "danger", "timestamp": get_current_time_str(), "client_ip": client_ip, } asyncio.run( send_message_to_client( client_ip=client_ip, json_data=json.dumps(danger_msg) ) ) lock_msg = { "type": "lock", "timestamp": get_current_time_str(), "client_ip": client_ip } asyncio.run( send_message_to_client( client_ip=client_ip, json_data=json.dumps(lock_msg) ) ) # 增加危险记录次数 increment_alarm_count_by_ip(client_ip) # 更新设备状态为未处理 update_is_need_handler_by_client_ip(client_ip, 1) def extract_face_names(face_result: str) -> str: pattern = r"匹配: (.*?) \(" all_names = re.findall(pattern, face_result) unique_names = list(set([name.strip() for name in all_names if name.strip()])) return ",".join(unique_names)