1
This commit is contained in:
		
							
								
								
									
										22
									
								
								ws/ws.py
									
									
									
									
									
								
							
							
						
						
									
										22
									
								
								ws/ws.py
									
									
									
									
									
								
							@ -18,9 +18,7 @@ from ocr.model_violation_detector import MultiModelViolationDetector
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
# 配置文件路径(建议实际部署时改为相对路径或环境变量)
 | 
					# 配置文件路径(建议实际部署时改为相对路径或环境变量)
 | 
				
			||||||
YOLO_MODEL_PATH = r"D:\Git\bin\video\ocr\models\best.pt"
 | 
					YOLO_MODEL_PATH = r"D:\Git\bin\video\ocr\models\best.pt"
 | 
				
			||||||
FORBIDDEN_WORDS_PATH = r"D:\Git\bin\video\ocr\forbidden_words.txt"
 | 
					 | 
				
			||||||
OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml"
 | 
					OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml"
 | 
				
			||||||
KNOWN_FACES_DIR = r"D:\Git\bin\video\ocr\known_faces"
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 模型池配置(根据GPU显存调整,每个模型约占1G显存)
 | 
					# 模型池配置(根据GPU显存调整,每个模型约占1G显存)
 | 
				
			||||||
MODEL_POOL_SIZE = 3  # 最大并发客户端数
 | 
					MODEL_POOL_SIZE = 3  # 最大并发客户端数
 | 
				
			||||||
@ -32,6 +30,15 @@ WS_ENDPOINT = "/ws"  # WebSocket端点路径
 | 
				
			|||||||
FRAME_QUEUE_SIZE = 1  # 帧队列大小限制
 | 
					FRAME_QUEUE_SIZE = 1  # 帧队列大小限制
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 工具函数:获取格式化时间字符串
 | 
				
			||||||
 | 
					def get_current_time_str() -> str:
 | 
				
			||||||
 | 
					    return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_current_time_file_str() -> str:
 | 
				
			||||||
 | 
					    return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 模型池实现 - 提前初始化固定数量的模型实例
 | 
					# 模型池实现 - 提前初始化固定数量的模型实例
 | 
				
			||||||
class ModelPool:
 | 
					class ModelPool:
 | 
				
			||||||
    def __init__(self, pool_size: int = MODEL_POOL_SIZE):
 | 
					    def __init__(self, pool_size: int = MODEL_POOL_SIZE):
 | 
				
			||||||
@ -40,10 +47,8 @@ class ModelPool:
 | 
				
			|||||||
        # 提前初始化模型实例(显存会在此阶段预分配)
 | 
					        # 提前初始化模型实例(显存会在此阶段预分配)
 | 
				
			||||||
        for i in range(pool_size):
 | 
					        for i in range(pool_size):
 | 
				
			||||||
            detector = MultiModelViolationDetector(
 | 
					            detector = MultiModelViolationDetector(
 | 
				
			||||||
                forbidden_words_path=FORBIDDEN_WORDS_PATH,
 | 
					 | 
				
			||||||
                ocr_config_path=OCR_CONFIG_PATH,
 | 
					                ocr_config_path=OCR_CONFIG_PATH,
 | 
				
			||||||
                yolo_model_path=YOLO_MODEL_PATH,
 | 
					                yolo_model_path=YOLO_MODEL_PATH,
 | 
				
			||||||
                known_faces_dir=KNOWN_FACES_DIR,
 | 
					 | 
				
			||||||
                ocr_confidence_threshold=0.5
 | 
					                ocr_confidence_threshold=0.5
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            self.pool.put(detector)
 | 
					            self.pool.put(detector)
 | 
				
			||||||
@ -64,15 +69,6 @@ class ModelPool:
 | 
				
			|||||||
model_pool = ModelPool(pool_size=MODEL_POOL_SIZE)
 | 
					model_pool = ModelPool(pool_size=MODEL_POOL_SIZE)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 工具函数:获取格式化时间字符串
 | 
					 | 
				
			||||||
def get_current_time_str() -> str:
 | 
					 | 
				
			||||||
    return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_current_time_file_str() -> str:
 | 
					 | 
				
			||||||
    return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# 客户端连接封装
 | 
					# 客户端连接封装
 | 
				
			||||||
class ClientConnection:
 | 
					class ClientConnection:
 | 
				
			||||||
    def __init__(self, websocket: WebSocket, client_ip: str):
 | 
					    def __init__(self, websocket: WebSocket, client_ip: str):
 | 
				
			||||||
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user