This commit is contained in:
2025-09-03 17:02:22 +08:00
parent 1911cd6588
commit 9d940e7fd2

View File

@ -1,15 +1,34 @@
import asyncio import asyncio
import datetime import datetime
import json import json
import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Dict, Optional from typing import Dict, Optional
import cv2
import numpy as np import numpy as np
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI
from ocr.model_violation_detector import MultiModelViolationDetector
# 配置文件相对路径(根据实际目录结构调整)
YOLO_MODEL_PATH = r"D:\Git\bin\video\ocr\models\best.pt" # 关键修正从core目录向上一级找ocr文件夹
FORBIDDEN_WORDS_PATH = r"D:\Git\bin\video\ocr\forbidden_words.txt"
OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml"
KNOWN_FACES_DIR = r"D:\Git\bin\video\ocr\known_faces"
# 创建检测器实例
detector = MultiModelViolationDetector(
forbidden_words_path=FORBIDDEN_WORDS_PATH,
ocr_config_path=OCR_CONFIG_PATH,
yolo_model_path=YOLO_MODEL_PATH,
known_faces_dir=KNOWN_FACES_DIR,
ocr_confidence_threshold=0.5
)
# -------------------------- 配置常量(简化硬编码) -------------------------- # -------------------------- 配置常量(简化硬编码) --------------------------
HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒) HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒)
HEARTBEAT_TIMEOUT = 60 # 客户端超时阈值(秒) HEARTBEAT_TIMEOUT = 600 # 客户端超时阈值(秒)
WS_ENDPOINT = "/ws" # WebSocket端点路径 WS_ENDPOINT = "/ws" # WebSocket端点路径
# -------------------------- 核心数据结构与全局变量 -------------------------- # -------------------------- 核心数据结构与全局变量 --------------------------
@ -113,9 +132,31 @@ async def handle_text_msg(client_ip: str, text: str, conn: ClientConnection):
async def handle_binary_msg(client_ip: str, data: bytes): async def handle_binary_msg(client_ip: str, data: bytes):
"""处理二进制消息(保留扩展入口)""" """处理二进制消息(保留扩展入口)"""
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:收到{len(data)}字节二进制数据") print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:收到{len(data)}字节二进制数据")
# 将二进制转化为 numpy 数组 # 将二进制数据转换为NumPy数组uint8类型
data_ndarray = np.frombuffer(data, dtype=np.uint8) nparr = np.frombuffer(data, np.uint8)
# 解码为图像返回与cv2.imread相同的格式BGR通道的ndarray
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
#转存到本地images文件夹下
# 确保images文件夹存在
if not os.path.exists('images'):
os.makedirs('images')
# 生成唯一的文件名包含时间戳和客户端IP避免文件名冲突
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"images/{client_ip.replace('.', '_')}_{timestamp}.jpg"
# 保存图像到本地
cv2.imwrite(filename, img)
print(f"[{datetime.datetime.now():%H:%M:%S}] 图像已保存至:{filename}")
# 进行检测 # 进行检测
if img is not None:
has_violation, violation_type, details = detector.detect_violations(img)
if has_violation:
print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
else:
print("未检测到任何违规内容")
else:
print(f"无法读取测试图像:")
# -------------------------- WebSocket核心端点 -------------------------- # -------------------------- WebSocket核心端点 --------------------------
@ws_router.websocket(WS_ENDPOINT) @ws_router.websocket(WS_ENDPOINT)