Files
video/core/all.py

95 lines
3.3 KiB
Python
Raw Normal View History

2025-09-09 16:30:12 +08:00
import cv2
import numpy as np
from PIL.Image import Image
2025-09-10 10:53:07 +08:00
from core.establish import get_image_save_path
2025-09-04 22:59:27 +08:00
from core.ocr import load_model as ocrLoadModel, detect as ocrDetect
from core.face import load_model as faceLoadModel, detect as faceDetect
from core.yolo import load_model as yoloLoadModel, detect as yoloDetect
2025-09-09 16:30:12 +08:00
# 导入保存路径函数(根据实际文件位置调整导入路径)
import numpy as np
import base64
from io import BytesIO
from PIL import Image
from ds.db import db
from mysql.connector import Error as MySQLError
2025-09-09 16:30:12 +08:00
# 模型加载状态标记(避免重复加载)
2025-09-04 22:59:27 +08:00
2025-09-09 16:30:12 +08:00
_model_loaded = False
2025-09-05 17:23:50 +08:00
2025-09-04 22:59:27 +08:00
def load_model():
2025-09-09 16:30:12 +08:00
"""加载所有检测模型(仅首次调用时执行)"""
2025-09-04 22:59:27 +08:00
global _model_loaded
2025-09-09 16:30:12 +08:00
if _model_loaded:
print("模型已加载,无需重复执行")
return
2025-09-05 17:23:50 +08:00
2025-09-09 16:30:12 +08:00
# 依次加载OCR、人脸、YOLO模型
ocrLoadModel()
faceLoadModel()
yoloLoadModel()
2025-09-05 17:23:50 +08:00
2025-09-09 16:30:12 +08:00
_model_loaded = True
print("所有检测模型加载完成")
2025-09-05 17:23:50 +08:00
def save_db(model_type, client_ip, result):
conn = None
cursor = None
try:
# 连接数据库
conn = db.get_connection()
# 往表插入数据
cursor = conn.cursor(dictionary=True) # 返回字典格式结果
insert_query = """
INSERT INTO device_danger (client_ip, type, result)
VALUES (%s, %s, %s)
"""
cursor.execute(insert_query, (client_ip, model_type, result))
conn.commit()
except MySQLError as e:
raise Exception(f"获取设备列表失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 修正后的 detect 函数关键部分
def detect(client_ip, frame):
# 1. YOLO检测
2025-09-09 16:30:12 +08:00
yolo_flag, yolo_result = yoloDetect(frame)
if yolo_flag:
# model_type 传入 "yolo"(正确)
2025-09-10 10:53:07 +08:00
full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip)
if full_save_path:
2025-09-10 10:53:07 +08:00
cv2.imwrite(full_save_path, frame)
print(f"✅ yolo违规图片已保存{display_path}") # 日志也修正
2025-09-10 10:53:07 +08:00
save_db(model_type="yolo", client_ip=client_ip, result=str(full_save_path))
2025-09-09 16:30:12 +08:00
return (True, yolo_result, "yolo")
# 2. 人脸检测
2025-09-09 16:30:12 +08:00
face_flag, face_result = faceDetect(frame)
if face_flag:
full_save_path, display_path = get_image_save_path(model_type="face", client_ip=client_ip) # 这里改了
if full_save_path:
2025-09-10 10:53:07 +08:00
cv2.imwrite(full_save_path, frame)
print(f"✅ face违规图片已保存{display_path}") # 日志也修正
2025-09-10 10:53:07 +08:00
save_db(model_type="face", client_ip=client_ip, result=str(full_save_path))
2025-09-09 16:30:12 +08:00
return (True, face_result, "face")
# 3. OCR检测
2025-09-09 16:30:12 +08:00
ocr_flag, ocr_result = ocrDetect(frame)
if ocr_flag:
full_save_path, display_path = get_image_save_path(model_type="ocr", client_ip=client_ip) # 这里改了
if full_save_path:
2025-09-10 10:53:07 +08:00
cv2.imwrite(full_save_path, frame)
print(f"✅ ocr违规图片已保存{display_path}") # 日志也修正
2025-09-10 10:53:07 +08:00
save_db(model_type="ocr", client_ip=client_ip, result=str(full_save_path))
2025-09-09 16:30:12 +08:00
return (True, ocr_result, "ocr")
2025-09-09 16:30:12 +08:00
# 4. 无违规内容(不保存图片)
print(f"❌ 未检测到任何违规内容,不保存图片")
2025-09-10 10:53:07 +08:00
return (False, "未检测到任何内容", "none")