1
This commit is contained in:
22
ws/ws.py
22
ws/ws.py
@ -18,9 +18,7 @@ from ocr.model_violation_detector import MultiModelViolationDetector
|
||||
|
||||
# 配置文件路径(建议实际部署时改为相对路径或环境变量)
|
||||
YOLO_MODEL_PATH = r"D:\Git\bin\video\ocr\models\best.pt"
|
||||
FORBIDDEN_WORDS_PATH = r"D:\Git\bin\video\ocr\forbidden_words.txt"
|
||||
OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml"
|
||||
KNOWN_FACES_DIR = r"D:\Git\bin\video\ocr\known_faces"
|
||||
|
||||
# 模型池配置(根据GPU显存调整,每个模型约占1G显存)
|
||||
MODEL_POOL_SIZE = 3 # 最大并发客户端数
|
||||
@ -32,6 +30,15 @@ WS_ENDPOINT = "/ws" # WebSocket端点路径
|
||||
FRAME_QUEUE_SIZE = 1 # 帧队列大小限制
|
||||
|
||||
|
||||
# 工具函数:获取格式化时间字符串
|
||||
def get_current_time_str() -> str:
|
||||
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
def get_current_time_file_str() -> str:
|
||||
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
|
||||
|
||||
# 模型池实现 - 提前初始化固定数量的模型实例
|
||||
class ModelPool:
|
||||
def __init__(self, pool_size: int = MODEL_POOL_SIZE):
|
||||
@ -40,10 +47,8 @@ class ModelPool:
|
||||
# 提前初始化模型实例(显存会在此阶段预分配)
|
||||
for i in range(pool_size):
|
||||
detector = MultiModelViolationDetector(
|
||||
forbidden_words_path=FORBIDDEN_WORDS_PATH,
|
||||
ocr_config_path=OCR_CONFIG_PATH,
|
||||
yolo_model_path=YOLO_MODEL_PATH,
|
||||
known_faces_dir=KNOWN_FACES_DIR,
|
||||
ocr_confidence_threshold=0.5
|
||||
)
|
||||
self.pool.put(detector)
|
||||
@ -64,15 +69,6 @@ class ModelPool:
|
||||
model_pool = ModelPool(pool_size=MODEL_POOL_SIZE)
|
||||
|
||||
|
||||
# 工具函数:获取格式化时间字符串
|
||||
def get_current_time_str() -> str:
|
||||
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
def get_current_time_file_str() -> str:
|
||||
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
|
||||
|
||||
# 客户端连接封装
|
||||
class ClientConnection:
|
||||
def __init__(self, websocket: WebSocket, client_ip: str):
|
||||
|
Reference in New Issue
Block a user