Files
video/core/all.py

73 lines
2.8 KiB
Python
Raw Normal View History

2025-09-09 16:30:12 +08:00
import cv2
2025-09-04 22:59:27 +08:00
from core.ocr import load_model as ocrLoadModel, detect as ocrDetect
from core.face import load_model as faceLoadModel, detect as faceDetect
from core.yolo import load_model as yoloLoadModel, detect as yoloDetect
2025-09-09 16:30:12 +08:00
# 导入保存路径函数(根据实际文件位置调整导入路径)
from core.establish import get_image_save_path
# 模型加载状态标记(避免重复加载)
2025-09-04 22:59:27 +08:00
2025-09-09 16:30:12 +08:00
_model_loaded = False
2025-09-05 17:23:50 +08:00
2025-09-04 22:59:27 +08:00
def load_model():
2025-09-09 16:30:12 +08:00
"""加载所有检测模型(仅首次调用时执行)"""
2025-09-04 22:59:27 +08:00
global _model_loaded
2025-09-09 16:30:12 +08:00
if _model_loaded:
print("模型已加载,无需重复执行")
return
2025-09-05 17:23:50 +08:00
2025-09-09 16:30:12 +08:00
# 依次加载OCR、人脸、YOLO模型
ocrLoadModel()
faceLoadModel()
yoloLoadModel()
2025-09-05 17:23:50 +08:00
2025-09-09 16:30:12 +08:00
_model_loaded = True
print("所有检测模型加载完成")
2025-09-05 17:23:50 +08:00
2025-09-09 16:30:12 +08:00
def detect(frame):
2025-09-05 17:23:50 +08:00
"""
2025-09-09 16:30:12 +08:00
执行模型检测检测到违规时按指定格式保存图片
参数
frame: 待检测的图像帧OpenCV格式numpy.ndarray类型
返回
(检测结果布尔值, 检测详情, 检测模型类型)
2025-09-05 17:23:50 +08:00
"""
2025-09-09 16:30:12 +08:00
# 1. YOLO检测优先级1
yolo_flag, yolo_result = yoloDetect(frame)
print(f"YOLO检测结果{yolo_result}")
if yolo_flag:
# 元组解构:获取「完整保存路径」和「显示用短路径」
full_save_path, display_path = get_image_save_path(model_type="yolo")
if full_save_path: # 只判断完整路径是否有效(用于保存)
cv2.imwrite(full_save_path, frame)
# 打印时使用「显示用短路径」,符合需求格式
print(f"✅ YOLO违规图片已保存{display_path}")
2025-09-09 16:30:12 +08:00
return (True, yolo_result, "yolo")
# 2. 人脸检测优先级2
face_flag, face_result = faceDetect(frame)
print(f"人脸检测结果:{face_result}")
if face_flag:
# 同样解构元组,分离保存路径和显示路径
full_save_path, display_path = get_image_save_path(model_type="face")
if full_save_path:
cv2.imwrite(full_save_path, frame)
print(f"✅ 人脸违规图片已保存:{display_path}")
2025-09-09 16:30:12 +08:00
return (True, face_result, "face")
# 3. OCR检测优先级3
ocr_flag, ocr_result = ocrDetect(frame)
print(f"OCR检测结果{ocr_result}")
if ocr_flag:
# 解构元组,保存用完整路径,打印用短路径
full_save_path, display_path = get_image_save_path(model_type="ocr")
if full_save_path:
cv2.imwrite(full_save_path, frame)
print(f"✅ OCR违规图片已保存{display_path}")
2025-09-09 16:30:12 +08:00
return (True, ocr_result, "ocr")
# 4. 无违规内容(不保存图片)
print(f"❌ 未检测到任何违规内容,不保存图片")
return (False, "未检测到任何内容", "none")