路径写入数据库

This commit is contained in:
2025-09-10 10:53:07 +08:00
parent ae177ca14a
commit 435b2a0e6c

View File

@ -2,6 +2,7 @@ import cv2
import numpy as np
from PIL.Image import Image
from core.establish import get_image_save_path
from core.ocr import load_model as ocrLoadModel, detect as ocrDetect
from core.face import load_model as faceLoadModel, detect as faceDetect
from core.yolo import load_model as yoloLoadModel, detect as yoloDetect
@ -68,19 +69,24 @@ def detect(client_ip, frame):
yolo_flag, yolo_result = yoloDetect(frame)
print(f"YOLO检测结果{yolo_result}")
if yolo_flag:
save_db(model_type="yolo", client_ip=client_ip, result=numpy_array_to_base64(frame))
# if full_save_path: # 只判断完整路径是否有效(用于保存)
# cv2.imwrite(full_save_path, frame)
# # 打印时使用「显示用短路径」,符合需求格式
# print(f"✅ YOLO违规图片已保存{display_path}")
full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip)
if full_save_path: # 只判断完整路径是否有效(用于保存)
cv2.imwrite(full_save_path, frame)
# 打印时使用「显示用短路径」,符合需求格式
print(f"✅ YOLO违规图片已保存{display_path}")
save_db(model_type="yolo", client_ip=client_ip, result=str(full_save_path))
return (True, yolo_result, "yolo")
#
# # 2. 人脸检测优先级2
face_flag, face_result = faceDetect(frame)
print(f"人脸检测结果:{face_result}")
if face_flag:
# 将帧转化为 base64 字符串
save_db(model_type="face", client_ip=client_ip, result=numpy_array_to_base64(frame))
full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip)
if full_save_path: # 只判断完整路径是否有效(用于保存)
cv2.imwrite(full_save_path, frame)
# 打印时使用「显示用短路径」,符合需求格式
print(f"✅ face违规图片已保存{display_path}")
save_db(model_type="face", client_ip=client_ip, result=str(full_save_path))
return (True, face_result, "face")
# 3. OCR检测优先级3
@ -88,70 +94,13 @@ def detect(client_ip, frame):
print(f"OCR检测结果{ocr_result}")
if ocr_flag:
# 解构元组,保存用完整路径,打印用短路径
save_db(model_type="ocr", client_ip=client_ip, result=ocr_result)
# if full_save_path:
# cv2.imwrite(full_save_path, frame)
# print(f"✅ OCR违规图片已保存{display_path}")
full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip)
if full_save_path: # 只判断完整路径是否有效(用于保存)
cv2.imwrite(full_save_path, frame)
# 打印时使用「显示用短路径」,符合需求格式
print(f"✅ ocr违规图片已保存{display_path}")
save_db(model_type="ocr", client_ip=client_ip, result=str(full_save_path))
return (True, ocr_result, "ocr")
# 4. 无违规内容(不保存图片)
print(f"❌ 未检测到任何违规内容,不保存图片")
return (False, "未检测到任何内容", "none")
def numpy_array_to_base64(arr, img_format='PNG'):
"""
将numpy数组转换为base64字符串
参数:
arr: numpy数组通常是图像数据形状为(height, width, channels)
img_format: 图像格式,默认为'PNG',也可以是'JPEG'等PIL支持的格式
返回:
str: 转换后的base64字符串
异常:
ValueError: 当输入不是有效的numpy数组或不支持的形状时抛出
Exception: 处理过程中出现的其他异常
"""
try:
# 检查输入是否为numpy数组
if not isinstance(arr, np.ndarray):
raise ValueError("输入必须是numpy数组")
# 处理单通道图像(灰度图)
if len(arr.shape) == 2:
arr = np.expand_dims(arr, axis=-1)
# 检查数组形状是否有效
if len(arr.shape) != 3 or arr.shape[2] not in [1, 3, 4]:
raise ValueError("numpy数组必须是形状为(height, width, channels)的图像数据通道数应为1、3或4")
# 处理数据类型确保是uint8类型
if arr.dtype != np.uint8:
# 归一化到0-255并转换为uint8
arr = ((arr - arr.min()) / (arr.max() - arr.min() + 1e-8) * 255).astype(np.uint8)
# 将单通道图像转换为PIL支持的模式
if arr.shape[2] == 1:
arr = arr.squeeze(axis=-1)
image = Image.fromarray(arr, mode='L') # L模式表示灰度图
elif arr.shape[2] == 3:
image = Image.fromarray(arr, mode='RGB')
else: # 4通道
image = Image.fromarray(arr, mode='RGBA')
# 将图像保存到内存缓冲区
buffer = BytesIO()
image.save(buffer, format=img_format)
# 从缓冲区读取数据并编码为base64
buffer.seek(0)
base64_str = base64.b64encode(buffer.read()).decode('utf-8')
return base64_str
except ValueError as ve:
raise ve
except Exception as e:
raise Exception(f"转换过程中发生错误: {str(e)}")