从服务器读取IP并将检测数据写入数据库
This commit is contained in:
		
							
								
								
									
										122
									
								
								core/all.py
									
									
									
									
									
								
							
							
						
						
									
										122
									
								
								core/all.py
									
									
									
									
									
								
							| @ -1,9 +1,18 @@ | |||||||
| import cv2 | import cv2 | ||||||
|  | import numpy as np | ||||||
|  | from PIL.Image import Image | ||||||
|  |  | ||||||
| from core.ocr import load_model as ocrLoadModel, detect as ocrDetect | from core.ocr import load_model as ocrLoadModel, detect as ocrDetect | ||||||
| from core.face import load_model as faceLoadModel, detect as faceDetect | from core.face import load_model as faceLoadModel, detect as faceDetect | ||||||
| from core.yolo import load_model as yoloLoadModel, detect as yoloDetect | from core.yolo import load_model as yoloLoadModel, detect as yoloDetect | ||||||
| # 导入保存路径函数(根据实际文件位置调整导入路径) | # 导入保存路径函数(根据实际文件位置调整导入路径) | ||||||
| from core.establish import get_image_save_path | import numpy as np | ||||||
|  | import base64 | ||||||
|  | from io import BytesIO | ||||||
|  | from PIL import Image | ||||||
|  | from ds.db import db | ||||||
|  | from mysql.connector import Error as MySQLError | ||||||
|  |  | ||||||
| # 模型加载状态标记(避免重复加载) | # 模型加载状态标记(避免重复加载) | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -26,7 +35,28 @@ def load_model(): | |||||||
|     print("所有检测模型加载完成") |     print("所有检测模型加载完成") | ||||||
|  |  | ||||||
|  |  | ||||||
| def detect(frame): | def save_db(model_type, client_ip, result): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         # 连接数据库 | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         # 往表插入数据 | ||||||
|  |         cursor = conn.cursor(dictionary=True)  # 返回字典格式结果 | ||||||
|  |         insert_query = """ | ||||||
|  |             INSERT INTO device_danger (client_ip, type, result) | ||||||
|  |             VALUES (%s, %s, %s) | ||||||
|  |         """ | ||||||
|  |         cursor.execute(insert_query, (client_ip, model_type, result)) | ||||||
|  |         conn.commit() | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise Exception(f"获取设备列表失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def detect(client_ip, frame): | ||||||
|     """ |     """ | ||||||
|     执行模型检测,检测到违规时按指定格式保存图片 |     执行模型检测,检测到违规时按指定格式保存图片 | ||||||
|     参数: |     参数: | ||||||
| @ -38,23 +68,19 @@ def detect(frame): | |||||||
|     yolo_flag, yolo_result = yoloDetect(frame) |     yolo_flag, yolo_result = yoloDetect(frame) | ||||||
|     print(f"YOLO检测结果:{yolo_result}") |     print(f"YOLO检测结果:{yolo_result}") | ||||||
|     if yolo_flag: |     if yolo_flag: | ||||||
|         # 元组解构:获取「完整保存路径」和「显示用短路径」 |         save_db(model_type="yolo", client_ip=client_ip, result=numpy_array_to_base64(frame)) | ||||||
|         full_save_path, display_path = get_image_save_path(model_type="yolo") |         # if full_save_path:  # 只判断完整路径是否有效(用于保存) | ||||||
|         if full_save_path:  # 只判断完整路径是否有效(用于保存) |         #     cv2.imwrite(full_save_path, frame) | ||||||
|             cv2.imwrite(full_save_path, frame) |         #     # 打印时使用「显示用短路径」,符合需求格式 | ||||||
|             # 打印时使用「显示用短路径」,符合需求格式 |         #     print(f"✅ YOLO违规图片已保存:{display_path}") | ||||||
|             print(f"✅ YOLO违规图片已保存:{display_path}") |  | ||||||
|         return (True, yolo_result, "yolo") |         return (True, yolo_result, "yolo") | ||||||
|  |     # | ||||||
|     # 2. 人脸检测(优先级2) |     # # 2. 人脸检测(优先级2) | ||||||
|     face_flag, face_result = faceDetect(frame) |     face_flag, face_result = faceDetect(frame) | ||||||
|     print(f"人脸检测结果:{face_result}") |     print(f"人脸检测结果:{face_result}") | ||||||
|     if face_flag: |     if face_flag: | ||||||
|         # 同样解构元组,分离保存路径和显示路径 |         # 将帧转化为 base64 字符串 | ||||||
|         full_save_path, display_path = get_image_save_path(model_type="face") |         save_db(model_type="face", client_ip=client_ip, result=numpy_array_to_base64(frame)) | ||||||
|         if full_save_path: |  | ||||||
|             cv2.imwrite(full_save_path, frame) |  | ||||||
|             print(f"✅ 人脸违规图片已保存:{display_path}") |  | ||||||
|         return (True, face_result, "face") |         return (True, face_result, "face") | ||||||
|  |  | ||||||
|     # 3. OCR检测(优先级3) |     # 3. OCR检测(优先级3) | ||||||
| @ -62,12 +88,70 @@ def detect(frame): | |||||||
|     print(f"OCR检测结果:{ocr_result}") |     print(f"OCR检测结果:{ocr_result}") | ||||||
|     if ocr_flag: |     if ocr_flag: | ||||||
|         # 解构元组,保存用完整路径,打印用短路径 |         # 解构元组,保存用完整路径,打印用短路径 | ||||||
|         full_save_path, display_path = get_image_save_path(model_type="ocr") |         save_db(model_type="ocr", client_ip=client_ip, result=ocr_result) | ||||||
|         if full_save_path: |         # if full_save_path: | ||||||
|             cv2.imwrite(full_save_path, frame) |         #     cv2.imwrite(full_save_path, frame) | ||||||
|             print(f"✅ OCR违规图片已保存:{display_path}") |         #     print(f"✅ OCR违规图片已保存:{display_path}") | ||||||
|         return (True, ocr_result, "ocr") |         return (True, ocr_result, "ocr") | ||||||
|  |  | ||||||
|     # 4. 无违规内容(不保存图片) |     # 4. 无违规内容(不保存图片) | ||||||
|     print(f"❌ 未检测到任何违规内容,不保存图片") |     print(f"❌ 未检测到任何违规内容,不保存图片") | ||||||
|     return (False, "未检测到任何内容", "none") |     return (False, "未检测到任何内容", "none") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def numpy_array_to_base64(arr, img_format='PNG'): | ||||||
|  |     """ | ||||||
|  |     将numpy数组转换为base64字符串 | ||||||
|  |  | ||||||
|  |     参数: | ||||||
|  |         arr: numpy数组,通常是图像数据,形状为(height, width, channels) | ||||||
|  |         img_format: 图像格式,默认为'PNG',也可以是'JPEG'等PIL支持的格式 | ||||||
|  |  | ||||||
|  |     返回: | ||||||
|  |         str: 转换后的base64字符串 | ||||||
|  |  | ||||||
|  |     异常: | ||||||
|  |         ValueError: 当输入不是有效的numpy数组或不支持的形状时抛出 | ||||||
|  |         Exception: 处理过程中出现的其他异常 | ||||||
|  |     """ | ||||||
|  |     try: | ||||||
|  |         # 检查输入是否为numpy数组 | ||||||
|  |         if not isinstance(arr, np.ndarray): | ||||||
|  |             raise ValueError("输入必须是numpy数组") | ||||||
|  |  | ||||||
|  |         # 处理单通道图像(灰度图) | ||||||
|  |         if len(arr.shape) == 2: | ||||||
|  |             arr = np.expand_dims(arr, axis=-1) | ||||||
|  |  | ||||||
|  |         # 检查数组形状是否有效 | ||||||
|  |         if len(arr.shape) != 3 or arr.shape[2] not in [1, 3, 4]: | ||||||
|  |             raise ValueError("numpy数组必须是形状为(height, width, channels)的图像数据,通道数应为1、3或4") | ||||||
|  |  | ||||||
|  |         # 处理数据类型,确保是uint8类型 | ||||||
|  |         if arr.dtype != np.uint8: | ||||||
|  |             # 归一化到0-255并转换为uint8 | ||||||
|  |             arr = ((arr - arr.min()) / (arr.max() - arr.min() + 1e-8) * 255).astype(np.uint8) | ||||||
|  |  | ||||||
|  |         # 将单通道图像转换为PIL支持的模式 | ||||||
|  |         if arr.shape[2] == 1: | ||||||
|  |             arr = arr.squeeze(axis=-1) | ||||||
|  |             image = Image.fromarray(arr, mode='L')  # L模式表示灰度图 | ||||||
|  |         elif arr.shape[2] == 3: | ||||||
|  |             image = Image.fromarray(arr, mode='RGB') | ||||||
|  |         else:  # 4通道 | ||||||
|  |             image = Image.fromarray(arr, mode='RGBA') | ||||||
|  |  | ||||||
|  |         # 将图像保存到内存缓冲区 | ||||||
|  |         buffer = BytesIO() | ||||||
|  |         image.save(buffer, format=img_format) | ||||||
|  |  | ||||||
|  |         # 从缓冲区读取数据并编码为base64 | ||||||
|  |         buffer.seek(0) | ||||||
|  |         base64_str = base64.b64encode(buffer.read()).decode('utf-8') | ||||||
|  |  | ||||||
|  |         return base64_str | ||||||
|  |  | ||||||
|  |     except ValueError as ve: | ||||||
|  |         raise ve | ||||||
|  |     except Exception as e: | ||||||
|  |         raise Exception(f"转换过程中发生错误: {str(e)}") | ||||||
| @ -2,15 +2,11 @@ import os | |||||||
| import datetime | import datetime | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  |  | ||||||
|  | from service.device_service import get_unique_client_ips | ||||||
| # 配置IP文件路径(统一使用绝对路径) |  | ||||||
| IP_FILE_PATH = Path(r"D:\ccc\IP.txt") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def create_directory_structure(): | def create_directory_structure(): | ||||||
|     """创建项目所需的目录结构""" |     """创建项目所需的目录结构,为所有客户端IP预创建基础目录""" | ||||||
|     try: |     try: | ||||||
|         # 1. 创建根目录下的resource文件夹 |         # 1. 创建根目录下的resource文件夹(存在则跳过,不覆盖子内容) | ||||||
|         resource_dir = Path("resource") |         resource_dir = Path("resource") | ||||||
|         resource_dir.mkdir(exist_ok=True) |         resource_dir.mkdir(exist_ok=True) | ||||||
|         print(f"确保resource目录存在: {resource_dir.absolute()}") |         print(f"确保resource目录存在: {resource_dir.absolute()}") | ||||||
| @ -27,87 +23,95 @@ def create_directory_structure(): | |||||||
|             model_dir.mkdir(exist_ok=True) |             model_dir.mkdir(exist_ok=True) | ||||||
|             print(f"确保{model}模型目录存在: {model_dir.absolute()}") |             print(f"确保{model}模型目录存在: {model_dir.absolute()}") | ||||||
|  |  | ||||||
|         # 4. 读取ip.txt文件获取IP地址 |         # 4. 调用外部方法获取所有客户端IP地址 | ||||||
|         try: |         try: | ||||||
|             with open(IP_FILE_PATH, "r") as f: |             # 调用外部ip_read()方法获取所有客户端IP地址列表 | ||||||
|                 ip_addresses = [line.strip() for line in f if line.strip()] |             all_ip_addresses = get_unique_client_ips() | ||||||
|  |  | ||||||
|             if not ip_addresses: |             # 确保返回的是列表类型 | ||||||
|                 print("警告: ip.txt文件中未找到有效的IP地址") |             if not isinstance(all_ip_addresses, list): | ||||||
|  |                 all_ip_addresses = [all_ip_addresses] | ||||||
|  |  | ||||||
|  |             # 过滤有效IP(去除空字符串和空格) | ||||||
|  |             valid_ips = [ip.strip() for ip in all_ip_addresses if ip.strip()] | ||||||
|  |  | ||||||
|  |             if not valid_ips: | ||||||
|  |                 print("警告: 未获取到有效的客户端IP地址") | ||||||
|                 return |                 return | ||||||
|  |  | ||||||
|             print(f"从ip.txt中读取到的IP地址: {ip_addresses}") |             print(f"获取到的所有客户端IP地址: {valid_ips}") | ||||||
|  |  | ||||||
|             # 5. 获取当前日期 |             # 5. 获取当前日期(年、月) | ||||||
|             now = datetime.datetime.now() |             now = datetime.datetime.now() | ||||||
|             current_year = str(now.year) |             current_year = str(now.year) | ||||||
|             current_month = str(now.month) |             current_month = str(now.month) | ||||||
|  |  | ||||||
|             # 6. 为每个IP在每个模型文件夹下创建年->月的目录结构 |             # 6. 为每个客户端IP在每个模型文件夹下创建年->月的基础目录结构 | ||||||
|             for ip in ip_addresses: |             for ip in valid_ips: | ||||||
|                 # 直接使用原始IP格式 |                 # 处理IP地址中的特殊字符(将.替换为_,避免路径问题) | ||||||
|                 safe_ip = ip |                 safe_ip = ip.replace(".", "_") | ||||||
|  |  | ||||||
|                 for model in model_dirs: |                 for model in model_dirs: | ||||||
|                     # 构建路径: resource/dect/{model}/{ip}/{year}/{month} |                     # 构建路径: resource/dect/{model}/{safe_ip}/{year}/{month} | ||||||
|                     ip_dir = dect_dir / model / safe_ip |                     ip_dir = dect_dir / model / safe_ip | ||||||
|                     year_dir = ip_dir / current_year |                     year_dir = ip_dir / current_year | ||||||
|                     month_dir = year_dir / current_month |                     month_dir = year_dir / current_month | ||||||
|  |  | ||||||
|                     # 创建目录(如果不存在) |                     # 递归创建目录(存在则跳过,不覆盖) | ||||||
|                     month_dir.mkdir(parents=True, exist_ok=True) |                     month_dir.mkdir(parents=True, exist_ok=True) | ||||||
|                     print(f"创建/确保目录存在: {month_dir.absolute()}") |                     print(f"为客户端IP {ip} 创建/确保目录存在: {month_dir.absolute()}") | ||||||
|  |  | ||||||
|         except FileNotFoundError: |  | ||||||
|             print(f"错误: 未找到ip.txt文件,请确保该文件存在于 {IP_FILE_PATH}") |  | ||||||
|         except Exception as e: |  | ||||||
|             print(f"处理IP和日期目录时发生错误: {str(e)}") |  | ||||||
|  |  | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|         print(f"创建目录结构时发生错误: {str(e)}") |             print(f"处理客户端IP和日期目录时发生错误: {str(e)}") | ||||||
|  |  | ||||||
|  |     except Exception as e: | ||||||
|  |         print(f"创建基础目录结构时发生错误: {str(e)}") | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_image_save_path(model_type: str) -> tuple: | def get_image_save_path(model_type: str, client_ip: str) -> tuple: | ||||||
|     """ |     """ | ||||||
|     获取图片保存的完整路径和显示用路径 |     获取图片保存的「完整路径」和「显示用短路径」 | ||||||
|  |  | ||||||
|     参数: |     参数: | ||||||
|         model_type: 模型类型,应为"ocr"、"face"或"yolo" |         model_type: 模型类型,应为"ocr"、"face"或"yolo" | ||||||
|  |         client_ip: 检测到违禁的客户端IP地址(原始格式,如192.168.1.101) | ||||||
|  |  | ||||||
|     返回: |     返回: | ||||||
|         元组 (完整保存路径, 显示用路径) |         元组 (完整保存路径, 显示用短路径);若出错则返回 ("", "") | ||||||
|     """ |     """ | ||||||
|     try: |     try: | ||||||
|         # 读取IP地址(假设只有一个IP或使用第一个IP) |         # 1. 验证客户端IP有效性(检查是否在已知IP列表中) | ||||||
|         with open(IP_FILE_PATH, "r") as f: |         all_ip_addresses = get_unique_client_ips() | ||||||
|             ip_addresses = [line.strip() for line in f if line.strip()] |         if not isinstance(all_ip_addresses, list): | ||||||
|  |             all_ip_addresses = [all_ip_addresses] | ||||||
|  |         valid_ips = [ip.strip() for ip in all_ip_addresses if ip.strip()] | ||||||
|  |  | ||||||
|         if not ip_addresses: |         if client_ip.strip() not in valid_ips: | ||||||
|             raise ValueError("ip.txt文件中未找到有效的IP地址") |             raise ValueError(f"客户端IP {client_ip} 不在已知IP列表中,无法保存文件") | ||||||
|  |  | ||||||
|         ip = ip_addresses[0] |         # 2. 处理IP地址(与目录创建逻辑一致,将.替换为_) | ||||||
|         safe_ip = ip  # 直接使用原始IP格式 |         safe_ip = client_ip.strip().replace(".", "_") | ||||||
|  |  | ||||||
|         # 获取当前日期和时间(精确到毫秒,确保文件名唯一) |         # 3. 获取当前日期和毫秒级时间戳(确保文件名唯一) | ||||||
|         now = datetime.datetime.now() |         now = datetime.datetime.now() | ||||||
|         current_year = str(now.year) |         current_year = str(now.year) | ||||||
|         current_month = str(now.month) |         current_month = str(now.month) | ||||||
|         current_day = str(now.day) |         current_day = str(now.day) | ||||||
|         # 生成时间戳字符串(格式:年月日时分秒毫秒) |         # 时间戳格式:年月日时分秒毫秒(如20250910143050123) | ||||||
|         timestamp = now.strftime("%Y%m%d%H%M%S%f")[:-3]  # 去除最后三位,保留到毫秒 |         timestamp = now.strftime("%Y%m%d%H%M%S%f")[:-3] | ||||||
|  |  | ||||||
|         # 构建基础目录路径 |         # 4. 定义基础目录(用于生成相对路径) | ||||||
|         base_dir = Path("resource") / "dect" |         base_dir = Path("resource") / "dect"  # 显示路径会去掉这个前缀 | ||||||
|         # 构建完整路径: resource/dect/{model}/{ip}/{year}/{month}/{day} |         # 构建日级目录(完整路径:resource/dect/{model}/{safe_ip}/{年}/{月}/{日}) | ||||||
|         day_dir = base_dir / model_type / safe_ip / current_year / current_month / current_day |         day_dir = base_dir / model_type / safe_ip / current_year / current_month / current_day | ||||||
|         day_dir.mkdir(parents=True, exist_ok=True) |         day_dir.mkdir(parents=True, exist_ok=True)  # 确保日目录存在 | ||||||
|  |  | ||||||
|         # 构建图片文件名(简化名称,去掉resource_dect_前缀) |         # 5. 构建唯一文件名 | ||||||
|         image_filename = f"{model_type}_{safe_ip}_{current_year}_{current_month}_{current_day}_{timestamp}.jpg" |         image_filename = f"dect_{model_type}_{safe_ip}_{current_year}{current_month}{current_day}_{timestamp}.jpg" | ||||||
|         full_path = day_dir / image_filename |  | ||||||
|  |  | ||||||
|         # 计算显示用路径(相对于resource/dect的路径) |         # 6. 生成完整路径(用于实际保存图片)和显示路径(用于打印) | ||||||
|         display_path = full_path.relative_to(base_dir) |         full_path = day_dir / image_filename  # 完整路径:resource/dect/.../xxx.jpg | ||||||
|  |         display_path = full_path.relative_to(base_dir)  # 短路径:{model}/.../xxx.jpg(去掉resource/dect) | ||||||
|  |  | ||||||
|         return str(full_path), str(display_path) |         return str(full_path), str(display_path) | ||||||
|  |  | ||||||
|  | |||||||
| @ -236,3 +236,29 @@ async def get_device_list( | |||||||
|         raise Exception(f"获取设备列表失败: {str(e)}") from e |         raise Exception(f"获取设备列表失败: {str(e)}") from e | ||||||
|     finally: |     finally: | ||||||
|         db.close_connection(conn, cursor) |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_unique_client_ips() -> list[str]: | ||||||
|  |     """ | ||||||
|  |     获取所有去重的客户端IP列表 | ||||||
|  |  | ||||||
|  |     :return: 去重后的客户端IP字符串列表,如果没有数据则返回空列表 | ||||||
|  |     """ | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 查询去重的客户端IP | ||||||
|  |         query = "SELECT DISTINCT client_ip FROM devices WHERE client_ip IS NOT NULL" | ||||||
|  |         cursor.execute(query) | ||||||
|  |  | ||||||
|  |         # 提取结果并转换为字符串列表 | ||||||
|  |         results = cursor.fetchall() | ||||||
|  |         return [item['client_ip'] for item in results] | ||||||
|  |  | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise Exception(f"获取客户端IP列表失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
							
								
								
									
										33
									
								
								ws/ws.py
									
									
									
									
									
								
							
							
						
						
									
										33
									
								
								ws/ws.py
									
									
									
									
									
								
							| @ -33,7 +33,7 @@ def get_current_time_file_str() -> str: | |||||||
| class ClientConnection: | class ClientConnection: | ||||||
|     def __init__(self, websocket: WebSocket, client_ip: str): |     def __init__(self, websocket: WebSocket, client_ip: str): | ||||||
|         self.websocket = websocket |         self.websocket = websocket | ||||||
|         self.client_ip = client_ip |         self.client_ip = client_ip  # 已初始化客户端IP,用于传递给detect | ||||||
|         self.last_heartbeat = datetime.datetime.now() |         self.last_heartbeat = datetime.datetime.now() | ||||||
|         self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE) |         self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE) | ||||||
|         self.consumer_task: Optional[asyncio.Task] = None |         self.consumer_task: Optional[asyncio.Task] = None | ||||||
| @ -84,7 +84,7 @@ class ClientConnection: | |||||||
|             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费逻辑错误 - {str(e)}") |             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费逻辑错误 - {str(e)}") | ||||||
|  |  | ||||||
|     async def process_frame(self, frame_data: bytes) -> None: |     async def process_frame(self, frame_data: bytes) -> None: | ||||||
|         """处理单帧图像数据(核心修复:按3个返回值解包)""" |         """处理单帧图像数据(核心修改:detect函数传入 client_ip + img 双参数)""" | ||||||
|         # 二进制转OpenCV图像 |         # 二进制转OpenCV图像 | ||||||
|         nparr = np.frombuffer(frame_data, np.uint8) |         nparr = np.frombuffer(frame_data, np.uint8) | ||||||
|         img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |         img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | ||||||
| @ -93,19 +93,21 @@ class ClientConnection: | |||||||
|             return |             return | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             # -------------------------- 修复核心:匹配detect返回的3个值 -------------------------- |             # -------------------------- 核心修改:按要求传入参数(1.client_ip 2.img) -------------------------- | ||||||
|             # 假设detect返回 (是否违规, 结果数据, 检测器类型) |             # detect函数参数顺序:第一个为client_ip,第二个为图像数据img | ||||||
|  |             # 保持返回值解包(是否违规, 结果数据, 检测器类型)不变 | ||||||
|             has_violation, data, detector_type = await asyncio.to_thread( |             has_violation, data, detector_type = await asyncio.to_thread( | ||||||
|                 detect,                  # 调用检测函数 |                 detect,                  # 调用检测函数 | ||||||
|                 img      # 传入图像参数 |                 self.client_ip,          # 第一个参数:客户端IP(新增,按需求顺序) | ||||||
|  |                 img                      # 第二个参数:图像数据(原参数,调整顺序) | ||||||
|             ) |             ) | ||||||
|             # ------------------------------------------------------------------------------------- |             # ------------------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|             # 打印检测结果(移除task_id相关内容) |             # 打印检测结果(包含客户端IP,与传入参数对应) | ||||||
|             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测结果 - " |             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测结果 - " | ||||||
|                   f"违规: {has_violation}, 类型: {detector_type}, 数据: {data}") |                   f"违规: {has_violation}, 类型: {detector_type}, 数据: {data}") | ||||||
|  |  | ||||||
|             # 处理违规逻辑 |             # 处理违规逻辑(逻辑不变,基于detect返回结果执行) | ||||||
|             if has_violation: |             if has_violation: | ||||||
|                 print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - " |                 print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - " | ||||||
|                       f"类型: {detector_type}, 详情: {data}") |                       f"类型: {detector_type}, 详情: {data}") | ||||||
| @ -227,7 +229,7 @@ ws_router = APIRouter() | |||||||
|  |  | ||||||
| @ws_router.websocket(WS_ENDPOINT) | @ws_router.websocket(WS_ENDPOINT) | ||||||
| async def websocket_endpoint(websocket: WebSocket): | async def websocket_endpoint(websocket: WebSocket): | ||||||
|     load_model() |     load_model()  # 加载检测模型(仅在连接建立时加载一次,避免重复加载) | ||||||
|     await websocket.accept() |     await websocket.accept() | ||||||
|     client_ip = websocket.client.host if websocket.client else "unknown_ip" |     client_ip = websocket.client.host if websocket.client else "unknown_ip" | ||||||
|     current_time = get_current_time_str() |     current_time = get_current_time_str() | ||||||
| @ -236,7 +238,7 @@ async def websocket_endpoint(websocket: WebSocket): | |||||||
|     is_online_updated = False |     is_online_updated = False | ||||||
|  |  | ||||||
|     try: |     try: | ||||||
|         # 处理重复连接 |         # 处理重复连接(同一IP断开旧连接) | ||||||
|         if client_ip in connected_clients: |         if client_ip in connected_clients: | ||||||
|             old_conn = connected_clients[client_ip] |             old_conn = connected_clients[client_ip] | ||||||
|             if old_conn.consumer_task and not old_conn.consumer_task.done(): |             if old_conn.consumer_task and not old_conn.consumer_task.done(): | ||||||
| @ -245,13 +247,13 @@ async def websocket_endpoint(websocket: WebSocket): | |||||||
|             connected_clients.pop(client_ip) |             connected_clients.pop(client_ip) | ||||||
|             print(f"[{current_time}] 客户端{client_ip}: 已关闭旧连接") |             print(f"[{current_time}] 客户端{client_ip}: 已关闭旧连接") | ||||||
|  |  | ||||||
|         # 注册新连接 |         # 注册新连接(绑定client_ip和WebSocket) | ||||||
|         new_conn = ClientConnection(websocket, client_ip) |         new_conn = ClientConnection(websocket, client_ip) | ||||||
|         connected_clients[client_ip] = new_conn |         connected_clients[client_ip] = new_conn | ||||||
|         new_conn.start_consumer() |         new_conn.start_consumer()  # 启动帧消费任务 | ||||||
|         await new_conn.send_frame_permit() |         await new_conn.send_frame_permit()  # 发送首次帧许可 | ||||||
|  |  | ||||||
|         # 标记上线 |         # 标记客户端上线 | ||||||
|         try: |         try: | ||||||
|             await asyncio.to_thread(update_online_status_by_ip, client_ip, 1) |             await asyncio.to_thread(update_online_status_by_ip, client_ip, 1) | ||||||
|             action_data = DeviceActionCreate(client_ip=client_ip, action=1) |             action_data = DeviceActionCreate(client_ip=client_ip, action=1) | ||||||
| @ -263,7 +265,7 @@ async def websocket_endpoint(websocket: WebSocket): | |||||||
|  |  | ||||||
|         print(f"[{current_time}] 客户端{client_ip}: 新连接注册成功、在线数: {len(connected_clients)}") |         print(f"[{current_time}] 客户端{client_ip}: 新连接注册成功、在线数: {len(connected_clients)}") | ||||||
|  |  | ||||||
|         # 消息循环 |         # 消息循环(持续接收客户端消息) | ||||||
|         while True: |         while True: | ||||||
|             data = await websocket.receive() |             data = await websocket.receive() | ||||||
|             if "text" in data: |             if "text" in data: | ||||||
| @ -276,12 +278,13 @@ async def websocket_endpoint(websocket: WebSocket): | |||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         print(f"[{get_current_time_str()}] 客户端{client_ip}: 连接异常 - {str(e)[:50]}") |         print(f"[{get_current_time_str()}] 客户端{client_ip}: 连接异常 - {str(e)[:50]}") | ||||||
|     finally: |     finally: | ||||||
|         # 清理资源 |         # 清理资源(断开后标记离线+删除连接) | ||||||
|         if client_ip in connected_clients: |         if client_ip in connected_clients: | ||||||
|             conn = connected_clients[client_ip] |             conn = connected_clients[client_ip] | ||||||
|             if conn.consumer_task and not conn.consumer_task.done(): |             if conn.consumer_task and not conn.consumer_task.done(): | ||||||
|                 conn.consumer_task.cancel() |                 conn.consumer_task.cancel() | ||||||
|  |  | ||||||
|  |             # 仅当上线状态更新成功时,才执行离线更新 | ||||||
|             if is_online_updated: |             if is_online_updated: | ||||||
|                 try: |                 try: | ||||||
|                     await asyncio.to_thread(update_online_status_by_ip, client_ip, 0) |                     await asyncio.to_thread(update_online_status_by_ip, client_ip, 0) | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user