Files
video/core/all.py

157 lines
5.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
from PIL.Image import Image
from core.ocr import load_model as ocrLoadModel, detect as ocrDetect
from core.face import load_model as faceLoadModel, detect as faceDetect
from core.yolo import load_model as yoloLoadModel, detect as yoloDetect
# 导入保存路径函数(根据实际文件位置调整导入路径)
import numpy as np
import base64
from io import BytesIO
from PIL import Image
from ds.db import db
from mysql.connector import Error as MySQLError
# 模型加载状态标记(避免重复加载)
_model_loaded = False
def load_model():
"""加载所有检测模型(仅首次调用时执行)"""
global _model_loaded
if _model_loaded:
print("模型已加载,无需重复执行")
return
# 依次加载OCR、人脸、YOLO模型
ocrLoadModel()
faceLoadModel()
yoloLoadModel()
_model_loaded = True
print("所有检测模型加载完成")
def save_db(model_type, client_ip, result):
conn = None
cursor = None
try:
# 连接数据库
conn = db.get_connection()
# 往表插入数据
cursor = conn.cursor(dictionary=True) # 返回字典格式结果
insert_query = """
INSERT INTO device_danger (client_ip, type, result)
VALUES (%s, %s, %s)
"""
cursor.execute(insert_query, (client_ip, model_type, result))
conn.commit()
except MySQLError as e:
raise Exception(f"获取设备列表失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
def detect(client_ip, frame):
"""
执行模型检测,检测到违规时按指定格式保存图片
参数:
frame: 待检测的图像帧OpenCV格式numpy.ndarray类型
返回:
(检测结果布尔值, 检测详情, 检测模型类型)
"""
# 1. YOLO检测优先级1
yolo_flag, yolo_result = yoloDetect(frame)
print(f"YOLO检测结果{yolo_result}")
if yolo_flag:
save_db(model_type="yolo", client_ip=client_ip, result=numpy_array_to_base64(frame))
# if full_save_path: # 只判断完整路径是否有效(用于保存)
# cv2.imwrite(full_save_path, frame)
# # 打印时使用「显示用短路径」,符合需求格式
# print(f"✅ YOLO违规图片已保存{display_path}")
return (True, yolo_result, "yolo")
#
# # 2. 人脸检测优先级2
face_flag, face_result = faceDetect(frame)
print(f"人脸检测结果:{face_result}")
if face_flag:
# 将帧转化为 base64 字符串
save_db(model_type="face", client_ip=client_ip, result=numpy_array_to_base64(frame))
return (True, face_result, "face")
# 3. OCR检测优先级3
ocr_flag, ocr_result = ocrDetect(frame)
print(f"OCR检测结果{ocr_result}")
if ocr_flag:
# 解构元组,保存用完整路径,打印用短路径
save_db(model_type="ocr", client_ip=client_ip, result=ocr_result)
# if full_save_path:
# cv2.imwrite(full_save_path, frame)
# print(f"✅ OCR违规图片已保存{display_path}")
return (True, ocr_result, "ocr")
# 4. 无违规内容(不保存图片)
print(f"❌ 未检测到任何违规内容,不保存图片")
return (False, "未检测到任何内容", "none")
def numpy_array_to_base64(arr, img_format='PNG'):
"""
将numpy数组转换为base64字符串
参数:
arr: numpy数组通常是图像数据形状为(height, width, channels)
img_format: 图像格式,默认为'PNG',也可以是'JPEG'等PIL支持的格式
返回:
str: 转换后的base64字符串
异常:
ValueError: 当输入不是有效的numpy数组或不支持的形状时抛出
Exception: 处理过程中出现的其他异常
"""
try:
# 检查输入是否为numpy数组
if not isinstance(arr, np.ndarray):
raise ValueError("输入必须是numpy数组")
# 处理单通道图像(灰度图)
if len(arr.shape) == 2:
arr = np.expand_dims(arr, axis=-1)
# 检查数组形状是否有效
if len(arr.shape) != 3 or arr.shape[2] not in [1, 3, 4]:
raise ValueError("numpy数组必须是形状为(height, width, channels)的图像数据通道数应为1、3或4")
# 处理数据类型确保是uint8类型
if arr.dtype != np.uint8:
# 归一化到0-255并转换为uint8
arr = ((arr - arr.min()) / (arr.max() - arr.min() + 1e-8) * 255).astype(np.uint8)
# 将单通道图像转换为PIL支持的模式
if arr.shape[2] == 1:
arr = arr.squeeze(axis=-1)
image = Image.fromarray(arr, mode='L') # L模式表示灰度图
elif arr.shape[2] == 3:
image = Image.fromarray(arr, mode='RGB')
else: # 4通道
image = Image.fromarray(arr, mode='RGBA')
# 将图像保存到内存缓冲区
buffer = BytesIO()
image.save(buffer, format=img_format)
# 从缓冲区读取数据并编码为base64
buffer.seek(0)
base64_str = base64.b64encode(buffer.read()).decode('utf-8')
return base64_str
except ValueError as ve:
raise ve
except Exception as e:
raise Exception(f"转换过程中发生错误: {str(e)}")