Files
video/core/all.py
2025-09-10 10:53:07 +08:00

106 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
from PIL.Image import Image
from core.establish import get_image_save_path
from core.ocr import load_model as ocrLoadModel, detect as ocrDetect
from core.face import load_model as faceLoadModel, detect as faceDetect
from core.yolo import load_model as yoloLoadModel, detect as yoloDetect
# 导入保存路径函数(根据实际文件位置调整导入路径)
import numpy as np
import base64
from io import BytesIO
from PIL import Image
from ds.db import db
from mysql.connector import Error as MySQLError
# 模型加载状态标记(避免重复加载)
_model_loaded = False
def load_model():
"""加载所有检测模型(仅首次调用时执行)"""
global _model_loaded
if _model_loaded:
print("模型已加载,无需重复执行")
return
# 依次加载OCR、人脸、YOLO模型
ocrLoadModel()
faceLoadModel()
yoloLoadModel()
_model_loaded = True
print("所有检测模型加载完成")
def save_db(model_type, client_ip, result):
conn = None
cursor = None
try:
# 连接数据库
conn = db.get_connection()
# 往表插入数据
cursor = conn.cursor(dictionary=True) # 返回字典格式结果
insert_query = """
INSERT INTO device_danger (client_ip, type, result)
VALUES (%s, %s, %s)
"""
cursor.execute(insert_query, (client_ip, model_type, result))
conn.commit()
except MySQLError as e:
raise Exception(f"获取设备列表失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
def detect(client_ip, frame):
"""
执行模型检测,检测到违规时按指定格式保存图片
参数:
frame: 待检测的图像帧OpenCV格式numpy.ndarray类型
返回:
(检测结果布尔值, 检测详情, 检测模型类型)
"""
# 1. YOLO检测优先级1
yolo_flag, yolo_result = yoloDetect(frame)
print(f"YOLO检测结果{yolo_result}")
if yolo_flag:
full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip)
if full_save_path: # 只判断完整路径是否有效(用于保存)
cv2.imwrite(full_save_path, frame)
# 打印时使用「显示用短路径」,符合需求格式
print(f"✅ YOLO违规图片已保存{display_path}")
save_db(model_type="yolo", client_ip=client_ip, result=str(full_save_path))
return (True, yolo_result, "yolo")
#
# # 2. 人脸检测优先级2
face_flag, face_result = faceDetect(frame)
print(f"人脸检测结果:{face_result}")
if face_flag:
full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip)
if full_save_path: # 只判断完整路径是否有效(用于保存)
cv2.imwrite(full_save_path, frame)
# 打印时使用「显示用短路径」,符合需求格式
print(f"✅ face违规图片已保存{display_path}")
save_db(model_type="face", client_ip=client_ip, result=str(full_save_path))
return (True, face_result, "face")
# 3. OCR检测优先级3
ocr_flag, ocr_result = ocrDetect(frame)
print(f"OCR检测结果{ocr_result}")
if ocr_flag:
# 解构元组,保存用完整路径,打印用短路径
full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip)
if full_save_path: # 只判断完整路径是否有效(用于保存)
cv2.imwrite(full_save_path, frame)
# 打印时使用「显示用短路径」,符合需求格式
print(f"✅ ocr违规图片已保存{display_path}")
save_db(model_type="ocr", client_ip=client_ip, result=str(full_save_path))
return (True, ocr_result, "ocr")
# 4. 无违规内容(不保存图片)
print(f"❌ 未检测到任何违规内容,不保存图片")
return (False, "未检测到任何内容", "none")