参数分离完全版
This commit is contained in:
		
							
								
								
									
										31
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,31 @@ | |||||||
|  | ### models 目录 | ||||||
|  |  | ||||||
|  | - 存放模型文件 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | ### frame_transfer.py | ||||||
|  |  | ||||||
|  | - 从检测结果队列推送数据到 RTMP 服务器【不必修改】 | ||||||
|  | - 从原始队列拿取数据、调用 yolo_core 封装的方法进行检测【四类】 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | ### rtc_handler.py | ||||||
|  |  | ||||||
|  | - 从 WebRTC 实时视频流截取帧并持续推送到原始队列【不必修改】 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | ### yolo_core.py | ||||||
|  |  | ||||||
|  | - 封装四类方法【参数均为原始队列、检测结果队列】 | ||||||
|  | - 方法一:原始YOLO检测 | ||||||
|  | - 方法二:原始YOLO检测 + 汉化 + 颜色 | ||||||
|  | - 方法三:原始累计计数 | ||||||
|  | - 方法四:原始累计计数 + 汉化 + 颜色 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | **读取 WebRTC 流和推送结果帧的代码不需要修改** | ||||||
|  |  | ||||||
							
								
								
									
										389
									
								
								api_server.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										389
									
								
								api_server.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,389 @@ | |||||||
|  | import json | ||||||
|  | import os | ||||||
|  | import shutil | ||||||
|  | from typing import Dict, Optional, Any | ||||||
|  |  | ||||||
|  | from fastapi import APIRouter, File, UploadFile, HTTPException, Form | ||||||
|  | from fastapi.responses import JSONResponse | ||||||
|  | from pydantic import BaseModel | ||||||
|  |  | ||||||
|  | from result import Response | ||||||
|  | # 假设 rfdetr_core.py 在同一目录下或 PYTHONPATH 中 | ||||||
|  | from rfdetr_core import RFDETRDetector | ||||||
|  |  | ||||||
|  | # --- Global Variables and Configuration --- | ||||||
|  | model_management_router = APIRouter() | ||||||
|  |  | ||||||
|  | BASE_MODEL_DIR = "models" | ||||||
|  | BASE_CONFIG_DIR = "configs" | ||||||
|  | os.makedirs(BASE_MODEL_DIR, exist_ok=True) | ||||||
|  | os.makedirs(BASE_CONFIG_DIR, exist_ok=True) | ||||||
|  |  | ||||||
|  | # 用于存储当前激活的检测器实例和可用模型信息 | ||||||
|  | current_detector: Optional[RFDETRDetector] = None | ||||||
|  | current_model_identifier: Optional[str] = None | ||||||
|  | available_models_info: Dict[str, Dict] = {}  # 存储模型标识符及其配置内容 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def load_config(config_name: str) -> Optional[Dict]: | ||||||
|  |     """加载指定的JSON配置文件。""" | ||||||
|  |     config_path = os.path.join(BASE_CONFIG_DIR, f"{config_name}.json") | ||||||
|  |     if os.path.exists(config_path): | ||||||
|  |         try: | ||||||
|  |             with open(config_path, 'r', encoding='utf-8') as f: | ||||||
|  |                 return json.load(f) | ||||||
|  |         except Exception as e: | ||||||
|  |             print(f"错误:加载配置文件 '{config_path}' 失败: {e}") | ||||||
|  |             return None | ||||||
|  |     return None | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def initialize_detector(model_identifier: str) -> Optional[RFDETRDetector]: | ||||||
|  |     """根据模型标识符初始化检测器。""" | ||||||
|  |     global current_detector, current_model_identifier | ||||||
|  |     try: | ||||||
|  |         config = available_models_info.get(model_identifier) | ||||||
|  |         if not config: | ||||||
|  |             print(f"错误:未找到模型 '{model_identifier}' 的缓存配置,尝试从磁盘加载。")  # 更明确的日志 | ||||||
|  |             config_data = load_config(model_identifier) | ||||||
|  |             if not config_data: | ||||||
|  |                 raise FileNotFoundError(f"配置文件 {model_identifier}.json 未找到或无法加载。") | ||||||
|  |             available_models_info[model_identifier] = config_data | ||||||
|  |             config = config_data | ||||||
|  |  | ||||||
|  |         model_filename = config.get('model_pth_filename') | ||||||
|  |         if not model_filename: | ||||||
|  |             raise ValueError(f"配置文件 '{model_identifier}.json' 中缺少 'model_pth_filename' 字段。") | ||||||
|  |  | ||||||
|  |         model_full_path = os.path.join(BASE_MODEL_DIR, model_filename) | ||||||
|  |         if not os.path.exists(model_full_path): | ||||||
|  |             raise FileNotFoundError(f"模型文件 '{model_full_path}' (在 '{model_identifier}.json' 中指定) 不存在。") | ||||||
|  |  | ||||||
|  |         print(f"尝试使用配置 '{model_identifier}.json' 和模型 '{model_full_path}' 初始化检测器...") | ||||||
|  |         detector = RFDETRDetector(config_name=model_identifier, base_model_dir=BASE_MODEL_DIR, | ||||||
|  |                                   base_config_dir=BASE_CONFIG_DIR) | ||||||
|  |         print(f"检测器 '{model_identifier}' 初始化成功。") | ||||||
|  |         current_detector = detector | ||||||
|  |         current_model_identifier = model_identifier | ||||||
|  |  | ||||||
|  |         # 通知 DataPusher 更新其检测器实例 | ||||||
|  |         try: | ||||||
|  |             from data_pusher import get_data_pusher_instance  # 动态导入以避免潜在的循环依赖问题 | ||||||
|  |             pusher = get_data_pusher_instance() | ||||||
|  |             if pusher: | ||||||
|  |                 print(f"通知 DataPusher 更新其检测器实例为: {model_identifier}") | ||||||
|  |                 pusher.update_detector_instance(current_detector) | ||||||
|  |             # else: | ||||||
|  |             # 如果 pusher 为 None,可能是因为它尚未在主应用启动时被完全初始化 | ||||||
|  |             # data_pusher 模块内部的 initialize_data_pusher 负责记录其自身初始化状态 | ||||||
|  |             # print(f"DataPusher 实例尚未初始化 (或初始化失败),无法更新其检测器。") | ||||||
|  |         except ImportError: | ||||||
|  |             print( | ||||||
|  |                 "警告: 无法导入 data_pusher 模块以更新 DataPusher 的检测器实例。如果您不使用数据推送功能,此消息可忽略。") | ||||||
|  |         except Exception as e_pusher: | ||||||
|  |             print(f"警告: 通知 DataPusher 更新检测器时发生意外错误: {e_pusher}") | ||||||
|  |  | ||||||
|  |         return detector | ||||||
|  |     except FileNotFoundError as e: | ||||||
|  |         print(f"初始化检测器 '{model_identifier}' 失败 (文件未找到): {e}") | ||||||
|  |         if current_model_identifier == model_identifier: | ||||||
|  |             current_detector = None | ||||||
|  |             current_model_identifier = None | ||||||
|  |         raise HTTPException(status_code=404, detail=str(e))  # 保持 404,让上层知道是文件问题 | ||||||
|  |     except ValueError as e:  # 捕获配置字段缺失等问题 | ||||||
|  |         print(f"初始化检测器 '{model_identifier}' 失败 (配置值错误): {e}") | ||||||
|  |         if current_model_identifier == model_identifier: | ||||||
|  |             current_detector = None | ||||||
|  |             current_model_identifier = None | ||||||
|  |         raise HTTPException(status_code=400, detail=str(e))  # 配置问题用 400 | ||||||
|  |     except Exception as e:  # 其他来自 RFDETRDetector 内部的初始化错误 | ||||||
|  |         print(f"初始化检测器 '{model_identifier}' 失败 (内部错误): {e}") | ||||||
|  |         # import traceback | ||||||
|  |         # traceback.print_exc() # 在服务器日志中打印完整堆栈,方便调试 | ||||||
|  |         if current_model_identifier == model_identifier: | ||||||
|  |             current_detector = None | ||||||
|  |             current_model_identifier = None | ||||||
|  |         if model_identifier in available_models_info: | ||||||
|  |             del available_models_info[model_identifier] | ||||||
|  |         # 这些通常是服务器端问题或模型/库的兼容性问题,所以用500 | ||||||
|  |         raise HTTPException(status_code=500, detail=f"检测器 '{model_identifier}' 内部初始化失败: {e}") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_active_detector() -> Optional[RFDETRDetector]: | ||||||
|  |     """获取当前激活的检测器实例。""" | ||||||
|  |     return current_detector | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_active_model_identifier() -> Optional[str]: | ||||||
|  |     """获取当前激活的模型标识符。""" | ||||||
|  |     return current_model_identifier | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def scan_and_load_available_models(): | ||||||
|  |     """扫描配置目录,加载所有有效的模型配置。""" | ||||||
|  |     global available_models_info | ||||||
|  |     available_models_info = {} | ||||||
|  |     # print(f"扫描目录 '{BASE_CONFIG_DIR}' 以查找配置文件...") # 减少日志冗余 | ||||||
|  |     if not os.path.exists(BASE_CONFIG_DIR): | ||||||
|  |         # print(f"配置目录 '{BASE_CONFIG_DIR}' 不存在,跳过扫描。") | ||||||
|  |         return | ||||||
|  |     for filename in os.listdir(BASE_CONFIG_DIR): | ||||||
|  |         if filename.endswith(".json"): | ||||||
|  |             model_identifier = filename[:-5] | ||||||
|  |             # print(f"找到配置文件: {filename},模型标识符: {model_identifier}") | ||||||
|  |             config_data = load_config(model_identifier) | ||||||
|  |             if config_data: | ||||||
|  |                 model_pth = config_data.get('model_pth_filename') | ||||||
|  |                 model_full_path = os.path.join(BASE_MODEL_DIR, model_pth) if model_pth else None | ||||||
|  |                 if model_pth and os.path.exists(model_full_path): | ||||||
|  |                     available_models_info[model_identifier] = config_data | ||||||
|  |                     # print(f"模型配置 '{model_identifier}' 加载成功。") | ||||||
|  |                 # elif not model_pth: | ||||||
|  |                 # print(f"警告:模型 '{model_identifier}' 的配置文件中未指定 'model_pth_filename'。") | ||||||
|  |                 # else: | ||||||
|  |                 # print(f"警告:模型 '{model_identifier}' 的模型文件 '{model_full_path}' 未找到。") | ||||||
|  |             # else: | ||||||
|  |             # print(f"警告:无法加载模型 '{model_identifier}' 的配置文件。") | ||||||
|  |     print(f"可用模型配置已扫描/更新: {list(available_models_info.keys())}") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # --- Extracted Startup Logic (to be called by the main app) --- | ||||||
|  | async def initialize_default_model_on_startup(): | ||||||
|  |     """应用启动时,扫描并加载可用模型,尝试激活第一个。""" | ||||||
|  |     print("执行模型管理模块的启动初始化...") | ||||||
|  |     scan_and_load_available_models() | ||||||
|  |     global current_model_identifier, current_detector | ||||||
|  |  | ||||||
|  |     if available_models_info: | ||||||
|  |         first_model = sorted(list(available_models_info.keys()))[0] | ||||||
|  |         print(f"尝试将第一个可用模型 '{first_model}' 设置为活动模型。") | ||||||
|  |         try: | ||||||
|  |             initialize_detector(first_model) | ||||||
|  |             print(f"默认模型 '{current_model_identifier}' 加载并激活成功。") | ||||||
|  |         except HTTPException as e:  # 捕获 initialize_detector 抛出的HTTPException | ||||||
|  |             print(f"加载默认模型 '{first_model}' 失败: {e.detail} (状态码: {e.status_code})") | ||||||
|  |             current_model_identifier = None | ||||||
|  |             current_detector = None | ||||||
|  |         # 不需要再捕获 Exception as e,因为 initialize_detector 已经处理并转换为 HTTPException | ||||||
|  |     else: | ||||||
|  |         print("没有可用的模型配置,服务器启动但无默认模型激活。") | ||||||
|  |         current_model_identifier = None | ||||||
|  |         current_detector = None | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # --- Pydantic Models for Request/Response --- | ||||||
|  |  | ||||||
|  | class ModelIdentifier(BaseModel): | ||||||
|  |     model_identifier: str | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # Define a standard response model for OpenAPI documentation | ||||||
|  | class StandardResponse(BaseModel): | ||||||
|  |     code: int | ||||||
|  |     data: Optional[Any] = None | ||||||
|  |     message: str | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # --- API Endpoints --- | ||||||
|  |  | ||||||
|  | @model_management_router.post("/upload_model_and_config/", response_model=StandardResponse) | ||||||
|  | async def upload_model_and_config( | ||||||
|  |         model_identifier_form: str = Form(..., alias="model_identifier"), | ||||||
|  |         config_file: UploadFile = File(...), | ||||||
|  |         model_file: UploadFile = File(...) | ||||||
|  | ): | ||||||
|  |     """ | ||||||
|  |     上传模型文件 (.pth) 和配置文件 (.json)。 | ||||||
|  |     - **model_identifier**: 模型的唯一名称 (例如 "人车检测")。配置文件将以此名称保存 (e.g., "人车检测.json")。 | ||||||
|  |     - **config_file**: JSON 配置文件。 | ||||||
|  |     - **model_file**: Pytorch 模型文件 (.pth)。其文件名必须与配置文件中 'model_pth_filename' 字段指定的一致。 | ||||||
|  |     """ | ||||||
|  |     config_filename = f"{model_identifier_form}.json" | ||||||
|  |     config_path = os.path.join(BASE_CONFIG_DIR, config_filename) | ||||||
|  |     model_path = None  # 在 try 块外部声明,以便 finally 和 except 中可用 | ||||||
|  |     global current_model_identifier, current_detector  # current_detector 实际上在这里主要由 initialize_detector 设置 | ||||||
|  |  | ||||||
|  |     try: | ||||||
|  |         print(f"开始上传模型 '{model_identifier_form}'...") | ||||||
|  |         config_content = await config_file.read() | ||||||
|  |         try: | ||||||
|  |             config_data = json.loads(config_content.decode('utf-8')) | ||||||
|  |         except json.JSONDecodeError: | ||||||
|  |             # raise HTTPException(status_code=400, detail="无效的JSON配置文件格式,请检查JSON语法。") | ||||||
|  |             return JSONResponse(status_code=400, | ||||||
|  |                                 content=Response.error(message="无效的JSON配置文件格式,请检查JSON语法。", code=400)) | ||||||
|  |         except UnicodeDecodeError: | ||||||
|  |             # raise HTTPException(status_code=400, detail="配置文件编码错误,请确保为UTF-8编码。") | ||||||
|  |             return JSONResponse(status_code=400, | ||||||
|  |                                 content=Response.error(message="配置文件编码错误,请确保为UTF-8编码。", code=400)) | ||||||
|  |  | ||||||
|  |         if 'model_pth_filename' not in config_data: | ||||||
|  |             # raise HTTPException(status_code=400, detail="配置文件中必须包含 'model_pth_filename' 字段。") | ||||||
|  |             return JSONResponse(status_code=400, | ||||||
|  |                                 content=Response.error(message="配置文件中必须包含 'model_pth_filename' 字段。", | ||||||
|  |                                                        code=400)) | ||||||
|  |  | ||||||
|  |         target_model_filename = config_data['model_pth_filename'] | ||||||
|  |         if not isinstance(target_model_filename, str) or not target_model_filename.endswith(".pth"): | ||||||
|  |             # raise HTTPException(status_code=400, detail="配置文件中的 'model_pth_filename' 必须是有效的 .pth 文件名字符串。") | ||||||
|  |             return JSONResponse(status_code=400, content=Response.error( | ||||||
|  |                 message="配置文件中的 'model_pth_filename' 必须是有效的 .pth 文件名字符串。", code=400)) | ||||||
|  |  | ||||||
|  |         if model_file.filename != target_model_filename: | ||||||
|  |             # raise HTTPException( | ||||||
|  |             #     status_code=400, | ||||||
|  |             #     detail=f"上传的模型文件名 '{model_file.filename}' 与配置文件中指定的 '{target_model_filename}' 不匹配。" | ||||||
|  |             # ) | ||||||
|  |             return JSONResponse( | ||||||
|  |                 status_code=400, | ||||||
|  |                 content=Response.error( | ||||||
|  |                     message=f"上传的模型文件名 '{model_file.filename}' 与配置文件中指定的 '{target_model_filename}' 不匹配。", | ||||||
|  |                     code=400 | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         with open(config_path, 'wb') as f: | ||||||
|  |             f.write(config_content) | ||||||
|  |         print(f"配置文件 '{config_path}' 已保存。") | ||||||
|  |  | ||||||
|  |         model_path = os.path.join(BASE_MODEL_DIR, target_model_filename) | ||||||
|  |         with open(model_path, "wb") as buffer: | ||||||
|  |             shutil.copyfileobj(model_file.file, buffer) | ||||||
|  |         print(f"模型文件 '{model_path}' 已保存。") | ||||||
|  |  | ||||||
|  |         # 在尝试初始化之前,将配置数据添加到 available_models_info | ||||||
|  |         # 这样 initialize_detector 即使在缓存未命中的情况下也能通过 load_config 找到它 | ||||||
|  |         available_models_info[model_identifier_form] = config_data | ||||||
|  |  | ||||||
|  |         try: | ||||||
|  |             print(f"尝试初始化并激活新上传的模型: '{model_identifier_form}'") | ||||||
|  |             initialize_detector(model_identifier_form)  # 此函数会设置全局的 current_detector 和 current_model_identifier | ||||||
|  |             print(f"新上传的模型 '{model_identifier_form}' 验证并激活成功。") | ||||||
|  |             # 成功后,available_models_info 已包含此模型,current_detector 和 current_model_identifier 已更新 | ||||||
|  |             # 无需在此处再次调用 scan_and_load_available_models(),因为它会重新扫描所有,可能覆盖内存中的一些状态或引入不必要的IO | ||||||
|  |         except HTTPException as e:  # 从 initialize_detector 抛出的错误 | ||||||
|  |             print(f"初始化新上传的模型 '{model_identifier_form}' 失败: {e.detail}") | ||||||
|  |             # 清理已保存的文件和内存中的条目 | ||||||
|  |             if os.path.exists(config_path): os.remove(config_path) | ||||||
|  |             if model_path and os.path.exists(model_path): os.remove(model_path) | ||||||
|  |             if model_identifier_form in available_models_info: del available_models_info[model_identifier_form] | ||||||
|  |             # scan_and_load_available_models() # 失败后重新扫描是好的,以确保 available_models_info 准确 | ||||||
|  |             # 但如果 initialize_detector 内部已经删除了 available_models_info 中的条目,这里可能不需要 | ||||||
|  |             # 为保持一致性,如果上面删除了,这里重新扫描一下比较稳妥 | ||||||
|  |             scan_and_load_available_models() | ||||||
|  |  | ||||||
|  |             # 重新抛出为 422,并包含原始错误信息 | ||||||
|  |             # raise HTTPException(status_code=422,  | ||||||
|  |             #                     detail=f"模型 '{model_identifier_form}' 已上传但初始化失败,请检查模型或配置内容。原始错误: {e.detail}") | ||||||
|  |             return JSONResponse(status_code=422, | ||||||
|  |                                 content=Response.error( | ||||||
|  |                                     message=f"模型 '{model_identifier_form}' 已上传但初始化失败,请检查模型或配置内容。原始错误: {e.detail}", | ||||||
|  |                                     code=422 | ||||||
|  |                                 )) | ||||||
|  |  | ||||||
|  |         # 如果成功,确保全局模型列表是最新的(虽然 initialize_detector 已更新了 current_*,但列表可能需要刷新以供 /available_models 使用) | ||||||
|  |         # scan_and_load_available_models() # 移除这个,因为当前模型已激活,列表会在下次调用 /available_models 时刷新 | ||||||
|  |  | ||||||
|  |         # return UploadResponse( | ||||||
|  |         #     message=f"模型 '{model_identifier_form}' 和配置文件上传成功并已激活。", | ||||||
|  |         #     model_identifier=model_identifier_form, | ||||||
|  |         #     config_filename=config_filename, | ||||||
|  |         #     model_filename=target_model_filename | ||||||
|  |         # ) | ||||||
|  |         return Response.success(data={ | ||||||
|  |             "message": f"模型 '{model_identifier_form}' 和配置文件上传成功并已激活。",  # Message is also in the wrapper | ||||||
|  |             "model_identifier": model_identifier_form, | ||||||
|  |             "config_filename": config_filename, | ||||||
|  |             "model_filename": target_model_filename | ||||||
|  |         }, message=f"模型 '{model_identifier_form}' 和配置文件上传成功并已激活。") | ||||||
|  |  | ||||||
|  |     except HTTPException as e:  # 捕获直接由 FastAPI 验证或其他地方抛出的 HTTPException | ||||||
|  |         # This might catch exceptions from initialize_detector if they are not caught internally by the above try-except for initialize_detector | ||||||
|  |         return JSONResponse(status_code=e.status_code, content=Response.error(message=e.detail, code=e.status_code)) | ||||||
|  |     except Exception as e: | ||||||
|  |         identifier_for_log = model_identifier_form if 'model_identifier_form' in locals() else "unknown" | ||||||
|  |         print(f"上传模型 '{identifier_for_log}' 过程中发生意外的严重错误: {e}") | ||||||
|  |         # import traceback | ||||||
|  |         # traceback.print_exc() | ||||||
|  |         # 尝试清理,以防文件部分写入 | ||||||
|  |         if config_path and os.path.exists(config_path): | ||||||
|  |             os.remove(config_path) | ||||||
|  |         if model_path and os.path.exists(model_path): | ||||||
|  |             os.remove(model_path) | ||||||
|  |         if 'model_identifier_form' in locals() and model_identifier_form in available_models_info: | ||||||
|  |             del available_models_info[model_identifier_form] | ||||||
|  |         scan_and_load_available_models()  # 出错后务必刷新列表 | ||||||
|  |  | ||||||
|  |         # raise HTTPException(status_code=500, detail=f"上传模型 '{identifier_for_log}' 过程中发生内部服务器错误: {str(e)}") | ||||||
|  |         return JSONResponse(status_code=500, content=Response.error( | ||||||
|  |             message=f"上传模型 '{identifier_for_log}' 过程中发生内部服务器错误: {str(e)}", code=500)) | ||||||
|  |     finally: | ||||||
|  |         if config_file: await config_file.close() | ||||||
|  |         if model_file: await model_file.close() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @model_management_router.post("/select_model/", response_model=StandardResponse) | ||||||
|  | # @model_management_router.post("/select_model/{model_identifier_path}", response_model=StandardResponse) | ||||||
|  | async def select_model(model_identifier_path: str): | ||||||
|  |     """ | ||||||
|  |     根据提供的标识符选择并激活一个已上传的模型。 | ||||||
|  |     路径参数 `model_identifier_path` 即为模型的唯一名称。 | ||||||
|  |     """ | ||||||
|  |     global current_model_identifier, current_detector | ||||||
|  |  | ||||||
|  |     # 总是先扫描以获取最新的可用模型列表,以防外部文件更改 | ||||||
|  |     scan_and_load_available_models() | ||||||
|  |     if model_identifier_path not in available_models_info: | ||||||
|  |         # raise HTTPException(status_code=404, detail=f"模型标识符 '{model_identifier_path}' 未在可用配置中找到。请确保已上传或名称正确。") | ||||||
|  |         return JSONResponse(status_code=404, content=Response.error( | ||||||
|  |             message=f"模型标识符 '{model_identifier_path}' 未在可用配置中找到。请确保已上传或名称正确。", code=404)) | ||||||
|  |  | ||||||
|  |     if current_model_identifier == model_identifier_path and current_detector is not None: | ||||||
|  |         print(f"模型 '{model_identifier_path}' 已经是活动模型。") | ||||||
|  |         # return SelectModelResponse( | ||||||
|  |         #     message=f"模型 '{model_identifier_path}' 已是当前活动模型。", | ||||||
|  |         #     active_model=current_model_identifier | ||||||
|  |         # ) | ||||||
|  |         return Response.success(data={ | ||||||
|  |             "active_model": current_model_identifier | ||||||
|  |         }, message=f"模型 '{model_identifier_path}' 已是当前活动模型。") | ||||||
|  |     try: | ||||||
|  |         print(f"尝试激活模型: {model_identifier_path}") | ||||||
|  |         initialize_detector(model_identifier_path) | ||||||
|  |         print(f"模型 '{current_model_identifier}' 已成功激活。") | ||||||
|  |         # return SelectModelResponse( | ||||||
|  |         #     message=f"模型 '{model_identifier_path}' 启动成功。", | ||||||
|  |         #     active_model=current_model_identifier | ||||||
|  |         # ) | ||||||
|  |         return Response.success(data={ | ||||||
|  |             "active_model": current_model_identifier | ||||||
|  |         }, message=f"模型 '{model_identifier_path}' 启动成功。") | ||||||
|  |     except HTTPException as e: | ||||||
|  |         # initialize_detector 内部已经处理了 current_model_identifier 和 current_detector 的清理 | ||||||
|  |         # 以及 available_models_info 中对应条目的移除(如果是其内部错误) | ||||||
|  |         print(f"激活模型 '{model_identifier_path}' 失败: {e.detail} (原始状态码: {e.status_code})") | ||||||
|  |         # 此处不再需要手动清理 available_models_info,因为 initialize_detector 如果因内部错误(非文件找不到) | ||||||
|  |         # 导致无法实例化 RFDETRDetector,它会自己删除条目。 | ||||||
|  |         # 如果是文件找不到 (404 from initialize_detector),条目可能还在,但下次 scan 会处理。 | ||||||
|  |         scan_and_load_available_models()  # 确保 select 失败后,列表也是最新的 | ||||||
|  |         # raise HTTPException(status_code=e.status_code, detail=e.detail) # 重新抛出原始的 HTTPException | ||||||
|  |         return JSONResponse(status_code=e.status_code, content=Response.error(message=e.detail, code=e.status_code)) | ||||||
|  |     # 不需要再捕获 Exception as e,因为 initialize_detector 已经处理并转换为 HTTPException | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @model_management_router.get("/available_models/", response_model=StandardResponse) | ||||||
|  | async def get_available_models(): | ||||||
|  |     """列出所有当前可用的模型标识符。""" | ||||||
|  |     scan_and_load_available_models() | ||||||
|  |     # return list(available_models_info.keys()) | ||||||
|  |     return Response.success(data=list(available_models_info.keys()), message="成功获取可用模型列表。") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @model_management_router.get("/current_model/", response_model=StandardResponse) | ||||||
|  | async def get_current_model_endpoint(): | ||||||
|  |     """获取当前激活的模型标识符。""" | ||||||
|  |     # return current_model_identifier | ||||||
|  |     if current_model_identifier: | ||||||
|  |         return Response.success(data=current_model_identifier, message="成功获取当前激活的模型。") | ||||||
|  |     else: | ||||||
|  |         return Response.success(data=None, message="当前没有激活的模型。") | ||||||
							
								
								
									
										58
									
								
								configs/人车.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								configs/人车.json
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,58 @@ | |||||||
|  | { | ||||||
|  |     "model_id": "1", | ||||||
|  |     "model_pth_filename": "人车.pth", | ||||||
|  |     "resolution": 448, | ||||||
|  |     "classes_en": [ | ||||||
|  |       "pedestrian", "person", "bicycle", "car", "van", "truck", "tricycle", "awning-tricycle", "bus", "motor" | ||||||
|  |     ], | ||||||
|  |  | ||||||
|  |     "classes_zh_map": { | ||||||
|  |       "pedestrian":"行人", | ||||||
|  |       "person": "人群", | ||||||
|  |       "bicycle": "自行车", | ||||||
|  |       "car": "小汽车", | ||||||
|  |       "van": "面包车", | ||||||
|  |       "truck": "卡车", | ||||||
|  |       "tricycle":"三轮车", | ||||||
|  |       "awning-tricycle":"篷式三轮车", | ||||||
|  |       "bus": "公交车", | ||||||
|  |       "motor":"摩托车" | ||||||
|  |     }, | ||||||
|  |  | ||||||
|  |     "class_colors_hex": { | ||||||
|  |       "pedestrian": "#470024",        | ||||||
|  |       "person": "#00FF00",            | ||||||
|  |       "bicycle": "#003153",           | ||||||
|  |       "car": "#002FA7",               | ||||||
|  |       "van": "#800080",               | ||||||
|  |       "truck": "#D44848",             | ||||||
|  |       "tricycle": "#003153",          | ||||||
|  |       "awning-tricycle": "#FBDC6A",   | ||||||
|  |       "bus": "#492D22",               | ||||||
|  |       "motor": "#01847F"              | ||||||
|  |     }, | ||||||
|  |      | ||||||
|  |     "detection_settings": { | ||||||
|  |       "enabled_classes": { | ||||||
|  |         "pedestrian": true, | ||||||
|  |         "person": true, | ||||||
|  |         "bicycle": false,  | ||||||
|  |         "car": true, | ||||||
|  |         "van": true, | ||||||
|  |         "truck": true, | ||||||
|  |         "tricycle": false,  | ||||||
|  |         "awning-tricycle": false, | ||||||
|  |         "bus": true, | ||||||
|  |         "motor": true | ||||||
|  |       }, | ||||||
|  |       "default_confidence_threshold": 0.7 | ||||||
|  |     }, | ||||||
|  |  | ||||||
|  |     "default_color_hex": "#00FF00", | ||||||
|  |     "tracker_activation_threshold": 0.5, | ||||||
|  |     "tracker_lost_buffer": 120, | ||||||
|  |     "tracker_match_threshold": 0.85, | ||||||
|  |     "tracker_frame_rate": 25, | ||||||
|  |     "tracker_consecutive_frames": 2 | ||||||
|  |   } | ||||||
|  |    | ||||||
							
								
								
									
										248
									
								
								data_pusher.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										248
									
								
								data_pusher.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,248 @@ | |||||||
|  | import base64 | ||||||
|  | import cv2 | ||||||
|  | import numpy as np | ||||||
|  | import requests | ||||||
|  | import time | ||||||
|  | import datetime | ||||||
|  | import logging | ||||||
|  | from fastapi import APIRouter, HTTPException, Body # 切换到 APIRouter | ||||||
|  | from pydantic import BaseModel, HttpUrl | ||||||
|  | # import uvicorn # 不再由此文件运行uvicorn | ||||||
|  | from apscheduler.schedulers.background import BackgroundScheduler | ||||||
|  | from typing import Optional # 用于类型提示 | ||||||
|  | # 确保 RFDETRDetector 可以被导入,假设 rfdetr_core.py 在同一目录或 PYTHONPATH 中 | ||||||
|  | # from rfdetr_core import RFDETRDetector # 在实际使用中取消注释并确保路径正确 | ||||||
|  |  | ||||||
|  | # 配置日志记录 | ||||||
|  | logging.basicConfig(level=logging.INFO) | ||||||
|  | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
|  | # app = FastAPI(title="Data Pusher Service", version="1.0.0") # 移除独立的 FastAPI app | ||||||
|  | pusher_router = APIRouter() # 创建一个 APIRouter | ||||||
|  |  | ||||||
|  | class DataPusher: | ||||||
|  |     def __init__(self, detector): # detector: RFDETRDetector | ||||||
|  |         if detector is None: | ||||||
|  |             logger.error("DataPusher initialized with a None detector. Push functionality will be impaired.") | ||||||
|  |             # 仍然创建实例,但功能会受限,_get_data_payload 会处理 detector is None | ||||||
|  |         self.detector = detector | ||||||
|  |         self.scheduler = BackgroundScheduler(daemon=True) | ||||||
|  |         self.push_job_id = "rt_push_job" | ||||||
|  |         self.target_url = None | ||||||
|  |         if not self.scheduler.running: | ||||||
|  |             try: | ||||||
|  |                 self.scheduler.start() | ||||||
|  |             except Exception as e: | ||||||
|  |                 logger.error(f"Error starting APScheduler in DataPusher: {e}") | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     def update_detector_instance(self, detector): | ||||||
|  |         """允许在运行时更新检测器实例,例如当主应用切换模型时""" | ||||||
|  |         logger.info(f"DataPusher's detector instance is being updated.") | ||||||
|  |         self.detector = detector | ||||||
|  |         if detector is None: | ||||||
|  |             logger.warning("DataPusher's detector instance updated to None.") | ||||||
|  |  | ||||||
|  |     def _get_data_payload(self): | ||||||
|  |         """获取当前的类别计数和最新标注的帧""" | ||||||
|  |         if self.detector is None: | ||||||
|  |             logger.warning("DataPusher: Detector not available. Cannot get data payload.") | ||||||
|  |             return { # 即使检测器不可用,也返回一个结构,包含空数据 | ||||||
|  |                 # "timestamp": time.time(), | ||||||
|  |                 "category_counts": {}, | ||||||
|  |                 "frame_base64": None, | ||||||
|  |                 "error": "Detector not available" | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |         category_counts = getattr(self.detector, 'category_counts', {}) | ||||||
|  |         # 如果 detector 存在但没有 last_annotated_frame (例如模型刚加载还没处理第一帧) | ||||||
|  |         last_frame_np = getattr(self.detector, 'last_annotated_frame', None) | ||||||
|  |  | ||||||
|  |         frame_base64 = None | ||||||
|  |         if last_frame_np is not None and isinstance(last_frame_np, np.ndarray): | ||||||
|  |             try: | ||||||
|  |                 _, buffer = cv2.imencode('.jpg', last_frame_np) | ||||||
|  |                 frame_base64 = base64.b64encode(buffer).decode('utf-8') | ||||||
|  |             except Exception as e: | ||||||
|  |                 logger.error(f"Error encoding frame to base64: {e}") | ||||||
|  |          | ||||||
|  |         return { | ||||||
|  |             # "timestamp": time.time(), | ||||||
|  |             "category_counts": category_counts, | ||||||
|  |             "frame_base64": frame_base64 | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |     def _push_data_task(self): | ||||||
|  |         """执行数据推送的任务""" | ||||||
|  |         if not self.target_url: | ||||||
|  |             # logger.warning("Target URL not set. Skipping push task.") # 减少日志噪音,仅在初次设置时记录 | ||||||
|  |             return | ||||||
|  |  | ||||||
|  |         payload = self._get_data_payload() | ||||||
|  |         # if payload is None: # _get_data_payload 现在总会返回一个字典 | ||||||
|  |         #     logger.warning("No payload to push.") | ||||||
|  |         #     return | ||||||
|  |  | ||||||
|  |         try: | ||||||
|  |             response = requests.post(self.target_url, json=payload, timeout=5) | ||||||
|  |             response.raise_for_status()  | ||||||
|  |             logger.debug(f"Data pushed successfully to {self.target_url}. Status: {response.status_code}") # 改为 debug 级别 | ||||||
|  |         except requests.exceptions.RequestException as e: | ||||||
|  |             logger.error(f"Error pushing data to {self.target_url}: {e}") | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.error(f"An unexpected error occurred during data push: {e}") | ||||||
|  |  | ||||||
|  |     def setup_push_schedule(self, frequency: float, target_url: str): | ||||||
|  |         """设置或更新推送计划""" | ||||||
|  |         if not isinstance(frequency, (int, float)) or frequency <= 0: | ||||||
|  |             raise ValueError("Frequency must be a positive number (pushes per second).") | ||||||
|  |  | ||||||
|  |         self.target_url = str(target_url)  | ||||||
|  |         interval_seconds = 1.0 / frequency | ||||||
|  |  | ||||||
|  |         if not self.scheduler.running: # 确保调度器正在运行 | ||||||
|  |             try: | ||||||
|  |                 logger.info("APScheduler was not running. Attempting to start it now.") | ||||||
|  |                 self.scheduler.start() | ||||||
|  |             except Exception as e: | ||||||
|  |                 logger.error(f"Failed to start APScheduler in setup_push_schedule: {e}") | ||||||
|  |                 raise RuntimeError(f"APScheduler could not be started: {e}") | ||||||
|  |  | ||||||
|  |  | ||||||
|  |         try: | ||||||
|  |             if self.scheduler.get_job(self.push_job_id): | ||||||
|  |                 self.scheduler.remove_job(self.push_job_id) | ||||||
|  |                 logger.info(f"Removed existing push job: {self.push_job_id}") | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.error(f"Error removing existing job: {e}") | ||||||
|  |  | ||||||
|  |         first_run_time = datetime.datetime.now() + datetime.timedelta(seconds=10) | ||||||
|  |         self.scheduler.add_job( | ||||||
|  |             self._push_data_task, | ||||||
|  |             trigger='interval', | ||||||
|  |             seconds=interval_seconds, | ||||||
|  |             id=self.push_job_id, | ||||||
|  |             next_run_time=first_run_time, | ||||||
|  |             replace_existing=True | ||||||
|  |         ) | ||||||
|  |         logger.info(f"Push task scheduled to {self.target_url} every {interval_seconds:.2f}s, starting in 10s.") | ||||||
|  |  | ||||||
|  |     def stop_push_schedule(self): | ||||||
|  |         """停止数据推送任务""" | ||||||
|  |         if self.scheduler.get_job(self.push_job_id): | ||||||
|  |             try: | ||||||
|  |                 self.scheduler.remove_job(self.push_job_id) | ||||||
|  |                 logger.info(f"Push job {self.push_job_id} stopped successfully.") | ||||||
|  |                 self.target_url = None # 清除目标 URL | ||||||
|  |             except Exception as e: | ||||||
|  |                 logger.error(f"Error stopping push job {self.push_job_id}: {e}") | ||||||
|  |         else: | ||||||
|  |             logger.info("No active push job to stop.") | ||||||
|  |              | ||||||
|  |     def shutdown_scheduler(self): | ||||||
|  |         """安全关闭调度器""" | ||||||
|  |         if self.scheduler.running: | ||||||
|  |             try: | ||||||
|  |                 self.scheduler.shutdown() | ||||||
|  |                 logger.info("DataPusher's APScheduler shut down successfully.") | ||||||
|  |             except Exception as e: | ||||||
|  |                 logger.error(f"Error shutting down DataPusher's APScheduler: {e}") | ||||||
|  |  | ||||||
|  |     def push_specific_payload(self, payload: dict): | ||||||
|  |         """推送一个特定的、已格式化的数据负载到配置的 target_url。""" | ||||||
|  |         if not self.target_url: | ||||||
|  |             logger.warning("DataPusher: Target URL not set. Cannot push specific payload.") | ||||||
|  |             return | ||||||
|  |          | ||||||
|  |         if not payload: | ||||||
|  |             logger.warning("DataPusher: Received empty payload for specific push. Skipping.") | ||||||
|  |             return | ||||||
|  |  | ||||||
|  |         logger.info(f"Attempting to push specific payload to {self.target_url}") | ||||||
|  |         try: | ||||||
|  |             response = requests.post(self.target_url, json=payload, timeout=10) # Increased timeout for one-off | ||||||
|  |             response.raise_for_status()  | ||||||
|  |             logger.info(f"Specific payload pushed successfully to {self.target_url}. Status: {response.status_code}") | ||||||
|  |         except requests.exceptions.RequestException as e: | ||||||
|  |             logger.error(f"Error pushing specific payload to {self.target_url}: {e}") | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.error(f"An unexpected error occurred during specific payload push: {e}") | ||||||
|  |  | ||||||
|  | # 全局 DataPusher 实例,将由主应用初始化 | ||||||
|  | data_pusher_instance: Optional[DataPusher] = None | ||||||
|  |  | ||||||
|  | # --- FastAPI Request Body Model --- | ||||||
|  | class PushConfigRequest(BaseModel): | ||||||
|  |     frequency: float | ||||||
|  |     url: HttpUrl  | ||||||
|  |  | ||||||
|  | # --- FastAPI HTTP Endpoint (using APIRouter) --- | ||||||
|  | @pusher_router.post("/setup_push", summary="配置数据推送任务") | ||||||
|  | async def handle_setup_push(config: PushConfigRequest = Body(...)): | ||||||
|  |     global data_pusher_instance | ||||||
|  |     if data_pusher_instance is None: | ||||||
|  |         # 这个错误理论上不应该发生,如果主应用正确初始化了 data_pusher_instance | ||||||
|  |         logger.error("CRITICAL: /setup_push called but data_pusher_instance is None. Main app did not initialize it.") | ||||||
|  |         raise HTTPException(status_code=503, detail="DataPusher service not available. Initialization may have failed.") | ||||||
|  |      | ||||||
|  |     if config.frequency <= 0: # Pydantic v2 中可以直接用 gt=0 | ||||||
|  |         raise HTTPException(status_code=400, detail="Invalid frequency value. Must be a positive number.") | ||||||
|  |  | ||||||
|  |     try: | ||||||
|  |         data_pusher_instance.setup_push_schedule(config.frequency, str(config.url)) | ||||||
|  |         return { | ||||||
|  |             "message": "Push task configured successfully.", | ||||||
|  |             "frequency_hz": config.frequency, | ||||||
|  |             "interval_seconds": 1.0 / config.frequency, | ||||||
|  |             "target_url": str(config.url), | ||||||
|  |             "first_push_delay_seconds": 10 | ||||||
|  |         } | ||||||
|  |     except ValueError as ve:  | ||||||
|  |         raise HTTPException(status_code=400, detail=str(ve)) | ||||||
|  |     except RuntimeError as re: # 例如 APScheduler 启动失败 | ||||||
|  |         logger.error(f"Runtime error during push schedule setup: {re}") | ||||||
|  |         raise HTTPException(status_code=500, detail=str(re)) | ||||||
|  |     except Exception as e: | ||||||
|  |         logger.error(f"Error setting up push schedule: {e}", exc_info=True) | ||||||
|  |         raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | ||||||
|  |  | ||||||
|  | @pusher_router.post("/stop_push", summary="停止当前数据推送任务") | ||||||
|  | async def handle_stop_push(): | ||||||
|  |     global data_pusher_instance | ||||||
|  |     if data_pusher_instance is None: | ||||||
|  |         logger.error("CRITICAL: /stop_push called but data_pusher_instance is None.") | ||||||
|  |         raise HTTPException(status_code=503, detail="DataPusher service not available.") | ||||||
|  |      | ||||||
|  |     try: | ||||||
|  |         data_pusher_instance.stop_push_schedule() | ||||||
|  |         return {"message": "Push task stopped successfully if it was running."} | ||||||
|  |     except Exception as e: | ||||||
|  |         logger.error(f"Error stopping push schedule: {e}", exc_info=True) | ||||||
|  |         raise HTTPException(status_code=500, detail=f"Internal server error while stopping push: {str(e)}") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # --- Initialization and Shutdown Functions for Main App --- | ||||||
|  | def initialize_data_pusher(detector_instance_param): # Renamed to avoid conflict | ||||||
|  |     """ | ||||||
|  |     由主应用程序调用以创建和配置 DataPusher 实例。 | ||||||
|  |     """ | ||||||
|  |     global data_pusher_instance | ||||||
|  |     if data_pusher_instance is None: | ||||||
|  |         logger.info("Initializing DataPusher instance...") | ||||||
|  |         data_pusher_instance = DataPusher(detector_instance_param) | ||||||
|  |     else: | ||||||
|  |         logger.info("DataPusher instance already initialized. Updating detector instance if provided.") | ||||||
|  |         data_pusher_instance.update_detector_instance(detector_instance_param) | ||||||
|  |     return data_pusher_instance | ||||||
|  |  | ||||||
|  | def get_data_pusher_instance() -> Optional[DataPusher]: | ||||||
|  |     """获取 DataPusher 实例 (主要用于主应用可能需要访问它的其他方法,如 shutdown)""" | ||||||
|  |     return data_pusher_instance | ||||||
|  |  | ||||||
|  | # 移除 if __name__ == '__main__' 和 run_pusher_service,因为不再独立运行 | ||||||
|  | # 示例代码可以移至主应用的文档或测试脚本中。 | ||||||
|  |  | ||||||
|  | # 注意: | ||||||
|  | # RFDETRDetector 实例的生命周期由 api_server.py (current_detector) 管理。 | ||||||
|  | # 当 api_server.py 中的模型切换时,需要有一种机制来更新 DataPusher 内部的 detector 引用。 | ||||||
|  | # initialize_data_pusher 可以被多次调用 (例如,在模型切换后),它会更新 DataPusher 持有的 detector 实例。 | ||||||
							
								
								
									
										
											BIN
										
									
								
								font/MSYH.TTC
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								font/MSYH.TTC
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										226
									
								
								frame_transfer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										226
									
								
								frame_transfer.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,226 @@ | |||||||
|  | import base64 | ||||||
|  | import requests as http_client | ||||||
|  | import queue | ||||||
|  | import time  # 确保 time 被导入,如果之前被误删 | ||||||
|  | import cv2 | ||||||
|  | import traceback | ||||||
|  | import threading  # 确保导入 threading | ||||||
|  | import av  # 重新导入 av | ||||||
|  | import numpy as np | ||||||
|  |  | ||||||
|  | # from fastapi import requests | ||||||
|  |  | ||||||
|  | # 从 rfdetr_core 导入 RFDETRDetector 仅用于类型提示 (可选) | ||||||
|  | from rfdetr_core import RFDETRDetector | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 目标检测处理函数 | ||||||
|  | # 函数签名已更改:现在接受一个 detector_instance 作为参数 | ||||||
|  | def yolo_frame(rtc_q: queue.Queue, yolo_q: queue.Queue, stream_detector_instance: RFDETRDetector): | ||||||
|  |     thread_name = threading.current_thread().name  # 获取线程名称用于日志 | ||||||
|  |     print(f"处理线程 '{thread_name}' 已启动。") | ||||||
|  |     error_message_displayed_once = False | ||||||
|  |     no_detector_message_displayed_once = False  # 用于只提示一次没有检测器 | ||||||
|  |  | ||||||
|  |     if stream_detector_instance is None: | ||||||
|  |         print(f"错误 (线程 '{thread_name}'): 未提供有效的检测器实例给yolo_frame。线程将无法处理视频。") | ||||||
|  |         # 此线程实际上无法做任何有用的工作,可以考虑直接退出或进入一个安全循环 | ||||||
|  |         # 为简单起见,我们允许它进入主循环,但它会在每次迭代时打印警告 | ||||||
|  |  | ||||||
|  |     while True: | ||||||
|  |         frame = None | ||||||
|  |         # current_category_counts = {} # 将在获取后转换 | ||||||
|  |         try: | ||||||
|  |             # 恢复队列长度打印 | ||||||
|  |             print(f"线程 '{thread_name}' - 原始队列长度: {rtc_q.qsize()}, 检测队列长度: {yolo_q.qsize()}") | ||||||
|  |  | ||||||
|  |             frame = rtc_q.get(timeout=0.1) | ||||||
|  |             if frame is None: | ||||||
|  |                 print(f"处理线程 '{thread_name}' 接收到停止信号,正在退出...") | ||||||
|  |                 # 发送包含None frame和空计数的字典作为停止信号 | ||||||
|  |                 yolo_q.put({"frame": None, "category_counts": {}}) | ||||||
|  |                 break | ||||||
|  |  | ||||||
|  |             category_counts_for_packet = {} | ||||||
|  |             if stream_detector_instance: | ||||||
|  |                 no_detector_message_displayed_once = False  # 检测器有效,重置提示 | ||||||
|  |                 annotated_frame = stream_detector_instance.detect_and_draw_count(frame) | ||||||
|  |                 error_message_displayed_once = False | ||||||
|  |                  | ||||||
|  |                 # 获取英文键的类别计数 | ||||||
|  |                 english_counts = stream_detector_instance.category_counts.copy() if hasattr(stream_detector_instance, 'category_counts') else {} | ||||||
|  |                  | ||||||
|  |                 # 转换为中文键的类别计数 | ||||||
|  |                 if hasattr(stream_detector_instance, 'VISDRONE_CLASSES_CHINESE'): | ||||||
|  |                     chinese_map = stream_detector_instance.VISDRONE_CLASSES_CHINESE | ||||||
|  |                     for eng_key, count_val in english_counts.items(): | ||||||
|  |                         # 使用 get 提供一个默认值,以防某个英文类别在中文映射表中确实没有 | ||||||
|  |                         chi_key = chinese_map.get(eng_key, eng_key)  | ||||||
|  |                         category_counts_for_packet[chi_key] = count_val | ||||||
|  |                 else: | ||||||
|  |                     # 如果没有中文映射表,则直接使用英文计数 (或记录警告) | ||||||
|  |                     category_counts_for_packet = english_counts | ||||||
|  |                     # logger.warning(f"线程 '{thread_name}': stream_detector_instance 没有 VISDRONE_CLASSES_CHINESE 属性,将使用英文类别计数。") | ||||||
|  |  | ||||||
|  |             else: | ||||||
|  |                 # 如果没有有效的检测器实例传递进来 | ||||||
|  |                 if not no_detector_message_displayed_once: | ||||||
|  |                     print(f"警告 (线程 '{thread_name}'): 无有效检测器实例。将在帧上绘制提示。") | ||||||
|  |                     no_detector_message_displayed_once = True | ||||||
|  |  | ||||||
|  |                 annotated_frame = frame.copy() | ||||||
|  |                 cv2.putText(annotated_frame, | ||||||
|  |                             "No detector instance provided for this stream", | ||||||
|  |                             (30, 50), cv2.FONT_HERSHEY_SIMPLEX, | ||||||
|  |                             1, (0, 0, 255), 2, cv2.LINE_AA) | ||||||
|  |                 category_counts_for_packet = {} # 无检测器,计数为空 | ||||||
|  |  | ||||||
|  |             # 将帧和类别计数一起放入队列 | ||||||
|  |             data_packet = {"frame": annotated_frame, "category_counts": category_counts_for_packet} | ||||||
|  |             try: | ||||||
|  |                 yolo_q.put_nowait(data_packet) | ||||||
|  |             except queue.Full: | ||||||
|  |                 # print(f"警告 (线程 '{thread_name}'): yolo_q 已满,丢弃帧。") # 避免刷屏 | ||||||
|  |                 pass | ||||||
|  |  | ||||||
|  |         except queue.Empty: | ||||||
|  |             time.sleep(0.01) | ||||||
|  |             continue | ||||||
|  |         except Exception as e: | ||||||
|  |             if not error_message_displayed_once: | ||||||
|  |                 print(f"线程 '{thread_name}' (yolo_frame) 处理时发生严重错误: {e}") | ||||||
|  |                 traceback.print_exc() | ||||||
|  |                 error_message_displayed_once = True | ||||||
|  |             time.sleep(1) | ||||||
|  |             if frame is not None: | ||||||
|  |                 try: | ||||||
|  |                     pass | ||||||
|  |                 except queue.Full: | ||||||
|  |                     pass | ||||||
|  |             continue | ||||||
|  |     print(f"处理线程 '{thread_name}' 已停止。") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def push_frame(yolo_q: queue.Queue, rtmp_url: str, gateway: str, frequency: int, push_url: str): | ||||||
|  |     thread_name = threading.current_thread().name | ||||||
|  |     print(f"推流线程 '{thread_name}' (RTMP: {rtmp_url}) 已启动。") | ||||||
|  |  | ||||||
|  |     output_container = None | ||||||
|  |     stream = None | ||||||
|  |     first_frame_processed = False | ||||||
|  |     last_push_time = 0  # 记录上次推送base64的时间 | ||||||
|  |  | ||||||
|  |     try: | ||||||
|  |         while True: | ||||||
|  |             frame_to_push = None | ||||||
|  |             received_category_counts = {} # 初始化为空字典 | ||||||
|  |             try: | ||||||
|  |                 data_packet = yolo_q.get(timeout=0.1) | ||||||
|  |                 if data_packet: | ||||||
|  |                     frame_to_push = data_packet.get("frame") | ||||||
|  |                     received_category_counts = data_packet.get("category_counts", {}) | ||||||
|  |                 else: # data_packet is None (不太可能,除非队列明确放入None) | ||||||
|  |                     time.sleep(0.01) | ||||||
|  |                     continue | ||||||
|  |  | ||||||
|  |             except queue.Empty: | ||||||
|  |                 time.sleep(0.01) | ||||||
|  |                 continue | ||||||
|  |  | ||||||
|  |             if frame_to_push is None: # 这是通过 data_packet["frame"] is None 来判断的 | ||||||
|  |                 print(f"推流线程 '{thread_name}' 接收到停止信号,正在清理并退出...") | ||||||
|  |                 break | ||||||
|  |  | ||||||
|  |             if not first_frame_processed: | ||||||
|  |                 if frame_to_push is not None: | ||||||
|  |                     try: | ||||||
|  |                         height, width, _ = frame_to_push.shape | ||||||
|  |                         print(f"线程 '{thread_name}': 首帧尺寸 {width}x{height},正在初始化RTMP推流器到 {rtmp_url}") | ||||||
|  |                         output_container = av.open(rtmp_url, 'w', format='flv') | ||||||
|  |                         stream = output_container.add_stream('libx264', rate=25) | ||||||
|  |                         stream.pix_fmt = 'yuv420p' | ||||||
|  |                         stream.width = width | ||||||
|  |                         stream.height = height | ||||||
|  |                         stream.options = {'preset': 'ultrafast', 'tune': 'zerolatency', 'crf': '25'} | ||||||
|  |                         print(f"线程 '{thread_name}': RTMP推流器初始化成功。") | ||||||
|  |                         first_frame_processed = True | ||||||
|  |                     except Exception as e_init: | ||||||
|  |                         print(f"错误 (线程 '{thread_name}'): 初始化PyAV推流容器/流失败: {e_init}") | ||||||
|  |                         traceback.print_exc() | ||||||
|  |                         return | ||||||
|  |                 else: | ||||||
|  |                     continue | ||||||
|  |  | ||||||
|  |             if not output_container or not stream: | ||||||
|  |                 print(f"错误 (线程 '{thread_name}'): 推流器未初始化,无法推流。可能是首帧处理失败。") | ||||||
|  |                 time.sleep(1) | ||||||
|  |                 continue | ||||||
|  |  | ||||||
|  |             # 持续推流到RTMP | ||||||
|  |             try: | ||||||
|  |                 video_frame = av.VideoFrame.from_ndarray(frame_to_push, format='bgr24') | ||||||
|  |                 for packet in stream.encode(video_frame): | ||||||
|  |                     output_container.mux(packet) | ||||||
|  |             except Exception as e_push: | ||||||
|  |                 print(f"错误 (线程 '{thread_name}'): 推送帧到RTMP时发生错误: {e_push}") | ||||||
|  |                 time.sleep(0.5) | ||||||
|  |  | ||||||
|  |             # 定时推送base64帧 | ||||||
|  |             current_time = time.time() | ||||||
|  |             if current_time - last_push_time >= frequency: | ||||||
|  |                 # 将接收到的类别计数传递给 push_base64_frame | ||||||
|  |                 push_base64_frame(frame_to_push, gateway, push_url, thread_name, received_category_counts) | ||||||
|  |                 last_push_time = current_time | ||||||
|  |  | ||||||
|  |     except Exception as e_outer: | ||||||
|  |         print(f"推流线程 '{thread_name}' 发生严重外部错误: {e_outer}") | ||||||
|  |         traceback.print_exc() | ||||||
|  |     finally: | ||||||
|  |         print(f"推流线程 '{thread_name}': 进入finally块,准备关闭推流器。") | ||||||
|  |         if stream and output_container: | ||||||
|  |             try: | ||||||
|  |                 print(f"推流线程 '{thread_name}': 正在编码流的剩余部分...") | ||||||
|  |                 for packet in stream.encode(None): | ||||||
|  |                     output_container.mux(packet) | ||||||
|  |                 print(f"推流线程 '{thread_name}': 编码剩余部分完成。") | ||||||
|  |             except Exception as e_flush: | ||||||
|  |                 print(f"错误 (线程 '{thread_name}'): 关闭推流流时发生编码/刷新错误: {e_flush}") | ||||||
|  |                 traceback.print_exc() | ||||||
|  |         if output_container: | ||||||
|  |             try: | ||||||
|  |                 print(f"推流线程 '{thread_name}': 正在关闭推流容器...") | ||||||
|  |                 output_container.close() | ||||||
|  |                 print(f"推流线程 '{thread_name}': 推流容器已关闭。") | ||||||
|  |             except Exception as e_close: | ||||||
|  |                 print(f"错误 (线程 '{thread_name}'): 关闭推流容器时发生错误: {e_close}") | ||||||
|  |                 traceback.print_exc() | ||||||
|  |         print(f"推流线程 '{thread_name}' 已停止并完成清理。") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def push_base64_frame(frame: np.ndarray, gateway: str, push_url: str, thread_name: str, category_counts: dict): | ||||||
|  |     """将帧转换为base64并推送到指定URL""" | ||||||
|  |     try: | ||||||
|  |         # 转换为JPEG格式 | ||||||
|  |         _, buffer = cv2.imencode('.jpg', frame) | ||||||
|  |         # 转换为base64字符串 | ||||||
|  |         frame_base64 = base64.b64encode(buffer).decode('utf-8') | ||||||
|  |  | ||||||
|  |         # 构建JSON数据 | ||||||
|  |         data = { | ||||||
|  |             "gateway": gateway, | ||||||
|  |             "frame_base64": frame_base64, | ||||||
|  |             "category_counts": category_counts # 使用传入的 category_counts | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         print(f"DEBUG push_base64_frame: Pushing data: {data.get('category_counts')}") # 调试打印,检查发送的数据 | ||||||
|  |         # 发送POST请求 | ||||||
|  |         response = http_client.post(push_url, json=data, timeout=5) | ||||||
|  |  | ||||||
|  |         # 检查响应 | ||||||
|  |         if response.status_code == 200: | ||||||
|  |             print(f"线程 '{thread_name}': base64帧已成功推送到 {push_url}") | ||||||
|  |         else: | ||||||
|  |             print(f"错误 (线程 '{thread_name}'): 推送base64帧失败,状态码: {response.status_code}") | ||||||
|  |  | ||||||
|  |     except Exception as e: | ||||||
|  |         print(f"错误 (线程 '{thread_name}'): 处理或推送base64帧时发生错误: {e}") | ||||||
							
								
								
									
										5
									
								
								main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								main.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,5 @@ | |||||||
|  | import uvicorn | ||||||
|  | from web import app | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     uvicorn.run(app, host="0.0.0.0", port=8000) | ||||||
							
								
								
									
										
											BIN
										
									
								
								models/人车.pth
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								models/人车.pth
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										27
									
								
								requirements - 副本.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								requirements - 副本.txt
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,27 @@ | |||||||
|  | aiohttp==3.11.14 | ||||||
|  | aiortc==1.11.0 | ||||||
|  | aiosignal==1.3.2 | ||||||
|  | APScheduler==3.11.0 | ||||||
|  | av==14.2.0 | ||||||
|  | fastapi==0.115.11 | ||||||
|  | huggingface-hub==0.30.1 | ||||||
|  | numpy==2.1.1 | ||||||
|  | nvidia-cuda-runtime-cu12==12.8.90 | ||||||
|  | opencv-contrib-python==4.11.0.86 | ||||||
|  | opencv-python==4.11.0.86 | ||||||
|  | pillow==11.1.0 | ||||||
|  | pillow_heif==0.22.0 | ||||||
|  | pycuda==2025.1 | ||||||
|  | pydantic==2.10.6 | ||||||
|  | pydantic_core==2.27.2 | ||||||
|  | requests==2.32.3 | ||||||
|  | requests-toolbelt==1.0.0 | ||||||
|  | rfdetr==1.1.0 | ||||||
|  | safetensors==0.5.3 | ||||||
|  | supervision==0.25.1 | ||||||
|  | torch==2.6.0+cu126 | ||||||
|  | torchaudio==2.6.0+cu126 | ||||||
|  | torchvision==0.21.0+cu126 | ||||||
|  | transformers==4.50.3 | ||||||
|  | uvicorn==0.34.0 | ||||||
|  | wandb==0.19.9 | ||||||
							
								
								
									
										233
									
								
								requirements.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										233
									
								
								requirements.txt
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,233 @@ | |||||||
|  | absl-py==2.2.1 | ||||||
|  | accelerate==1.6.0 | ||||||
|  | addict==2.4.0 | ||||||
|  | aiohappyeyeballs==2.6.1 | ||||||
|  | aiohttp==3.11.14 | ||||||
|  | aioice==0.9.0 | ||||||
|  | aiortc==1.11.0 | ||||||
|  | aiosignal==1.3.2 | ||||||
|  | albucore==0.0.23 | ||||||
|  | albumentations==2.0.5 | ||||||
|  | altgraph==0.17.4 | ||||||
|  | annotated-types==0.7.0 | ||||||
|  | anyio==4.8.0 | ||||||
|  | anywidget==0.9.18 | ||||||
|  | appdirs==1.4.4 | ||||||
|  | APScheduler==3.11.0 | ||||||
|  | asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1733250440834/work | ||||||
|  | async-timeout==5.0.1 | ||||||
|  | attrs==25.3.0 | ||||||
|  | av==14.2.0 | ||||||
|  | basicsr==1.4.2 | ||||||
|  | bbox_visualizer==0.2.0 | ||||||
|  | blind-watermark==0.4.4 | ||||||
|  | certifi==2025.1.31 | ||||||
|  | cffi==1.17.1 | ||||||
|  | charset-normalizer==3.4.1 | ||||||
|  | click==8.1.8 | ||||||
|  | cmake==3.31.6 | ||||||
|  | colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1733218098505/work | ||||||
|  | coloredlogs==15.0.1 | ||||||
|  | comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1733502965406/work | ||||||
|  | contourpy==1.3.1 | ||||||
|  | cryptography==44.0.2 | ||||||
|  | cycler==0.12.1 | ||||||
|  | Cython==3.0.12 | ||||||
|  | debugpy @ file:///D:/bld/debugpy_1741148401445/work | ||||||
|  | decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1740384970518/work | ||||||
|  | deep-sort-realtime==1.3.2 | ||||||
|  | defusedxml==0.7.1 | ||||||
|  | dnspython==2.7.0 | ||||||
|  | docker-pycreds==0.4.0 | ||||||
|  | easydict==1.13 | ||||||
|  | einops==0.8.1 | ||||||
|  | et_xmlfile==2.0.0 | ||||||
|  | exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1733208806608/work | ||||||
|  | executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1733569351617/work | ||||||
|  | facexlib==0.3.0 | ||||||
|  | fairscale==0.4.13 | ||||||
|  | fastapi==0.115.11 | ||||||
|  | filelock==3.13.1 | ||||||
|  | filetype==1.2.0 | ||||||
|  | filterpy==1.4.5 | ||||||
|  | fire==0.7.0 | ||||||
|  | flatbuffers==25.2.10 | ||||||
|  | fonttools==4.56.0 | ||||||
|  | frozenlist==1.5.0 | ||||||
|  | fsspec==2024.6.1 | ||||||
|  | ftfy==6.3.1 | ||||||
|  | future==1.0.0 | ||||||
|  | gfpgan==1.3.8 | ||||||
|  | gitdb==4.0.12 | ||||||
|  | GitPython==3.1.44 | ||||||
|  | google-crc32c==1.7.1 | ||||||
|  | grpcio==1.71.0 | ||||||
|  | h11==0.14.0 | ||||||
|  | httptools==0.6.4 | ||||||
|  | huggingface-hub==0.30.1 | ||||||
|  | humanfriendly==10.0 | ||||||
|  | idna==3.7 | ||||||
|  | ifaddr==0.2.0 | ||||||
|  | imageio==2.37.0 | ||||||
|  | importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1737420181517/work | ||||||
|  | iniconfig==2.1.0 | ||||||
|  | insightface==0.7.3 | ||||||
|  | ipykernel @ file:///D:/bld/ipykernel_1719845595208/work | ||||||
|  | ipython @ file:///D:/bld/bld/rattler-build_ipython_1740856913/work | ||||||
|  | ipywidgets==8.1.5 | ||||||
|  | jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1733300866624/work | ||||||
|  | Jinja2==3.1.4 | ||||||
|  | joblib==1.4.2 | ||||||
|  | jupyter_bbox_widget==0.6.0 | ||||||
|  | jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1733440914442/work | ||||||
|  | jupyter_core @ file:///D:/bld/jupyter_core_1727163532151/work | ||||||
|  | jupyterlab_widgets==3.0.13 | ||||||
|  | kiwisolver==1.4.8 | ||||||
|  | lap==0.5.12 | ||||||
|  | lazy_loader==0.4 | ||||||
|  | llvmlite==0.44.0 | ||||||
|  | lmdb==1.6.2 | ||||||
|  | Mako==1.3.9 | ||||||
|  | Markdown==3.7 | ||||||
|  | markdown-it-py==3.0.0 | ||||||
|  | MarkupSafe==2.1.5 | ||||||
|  | matplotlib==3.10.1 | ||||||
|  | matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1733416936468/work | ||||||
|  | mdurl==0.1.2 | ||||||
|  | memory-profiler==0.61.0 | ||||||
|  | motmetrics==1.4.0 | ||||||
|  | mpmath==1.3.0 | ||||||
|  | multidict==6.2.0 | ||||||
|  | nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1733325553580/work | ||||||
|  | networkx==3.3 | ||||||
|  | ninja==1.11.1.3 | ||||||
|  | Nuitka==2.6.7 | ||||||
|  | numba==0.61.0 | ||||||
|  | numpy==2.1.1 | ||||||
|  | nvidia-cuda-runtime-cu12==12.8.90 | ||||||
|  | onnx==1.16.1 | ||||||
|  | onnx-graphsurgeon==0.5.7 | ||||||
|  | onnxruntime-gpu==1.20.2 | ||||||
|  | onnxsim==0.4.36 | ||||||
|  | onnxslim==0.1.48 | ||||||
|  | open_clip_torch==2.32.0 | ||||||
|  | opencv-contrib-python==4.11.0.86 | ||||||
|  | opencv-python==4.11.0.86 | ||||||
|  | openpyxl==3.1.5 | ||||||
|  | ordered-set==4.1.0 | ||||||
|  | packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1733203243479/work | ||||||
|  | pandas==2.2.3 | ||||||
|  | parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1733271261340/work | ||||||
|  | pefile==2023.2.7 | ||||||
|  | peft==0.15.1 | ||||||
|  | pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1733327343728/work | ||||||
|  | pillow==11.1.0 | ||||||
|  | pillow_heif==0.22.0 | ||||||
|  | platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1733232627818/work | ||||||
|  | pluggy==1.5.0 | ||||||
|  | polygraphy==0.49.20 | ||||||
|  | prettytable==3.15.1 | ||||||
|  | prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1737453357274/work | ||||||
|  | propcache==0.3.1 | ||||||
|  | protobuf==3.20.2 | ||||||
|  | psutil @ file:///D:/bld/psutil_1740663160591/work | ||||||
|  | psygnal==0.12.0 | ||||||
|  | pure_eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1733569405015/work | ||||||
|  | py-cpuinfo==9.0.0 | ||||||
|  | pybboxes==0.1.6 | ||||||
|  | pycocotools==2.0.8 | ||||||
|  | pycparser==2.22 | ||||||
|  | pycuda==2025.1 | ||||||
|  | pydantic==2.10.6 | ||||||
|  | pydantic_core==2.27.2 | ||||||
|  | pyee==13.0.0 | ||||||
|  | Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1736243443484/work | ||||||
|  | pyinstaller==6.12.0 | ||||||
|  | pyinstaller-hooks-contrib==2025.1 | ||||||
|  | pylabel==0.1.55 | ||||||
|  | pylibsrtp==0.11.0 | ||||||
|  | PyMuPDF==1.25.4 | ||||||
|  | pyOpenSSL==25.0.0 | ||||||
|  | pyparsing==3.2.1 | ||||||
|  | pyproj==3.7.1 | ||||||
|  | pyreadline3==3.5.4 | ||||||
|  | PySide6==6.8.2.1 | ||||||
|  | PySide6_Addons==6.8.2.1 | ||||||
|  | PySide6_Essentials==6.8.2.1 | ||||||
|  | pytest==8.3.5 | ||||||
|  | python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1733215673016/work | ||||||
|  | python-dotenv==1.0.1 | ||||||
|  | python-multipart==0.0.20 | ||||||
|  | pytools==2025.1.2 | ||||||
|  | pytz==2025.1 | ||||||
|  | PyWavelets==1.8.0 | ||||||
|  | pywin32==307 | ||||||
|  | pywin32-ctypes==0.2.3 | ||||||
|  | PyYAML==6.0.2 | ||||||
|  | pyzmq @ file:///D:/bld/pyzmq_1738270977186/work | ||||||
|  | realesrgan==0.3.0 | ||||||
|  | regex==2024.11.6 | ||||||
|  | requests==2.32.3 | ||||||
|  | requests-toolbelt==1.0.0 | ||||||
|  | rf100vl==1.0.0 | ||||||
|  | rfdetr==1.1.0 | ||||||
|  | rich==14.0.0 | ||||||
|  | roboflow==1.1.60 | ||||||
|  | safetensors==0.5.3 | ||||||
|  | sahi==0.11.22 | ||||||
|  | scikit-image==0.25.2 | ||||||
|  | scikit-learn==1.6.1 | ||||||
|  | scipy==1.15.2 | ||||||
|  | seaborn==0.13.2 | ||||||
|  | sentry-sdk==2.25.1 | ||||||
|  | setproctitle==1.3.5 | ||||||
|  | shapely==2.0.7 | ||||||
|  | shiboken6==6.8.2.1 | ||||||
|  | simsimd==6.2.1 | ||||||
|  | six @ file:///home/conda/feedstock_root/build_artifacts/six_1733380938961/work | ||||||
|  | smmap==5.0.2 | ||||||
|  | sniffio==1.3.1 | ||||||
|  | stack_data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1733569443808/work | ||||||
|  | starlette==0.46.1 | ||||||
|  | stringzilla==3.12.2 | ||||||
|  | supervision==0.25.1 | ||||||
|  | sympy==1.13.1 | ||||||
|  | tabulate==0.9.0 | ||||||
|  | tb-nightly==2.20.0a20250326 | ||||||
|  | tensorboard-data-server==0.7.2 | ||||||
|  | tensorrt==10.9.0.34 | ||||||
|  | tensorrt_cu12==10.9.0.34 | ||||||
|  | tensorrt_cu12_bindings==10.9.0.34 | ||||||
|  | tensorrt_cu12_libs==10.9.0.34 | ||||||
|  | termcolor==2.5.0 | ||||||
|  | terminaltables==3.1.10 | ||||||
|  | threadpoolctl==3.5.0 | ||||||
|  | tifffile==2025.2.18 | ||||||
|  | timm==1.0.15 | ||||||
|  | tokenizers==0.21.1 | ||||||
|  | tomli==2.2.1 | ||||||
|  | torch==2.6.0+cu126 | ||||||
|  | torchaudio==2.6.0+cu126 | ||||||
|  | torchvision==0.21.0+cu126 | ||||||
|  | tornado @ file:///D:/bld/tornado_1732615925919/work | ||||||
|  | tqdm==4.67.1 | ||||||
|  | traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1733367359838/work | ||||||
|  | transformers==4.50.3 | ||||||
|  | typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1733188668063/work | ||||||
|  | tzdata==2025.1 | ||||||
|  | tzlocal==5.3.1 | ||||||
|  | ultralytics==8.3.99 | ||||||
|  | ultralytics-thop==2.0.14 | ||||||
|  | urllib3==2.3.0 | ||||||
|  | uvicorn==0.34.0 | ||||||
|  | wandb==0.19.9 | ||||||
|  | watchfiles==1.0.4 | ||||||
|  | wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1733231326287/work | ||||||
|  | websockets==15.0.1 | ||||||
|  | Werkzeug==3.1.3 | ||||||
|  | widgetsnbextension==4.0.13 | ||||||
|  | xmltodict==0.14.2 | ||||||
|  | yapf==0.43.0 | ||||||
|  | yarl==1.18.3 | ||||||
|  | zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1732827521216/work | ||||||
|  | zstandard==0.23.0 | ||||||
							
								
								
									
										21
									
								
								result.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								result.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,21 @@ | |||||||
|  | # 定义返回值类 | ||||||
|  | class Response: | ||||||
|  |     def __init__(self, code, data, message): | ||||||
|  |         self.code = code | ||||||
|  |         self.data = data | ||||||
|  |         self.message = message | ||||||
|  |  | ||||||
|  |     def to_dict(self): | ||||||
|  |         return { | ||||||
|  |             "code": self.code, | ||||||
|  |             "data": self.data, | ||||||
|  |             "message": self.message | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def success(cls, data=None, message="操作成功"): | ||||||
|  |         return cls(200, data, message).to_dict() | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def error(cls, data=None, message="操作失败", code=400): | ||||||
|  |         return cls(code, data, message).to_dict() | ||||||
							
								
								
									
										254
									
								
								rfdetr_core.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										254
									
								
								rfdetr_core.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,254 @@ | |||||||
|  | import cv2 | ||||||
|  | import supervision as sv | ||||||
|  | from rfdetr import RFDETRBase | ||||||
|  | from collections import defaultdict | ||||||
|  | from typing import Dict, Set | ||||||
|  | from PIL import Image, ImageDraw, ImageFont  # 导入PIL库 | ||||||
|  | import numpy as np  # 导入numpy用于图像格式转换 | ||||||
|  | import json # 新增 | ||||||
|  | import os   # 新增 | ||||||
|  |  | ||||||
|  | class RFDETRDetector: | ||||||
|  |     def __init__(self, config_name: str, base_model_dir="models", base_config_dir="configs", default_font_path="./font/MSYH.TTC", default_font_size=15): | ||||||
|  |         self.config_path = os.path.join(base_config_dir, f"{config_name}.json") | ||||||
|  |         if not os.path.exists(self.config_path): | ||||||
|  |             raise FileNotFoundError(f"配置文件不存在: {self.config_path}") | ||||||
|  |  | ||||||
|  |         with open(self.config_path, 'r', encoding='utf-8') as f: | ||||||
|  |             self.config = json.load(f) | ||||||
|  |  | ||||||
|  |         model_path = os.path.join(base_model_dir, self.config['model_pth_filename']) | ||||||
|  |         resolution = self.config['resolution'] | ||||||
|  |          | ||||||
|  |         # 从配置读取字体路径和大小,如果未提供则使用默认值 | ||||||
|  |         font_path = self.config.get('font_path', default_font_path) | ||||||
|  |         font_size = self.config.get('font_size', default_font_size) | ||||||
|  |  | ||||||
|  |         # 1. 初始化模型 | ||||||
|  |         self.model = RFDETRBase( | ||||||
|  |             pretrain_weights=model_path, | ||||||
|  |             # pretrain_weights=model_path or r"E:\A\rf-detr-main\output\pre-train1\checkpoint_best_ema.pth", | ||||||
|  |             resolution=resolution | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         # 2. 初始化跟踪器 | ||||||
|  |         self.tracker = sv.ByteTrack( | ||||||
|  |             track_activation_threshold=self.config['tracker_activation_threshold'], | ||||||
|  |             lost_track_buffer=self.config['tracker_lost_buffer'], | ||||||
|  |             minimum_matching_threshold=self.config['tracker_match_threshold'], | ||||||
|  |             minimum_consecutive_frames=self.config['tracker_consecutive_frames'], | ||||||
|  |             frame_rate=self.config['tracker_frame_rate'] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         # 3. 类别定义 | ||||||
|  |         self.VISDRONE_CLASSES = self.config['classes_en'] | ||||||
|  |         self.VISDRONE_CLASSES_CHINESE = self.config['classes_zh_map'] | ||||||
|  |  | ||||||
|  |         # 新增:加载类别启用配置 | ||||||
|  |         self.detection_settings = self.config.get('detection_settings', {}) | ||||||
|  |         self.enabled_classes_filter = self.detection_settings.get('enabled_classes', {}) | ||||||
|  |         # 构建一个查找表,对于未在filter中指定的类别,默认为 True (启用) | ||||||
|  |         self._active_classes_lookup = { | ||||||
|  |             cls_name: self.enabled_classes_filter.get(cls_name, True) | ||||||
|  |             for cls_name in self.VISDRONE_CLASSES | ||||||
|  |         } | ||||||
|  |         print(f"活动类别配置: {self._active_classes_lookup}") | ||||||
|  |  | ||||||
|  |         # 4. 初始化字体 | ||||||
|  |         self.FONT_SIZE = font_size | ||||||
|  |         try: | ||||||
|  |             self.font = ImageFont.truetype(font_path, self.FONT_SIZE) | ||||||
|  |         except IOError: | ||||||
|  |             print(f"错误:无法加载字体 {font_path}。将使用默认字体。") | ||||||
|  |             self.font = ImageFont.load_default() # 使用真正通用的默认字体 | ||||||
|  |  | ||||||
|  |         # 5. 类别计数器 (作为类属性) | ||||||
|  |         self.class_tracks: Dict[str, Set[int]] = defaultdict(set) | ||||||
|  |         self.category_counts: Dict[str, int] = defaultdict(int) | ||||||
|  |  | ||||||
|  |         # 6. 初始化标注器 | ||||||
|  |         # 从配置加载默认颜色,如果失败则使用预设颜色 | ||||||
|  |         self.default_color_hex = self.config.get('default_color_hex', "#00FF00") # 默认绿色 | ||||||
|  |         self.bounding_box_thickness = self.config.get('bounding_box_thickness', 2) | ||||||
|  |  | ||||||
|  |         # 加载颜色配置,用于 PIL 绘制 | ||||||
|  |         self.class_colors_hex = self.config.get('class_colors_hex', {}) | ||||||
|  |         self.last_annotated_frame: np.ndarray | None = None # 新增: 用于存储最新的标注帧 | ||||||
|  |  | ||||||
|  |     def _hex_to_rgb(self, hex_color: str) -> tuple: | ||||||
|  |         """将十六进制颜色字符串转换为RGB元组。""" | ||||||
|  |         hex_color = hex_color.lstrip('#') | ||||||
|  |         try: | ||||||
|  |             return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4)) | ||||||
|  |         except ValueError: | ||||||
|  |             print(f"警告: 无法解析十六进制颜色 '{hex_color}', 将使用默认颜色。") | ||||||
|  |             # 解析失败时返回一个默认颜色,例如红色 | ||||||
|  |             return self._hex_to_rgb(self.default_color_hex if self.default_color_hex != hex_color else "#00FF00") | ||||||
|  |  | ||||||
|  |     def _update_counter(self, detections: sv.Detections): | ||||||
|  |         """更新类别计数器""" | ||||||
|  |         # 只统计有 tracker_id 的检测结果 | ||||||
|  |         valid_indices = detections.tracker_id != None | ||||||
|  |         if not np.any(valid_indices): # 处理 detections 为空或 tracker_id 都为 None 的情况 | ||||||
|  |              return | ||||||
|  |  | ||||||
|  |         class_ids = detections.class_id[valid_indices] | ||||||
|  |         track_ids = detections.tracker_id[valid_indices] | ||||||
|  |  | ||||||
|  |         for class_id, track_id in zip(class_ids, track_ids): | ||||||
|  |             if track_id is None: # 跳过没有 tracker_id 的项 | ||||||
|  |                 continue | ||||||
|  |             # 使用英文类别名作为内部 key | ||||||
|  |             class_name = self.VISDRONE_CLASSES[class_id] | ||||||
|  |             if track_id not in self.class_tracks[class_name]: | ||||||
|  |                 self.class_tracks[class_name].add(track_id) | ||||||
|  |                 self.category_counts[class_name] += 1 | ||||||
|  |  | ||||||
|  |     def _draw_frame(self, frame: np.ndarray, detections: sv.Detections) -> np.ndarray: | ||||||
|  |         """使用PIL绘制检测框、中文标签和计数信息""" | ||||||
|  |          | ||||||
|  |         pil_image = Image.fromarray(cv2.cvtColor(frame.copy(), cv2.COLOR_BGR2RGB)) | ||||||
|  |         draw = ImageDraw.Draw(pil_image) | ||||||
|  |  | ||||||
|  |         # --- 使用 PIL 绘制检测框和中文标签 --- | ||||||
|  |         valid_indices = detections.tracker_id != None # 或直接使用 detections.xyxy 如果不过滤无 tracker_id 的 | ||||||
|  |         if np.any(valid_indices):  | ||||||
|  |             boxes = detections.xyxy[valid_indices] | ||||||
|  |             class_ids = detections.class_id[valid_indices] | ||||||
|  |             # tracker_ids = detections.tracker_id[valid_indices] # 如果需要 tracker_id | ||||||
|  |  | ||||||
|  |             for box, class_id in zip(boxes, class_ids): | ||||||
|  |                 x1, y1, x2, y2 = map(int, box) | ||||||
|  |  | ||||||
|  |                 english_label = self.VISDRONE_CLASSES[class_id] | ||||||
|  |                 chinese_label = self.VISDRONE_CLASSES_CHINESE.get(english_label, english_label) | ||||||
|  |  | ||||||
|  |                 # 获取边界框颜色 | ||||||
|  |                 box_color_hex = self.class_colors_hex.get(english_label, self.default_color_hex) | ||||||
|  |                 box_rgb_color = self._hex_to_rgb(box_color_hex) | ||||||
|  |                  | ||||||
|  |                 # 绘制边界框 | ||||||
|  |                 draw.rectangle([x1, y1, x2, y2], outline=box_rgb_color, width=self.bounding_box_thickness) | ||||||
|  |  | ||||||
|  |                 # 绘制中文标签 (与之前逻辑类似) | ||||||
|  |                 text_to_draw = f"{chinese_label}" | ||||||
|  |                 # 标签背景 (可选,使其更易读) | ||||||
|  |                 # label_text_bbox = draw.textbbox((0,0), text_to_draw, font=self.font) | ||||||
|  |                 # label_width = label_text_bbox[2] - label_text_bbox[0] | ||||||
|  |                 # label_height = label_text_bbox[3] - label_text_bbox[1] | ||||||
|  |                 # label_bg_y1 = y1 - label_height - 4 if y1 - label_height - 4 > 0 else y1 + 2 | ||||||
|  |                 # draw.rectangle([x1, label_bg_y1, x1 + label_width + 4, label_bg_y1 + label_height + 2], fill=box_rgb_color) | ||||||
|  |                 # text_color = (255,255,255) if sum(box_rgb_color) < 382 else (0,0,0) # 简易对比色 | ||||||
|  |                 text_color = (255, 255, 255) # 白色 (RGB) | ||||||
|  |  | ||||||
|  |                 text_x = x1 + 2 # 稍微偏移,避免紧贴边框 | ||||||
|  |                 text_y = y1 - self.FONT_SIZE - 2 | ||||||
|  |                 if text_y < 0: # 如果标签超出图像顶部 | ||||||
|  |                     text_y = y1 + 2 | ||||||
|  |                  | ||||||
|  |                 draw.text((text_x, text_y), text_to_draw, font=self.font, fill=text_color) | ||||||
|  |  | ||||||
|  |         # --- 绘制统计面板 (右上角) --- | ||||||
|  |         stats_text_lines = [ | ||||||
|  |             f"{self.VISDRONE_CLASSES_CHINESE.get(cls, cls)}: {self.category_counts[cls]}" | ||||||
|  |             for cls in self.VISDRONE_CLASSES if self.category_counts[cls] > 0 | ||||||
|  |         ] | ||||||
|  |  | ||||||
|  |         frame_height, frame_width, _ = frame.shape | ||||||
|  |         stats_start_x = frame_width - self.config.get('stats_panel_width', 200)  | ||||||
|  |         stats_start_y = self.config.get('stats_panel_margin_y', 10) | ||||||
|  |         line_height = self.FONT_SIZE + self.config.get('stats_line_spacing', 5) | ||||||
|  |          | ||||||
|  |         stats_text_color_hex = self.config.get('stats_text_color_hex', "#FFFFFF") | ||||||
|  |         stats_text_color = self._hex_to_rgb(stats_text_color_hex) | ||||||
|  |  | ||||||
|  |         # 可选:为统计面板添加背景 | ||||||
|  |         if stats_text_lines: | ||||||
|  |             panel_height = len(stats_text_lines) * line_height + 10 | ||||||
|  |             panel_y2 = stats_start_y + panel_height | ||||||
|  |             # 半透明背景 | ||||||
|  |             # overlay = Image.new('RGBA', pil_image.size, (0,0,0,0)) | ||||||
|  |             # panel_draw = ImageDraw.Draw(overlay) | ||||||
|  |             # panel_draw.rectangle( | ||||||
|  |             #     [stats_start_x - 5, stats_start_y - 5, frame_width - 5, panel_y2], | ||||||
|  |             #     fill=(100, 100, 100, 128) # 半透明灰色 | ||||||
|  |             # ) | ||||||
|  |             # pil_image = Image.alpha_composite(pil_image.convert('RGBA'), overlay) | ||||||
|  |             # draw = ImageDraw.Draw(pil_image) # 如果用了 alpha_composite, 需要重新获取 draw 对象 | ||||||
|  |              | ||||||
|  |             # 或者简单不透明背景 | ||||||
|  |             # draw.rectangle( | ||||||
|  |             #    [stats_start_x - 5, stats_start_y - 5, frame_width - 5, panel_y2], | ||||||
|  |             #    fill=self._hex_to_rgb(self.config.get('stats_panel_bg_color_hex', "#808080")) # 例如灰色背景 | ||||||
|  |             # ) | ||||||
|  |  | ||||||
|  |         for i, line in enumerate(stats_text_lines): | ||||||
|  |             text_pos = (stats_start_x, stats_start_y + i * line_height) | ||||||
|  |             draw.text(text_pos, line, font=self.font, fill=stats_text_color) | ||||||
|  |  | ||||||
|  |         final_annotated_frame = cv2.cvtColor(np.array(pil_image.convert('RGB')), cv2.COLOR_RGB2BGR) | ||||||
|  |         return final_annotated_frame | ||||||
|  |  | ||||||
|  |     def detect_and_draw_count(self, frame: np.ndarray, conf: float = -1.0) -> np.ndarray: | ||||||
|  |         """执行单帧检测、跟踪、计数并绘制结果(包含类别过滤)。""" | ||||||
|  |         if conf == -1.0: | ||||||
|  |             # 优先从 detection_settings 中获取,其次是顶层config,最后是硬编码默认值 | ||||||
|  |             effective_conf = float( | ||||||
|  |                 self.detection_settings.get('default_confidence_threshold',  | ||||||
|  |                                            self.config.get('default_confidence_threshold', 0.8)) | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|  |             effective_conf = conf | ||||||
|  |          | ||||||
|  |         try: | ||||||
|  |             # 1. 执行检测 | ||||||
|  |             detections = self.model.predict(frame, threshold=effective_conf) | ||||||
|  |  | ||||||
|  |             # 处理 detections 为 None 或空的情况 | ||||||
|  |             if detections is None or len(detections) == 0: | ||||||
|  |                  detections = sv.Detections.empty() | ||||||
|  |                  annotated_frame = self._draw_frame(frame, detections) | ||||||
|  |                  self.last_annotated_frame = annotated_frame.copy() # 新增 | ||||||
|  |                  return annotated_frame | ||||||
|  |  | ||||||
|  |             # 新增:根据配置过滤检测到的类别 | ||||||
|  |             if detections is not None and len(detections) > 0: | ||||||
|  |                 keep_indices = [] | ||||||
|  |                 for i, class_id in enumerate(detections.class_id): | ||||||
|  |                     if class_id < len(self.VISDRONE_CLASSES): # 确保 class_id 有效 | ||||||
|  |                         class_name = self.VISDRONE_CLASSES[class_id] | ||||||
|  |                         if self._active_classes_lookup.get(class_name, True): # 默认为 True | ||||||
|  |                             keep_indices.append(i) | ||||||
|  |                     else: | ||||||
|  |                         print(f"警告: 检测到无效的 class_id {class_id},超出了已知类别范围。") | ||||||
|  |  | ||||||
|  |                 if not keep_indices: | ||||||
|  |                     detections = sv.Detections.empty() | ||||||
|  |                 else: | ||||||
|  |                     detections = detections[keep_indices] | ||||||
|  |              | ||||||
|  |             # 如果过滤后没有检测结果 | ||||||
|  |             if len(detections) == 0: | ||||||
|  |                 annotated_frame = self._draw_frame(frame, sv.Detections.empty()) | ||||||
|  |                 self.last_annotated_frame = annotated_frame.copy() # 新增 | ||||||
|  |                 return annotated_frame | ||||||
|  |  | ||||||
|  |             # 2. 执行跟踪 (只对过滤后的结果进行跟踪) | ||||||
|  |             detections = self.tracker.update_with_detections(detections) | ||||||
|  |  | ||||||
|  |             # 3. 更新计数器 (只对过滤并跟踪后的结果进行计数) | ||||||
|  |             self._update_counter(detections) | ||||||
|  |  | ||||||
|  |             # 4. 绘制结果 | ||||||
|  |             annotated_frame = self._draw_frame(frame, detections) | ||||||
|  |             self.last_annotated_frame = annotated_frame.copy() # 新增 | ||||||
|  |  | ||||||
|  |             return annotated_frame | ||||||
|  |  | ||||||
|  |         except Exception as e: | ||||||
|  |             print(f"处理帧时发生错误: {e}") | ||||||
|  |             if frame is not None: | ||||||
|  |                 self.last_annotated_frame = frame.copy() # 新增 | ||||||
|  |             else: | ||||||
|  |                 self.last_annotated_frame = None # 新增 | ||||||
|  |             return frame | ||||||
							
								
								
									
										94
									
								
								rtc_handler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								rtc_handler.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,94 @@ | |||||||
|  | import asyncio | ||||||
|  | import queue | ||||||
|  | from fractions import Fraction | ||||||
|  | from urllib.parse import urlparse | ||||||
|  |  | ||||||
|  | import aiohttp | ||||||
|  | import av | ||||||
|  | import numpy as np | ||||||
|  | from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration, VideoStreamTrack | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class DummyVideoTrack(VideoStreamTrack): | ||||||
|  |     async def recv(self): | ||||||
|  |         # 简洁初始化、返回固定颜色的帧 | ||||||
|  |         return np.full((480, 640, 3), (0, 0, 255), dtype=np.uint8) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def receive_video_frames(whep_url): | ||||||
|  |     pc = RTCPeerConnection(RTCConfiguration(iceServers=[])) | ||||||
|  |     frames_queue = asyncio.Queue() | ||||||
|  |  | ||||||
|  |     pc.addTrack(DummyVideoTrack()) | ||||||
|  |  | ||||||
|  |     @pc.on("track") | ||||||
|  |     def on_track(track): | ||||||
|  |         if track.kind == "video": | ||||||
|  |             asyncio.create_task(consume_track(track, frames_queue)) | ||||||
|  |  | ||||||
|  |     @pc.on("iceconnectionstatechange") | ||||||
|  |     def on_ice_connection_state_change(): | ||||||
|  |         print(f"ICE 连接状态: {pc.iceConnectionState}") | ||||||
|  |  | ||||||
|  |     offer = await pc.createOffer() | ||||||
|  |     await pc.setLocalDescription(offer) | ||||||
|  |  | ||||||
|  |     headers = {"Content-Type": "application/sdp"} | ||||||
|  |  | ||||||
|  |     async with aiohttp.ClientSession() as session: | ||||||
|  |         async with session.post(whep_url, data=pc.localDescription.sdp, headers=headers) as response: | ||||||
|  |             if response.status != 201: | ||||||
|  |                 raise Exception(f"服务器返回错误: {response.status}") | ||||||
|  |  | ||||||
|  |             answer = RTCSessionDescription(sdp=await response.text(), type="answer") | ||||||
|  |             await pc.setRemoteDescription(answer) | ||||||
|  |  | ||||||
|  |             if "Location" in response.headers: | ||||||
|  |                 base_url = f"{urlparse(whep_url).scheme}://{urlparse(whep_url).netloc}" | ||||||
|  |                 print("ICE 协商 URL:", base_url + response.headers["Location"]) | ||||||
|  |  | ||||||
|  |     while pc.iceConnectionState not in ["connected", "completed"]: | ||||||
|  |         await asyncio.sleep(1) | ||||||
|  |  | ||||||
|  |     print("ICE 连接完成,开始接收视频流") | ||||||
|  |  | ||||||
|  |     try: | ||||||
|  |         while True: | ||||||
|  |             frame = await frames_queue.get() | ||||||
|  |             if frame is None: | ||||||
|  |                 break | ||||||
|  |             yield frame | ||||||
|  |     except KeyboardInterrupt: | ||||||
|  |         pass | ||||||
|  |     finally: | ||||||
|  |         await pc.close() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def consume_track(track, frames_queue): | ||||||
|  |     try: | ||||||
|  |         while True: | ||||||
|  |             frame = await track.recv() | ||||||
|  |             if frame is None: | ||||||
|  |                 print("没有接收到有效的帧数据") | ||||||
|  |                 await frames_queue.put(None) | ||||||
|  |                 break | ||||||
|  |             img = frame.to_ndarray(format="bgr24") | ||||||
|  |             await frames_queue.put(img) | ||||||
|  |     except Exception as e: | ||||||
|  |         print("处理帧错误:", e) | ||||||
|  |         await frames_queue.put(None) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def rtc_frame(url, frame_queue): | ||||||
|  |     async def main(): | ||||||
|  |         async for frame in receive_video_frames(url): | ||||||
|  |             try: | ||||||
|  |                 frame_queue.put_nowait(frame) | ||||||
|  |             except queue.Full: | ||||||
|  |                 frame_queue.get_nowait() | ||||||
|  |                 frame_queue.put_nowait(frame) | ||||||
|  |  | ||||||
|  |     loop = asyncio.new_event_loop() | ||||||
|  |     asyncio.set_event_loop(loop) | ||||||
|  |     loop.run_until_complete(main()) | ||||||
|  |     loop.close() | ||||||
							
								
								
									
										276
									
								
								web.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										276
									
								
								web.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,276 @@ | |||||||
|  | import base64 | ||||||
|  | import hashlib | ||||||
|  | import queue | ||||||
|  | import time | ||||||
|  | from multiprocessing import Process, Event, Queue as MpQueue | ||||||
|  | from typing import Dict, Any | ||||||
|  |  | ||||||
|  | import cv2 | ||||||
|  | import numpy as np | ||||||
|  | from fastapi import FastAPI, Request, HTTPException | ||||||
|  |  | ||||||
|  | from api_server import model_management_router, initialize_default_model_on_startup, get_active_detector, \ | ||||||
|  |     get_active_model_identifier | ||||||
|  | from data_pusher import pusher_router, initialize_data_pusher, get_data_pusher_instance | ||||||
|  | from frame_transfer import yolo_frame, push_frame | ||||||
|  | from result import Response | ||||||
|  | from rfdetr_core import RFDETRDetector | ||||||
|  | from rtc_handler import rtc_frame | ||||||
|  |  | ||||||
|  | app = FastAPI(title="Real-time Video Processing, Model Management, and Data Pusher API") | ||||||
|  |  | ||||||
|  | app.include_router(model_management_router, prefix="/api", tags=["Model Management"]) | ||||||
|  | app.include_router(pusher_router, prefix="/api/pusher", tags=["Data Pusher"]) | ||||||
|  |  | ||||||
|  | process_map: Dict[str, Dict[str, Any]] = {} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @app.on_event("startup") | ||||||
|  | async def web_app_startup_event(): | ||||||
|  |     print("主应用服务启动中...") | ||||||
|  |     await initialize_default_model_on_startup() | ||||||
|  |     active_detector_instance = get_active_detector() | ||||||
|  |     if active_detector_instance: | ||||||
|  |         print(f"主应用启动:检测到活动模型 '{get_active_model_identifier()}',将用于初始化 DataPusher。") | ||||||
|  |     else: | ||||||
|  |         print("主应用启动:未检测到活动模型。DataPusher 将以无检测器状态初始化。") | ||||||
|  |     initialize_data_pusher(active_detector_instance) | ||||||
|  |     print("DataPusher 服务已初始化。") | ||||||
|  |     print("主应用启动完成。") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @app.on_event("shutdown") | ||||||
|  | async def web_app_shutdown_event(): | ||||||
|  |     print("主应用服务关闭中...") | ||||||
|  |     pusher = get_data_pusher_instance() | ||||||
|  |     if pusher: | ||||||
|  |         print("正在关闭 DataPusher 调度器...") | ||||||
|  |         pusher.shutdown_scheduler() | ||||||
|  |     else: | ||||||
|  |         print("DataPusher 实例未找到,跳过调度器关闭。") | ||||||
|  |  | ||||||
|  |     print("正在尝试终止所有活动的视频处理子进程...") | ||||||
|  |     for url, task_info in list(process_map.items()): | ||||||
|  |         process: Process = task_info['process'] | ||||||
|  |         stop_event: Event = task_info['stop_event'] | ||||||
|  |         data_q: MpQueue = task_info['data_queue'] | ||||||
|  |  | ||||||
|  |         if process.is_alive(): | ||||||
|  |             print(f"向进程 {url} 发送停止信号...") | ||||||
|  |             stop_event.set() | ||||||
|  |             try: | ||||||
|  |                 process.join(timeout=15) | ||||||
|  |                 if process.is_alive(): | ||||||
|  |                     print(f"进程 {url} 在优雅关闭超时后仍然存活,尝试 terminate。") | ||||||
|  |                     process.terminate() | ||||||
|  |                     process.join(timeout=5) | ||||||
|  |                     if process.is_alive(): | ||||||
|  |                         print(f"进程 {url} 在 terminate 后仍然存活,尝试 kill。") | ||||||
|  |                         process.kill() | ||||||
|  |                         process.join(timeout=2) | ||||||
|  |             except Exception as e: | ||||||
|  |                 print(f"关闭/终止进程 {url} 时发生错误: {e}") | ||||||
|  |  | ||||||
|  |         try: | ||||||
|  |             if not data_q.empty(): | ||||||
|  |                 pass | ||||||
|  |             data_q.close() | ||||||
|  |             data_q.join_thread() | ||||||
|  |         except Exception as e_q_cleanup: | ||||||
|  |             print(f"清理进程 {url} 的数据队列时出错: {e_q_cleanup}") | ||||||
|  |  | ||||||
|  |         del process_map[url] | ||||||
|  |     print("所有视频处理子进程已尝试终止和清理。") | ||||||
|  |     print("主应用服务关闭完成。") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def start_video_processing(url: str, rtmp_url: str, model_config_name: str, stop_event: Event, data_queue: MpQueue, | ||||||
|  |                            gateway: str,frequency:int, push_url:str): | ||||||
|  |     print(f"视频处理子进程启动 (URL: {url}, Model: {model_config_name})") | ||||||
|  |     detector_instance_for_stream: RFDETRDetector = None | ||||||
|  |     producer_thread, transfer_thread, consumer_thread = None, None, None | ||||||
|  |  | ||||||
|  |     try: | ||||||
|  |         print(f"正在为流 {url} 初始化模型: {model_config_name}...") | ||||||
|  |         detector_instance_for_stream = RFDETRDetector(config_name=model_config_name) | ||||||
|  |         print(f"模型 {model_config_name} 为流 {url} 初始化成功。") | ||||||
|  |         rtc_q = queue.Queue(maxsize=10000) | ||||||
|  |         yolo_q = queue.Queue(maxsize=10000) | ||||||
|  |         import threading | ||||||
|  |         producer_thread = threading.Thread(target=rtc_frame, args=(url, rtc_q), name=f"RTC-{url[:20]}", daemon=True) | ||||||
|  |         transfer_thread = threading.Thread(target=yolo_frame, args=(rtc_q, yolo_q, detector_instance_for_stream), | ||||||
|  |                                            name=f"YOLO-{url[:20]}", daemon=True) | ||||||
|  |         consumer_thread = threading.Thread(target=push_frame, args=(yolo_q, rtmp_url,gateway,frequency,push_url), name=f"Push-{url[:20]}", | ||||||
|  |                                            daemon=True) | ||||||
|  |  | ||||||
|  |         producer_thread.start() | ||||||
|  |         transfer_thread.start() | ||||||
|  |         consumer_thread.start() | ||||||
|  |  | ||||||
|  |         stop_event.wait() | ||||||
|  |         print(f"子进程 {url}: 收到停止信号。准备关闭线程...") | ||||||
|  |  | ||||||
|  |     except FileNotFoundError as e: | ||||||
|  |         print(f"错误 (视频进程 {url}): 模型配置文件 '{model_config_name}.json' 未找到。错误: {e}") | ||||||
|  |     except Exception as e: | ||||||
|  |         print(f"错误 (视频进程 {url}): 初始化或运行时错误。错误: {e}") | ||||||
|  |     finally: | ||||||
|  |         print(f"视频处理子进程 {url} 进入 finally 块。") | ||||||
|  |  | ||||||
|  |         if producer_thread and producer_thread.is_alive(): | ||||||
|  |             print(f"子进程 {url}: producer_thread is still alive (daemon).") | ||||||
|  |         if transfer_thread and transfer_thread.is_alive(): | ||||||
|  |             print(f"子进程 {url}: transfer_thread is still alive (daemon).") | ||||||
|  |         if consumer_thread and consumer_thread.is_alive(): | ||||||
|  |             print(f"子进程 {url}: consumer_thread is still alive (daemon).") | ||||||
|  |  | ||||||
|  |         if detector_instance_for_stream: | ||||||
|  |             print(f"子进程 {url}: 收集最后数据...") | ||||||
|  |             final_counts = getattr(detector_instance_for_stream, 'category_counts', {}) | ||||||
|  |             final_frame_np = getattr(detector_instance_for_stream, 'last_annotated_frame', None) | ||||||
|  |             frame_base64 = None | ||||||
|  |             if final_frame_np is not None and isinstance(final_frame_np, np.ndarray): | ||||||
|  |                 try: | ||||||
|  |                     _, buffer = cv2.imencode('.jpg', final_frame_np) | ||||||
|  |                     frame_base64 = base64.b64encode(buffer).decode('utf-8') | ||||||
|  |                 except Exception as e_encode: | ||||||
|  |                     print(f"子进程 {url}: 帧编码错误: {e_encode}") | ||||||
|  |  | ||||||
|  |             payload = { | ||||||
|  |                 "timestamp": time.time(), | ||||||
|  |                 "category_counts": final_counts, | ||||||
|  |                 "frame_base64": frame_base64, | ||||||
|  |                 "source_url": url, | ||||||
|  |                 "event": "task_stopped_final_data" | ||||||
|  |             } | ||||||
|  |             try: | ||||||
|  |                 data_queue.put(payload, timeout=5) | ||||||
|  |                 print(f"子进程 {url}: 已将最终数据放入队列。") | ||||||
|  |             except queue.Full: | ||||||
|  |                 print(f"子进程 {url}: 无法将最终数据放入队列 (队列已满或超时)。") | ||||||
|  |             except Exception as e_put: | ||||||
|  |                 print(f"子进程 {url}: 将最终数据放入队列时发生错误: {e_put}") | ||||||
|  |         else: | ||||||
|  |             print(f"子进程 {url}: 检测器实例不可用,无法发送最终数据。") | ||||||
|  |  | ||||||
|  |         try: | ||||||
|  |             data_queue.close() | ||||||
|  |         except Exception as e_q_close: | ||||||
|  |             print(f"子进程 {url}: 关闭数据队列时出错: {e_q_close}") | ||||||
|  |         print(f"视频处理子进程 {url} 执行完毕。") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @app.post("/start_video", tags=["Video Processing"]) | ||||||
|  | async def start_video(request: Request): | ||||||
|  |     data = await request.json() | ||||||
|  |     url = data.get("url") | ||||||
|  |     model_identifier_to_use = data.get("model_identifier") | ||||||
|  |     host = data.get("host") | ||||||
|  |     rtmp_port = data.get("rtmp_port") | ||||||
|  |     rtc_port = data.get("rtc_port") | ||||||
|  |     gateway = data.get("gateway") | ||||||
|  |     frequency = data.get("frequency") | ||||||
|  |     push_url = data.get("push_url") | ||||||
|  |  | ||||||
|  |     # 生成MD5 | ||||||
|  |     md5_hash = hashlib.md5(url.encode()).hexdigest() | ||||||
|  |     rtmp_url = f"rtmp://{host}:{rtmp_port}/live/{md5_hash}" | ||||||
|  |     rtc_url = f"http://{host}:{rtc_port}/rtc/v1/whep/?{md5_hash}" | ||||||
|  |     if not url or not rtmp_url: | ||||||
|  |         raise HTTPException(status_code=400, detail="'url' 和 'rtmp_url' 字段是必须的。") | ||||||
|  |  | ||||||
|  |     if not model_identifier_to_use: | ||||||
|  |         print(f"请求中未指定 model_identifier,尝试使用全局激活的模型。") | ||||||
|  |         model_identifier_to_use = get_active_model_identifier() | ||||||
|  |         if not model_identifier_to_use: | ||||||
|  |             raise HTTPException(status_code=400, detail="请求中未指定 'model_identifier',且当前无全局激活的默认模型。") | ||||||
|  |         print(f"将为流 {url} 使用当前全局激活的模型: {model_identifier_to_use}") | ||||||
|  |  | ||||||
|  |     if url in process_map and process_map[url]['process'].is_alive(): | ||||||
|  |         raise HTTPException(status_code=409, detail=f"视频处理进程已在运行: {url}") | ||||||
|  |  | ||||||
|  |     print(f"请求启动视频处理: URL = {url}, RTMP = {rtmp_url}, Model = {model_identifier_to_use}") | ||||||
|  |  | ||||||
|  |     stop_event = Event() | ||||||
|  |     data_queue = MpQueue(maxsize=1) | ||||||
|  |  | ||||||
|  |     process = Process(target=start_video_processing, | ||||||
|  |                       args=(url, rtmp_url, model_identifier_to_use, stop_event, data_queue,  gateway,  frequency, push_url)) | ||||||
|  |     process.start() | ||||||
|  |     process_map[url] = {'process': process, 'stop_event': stop_event, 'data_queue': data_queue} | ||||||
|  |     return Response.success(message=f"视频处理已为 URL '{url}' 使用模型 '{model_identifier_to_use}' 启动。", | ||||||
|  |                             data=rtc_url) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @app.post("/stop_video", tags=["Video Processing"]) | ||||||
|  | async def stop_video(request: Request): | ||||||
|  |     data = await request.json() | ||||||
|  |     url = data.get("url") | ||||||
|  |     if not url: | ||||||
|  |         raise HTTPException(status_code=400, detail="'url' 字段是必须的。") | ||||||
|  |  | ||||||
|  |     task_info = process_map.get(url) | ||||||
|  |     if not task_info: | ||||||
|  |         raise HTTPException(status_code=404, detail=f"没有找到与 URL '{url}' 匹配的活动视频处理进程。") | ||||||
|  |  | ||||||
|  |     process: Process = task_info['process'] | ||||||
|  |     stop_event: Event = task_info['stop_event'] | ||||||
|  |     data_q: MpQueue = task_info['data_queue'] | ||||||
|  |  | ||||||
|  |     final_data_pushed = False | ||||||
|  |     if process.is_alive(): | ||||||
|  |         print(f"请求停止视频处理: {url}. 发送停止信号...") | ||||||
|  |         stop_event.set() | ||||||
|  |         process.join(timeout=20) | ||||||
|  |  | ||||||
|  |         if process.is_alive(): | ||||||
|  |             print(f"警告: 视频处理进程 {url} 在超时后未能正常终止,尝试强制结束。") | ||||||
|  |             process.terminate() | ||||||
|  |             process.join(timeout=5) | ||||||
|  |             if process.is_alive(): | ||||||
|  |                 print(f"错误: 视频处理进程 {url} 强制结束后仍然存在。尝试 kill。") | ||||||
|  |                 process.kill() | ||||||
|  |                 process.join(timeout=2) | ||||||
|  |         else: | ||||||
|  |             print(f"进程 {url} 已优雅停止。尝试获取最后数据...") | ||||||
|  |             try: | ||||||
|  |                 final_payload = data_q.get(timeout=10) | ||||||
|  |                 print(f"从停止的任务 {url} 收到最终数据。") | ||||||
|  |  | ||||||
|  |                 pusher_instance = get_data_pusher_instance() | ||||||
|  |                 if pusher_instance and pusher_instance.target_url: | ||||||
|  |                     print(f"准备将任务 {url} 的最后数据推送到 {pusher_instance.target_url}") | ||||||
|  |                     pusher_instance.push_specific_payload(final_payload) | ||||||
|  |                     final_data_pushed = True | ||||||
|  |                 elif pusher_instance: | ||||||
|  |                     print( | ||||||
|  |                         f"DataPusher 服务已配置,但未设置目标URL (pusher.target_url is None)。无法推送任务 {url} 的最后数据。") | ||||||
|  |                 else: | ||||||
|  |                     print(f"DataPusher 服务未初始化或不可用。无法推送任务 {url} 的最后数据。") | ||||||
|  |  | ||||||
|  |             except queue.Empty: | ||||||
|  |                 print(f"警告: 任务 {url} 优雅停止后,未从其数据队列中获取到最终数据 (队列为空或超时)。") | ||||||
|  |             except Exception as e_q_get: | ||||||
|  |                 print(f"获取或处理来自任务 {url} 的最终数据时发生错误: {e_q_get}") | ||||||
|  |     else: | ||||||
|  |         print(f"视频处理进程先前已停止或已结束: {url}") | ||||||
|  |  | ||||||
|  |     try: | ||||||
|  |         while not data_q.empty(): | ||||||
|  |             try: | ||||||
|  |                 data_q.get_nowait() | ||||||
|  |             except queue.Empty: | ||||||
|  |                 break | ||||||
|  |         data_q.close() | ||||||
|  |         data_q.join_thread() | ||||||
|  |     except Exception as e_q_final_cleanup: | ||||||
|  |         print(f"清理任务 {url} 的数据队列的最后步骤中发生错误: {e_q_final_cleanup}") | ||||||
|  |  | ||||||
|  |     del process_map[url] | ||||||
|  |     message = f"视频处理已为 URL '{url}' 停止。" | ||||||
|  |     if final_data_pushed: | ||||||
|  |         message += " 已尝试推送最后的数据。" | ||||||
|  |     elif process.exitcode == 0: | ||||||
|  |         message += " 进程已退出,但未确认最后数据推送 (可能未配置推送或队列问题)。" | ||||||
|  |  | ||||||
|  |     return Response.success(message=message) | ||||||
							
								
								
									
										151
									
								
								yolo_core.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										151
									
								
								yolo_core.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,151 @@ | |||||||
|  | from ultralytics import YOLO | ||||||
|  | import cv2 | ||||||
|  | import numpy as np | ||||||
|  | from PIL import Image, ImageDraw, ImageFont | ||||||
|  | import tensorrt as trt | ||||||
|  | import pycuda.driver as cuda | ||||||
|  | import pycuda.autoinit | ||||||
|  |  | ||||||
|  | class YOLODetector: | ||||||
|  |     def __init__(self, model_path='models/best.engine'): | ||||||
|  |         # 加载 TensorRT 模型 | ||||||
|  |         self.model = YOLO(model_path, task="detect") | ||||||
|  |         # 英文类别名称到中文的映射 | ||||||
|  |         self.class_name_mapping = { | ||||||
|  |             'pedestrian': '行人', | ||||||
|  |             'people': '人群', | ||||||
|  |             'bicycle': '自行车', | ||||||
|  |             'car': '轿车', | ||||||
|  |             'van': '面包车', | ||||||
|  |             'truck': '卡车', | ||||||
|  |             'tricycle': '三轮车', | ||||||
|  |             'awning-tricycle': '篷式三轮车', | ||||||
|  |             'bus': '公交车', | ||||||
|  |             'motor': '摩托车' | ||||||
|  |         } | ||||||
|  |         # 为每个类别设置固定的RGB颜色 | ||||||
|  |         self.color_mapping = { | ||||||
|  |             'pedestrian': (71, 0, 36),      # 勃艮第红 | ||||||
|  |             'people': (0, 255, 0),          # 绿色 | ||||||
|  |             'bicycle': (0, 49, 83),         # 普鲁士蓝 | ||||||
|  |             'car': (0, 47, 167),            # 克莱茵蓝 | ||||||
|  |             'van': (128, 0, 128),           # 紫色 | ||||||
|  |             'truck': (212, 72, 72),         # 缇香红 | ||||||
|  |             'tricycle': (0, 49, 83),        # 橙色 | ||||||
|  |             'awning-tricycle': (251, 220, 106), # 申布伦黄 | ||||||
|  |             'bus': (73, 45, 34),            # 凡戴克棕 | ||||||
|  |             'motor': (1, 132, 127)          # 马尔斯绿 | ||||||
|  |         } | ||||||
|  |         # 初始化类别计数器 | ||||||
|  |         self.class_counts = {cls_name: 0 for cls_name in self.class_name_mapping.keys()} | ||||||
|  |         # 初始化字体 | ||||||
|  |         try: | ||||||
|  |             self.font = ImageFont.truetype("simhei.ttf", 20) | ||||||
|  |         except IOError: | ||||||
|  |             self.font = ImageFont.load_default() | ||||||
|  |  | ||||||
|  |     def detect_and_draw_English(self, frame, conf=0.3, iou=0.5): | ||||||
|  |         """ | ||||||
|  |         对输入帧进行目标检测并返回绘制结果 | ||||||
|  |          | ||||||
|  |         Args: | ||||||
|  |             frame: 输入的图像帧(BGR格式) | ||||||
|  |             conf: 置信度阈值 | ||||||
|  |             iou: IOU阈值 | ||||||
|  |          | ||||||
|  |         Returns: | ||||||
|  |             annotated_frame: 绘制了检测结果的图像帧 | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             # 进行 YOLO 目标检测 | ||||||
|  |             results = self.model( | ||||||
|  |                 frame, | ||||||
|  |                 conf=conf, | ||||||
|  |                 iou=iou, | ||||||
|  |                 half=True, | ||||||
|  |             ) | ||||||
|  |             result = results[0] | ||||||
|  |              | ||||||
|  |             # 使用YOLO自带的绘制功能 | ||||||
|  |             annotated_frame = result.plot() | ||||||
|  |              | ||||||
|  |             return annotated_frame | ||||||
|  |              | ||||||
|  |         except Exception as e: | ||||||
|  |             print(f"Detection error: {e}") | ||||||
|  |             return frame | ||||||
|  |  | ||||||
|  |     def detect_and_draw_Chinese(self, frame, conf=0.2, iou=0.3): | ||||||
|  |         """ | ||||||
|  |         对输入帧进行目标检测并绘制中文标注 | ||||||
|  |          | ||||||
|  |         Args: | ||||||
|  |             frame: 输入的图像帧(BGR格式) | ||||||
|  |             conf: 置信度阈值 | ||||||
|  |             iou: IOU阈值 | ||||||
|  |          | ||||||
|  |         Returns: | ||||||
|  |             annotated_frame: 绘制了检测结果的图像帧 | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             # 进行 YOLO 目标检测 | ||||||
|  |             results = self.model( | ||||||
|  |                 frame, | ||||||
|  |                 conf=conf, | ||||||
|  |                 iou=iou, | ||||||
|  |                 # half=True, | ||||||
|  |             ) | ||||||
|  |             result = results[0] | ||||||
|  |              | ||||||
|  |             # 获取原始帧的副本 | ||||||
|  |             img = frame.copy() | ||||||
|  |              | ||||||
|  |             # 转换为PIL图像以绘制中文 | ||||||
|  |             pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) | ||||||
|  |             draw = ImageDraw.Draw(pil_img) | ||||||
|  |              | ||||||
|  |             # 绘制检测结果 | ||||||
|  |             for box in result.boxes: | ||||||
|  |                 # 获取边框坐标 | ||||||
|  |                 x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int) | ||||||
|  |                  | ||||||
|  |                 # 获取类别ID和置信度 | ||||||
|  |                 cls_id = int(box.cls[0].item()) | ||||||
|  |                 conf = box.conf[0].item() | ||||||
|  |                  | ||||||
|  |                 # 获取类别名称并转换为中文 | ||||||
|  |                 cls_name = result.names[cls_id] | ||||||
|  |                 chinese_name = self.class_name_mapping.get(cls_name, cls_name) | ||||||
|  |                  | ||||||
|  |                 # 获取该类别的颜色 | ||||||
|  |                 color = self.color_mapping.get(cls_name, (255, 255, 255)) | ||||||
|  |                  | ||||||
|  |                 # 绘制边框 | ||||||
|  |                 draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=3) | ||||||
|  |                  | ||||||
|  |                 # 准备标签文本 | ||||||
|  |                 text = f"{chinese_name} {conf:.2f}" | ||||||
|  |                 text_size = draw.textbbox((0, 0), text, font=self.font) | ||||||
|  |                 text_width = text_size[2] - text_size[0] | ||||||
|  |                 text_height = text_size[3] - text_size[1] | ||||||
|  |                  | ||||||
|  |                 # 绘制标签背景(使用与边框相同的颜色) | ||||||
|  |                 draw.rectangle( | ||||||
|  |                     [(x1, y1 - text_height - 4), (x1 + text_width, y1)], | ||||||
|  |                     fill=color | ||||||
|  |                 ) | ||||||
|  |                  | ||||||
|  |                 # 绘制白色文本 | ||||||
|  |                 draw.text( | ||||||
|  |                     (x1, y1 - text_height - 2), | ||||||
|  |                     text, | ||||||
|  |                     fill=(255, 255, 255),  # 白色文本 | ||||||
|  |                     font=self.font | ||||||
|  |                 ) | ||||||
|  |              | ||||||
|  |             # 转换回OpenCV格式 | ||||||
|  |             return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) | ||||||
|  |              | ||||||
|  |         except Exception as e: | ||||||
|  |             print(f"Detection error: {e}") | ||||||
|  |             return frame | ||||||
		Reference in New Issue
	
	Block a user