1
This commit is contained in:
		
							
								
								
									
										22
									
								
								ws/ws.py
									
									
									
									
									
								
							
							
						
						
									
										22
									
								
								ws/ws.py
									
									
									
									
									
								
							| @ -18,9 +18,7 @@ from ocr.model_violation_detector import MultiModelViolationDetector | ||||
|  | ||||
| # 配置文件路径(建议实际部署时改为相对路径或环境变量) | ||||
| YOLO_MODEL_PATH = r"D:\Git\bin\video\ocr\models\best.pt" | ||||
| FORBIDDEN_WORDS_PATH = r"D:\Git\bin\video\ocr\forbidden_words.txt" | ||||
| OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml" | ||||
| KNOWN_FACES_DIR = r"D:\Git\bin\video\ocr\known_faces" | ||||
|  | ||||
| # 模型池配置(根据GPU显存调整,每个模型约占1G显存) | ||||
| MODEL_POOL_SIZE = 3  # 最大并发客户端数 | ||||
| @ -32,6 +30,15 @@ WS_ENDPOINT = "/ws"  # WebSocket端点路径 | ||||
| FRAME_QUEUE_SIZE = 1  # 帧队列大小限制 | ||||
|  | ||||
|  | ||||
| # 工具函数:获取格式化时间字符串 | ||||
| def get_current_time_str() -> str: | ||||
|     return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") | ||||
|  | ||||
|  | ||||
| def get_current_time_file_str() -> str: | ||||
|     return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") | ||||
|  | ||||
|  | ||||
| # 模型池实现 - 提前初始化固定数量的模型实例 | ||||
| class ModelPool: | ||||
|     def __init__(self, pool_size: int = MODEL_POOL_SIZE): | ||||
| @ -40,10 +47,8 @@ class ModelPool: | ||||
|         # 提前初始化模型实例(显存会在此阶段预分配) | ||||
|         for i in range(pool_size): | ||||
|             detector = MultiModelViolationDetector( | ||||
|                 forbidden_words_path=FORBIDDEN_WORDS_PATH, | ||||
|                 ocr_config_path=OCR_CONFIG_PATH, | ||||
|                 yolo_model_path=YOLO_MODEL_PATH, | ||||
|                 known_faces_dir=KNOWN_FACES_DIR, | ||||
|                 ocr_confidence_threshold=0.5 | ||||
|             ) | ||||
|             self.pool.put(detector) | ||||
| @ -64,15 +69,6 @@ class ModelPool: | ||||
| model_pool = ModelPool(pool_size=MODEL_POOL_SIZE) | ||||
|  | ||||
|  | ||||
| # 工具函数:获取格式化时间字符串 | ||||
| def get_current_time_str() -> str: | ||||
|     return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") | ||||
|  | ||||
|  | ||||
| def get_current_time_file_str() -> str: | ||||
|     return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") | ||||
|  | ||||
|  | ||||
| # 客户端连接封装 | ||||
| class ClientConnection: | ||||
|     def __init__(self, websocket: WebSocket, client_ip: str): | ||||
|  | ||||
		Reference in New Issue
	
	Block a user