2025-09-09 16:30:12 +08:00
|
|
|
|
import cv2
|
2025-09-10 08:57:56 +08:00
|
|
|
|
import numpy as np
|
|
|
|
|
from PIL.Image import Image
|
|
|
|
|
|
2025-09-10 10:53:07 +08:00
|
|
|
|
from core.establish import get_image_save_path
|
2025-09-04 22:59:27 +08:00
|
|
|
|
from core.ocr import load_model as ocrLoadModel, detect as ocrDetect
|
|
|
|
|
from core.face import load_model as faceLoadModel, detect as faceDetect
|
|
|
|
|
from core.yolo import load_model as yoloLoadModel, detect as yoloDetect
|
2025-09-09 16:30:12 +08:00
|
|
|
|
# 导入保存路径函数(根据实际文件位置调整导入路径)
|
2025-09-10 08:57:56 +08:00
|
|
|
|
import numpy as np
|
|
|
|
|
import base64
|
|
|
|
|
from io import BytesIO
|
|
|
|
|
from PIL import Image
|
|
|
|
|
from ds.db import db
|
|
|
|
|
from mysql.connector import Error as MySQLError
|
|
|
|
|
|
2025-09-09 16:30:12 +08:00
|
|
|
|
# 模型加载状态标记(避免重复加载)
|
2025-09-04 22:59:27 +08:00
|
|
|
|
|
|
|
|
|
|
2025-09-09 16:30:12 +08:00
|
|
|
|
_model_loaded = False
|
2025-09-05 17:23:50 +08:00
|
|
|
|
|
|
|
|
|
|
2025-09-04 22:59:27 +08:00
|
|
|
|
def load_model():
|
2025-09-09 16:30:12 +08:00
|
|
|
|
"""加载所有检测模型(仅首次调用时执行)"""
|
2025-09-04 22:59:27 +08:00
|
|
|
|
global _model_loaded
|
2025-09-09 16:30:12 +08:00
|
|
|
|
if _model_loaded:
|
|
|
|
|
print("模型已加载,无需重复执行")
|
|
|
|
|
return
|
2025-09-05 17:23:50 +08:00
|
|
|
|
|
2025-09-09 16:30:12 +08:00
|
|
|
|
# 依次加载OCR、人脸、YOLO模型
|
|
|
|
|
ocrLoadModel()
|
|
|
|
|
faceLoadModel()
|
|
|
|
|
yoloLoadModel()
|
2025-09-05 17:23:50 +08:00
|
|
|
|
|
2025-09-09 16:30:12 +08:00
|
|
|
|
_model_loaded = True
|
|
|
|
|
print("所有检测模型加载完成")
|
2025-09-05 17:23:50 +08:00
|
|
|
|
|
|
|
|
|
|
2025-09-10 08:57:56 +08:00
|
|
|
|
def save_db(model_type, client_ip, result):
|
|
|
|
|
conn = None
|
|
|
|
|
cursor = None
|
|
|
|
|
try:
|
|
|
|
|
# 连接数据库
|
|
|
|
|
conn = db.get_connection()
|
|
|
|
|
# 往表插入数据
|
|
|
|
|
cursor = conn.cursor(dictionary=True) # 返回字典格式结果
|
|
|
|
|
insert_query = """
|
|
|
|
|
INSERT INTO device_danger (client_ip, type, result)
|
|
|
|
|
VALUES (%s, %s, %s)
|
|
|
|
|
"""
|
|
|
|
|
cursor.execute(insert_query, (client_ip, model_type, result))
|
|
|
|
|
conn.commit()
|
|
|
|
|
except MySQLError as e:
|
|
|
|
|
raise Exception(f"获取设备列表失败: {str(e)}") from e
|
|
|
|
|
finally:
|
|
|
|
|
db.close_connection(conn, cursor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def detect(client_ip, frame):
|
2025-09-05 17:23:50 +08:00
|
|
|
|
"""
|
2025-09-09 16:30:12 +08:00
|
|
|
|
执行模型检测,检测到违规时按指定格式保存图片
|
|
|
|
|
参数:
|
|
|
|
|
frame: 待检测的图像帧(OpenCV格式,numpy.ndarray类型)
|
|
|
|
|
返回:
|
|
|
|
|
(检测结果布尔值, 检测详情, 检测模型类型)
|
2025-09-05 17:23:50 +08:00
|
|
|
|
"""
|
2025-09-09 16:30:12 +08:00
|
|
|
|
# 1. YOLO检测(优先级1)
|
|
|
|
|
yolo_flag, yolo_result = yoloDetect(frame)
|
|
|
|
|
print(f"YOLO检测结果:{yolo_result}")
|
|
|
|
|
if yolo_flag:
|
2025-09-10 10:53:07 +08:00
|
|
|
|
full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip)
|
|
|
|
|
if full_save_path: # 只判断完整路径是否有效(用于保存)
|
|
|
|
|
cv2.imwrite(full_save_path, frame)
|
|
|
|
|
# 打印时使用「显示用短路径」,符合需求格式
|
|
|
|
|
print(f"✅ YOLO违规图片已保存:{display_path}")
|
|
|
|
|
save_db(model_type="yolo", client_ip=client_ip, result=str(full_save_path))
|
2025-09-09 16:30:12 +08:00
|
|
|
|
return (True, yolo_result, "yolo")
|
2025-09-10 08:57:56 +08:00
|
|
|
|
#
|
|
|
|
|
# # 2. 人脸检测(优先级2)
|
2025-09-09 16:30:12 +08:00
|
|
|
|
face_flag, face_result = faceDetect(frame)
|
|
|
|
|
print(f"人脸检测结果:{face_result}")
|
|
|
|
|
if face_flag:
|
2025-09-10 10:53:07 +08:00
|
|
|
|
full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip)
|
|
|
|
|
if full_save_path: # 只判断完整路径是否有效(用于保存)
|
|
|
|
|
cv2.imwrite(full_save_path, frame)
|
|
|
|
|
# 打印时使用「显示用短路径」,符合需求格式
|
|
|
|
|
print(f"✅ face违规图片已保存:{display_path}")
|
|
|
|
|
save_db(model_type="face", client_ip=client_ip, result=str(full_save_path))
|
2025-09-09 16:30:12 +08:00
|
|
|
|
return (True, face_result, "face")
|
|
|
|
|
|
|
|
|
|
# 3. OCR检测(优先级3)
|
|
|
|
|
ocr_flag, ocr_result = ocrDetect(frame)
|
|
|
|
|
print(f"OCR检测结果:{ocr_result}")
|
|
|
|
|
if ocr_flag:
|
2025-09-09 17:09:34 +08:00
|
|
|
|
# 解构元组,保存用完整路径,打印用短路径
|
2025-09-10 10:53:07 +08:00
|
|
|
|
full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip)
|
|
|
|
|
if full_save_path: # 只判断完整路径是否有效(用于保存)
|
|
|
|
|
cv2.imwrite(full_save_path, frame)
|
|
|
|
|
# 打印时使用「显示用短路径」,符合需求格式
|
|
|
|
|
print(f"✅ ocr违规图片已保存:{display_path}")
|
|
|
|
|
save_db(model_type="ocr", client_ip=client_ip, result=str(full_save_path))
|
2025-09-09 16:30:12 +08:00
|
|
|
|
return (True, ocr_result, "ocr")
|
|
|
|
|
# 4. 无违规内容(不保存图片)
|
|
|
|
|
print(f"❌ 未检测到任何违规内容,不保存图片")
|
2025-09-10 10:53:07 +08:00
|
|
|
|
return (False, "未检测到任何内容", "none")
|