From 435b2a0e6c88098659900b3a67f1b2d28623ef2c Mon Sep 17 00:00:00 2001 From: ninghongbin <2409766686@qq.com> Date: Wed, 10 Sep 2025 10:53:07 +0800 Subject: [PATCH] =?UTF-8?q?=E8=B7=AF=E5=BE=84=E5=86=99=E5=85=A5=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- core/all.py | 91 ++++++++++++----------------------------------------- 1 file changed, 20 insertions(+), 71 deletions(-) diff --git a/core/all.py b/core/all.py index 80267aa..7cc307d 100644 --- a/core/all.py +++ b/core/all.py @@ -2,6 +2,7 @@ import cv2 import numpy as np from PIL.Image import Image +from core.establish import get_image_save_path from core.ocr import load_model as ocrLoadModel, detect as ocrDetect from core.face import load_model as faceLoadModel, detect as faceDetect from core.yolo import load_model as yoloLoadModel, detect as yoloDetect @@ -68,19 +69,24 @@ def detect(client_ip, frame): yolo_flag, yolo_result = yoloDetect(frame) print(f"YOLO检测结果:{yolo_result}") if yolo_flag: - save_db(model_type="yolo", client_ip=client_ip, result=numpy_array_to_base64(frame)) - # if full_save_path: # 只判断完整路径是否有效(用于保存) - # cv2.imwrite(full_save_path, frame) - # # 打印时使用「显示用短路径」,符合需求格式 - # print(f"✅ YOLO违规图片已保存:{display_path}") + full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip) + if full_save_path: # 只判断完整路径是否有效(用于保存) + cv2.imwrite(full_save_path, frame) + # 打印时使用「显示用短路径」,符合需求格式 + print(f"✅ YOLO违规图片已保存:{display_path}") + save_db(model_type="yolo", client_ip=client_ip, result=str(full_save_path)) return (True, yolo_result, "yolo") # # # 2. 人脸检测(优先级2) face_flag, face_result = faceDetect(frame) print(f"人脸检测结果:{face_result}") if face_flag: - # 将帧转化为 base64 字符串 - save_db(model_type="face", client_ip=client_ip, result=numpy_array_to_base64(frame)) + full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip) + if full_save_path: # 只判断完整路径是否有效(用于保存) + cv2.imwrite(full_save_path, frame) + # 打印时使用「显示用短路径」,符合需求格式 + print(f"✅ face违规图片已保存:{display_path}") + save_db(model_type="face", client_ip=client_ip, result=str(full_save_path)) return (True, face_result, "face") # 3. OCR检测(优先级3) @@ -88,70 +94,13 @@ def detect(client_ip, frame): print(f"OCR检测结果:{ocr_result}") if ocr_flag: # 解构元组,保存用完整路径,打印用短路径 - save_db(model_type="ocr", client_ip=client_ip, result=ocr_result) - # if full_save_path: - # cv2.imwrite(full_save_path, frame) - # print(f"✅ OCR违规图片已保存:{display_path}") + full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip) + if full_save_path: # 只判断完整路径是否有效(用于保存) + cv2.imwrite(full_save_path, frame) + # 打印时使用「显示用短路径」,符合需求格式 + print(f"✅ ocr违规图片已保存:{display_path}") + save_db(model_type="ocr", client_ip=client_ip, result=str(full_save_path)) return (True, ocr_result, "ocr") - # 4. 无违规内容(不保存图片) print(f"❌ 未检测到任何违规内容,不保存图片") - return (False, "未检测到任何内容", "none") - - -def numpy_array_to_base64(arr, img_format='PNG'): - """ - 将numpy数组转换为base64字符串 - - 参数: - arr: numpy数组,通常是图像数据,形状为(height, width, channels) - img_format: 图像格式,默认为'PNG',也可以是'JPEG'等PIL支持的格式 - - 返回: - str: 转换后的base64字符串 - - 异常: - ValueError: 当输入不是有效的numpy数组或不支持的形状时抛出 - Exception: 处理过程中出现的其他异常 - """ - try: - # 检查输入是否为numpy数组 - if not isinstance(arr, np.ndarray): - raise ValueError("输入必须是numpy数组") - - # 处理单通道图像(灰度图) - if len(arr.shape) == 2: - arr = np.expand_dims(arr, axis=-1) - - # 检查数组形状是否有效 - if len(arr.shape) != 3 or arr.shape[2] not in [1, 3, 4]: - raise ValueError("numpy数组必须是形状为(height, width, channels)的图像数据,通道数应为1、3或4") - - # 处理数据类型,确保是uint8类型 - if arr.dtype != np.uint8: - # 归一化到0-255并转换为uint8 - arr = ((arr - arr.min()) / (arr.max() - arr.min() + 1e-8) * 255).astype(np.uint8) - - # 将单通道图像转换为PIL支持的模式 - if arr.shape[2] == 1: - arr = arr.squeeze(axis=-1) - image = Image.fromarray(arr, mode='L') # L模式表示灰度图 - elif arr.shape[2] == 3: - image = Image.fromarray(arr, mode='RGB') - else: # 4通道 - image = Image.fromarray(arr, mode='RGBA') - - # 将图像保存到内存缓冲区 - buffer = BytesIO() - image.save(buffer, format=img_format) - - # 从缓冲区读取数据并编码为base64 - buffer.seek(0) - base64_str = base64.b64encode(buffer.read()).decode('utf-8') - - return base64_str - - except ValueError as ve: - raise ve - except Exception as e: - raise Exception(f"转换过程中发生错误: {str(e)}") \ No newline at end of file + return (False, "未检测到任何内容", "none") \ No newline at end of file