从服务器读取IP并将检测数据写入数据库
This commit is contained in:
124
core/all.py
124
core/all.py
@ -1,9 +1,18 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL.Image import Image
|
||||
|
||||
from core.ocr import load_model as ocrLoadModel, detect as ocrDetect
|
||||
from core.face import load_model as faceLoadModel, detect as faceDetect
|
||||
from core.yolo import load_model as yoloLoadModel, detect as yoloDetect
|
||||
# 导入保存路径函数(根据实际文件位置调整导入路径)
|
||||
from core.establish import get_image_save_path
|
||||
import numpy as np
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from ds.db import db
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
# 模型加载状态标记(避免重复加载)
|
||||
|
||||
|
||||
@ -26,7 +35,28 @@ def load_model():
|
||||
print("所有检测模型加载完成")
|
||||
|
||||
|
||||
def detect(frame):
|
||||
def save_db(model_type, client_ip, result):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
# 连接数据库
|
||||
conn = db.get_connection()
|
||||
# 往表插入数据
|
||||
cursor = conn.cursor(dictionary=True) # 返回字典格式结果
|
||||
insert_query = """
|
||||
INSERT INTO device_danger (client_ip, type, result)
|
||||
VALUES (%s, %s, %s)
|
||||
"""
|
||||
cursor.execute(insert_query, (client_ip, model_type, result))
|
||||
conn.commit()
|
||||
except MySQLError as e:
|
||||
raise Exception(f"获取设备列表失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
|
||||
def detect(client_ip, frame):
|
||||
"""
|
||||
执行模型检测,检测到违规时按指定格式保存图片
|
||||
参数:
|
||||
@ -38,23 +68,19 @@ def detect(frame):
|
||||
yolo_flag, yolo_result = yoloDetect(frame)
|
||||
print(f"YOLO检测结果:{yolo_result}")
|
||||
if yolo_flag:
|
||||
# 元组解构:获取「完整保存路径」和「显示用短路径」
|
||||
full_save_path, display_path = get_image_save_path(model_type="yolo")
|
||||
if full_save_path: # 只判断完整路径是否有效(用于保存)
|
||||
cv2.imwrite(full_save_path, frame)
|
||||
# 打印时使用「显示用短路径」,符合需求格式
|
||||
print(f"✅ YOLO违规图片已保存:{display_path}")
|
||||
save_db(model_type="yolo", client_ip=client_ip, result=numpy_array_to_base64(frame))
|
||||
# if full_save_path: # 只判断完整路径是否有效(用于保存)
|
||||
# cv2.imwrite(full_save_path, frame)
|
||||
# # 打印时使用「显示用短路径」,符合需求格式
|
||||
# print(f"✅ YOLO违规图片已保存:{display_path}")
|
||||
return (True, yolo_result, "yolo")
|
||||
|
||||
# 2. 人脸检测(优先级2)
|
||||
#
|
||||
# # 2. 人脸检测(优先级2)
|
||||
face_flag, face_result = faceDetect(frame)
|
||||
print(f"人脸检测结果:{face_result}")
|
||||
if face_flag:
|
||||
# 同样解构元组,分离保存路径和显示路径
|
||||
full_save_path, display_path = get_image_save_path(model_type="face")
|
||||
if full_save_path:
|
||||
cv2.imwrite(full_save_path, frame)
|
||||
print(f"✅ 人脸违规图片已保存:{display_path}")
|
||||
# 将帧转化为 base64 字符串
|
||||
save_db(model_type="face", client_ip=client_ip, result=numpy_array_to_base64(frame))
|
||||
return (True, face_result, "face")
|
||||
|
||||
# 3. OCR检测(优先级3)
|
||||
@ -62,12 +88,70 @@ def detect(frame):
|
||||
print(f"OCR检测结果:{ocr_result}")
|
||||
if ocr_flag:
|
||||
# 解构元组,保存用完整路径,打印用短路径
|
||||
full_save_path, display_path = get_image_save_path(model_type="ocr")
|
||||
if full_save_path:
|
||||
cv2.imwrite(full_save_path, frame)
|
||||
print(f"✅ OCR违规图片已保存:{display_path}")
|
||||
save_db(model_type="ocr", client_ip=client_ip, result=ocr_result)
|
||||
# if full_save_path:
|
||||
# cv2.imwrite(full_save_path, frame)
|
||||
# print(f"✅ OCR违规图片已保存:{display_path}")
|
||||
return (True, ocr_result, "ocr")
|
||||
|
||||
# 4. 无违规内容(不保存图片)
|
||||
print(f"❌ 未检测到任何违规内容,不保存图片")
|
||||
return (False, "未检测到任何内容", "none")
|
||||
return (False, "未检测到任何内容", "none")
|
||||
|
||||
|
||||
def numpy_array_to_base64(arr, img_format='PNG'):
|
||||
"""
|
||||
将numpy数组转换为base64字符串
|
||||
|
||||
参数:
|
||||
arr: numpy数组,通常是图像数据,形状为(height, width, channels)
|
||||
img_format: 图像格式,默认为'PNG',也可以是'JPEG'等PIL支持的格式
|
||||
|
||||
返回:
|
||||
str: 转换后的base64字符串
|
||||
|
||||
异常:
|
||||
ValueError: 当输入不是有效的numpy数组或不支持的形状时抛出
|
||||
Exception: 处理过程中出现的其他异常
|
||||
"""
|
||||
try:
|
||||
# 检查输入是否为numpy数组
|
||||
if not isinstance(arr, np.ndarray):
|
||||
raise ValueError("输入必须是numpy数组")
|
||||
|
||||
# 处理单通道图像(灰度图)
|
||||
if len(arr.shape) == 2:
|
||||
arr = np.expand_dims(arr, axis=-1)
|
||||
|
||||
# 检查数组形状是否有效
|
||||
if len(arr.shape) != 3 or arr.shape[2] not in [1, 3, 4]:
|
||||
raise ValueError("numpy数组必须是形状为(height, width, channels)的图像数据,通道数应为1、3或4")
|
||||
|
||||
# 处理数据类型,确保是uint8类型
|
||||
if arr.dtype != np.uint8:
|
||||
# 归一化到0-255并转换为uint8
|
||||
arr = ((arr - arr.min()) / (arr.max() - arr.min() + 1e-8) * 255).astype(np.uint8)
|
||||
|
||||
# 将单通道图像转换为PIL支持的模式
|
||||
if arr.shape[2] == 1:
|
||||
arr = arr.squeeze(axis=-1)
|
||||
image = Image.fromarray(arr, mode='L') # L模式表示灰度图
|
||||
elif arr.shape[2] == 3:
|
||||
image = Image.fromarray(arr, mode='RGB')
|
||||
else: # 4通道
|
||||
image = Image.fromarray(arr, mode='RGBA')
|
||||
|
||||
# 将图像保存到内存缓冲区
|
||||
buffer = BytesIO()
|
||||
image.save(buffer, format=img_format)
|
||||
|
||||
# 从缓冲区读取数据并编码为base64
|
||||
buffer.seek(0)
|
||||
base64_str = base64.b64encode(buffer.read()).decode('utf-8')
|
||||
|
||||
return base64_str
|
||||
|
||||
except ValueError as ve:
|
||||
raise ve
|
||||
except Exception as e:
|
||||
raise Exception(f"转换过程中发生错误: {str(e)}")
|
Reference in New Issue
Block a user