import cv2 import numpy as np from PIL.Image import Image from core.ocr import load_model as ocrLoadModel, detect as ocrDetect from core.face import load_model as faceLoadModel, detect as faceDetect from core.yolo import load_model as yoloLoadModel, detect as yoloDetect # 导入保存路径函数(根据实际文件位置调整导入路径) import numpy as np import base64 from io import BytesIO from PIL import Image from ds.db import db from mysql.connector import Error as MySQLError # 模型加载状态标记(避免重复加载) _model_loaded = False def load_model(): """加载所有检测模型(仅首次调用时执行)""" global _model_loaded if _model_loaded: print("模型已加载,无需重复执行") return # 依次加载OCR、人脸、YOLO模型 ocrLoadModel() faceLoadModel() yoloLoadModel() _model_loaded = True print("所有检测模型加载完成") def save_db(model_type, client_ip, result): conn = None cursor = None try: # 连接数据库 conn = db.get_connection() # 往表插入数据 cursor = conn.cursor(dictionary=True) # 返回字典格式结果 insert_query = """ INSERT INTO device_danger (client_ip, type, result) VALUES (%s, %s, %s) """ cursor.execute(insert_query, (client_ip, model_type, result)) conn.commit() except MySQLError as e: raise Exception(f"获取设备列表失败: {str(e)}") from e finally: db.close_connection(conn, cursor) def detect(client_ip, frame): """ 执行模型检测,检测到违规时按指定格式保存图片 参数: frame: 待检测的图像帧(OpenCV格式,numpy.ndarray类型) 返回: (检测结果布尔值, 检测详情, 检测模型类型) """ # 1. YOLO检测(优先级1) yolo_flag, yolo_result = yoloDetect(frame) print(f"YOLO检测结果:{yolo_result}") if yolo_flag: save_db(model_type="yolo", client_ip=client_ip, result=numpy_array_to_base64(frame)) # if full_save_path: # 只判断完整路径是否有效(用于保存) # cv2.imwrite(full_save_path, frame) # # 打印时使用「显示用短路径」,符合需求格式 # print(f"✅ YOLO违规图片已保存:{display_path}") return (True, yolo_result, "yolo") # # # 2. 人脸检测(优先级2) face_flag, face_result = faceDetect(frame) print(f"人脸检测结果:{face_result}") if face_flag: # 将帧转化为 base64 字符串 save_db(model_type="face", client_ip=client_ip, result=numpy_array_to_base64(frame)) return (True, face_result, "face") # 3. OCR检测(优先级3) ocr_flag, ocr_result = ocrDetect(frame) print(f"OCR检测结果:{ocr_result}") if ocr_flag: # 解构元组,保存用完整路径,打印用短路径 save_db(model_type="ocr", client_ip=client_ip, result=ocr_result) # if full_save_path: # cv2.imwrite(full_save_path, frame) # print(f"✅ OCR违规图片已保存:{display_path}") return (True, ocr_result, "ocr") # 4. 无违规内容(不保存图片) print(f"❌ 未检测到任何违规内容,不保存图片") return (False, "未检测到任何内容", "none") def numpy_array_to_base64(arr, img_format='PNG'): """ 将numpy数组转换为base64字符串 参数: arr: numpy数组,通常是图像数据,形状为(height, width, channels) img_format: 图像格式,默认为'PNG',也可以是'JPEG'等PIL支持的格式 返回: str: 转换后的base64字符串 异常: ValueError: 当输入不是有效的numpy数组或不支持的形状时抛出 Exception: 处理过程中出现的其他异常 """ try: # 检查输入是否为numpy数组 if not isinstance(arr, np.ndarray): raise ValueError("输入必须是numpy数组") # 处理单通道图像(灰度图) if len(arr.shape) == 2: arr = np.expand_dims(arr, axis=-1) # 检查数组形状是否有效 if len(arr.shape) != 3 or arr.shape[2] not in [1, 3, 4]: raise ValueError("numpy数组必须是形状为(height, width, channels)的图像数据,通道数应为1、3或4") # 处理数据类型,确保是uint8类型 if arr.dtype != np.uint8: # 归一化到0-255并转换为uint8 arr = ((arr - arr.min()) / (arr.max() - arr.min() + 1e-8) * 255).astype(np.uint8) # 将单通道图像转换为PIL支持的模式 if arr.shape[2] == 1: arr = arr.squeeze(axis=-1) image = Image.fromarray(arr, mode='L') # L模式表示灰度图 elif arr.shape[2] == 3: image = Image.fromarray(arr, mode='RGB') else: # 4通道 image = Image.fromarray(arr, mode='RGBA') # 将图像保存到内存缓冲区 buffer = BytesIO() image.save(buffer, format=img_format) # 从缓冲区读取数据并编码为base64 buffer.seek(0) base64_str = base64.b64encode(buffer.read()).decode('utf-8') return base64_str except ValueError as ve: raise ve except Exception as e: raise Exception(f"转换过程中发生错误: {str(e)}")