115 lines
3.9 KiB
Python
115 lines
3.9 KiB
Python
import json
|
||
import re
|
||
|
||
from mysql.connector import Error as MySQLError
|
||
|
||
from ds.config import BUSINESS_CONFIG
|
||
from ds.db import db
|
||
from service.face_service import detect as faceDetect,init_insightface
|
||
from service.model_service import load_yolo_model,detect as yoloDetect
|
||
from service.ocr_service import detect as ocrDetect,init_ocr_engine
|
||
from service.file_service import save_detect_file, save_detect_yolo_file, save_detect_face_file
|
||
import asyncio
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
|
||
|
||
# 创建线程池执行器
|
||
executor = ThreadPoolExecutor(max_workers=10)
|
||
def init():
|
||
# # 人脸相关
|
||
init_insightface()
|
||
# # 初始化OCR引擎
|
||
init_ocr_engine()
|
||
#初始化YOLO模型
|
||
load_yolo_model()
|
||
|
||
def save_db(model_type, client_ip, result):
|
||
conn = None
|
||
cursor = None
|
||
try:
|
||
# 连接数据库
|
||
conn = db.get_connection()
|
||
# 往表插入数据
|
||
cursor = conn.cursor(dictionary=True) # 返回字典格式结果
|
||
insert_query = """
|
||
INSERT INTO device_danger (client_ip, type, result)
|
||
VALUES (%s, %s, %s)
|
||
"""
|
||
cursor.execute(insert_query, (client_ip, model_type, result))
|
||
conn.commit()
|
||
except MySQLError as e:
|
||
raise Exception(f"获取设备列表失败: {str(e)}") from e
|
||
finally:
|
||
db.close_connection(conn, cursor)
|
||
|
||
|
||
def detectFrame(client_ip, frame):
|
||
# YOLO检测
|
||
yolo_flag, yolo_result = yoloDetect(frame, float(BUSINESS_CONFIG["yolo_conf"]))
|
||
if yolo_flag:
|
||
print(f"❌ 检测到违规内容,保存图片,YOLO")
|
||
danger_handler(client_ip)
|
||
path = save_detect_yolo_file(client_ip, frame, yolo_result, "yolo")
|
||
save_db(model_type="色情", client_ip=client_ip, result=str(path))
|
||
|
||
# 人脸检测
|
||
face_flag, face_result = faceDetect(frame, float(BUSINESS_CONFIG["face_conf"]))
|
||
if face_flag:
|
||
print(f"❌ 检测到违规内容,保存图片,FACE")
|
||
print("人脸识别内容:", face_result)
|
||
model_type = extract_face_names(face_result)
|
||
danger_handler(client_ip)
|
||
path = save_detect_face_file(client_ip, frame, face_result, "face")
|
||
save_db(model_type=model_type, client_ip=client_ip, result=str(path))
|
||
|
||
# OCR检测部分(使用修正后的提取函数)
|
||
ocr_flag, ocr_result = ocrDetect(frame, float(BUSINESS_CONFIG["ocr_conf"]))
|
||
if ocr_flag:
|
||
print(f"❌ 检测到违规内容,保存图片,OCR")
|
||
print("ocr识别内容:", ocr_result)
|
||
danger_handler(client_ip)
|
||
path = save_detect_file(client_ip, frame, "ocr")
|
||
save_db(model_type=str(ocr_result), client_ip=client_ip, result=str(path))
|
||
|
||
# 仅当所有检测均未发现违规时才提示
|
||
if not (face_flag or yolo_flag or ocr_flag):
|
||
print(f"所有模型未检测到任何违规内容")
|
||
|
||
def danger_handler(client_ip):
|
||
from ws.ws import send_message_to_client, get_current_time_str
|
||
from service.device_service import increment_alarm_count_by_ip
|
||
from service.device_service import update_is_need_handler_by_client_ip
|
||
|
||
danger_msg = {
|
||
"type": "danger",
|
||
"timestamp": get_current_time_str(),
|
||
"client_ip": client_ip,
|
||
}
|
||
asyncio.run(
|
||
send_message_to_client(
|
||
client_ip=client_ip,
|
||
json_data=json.dumps(danger_msg)
|
||
)
|
||
)
|
||
lock_msg = {
|
||
"type": "lock",
|
||
"timestamp": get_current_time_str(),
|
||
"client_ip": client_ip
|
||
}
|
||
asyncio.run(
|
||
send_message_to_client(
|
||
client_ip=client_ip,
|
||
json_data=json.dumps(lock_msg)
|
||
)
|
||
)
|
||
# 增加危险记录次数
|
||
increment_alarm_count_by_ip(client_ip)
|
||
# 更新设备状态为未处理
|
||
update_is_need_handler_by_client_ip(client_ip, 1)
|
||
|
||
def extract_face_names(face_result: str) -> str:
|
||
pattern = r"匹配: (.*?) \("
|
||
all_names = re.findall(pattern, face_result)
|
||
unique_names = list(set([name.strip() for name in all_names if name.strip()]))
|
||
return ",".join(unique_names)
|