Files
video/core/all.py

95 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
from PIL.Image import Image
from core.establish import get_image_save_path
from core.ocr import load_model as ocrLoadModel, detect as ocrDetect
from core.face import load_model as faceLoadModel, detect as faceDetect
from core.yolo import load_model as yoloLoadModel, detect as yoloDetect
# 导入保存路径函数(根据实际文件位置调整导入路径)
import numpy as np
import base64
from io import BytesIO
from PIL import Image
from ds.db import db
from mysql.connector import Error as MySQLError
# 模型加载状态标记(避免重复加载)
_model_loaded = False
def load_model():
"""加载所有检测模型(仅首次调用时执行)"""
global _model_loaded
if _model_loaded:
print("模型已加载,无需重复执行")
return
# 依次加载OCR、人脸、YOLO模型
ocrLoadModel()
faceLoadModel()
yoloLoadModel()
_model_loaded = True
print("所有检测模型加载完成")
def save_db(model_type, client_ip, result):
conn = None
cursor = None
try:
# 连接数据库
conn = db.get_connection()
# 往表插入数据
cursor = conn.cursor(dictionary=True) # 返回字典格式结果
insert_query = """
INSERT INTO device_danger (client_ip, type, result)
VALUES (%s, %s, %s)
"""
cursor.execute(insert_query, (client_ip, model_type, result))
conn.commit()
except MySQLError as e:
raise Exception(f"获取设备列表失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 修正后的 detect 函数关键部分
def detect(client_ip, frame):
# 1. YOLO检测
yolo_flag, yolo_result = yoloDetect(frame)
if yolo_flag:
# model_type 传入 "yolo"(正确)
full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip)
if full_save_path:
cv2.imwrite(full_save_path, frame)
print(f"✅ yolo违规图片已保存{display_path}") # 日志也修正
save_db(model_type="yolo", client_ip=client_ip, result=str(full_save_path))
return (True, yolo_result, "yolo")
# 2. 人脸检测
face_flag, face_result = faceDetect(frame)
if face_flag:
full_save_path, display_path = get_image_save_path(model_type="face", client_ip=client_ip) # 这里改了
if full_save_path:
cv2.imwrite(full_save_path, frame)
print(f"✅ face违规图片已保存{display_path}") # 日志也修正
save_db(model_type="face", client_ip=client_ip, result=str(full_save_path))
return (True, face_result, "face")
# 3. OCR检测
ocr_flag, ocr_result = ocrDetect(frame)
if ocr_flag:
full_save_path, display_path = get_image_save_path(model_type="ocr", client_ip=client_ip) # 这里改了
if full_save_path:
cv2.imwrite(full_save_path, frame)
print(f"✅ ocr违规图片已保存{display_path}") # 日志也修正
save_db(model_type="ocr", client_ip=client_ip, result=str(full_save_path))
return (True, ocr_result, "ocr")
# 4. 无违规内容(不保存图片)
print(f"❌ 未检测到任何违规内容,不保存图片")
return (False, "未检测到任何内容", "none")