从服务器读取IP并将检测数据写入数据库

This commit is contained in:
2025-09-10 08:57:56 +08:00
parent d3c4820b73
commit ae177ca14a
4 changed files with 200 additions and 83 deletions

View File

@ -1,9 +1,18 @@
import cv2
import numpy as np
from PIL.Image import Image
from core.ocr import load_model as ocrLoadModel, detect as ocrDetect
from core.face import load_model as faceLoadModel, detect as faceDetect
from core.yolo import load_model as yoloLoadModel, detect as yoloDetect
# 导入保存路径函数(根据实际文件位置调整导入路径)
from core.establish import get_image_save_path
import numpy as np
import base64
from io import BytesIO
from PIL import Image
from ds.db import db
from mysql.connector import Error as MySQLError
# 模型加载状态标记(避免重复加载)
@ -26,7 +35,28 @@ def load_model():
print("所有检测模型加载完成")
def detect(frame):
def save_db(model_type, client_ip, result):
conn = None
cursor = None
try:
# 连接数据库
conn = db.get_connection()
# 往表插入数据
cursor = conn.cursor(dictionary=True) # 返回字典格式结果
insert_query = """
INSERT INTO device_danger (client_ip, type, result)
VALUES (%s, %s, %s)
"""
cursor.execute(insert_query, (client_ip, model_type, result))
conn.commit()
except MySQLError as e:
raise Exception(f"获取设备列表失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
def detect(client_ip, frame):
"""
执行模型检测,检测到违规时按指定格式保存图片
参数:
@ -38,23 +68,19 @@ def detect(frame):
yolo_flag, yolo_result = yoloDetect(frame)
print(f"YOLO检测结果{yolo_result}")
if yolo_flag:
# 元组解构:获取「完整保存路径」和「显示用短路径」
full_save_path, display_path = get_image_save_path(model_type="yolo")
if full_save_path: # 只判断完整路径是否有效(用于保存)
cv2.imwrite(full_save_path, frame)
# 打印时使用「显示用短路径」,符合需求格式
print(f"✅ YOLO违规图片已保存{display_path}")
save_db(model_type="yolo", client_ip=client_ip, result=numpy_array_to_base64(frame))
# if full_save_path: # 只判断完整路径是否有效(用于保存)
# cv2.imwrite(full_save_path, frame)
# # 打印时使用「显示用短路径」,符合需求格式
# print(f"✅ YOLO违规图片已保存{display_path}")
return (True, yolo_result, "yolo")
# 2. 人脸检测优先级2
#
# # 2. 人脸检测优先级2
face_flag, face_result = faceDetect(frame)
print(f"人脸检测结果:{face_result}")
if face_flag:
# 同样解构元组,分离保存路径和显示路径
full_save_path, display_path = get_image_save_path(model_type="face")
if full_save_path:
cv2.imwrite(full_save_path, frame)
print(f"✅ 人脸违规图片已保存:{display_path}")
# 将帧转化为 base64 字符串
save_db(model_type="face", client_ip=client_ip, result=numpy_array_to_base64(frame))
return (True, face_result, "face")
# 3. OCR检测优先级3
@ -62,12 +88,70 @@ def detect(frame):
print(f"OCR检测结果{ocr_result}")
if ocr_flag:
# 解构元组,保存用完整路径,打印用短路径
full_save_path, display_path = get_image_save_path(model_type="ocr")
if full_save_path:
cv2.imwrite(full_save_path, frame)
print(f"✅ OCR违规图片已保存{display_path}")
save_db(model_type="ocr", client_ip=client_ip, result=ocr_result)
# if full_save_path:
# cv2.imwrite(full_save_path, frame)
# print(f"✅ OCR违规图片已保存{display_path}")
return (True, ocr_result, "ocr")
# 4. 无违规内容(不保存图片)
print(f"❌ 未检测到任何违规内容,不保存图片")
return (False, "未检测到任何内容", "none")
return (False, "未检测到任何内容", "none")
def numpy_array_to_base64(arr, img_format='PNG'):
"""
将numpy数组转换为base64字符串
参数:
arr: numpy数组通常是图像数据形状为(height, width, channels)
img_format: 图像格式,默认为'PNG',也可以是'JPEG'等PIL支持的格式
返回:
str: 转换后的base64字符串
异常:
ValueError: 当输入不是有效的numpy数组或不支持的形状时抛出
Exception: 处理过程中出现的其他异常
"""
try:
# 检查输入是否为numpy数组
if not isinstance(arr, np.ndarray):
raise ValueError("输入必须是numpy数组")
# 处理单通道图像(灰度图)
if len(arr.shape) == 2:
arr = np.expand_dims(arr, axis=-1)
# 检查数组形状是否有效
if len(arr.shape) != 3 or arr.shape[2] not in [1, 3, 4]:
raise ValueError("numpy数组必须是形状为(height, width, channels)的图像数据通道数应为1、3或4")
# 处理数据类型确保是uint8类型
if arr.dtype != np.uint8:
# 归一化到0-255并转换为uint8
arr = ((arr - arr.min()) / (arr.max() - arr.min() + 1e-8) * 255).astype(np.uint8)
# 将单通道图像转换为PIL支持的模式
if arr.shape[2] == 1:
arr = arr.squeeze(axis=-1)
image = Image.fromarray(arr, mode='L') # L模式表示灰度图
elif arr.shape[2] == 3:
image = Image.fromarray(arr, mode='RGB')
else: # 4通道
image = Image.fromarray(arr, mode='RGBA')
# 将图像保存到内存缓冲区
buffer = BytesIO()
image.save(buffer, format=img_format)
# 从缓冲区读取数据并编码为base64
buffer.seek(0)
base64_str = base64.b64encode(buffer.read()).decode('utf-8')
return base64_str
except ValueError as ve:
raise ve
except Exception as e:
raise Exception(f"转换过程中发生错误: {str(e)}")

View File

@ -2,15 +2,11 @@ import os
import datetime
from pathlib import Path
# 配置IP文件路径统一使用绝对路径
IP_FILE_PATH = Path(r"D:\ccc\IP.txt")
from service.device_service import get_unique_client_ips
def create_directory_structure():
"""创建项目所需的目录结构"""
"""创建项目所需的目录结构为所有客户端IP预创建基础目录"""
try:
# 1. 创建根目录下的resource文件夹
# 1. 创建根目录下的resource文件夹(存在则跳过,不覆盖子内容)
resource_dir = Path("resource")
resource_dir.mkdir(exist_ok=True)
print(f"确保resource目录存在: {resource_dir.absolute()}")
@ -27,87 +23,95 @@ def create_directory_structure():
model_dir.mkdir(exist_ok=True)
print(f"确保{model}模型目录存在: {model_dir.absolute()}")
# 4. 读取ip.txt文件获取IP地址
# 4. 调用外部方法获取所有客户端IP地址
try:
with open(IP_FILE_PATH, "r") as f:
ip_addresses = [line.strip() for line in f if line.strip()]
# 调用外部ip_read()方法获取所有客户端IP地址列表
all_ip_addresses = get_unique_client_ips()
if not ip_addresses:
print("警告: ip.txt文件中未找到有效的IP地址")
# 确保返回的是列表类型
if not isinstance(all_ip_addresses, list):
all_ip_addresses = [all_ip_addresses]
# 过滤有效IP去除空字符串和空格
valid_ips = [ip.strip() for ip in all_ip_addresses if ip.strip()]
if not valid_ips:
print("警告: 未获取到有效的客户端IP地址")
return
print(f"从ip.txt中读取到的IP地址: {ip_addresses}")
print(f"获取到的所有客户端IP地址: {valid_ips}")
# 5. 获取当前日期
# 5. 获取当前日期(年、月)
now = datetime.datetime.now()
current_year = str(now.year)
current_month = str(now.month)
# 6. 为每个IP在每个模型文件夹下创建年->月的目录结构
for ip in ip_addresses:
# 直接使用原始IP格式
safe_ip = ip
# 6. 为每个客户端IP在每个模型文件夹下创建年->月的基础目录结构
for ip in valid_ips:
# 处理IP地址中的特殊字符将.替换为_避免路径问题
safe_ip = ip.replace(".", "_")
for model in model_dirs:
# 构建路径: resource/dect/{model}/{ip}/{year}/{month}
# 构建路径: resource/dect/{model}/{safe_ip}/{year}/{month}
ip_dir = dect_dir / model / safe_ip
year_dir = ip_dir / current_year
month_dir = year_dir / current_month
# 创建目录(如果不存在)
# 递归创建目录(存在则跳过,不覆盖
month_dir.mkdir(parents=True, exist_ok=True)
print(f"创建/确保目录存在: {month_dir.absolute()}")
print(f"为客户端IP {ip} 创建/确保目录存在: {month_dir.absolute()}")
except FileNotFoundError:
print(f"错误: 未找到ip.txt文件请确保该文件存在于 {IP_FILE_PATH}")
except Exception as e:
print(f"处理IP和日期目录时发生错误: {str(e)}")
print(f"处理客户端IP和日期目录时发生错误: {str(e)}")
except Exception as e:
print(f"创建目录结构时发生错误: {str(e)}")
print(f"创建基础目录结构时发生错误: {str(e)}")
def get_image_save_path(model_type: str) -> tuple:
def get_image_save_path(model_type: str, client_ip: str) -> tuple:
"""
获取图片保存的完整路径显示用路径
获取图片保存的完整路径」和「显示用路径
参数:
model_type: 模型类型,应为"ocr""face""yolo"
client_ip: 检测到违禁的客户端IP地址原始格式如192.168.1.101
返回:
元组 (完整保存路径, 显示用路径)
元组 (完整保存路径, 显示用路径);若出错则返回 ("", "")
"""
try:
# 读取IP地址假设只有一个IP或使用第一个IP
with open(IP_FILE_PATH, "r") as f:
ip_addresses = [line.strip() for line in f if line.strip()]
# 1. 验证客户端IP有效性检查是否在已知IP列表中
all_ip_addresses = get_unique_client_ips()
if not isinstance(all_ip_addresses, list):
all_ip_addresses = [all_ip_addresses]
valid_ips = [ip.strip() for ip in all_ip_addresses if ip.strip()]
if not ip_addresses:
raise ValueError("ip.txt文件中未找到有效的IP地址")
if client_ip.strip() not in valid_ips:
raise ValueError(f"客户端IP {client_ip} 不在已知IP列表中无法保存文件")
ip = ip_addresses[0]
safe_ip = ip # 直接使用原始IP格式
# 2. 处理IP地址与目录创建逻辑一致将.替换为_
safe_ip = client_ip.strip().replace(".", "_")
# 获取当前日期和时间(精确到毫秒,确保文件名唯一)
# 3. 获取当前日期和毫秒级时间戳(确保文件名唯一)
now = datetime.datetime.now()
current_year = str(now.year)
current_month = str(now.month)
current_day = str(now.day)
# 生成时间戳字符串(格式:年月日时分秒毫秒)
timestamp = now.strftime("%Y%m%d%H%M%S%f")[:-3] # 去除最后三位,保留到毫秒
# 时间戳格式:年月日时分秒毫秒如20250910143050123
timestamp = now.strftime("%Y%m%d%H%M%S%f")[:-3]
# 构建基础目录路径
base_dir = Path("resource") / "dect"
# 构建完整路径: resource/dect/{model}/{ip}/{year}/{month}/{day}
# 4. 定义基础目录(用于生成相对路径
base_dir = Path("resource") / "dect" # 显示路径会去掉这个前缀
# 构建日级目录(完整路径resource/dect/{model}/{safe_ip}/{年}/{月}/{日}
day_dir = base_dir / model_type / safe_ip / current_year / current_month / current_day
day_dir.mkdir(parents=True, exist_ok=True)
day_dir.mkdir(parents=True, exist_ok=True) # 确保日目录存在
# 构建图片文件名简化名称去掉resource_dect_前缀
image_filename = f"{model_type}_{safe_ip}_{current_year}_{current_month}_{current_day}_{timestamp}.jpg"
full_path = day_dir / image_filename
# 5. 构建唯一文件名
image_filename = f"dect_{model_type}_{safe_ip}_{current_year}{current_month}{current_day}_{timestamp}.jpg"
# 计算显示用路径相对于resource/dect的路径
display_path = full_path.relative_to(base_dir)
# 6. 生成完整路径(用于实际保存图片)和显示路径(用于打印
full_path = day_dir / image_filename # 完整路径resource/dect/.../xxx.jpg
display_path = full_path.relative_to(base_dir) # 短路径:{model}/.../xxx.jpg去掉resource/dect
return str(full_path), str(display_path)

View File

@ -236,3 +236,29 @@ async def get_device_list(
raise Exception(f"获取设备列表失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
def get_unique_client_ips() -> list[str]:
"""
获取所有去重的客户端IP列表
:return: 去重后的客户端IP字符串列表如果没有数据则返回空列表
"""
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 查询去重的客户端IP
query = "SELECT DISTINCT client_ip FROM devices WHERE client_ip IS NOT NULL"
cursor.execute(query)
# 提取结果并转换为字符串列表
results = cursor.fetchall()
return [item['client_ip'] for item in results]
except MySQLError as e:
raise Exception(f"获取客户端IP列表失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)

View File

@ -33,7 +33,7 @@ def get_current_time_file_str() -> str:
class ClientConnection:
def __init__(self, websocket: WebSocket, client_ip: str):
self.websocket = websocket
self.client_ip = client_ip
self.client_ip = client_ip # 已初始化客户端IP用于传递给detect
self.last_heartbeat = datetime.datetime.now()
self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE)
self.consumer_task: Optional[asyncio.Task] = None
@ -84,7 +84,7 @@ class ClientConnection:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费逻辑错误 - {str(e)}")
async def process_frame(self, frame_data: bytes) -> None:
"""处理单帧图像数据(核心修按3个返回值解包"""
"""处理单帧图像数据(核心修detect函数传入 client_ip + img 双参数"""
# 二进制转OpenCV图像
nparr = np.frombuffer(frame_data, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
@ -93,19 +93,21 @@ class ClientConnection:
return
try:
# -------------------------- 修复核心匹配detect返回的3个值 --------------------------
# 假设detect返回 (是否违规, 结果数据, 检测器类型)
# -------------------------- 核心修改按要求传入参数1.client_ip 2.img --------------------------
# detect函数参数顺序第一个为client_ip第二个为图像数据img
# 保持返回值解包(是否违规, 结果数据, 检测器类型)不变
has_violation, data, detector_type = await asyncio.to_thread(
detect, # 调用检测函数
img # 传入图像参数
detect, # 调用检测函数
self.client_ip, # 第一个参数客户端IP新增按需求顺序
img # 第二个参数:图像数据(原参数,调整顺序)
)
# -------------------------------------------------------------------------------------
# 打印检测结果(移除task_id相关内容
# 打印检测结果(包含客户端IP与传入参数对应
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测结果 - "
f"违规: {has_violation}, 类型: {detector_type}, 数据: {data}")
# 处理违规逻辑
# 处理违规逻辑逻辑不变基于detect返回结果执行
if has_violation:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - "
f"类型: {detector_type}, 详情: {data}")
@ -227,7 +229,7 @@ ws_router = APIRouter()
@ws_router.websocket(WS_ENDPOINT)
async def websocket_endpoint(websocket: WebSocket):
load_model()
load_model() # 加载检测模型(仅在连接建立时加载一次,避免重复加载)
await websocket.accept()
client_ip = websocket.client.host if websocket.client else "unknown_ip"
current_time = get_current_time_str()
@ -236,7 +238,7 @@ async def websocket_endpoint(websocket: WebSocket):
is_online_updated = False
try:
# 处理重复连接
# 处理重复连接同一IP断开旧连接
if client_ip in connected_clients:
old_conn = connected_clients[client_ip]
if old_conn.consumer_task and not old_conn.consumer_task.done():
@ -245,13 +247,13 @@ async def websocket_endpoint(websocket: WebSocket):
connected_clients.pop(client_ip)
print(f"[{current_time}] 客户端{client_ip}: 已关闭旧连接")
# 注册新连接
# 注册新连接绑定client_ip和WebSocket
new_conn = ClientConnection(websocket, client_ip)
connected_clients[client_ip] = new_conn
new_conn.start_consumer()
await new_conn.send_frame_permit()
new_conn.start_consumer() # 启动帧消费任务
await new_conn.send_frame_permit() # 发送首次帧许可
# 标记上线
# 标记客户端上线
try:
await asyncio.to_thread(update_online_status_by_ip, client_ip, 1)
action_data = DeviceActionCreate(client_ip=client_ip, action=1)
@ -263,7 +265,7 @@ async def websocket_endpoint(websocket: WebSocket):
print(f"[{current_time}] 客户端{client_ip}: 新连接注册成功、在线数: {len(connected_clients)}")
# 消息循环
# 消息循环(持续接收客户端消息)
while True:
data = await websocket.receive()
if "text" in data:
@ -276,12 +278,13 @@ async def websocket_endpoint(websocket: WebSocket):
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}: 连接异常 - {str(e)[:50]}")
finally:
# 清理资源
# 清理资源(断开后标记离线+删除连接)
if client_ip in connected_clients:
conn = connected_clients[client_ip]
if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel()
# 仅当上线状态更新成功时,才执行离线更新
if is_online_updated:
try:
await asyncio.to_thread(update_online_status_by_ip, client_ip, 0)