Files
video_detect/core/detect.py
2025-09-30 17:17:20 +08:00

141 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import re
from mysql.connector import Error as MySQLError
from ds.config import BUSINESS_CONFIG
from ds.db import db
from service.face_service import detect as faceDetect,init_insightface
from service.model_service import load_yolo_model,detect as yoloDetect
from service.ocr_service import detect as ocrDetect,init_ocr_engine
from service.file_service import save_detect_file, save_detect_yolo_file, save_detect_face_file
import asyncio
from concurrent.futures import ThreadPoolExecutor
# 创建线程池执行器
executor = ThreadPoolExecutor(max_workers=10)
def init():
# # 人脸相关
init_insightface()
# # 初始化OCR引擎
init_ocr_engine()
#初始化YOLO模型
load_yolo_model()
def save_db(model_type, client_ip, result):
conn = None
cursor = None
try:
# 连接数据库
conn = db.get_connection()
# 往表插入数据
cursor = conn.cursor(dictionary=True) # 返回字典格式结果
insert_query = """
INSERT INTO device_danger (client_ip, type, result)
VALUES (%s, %s, %s)
"""
cursor.execute(insert_query, (client_ip, model_type, result))
conn.commit()
except MySQLError as e:
raise Exception(f"获取设备列表失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
def detectFrame(client_ip, frame):
# YOLO检测
yolo_flag, yolo_result = yoloDetect(frame, float(BUSINESS_CONFIG["yolo_conf"]))
if yolo_flag:
print(f"❌ 检测到违规内容,保存图片,YOLO")
danger_handler(client_ip)
path = save_detect_yolo_file(client_ip, frame, yolo_result, "yolo")
save_db(model_type="色情", client_ip=client_ip, result=str(path))
# 人脸检测
face_flag, face_result = faceDetect(frame, float(BUSINESS_CONFIG["face_conf"]))
if face_flag:
print(f"❌ 检测到违规内容,保存图片,FACE")
print("人脸识别内容:", face_result)
model_type = extract_face_names(face_result)
danger_handler(client_ip)
path = save_detect_face_file(client_ip, frame, face_result, "face")
save_db(model_type=model_type, client_ip=client_ip, result=str(path))
# OCR检测部分使用修正后的提取函数
ocr_flag, ocr_result = ocrDetect(frame, float(BUSINESS_CONFIG["ocr_conf"]))
if ocr_flag:
print(f"❌ 检测到违规内容,保存图片,OCR")
print("ocr识别内容:", ocr_result)
danger_handler(client_ip)
path = save_detect_file(client_ip, frame, "ocr")
save_db(model_type=str(ocr_result), client_ip=client_ip, result=str(path))
# 仅当所有检测均未发现违规时才提示
if not (face_flag or yolo_flag or ocr_flag):
print(f"所有模型未检测到任何违规内容")
def danger_handler(client_ip):
from ws.ws import send_message_to_client, get_current_time_str
from service.device_service import increment_alarm_count_by_ip
from service.device_service import update_is_need_handler_by_client_ip
danger_msg = {
"type": "danger",
"timestamp": get_current_time_str(),
"client_ip": client_ip,
}
asyncio.run(
send_message_to_client(
client_ip=client_ip,
json_data=json.dumps(danger_msg)
)
)
lock_msg = {
"type": "lock",
"timestamp": get_current_time_str(),
"client_ip": client_ip
}
asyncio.run(
send_message_to_client(
client_ip=client_ip,
json_data=json.dumps(lock_msg)
)
)
# 增加危险记录次数
increment_alarm_count_by_ip(client_ip)
# 更新设备状态为未处理
update_is_need_handler_by_client_ip(client_ip, 1)
def extract_prohibited_words(ocr_result: str) -> str:
"""
从多文本块的ocr_result中提取所有违禁词去重后用逗号拼接
适配格式:多个"文本: ... 包含违禁词: ...;"片段
"""
# 用正则匹配所有"包含违禁词: ...;"的片段(非贪婪匹配到分号)
# 匹配规则:"包含违禁词: "后面的内容,直到遇到";"结束
pattern = r"包含违禁词: (.*?);"
all_prohibited_segments = re.findall(pattern, ocr_result, re.DOTALL)
all_words = []
for segment in all_prohibited_segments:
# 去除每个片段中的置信度信息(如"(置信度: 1.00)"
cleaned = re.sub(r"\s*\([^)]*\)", "", segment.strip())
# 分割词语并过滤空值
words = [word.strip() for word in cleaned.split(",") if word.strip()]
all_words.extend(words)
# 去重后用逗号拼接
unique_words = list(set(all_words))
return ",".join(unique_words)
def extract_face_names(face_result: str) -> str:
pattern = r"匹配: (.*?) \("
all_names = re.findall(pattern, face_result)
unique_names = list(set([name.strip() for name in all_names if name.strip()]))
return ",".join(unique_names)