commit 25a2ded11dcc69023a5886f7e7f415425ee0d56f Author: Mashiro50070 <1251294066@qq.com> Date: Tue Aug 5 16:55:45 2025 +0800 参数分离完全版 diff --git a/000.txt b/000.txt new file mode 100644 index 0000000..e69de29 diff --git a/README.md b/README.md new file mode 100644 index 0000000..60d20ac --- /dev/null +++ b/README.md @@ -0,0 +1,31 @@ +### models 目录 + +- 存放模型文件 + + + +### frame_transfer.py + +- 从检测结果队列推送数据到 RTMP 服务器【不必修改】 +- 从原始队列拿取数据、调用 yolo_core 封装的方法进行检测【四类】 + + + +### rtc_handler.py + +- 从 WebRTC 实时视频流截取帧并持续推送到原始队列【不必修改】 + + + +### yolo_core.py + +- 封装四类方法【参数均为原始队列、检测结果队列】 +- 方法一:原始YOLO检测 +- 方法二:原始YOLO检测 + 汉化 + 颜色 +- 方法三:原始累计计数 +- 方法四:原始累计计数 + 汉化 + 颜色 + + + +**读取 WebRTC 流和推送结果帧的代码不需要修改** + diff --git a/api_server.py b/api_server.py new file mode 100644 index 0000000..1706497 --- /dev/null +++ b/api_server.py @@ -0,0 +1,389 @@ +import json +import os +import shutil +from typing import Dict, Optional, Any + +from fastapi import APIRouter, File, UploadFile, HTTPException, Form +from fastapi.responses import JSONResponse +from pydantic import BaseModel + +from result import Response +# 假设 rfdetr_core.py 在同一目录下或 PYTHONPATH 中 +from rfdetr_core import RFDETRDetector + +# --- Global Variables and Configuration --- +model_management_router = APIRouter() + +BASE_MODEL_DIR = "models" +BASE_CONFIG_DIR = "configs" +os.makedirs(BASE_MODEL_DIR, exist_ok=True) +os.makedirs(BASE_CONFIG_DIR, exist_ok=True) + +# 用于存储当前激活的检测器实例和可用模型信息 +current_detector: Optional[RFDETRDetector] = None +current_model_identifier: Optional[str] = None +available_models_info: Dict[str, Dict] = {} # 存储模型标识符及其配置内容 + + +def load_config(config_name: str) -> Optional[Dict]: + """加载指定的JSON配置文件。""" + config_path = os.path.join(BASE_CONFIG_DIR, f"{config_name}.json") + if os.path.exists(config_path): + try: + with open(config_path, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception as e: + print(f"错误:加载配置文件 '{config_path}' 失败: {e}") + return None + return None + + +def initialize_detector(model_identifier: str) -> Optional[RFDETRDetector]: + """根据模型标识符初始化检测器。""" + global current_detector, current_model_identifier + try: + config = available_models_info.get(model_identifier) + if not config: + print(f"错误:未找到模型 '{model_identifier}' 的缓存配置,尝试从磁盘加载。") # 更明确的日志 + config_data = load_config(model_identifier) + if not config_data: + raise FileNotFoundError(f"配置文件 {model_identifier}.json 未找到或无法加载。") + available_models_info[model_identifier] = config_data + config = config_data + + model_filename = config.get('model_pth_filename') + if not model_filename: + raise ValueError(f"配置文件 '{model_identifier}.json' 中缺少 'model_pth_filename' 字段。") + + model_full_path = os.path.join(BASE_MODEL_DIR, model_filename) + if not os.path.exists(model_full_path): + raise FileNotFoundError(f"模型文件 '{model_full_path}' (在 '{model_identifier}.json' 中指定) 不存在。") + + print(f"尝试使用配置 '{model_identifier}.json' 和模型 '{model_full_path}' 初始化检测器...") + detector = RFDETRDetector(config_name=model_identifier, base_model_dir=BASE_MODEL_DIR, + base_config_dir=BASE_CONFIG_DIR) + print(f"检测器 '{model_identifier}' 初始化成功。") + current_detector = detector + current_model_identifier = model_identifier + + # 通知 DataPusher 更新其检测器实例 + try: + from data_pusher import get_data_pusher_instance # 动态导入以避免潜在的循环依赖问题 + pusher = get_data_pusher_instance() + if pusher: + print(f"通知 DataPusher 更新其检测器实例为: {model_identifier}") + pusher.update_detector_instance(current_detector) + # else: + # 如果 pusher 为 None,可能是因为它尚未在主应用启动时被完全初始化 + # data_pusher 模块内部的 initialize_data_pusher 负责记录其自身初始化状态 + # print(f"DataPusher 实例尚未初始化 (或初始化失败),无法更新其检测器。") + except ImportError: + print( + "警告: 无法导入 data_pusher 模块以更新 DataPusher 的检测器实例。如果您不使用数据推送功能,此消息可忽略。") + except Exception as e_pusher: + print(f"警告: 通知 DataPusher 更新检测器时发生意外错误: {e_pusher}") + + return detector + except FileNotFoundError as e: + print(f"初始化检测器 '{model_identifier}' 失败 (文件未找到): {e}") + if current_model_identifier == model_identifier: + current_detector = None + current_model_identifier = None + raise HTTPException(status_code=404, detail=str(e)) # 保持 404,让上层知道是文件问题 + except ValueError as e: # 捕获配置字段缺失等问题 + print(f"初始化检测器 '{model_identifier}' 失败 (配置值错误): {e}") + if current_model_identifier == model_identifier: + current_detector = None + current_model_identifier = None + raise HTTPException(status_code=400, detail=str(e)) # 配置问题用 400 + except Exception as e: # 其他来自 RFDETRDetector 内部的初始化错误 + print(f"初始化检测器 '{model_identifier}' 失败 (内部错误): {e}") + # import traceback + # traceback.print_exc() # 在服务器日志中打印完整堆栈,方便调试 + if current_model_identifier == model_identifier: + current_detector = None + current_model_identifier = None + if model_identifier in available_models_info: + del available_models_info[model_identifier] + # 这些通常是服务器端问题或模型/库的兼容性问题,所以用500 + raise HTTPException(status_code=500, detail=f"检测器 '{model_identifier}' 内部初始化失败: {e}") + + +def get_active_detector() -> Optional[RFDETRDetector]: + """获取当前激活的检测器实例。""" + return current_detector + + +def get_active_model_identifier() -> Optional[str]: + """获取当前激活的模型标识符。""" + return current_model_identifier + + +def scan_and_load_available_models(): + """扫描配置目录,加载所有有效的模型配置。""" + global available_models_info + available_models_info = {} + # print(f"扫描目录 '{BASE_CONFIG_DIR}' 以查找配置文件...") # 减少日志冗余 + if not os.path.exists(BASE_CONFIG_DIR): + # print(f"配置目录 '{BASE_CONFIG_DIR}' 不存在,跳过扫描。") + return + for filename in os.listdir(BASE_CONFIG_DIR): + if filename.endswith(".json"): + model_identifier = filename[:-5] + # print(f"找到配置文件: {filename},模型标识符: {model_identifier}") + config_data = load_config(model_identifier) + if config_data: + model_pth = config_data.get('model_pth_filename') + model_full_path = os.path.join(BASE_MODEL_DIR, model_pth) if model_pth else None + if model_pth and os.path.exists(model_full_path): + available_models_info[model_identifier] = config_data + # print(f"模型配置 '{model_identifier}' 加载成功。") + # elif not model_pth: + # print(f"警告:模型 '{model_identifier}' 的配置文件中未指定 'model_pth_filename'。") + # else: + # print(f"警告:模型 '{model_identifier}' 的模型文件 '{model_full_path}' 未找到。") + # else: + # print(f"警告:无法加载模型 '{model_identifier}' 的配置文件。") + print(f"可用模型配置已扫描/更新: {list(available_models_info.keys())}") + + +# --- Extracted Startup Logic (to be called by the main app) --- +async def initialize_default_model_on_startup(): + """应用启动时,扫描并加载可用模型,尝试激活第一个。""" + print("执行模型管理模块的启动初始化...") + scan_and_load_available_models() + global current_model_identifier, current_detector + + if available_models_info: + first_model = sorted(list(available_models_info.keys()))[0] + print(f"尝试将第一个可用模型 '{first_model}' 设置为活动模型。") + try: + initialize_detector(first_model) + print(f"默认模型 '{current_model_identifier}' 加载并激活成功。") + except HTTPException as e: # 捕获 initialize_detector 抛出的HTTPException + print(f"加载默认模型 '{first_model}' 失败: {e.detail} (状态码: {e.status_code})") + current_model_identifier = None + current_detector = None + # 不需要再捕获 Exception as e,因为 initialize_detector 已经处理并转换为 HTTPException + else: + print("没有可用的模型配置,服务器启动但无默认模型激活。") + current_model_identifier = None + current_detector = None + + +# --- Pydantic Models for Request/Response --- + +class ModelIdentifier(BaseModel): + model_identifier: str + + +# Define a standard response model for OpenAPI documentation +class StandardResponse(BaseModel): + code: int + data: Optional[Any] = None + message: str + + +# --- API Endpoints --- + +@model_management_router.post("/upload_model_and_config/", response_model=StandardResponse) +async def upload_model_and_config( + model_identifier_form: str = Form(..., alias="model_identifier"), + config_file: UploadFile = File(...), + model_file: UploadFile = File(...) +): + """ + 上传模型文件 (.pth) 和配置文件 (.json)。 + - **model_identifier**: 模型的唯一名称 (例如 "人车检测")。配置文件将以此名称保存 (e.g., "人车检测.json")。 + - **config_file**: JSON 配置文件。 + - **model_file**: Pytorch 模型文件 (.pth)。其文件名必须与配置文件中 'model_pth_filename' 字段指定的一致。 + """ + config_filename = f"{model_identifier_form}.json" + config_path = os.path.join(BASE_CONFIG_DIR, config_filename) + model_path = None # 在 try 块外部声明,以便 finally 和 except 中可用 + global current_model_identifier, current_detector # current_detector 实际上在这里主要由 initialize_detector 设置 + + try: + print(f"开始上传模型 '{model_identifier_form}'...") + config_content = await config_file.read() + try: + config_data = json.loads(config_content.decode('utf-8')) + except json.JSONDecodeError: + # raise HTTPException(status_code=400, detail="无效的JSON配置文件格式,请检查JSON语法。") + return JSONResponse(status_code=400, + content=Response.error(message="无效的JSON配置文件格式,请检查JSON语法。", code=400)) + except UnicodeDecodeError: + # raise HTTPException(status_code=400, detail="配置文件编码错误,请确保为UTF-8编码。") + return JSONResponse(status_code=400, + content=Response.error(message="配置文件编码错误,请确保为UTF-8编码。", code=400)) + + if 'model_pth_filename' not in config_data: + # raise HTTPException(status_code=400, detail="配置文件中必须包含 'model_pth_filename' 字段。") + return JSONResponse(status_code=400, + content=Response.error(message="配置文件中必须包含 'model_pth_filename' 字段。", + code=400)) + + target_model_filename = config_data['model_pth_filename'] + if not isinstance(target_model_filename, str) or not target_model_filename.endswith(".pth"): + # raise HTTPException(status_code=400, detail="配置文件中的 'model_pth_filename' 必须是有效的 .pth 文件名字符串。") + return JSONResponse(status_code=400, content=Response.error( + message="配置文件中的 'model_pth_filename' 必须是有效的 .pth 文件名字符串。", code=400)) + + if model_file.filename != target_model_filename: + # raise HTTPException( + # status_code=400, + # detail=f"上传的模型文件名 '{model_file.filename}' 与配置文件中指定的 '{target_model_filename}' 不匹配。" + # ) + return JSONResponse( + status_code=400, + content=Response.error( + message=f"上传的模型文件名 '{model_file.filename}' 与配置文件中指定的 '{target_model_filename}' 不匹配。", + code=400 + ) + ) + + with open(config_path, 'wb') as f: + f.write(config_content) + print(f"配置文件 '{config_path}' 已保存。") + + model_path = os.path.join(BASE_MODEL_DIR, target_model_filename) + with open(model_path, "wb") as buffer: + shutil.copyfileobj(model_file.file, buffer) + print(f"模型文件 '{model_path}' 已保存。") + + # 在尝试初始化之前,将配置数据添加到 available_models_info + # 这样 initialize_detector 即使在缓存未命中的情况下也能通过 load_config 找到它 + available_models_info[model_identifier_form] = config_data + + try: + print(f"尝试初始化并激活新上传的模型: '{model_identifier_form}'") + initialize_detector(model_identifier_form) # 此函数会设置全局的 current_detector 和 current_model_identifier + print(f"新上传的模型 '{model_identifier_form}' 验证并激活成功。") + # 成功后,available_models_info 已包含此模型,current_detector 和 current_model_identifier 已更新 + # 无需在此处再次调用 scan_and_load_available_models(),因为它会重新扫描所有,可能覆盖内存中的一些状态或引入不必要的IO + except HTTPException as e: # 从 initialize_detector 抛出的错误 + print(f"初始化新上传的模型 '{model_identifier_form}' 失败: {e.detail}") + # 清理已保存的文件和内存中的条目 + if os.path.exists(config_path): os.remove(config_path) + if model_path and os.path.exists(model_path): os.remove(model_path) + if model_identifier_form in available_models_info: del available_models_info[model_identifier_form] + # scan_and_load_available_models() # 失败后重新扫描是好的,以确保 available_models_info 准确 + # 但如果 initialize_detector 内部已经删除了 available_models_info 中的条目,这里可能不需要 + # 为保持一致性,如果上面删除了,这里重新扫描一下比较稳妥 + scan_and_load_available_models() + + # 重新抛出为 422,并包含原始错误信息 + # raise HTTPException(status_code=422, + # detail=f"模型 '{model_identifier_form}' 已上传但初始化失败,请检查模型或配置内容。原始错误: {e.detail}") + return JSONResponse(status_code=422, + content=Response.error( + message=f"模型 '{model_identifier_form}' 已上传但初始化失败,请检查模型或配置内容。原始错误: {e.detail}", + code=422 + )) + + # 如果成功,确保全局模型列表是最新的(虽然 initialize_detector 已更新了 current_*,但列表可能需要刷新以供 /available_models 使用) + # scan_and_load_available_models() # 移除这个,因为当前模型已激活,列表会在下次调用 /available_models 时刷新 + + # return UploadResponse( + # message=f"模型 '{model_identifier_form}' 和配置文件上传成功并已激活。", + # model_identifier=model_identifier_form, + # config_filename=config_filename, + # model_filename=target_model_filename + # ) + return Response.success(data={ + "message": f"模型 '{model_identifier_form}' 和配置文件上传成功并已激活。", # Message is also in the wrapper + "model_identifier": model_identifier_form, + "config_filename": config_filename, + "model_filename": target_model_filename + }, message=f"模型 '{model_identifier_form}' 和配置文件上传成功并已激活。") + + except HTTPException as e: # 捕获直接由 FastAPI 验证或其他地方抛出的 HTTPException + # This might catch exceptions from initialize_detector if they are not caught internally by the above try-except for initialize_detector + return JSONResponse(status_code=e.status_code, content=Response.error(message=e.detail, code=e.status_code)) + except Exception as e: + identifier_for_log = model_identifier_form if 'model_identifier_form' in locals() else "unknown" + print(f"上传模型 '{identifier_for_log}' 过程中发生意外的严重错误: {e}") + # import traceback + # traceback.print_exc() + # 尝试清理,以防文件部分写入 + if config_path and os.path.exists(config_path): + os.remove(config_path) + if model_path and os.path.exists(model_path): + os.remove(model_path) + if 'model_identifier_form' in locals() and model_identifier_form in available_models_info: + del available_models_info[model_identifier_form] + scan_and_load_available_models() # 出错后务必刷新列表 + + # raise HTTPException(status_code=500, detail=f"上传模型 '{identifier_for_log}' 过程中发生内部服务器错误: {str(e)}") + return JSONResponse(status_code=500, content=Response.error( + message=f"上传模型 '{identifier_for_log}' 过程中发生内部服务器错误: {str(e)}", code=500)) + finally: + if config_file: await config_file.close() + if model_file: await model_file.close() + + +@model_management_router.post("/select_model/", response_model=StandardResponse) +# @model_management_router.post("/select_model/{model_identifier_path}", response_model=StandardResponse) +async def select_model(model_identifier_path: str): + """ + 根据提供的标识符选择并激活一个已上传的模型。 + 路径参数 `model_identifier_path` 即为模型的唯一名称。 + """ + global current_model_identifier, current_detector + + # 总是先扫描以获取最新的可用模型列表,以防外部文件更改 + scan_and_load_available_models() + if model_identifier_path not in available_models_info: + # raise HTTPException(status_code=404, detail=f"模型标识符 '{model_identifier_path}' 未在可用配置中找到。请确保已上传或名称正确。") + return JSONResponse(status_code=404, content=Response.error( + message=f"模型标识符 '{model_identifier_path}' 未在可用配置中找到。请确保已上传或名称正确。", code=404)) + + if current_model_identifier == model_identifier_path and current_detector is not None: + print(f"模型 '{model_identifier_path}' 已经是活动模型。") + # return SelectModelResponse( + # message=f"模型 '{model_identifier_path}' 已是当前活动模型。", + # active_model=current_model_identifier + # ) + return Response.success(data={ + "active_model": current_model_identifier + }, message=f"模型 '{model_identifier_path}' 已是当前活动模型。") + try: + print(f"尝试激活模型: {model_identifier_path}") + initialize_detector(model_identifier_path) + print(f"模型 '{current_model_identifier}' 已成功激活。") + # return SelectModelResponse( + # message=f"模型 '{model_identifier_path}' 启动成功。", + # active_model=current_model_identifier + # ) + return Response.success(data={ + "active_model": current_model_identifier + }, message=f"模型 '{model_identifier_path}' 启动成功。") + except HTTPException as e: + # initialize_detector 内部已经处理了 current_model_identifier 和 current_detector 的清理 + # 以及 available_models_info 中对应条目的移除(如果是其内部错误) + print(f"激活模型 '{model_identifier_path}' 失败: {e.detail} (原始状态码: {e.status_code})") + # 此处不再需要手动清理 available_models_info,因为 initialize_detector 如果因内部错误(非文件找不到) + # 导致无法实例化 RFDETRDetector,它会自己删除条目。 + # 如果是文件找不到 (404 from initialize_detector),条目可能还在,但下次 scan 会处理。 + scan_and_load_available_models() # 确保 select 失败后,列表也是最新的 + # raise HTTPException(status_code=e.status_code, detail=e.detail) # 重新抛出原始的 HTTPException + return JSONResponse(status_code=e.status_code, content=Response.error(message=e.detail, code=e.status_code)) + # 不需要再捕获 Exception as e,因为 initialize_detector 已经处理并转换为 HTTPException + + +@model_management_router.get("/available_models/", response_model=StandardResponse) +async def get_available_models(): + """列出所有当前可用的模型标识符。""" + scan_and_load_available_models() + # return list(available_models_info.keys()) + return Response.success(data=list(available_models_info.keys()), message="成功获取可用模型列表。") + + +@model_management_router.get("/current_model/", response_model=StandardResponse) +async def get_current_model_endpoint(): + """获取当前激活的模型标识符。""" + # return current_model_identifier + if current_model_identifier: + return Response.success(data=current_model_identifier, message="成功获取当前激活的模型。") + else: + return Response.success(data=None, message="当前没有激活的模型。") diff --git a/configs/人车.json b/configs/人车.json new file mode 100644 index 0000000..7641276 --- /dev/null +++ b/configs/人车.json @@ -0,0 +1,58 @@ +{ + "model_id": "1", + "model_pth_filename": "人车.pth", + "resolution": 448, + "classes_en": [ + "pedestrian", "person", "bicycle", "car", "van", "truck", "tricycle", "awning-tricycle", "bus", "motor" + ], + + "classes_zh_map": { + "pedestrian":"行人", + "person": "人群", + "bicycle": "自行车", + "car": "小汽车", + "van": "面包车", + "truck": "卡车", + "tricycle":"三轮车", + "awning-tricycle":"篷式三轮车", + "bus": "公交车", + "motor":"摩托车" + }, + + "class_colors_hex": { + "pedestrian": "#470024", + "person": "#00FF00", + "bicycle": "#003153", + "car": "#002FA7", + "van": "#800080", + "truck": "#D44848", + "tricycle": "#003153", + "awning-tricycle": "#FBDC6A", + "bus": "#492D22", + "motor": "#01847F" + }, + + "detection_settings": { + "enabled_classes": { + "pedestrian": true, + "person": true, + "bicycle": false, + "car": true, + "van": true, + "truck": true, + "tricycle": false, + "awning-tricycle": false, + "bus": true, + "motor": true + }, + "default_confidence_threshold": 0.7 + }, + + "default_color_hex": "#00FF00", + "tracker_activation_threshold": 0.5, + "tracker_lost_buffer": 120, + "tracker_match_threshold": 0.85, + "tracker_frame_rate": 25, + "tracker_consecutive_frames": 2 + } + \ No newline at end of file diff --git a/data_pusher.py b/data_pusher.py new file mode 100644 index 0000000..d026b04 --- /dev/null +++ b/data_pusher.py @@ -0,0 +1,248 @@ +import base64 +import cv2 +import numpy as np +import requests +import time +import datetime +import logging +from fastapi import APIRouter, HTTPException, Body # 切换到 APIRouter +from pydantic import BaseModel, HttpUrl +# import uvicorn # 不再由此文件运行uvicorn +from apscheduler.schedulers.background import BackgroundScheduler +from typing import Optional # 用于类型提示 +# 确保 RFDETRDetector 可以被导入,假设 rfdetr_core.py 在同一目录或 PYTHONPATH 中 +# from rfdetr_core import RFDETRDetector # 在实际使用中取消注释并确保路径正确 + +# 配置日志记录 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# app = FastAPI(title="Data Pusher Service", version="1.0.0") # 移除独立的 FastAPI app +pusher_router = APIRouter() # 创建一个 APIRouter + +class DataPusher: + def __init__(self, detector): # detector: RFDETRDetector + if detector is None: + logger.error("DataPusher initialized with a None detector. Push functionality will be impaired.") + # 仍然创建实例,但功能会受限,_get_data_payload 会处理 detector is None + self.detector = detector + self.scheduler = BackgroundScheduler(daemon=True) + self.push_job_id = "rt_push_job" + self.target_url = None + if not self.scheduler.running: + try: + self.scheduler.start() + except Exception as e: + logger.error(f"Error starting APScheduler in DataPusher: {e}") + + + def update_detector_instance(self, detector): + """允许在运行时更新检测器实例,例如当主应用切换模型时""" + logger.info(f"DataPusher's detector instance is being updated.") + self.detector = detector + if detector is None: + logger.warning("DataPusher's detector instance updated to None.") + + def _get_data_payload(self): + """获取当前的类别计数和最新标注的帧""" + if self.detector is None: + logger.warning("DataPusher: Detector not available. Cannot get data payload.") + return { # 即使检测器不可用,也返回一个结构,包含空数据 + # "timestamp": time.time(), + "category_counts": {}, + "frame_base64": None, + "error": "Detector not available" + } + + category_counts = getattr(self.detector, 'category_counts', {}) + # 如果 detector 存在但没有 last_annotated_frame (例如模型刚加载还没处理第一帧) + last_frame_np = getattr(self.detector, 'last_annotated_frame', None) + + frame_base64 = None + if last_frame_np is not None and isinstance(last_frame_np, np.ndarray): + try: + _, buffer = cv2.imencode('.jpg', last_frame_np) + frame_base64 = base64.b64encode(buffer).decode('utf-8') + except Exception as e: + logger.error(f"Error encoding frame to base64: {e}") + + return { + # "timestamp": time.time(), + "category_counts": category_counts, + "frame_base64": frame_base64 + } + + def _push_data_task(self): + """执行数据推送的任务""" + if not self.target_url: + # logger.warning("Target URL not set. Skipping push task.") # 减少日志噪音,仅在初次设置时记录 + return + + payload = self._get_data_payload() + # if payload is None: # _get_data_payload 现在总会返回一个字典 + # logger.warning("No payload to push.") + # return + + try: + response = requests.post(self.target_url, json=payload, timeout=5) + response.raise_for_status() + logger.debug(f"Data pushed successfully to {self.target_url}. Status: {response.status_code}") # 改为 debug 级别 + except requests.exceptions.RequestException as e: + logger.error(f"Error pushing data to {self.target_url}: {e}") + except Exception as e: + logger.error(f"An unexpected error occurred during data push: {e}") + + def setup_push_schedule(self, frequency: float, target_url: str): + """设置或更新推送计划""" + if not isinstance(frequency, (int, float)) or frequency <= 0: + raise ValueError("Frequency must be a positive number (pushes per second).") + + self.target_url = str(target_url) + interval_seconds = 1.0 / frequency + + if not self.scheduler.running: # 确保调度器正在运行 + try: + logger.info("APScheduler was not running. Attempting to start it now.") + self.scheduler.start() + except Exception as e: + logger.error(f"Failed to start APScheduler in setup_push_schedule: {e}") + raise RuntimeError(f"APScheduler could not be started: {e}") + + + try: + if self.scheduler.get_job(self.push_job_id): + self.scheduler.remove_job(self.push_job_id) + logger.info(f"Removed existing push job: {self.push_job_id}") + except Exception as e: + logger.error(f"Error removing existing job: {e}") + + first_run_time = datetime.datetime.now() + datetime.timedelta(seconds=10) + self.scheduler.add_job( + self._push_data_task, + trigger='interval', + seconds=interval_seconds, + id=self.push_job_id, + next_run_time=first_run_time, + replace_existing=True + ) + logger.info(f"Push task scheduled to {self.target_url} every {interval_seconds:.2f}s, starting in 10s.") + + def stop_push_schedule(self): + """停止数据推送任务""" + if self.scheduler.get_job(self.push_job_id): + try: + self.scheduler.remove_job(self.push_job_id) + logger.info(f"Push job {self.push_job_id} stopped successfully.") + self.target_url = None # 清除目标 URL + except Exception as e: + logger.error(f"Error stopping push job {self.push_job_id}: {e}") + else: + logger.info("No active push job to stop.") + + def shutdown_scheduler(self): + """安全关闭调度器""" + if self.scheduler.running: + try: + self.scheduler.shutdown() + logger.info("DataPusher's APScheduler shut down successfully.") + except Exception as e: + logger.error(f"Error shutting down DataPusher's APScheduler: {e}") + + def push_specific_payload(self, payload: dict): + """推送一个特定的、已格式化的数据负载到配置的 target_url。""" + if not self.target_url: + logger.warning("DataPusher: Target URL not set. Cannot push specific payload.") + return + + if not payload: + logger.warning("DataPusher: Received empty payload for specific push. Skipping.") + return + + logger.info(f"Attempting to push specific payload to {self.target_url}") + try: + response = requests.post(self.target_url, json=payload, timeout=10) # Increased timeout for one-off + response.raise_for_status() + logger.info(f"Specific payload pushed successfully to {self.target_url}. Status: {response.status_code}") + except requests.exceptions.RequestException as e: + logger.error(f"Error pushing specific payload to {self.target_url}: {e}") + except Exception as e: + logger.error(f"An unexpected error occurred during specific payload push: {e}") + +# 全局 DataPusher 实例,将由主应用初始化 +data_pusher_instance: Optional[DataPusher] = None + +# --- FastAPI Request Body Model --- +class PushConfigRequest(BaseModel): + frequency: float + url: HttpUrl + +# --- FastAPI HTTP Endpoint (using APIRouter) --- +@pusher_router.post("/setup_push", summary="配置数据推送任务") +async def handle_setup_push(config: PushConfigRequest = Body(...)): + global data_pusher_instance + if data_pusher_instance is None: + # 这个错误理论上不应该发生,如果主应用正确初始化了 data_pusher_instance + logger.error("CRITICAL: /setup_push called but data_pusher_instance is None. Main app did not initialize it.") + raise HTTPException(status_code=503, detail="DataPusher service not available. Initialization may have failed.") + + if config.frequency <= 0: # Pydantic v2 中可以直接用 gt=0 + raise HTTPException(status_code=400, detail="Invalid frequency value. Must be a positive number.") + + try: + data_pusher_instance.setup_push_schedule(config.frequency, str(config.url)) + return { + "message": "Push task configured successfully.", + "frequency_hz": config.frequency, + "interval_seconds": 1.0 / config.frequency, + "target_url": str(config.url), + "first_push_delay_seconds": 10 + } + except ValueError as ve: + raise HTTPException(status_code=400, detail=str(ve)) + except RuntimeError as re: # 例如 APScheduler 启动失败 + logger.error(f"Runtime error during push schedule setup: {re}") + raise HTTPException(status_code=500, detail=str(re)) + except Exception as e: + logger.error(f"Error setting up push schedule: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") + +@pusher_router.post("/stop_push", summary="停止当前数据推送任务") +async def handle_stop_push(): + global data_pusher_instance + if data_pusher_instance is None: + logger.error("CRITICAL: /stop_push called but data_pusher_instance is None.") + raise HTTPException(status_code=503, detail="DataPusher service not available.") + + try: + data_pusher_instance.stop_push_schedule() + return {"message": "Push task stopped successfully if it was running."} + except Exception as e: + logger.error(f"Error stopping push schedule: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Internal server error while stopping push: {str(e)}") + + +# --- Initialization and Shutdown Functions for Main App --- +def initialize_data_pusher(detector_instance_param): # Renamed to avoid conflict + """ + 由主应用程序调用以创建和配置 DataPusher 实例。 + """ + global data_pusher_instance + if data_pusher_instance is None: + logger.info("Initializing DataPusher instance...") + data_pusher_instance = DataPusher(detector_instance_param) + else: + logger.info("DataPusher instance already initialized. Updating detector instance if provided.") + data_pusher_instance.update_detector_instance(detector_instance_param) + return data_pusher_instance + +def get_data_pusher_instance() -> Optional[DataPusher]: + """获取 DataPusher 实例 (主要用于主应用可能需要访问它的其他方法,如 shutdown)""" + return data_pusher_instance + +# 移除 if __name__ == '__main__' 和 run_pusher_service,因为不再独立运行 +# 示例代码可以移至主应用的文档或测试脚本中。 + +# 注意: +# RFDETRDetector 实例的生命周期由 api_server.py (current_detector) 管理。 +# 当 api_server.py 中的模型切换时,需要有一种机制来更新 DataPusher 内部的 detector 引用。 +# initialize_data_pusher 可以被多次调用 (例如,在模型切换后),它会更新 DataPusher 持有的 detector 实例。 diff --git a/font/MSYH.TTC b/font/MSYH.TTC new file mode 100644 index 0000000..ea174b2 Binary files /dev/null and b/font/MSYH.TTC differ diff --git a/frame_transfer.py b/frame_transfer.py new file mode 100644 index 0000000..6c81a14 --- /dev/null +++ b/frame_transfer.py @@ -0,0 +1,226 @@ +import base64 +import requests as http_client +import queue +import time # 确保 time 被导入,如果之前被误删 +import cv2 +import traceback +import threading # 确保导入 threading +import av # 重新导入 av +import numpy as np + +# from fastapi import requests + +# 从 rfdetr_core 导入 RFDETRDetector 仅用于类型提示 (可选) +from rfdetr_core import RFDETRDetector + + +# 目标检测处理函数 +# 函数签名已更改:现在接受一个 detector_instance 作为参数 +def yolo_frame(rtc_q: queue.Queue, yolo_q: queue.Queue, stream_detector_instance: RFDETRDetector): + thread_name = threading.current_thread().name # 获取线程名称用于日志 + print(f"处理线程 '{thread_name}' 已启动。") + error_message_displayed_once = False + no_detector_message_displayed_once = False # 用于只提示一次没有检测器 + + if stream_detector_instance is None: + print(f"错误 (线程 '{thread_name}'): 未提供有效的检测器实例给yolo_frame。线程将无法处理视频。") + # 此线程实际上无法做任何有用的工作,可以考虑直接退出或进入一个安全循环 + # 为简单起见,我们允许它进入主循环,但它会在每次迭代时打印警告 + + while True: + frame = None + # current_category_counts = {} # 将在获取后转换 + try: + # 恢复队列长度打印 + print(f"线程 '{thread_name}' - 原始队列长度: {rtc_q.qsize()}, 检测队列长度: {yolo_q.qsize()}") + + frame = rtc_q.get(timeout=0.1) + if frame is None: + print(f"处理线程 '{thread_name}' 接收到停止信号,正在退出...") + # 发送包含None frame和空计数的字典作为停止信号 + yolo_q.put({"frame": None, "category_counts": {}}) + break + + category_counts_for_packet = {} + if stream_detector_instance: + no_detector_message_displayed_once = False # 检测器有效,重置提示 + annotated_frame = stream_detector_instance.detect_and_draw_count(frame) + error_message_displayed_once = False + + # 获取英文键的类别计数 + english_counts = stream_detector_instance.category_counts.copy() if hasattr(stream_detector_instance, 'category_counts') else {} + + # 转换为中文键的类别计数 + if hasattr(stream_detector_instance, 'VISDRONE_CLASSES_CHINESE'): + chinese_map = stream_detector_instance.VISDRONE_CLASSES_CHINESE + for eng_key, count_val in english_counts.items(): + # 使用 get 提供一个默认值,以防某个英文类别在中文映射表中确实没有 + chi_key = chinese_map.get(eng_key, eng_key) + category_counts_for_packet[chi_key] = count_val + else: + # 如果没有中文映射表,则直接使用英文计数 (或记录警告) + category_counts_for_packet = english_counts + # logger.warning(f"线程 '{thread_name}': stream_detector_instance 没有 VISDRONE_CLASSES_CHINESE 属性,将使用英文类别计数。") + + else: + # 如果没有有效的检测器实例传递进来 + if not no_detector_message_displayed_once: + print(f"警告 (线程 '{thread_name}'): 无有效检测器实例。将在帧上绘制提示。") + no_detector_message_displayed_once = True + + annotated_frame = frame.copy() + cv2.putText(annotated_frame, + "No detector instance provided for this stream", + (30, 50), cv2.FONT_HERSHEY_SIMPLEX, + 1, (0, 0, 255), 2, cv2.LINE_AA) + category_counts_for_packet = {} # 无检测器,计数为空 + + # 将帧和类别计数一起放入队列 + data_packet = {"frame": annotated_frame, "category_counts": category_counts_for_packet} + try: + yolo_q.put_nowait(data_packet) + except queue.Full: + # print(f"警告 (线程 '{thread_name}'): yolo_q 已满,丢弃帧。") # 避免刷屏 + pass + + except queue.Empty: + time.sleep(0.01) + continue + except Exception as e: + if not error_message_displayed_once: + print(f"线程 '{thread_name}' (yolo_frame) 处理时发生严重错误: {e}") + traceback.print_exc() + error_message_displayed_once = True + time.sleep(1) + if frame is not None: + try: + pass + except queue.Full: + pass + continue + print(f"处理线程 '{thread_name}' 已停止。") + + +def push_frame(yolo_q: queue.Queue, rtmp_url: str, gateway: str, frequency: int, push_url: str): + thread_name = threading.current_thread().name + print(f"推流线程 '{thread_name}' (RTMP: {rtmp_url}) 已启动。") + + output_container = None + stream = None + first_frame_processed = False + last_push_time = 0 # 记录上次推送base64的时间 + + try: + while True: + frame_to_push = None + received_category_counts = {} # 初始化为空字典 + try: + data_packet = yolo_q.get(timeout=0.1) + if data_packet: + frame_to_push = data_packet.get("frame") + received_category_counts = data_packet.get("category_counts", {}) + else: # data_packet is None (不太可能,除非队列明确放入None) + time.sleep(0.01) + continue + + except queue.Empty: + time.sleep(0.01) + continue + + if frame_to_push is None: # 这是通过 data_packet["frame"] is None 来判断的 + print(f"推流线程 '{thread_name}' 接收到停止信号,正在清理并退出...") + break + + if not first_frame_processed: + if frame_to_push is not None: + try: + height, width, _ = frame_to_push.shape + print(f"线程 '{thread_name}': 首帧尺寸 {width}x{height},正在初始化RTMP推流器到 {rtmp_url}") + output_container = av.open(rtmp_url, 'w', format='flv') + stream = output_container.add_stream('libx264', rate=25) + stream.pix_fmt = 'yuv420p' + stream.width = width + stream.height = height + stream.options = {'preset': 'ultrafast', 'tune': 'zerolatency', 'crf': '25'} + print(f"线程 '{thread_name}': RTMP推流器初始化成功。") + first_frame_processed = True + except Exception as e_init: + print(f"错误 (线程 '{thread_name}'): 初始化PyAV推流容器/流失败: {e_init}") + traceback.print_exc() + return + else: + continue + + if not output_container or not stream: + print(f"错误 (线程 '{thread_name}'): 推流器未初始化,无法推流。可能是首帧处理失败。") + time.sleep(1) + continue + + # 持续推流到RTMP + try: + video_frame = av.VideoFrame.from_ndarray(frame_to_push, format='bgr24') + for packet in stream.encode(video_frame): + output_container.mux(packet) + except Exception as e_push: + print(f"错误 (线程 '{thread_name}'): 推送帧到RTMP时发生错误: {e_push}") + time.sleep(0.5) + + # 定时推送base64帧 + current_time = time.time() + if current_time - last_push_time >= frequency: + # 将接收到的类别计数传递给 push_base64_frame + push_base64_frame(frame_to_push, gateway, push_url, thread_name, received_category_counts) + last_push_time = current_time + + except Exception as e_outer: + print(f"推流线程 '{thread_name}' 发生严重外部错误: {e_outer}") + traceback.print_exc() + finally: + print(f"推流线程 '{thread_name}': 进入finally块,准备关闭推流器。") + if stream and output_container: + try: + print(f"推流线程 '{thread_name}': 正在编码流的剩余部分...") + for packet in stream.encode(None): + output_container.mux(packet) + print(f"推流线程 '{thread_name}': 编码剩余部分完成。") + except Exception as e_flush: + print(f"错误 (线程 '{thread_name}'): 关闭推流流时发生编码/刷新错误: {e_flush}") + traceback.print_exc() + if output_container: + try: + print(f"推流线程 '{thread_name}': 正在关闭推流容器...") + output_container.close() + print(f"推流线程 '{thread_name}': 推流容器已关闭。") + except Exception as e_close: + print(f"错误 (线程 '{thread_name}'): 关闭推流容器时发生错误: {e_close}") + traceback.print_exc() + print(f"推流线程 '{thread_name}' 已停止并完成清理。") + + +def push_base64_frame(frame: np.ndarray, gateway: str, push_url: str, thread_name: str, category_counts: dict): + """将帧转换为base64并推送到指定URL""" + try: + # 转换为JPEG格式 + _, buffer = cv2.imencode('.jpg', frame) + # 转换为base64字符串 + frame_base64 = base64.b64encode(buffer).decode('utf-8') + + # 构建JSON数据 + data = { + "gateway": gateway, + "frame_base64": frame_base64, + "category_counts": category_counts # 使用传入的 category_counts + } + + print(f"DEBUG push_base64_frame: Pushing data: {data.get('category_counts')}") # 调试打印,检查发送的数据 + # 发送POST请求 + response = http_client.post(push_url, json=data, timeout=5) + + # 检查响应 + if response.status_code == 200: + print(f"线程 '{thread_name}': base64帧已成功推送到 {push_url}") + else: + print(f"错误 (线程 '{thread_name}'): 推送base64帧失败,状态码: {response.status_code}") + + except Exception as e: + print(f"错误 (线程 '{thread_name}'): 处理或推送base64帧时发生错误: {e}") \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..0bb3e1c --- /dev/null +++ b/main.py @@ -0,0 +1,5 @@ +import uvicorn +from web import app + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/models/人车.pth b/models/人车.pth new file mode 100644 index 0000000..9341a8c Binary files /dev/null and b/models/人车.pth differ diff --git a/requirements - 副本.txt b/requirements - 副本.txt new file mode 100644 index 0000000..b2d2d7d --- /dev/null +++ b/requirements - 副本.txt @@ -0,0 +1,27 @@ +aiohttp==3.11.14 +aiortc==1.11.0 +aiosignal==1.3.2 +APScheduler==3.11.0 +av==14.2.0 +fastapi==0.115.11 +huggingface-hub==0.30.1 +numpy==2.1.1 +nvidia-cuda-runtime-cu12==12.8.90 +opencv-contrib-python==4.11.0.86 +opencv-python==4.11.0.86 +pillow==11.1.0 +pillow_heif==0.22.0 +pycuda==2025.1 +pydantic==2.10.6 +pydantic_core==2.27.2 +requests==2.32.3 +requests-toolbelt==1.0.0 +rfdetr==1.1.0 +safetensors==0.5.3 +supervision==0.25.1 +torch==2.6.0+cu126 +torchaudio==2.6.0+cu126 +torchvision==0.21.0+cu126 +transformers==4.50.3 +uvicorn==0.34.0 +wandb==0.19.9 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..c401283 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,233 @@ +absl-py==2.2.1 +accelerate==1.6.0 +addict==2.4.0 +aiohappyeyeballs==2.6.1 +aiohttp==3.11.14 +aioice==0.9.0 +aiortc==1.11.0 +aiosignal==1.3.2 +albucore==0.0.23 +albumentations==2.0.5 +altgraph==0.17.4 +annotated-types==0.7.0 +anyio==4.8.0 +anywidget==0.9.18 +appdirs==1.4.4 +APScheduler==3.11.0 +asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1733250440834/work +async-timeout==5.0.1 +attrs==25.3.0 +av==14.2.0 +basicsr==1.4.2 +bbox_visualizer==0.2.0 +blind-watermark==0.4.4 +certifi==2025.1.31 +cffi==1.17.1 +charset-normalizer==3.4.1 +click==8.1.8 +cmake==3.31.6 +colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1733218098505/work +coloredlogs==15.0.1 +comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1733502965406/work +contourpy==1.3.1 +cryptography==44.0.2 +cycler==0.12.1 +Cython==3.0.12 +debugpy @ file:///D:/bld/debugpy_1741148401445/work +decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1740384970518/work +deep-sort-realtime==1.3.2 +defusedxml==0.7.1 +dnspython==2.7.0 +docker-pycreds==0.4.0 +easydict==1.13 +einops==0.8.1 +et_xmlfile==2.0.0 +exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1733208806608/work +executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1733569351617/work +facexlib==0.3.0 +fairscale==0.4.13 +fastapi==0.115.11 +filelock==3.13.1 +filetype==1.2.0 +filterpy==1.4.5 +fire==0.7.0 +flatbuffers==25.2.10 +fonttools==4.56.0 +frozenlist==1.5.0 +fsspec==2024.6.1 +ftfy==6.3.1 +future==1.0.0 +gfpgan==1.3.8 +gitdb==4.0.12 +GitPython==3.1.44 +google-crc32c==1.7.1 +grpcio==1.71.0 +h11==0.14.0 +httptools==0.6.4 +huggingface-hub==0.30.1 +humanfriendly==10.0 +idna==3.7 +ifaddr==0.2.0 +imageio==2.37.0 +importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1737420181517/work +iniconfig==2.1.0 +insightface==0.7.3 +ipykernel @ file:///D:/bld/ipykernel_1719845595208/work +ipython @ file:///D:/bld/bld/rattler-build_ipython_1740856913/work +ipywidgets==8.1.5 +jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1733300866624/work +Jinja2==3.1.4 +joblib==1.4.2 +jupyter_bbox_widget==0.6.0 +jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1733440914442/work +jupyter_core @ file:///D:/bld/jupyter_core_1727163532151/work +jupyterlab_widgets==3.0.13 +kiwisolver==1.4.8 +lap==0.5.12 +lazy_loader==0.4 +llvmlite==0.44.0 +lmdb==1.6.2 +Mako==1.3.9 +Markdown==3.7 +markdown-it-py==3.0.0 +MarkupSafe==2.1.5 +matplotlib==3.10.1 +matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1733416936468/work +mdurl==0.1.2 +memory-profiler==0.61.0 +motmetrics==1.4.0 +mpmath==1.3.0 +multidict==6.2.0 +nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1733325553580/work +networkx==3.3 +ninja==1.11.1.3 +Nuitka==2.6.7 +numba==0.61.0 +numpy==2.1.1 +nvidia-cuda-runtime-cu12==12.8.90 +onnx==1.16.1 +onnx-graphsurgeon==0.5.7 +onnxruntime-gpu==1.20.2 +onnxsim==0.4.36 +onnxslim==0.1.48 +open_clip_torch==2.32.0 +opencv-contrib-python==4.11.0.86 +opencv-python==4.11.0.86 +openpyxl==3.1.5 +ordered-set==4.1.0 +packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1733203243479/work +pandas==2.2.3 +parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1733271261340/work +pefile==2023.2.7 +peft==0.15.1 +pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1733327343728/work +pillow==11.1.0 +pillow_heif==0.22.0 +platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1733232627818/work +pluggy==1.5.0 +polygraphy==0.49.20 +prettytable==3.15.1 +prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1737453357274/work +propcache==0.3.1 +protobuf==3.20.2 +psutil @ file:///D:/bld/psutil_1740663160591/work +psygnal==0.12.0 +pure_eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1733569405015/work +py-cpuinfo==9.0.0 +pybboxes==0.1.6 +pycocotools==2.0.8 +pycparser==2.22 +pycuda==2025.1 +pydantic==2.10.6 +pydantic_core==2.27.2 +pyee==13.0.0 +Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1736243443484/work +pyinstaller==6.12.0 +pyinstaller-hooks-contrib==2025.1 +pylabel==0.1.55 +pylibsrtp==0.11.0 +PyMuPDF==1.25.4 +pyOpenSSL==25.0.0 +pyparsing==3.2.1 +pyproj==3.7.1 +pyreadline3==3.5.4 +PySide6==6.8.2.1 +PySide6_Addons==6.8.2.1 +PySide6_Essentials==6.8.2.1 +pytest==8.3.5 +python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1733215673016/work +python-dotenv==1.0.1 +python-multipart==0.0.20 +pytools==2025.1.2 +pytz==2025.1 +PyWavelets==1.8.0 +pywin32==307 +pywin32-ctypes==0.2.3 +PyYAML==6.0.2 +pyzmq @ file:///D:/bld/pyzmq_1738270977186/work +realesrgan==0.3.0 +regex==2024.11.6 +requests==2.32.3 +requests-toolbelt==1.0.0 +rf100vl==1.0.0 +rfdetr==1.1.0 +rich==14.0.0 +roboflow==1.1.60 +safetensors==0.5.3 +sahi==0.11.22 +scikit-image==0.25.2 +scikit-learn==1.6.1 +scipy==1.15.2 +seaborn==0.13.2 +sentry-sdk==2.25.1 +setproctitle==1.3.5 +shapely==2.0.7 +shiboken6==6.8.2.1 +simsimd==6.2.1 +six @ file:///home/conda/feedstock_root/build_artifacts/six_1733380938961/work +smmap==5.0.2 +sniffio==1.3.1 +stack_data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1733569443808/work +starlette==0.46.1 +stringzilla==3.12.2 +supervision==0.25.1 +sympy==1.13.1 +tabulate==0.9.0 +tb-nightly==2.20.0a20250326 +tensorboard-data-server==0.7.2 +tensorrt==10.9.0.34 +tensorrt_cu12==10.9.0.34 +tensorrt_cu12_bindings==10.9.0.34 +tensorrt_cu12_libs==10.9.0.34 +termcolor==2.5.0 +terminaltables==3.1.10 +threadpoolctl==3.5.0 +tifffile==2025.2.18 +timm==1.0.15 +tokenizers==0.21.1 +tomli==2.2.1 +torch==2.6.0+cu126 +torchaudio==2.6.0+cu126 +torchvision==0.21.0+cu126 +tornado @ file:///D:/bld/tornado_1732615925919/work +tqdm==4.67.1 +traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1733367359838/work +transformers==4.50.3 +typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1733188668063/work +tzdata==2025.1 +tzlocal==5.3.1 +ultralytics==8.3.99 +ultralytics-thop==2.0.14 +urllib3==2.3.0 +uvicorn==0.34.0 +wandb==0.19.9 +watchfiles==1.0.4 +wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1733231326287/work +websockets==15.0.1 +Werkzeug==3.1.3 +widgetsnbextension==4.0.13 +xmltodict==0.14.2 +yapf==0.43.0 +yarl==1.18.3 +zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1732827521216/work +zstandard==0.23.0 diff --git a/result.py b/result.py new file mode 100644 index 0000000..aaee6e7 --- /dev/null +++ b/result.py @@ -0,0 +1,21 @@ +# 定义返回值类 +class Response: + def __init__(self, code, data, message): + self.code = code + self.data = data + self.message = message + + def to_dict(self): + return { + "code": self.code, + "data": self.data, + "message": self.message + } + + @classmethod + def success(cls, data=None, message="操作成功"): + return cls(200, data, message).to_dict() + + @classmethod + def error(cls, data=None, message="操作失败", code=400): + return cls(code, data, message).to_dict() diff --git a/rfdetr_core.py b/rfdetr_core.py new file mode 100644 index 0000000..4116525 --- /dev/null +++ b/rfdetr_core.py @@ -0,0 +1,254 @@ +import cv2 +import supervision as sv +from rfdetr import RFDETRBase +from collections import defaultdict +from typing import Dict, Set +from PIL import Image, ImageDraw, ImageFont # 导入PIL库 +import numpy as np # 导入numpy用于图像格式转换 +import json # 新增 +import os # 新增 + +class RFDETRDetector: + def __init__(self, config_name: str, base_model_dir="models", base_config_dir="configs", default_font_path="./font/MSYH.TTC", default_font_size=15): + self.config_path = os.path.join(base_config_dir, f"{config_name}.json") + if not os.path.exists(self.config_path): + raise FileNotFoundError(f"配置文件不存在: {self.config_path}") + + with open(self.config_path, 'r', encoding='utf-8') as f: + self.config = json.load(f) + + model_path = os.path.join(base_model_dir, self.config['model_pth_filename']) + resolution = self.config['resolution'] + + # 从配置读取字体路径和大小,如果未提供则使用默认值 + font_path = self.config.get('font_path', default_font_path) + font_size = self.config.get('font_size', default_font_size) + + # 1. 初始化模型 + self.model = RFDETRBase( + pretrain_weights=model_path, + # pretrain_weights=model_path or r"E:\A\rf-detr-main\output\pre-train1\checkpoint_best_ema.pth", + resolution=resolution + ) + + # 2. 初始化跟踪器 + self.tracker = sv.ByteTrack( + track_activation_threshold=self.config['tracker_activation_threshold'], + lost_track_buffer=self.config['tracker_lost_buffer'], + minimum_matching_threshold=self.config['tracker_match_threshold'], + minimum_consecutive_frames=self.config['tracker_consecutive_frames'], + frame_rate=self.config['tracker_frame_rate'] + ) + + # 3. 类别定义 + self.VISDRONE_CLASSES = self.config['classes_en'] + self.VISDRONE_CLASSES_CHINESE = self.config['classes_zh_map'] + + # 新增:加载类别启用配置 + self.detection_settings = self.config.get('detection_settings', {}) + self.enabled_classes_filter = self.detection_settings.get('enabled_classes', {}) + # 构建一个查找表,对于未在filter中指定的类别,默认为 True (启用) + self._active_classes_lookup = { + cls_name: self.enabled_classes_filter.get(cls_name, True) + for cls_name in self.VISDRONE_CLASSES + } + print(f"活动类别配置: {self._active_classes_lookup}") + + # 4. 初始化字体 + self.FONT_SIZE = font_size + try: + self.font = ImageFont.truetype(font_path, self.FONT_SIZE) + except IOError: + print(f"错误:无法加载字体 {font_path}。将使用默认字体。") + self.font = ImageFont.load_default() # 使用真正通用的默认字体 + + # 5. 类别计数器 (作为类属性) + self.class_tracks: Dict[str, Set[int]] = defaultdict(set) + self.category_counts: Dict[str, int] = defaultdict(int) + + # 6. 初始化标注器 + # 从配置加载默认颜色,如果失败则使用预设颜色 + self.default_color_hex = self.config.get('default_color_hex', "#00FF00") # 默认绿色 + self.bounding_box_thickness = self.config.get('bounding_box_thickness', 2) + + # 加载颜色配置,用于 PIL 绘制 + self.class_colors_hex = self.config.get('class_colors_hex', {}) + self.last_annotated_frame: np.ndarray | None = None # 新增: 用于存储最新的标注帧 + + def _hex_to_rgb(self, hex_color: str) -> tuple: + """将十六进制颜色字符串转换为RGB元组。""" + hex_color = hex_color.lstrip('#') + try: + return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4)) + except ValueError: + print(f"警告: 无法解析十六进制颜色 '{hex_color}', 将使用默认颜色。") + # 解析失败时返回一个默认颜色,例如红色 + return self._hex_to_rgb(self.default_color_hex if self.default_color_hex != hex_color else "#00FF00") + + def _update_counter(self, detections: sv.Detections): + """更新类别计数器""" + # 只统计有 tracker_id 的检测结果 + valid_indices = detections.tracker_id != None + if not np.any(valid_indices): # 处理 detections 为空或 tracker_id 都为 None 的情况 + return + + class_ids = detections.class_id[valid_indices] + track_ids = detections.tracker_id[valid_indices] + + for class_id, track_id in zip(class_ids, track_ids): + if track_id is None: # 跳过没有 tracker_id 的项 + continue + # 使用英文类别名作为内部 key + class_name = self.VISDRONE_CLASSES[class_id] + if track_id not in self.class_tracks[class_name]: + self.class_tracks[class_name].add(track_id) + self.category_counts[class_name] += 1 + + def _draw_frame(self, frame: np.ndarray, detections: sv.Detections) -> np.ndarray: + """使用PIL绘制检测框、中文标签和计数信息""" + + pil_image = Image.fromarray(cv2.cvtColor(frame.copy(), cv2.COLOR_BGR2RGB)) + draw = ImageDraw.Draw(pil_image) + + # --- 使用 PIL 绘制检测框和中文标签 --- + valid_indices = detections.tracker_id != None # 或直接使用 detections.xyxy 如果不过滤无 tracker_id 的 + if np.any(valid_indices): + boxes = detections.xyxy[valid_indices] + class_ids = detections.class_id[valid_indices] + # tracker_ids = detections.tracker_id[valid_indices] # 如果需要 tracker_id + + for box, class_id in zip(boxes, class_ids): + x1, y1, x2, y2 = map(int, box) + + english_label = self.VISDRONE_CLASSES[class_id] + chinese_label = self.VISDRONE_CLASSES_CHINESE.get(english_label, english_label) + + # 获取边界框颜色 + box_color_hex = self.class_colors_hex.get(english_label, self.default_color_hex) + box_rgb_color = self._hex_to_rgb(box_color_hex) + + # 绘制边界框 + draw.rectangle([x1, y1, x2, y2], outline=box_rgb_color, width=self.bounding_box_thickness) + + # 绘制中文标签 (与之前逻辑类似) + text_to_draw = f"{chinese_label}" + # 标签背景 (可选,使其更易读) + # label_text_bbox = draw.textbbox((0,0), text_to_draw, font=self.font) + # label_width = label_text_bbox[2] - label_text_bbox[0] + # label_height = label_text_bbox[3] - label_text_bbox[1] + # label_bg_y1 = y1 - label_height - 4 if y1 - label_height - 4 > 0 else y1 + 2 + # draw.rectangle([x1, label_bg_y1, x1 + label_width + 4, label_bg_y1 + label_height + 2], fill=box_rgb_color) + # text_color = (255,255,255) if sum(box_rgb_color) < 382 else (0,0,0) # 简易对比色 + text_color = (255, 255, 255) # 白色 (RGB) + + text_x = x1 + 2 # 稍微偏移,避免紧贴边框 + text_y = y1 - self.FONT_SIZE - 2 + if text_y < 0: # 如果标签超出图像顶部 + text_y = y1 + 2 + + draw.text((text_x, text_y), text_to_draw, font=self.font, fill=text_color) + + # --- 绘制统计面板 (右上角) --- + stats_text_lines = [ + f"{self.VISDRONE_CLASSES_CHINESE.get(cls, cls)}: {self.category_counts[cls]}" + for cls in self.VISDRONE_CLASSES if self.category_counts[cls] > 0 + ] + + frame_height, frame_width, _ = frame.shape + stats_start_x = frame_width - self.config.get('stats_panel_width', 200) + stats_start_y = self.config.get('stats_panel_margin_y', 10) + line_height = self.FONT_SIZE + self.config.get('stats_line_spacing', 5) + + stats_text_color_hex = self.config.get('stats_text_color_hex', "#FFFFFF") + stats_text_color = self._hex_to_rgb(stats_text_color_hex) + + # 可选:为统计面板添加背景 + if stats_text_lines: + panel_height = len(stats_text_lines) * line_height + 10 + panel_y2 = stats_start_y + panel_height + # 半透明背景 + # overlay = Image.new('RGBA', pil_image.size, (0,0,0,0)) + # panel_draw = ImageDraw.Draw(overlay) + # panel_draw.rectangle( + # [stats_start_x - 5, stats_start_y - 5, frame_width - 5, panel_y2], + # fill=(100, 100, 100, 128) # 半透明灰色 + # ) + # pil_image = Image.alpha_composite(pil_image.convert('RGBA'), overlay) + # draw = ImageDraw.Draw(pil_image) # 如果用了 alpha_composite, 需要重新获取 draw 对象 + + # 或者简单不透明背景 + # draw.rectangle( + # [stats_start_x - 5, stats_start_y - 5, frame_width - 5, panel_y2], + # fill=self._hex_to_rgb(self.config.get('stats_panel_bg_color_hex', "#808080")) # 例如灰色背景 + # ) + + for i, line in enumerate(stats_text_lines): + text_pos = (stats_start_x, stats_start_y + i * line_height) + draw.text(text_pos, line, font=self.font, fill=stats_text_color) + + final_annotated_frame = cv2.cvtColor(np.array(pil_image.convert('RGB')), cv2.COLOR_RGB2BGR) + return final_annotated_frame + + def detect_and_draw_count(self, frame: np.ndarray, conf: float = -1.0) -> np.ndarray: + """执行单帧检测、跟踪、计数并绘制结果(包含类别过滤)。""" + if conf == -1.0: + # 优先从 detection_settings 中获取,其次是顶层config,最后是硬编码默认值 + effective_conf = float( + self.detection_settings.get('default_confidence_threshold', + self.config.get('default_confidence_threshold', 0.8)) + ) + else: + effective_conf = conf + + try: + # 1. 执行检测 + detections = self.model.predict(frame, threshold=effective_conf) + + # 处理 detections 为 None 或空的情况 + if detections is None or len(detections) == 0: + detections = sv.Detections.empty() + annotated_frame = self._draw_frame(frame, detections) + self.last_annotated_frame = annotated_frame.copy() # 新增 + return annotated_frame + + # 新增:根据配置过滤检测到的类别 + if detections is not None and len(detections) > 0: + keep_indices = [] + for i, class_id in enumerate(detections.class_id): + if class_id < len(self.VISDRONE_CLASSES): # 确保 class_id 有效 + class_name = self.VISDRONE_CLASSES[class_id] + if self._active_classes_lookup.get(class_name, True): # 默认为 True + keep_indices.append(i) + else: + print(f"警告: 检测到无效的 class_id {class_id},超出了已知类别范围。") + + if not keep_indices: + detections = sv.Detections.empty() + else: + detections = detections[keep_indices] + + # 如果过滤后没有检测结果 + if len(detections) == 0: + annotated_frame = self._draw_frame(frame, sv.Detections.empty()) + self.last_annotated_frame = annotated_frame.copy() # 新增 + return annotated_frame + + # 2. 执行跟踪 (只对过滤后的结果进行跟踪) + detections = self.tracker.update_with_detections(detections) + + # 3. 更新计数器 (只对过滤并跟踪后的结果进行计数) + self._update_counter(detections) + + # 4. 绘制结果 + annotated_frame = self._draw_frame(frame, detections) + self.last_annotated_frame = annotated_frame.copy() # 新增 + + return annotated_frame + + except Exception as e: + print(f"处理帧时发生错误: {e}") + if frame is not None: + self.last_annotated_frame = frame.copy() # 新增 + else: + self.last_annotated_frame = None # 新增 + return frame \ No newline at end of file diff --git a/rtc_handler.py b/rtc_handler.py new file mode 100644 index 0000000..e558409 --- /dev/null +++ b/rtc_handler.py @@ -0,0 +1,94 @@ +import asyncio +import queue +from fractions import Fraction +from urllib.parse import urlparse + +import aiohttp +import av +import numpy as np +from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration, VideoStreamTrack + + +class DummyVideoTrack(VideoStreamTrack): + async def recv(self): + # 简洁初始化、返回固定颜色的帧 + return np.full((480, 640, 3), (0, 0, 255), dtype=np.uint8) + + +async def receive_video_frames(whep_url): + pc = RTCPeerConnection(RTCConfiguration(iceServers=[])) + frames_queue = asyncio.Queue() + + pc.addTrack(DummyVideoTrack()) + + @pc.on("track") + def on_track(track): + if track.kind == "video": + asyncio.create_task(consume_track(track, frames_queue)) + + @pc.on("iceconnectionstatechange") + def on_ice_connection_state_change(): + print(f"ICE 连接状态: {pc.iceConnectionState}") + + offer = await pc.createOffer() + await pc.setLocalDescription(offer) + + headers = {"Content-Type": "application/sdp"} + + async with aiohttp.ClientSession() as session: + async with session.post(whep_url, data=pc.localDescription.sdp, headers=headers) as response: + if response.status != 201: + raise Exception(f"服务器返回错误: {response.status}") + + answer = RTCSessionDescription(sdp=await response.text(), type="answer") + await pc.setRemoteDescription(answer) + + if "Location" in response.headers: + base_url = f"{urlparse(whep_url).scheme}://{urlparse(whep_url).netloc}" + print("ICE 协商 URL:", base_url + response.headers["Location"]) + + while pc.iceConnectionState not in ["connected", "completed"]: + await asyncio.sleep(1) + + print("ICE 连接完成,开始接收视频流") + + try: + while True: + frame = await frames_queue.get() + if frame is None: + break + yield frame + except KeyboardInterrupt: + pass + finally: + await pc.close() + + +async def consume_track(track, frames_queue): + try: + while True: + frame = await track.recv() + if frame is None: + print("没有接收到有效的帧数据") + await frames_queue.put(None) + break + img = frame.to_ndarray(format="bgr24") + await frames_queue.put(img) + except Exception as e: + print("处理帧错误:", e) + await frames_queue.put(None) + + +def rtc_frame(url, frame_queue): + async def main(): + async for frame in receive_video_frames(url): + try: + frame_queue.put_nowait(frame) + except queue.Full: + frame_queue.get_nowait() + frame_queue.put_nowait(frame) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(main()) + loop.close() diff --git a/web.py b/web.py new file mode 100644 index 0000000..d1b3d9c --- /dev/null +++ b/web.py @@ -0,0 +1,276 @@ +import base64 +import hashlib +import queue +import time +from multiprocessing import Process, Event, Queue as MpQueue +from typing import Dict, Any + +import cv2 +import numpy as np +from fastapi import FastAPI, Request, HTTPException + +from api_server import model_management_router, initialize_default_model_on_startup, get_active_detector, \ + get_active_model_identifier +from data_pusher import pusher_router, initialize_data_pusher, get_data_pusher_instance +from frame_transfer import yolo_frame, push_frame +from result import Response +from rfdetr_core import RFDETRDetector +from rtc_handler import rtc_frame + +app = FastAPI(title="Real-time Video Processing, Model Management, and Data Pusher API") + +app.include_router(model_management_router, prefix="/api", tags=["Model Management"]) +app.include_router(pusher_router, prefix="/api/pusher", tags=["Data Pusher"]) + +process_map: Dict[str, Dict[str, Any]] = {} + + +@app.on_event("startup") +async def web_app_startup_event(): + print("主应用服务启动中...") + await initialize_default_model_on_startup() + active_detector_instance = get_active_detector() + if active_detector_instance: + print(f"主应用启动:检测到活动模型 '{get_active_model_identifier()}',将用于初始化 DataPusher。") + else: + print("主应用启动:未检测到活动模型。DataPusher 将以无检测器状态初始化。") + initialize_data_pusher(active_detector_instance) + print("DataPusher 服务已初始化。") + print("主应用启动完成。") + + +@app.on_event("shutdown") +async def web_app_shutdown_event(): + print("主应用服务关闭中...") + pusher = get_data_pusher_instance() + if pusher: + print("正在关闭 DataPusher 调度器...") + pusher.shutdown_scheduler() + else: + print("DataPusher 实例未找到,跳过调度器关闭。") + + print("正在尝试终止所有活动的视频处理子进程...") + for url, task_info in list(process_map.items()): + process: Process = task_info['process'] + stop_event: Event = task_info['stop_event'] + data_q: MpQueue = task_info['data_queue'] + + if process.is_alive(): + print(f"向进程 {url} 发送停止信号...") + stop_event.set() + try: + process.join(timeout=15) + if process.is_alive(): + print(f"进程 {url} 在优雅关闭超时后仍然存活,尝试 terminate。") + process.terminate() + process.join(timeout=5) + if process.is_alive(): + print(f"进程 {url} 在 terminate 后仍然存活,尝试 kill。") + process.kill() + process.join(timeout=2) + except Exception as e: + print(f"关闭/终止进程 {url} 时发生错误: {e}") + + try: + if not data_q.empty(): + pass + data_q.close() + data_q.join_thread() + except Exception as e_q_cleanup: + print(f"清理进程 {url} 的数据队列时出错: {e_q_cleanup}") + + del process_map[url] + print("所有视频处理子进程已尝试终止和清理。") + print("主应用服务关闭完成。") + + +def start_video_processing(url: str, rtmp_url: str, model_config_name: str, stop_event: Event, data_queue: MpQueue, + gateway: str,frequency:int, push_url:str): + print(f"视频处理子进程启动 (URL: {url}, Model: {model_config_name})") + detector_instance_for_stream: RFDETRDetector = None + producer_thread, transfer_thread, consumer_thread = None, None, None + + try: + print(f"正在为流 {url} 初始化模型: {model_config_name}...") + detector_instance_for_stream = RFDETRDetector(config_name=model_config_name) + print(f"模型 {model_config_name} 为流 {url} 初始化成功。") + rtc_q = queue.Queue(maxsize=10000) + yolo_q = queue.Queue(maxsize=10000) + import threading + producer_thread = threading.Thread(target=rtc_frame, args=(url, rtc_q), name=f"RTC-{url[:20]}", daemon=True) + transfer_thread = threading.Thread(target=yolo_frame, args=(rtc_q, yolo_q, detector_instance_for_stream), + name=f"YOLO-{url[:20]}", daemon=True) + consumer_thread = threading.Thread(target=push_frame, args=(yolo_q, rtmp_url,gateway,frequency,push_url), name=f"Push-{url[:20]}", + daemon=True) + + producer_thread.start() + transfer_thread.start() + consumer_thread.start() + + stop_event.wait() + print(f"子进程 {url}: 收到停止信号。准备关闭线程...") + + except FileNotFoundError as e: + print(f"错误 (视频进程 {url}): 模型配置文件 '{model_config_name}.json' 未找到。错误: {e}") + except Exception as e: + print(f"错误 (视频进程 {url}): 初始化或运行时错误。错误: {e}") + finally: + print(f"视频处理子进程 {url} 进入 finally 块。") + + if producer_thread and producer_thread.is_alive(): + print(f"子进程 {url}: producer_thread is still alive (daemon).") + if transfer_thread and transfer_thread.is_alive(): + print(f"子进程 {url}: transfer_thread is still alive (daemon).") + if consumer_thread and consumer_thread.is_alive(): + print(f"子进程 {url}: consumer_thread is still alive (daemon).") + + if detector_instance_for_stream: + print(f"子进程 {url}: 收集最后数据...") + final_counts = getattr(detector_instance_for_stream, 'category_counts', {}) + final_frame_np = getattr(detector_instance_for_stream, 'last_annotated_frame', None) + frame_base64 = None + if final_frame_np is not None and isinstance(final_frame_np, np.ndarray): + try: + _, buffer = cv2.imencode('.jpg', final_frame_np) + frame_base64 = base64.b64encode(buffer).decode('utf-8') + except Exception as e_encode: + print(f"子进程 {url}: 帧编码错误: {e_encode}") + + payload = { + "timestamp": time.time(), + "category_counts": final_counts, + "frame_base64": frame_base64, + "source_url": url, + "event": "task_stopped_final_data" + } + try: + data_queue.put(payload, timeout=5) + print(f"子进程 {url}: 已将最终数据放入队列。") + except queue.Full: + print(f"子进程 {url}: 无法将最终数据放入队列 (队列已满或超时)。") + except Exception as e_put: + print(f"子进程 {url}: 将最终数据放入队列时发生错误: {e_put}") + else: + print(f"子进程 {url}: 检测器实例不可用,无法发送最终数据。") + + try: + data_queue.close() + except Exception as e_q_close: + print(f"子进程 {url}: 关闭数据队列时出错: {e_q_close}") + print(f"视频处理子进程 {url} 执行完毕。") + + +@app.post("/start_video", tags=["Video Processing"]) +async def start_video(request: Request): + data = await request.json() + url = data.get("url") + model_identifier_to_use = data.get("model_identifier") + host = data.get("host") + rtmp_port = data.get("rtmp_port") + rtc_port = data.get("rtc_port") + gateway = data.get("gateway") + frequency = data.get("frequency") + push_url = data.get("push_url") + + # 生成MD5 + md5_hash = hashlib.md5(url.encode()).hexdigest() + rtmp_url = f"rtmp://{host}:{rtmp_port}/live/{md5_hash}" + rtc_url = f"http://{host}:{rtc_port}/rtc/v1/whep/?{md5_hash}" + if not url or not rtmp_url: + raise HTTPException(status_code=400, detail="'url' 和 'rtmp_url' 字段是必须的。") + + if not model_identifier_to_use: + print(f"请求中未指定 model_identifier,尝试使用全局激活的模型。") + model_identifier_to_use = get_active_model_identifier() + if not model_identifier_to_use: + raise HTTPException(status_code=400, detail="请求中未指定 'model_identifier',且当前无全局激活的默认模型。") + print(f"将为流 {url} 使用当前全局激活的模型: {model_identifier_to_use}") + + if url in process_map and process_map[url]['process'].is_alive(): + raise HTTPException(status_code=409, detail=f"视频处理进程已在运行: {url}") + + print(f"请求启动视频处理: URL = {url}, RTMP = {rtmp_url}, Model = {model_identifier_to_use}") + + stop_event = Event() + data_queue = MpQueue(maxsize=1) + + process = Process(target=start_video_processing, + args=(url, rtmp_url, model_identifier_to_use, stop_event, data_queue, gateway, frequency, push_url)) + process.start() + process_map[url] = {'process': process, 'stop_event': stop_event, 'data_queue': data_queue} + return Response.success(message=f"视频处理已为 URL '{url}' 使用模型 '{model_identifier_to_use}' 启动。", + data=rtc_url) + + +@app.post("/stop_video", tags=["Video Processing"]) +async def stop_video(request: Request): + data = await request.json() + url = data.get("url") + if not url: + raise HTTPException(status_code=400, detail="'url' 字段是必须的。") + + task_info = process_map.get(url) + if not task_info: + raise HTTPException(status_code=404, detail=f"没有找到与 URL '{url}' 匹配的活动视频处理进程。") + + process: Process = task_info['process'] + stop_event: Event = task_info['stop_event'] + data_q: MpQueue = task_info['data_queue'] + + final_data_pushed = False + if process.is_alive(): + print(f"请求停止视频处理: {url}. 发送停止信号...") + stop_event.set() + process.join(timeout=20) + + if process.is_alive(): + print(f"警告: 视频处理进程 {url} 在超时后未能正常终止,尝试强制结束。") + process.terminate() + process.join(timeout=5) + if process.is_alive(): + print(f"错误: 视频处理进程 {url} 强制结束后仍然存在。尝试 kill。") + process.kill() + process.join(timeout=2) + else: + print(f"进程 {url} 已优雅停止。尝试获取最后数据...") + try: + final_payload = data_q.get(timeout=10) + print(f"从停止的任务 {url} 收到最终数据。") + + pusher_instance = get_data_pusher_instance() + if pusher_instance and pusher_instance.target_url: + print(f"准备将任务 {url} 的最后数据推送到 {pusher_instance.target_url}") + pusher_instance.push_specific_payload(final_payload) + final_data_pushed = True + elif pusher_instance: + print( + f"DataPusher 服务已配置,但未设置目标URL (pusher.target_url is None)。无法推送任务 {url} 的最后数据。") + else: + print(f"DataPusher 服务未初始化或不可用。无法推送任务 {url} 的最后数据。") + + except queue.Empty: + print(f"警告: 任务 {url} 优雅停止后,未从其数据队列中获取到最终数据 (队列为空或超时)。") + except Exception as e_q_get: + print(f"获取或处理来自任务 {url} 的最终数据时发生错误: {e_q_get}") + else: + print(f"视频处理进程先前已停止或已结束: {url}") + + try: + while not data_q.empty(): + try: + data_q.get_nowait() + except queue.Empty: + break + data_q.close() + data_q.join_thread() + except Exception as e_q_final_cleanup: + print(f"清理任务 {url} 的数据队列的最后步骤中发生错误: {e_q_final_cleanup}") + + del process_map[url] + message = f"视频处理已为 URL '{url}' 停止。" + if final_data_pushed: + message += " 已尝试推送最后的数据。" + elif process.exitcode == 0: + message += " 进程已退出,但未确认最后数据推送 (可能未配置推送或队列问题)。" + + return Response.success(message=message) diff --git a/yolo_core.py b/yolo_core.py new file mode 100644 index 0000000..5b0e9b8 --- /dev/null +++ b/yolo_core.py @@ -0,0 +1,151 @@ +from ultralytics import YOLO +import cv2 +import numpy as np +from PIL import Image, ImageDraw, ImageFont +import tensorrt as trt +import pycuda.driver as cuda +import pycuda.autoinit + +class YOLODetector: + def __init__(self, model_path='models/best.engine'): + # 加载 TensorRT 模型 + self.model = YOLO(model_path, task="detect") + # 英文类别名称到中文的映射 + self.class_name_mapping = { + 'pedestrian': '行人', + 'people': '人群', + 'bicycle': '自行车', + 'car': '轿车', + 'van': '面包车', + 'truck': '卡车', + 'tricycle': '三轮车', + 'awning-tricycle': '篷式三轮车', + 'bus': '公交车', + 'motor': '摩托车' + } + # 为每个类别设置固定的RGB颜色 + self.color_mapping = { + 'pedestrian': (71, 0, 36), # 勃艮第红 + 'people': (0, 255, 0), # 绿色 + 'bicycle': (0, 49, 83), # 普鲁士蓝 + 'car': (0, 47, 167), # 克莱茵蓝 + 'van': (128, 0, 128), # 紫色 + 'truck': (212, 72, 72), # 缇香红 + 'tricycle': (0, 49, 83), # 橙色 + 'awning-tricycle': (251, 220, 106), # 申布伦黄 + 'bus': (73, 45, 34), # 凡戴克棕 + 'motor': (1, 132, 127) # 马尔斯绿 + } + # 初始化类别计数器 + self.class_counts = {cls_name: 0 for cls_name in self.class_name_mapping.keys()} + # 初始化字体 + try: + self.font = ImageFont.truetype("simhei.ttf", 20) + except IOError: + self.font = ImageFont.load_default() + + def detect_and_draw_English(self, frame, conf=0.3, iou=0.5): + """ + 对输入帧进行目标检测并返回绘制结果 + + Args: + frame: 输入的图像帧(BGR格式) + conf: 置信度阈值 + iou: IOU阈值 + + Returns: + annotated_frame: 绘制了检测结果的图像帧 + """ + try: + # 进行 YOLO 目标检测 + results = self.model( + frame, + conf=conf, + iou=iou, + half=True, + ) + result = results[0] + + # 使用YOLO自带的绘制功能 + annotated_frame = result.plot() + + return annotated_frame + + except Exception as e: + print(f"Detection error: {e}") + return frame + + def detect_and_draw_Chinese(self, frame, conf=0.2, iou=0.3): + """ + 对输入帧进行目标检测并绘制中文标注 + + Args: + frame: 输入的图像帧(BGR格式) + conf: 置信度阈值 + iou: IOU阈值 + + Returns: + annotated_frame: 绘制了检测结果的图像帧 + """ + try: + # 进行 YOLO 目标检测 + results = self.model( + frame, + conf=conf, + iou=iou, + # half=True, + ) + result = results[0] + + # 获取原始帧的副本 + img = frame.copy() + + # 转换为PIL图像以绘制中文 + pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + draw = ImageDraw.Draw(pil_img) + + # 绘制检测结果 + for box in result.boxes: + # 获取边框坐标 + x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int) + + # 获取类别ID和置信度 + cls_id = int(box.cls[0].item()) + conf = box.conf[0].item() + + # 获取类别名称并转换为中文 + cls_name = result.names[cls_id] + chinese_name = self.class_name_mapping.get(cls_name, cls_name) + + # 获取该类别的颜色 + color = self.color_mapping.get(cls_name, (255, 255, 255)) + + # 绘制边框 + draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=3) + + # 准备标签文本 + text = f"{chinese_name} {conf:.2f}" + text_size = draw.textbbox((0, 0), text, font=self.font) + text_width = text_size[2] - text_size[0] + text_height = text_size[3] - text_size[1] + + # 绘制标签背景(使用与边框相同的颜色) + draw.rectangle( + [(x1, y1 - text_height - 4), (x1 + text_width, y1)], + fill=color + ) + + # 绘制白色文本 + draw.text( + (x1, y1 - text_height - 2), + text, + fill=(255, 255, 255), # 白色文本 + font=self.font + ) + + # 转换回OpenCV格式 + return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) + + except Exception as e: + print(f"Detection error: {e}") + return frame