115 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			115 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|  | import json | |||
|  | import re | |||
|  | 
 | |||
|  | from mysql.connector import Error as MySQLError | |||
|  | 
 | |||
|  | from ds.config import BUSINESS_CONFIG | |||
|  | from ds.db import db | |||
|  | from service.face_service import detect as faceDetect,init_insightface | |||
|  | from service.model_service import load_yolo_model,detect as yoloDetect | |||
|  | from service.ocr_service import detect as ocrDetect,init_ocr_engine | |||
|  | from service.file_service import save_detect_file, save_detect_yolo_file, save_detect_face_file | |||
|  | import asyncio | |||
|  | from concurrent.futures import ThreadPoolExecutor | |||
|  | 
 | |||
|  | 
 | |||
|  | # 创建线程池执行器 | |||
|  | executor = ThreadPoolExecutor(max_workers=10) | |||
|  | def init(): | |||
|  |     # # 人脸相关 | |||
|  |     init_insightface() | |||
|  |     # # 初始化OCR引擎 | |||
|  |     init_ocr_engine() | |||
|  |     #初始化YOLO模型 | |||
|  |     load_yolo_model() | |||
|  | 
 | |||
|  | def save_db(model_type, client_ip, result): | |||
|  |     conn = None | |||
|  |     cursor = None | |||
|  |     try: | |||
|  |         # 连接数据库 | |||
|  |         conn = db.get_connection() | |||
|  |         # 往表插入数据 | |||
|  |         cursor = conn.cursor(dictionary=True)  # 返回字典格式结果 | |||
|  |         insert_query = """
 | |||
|  |             INSERT INTO device_danger (client_ip, type, result) | |||
|  |             VALUES (%s, %s, %s) | |||
|  |         """
 | |||
|  |         cursor.execute(insert_query, (client_ip, model_type, result)) | |||
|  |         conn.commit() | |||
|  |     except MySQLError as e: | |||
|  |         raise Exception(f"获取设备列表失败: {str(e)}") from e | |||
|  |     finally: | |||
|  |         db.close_connection(conn, cursor) | |||
|  | 
 | |||
|  | 
 | |||
|  | def detectFrame(client_ip, frame): | |||
|  |     # YOLO检测 | |||
|  |     yolo_flag, yolo_result = yoloDetect(frame, float(BUSINESS_CONFIG["yolo_conf"])) | |||
|  |     if yolo_flag: | |||
|  |         print(f"❌ 检测到违规内容,保存图片,YOLO") | |||
|  |         danger_handler(client_ip) | |||
|  |         path = save_detect_yolo_file(client_ip, frame, yolo_result, "yolo") | |||
|  |         save_db(model_type="色情", client_ip=client_ip, result=str(path)) | |||
|  | 
 | |||
|  |     # 人脸检测 | |||
|  |     face_flag, face_result = faceDetect(frame, float(BUSINESS_CONFIG["face_conf"])) | |||
|  |     if face_flag: | |||
|  |         print(f"❌ 检测到违规内容,保存图片,FACE") | |||
|  |         print("人脸识别内容:", face_result) | |||
|  |         model_type = extract_face_names(face_result) | |||
|  |         danger_handler(client_ip) | |||
|  |         path = save_detect_face_file(client_ip, frame, face_result, "face") | |||
|  |         save_db(model_type=model_type, client_ip=client_ip, result=str(path)) | |||
|  | 
 | |||
|  |     # OCR检测部分(使用修正后的提取函数) | |||
|  |     ocr_flag, ocr_result = ocrDetect(frame, float(BUSINESS_CONFIG["ocr_conf"])) | |||
|  |     if ocr_flag: | |||
|  |         print(f"❌ 检测到违规内容,保存图片,OCR") | |||
|  |         print("ocr识别内容:", ocr_result) | |||
|  |         danger_handler(client_ip) | |||
|  |         path = save_detect_file(client_ip, frame, "ocr") | |||
|  |         save_db(model_type=str(ocr_result), client_ip=client_ip, result=str(path)) | |||
|  | 
 | |||
|  |     # 仅当所有检测均未发现违规时才提示 | |||
|  |     if not (face_flag or yolo_flag or ocr_flag): | |||
|  |         print(f"所有模型未检测到任何违规内容") | |||
|  | 
 | |||
|  | def danger_handler(client_ip): | |||
|  |     from ws.ws import send_message_to_client, get_current_time_str | |||
|  |     from service.device_service import increment_alarm_count_by_ip | |||
|  |     from service.device_service import update_is_need_handler_by_client_ip | |||
|  | 
 | |||
|  |     danger_msg = { | |||
|  |         "type": "danger", | |||
|  |         "timestamp": get_current_time_str(), | |||
|  |         "client_ip": client_ip, | |||
|  |     } | |||
|  |     asyncio.run( | |||
|  |         send_message_to_client( | |||
|  |             client_ip=client_ip, | |||
|  |             json_data=json.dumps(danger_msg) | |||
|  |         ) | |||
|  |     ) | |||
|  |     lock_msg = { | |||
|  |         "type": "lock", | |||
|  |         "timestamp": get_current_time_str(), | |||
|  |         "client_ip": client_ip | |||
|  |     } | |||
|  |     asyncio.run( | |||
|  |         send_message_to_client( | |||
|  |             client_ip=client_ip, | |||
|  |             json_data=json.dumps(lock_msg) | |||
|  |         ) | |||
|  |     ) | |||
|  |     # 增加危险记录次数 | |||
|  |     increment_alarm_count_by_ip(client_ip) | |||
|  |     # 更新设备状态为未处理 | |||
|  |     update_is_need_handler_by_client_ip(client_ip, 1) | |||
|  | 
 | |||
|  | def extract_face_names(face_result: str) -> str: | |||
|  |     pattern = r"匹配: (.*?) \(" | |||
|  |     all_names = re.findall(pattern, face_result) | |||
|  |     unique_names = list(set([name.strip() for name in all_names if name.strip()])) | |||
|  |     return ",".join(unique_names) |