From ae177ca14a4cd4cc68a8b5c852a008ed583d9c36 Mon Sep 17 00:00:00 2001 From: ninghongbin <2409766686@qq.com> Date: Wed, 10 Sep 2025 08:57:56 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BB=8E=E6=9C=8D=E5=8A=A1=E5=99=A8=E8=AF=BB?= =?UTF-8?q?=E5=8F=96IP=E5=B9=B6=E5=B0=86=E6=A3=80=E6=B5=8B=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=86=99=E5=85=A5=E6=95=B0=E6=8D=AE=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- core/all.py | 124 ++++++++++++++++++++++++++++++++------ core/establish.py | 98 +++++++++++++++--------------- service/device_service.py | 26 ++++++++ ws/ws.py | 35 ++++++----- 4 files changed, 200 insertions(+), 83 deletions(-) diff --git a/core/all.py b/core/all.py index 50f89fb..80267aa 100644 --- a/core/all.py +++ b/core/all.py @@ -1,9 +1,18 @@ import cv2 +import numpy as np +from PIL.Image import Image + from core.ocr import load_model as ocrLoadModel, detect as ocrDetect from core.face import load_model as faceLoadModel, detect as faceDetect from core.yolo import load_model as yoloLoadModel, detect as yoloDetect # 导入保存路径函数(根据实际文件位置调整导入路径) -from core.establish import get_image_save_path +import numpy as np +import base64 +from io import BytesIO +from PIL import Image +from ds.db import db +from mysql.connector import Error as MySQLError + # 模型加载状态标记(避免重复加载) @@ -26,7 +35,28 @@ def load_model(): print("所有检测模型加载完成") -def detect(frame): +def save_db(model_type, client_ip, result): + conn = None + cursor = None + try: + # 连接数据库 + conn = db.get_connection() + # 往表插入数据 + cursor = conn.cursor(dictionary=True) # 返回字典格式结果 + insert_query = """ + INSERT INTO device_danger (client_ip, type, result) + VALUES (%s, %s, %s) + """ + cursor.execute(insert_query, (client_ip, model_type, result)) + conn.commit() + except MySQLError as e: + raise Exception(f"获取设备列表失败: {str(e)}") from e + finally: + db.close_connection(conn, cursor) + + + +def detect(client_ip, frame): """ 执行模型检测,检测到违规时按指定格式保存图片 参数: @@ -38,23 +68,19 @@ def detect(frame): yolo_flag, yolo_result = yoloDetect(frame) print(f"YOLO检测结果:{yolo_result}") if yolo_flag: - # 元组解构:获取「完整保存路径」和「显示用短路径」 - full_save_path, display_path = get_image_save_path(model_type="yolo") - if full_save_path: # 只判断完整路径是否有效(用于保存) - cv2.imwrite(full_save_path, frame) - # 打印时使用「显示用短路径」,符合需求格式 - print(f"✅ YOLO违规图片已保存:{display_path}") + save_db(model_type="yolo", client_ip=client_ip, result=numpy_array_to_base64(frame)) + # if full_save_path: # 只判断完整路径是否有效(用于保存) + # cv2.imwrite(full_save_path, frame) + # # 打印时使用「显示用短路径」,符合需求格式 + # print(f"✅ YOLO违规图片已保存:{display_path}") return (True, yolo_result, "yolo") - - # 2. 人脸检测(优先级2) + # + # # 2. 人脸检测(优先级2) face_flag, face_result = faceDetect(frame) print(f"人脸检测结果:{face_result}") if face_flag: - # 同样解构元组,分离保存路径和显示路径 - full_save_path, display_path = get_image_save_path(model_type="face") - if full_save_path: - cv2.imwrite(full_save_path, frame) - print(f"✅ 人脸违规图片已保存:{display_path}") + # 将帧转化为 base64 字符串 + save_db(model_type="face", client_ip=client_ip, result=numpy_array_to_base64(frame)) return (True, face_result, "face") # 3. OCR检测(优先级3) @@ -62,12 +88,70 @@ def detect(frame): print(f"OCR检测结果:{ocr_result}") if ocr_flag: # 解构元组,保存用完整路径,打印用短路径 - full_save_path, display_path = get_image_save_path(model_type="ocr") - if full_save_path: - cv2.imwrite(full_save_path, frame) - print(f"✅ OCR违规图片已保存:{display_path}") + save_db(model_type="ocr", client_ip=client_ip, result=ocr_result) + # if full_save_path: + # cv2.imwrite(full_save_path, frame) + # print(f"✅ OCR违规图片已保存:{display_path}") return (True, ocr_result, "ocr") # 4. 无违规内容(不保存图片) print(f"❌ 未检测到任何违规内容,不保存图片") - return (False, "未检测到任何内容", "none") \ No newline at end of file + return (False, "未检测到任何内容", "none") + + +def numpy_array_to_base64(arr, img_format='PNG'): + """ + 将numpy数组转换为base64字符串 + + 参数: + arr: numpy数组,通常是图像数据,形状为(height, width, channels) + img_format: 图像格式,默认为'PNG',也可以是'JPEG'等PIL支持的格式 + + 返回: + str: 转换后的base64字符串 + + 异常: + ValueError: 当输入不是有效的numpy数组或不支持的形状时抛出 + Exception: 处理过程中出现的其他异常 + """ + try: + # 检查输入是否为numpy数组 + if not isinstance(arr, np.ndarray): + raise ValueError("输入必须是numpy数组") + + # 处理单通道图像(灰度图) + if len(arr.shape) == 2: + arr = np.expand_dims(arr, axis=-1) + + # 检查数组形状是否有效 + if len(arr.shape) != 3 or arr.shape[2] not in [1, 3, 4]: + raise ValueError("numpy数组必须是形状为(height, width, channels)的图像数据,通道数应为1、3或4") + + # 处理数据类型,确保是uint8类型 + if arr.dtype != np.uint8: + # 归一化到0-255并转换为uint8 + arr = ((arr - arr.min()) / (arr.max() - arr.min() + 1e-8) * 255).astype(np.uint8) + + # 将单通道图像转换为PIL支持的模式 + if arr.shape[2] == 1: + arr = arr.squeeze(axis=-1) + image = Image.fromarray(arr, mode='L') # L模式表示灰度图 + elif arr.shape[2] == 3: + image = Image.fromarray(arr, mode='RGB') + else: # 4通道 + image = Image.fromarray(arr, mode='RGBA') + + # 将图像保存到内存缓冲区 + buffer = BytesIO() + image.save(buffer, format=img_format) + + # 从缓冲区读取数据并编码为base64 + buffer.seek(0) + base64_str = base64.b64encode(buffer.read()).decode('utf-8') + + return base64_str + + except ValueError as ve: + raise ve + except Exception as e: + raise Exception(f"转换过程中发生错误: {str(e)}") \ No newline at end of file diff --git a/core/establish.py b/core/establish.py index f679634..1a800dc 100644 --- a/core/establish.py +++ b/core/establish.py @@ -2,15 +2,11 @@ import os import datetime from pathlib import Path - -# 配置IP文件路径(统一使用绝对路径) -IP_FILE_PATH = Path(r"D:\ccc\IP.txt") - - +from service.device_service import get_unique_client_ips def create_directory_structure(): - """创建项目所需的目录结构""" + """创建项目所需的目录结构,为所有客户端IP预创建基础目录""" try: - # 1. 创建根目录下的resource文件夹 + # 1. 创建根目录下的resource文件夹(存在则跳过,不覆盖子内容) resource_dir = Path("resource") resource_dir.mkdir(exist_ok=True) print(f"确保resource目录存在: {resource_dir.absolute()}") @@ -27,87 +23,95 @@ def create_directory_structure(): model_dir.mkdir(exist_ok=True) print(f"确保{model}模型目录存在: {model_dir.absolute()}") - # 4. 读取ip.txt文件获取IP地址 + # 4. 调用外部方法获取所有客户端IP地址 try: - with open(IP_FILE_PATH, "r") as f: - ip_addresses = [line.strip() for line in f if line.strip()] + # 调用外部ip_read()方法获取所有客户端IP地址列表 + all_ip_addresses = get_unique_client_ips() - if not ip_addresses: - print("警告: ip.txt文件中未找到有效的IP地址") + # 确保返回的是列表类型 + if not isinstance(all_ip_addresses, list): + all_ip_addresses = [all_ip_addresses] + + # 过滤有效IP(去除空字符串和空格) + valid_ips = [ip.strip() for ip in all_ip_addresses if ip.strip()] + + if not valid_ips: + print("警告: 未获取到有效的客户端IP地址") return - print(f"从ip.txt中读取到的IP地址: {ip_addresses}") + print(f"获取到的所有客户端IP地址: {valid_ips}") - # 5. 获取当前日期 + # 5. 获取当前日期(年、月) now = datetime.datetime.now() current_year = str(now.year) current_month = str(now.month) - # 6. 为每个IP在每个模型文件夹下创建年->月的目录结构 - for ip in ip_addresses: - # 直接使用原始IP格式 - safe_ip = ip + # 6. 为每个客户端IP在每个模型文件夹下创建年->月的基础目录结构 + for ip in valid_ips: + # 处理IP地址中的特殊字符(将.替换为_,避免路径问题) + safe_ip = ip.replace(".", "_") for model in model_dirs: - # 构建路径: resource/dect/{model}/{ip}/{year}/{month} + # 构建路径: resource/dect/{model}/{safe_ip}/{year}/{month} ip_dir = dect_dir / model / safe_ip year_dir = ip_dir / current_year month_dir = year_dir / current_month - # 创建目录(如果不存在) + # 递归创建目录(存在则跳过,不覆盖) month_dir.mkdir(parents=True, exist_ok=True) - print(f"创建/确保目录存在: {month_dir.absolute()}") + print(f"为客户端IP {ip} 创建/确保目录存在: {month_dir.absolute()}") - except FileNotFoundError: - print(f"错误: 未找到ip.txt文件,请确保该文件存在于 {IP_FILE_PATH}") except Exception as e: - print(f"处理IP和日期目录时发生错误: {str(e)}") + print(f"处理客户端IP和日期目录时发生错误: {str(e)}") except Exception as e: - print(f"创建目录结构时发生错误: {str(e)}") + print(f"创建基础目录结构时发生错误: {str(e)}") -def get_image_save_path(model_type: str) -> tuple: +def get_image_save_path(model_type: str, client_ip: str) -> tuple: """ - 获取图片保存的完整路径和显示用路径 + 获取图片保存的「完整路径」和「显示用短路径」 参数: model_type: 模型类型,应为"ocr"、"face"或"yolo" + client_ip: 检测到违禁的客户端IP地址(原始格式,如192.168.1.101) 返回: - 元组 (完整保存路径, 显示用路径) + 元组 (完整保存路径, 显示用短路径);若出错则返回 ("", "") """ try: - # 读取IP地址(假设只有一个IP或使用第一个IP) - with open(IP_FILE_PATH, "r") as f: - ip_addresses = [line.strip() for line in f if line.strip()] + # 1. 验证客户端IP有效性(检查是否在已知IP列表中) + all_ip_addresses = get_unique_client_ips() + if not isinstance(all_ip_addresses, list): + all_ip_addresses = [all_ip_addresses] + valid_ips = [ip.strip() for ip in all_ip_addresses if ip.strip()] - if not ip_addresses: - raise ValueError("ip.txt文件中未找到有效的IP地址") + if client_ip.strip() not in valid_ips: + raise ValueError(f"客户端IP {client_ip} 不在已知IP列表中,无法保存文件") - ip = ip_addresses[0] - safe_ip = ip # 直接使用原始IP格式 + # 2. 处理IP地址(与目录创建逻辑一致,将.替换为_) + safe_ip = client_ip.strip().replace(".", "_") - # 获取当前日期和时间(精确到毫秒,确保文件名唯一) + # 3. 获取当前日期和毫秒级时间戳(确保文件名唯一) now = datetime.datetime.now() current_year = str(now.year) current_month = str(now.month) current_day = str(now.day) - # 生成时间戳字符串(格式:年月日时分秒毫秒) - timestamp = now.strftime("%Y%m%d%H%M%S%f")[:-3] # 去除最后三位,保留到毫秒 + # 时间戳格式:年月日时分秒毫秒(如20250910143050123) + timestamp = now.strftime("%Y%m%d%H%M%S%f")[:-3] - # 构建基础目录路径 - base_dir = Path("resource") / "dect" - # 构建完整路径: resource/dect/{model}/{ip}/{year}/{month}/{day} + # 4. 定义基础目录(用于生成相对路径) + base_dir = Path("resource") / "dect" # 显示路径会去掉这个前缀 + # 构建日级目录(完整路径:resource/dect/{model}/{safe_ip}/{年}/{月}/{日}) day_dir = base_dir / model_type / safe_ip / current_year / current_month / current_day - day_dir.mkdir(parents=True, exist_ok=True) + day_dir.mkdir(parents=True, exist_ok=True) # 确保日目录存在 - # 构建图片文件名(简化名称,去掉resource_dect_前缀) - image_filename = f"{model_type}_{safe_ip}_{current_year}_{current_month}_{current_day}_{timestamp}.jpg" - full_path = day_dir / image_filename + # 5. 构建唯一文件名 + image_filename = f"dect_{model_type}_{safe_ip}_{current_year}{current_month}{current_day}_{timestamp}.jpg" - # 计算显示用路径(相对于resource/dect的路径) - display_path = full_path.relative_to(base_dir) + # 6. 生成完整路径(用于实际保存图片)和显示路径(用于打印) + full_path = day_dir / image_filename # 完整路径:resource/dect/.../xxx.jpg + display_path = full_path.relative_to(base_dir) # 短路径:{model}/.../xxx.jpg(去掉resource/dect) return str(full_path), str(display_path) diff --git a/service/device_service.py b/service/device_service.py index c3b3ef6..a5f15fb 100644 --- a/service/device_service.py +++ b/service/device_service.py @@ -236,3 +236,29 @@ async def get_device_list( raise Exception(f"获取设备列表失败: {str(e)}") from e finally: db.close_connection(conn, cursor) + + +def get_unique_client_ips() -> list[str]: + """ + 获取所有去重的客户端IP列表 + + :return: 去重后的客户端IP字符串列表,如果没有数据则返回空列表 + """ + conn = None + cursor = None + try: + conn = db.get_connection() + cursor = conn.cursor(dictionary=True) + + # 查询去重的客户端IP + query = "SELECT DISTINCT client_ip FROM devices WHERE client_ip IS NOT NULL" + cursor.execute(query) + + # 提取结果并转换为字符串列表 + results = cursor.fetchall() + return [item['client_ip'] for item in results] + + except MySQLError as e: + raise Exception(f"获取客户端IP列表失败: {str(e)}") from e + finally: + db.close_connection(conn, cursor) \ No newline at end of file diff --git a/ws/ws.py b/ws/ws.py index dc570e0..afb5e4f 100644 --- a/ws/ws.py +++ b/ws/ws.py @@ -33,7 +33,7 @@ def get_current_time_file_str() -> str: class ClientConnection: def __init__(self, websocket: WebSocket, client_ip: str): self.websocket = websocket - self.client_ip = client_ip + self.client_ip = client_ip # 已初始化客户端IP,用于传递给detect self.last_heartbeat = datetime.datetime.now() self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE) self.consumer_task: Optional[asyncio.Task] = None @@ -84,7 +84,7 @@ class ClientConnection: print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费逻辑错误 - {str(e)}") async def process_frame(self, frame_data: bytes) -> None: - """处理单帧图像数据(核心修复:按3个返回值解包)""" + """处理单帧图像数据(核心修改:detect函数传入 client_ip + img 双参数)""" # 二进制转OpenCV图像 nparr = np.frombuffer(frame_data, np.uint8) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) @@ -93,19 +93,21 @@ class ClientConnection: return try: - # -------------------------- 修复核心:匹配detect返回的3个值 -------------------------- - # 假设detect返回 (是否违规, 结果数据, 检测器类型) + # -------------------------- 核心修改:按要求传入参数(1.client_ip 2.img) -------------------------- + # detect函数参数顺序:第一个为client_ip,第二个为图像数据img + # 保持返回值解包(是否违规, 结果数据, 检测器类型)不变 has_violation, data, detector_type = await asyncio.to_thread( - detect, # 调用检测函数 - img # 传入图像参数 + detect, # 调用检测函数 + self.client_ip, # 第一个参数:客户端IP(新增,按需求顺序) + img # 第二个参数:图像数据(原参数,调整顺序) ) # ------------------------------------------------------------------------------------- - # 打印检测结果(移除task_id相关内容) + # 打印检测结果(包含客户端IP,与传入参数对应) print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测结果 - " f"违规: {has_violation}, 类型: {detector_type}, 数据: {data}") - # 处理违规逻辑 + # 处理违规逻辑(逻辑不变,基于detect返回结果执行) if has_violation: print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - " f"类型: {detector_type}, 详情: {data}") @@ -227,7 +229,7 @@ ws_router = APIRouter() @ws_router.websocket(WS_ENDPOINT) async def websocket_endpoint(websocket: WebSocket): - load_model() + load_model() # 加载检测模型(仅在连接建立时加载一次,避免重复加载) await websocket.accept() client_ip = websocket.client.host if websocket.client else "unknown_ip" current_time = get_current_time_str() @@ -236,7 +238,7 @@ async def websocket_endpoint(websocket: WebSocket): is_online_updated = False try: - # 处理重复连接 + # 处理重复连接(同一IP断开旧连接) if client_ip in connected_clients: old_conn = connected_clients[client_ip] if old_conn.consumer_task and not old_conn.consumer_task.done(): @@ -245,13 +247,13 @@ async def websocket_endpoint(websocket: WebSocket): connected_clients.pop(client_ip) print(f"[{current_time}] 客户端{client_ip}: 已关闭旧连接") - # 注册新连接 + # 注册新连接(绑定client_ip和WebSocket) new_conn = ClientConnection(websocket, client_ip) connected_clients[client_ip] = new_conn - new_conn.start_consumer() - await new_conn.send_frame_permit() + new_conn.start_consumer() # 启动帧消费任务 + await new_conn.send_frame_permit() # 发送首次帧许可 - # 标记上线 + # 标记客户端上线 try: await asyncio.to_thread(update_online_status_by_ip, client_ip, 1) action_data = DeviceActionCreate(client_ip=client_ip, action=1) @@ -263,7 +265,7 @@ async def websocket_endpoint(websocket: WebSocket): print(f"[{current_time}] 客户端{client_ip}: 新连接注册成功、在线数: {len(connected_clients)}") - # 消息循环 + # 消息循环(持续接收客户端消息) while True: data = await websocket.receive() if "text" in data: @@ -276,12 +278,13 @@ async def websocket_endpoint(websocket: WebSocket): except Exception as e: print(f"[{get_current_time_str()}] 客户端{client_ip}: 连接异常 - {str(e)[:50]}") finally: - # 清理资源 + # 清理资源(断开后标记离线+删除连接) if client_ip in connected_clients: conn = connected_clients[client_ip] if conn.consumer_task and not conn.consumer_task.done(): conn.consumer_task.cancel() + # 仅当上线状态更新成功时,才执行离线更新 if is_online_updated: try: await asyncio.to_thread(update_online_status_by_ip, client_ip, 0)