参数分离完全版

This commit is contained in:
2025-08-05 16:55:45 +08:00
commit 25a2ded11d
16 changed files with 2013 additions and 0 deletions

0
000.txt Normal file
View File

31
README.md Normal file
View File

@ -0,0 +1,31 @@
### models 目录
- 存放模型文件
### frame_transfer.py
- 从检测结果队列推送数据到 RTMP 服务器【不必修改】
- 从原始队列拿取数据、调用 yolo_core 封装的方法进行检测【四类】
### rtc_handler.py
- 从 WebRTC 实时视频流截取帧并持续推送到原始队列【不必修改】
### yolo_core.py
- 封装四类方法【参数均为原始队列、检测结果队列】
- 方法一原始YOLO检测
- 方法二原始YOLO检测 + 汉化 + 颜色
- 方法三:原始累计计数
- 方法四:原始累计计数 + 汉化 + 颜色
**读取 WebRTC 流和推送结果帧的代码不需要修改**

389
api_server.py Normal file
View File

@ -0,0 +1,389 @@
import json
import os
import shutil
from typing import Dict, Optional, Any
from fastapi import APIRouter, File, UploadFile, HTTPException, Form
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from result import Response
# 假设 rfdetr_core.py 在同一目录下或 PYTHONPATH 中
from rfdetr_core import RFDETRDetector
# --- Global Variables and Configuration ---
model_management_router = APIRouter()
BASE_MODEL_DIR = "models"
BASE_CONFIG_DIR = "configs"
os.makedirs(BASE_MODEL_DIR, exist_ok=True)
os.makedirs(BASE_CONFIG_DIR, exist_ok=True)
# 用于存储当前激活的检测器实例和可用模型信息
current_detector: Optional[RFDETRDetector] = None
current_model_identifier: Optional[str] = None
available_models_info: Dict[str, Dict] = {} # 存储模型标识符及其配置内容
def load_config(config_name: str) -> Optional[Dict]:
"""加载指定的JSON配置文件。"""
config_path = os.path.join(BASE_CONFIG_DIR, f"{config_name}.json")
if os.path.exists(config_path):
try:
with open(config_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
print(f"错误:加载配置文件 '{config_path}' 失败: {e}")
return None
return None
def initialize_detector(model_identifier: str) -> Optional[RFDETRDetector]:
"""根据模型标识符初始化检测器。"""
global current_detector, current_model_identifier
try:
config = available_models_info.get(model_identifier)
if not config:
print(f"错误:未找到模型 '{model_identifier}' 的缓存配置,尝试从磁盘加载。") # 更明确的日志
config_data = load_config(model_identifier)
if not config_data:
raise FileNotFoundError(f"配置文件 {model_identifier}.json 未找到或无法加载。")
available_models_info[model_identifier] = config_data
config = config_data
model_filename = config.get('model_pth_filename')
if not model_filename:
raise ValueError(f"配置文件 '{model_identifier}.json' 中缺少 'model_pth_filename' 字段。")
model_full_path = os.path.join(BASE_MODEL_DIR, model_filename)
if not os.path.exists(model_full_path):
raise FileNotFoundError(f"模型文件 '{model_full_path}' (在 '{model_identifier}.json' 中指定) 不存在。")
print(f"尝试使用配置 '{model_identifier}.json' 和模型 '{model_full_path}' 初始化检测器...")
detector = RFDETRDetector(config_name=model_identifier, base_model_dir=BASE_MODEL_DIR,
base_config_dir=BASE_CONFIG_DIR)
print(f"检测器 '{model_identifier}' 初始化成功。")
current_detector = detector
current_model_identifier = model_identifier
# 通知 DataPusher 更新其检测器实例
try:
from data_pusher import get_data_pusher_instance # 动态导入以避免潜在的循环依赖问题
pusher = get_data_pusher_instance()
if pusher:
print(f"通知 DataPusher 更新其检测器实例为: {model_identifier}")
pusher.update_detector_instance(current_detector)
# else:
# 如果 pusher 为 None可能是因为它尚未在主应用启动时被完全初始化
# data_pusher 模块内部的 initialize_data_pusher 负责记录其自身初始化状态
# print(f"DataPusher 实例尚未初始化 (或初始化失败),无法更新其检测器。")
except ImportError:
print(
"警告: 无法导入 data_pusher 模块以更新 DataPusher 的检测器实例。如果您不使用数据推送功能,此消息可忽略。")
except Exception as e_pusher:
print(f"警告: 通知 DataPusher 更新检测器时发生意外错误: {e_pusher}")
return detector
except FileNotFoundError as e:
print(f"初始化检测器 '{model_identifier}' 失败 (文件未找到): {e}")
if current_model_identifier == model_identifier:
current_detector = None
current_model_identifier = None
raise HTTPException(status_code=404, detail=str(e)) # 保持 404让上层知道是文件问题
except ValueError as e: # 捕获配置字段缺失等问题
print(f"初始化检测器 '{model_identifier}' 失败 (配置值错误): {e}")
if current_model_identifier == model_identifier:
current_detector = None
current_model_identifier = None
raise HTTPException(status_code=400, detail=str(e)) # 配置问题用 400
except Exception as e: # 其他来自 RFDETRDetector 内部的初始化错误
print(f"初始化检测器 '{model_identifier}' 失败 (内部错误): {e}")
# import traceback
# traceback.print_exc() # 在服务器日志中打印完整堆栈,方便调试
if current_model_identifier == model_identifier:
current_detector = None
current_model_identifier = None
if model_identifier in available_models_info:
del available_models_info[model_identifier]
# 这些通常是服务器端问题或模型/库的兼容性问题所以用500
raise HTTPException(status_code=500, detail=f"检测器 '{model_identifier}' 内部初始化失败: {e}")
def get_active_detector() -> Optional[RFDETRDetector]:
"""获取当前激活的检测器实例。"""
return current_detector
def get_active_model_identifier() -> Optional[str]:
"""获取当前激活的模型标识符。"""
return current_model_identifier
def scan_and_load_available_models():
"""扫描配置目录,加载所有有效的模型配置。"""
global available_models_info
available_models_info = {}
# print(f"扫描目录 '{BASE_CONFIG_DIR}' 以查找配置文件...") # 减少日志冗余
if not os.path.exists(BASE_CONFIG_DIR):
# print(f"配置目录 '{BASE_CONFIG_DIR}' 不存在,跳过扫描。")
return
for filename in os.listdir(BASE_CONFIG_DIR):
if filename.endswith(".json"):
model_identifier = filename[:-5]
# print(f"找到配置文件: {filename},模型标识符: {model_identifier}")
config_data = load_config(model_identifier)
if config_data:
model_pth = config_data.get('model_pth_filename')
model_full_path = os.path.join(BASE_MODEL_DIR, model_pth) if model_pth else None
if model_pth and os.path.exists(model_full_path):
available_models_info[model_identifier] = config_data
# print(f"模型配置 '{model_identifier}' 加载成功。")
# elif not model_pth:
# print(f"警告:模型 '{model_identifier}' 的配置文件中未指定 'model_pth_filename'。")
# else:
# print(f"警告:模型 '{model_identifier}' 的模型文件 '{model_full_path}' 未找到。")
# else:
# print(f"警告:无法加载模型 '{model_identifier}' 的配置文件。")
print(f"可用模型配置已扫描/更新: {list(available_models_info.keys())}")
# --- Extracted Startup Logic (to be called by the main app) ---
async def initialize_default_model_on_startup():
"""应用启动时,扫描并加载可用模型,尝试激活第一个。"""
print("执行模型管理模块的启动初始化...")
scan_and_load_available_models()
global current_model_identifier, current_detector
if available_models_info:
first_model = sorted(list(available_models_info.keys()))[0]
print(f"尝试将第一个可用模型 '{first_model}' 设置为活动模型。")
try:
initialize_detector(first_model)
print(f"默认模型 '{current_model_identifier}' 加载并激活成功。")
except HTTPException as e: # 捕获 initialize_detector 抛出的HTTPException
print(f"加载默认模型 '{first_model}' 失败: {e.detail} (状态码: {e.status_code})")
current_model_identifier = None
current_detector = None
# 不需要再捕获 Exception as e因为 initialize_detector 已经处理并转换为 HTTPException
else:
print("没有可用的模型配置,服务器启动但无默认模型激活。")
current_model_identifier = None
current_detector = None
# --- Pydantic Models for Request/Response ---
class ModelIdentifier(BaseModel):
model_identifier: str
# Define a standard response model for OpenAPI documentation
class StandardResponse(BaseModel):
code: int
data: Optional[Any] = None
message: str
# --- API Endpoints ---
@model_management_router.post("/upload_model_and_config/", response_model=StandardResponse)
async def upload_model_and_config(
model_identifier_form: str = Form(..., alias="model_identifier"),
config_file: UploadFile = File(...),
model_file: UploadFile = File(...)
):
"""
上传模型文件 (.pth) 和配置文件 (.json)。
- **model_identifier**: 模型的唯一名称 (例如 "人车检测")。配置文件将以此名称保存 (e.g., "人车检测.json")。
- **config_file**: JSON 配置文件。
- **model_file**: Pytorch 模型文件 (.pth)。其文件名必须与配置文件中 'model_pth_filename' 字段指定的一致。
"""
config_filename = f"{model_identifier_form}.json"
config_path = os.path.join(BASE_CONFIG_DIR, config_filename)
model_path = None # 在 try 块外部声明,以便 finally 和 except 中可用
global current_model_identifier, current_detector # current_detector 实际上在这里主要由 initialize_detector 设置
try:
print(f"开始上传模型 '{model_identifier_form}'...")
config_content = await config_file.read()
try:
config_data = json.loads(config_content.decode('utf-8'))
except json.JSONDecodeError:
# raise HTTPException(status_code=400, detail="无效的JSON配置文件格式请检查JSON语法。")
return JSONResponse(status_code=400,
content=Response.error(message="无效的JSON配置文件格式请检查JSON语法。", code=400))
except UnicodeDecodeError:
# raise HTTPException(status_code=400, detail="配置文件编码错误请确保为UTF-8编码。")
return JSONResponse(status_code=400,
content=Response.error(message="配置文件编码错误请确保为UTF-8编码。", code=400))
if 'model_pth_filename' not in config_data:
# raise HTTPException(status_code=400, detail="配置文件中必须包含 'model_pth_filename' 字段。")
return JSONResponse(status_code=400,
content=Response.error(message="配置文件中必须包含 'model_pth_filename' 字段。",
code=400))
target_model_filename = config_data['model_pth_filename']
if not isinstance(target_model_filename, str) or not target_model_filename.endswith(".pth"):
# raise HTTPException(status_code=400, detail="配置文件中的 'model_pth_filename' 必须是有效的 .pth 文件名字符串。")
return JSONResponse(status_code=400, content=Response.error(
message="配置文件中的 'model_pth_filename' 必须是有效的 .pth 文件名字符串。", code=400))
if model_file.filename != target_model_filename:
# raise HTTPException(
# status_code=400,
# detail=f"上传的模型文件名 '{model_file.filename}' 与配置文件中指定的 '{target_model_filename}' 不匹配。"
# )
return JSONResponse(
status_code=400,
content=Response.error(
message=f"上传的模型文件名 '{model_file.filename}' 与配置文件中指定的 '{target_model_filename}' 不匹配。",
code=400
)
)
with open(config_path, 'wb') as f:
f.write(config_content)
print(f"配置文件 '{config_path}' 已保存。")
model_path = os.path.join(BASE_MODEL_DIR, target_model_filename)
with open(model_path, "wb") as buffer:
shutil.copyfileobj(model_file.file, buffer)
print(f"模型文件 '{model_path}' 已保存。")
# 在尝试初始化之前,将配置数据添加到 available_models_info
# 这样 initialize_detector 即使在缓存未命中的情况下也能通过 load_config 找到它
available_models_info[model_identifier_form] = config_data
try:
print(f"尝试初始化并激活新上传的模型: '{model_identifier_form}'")
initialize_detector(model_identifier_form) # 此函数会设置全局的 current_detector 和 current_model_identifier
print(f"新上传的模型 '{model_identifier_form}' 验证并激活成功。")
# 成功后available_models_info 已包含此模型current_detector 和 current_model_identifier 已更新
# 无需在此处再次调用 scan_and_load_available_models()因为它会重新扫描所有可能覆盖内存中的一些状态或引入不必要的IO
except HTTPException as e: # 从 initialize_detector 抛出的错误
print(f"初始化新上传的模型 '{model_identifier_form}' 失败: {e.detail}")
# 清理已保存的文件和内存中的条目
if os.path.exists(config_path): os.remove(config_path)
if model_path and os.path.exists(model_path): os.remove(model_path)
if model_identifier_form in available_models_info: del available_models_info[model_identifier_form]
# scan_and_load_available_models() # 失败后重新扫描是好的,以确保 available_models_info 准确
# 但如果 initialize_detector 内部已经删除了 available_models_info 中的条目,这里可能不需要
# 为保持一致性,如果上面删除了,这里重新扫描一下比较稳妥
scan_and_load_available_models()
# 重新抛出为 422并包含原始错误信息
# raise HTTPException(status_code=422,
# detail=f"模型 '{model_identifier_form}' 已上传但初始化失败,请检查模型或配置内容。原始错误: {e.detail}")
return JSONResponse(status_code=422,
content=Response.error(
message=f"模型 '{model_identifier_form}' 已上传但初始化失败,请检查模型或配置内容。原始错误: {e.detail}",
code=422
))
# 如果成功,确保全局模型列表是最新的(虽然 initialize_detector 已更新了 current_*,但列表可能需要刷新以供 /available_models 使用)
# scan_and_load_available_models() # 移除这个,因为当前模型已激活,列表会在下次调用 /available_models 时刷新
# return UploadResponse(
# message=f"模型 '{model_identifier_form}' 和配置文件上传成功并已激活。",
# model_identifier=model_identifier_form,
# config_filename=config_filename,
# model_filename=target_model_filename
# )
return Response.success(data={
"message": f"模型 '{model_identifier_form}' 和配置文件上传成功并已激活。", # Message is also in the wrapper
"model_identifier": model_identifier_form,
"config_filename": config_filename,
"model_filename": target_model_filename
}, message=f"模型 '{model_identifier_form}' 和配置文件上传成功并已激活。")
except HTTPException as e: # 捕获直接由 FastAPI 验证或其他地方抛出的 HTTPException
# This might catch exceptions from initialize_detector if they are not caught internally by the above try-except for initialize_detector
return JSONResponse(status_code=e.status_code, content=Response.error(message=e.detail, code=e.status_code))
except Exception as e:
identifier_for_log = model_identifier_form if 'model_identifier_form' in locals() else "unknown"
print(f"上传模型 '{identifier_for_log}' 过程中发生意外的严重错误: {e}")
# import traceback
# traceback.print_exc()
# 尝试清理,以防文件部分写入
if config_path and os.path.exists(config_path):
os.remove(config_path)
if model_path and os.path.exists(model_path):
os.remove(model_path)
if 'model_identifier_form' in locals() and model_identifier_form in available_models_info:
del available_models_info[model_identifier_form]
scan_and_load_available_models() # 出错后务必刷新列表
# raise HTTPException(status_code=500, detail=f"上传模型 '{identifier_for_log}' 过程中发生内部服务器错误: {str(e)}")
return JSONResponse(status_code=500, content=Response.error(
message=f"上传模型 '{identifier_for_log}' 过程中发生内部服务器错误: {str(e)}", code=500))
finally:
if config_file: await config_file.close()
if model_file: await model_file.close()
@model_management_router.post("/select_model/", response_model=StandardResponse)
# @model_management_router.post("/select_model/{model_identifier_path}", response_model=StandardResponse)
async def select_model(model_identifier_path: str):
"""
根据提供的标识符选择并激活一个已上传的模型。
路径参数 `model_identifier_path` 即为模型的唯一名称。
"""
global current_model_identifier, current_detector
# 总是先扫描以获取最新的可用模型列表,以防外部文件更改
scan_and_load_available_models()
if model_identifier_path not in available_models_info:
# raise HTTPException(status_code=404, detail=f"模型标识符 '{model_identifier_path}' 未在可用配置中找到。请确保已上传或名称正确。")
return JSONResponse(status_code=404, content=Response.error(
message=f"模型标识符 '{model_identifier_path}' 未在可用配置中找到。请确保已上传或名称正确。", code=404))
if current_model_identifier == model_identifier_path and current_detector is not None:
print(f"模型 '{model_identifier_path}' 已经是活动模型。")
# return SelectModelResponse(
# message=f"模型 '{model_identifier_path}' 已是当前活动模型。",
# active_model=current_model_identifier
# )
return Response.success(data={
"active_model": current_model_identifier
}, message=f"模型 '{model_identifier_path}' 已是当前活动模型。")
try:
print(f"尝试激活模型: {model_identifier_path}")
initialize_detector(model_identifier_path)
print(f"模型 '{current_model_identifier}' 已成功激活。")
# return SelectModelResponse(
# message=f"模型 '{model_identifier_path}' 启动成功。",
# active_model=current_model_identifier
# )
return Response.success(data={
"active_model": current_model_identifier
}, message=f"模型 '{model_identifier_path}' 启动成功。")
except HTTPException as e:
# initialize_detector 内部已经处理了 current_model_identifier 和 current_detector 的清理
# 以及 available_models_info 中对应条目的移除(如果是其内部错误)
print(f"激活模型 '{model_identifier_path}' 失败: {e.detail} (原始状态码: {e.status_code})")
# 此处不再需要手动清理 available_models_info因为 initialize_detector 如果因内部错误(非文件找不到)
# 导致无法实例化 RFDETRDetector它会自己删除条目。
# 如果是文件找不到 (404 from initialize_detector),条目可能还在,但下次 scan 会处理。
scan_and_load_available_models() # 确保 select 失败后,列表也是最新的
# raise HTTPException(status_code=e.status_code, detail=e.detail) # 重新抛出原始的 HTTPException
return JSONResponse(status_code=e.status_code, content=Response.error(message=e.detail, code=e.status_code))
# 不需要再捕获 Exception as e因为 initialize_detector 已经处理并转换为 HTTPException
@model_management_router.get("/available_models/", response_model=StandardResponse)
async def get_available_models():
"""列出所有当前可用的模型标识符。"""
scan_and_load_available_models()
# return list(available_models_info.keys())
return Response.success(data=list(available_models_info.keys()), message="成功获取可用模型列表。")
@model_management_router.get("/current_model/", response_model=StandardResponse)
async def get_current_model_endpoint():
"""获取当前激活的模型标识符。"""
# return current_model_identifier
if current_model_identifier:
return Response.success(data=current_model_identifier, message="成功获取当前激活的模型。")
else:
return Response.success(data=None, message="当前没有激活的模型。")

58
configs/人车.json Normal file
View File

@ -0,0 +1,58 @@
{
"model_id": "1",
"model_pth_filename": "人车.pth",
"resolution": 448,
"classes_en": [
"pedestrian", "person", "bicycle", "car", "van", "truck", "tricycle", "awning-tricycle", "bus", "motor"
],
"classes_zh_map": {
"pedestrian":"行人",
"person": "人群",
"bicycle": "自行车",
"car": "小汽车",
"van": "面包车",
"truck": "卡车",
"tricycle":"三轮车",
"awning-tricycle":"篷式三轮车",
"bus": "公交车",
"motor":"摩托车"
},
"class_colors_hex": {
"pedestrian": "#470024",
"person": "#00FF00",
"bicycle": "#003153",
"car": "#002FA7",
"van": "#800080",
"truck": "#D44848",
"tricycle": "#003153",
"awning-tricycle": "#FBDC6A",
"bus": "#492D22",
"motor": "#01847F"
},
"detection_settings": {
"enabled_classes": {
"pedestrian": true,
"person": true,
"bicycle": false,
"car": true,
"van": true,
"truck": true,
"tricycle": false,
"awning-tricycle": false,
"bus": true,
"motor": true
},
"default_confidence_threshold": 0.7
},
"default_color_hex": "#00FF00",
"tracker_activation_threshold": 0.5,
"tracker_lost_buffer": 120,
"tracker_match_threshold": 0.85,
"tracker_frame_rate": 25,
"tracker_consecutive_frames": 2
}

248
data_pusher.py Normal file
View File

@ -0,0 +1,248 @@
import base64
import cv2
import numpy as np
import requests
import time
import datetime
import logging
from fastapi import APIRouter, HTTPException, Body # 切换到 APIRouter
from pydantic import BaseModel, HttpUrl
# import uvicorn # 不再由此文件运行uvicorn
from apscheduler.schedulers.background import BackgroundScheduler
from typing import Optional # 用于类型提示
# 确保 RFDETRDetector 可以被导入,假设 rfdetr_core.py 在同一目录或 PYTHONPATH 中
# from rfdetr_core import RFDETRDetector # 在实际使用中取消注释并确保路径正确
# 配置日志记录
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# app = FastAPI(title="Data Pusher Service", version="1.0.0") # 移除独立的 FastAPI app
pusher_router = APIRouter() # 创建一个 APIRouter
class DataPusher:
def __init__(self, detector): # detector: RFDETRDetector
if detector is None:
logger.error("DataPusher initialized with a None detector. Push functionality will be impaired.")
# 仍然创建实例但功能会受限_get_data_payload 会处理 detector is None
self.detector = detector
self.scheduler = BackgroundScheduler(daemon=True)
self.push_job_id = "rt_push_job"
self.target_url = None
if not self.scheduler.running:
try:
self.scheduler.start()
except Exception as e:
logger.error(f"Error starting APScheduler in DataPusher: {e}")
def update_detector_instance(self, detector):
"""允许在运行时更新检测器实例,例如当主应用切换模型时"""
logger.info(f"DataPusher's detector instance is being updated.")
self.detector = detector
if detector is None:
logger.warning("DataPusher's detector instance updated to None.")
def _get_data_payload(self):
"""获取当前的类别计数和最新标注的帧"""
if self.detector is None:
logger.warning("DataPusher: Detector not available. Cannot get data payload.")
return { # 即使检测器不可用,也返回一个结构,包含空数据
# "timestamp": time.time(),
"category_counts": {},
"frame_base64": None,
"error": "Detector not available"
}
category_counts = getattr(self.detector, 'category_counts', {})
# 如果 detector 存在但没有 last_annotated_frame (例如模型刚加载还没处理第一帧)
last_frame_np = getattr(self.detector, 'last_annotated_frame', None)
frame_base64 = None
if last_frame_np is not None and isinstance(last_frame_np, np.ndarray):
try:
_, buffer = cv2.imencode('.jpg', last_frame_np)
frame_base64 = base64.b64encode(buffer).decode('utf-8')
except Exception as e:
logger.error(f"Error encoding frame to base64: {e}")
return {
# "timestamp": time.time(),
"category_counts": category_counts,
"frame_base64": frame_base64
}
def _push_data_task(self):
"""执行数据推送的任务"""
if not self.target_url:
# logger.warning("Target URL not set. Skipping push task.") # 减少日志噪音,仅在初次设置时记录
return
payload = self._get_data_payload()
# if payload is None: # _get_data_payload 现在总会返回一个字典
# logger.warning("No payload to push.")
# return
try:
response = requests.post(self.target_url, json=payload, timeout=5)
response.raise_for_status()
logger.debug(f"Data pushed successfully to {self.target_url}. Status: {response.status_code}") # 改为 debug 级别
except requests.exceptions.RequestException as e:
logger.error(f"Error pushing data to {self.target_url}: {e}")
except Exception as e:
logger.error(f"An unexpected error occurred during data push: {e}")
def setup_push_schedule(self, frequency: float, target_url: str):
"""设置或更新推送计划"""
if not isinstance(frequency, (int, float)) or frequency <= 0:
raise ValueError("Frequency must be a positive number (pushes per second).")
self.target_url = str(target_url)
interval_seconds = 1.0 / frequency
if not self.scheduler.running: # 确保调度器正在运行
try:
logger.info("APScheduler was not running. Attempting to start it now.")
self.scheduler.start()
except Exception as e:
logger.error(f"Failed to start APScheduler in setup_push_schedule: {e}")
raise RuntimeError(f"APScheduler could not be started: {e}")
try:
if self.scheduler.get_job(self.push_job_id):
self.scheduler.remove_job(self.push_job_id)
logger.info(f"Removed existing push job: {self.push_job_id}")
except Exception as e:
logger.error(f"Error removing existing job: {e}")
first_run_time = datetime.datetime.now() + datetime.timedelta(seconds=10)
self.scheduler.add_job(
self._push_data_task,
trigger='interval',
seconds=interval_seconds,
id=self.push_job_id,
next_run_time=first_run_time,
replace_existing=True
)
logger.info(f"Push task scheduled to {self.target_url} every {interval_seconds:.2f}s, starting in 10s.")
def stop_push_schedule(self):
"""停止数据推送任务"""
if self.scheduler.get_job(self.push_job_id):
try:
self.scheduler.remove_job(self.push_job_id)
logger.info(f"Push job {self.push_job_id} stopped successfully.")
self.target_url = None # 清除目标 URL
except Exception as e:
logger.error(f"Error stopping push job {self.push_job_id}: {e}")
else:
logger.info("No active push job to stop.")
def shutdown_scheduler(self):
"""安全关闭调度器"""
if self.scheduler.running:
try:
self.scheduler.shutdown()
logger.info("DataPusher's APScheduler shut down successfully.")
except Exception as e:
logger.error(f"Error shutting down DataPusher's APScheduler: {e}")
def push_specific_payload(self, payload: dict):
"""推送一个特定的、已格式化的数据负载到配置的 target_url。"""
if not self.target_url:
logger.warning("DataPusher: Target URL not set. Cannot push specific payload.")
return
if not payload:
logger.warning("DataPusher: Received empty payload for specific push. Skipping.")
return
logger.info(f"Attempting to push specific payload to {self.target_url}")
try:
response = requests.post(self.target_url, json=payload, timeout=10) # Increased timeout for one-off
response.raise_for_status()
logger.info(f"Specific payload pushed successfully to {self.target_url}. Status: {response.status_code}")
except requests.exceptions.RequestException as e:
logger.error(f"Error pushing specific payload to {self.target_url}: {e}")
except Exception as e:
logger.error(f"An unexpected error occurred during specific payload push: {e}")
# 全局 DataPusher 实例,将由主应用初始化
data_pusher_instance: Optional[DataPusher] = None
# --- FastAPI Request Body Model ---
class PushConfigRequest(BaseModel):
frequency: float
url: HttpUrl
# --- FastAPI HTTP Endpoint (using APIRouter) ---
@pusher_router.post("/setup_push", summary="配置数据推送任务")
async def handle_setup_push(config: PushConfigRequest = Body(...)):
global data_pusher_instance
if data_pusher_instance is None:
# 这个错误理论上不应该发生,如果主应用正确初始化了 data_pusher_instance
logger.error("CRITICAL: /setup_push called but data_pusher_instance is None. Main app did not initialize it.")
raise HTTPException(status_code=503, detail="DataPusher service not available. Initialization may have failed.")
if config.frequency <= 0: # Pydantic v2 中可以直接用 gt=0
raise HTTPException(status_code=400, detail="Invalid frequency value. Must be a positive number.")
try:
data_pusher_instance.setup_push_schedule(config.frequency, str(config.url))
return {
"message": "Push task configured successfully.",
"frequency_hz": config.frequency,
"interval_seconds": 1.0 / config.frequency,
"target_url": str(config.url),
"first_push_delay_seconds": 10
}
except ValueError as ve:
raise HTTPException(status_code=400, detail=str(ve))
except RuntimeError as re: # 例如 APScheduler 启动失败
logger.error(f"Runtime error during push schedule setup: {re}")
raise HTTPException(status_code=500, detail=str(re))
except Exception as e:
logger.error(f"Error setting up push schedule: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@pusher_router.post("/stop_push", summary="停止当前数据推送任务")
async def handle_stop_push():
global data_pusher_instance
if data_pusher_instance is None:
logger.error("CRITICAL: /stop_push called but data_pusher_instance is None.")
raise HTTPException(status_code=503, detail="DataPusher service not available.")
try:
data_pusher_instance.stop_push_schedule()
return {"message": "Push task stopped successfully if it was running."}
except Exception as e:
logger.error(f"Error stopping push schedule: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Internal server error while stopping push: {str(e)}")
# --- Initialization and Shutdown Functions for Main App ---
def initialize_data_pusher(detector_instance_param): # Renamed to avoid conflict
"""
由主应用程序调用以创建和配置 DataPusher 实例。
"""
global data_pusher_instance
if data_pusher_instance is None:
logger.info("Initializing DataPusher instance...")
data_pusher_instance = DataPusher(detector_instance_param)
else:
logger.info("DataPusher instance already initialized. Updating detector instance if provided.")
data_pusher_instance.update_detector_instance(detector_instance_param)
return data_pusher_instance
def get_data_pusher_instance() -> Optional[DataPusher]:
"""获取 DataPusher 实例 (主要用于主应用可能需要访问它的其他方法,如 shutdown)"""
return data_pusher_instance
# 移除 if __name__ == '__main__' 和 run_pusher_service因为不再独立运行
# 示例代码可以移至主应用的文档或测试脚本中。
# 注意:
# RFDETRDetector 实例的生命周期由 api_server.py (current_detector) 管理。
# 当 api_server.py 中的模型切换时,需要有一种机制来更新 DataPusher 内部的 detector 引用。
# initialize_data_pusher 可以被多次调用 (例如,在模型切换后),它会更新 DataPusher 持有的 detector 实例。

BIN
font/MSYH.TTC Normal file

Binary file not shown.

226
frame_transfer.py Normal file
View File

@ -0,0 +1,226 @@
import base64
import requests as http_client
import queue
import time # 确保 time 被导入,如果之前被误删
import cv2
import traceback
import threading # 确保导入 threading
import av # 重新导入 av
import numpy as np
# from fastapi import requests
# 从 rfdetr_core 导入 RFDETRDetector 仅用于类型提示 (可选)
from rfdetr_core import RFDETRDetector
# 目标检测处理函数
# 函数签名已更改:现在接受一个 detector_instance 作为参数
def yolo_frame(rtc_q: queue.Queue, yolo_q: queue.Queue, stream_detector_instance: RFDETRDetector):
thread_name = threading.current_thread().name # 获取线程名称用于日志
print(f"处理线程 '{thread_name}' 已启动。")
error_message_displayed_once = False
no_detector_message_displayed_once = False # 用于只提示一次没有检测器
if stream_detector_instance is None:
print(f"错误 (线程 '{thread_name}'): 未提供有效的检测器实例给yolo_frame。线程将无法处理视频。")
# 此线程实际上无法做任何有用的工作,可以考虑直接退出或进入一个安全循环
# 为简单起见,我们允许它进入主循环,但它会在每次迭代时打印警告
while True:
frame = None
# current_category_counts = {} # 将在获取后转换
try:
# 恢复队列长度打印
print(f"线程 '{thread_name}' - 原始队列长度: {rtc_q.qsize()}, 检测队列长度: {yolo_q.qsize()}")
frame = rtc_q.get(timeout=0.1)
if frame is None:
print(f"处理线程 '{thread_name}' 接收到停止信号,正在退出...")
# 发送包含None frame和空计数的字典作为停止信号
yolo_q.put({"frame": None, "category_counts": {}})
break
category_counts_for_packet = {}
if stream_detector_instance:
no_detector_message_displayed_once = False # 检测器有效,重置提示
annotated_frame = stream_detector_instance.detect_and_draw_count(frame)
error_message_displayed_once = False
# 获取英文键的类别计数
english_counts = stream_detector_instance.category_counts.copy() if hasattr(stream_detector_instance, 'category_counts') else {}
# 转换为中文键的类别计数
if hasattr(stream_detector_instance, 'VISDRONE_CLASSES_CHINESE'):
chinese_map = stream_detector_instance.VISDRONE_CLASSES_CHINESE
for eng_key, count_val in english_counts.items():
# 使用 get 提供一个默认值,以防某个英文类别在中文映射表中确实没有
chi_key = chinese_map.get(eng_key, eng_key)
category_counts_for_packet[chi_key] = count_val
else:
# 如果没有中文映射表,则直接使用英文计数 (或记录警告)
category_counts_for_packet = english_counts
# logger.warning(f"线程 '{thread_name}': stream_detector_instance 没有 VISDRONE_CLASSES_CHINESE 属性,将使用英文类别计数。")
else:
# 如果没有有效的检测器实例传递进来
if not no_detector_message_displayed_once:
print(f"警告 (线程 '{thread_name}'): 无有效检测器实例。将在帧上绘制提示。")
no_detector_message_displayed_once = True
annotated_frame = frame.copy()
cv2.putText(annotated_frame,
"No detector instance provided for this stream",
(30, 50), cv2.FONT_HERSHEY_SIMPLEX,
1, (0, 0, 255), 2, cv2.LINE_AA)
category_counts_for_packet = {} # 无检测器,计数为空
# 将帧和类别计数一起放入队列
data_packet = {"frame": annotated_frame, "category_counts": category_counts_for_packet}
try:
yolo_q.put_nowait(data_packet)
except queue.Full:
# print(f"警告 (线程 '{thread_name}'): yolo_q 已满,丢弃帧。") # 避免刷屏
pass
except queue.Empty:
time.sleep(0.01)
continue
except Exception as e:
if not error_message_displayed_once:
print(f"线程 '{thread_name}' (yolo_frame) 处理时发生严重错误: {e}")
traceback.print_exc()
error_message_displayed_once = True
time.sleep(1)
if frame is not None:
try:
pass
except queue.Full:
pass
continue
print(f"处理线程 '{thread_name}' 已停止。")
def push_frame(yolo_q: queue.Queue, rtmp_url: str, gateway: str, frequency: int, push_url: str):
thread_name = threading.current_thread().name
print(f"推流线程 '{thread_name}' (RTMP: {rtmp_url}) 已启动。")
output_container = None
stream = None
first_frame_processed = False
last_push_time = 0 # 记录上次推送base64的时间
try:
while True:
frame_to_push = None
received_category_counts = {} # 初始化为空字典
try:
data_packet = yolo_q.get(timeout=0.1)
if data_packet:
frame_to_push = data_packet.get("frame")
received_category_counts = data_packet.get("category_counts", {})
else: # data_packet is None (不太可能除非队列明确放入None)
time.sleep(0.01)
continue
except queue.Empty:
time.sleep(0.01)
continue
if frame_to_push is None: # 这是通过 data_packet["frame"] is None 来判断的
print(f"推流线程 '{thread_name}' 接收到停止信号,正在清理并退出...")
break
if not first_frame_processed:
if frame_to_push is not None:
try:
height, width, _ = frame_to_push.shape
print(f"线程 '{thread_name}': 首帧尺寸 {width}x{height}正在初始化RTMP推流器到 {rtmp_url}")
output_container = av.open(rtmp_url, 'w', format='flv')
stream = output_container.add_stream('libx264', rate=25)
stream.pix_fmt = 'yuv420p'
stream.width = width
stream.height = height
stream.options = {'preset': 'ultrafast', 'tune': 'zerolatency', 'crf': '25'}
print(f"线程 '{thread_name}': RTMP推流器初始化成功。")
first_frame_processed = True
except Exception as e_init:
print(f"错误 (线程 '{thread_name}'): 初始化PyAV推流容器/流失败: {e_init}")
traceback.print_exc()
return
else:
continue
if not output_container or not stream:
print(f"错误 (线程 '{thread_name}'): 推流器未初始化,无法推流。可能是首帧处理失败。")
time.sleep(1)
continue
# 持续推流到RTMP
try:
video_frame = av.VideoFrame.from_ndarray(frame_to_push, format='bgr24')
for packet in stream.encode(video_frame):
output_container.mux(packet)
except Exception as e_push:
print(f"错误 (线程 '{thread_name}'): 推送帧到RTMP时发生错误: {e_push}")
time.sleep(0.5)
# 定时推送base64帧
current_time = time.time()
if current_time - last_push_time >= frequency:
# 将接收到的类别计数传递给 push_base64_frame
push_base64_frame(frame_to_push, gateway, push_url, thread_name, received_category_counts)
last_push_time = current_time
except Exception as e_outer:
print(f"推流线程 '{thread_name}' 发生严重外部错误: {e_outer}")
traceback.print_exc()
finally:
print(f"推流线程 '{thread_name}': 进入finally块准备关闭推流器。")
if stream and output_container:
try:
print(f"推流线程 '{thread_name}': 正在编码流的剩余部分...")
for packet in stream.encode(None):
output_container.mux(packet)
print(f"推流线程 '{thread_name}': 编码剩余部分完成。")
except Exception as e_flush:
print(f"错误 (线程 '{thread_name}'): 关闭推流流时发生编码/刷新错误: {e_flush}")
traceback.print_exc()
if output_container:
try:
print(f"推流线程 '{thread_name}': 正在关闭推流容器...")
output_container.close()
print(f"推流线程 '{thread_name}': 推流容器已关闭。")
except Exception as e_close:
print(f"错误 (线程 '{thread_name}'): 关闭推流容器时发生错误: {e_close}")
traceback.print_exc()
print(f"推流线程 '{thread_name}' 已停止并完成清理。")
def push_base64_frame(frame: np.ndarray, gateway: str, push_url: str, thread_name: str, category_counts: dict):
"""将帧转换为base64并推送到指定URL"""
try:
# 转换为JPEG格式
_, buffer = cv2.imencode('.jpg', frame)
# 转换为base64字符串
frame_base64 = base64.b64encode(buffer).decode('utf-8')
# 构建JSON数据
data = {
"gateway": gateway,
"frame_base64": frame_base64,
"category_counts": category_counts # 使用传入的 category_counts
}
print(f"DEBUG push_base64_frame: Pushing data: {data.get('category_counts')}") # 调试打印,检查发送的数据
# 发送POST请求
response = http_client.post(push_url, json=data, timeout=5)
# 检查响应
if response.status_code == 200:
print(f"线程 '{thread_name}': base64帧已成功推送到 {push_url}")
else:
print(f"错误 (线程 '{thread_name}'): 推送base64帧失败状态码: {response.status_code}")
except Exception as e:
print(f"错误 (线程 '{thread_name}'): 处理或推送base64帧时发生错误: {e}")

5
main.py Normal file
View File

@ -0,0 +1,5 @@
import uvicorn
from web import app
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)

BIN
models/人车.pth Normal file

Binary file not shown.

27
requirements - 副本.txt Normal file
View File

@ -0,0 +1,27 @@
aiohttp==3.11.14
aiortc==1.11.0
aiosignal==1.3.2
APScheduler==3.11.0
av==14.2.0
fastapi==0.115.11
huggingface-hub==0.30.1
numpy==2.1.1
nvidia-cuda-runtime-cu12==12.8.90
opencv-contrib-python==4.11.0.86
opencv-python==4.11.0.86
pillow==11.1.0
pillow_heif==0.22.0
pycuda==2025.1
pydantic==2.10.6
pydantic_core==2.27.2
requests==2.32.3
requests-toolbelt==1.0.0
rfdetr==1.1.0
safetensors==0.5.3
supervision==0.25.1
torch==2.6.0+cu126
torchaudio==2.6.0+cu126
torchvision==0.21.0+cu126
transformers==4.50.3
uvicorn==0.34.0
wandb==0.19.9

233
requirements.txt Normal file
View File

@ -0,0 +1,233 @@
absl-py==2.2.1
accelerate==1.6.0
addict==2.4.0
aiohappyeyeballs==2.6.1
aiohttp==3.11.14
aioice==0.9.0
aiortc==1.11.0
aiosignal==1.3.2
albucore==0.0.23
albumentations==2.0.5
altgraph==0.17.4
annotated-types==0.7.0
anyio==4.8.0
anywidget==0.9.18
appdirs==1.4.4
APScheduler==3.11.0
asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1733250440834/work
async-timeout==5.0.1
attrs==25.3.0
av==14.2.0
basicsr==1.4.2
bbox_visualizer==0.2.0
blind-watermark==0.4.4
certifi==2025.1.31
cffi==1.17.1
charset-normalizer==3.4.1
click==8.1.8
cmake==3.31.6
colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1733218098505/work
coloredlogs==15.0.1
comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1733502965406/work
contourpy==1.3.1
cryptography==44.0.2
cycler==0.12.1
Cython==3.0.12
debugpy @ file:///D:/bld/debugpy_1741148401445/work
decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1740384970518/work
deep-sort-realtime==1.3.2
defusedxml==0.7.1
dnspython==2.7.0
docker-pycreds==0.4.0
easydict==1.13
einops==0.8.1
et_xmlfile==2.0.0
exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1733208806608/work
executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1733569351617/work
facexlib==0.3.0
fairscale==0.4.13
fastapi==0.115.11
filelock==3.13.1
filetype==1.2.0
filterpy==1.4.5
fire==0.7.0
flatbuffers==25.2.10
fonttools==4.56.0
frozenlist==1.5.0
fsspec==2024.6.1
ftfy==6.3.1
future==1.0.0
gfpgan==1.3.8
gitdb==4.0.12
GitPython==3.1.44
google-crc32c==1.7.1
grpcio==1.71.0
h11==0.14.0
httptools==0.6.4
huggingface-hub==0.30.1
humanfriendly==10.0
idna==3.7
ifaddr==0.2.0
imageio==2.37.0
importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1737420181517/work
iniconfig==2.1.0
insightface==0.7.3
ipykernel @ file:///D:/bld/ipykernel_1719845595208/work
ipython @ file:///D:/bld/bld/rattler-build_ipython_1740856913/work
ipywidgets==8.1.5
jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1733300866624/work
Jinja2==3.1.4
joblib==1.4.2
jupyter_bbox_widget==0.6.0
jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1733440914442/work
jupyter_core @ file:///D:/bld/jupyter_core_1727163532151/work
jupyterlab_widgets==3.0.13
kiwisolver==1.4.8
lap==0.5.12
lazy_loader==0.4
llvmlite==0.44.0
lmdb==1.6.2
Mako==1.3.9
Markdown==3.7
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.10.1
matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1733416936468/work
mdurl==0.1.2
memory-profiler==0.61.0
motmetrics==1.4.0
mpmath==1.3.0
multidict==6.2.0
nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1733325553580/work
networkx==3.3
ninja==1.11.1.3
Nuitka==2.6.7
numba==0.61.0
numpy==2.1.1
nvidia-cuda-runtime-cu12==12.8.90
onnx==1.16.1
onnx-graphsurgeon==0.5.7
onnxruntime-gpu==1.20.2
onnxsim==0.4.36
onnxslim==0.1.48
open_clip_torch==2.32.0
opencv-contrib-python==4.11.0.86
opencv-python==4.11.0.86
openpyxl==3.1.5
ordered-set==4.1.0
packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1733203243479/work
pandas==2.2.3
parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1733271261340/work
pefile==2023.2.7
peft==0.15.1
pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1733327343728/work
pillow==11.1.0
pillow_heif==0.22.0
platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1733232627818/work
pluggy==1.5.0
polygraphy==0.49.20
prettytable==3.15.1
prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1737453357274/work
propcache==0.3.1
protobuf==3.20.2
psutil @ file:///D:/bld/psutil_1740663160591/work
psygnal==0.12.0
pure_eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1733569405015/work
py-cpuinfo==9.0.0
pybboxes==0.1.6
pycocotools==2.0.8
pycparser==2.22
pycuda==2025.1
pydantic==2.10.6
pydantic_core==2.27.2
pyee==13.0.0
Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1736243443484/work
pyinstaller==6.12.0
pyinstaller-hooks-contrib==2025.1
pylabel==0.1.55
pylibsrtp==0.11.0
PyMuPDF==1.25.4
pyOpenSSL==25.0.0
pyparsing==3.2.1
pyproj==3.7.1
pyreadline3==3.5.4
PySide6==6.8.2.1
PySide6_Addons==6.8.2.1
PySide6_Essentials==6.8.2.1
pytest==8.3.5
python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1733215673016/work
python-dotenv==1.0.1
python-multipart==0.0.20
pytools==2025.1.2
pytz==2025.1
PyWavelets==1.8.0
pywin32==307
pywin32-ctypes==0.2.3
PyYAML==6.0.2
pyzmq @ file:///D:/bld/pyzmq_1738270977186/work
realesrgan==0.3.0
regex==2024.11.6
requests==2.32.3
requests-toolbelt==1.0.0
rf100vl==1.0.0
rfdetr==1.1.0
rich==14.0.0
roboflow==1.1.60
safetensors==0.5.3
sahi==0.11.22
scikit-image==0.25.2
scikit-learn==1.6.1
scipy==1.15.2
seaborn==0.13.2
sentry-sdk==2.25.1
setproctitle==1.3.5
shapely==2.0.7
shiboken6==6.8.2.1
simsimd==6.2.1
six @ file:///home/conda/feedstock_root/build_artifacts/six_1733380938961/work
smmap==5.0.2
sniffio==1.3.1
stack_data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1733569443808/work
starlette==0.46.1
stringzilla==3.12.2
supervision==0.25.1
sympy==1.13.1
tabulate==0.9.0
tb-nightly==2.20.0a20250326
tensorboard-data-server==0.7.2
tensorrt==10.9.0.34
tensorrt_cu12==10.9.0.34
tensorrt_cu12_bindings==10.9.0.34
tensorrt_cu12_libs==10.9.0.34
termcolor==2.5.0
terminaltables==3.1.10
threadpoolctl==3.5.0
tifffile==2025.2.18
timm==1.0.15
tokenizers==0.21.1
tomli==2.2.1
torch==2.6.0+cu126
torchaudio==2.6.0+cu126
torchvision==0.21.0+cu126
tornado @ file:///D:/bld/tornado_1732615925919/work
tqdm==4.67.1
traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1733367359838/work
transformers==4.50.3
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1733188668063/work
tzdata==2025.1
tzlocal==5.3.1
ultralytics==8.3.99
ultralytics-thop==2.0.14
urllib3==2.3.0
uvicorn==0.34.0
wandb==0.19.9
watchfiles==1.0.4
wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1733231326287/work
websockets==15.0.1
Werkzeug==3.1.3
widgetsnbextension==4.0.13
xmltodict==0.14.2
yapf==0.43.0
yarl==1.18.3
zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1732827521216/work
zstandard==0.23.0

21
result.py Normal file
View File

@ -0,0 +1,21 @@
# 定义返回值类
class Response:
def __init__(self, code, data, message):
self.code = code
self.data = data
self.message = message
def to_dict(self):
return {
"code": self.code,
"data": self.data,
"message": self.message
}
@classmethod
def success(cls, data=None, message="操作成功"):
return cls(200, data, message).to_dict()
@classmethod
def error(cls, data=None, message="操作失败", code=400):
return cls(code, data, message).to_dict()

254
rfdetr_core.py Normal file
View File

@ -0,0 +1,254 @@
import cv2
import supervision as sv
from rfdetr import RFDETRBase
from collections import defaultdict
from typing import Dict, Set
from PIL import Image, ImageDraw, ImageFont # 导入PIL库
import numpy as np # 导入numpy用于图像格式转换
import json # 新增
import os # 新增
class RFDETRDetector:
def __init__(self, config_name: str, base_model_dir="models", base_config_dir="configs", default_font_path="./font/MSYH.TTC", default_font_size=15):
self.config_path = os.path.join(base_config_dir, f"{config_name}.json")
if not os.path.exists(self.config_path):
raise FileNotFoundError(f"配置文件不存在: {self.config_path}")
with open(self.config_path, 'r', encoding='utf-8') as f:
self.config = json.load(f)
model_path = os.path.join(base_model_dir, self.config['model_pth_filename'])
resolution = self.config['resolution']
# 从配置读取字体路径和大小,如果未提供则使用默认值
font_path = self.config.get('font_path', default_font_path)
font_size = self.config.get('font_size', default_font_size)
# 1. 初始化模型
self.model = RFDETRBase(
pretrain_weights=model_path,
# pretrain_weights=model_path or r"E:\A\rf-detr-main\output\pre-train1\checkpoint_best_ema.pth",
resolution=resolution
)
# 2. 初始化跟踪器
self.tracker = sv.ByteTrack(
track_activation_threshold=self.config['tracker_activation_threshold'],
lost_track_buffer=self.config['tracker_lost_buffer'],
minimum_matching_threshold=self.config['tracker_match_threshold'],
minimum_consecutive_frames=self.config['tracker_consecutive_frames'],
frame_rate=self.config['tracker_frame_rate']
)
# 3. 类别定义
self.VISDRONE_CLASSES = self.config['classes_en']
self.VISDRONE_CLASSES_CHINESE = self.config['classes_zh_map']
# 新增:加载类别启用配置
self.detection_settings = self.config.get('detection_settings', {})
self.enabled_classes_filter = self.detection_settings.get('enabled_classes', {})
# 构建一个查找表对于未在filter中指定的类别默认为 True (启用)
self._active_classes_lookup = {
cls_name: self.enabled_classes_filter.get(cls_name, True)
for cls_name in self.VISDRONE_CLASSES
}
print(f"活动类别配置: {self._active_classes_lookup}")
# 4. 初始化字体
self.FONT_SIZE = font_size
try:
self.font = ImageFont.truetype(font_path, self.FONT_SIZE)
except IOError:
print(f"错误:无法加载字体 {font_path}。将使用默认字体。")
self.font = ImageFont.load_default() # 使用真正通用的默认字体
# 5. 类别计数器 (作为类属性)
self.class_tracks: Dict[str, Set[int]] = defaultdict(set)
self.category_counts: Dict[str, int] = defaultdict(int)
# 6. 初始化标注器
# 从配置加载默认颜色,如果失败则使用预设颜色
self.default_color_hex = self.config.get('default_color_hex', "#00FF00") # 默认绿色
self.bounding_box_thickness = self.config.get('bounding_box_thickness', 2)
# 加载颜色配置,用于 PIL 绘制
self.class_colors_hex = self.config.get('class_colors_hex', {})
self.last_annotated_frame: np.ndarray | None = None # 新增: 用于存储最新的标注帧
def _hex_to_rgb(self, hex_color: str) -> tuple:
"""将十六进制颜色字符串转换为RGB元组。"""
hex_color = hex_color.lstrip('#')
try:
return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
except ValueError:
print(f"警告: 无法解析十六进制颜色 '{hex_color}', 将使用默认颜色。")
# 解析失败时返回一个默认颜色,例如红色
return self._hex_to_rgb(self.default_color_hex if self.default_color_hex != hex_color else "#00FF00")
def _update_counter(self, detections: sv.Detections):
"""更新类别计数器"""
# 只统计有 tracker_id 的检测结果
valid_indices = detections.tracker_id != None
if not np.any(valid_indices): # 处理 detections 为空或 tracker_id 都为 None 的情况
return
class_ids = detections.class_id[valid_indices]
track_ids = detections.tracker_id[valid_indices]
for class_id, track_id in zip(class_ids, track_ids):
if track_id is None: # 跳过没有 tracker_id 的项
continue
# 使用英文类别名作为内部 key
class_name = self.VISDRONE_CLASSES[class_id]
if track_id not in self.class_tracks[class_name]:
self.class_tracks[class_name].add(track_id)
self.category_counts[class_name] += 1
def _draw_frame(self, frame: np.ndarray, detections: sv.Detections) -> np.ndarray:
"""使用PIL绘制检测框、中文标签和计数信息"""
pil_image = Image.fromarray(cv2.cvtColor(frame.copy(), cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(pil_image)
# --- 使用 PIL 绘制检测框和中文标签 ---
valid_indices = detections.tracker_id != None # 或直接使用 detections.xyxy 如果不过滤无 tracker_id 的
if np.any(valid_indices):
boxes = detections.xyxy[valid_indices]
class_ids = detections.class_id[valid_indices]
# tracker_ids = detections.tracker_id[valid_indices] # 如果需要 tracker_id
for box, class_id in zip(boxes, class_ids):
x1, y1, x2, y2 = map(int, box)
english_label = self.VISDRONE_CLASSES[class_id]
chinese_label = self.VISDRONE_CLASSES_CHINESE.get(english_label, english_label)
# 获取边界框颜色
box_color_hex = self.class_colors_hex.get(english_label, self.default_color_hex)
box_rgb_color = self._hex_to_rgb(box_color_hex)
# 绘制边界框
draw.rectangle([x1, y1, x2, y2], outline=box_rgb_color, width=self.bounding_box_thickness)
# 绘制中文标签 (与之前逻辑类似)
text_to_draw = f"{chinese_label}"
# 标签背景 (可选,使其更易读)
# label_text_bbox = draw.textbbox((0,0), text_to_draw, font=self.font)
# label_width = label_text_bbox[2] - label_text_bbox[0]
# label_height = label_text_bbox[3] - label_text_bbox[1]
# label_bg_y1 = y1 - label_height - 4 if y1 - label_height - 4 > 0 else y1 + 2
# draw.rectangle([x1, label_bg_y1, x1 + label_width + 4, label_bg_y1 + label_height + 2], fill=box_rgb_color)
# text_color = (255,255,255) if sum(box_rgb_color) < 382 else (0,0,0) # 简易对比色
text_color = (255, 255, 255) # 白色 (RGB)
text_x = x1 + 2 # 稍微偏移,避免紧贴边框
text_y = y1 - self.FONT_SIZE - 2
if text_y < 0: # 如果标签超出图像顶部
text_y = y1 + 2
draw.text((text_x, text_y), text_to_draw, font=self.font, fill=text_color)
# --- 绘制统计面板 (右上角) ---
stats_text_lines = [
f"{self.VISDRONE_CLASSES_CHINESE.get(cls, cls)}: {self.category_counts[cls]}"
for cls in self.VISDRONE_CLASSES if self.category_counts[cls] > 0
]
frame_height, frame_width, _ = frame.shape
stats_start_x = frame_width - self.config.get('stats_panel_width', 200)
stats_start_y = self.config.get('stats_panel_margin_y', 10)
line_height = self.FONT_SIZE + self.config.get('stats_line_spacing', 5)
stats_text_color_hex = self.config.get('stats_text_color_hex', "#FFFFFF")
stats_text_color = self._hex_to_rgb(stats_text_color_hex)
# 可选:为统计面板添加背景
if stats_text_lines:
panel_height = len(stats_text_lines) * line_height + 10
panel_y2 = stats_start_y + panel_height
# 半透明背景
# overlay = Image.new('RGBA', pil_image.size, (0,0,0,0))
# panel_draw = ImageDraw.Draw(overlay)
# panel_draw.rectangle(
# [stats_start_x - 5, stats_start_y - 5, frame_width - 5, panel_y2],
# fill=(100, 100, 100, 128) # 半透明灰色
# )
# pil_image = Image.alpha_composite(pil_image.convert('RGBA'), overlay)
# draw = ImageDraw.Draw(pil_image) # 如果用了 alpha_composite, 需要重新获取 draw 对象
# 或者简单不透明背景
# draw.rectangle(
# [stats_start_x - 5, stats_start_y - 5, frame_width - 5, panel_y2],
# fill=self._hex_to_rgb(self.config.get('stats_panel_bg_color_hex', "#808080")) # 例如灰色背景
# )
for i, line in enumerate(stats_text_lines):
text_pos = (stats_start_x, stats_start_y + i * line_height)
draw.text(text_pos, line, font=self.font, fill=stats_text_color)
final_annotated_frame = cv2.cvtColor(np.array(pil_image.convert('RGB')), cv2.COLOR_RGB2BGR)
return final_annotated_frame
def detect_and_draw_count(self, frame: np.ndarray, conf: float = -1.0) -> np.ndarray:
"""执行单帧检测、跟踪、计数并绘制结果(包含类别过滤)。"""
if conf == -1.0:
# 优先从 detection_settings 中获取其次是顶层config最后是硬编码默认值
effective_conf = float(
self.detection_settings.get('default_confidence_threshold',
self.config.get('default_confidence_threshold', 0.8))
)
else:
effective_conf = conf
try:
# 1. 执行检测
detections = self.model.predict(frame, threshold=effective_conf)
# 处理 detections 为 None 或空的情况
if detections is None or len(detections) == 0:
detections = sv.Detections.empty()
annotated_frame = self._draw_frame(frame, detections)
self.last_annotated_frame = annotated_frame.copy() # 新增
return annotated_frame
# 新增:根据配置过滤检测到的类别
if detections is not None and len(detections) > 0:
keep_indices = []
for i, class_id in enumerate(detections.class_id):
if class_id < len(self.VISDRONE_CLASSES): # 确保 class_id 有效
class_name = self.VISDRONE_CLASSES[class_id]
if self._active_classes_lookup.get(class_name, True): # 默认为 True
keep_indices.append(i)
else:
print(f"警告: 检测到无效的 class_id {class_id},超出了已知类别范围。")
if not keep_indices:
detections = sv.Detections.empty()
else:
detections = detections[keep_indices]
# 如果过滤后没有检测结果
if len(detections) == 0:
annotated_frame = self._draw_frame(frame, sv.Detections.empty())
self.last_annotated_frame = annotated_frame.copy() # 新增
return annotated_frame
# 2. 执行跟踪 (只对过滤后的结果进行跟踪)
detections = self.tracker.update_with_detections(detections)
# 3. 更新计数器 (只对过滤并跟踪后的结果进行计数)
self._update_counter(detections)
# 4. 绘制结果
annotated_frame = self._draw_frame(frame, detections)
self.last_annotated_frame = annotated_frame.copy() # 新增
return annotated_frame
except Exception as e:
print(f"处理帧时发生错误: {e}")
if frame is not None:
self.last_annotated_frame = frame.copy() # 新增
else:
self.last_annotated_frame = None # 新增
return frame

94
rtc_handler.py Normal file
View File

@ -0,0 +1,94 @@
import asyncio
import queue
from fractions import Fraction
from urllib.parse import urlparse
import aiohttp
import av
import numpy as np
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration, VideoStreamTrack
class DummyVideoTrack(VideoStreamTrack):
async def recv(self):
# 简洁初始化、返回固定颜色的帧
return np.full((480, 640, 3), (0, 0, 255), dtype=np.uint8)
async def receive_video_frames(whep_url):
pc = RTCPeerConnection(RTCConfiguration(iceServers=[]))
frames_queue = asyncio.Queue()
pc.addTrack(DummyVideoTrack())
@pc.on("track")
def on_track(track):
if track.kind == "video":
asyncio.create_task(consume_track(track, frames_queue))
@pc.on("iceconnectionstatechange")
def on_ice_connection_state_change():
print(f"ICE 连接状态: {pc.iceConnectionState}")
offer = await pc.createOffer()
await pc.setLocalDescription(offer)
headers = {"Content-Type": "application/sdp"}
async with aiohttp.ClientSession() as session:
async with session.post(whep_url, data=pc.localDescription.sdp, headers=headers) as response:
if response.status != 201:
raise Exception(f"服务器返回错误: {response.status}")
answer = RTCSessionDescription(sdp=await response.text(), type="answer")
await pc.setRemoteDescription(answer)
if "Location" in response.headers:
base_url = f"{urlparse(whep_url).scheme}://{urlparse(whep_url).netloc}"
print("ICE 协商 URL:", base_url + response.headers["Location"])
while pc.iceConnectionState not in ["connected", "completed"]:
await asyncio.sleep(1)
print("ICE 连接完成,开始接收视频流")
try:
while True:
frame = await frames_queue.get()
if frame is None:
break
yield frame
except KeyboardInterrupt:
pass
finally:
await pc.close()
async def consume_track(track, frames_queue):
try:
while True:
frame = await track.recv()
if frame is None:
print("没有接收到有效的帧数据")
await frames_queue.put(None)
break
img = frame.to_ndarray(format="bgr24")
await frames_queue.put(img)
except Exception as e:
print("处理帧错误:", e)
await frames_queue.put(None)
def rtc_frame(url, frame_queue):
async def main():
async for frame in receive_video_frames(url):
try:
frame_queue.put_nowait(frame)
except queue.Full:
frame_queue.get_nowait()
frame_queue.put_nowait(frame)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(main())
loop.close()

276
web.py Normal file
View File

@ -0,0 +1,276 @@
import base64
import hashlib
import queue
import time
from multiprocessing import Process, Event, Queue as MpQueue
from typing import Dict, Any
import cv2
import numpy as np
from fastapi import FastAPI, Request, HTTPException
from api_server import model_management_router, initialize_default_model_on_startup, get_active_detector, \
get_active_model_identifier
from data_pusher import pusher_router, initialize_data_pusher, get_data_pusher_instance
from frame_transfer import yolo_frame, push_frame
from result import Response
from rfdetr_core import RFDETRDetector
from rtc_handler import rtc_frame
app = FastAPI(title="Real-time Video Processing, Model Management, and Data Pusher API")
app.include_router(model_management_router, prefix="/api", tags=["Model Management"])
app.include_router(pusher_router, prefix="/api/pusher", tags=["Data Pusher"])
process_map: Dict[str, Dict[str, Any]] = {}
@app.on_event("startup")
async def web_app_startup_event():
print("主应用服务启动中...")
await initialize_default_model_on_startup()
active_detector_instance = get_active_detector()
if active_detector_instance:
print(f"主应用启动:检测到活动模型 '{get_active_model_identifier()}',将用于初始化 DataPusher。")
else:
print("主应用启动未检测到活动模型。DataPusher 将以无检测器状态初始化。")
initialize_data_pusher(active_detector_instance)
print("DataPusher 服务已初始化。")
print("主应用启动完成。")
@app.on_event("shutdown")
async def web_app_shutdown_event():
print("主应用服务关闭中...")
pusher = get_data_pusher_instance()
if pusher:
print("正在关闭 DataPusher 调度器...")
pusher.shutdown_scheduler()
else:
print("DataPusher 实例未找到,跳过调度器关闭。")
print("正在尝试终止所有活动的视频处理子进程...")
for url, task_info in list(process_map.items()):
process: Process = task_info['process']
stop_event: Event = task_info['stop_event']
data_q: MpQueue = task_info['data_queue']
if process.is_alive():
print(f"向进程 {url} 发送停止信号...")
stop_event.set()
try:
process.join(timeout=15)
if process.is_alive():
print(f"进程 {url} 在优雅关闭超时后仍然存活,尝试 terminate。")
process.terminate()
process.join(timeout=5)
if process.is_alive():
print(f"进程 {url} 在 terminate 后仍然存活,尝试 kill。")
process.kill()
process.join(timeout=2)
except Exception as e:
print(f"关闭/终止进程 {url} 时发生错误: {e}")
try:
if not data_q.empty():
pass
data_q.close()
data_q.join_thread()
except Exception as e_q_cleanup:
print(f"清理进程 {url} 的数据队列时出错: {e_q_cleanup}")
del process_map[url]
print("所有视频处理子进程已尝试终止和清理。")
print("主应用服务关闭完成。")
def start_video_processing(url: str, rtmp_url: str, model_config_name: str, stop_event: Event, data_queue: MpQueue,
gateway: str,frequency:int, push_url:str):
print(f"视频处理子进程启动 (URL: {url}, Model: {model_config_name})")
detector_instance_for_stream: RFDETRDetector = None
producer_thread, transfer_thread, consumer_thread = None, None, None
try:
print(f"正在为流 {url} 初始化模型: {model_config_name}...")
detector_instance_for_stream = RFDETRDetector(config_name=model_config_name)
print(f"模型 {model_config_name} 为流 {url} 初始化成功。")
rtc_q = queue.Queue(maxsize=10000)
yolo_q = queue.Queue(maxsize=10000)
import threading
producer_thread = threading.Thread(target=rtc_frame, args=(url, rtc_q), name=f"RTC-{url[:20]}", daemon=True)
transfer_thread = threading.Thread(target=yolo_frame, args=(rtc_q, yolo_q, detector_instance_for_stream),
name=f"YOLO-{url[:20]}", daemon=True)
consumer_thread = threading.Thread(target=push_frame, args=(yolo_q, rtmp_url,gateway,frequency,push_url), name=f"Push-{url[:20]}",
daemon=True)
producer_thread.start()
transfer_thread.start()
consumer_thread.start()
stop_event.wait()
print(f"子进程 {url}: 收到停止信号。准备关闭线程...")
except FileNotFoundError as e:
print(f"错误 (视频进程 {url}): 模型配置文件 '{model_config_name}.json' 未找到。错误: {e}")
except Exception as e:
print(f"错误 (视频进程 {url}): 初始化或运行时错误。错误: {e}")
finally:
print(f"视频处理子进程 {url} 进入 finally 块。")
if producer_thread and producer_thread.is_alive():
print(f"子进程 {url}: producer_thread is still alive (daemon).")
if transfer_thread and transfer_thread.is_alive():
print(f"子进程 {url}: transfer_thread is still alive (daemon).")
if consumer_thread and consumer_thread.is_alive():
print(f"子进程 {url}: consumer_thread is still alive (daemon).")
if detector_instance_for_stream:
print(f"子进程 {url}: 收集最后数据...")
final_counts = getattr(detector_instance_for_stream, 'category_counts', {})
final_frame_np = getattr(detector_instance_for_stream, 'last_annotated_frame', None)
frame_base64 = None
if final_frame_np is not None and isinstance(final_frame_np, np.ndarray):
try:
_, buffer = cv2.imencode('.jpg', final_frame_np)
frame_base64 = base64.b64encode(buffer).decode('utf-8')
except Exception as e_encode:
print(f"子进程 {url}: 帧编码错误: {e_encode}")
payload = {
"timestamp": time.time(),
"category_counts": final_counts,
"frame_base64": frame_base64,
"source_url": url,
"event": "task_stopped_final_data"
}
try:
data_queue.put(payload, timeout=5)
print(f"子进程 {url}: 已将最终数据放入队列。")
except queue.Full:
print(f"子进程 {url}: 无法将最终数据放入队列 (队列已满或超时)。")
except Exception as e_put:
print(f"子进程 {url}: 将最终数据放入队列时发生错误: {e_put}")
else:
print(f"子进程 {url}: 检测器实例不可用,无法发送最终数据。")
try:
data_queue.close()
except Exception as e_q_close:
print(f"子进程 {url}: 关闭数据队列时出错: {e_q_close}")
print(f"视频处理子进程 {url} 执行完毕。")
@app.post("/start_video", tags=["Video Processing"])
async def start_video(request: Request):
data = await request.json()
url = data.get("url")
model_identifier_to_use = data.get("model_identifier")
host = data.get("host")
rtmp_port = data.get("rtmp_port")
rtc_port = data.get("rtc_port")
gateway = data.get("gateway")
frequency = data.get("frequency")
push_url = data.get("push_url")
# 生成MD5
md5_hash = hashlib.md5(url.encode()).hexdigest()
rtmp_url = f"rtmp://{host}:{rtmp_port}/live/{md5_hash}"
rtc_url = f"http://{host}:{rtc_port}/rtc/v1/whep/?{md5_hash}"
if not url or not rtmp_url:
raise HTTPException(status_code=400, detail="'url''rtmp_url' 字段是必须的。")
if not model_identifier_to_use:
print(f"请求中未指定 model_identifier尝试使用全局激活的模型。")
model_identifier_to_use = get_active_model_identifier()
if not model_identifier_to_use:
raise HTTPException(status_code=400, detail="请求中未指定 'model_identifier',且当前无全局激活的默认模型。")
print(f"将为流 {url} 使用当前全局激活的模型: {model_identifier_to_use}")
if url in process_map and process_map[url]['process'].is_alive():
raise HTTPException(status_code=409, detail=f"视频处理进程已在运行: {url}")
print(f"请求启动视频处理: URL = {url}, RTMP = {rtmp_url}, Model = {model_identifier_to_use}")
stop_event = Event()
data_queue = MpQueue(maxsize=1)
process = Process(target=start_video_processing,
args=(url, rtmp_url, model_identifier_to_use, stop_event, data_queue, gateway, frequency, push_url))
process.start()
process_map[url] = {'process': process, 'stop_event': stop_event, 'data_queue': data_queue}
return Response.success(message=f"视频处理已为 URL '{url}' 使用模型 '{model_identifier_to_use}' 启动。",
data=rtc_url)
@app.post("/stop_video", tags=["Video Processing"])
async def stop_video(request: Request):
data = await request.json()
url = data.get("url")
if not url:
raise HTTPException(status_code=400, detail="'url' 字段是必须的。")
task_info = process_map.get(url)
if not task_info:
raise HTTPException(status_code=404, detail=f"没有找到与 URL '{url}' 匹配的活动视频处理进程。")
process: Process = task_info['process']
stop_event: Event = task_info['stop_event']
data_q: MpQueue = task_info['data_queue']
final_data_pushed = False
if process.is_alive():
print(f"请求停止视频处理: {url}. 发送停止信号...")
stop_event.set()
process.join(timeout=20)
if process.is_alive():
print(f"警告: 视频处理进程 {url} 在超时后未能正常终止,尝试强制结束。")
process.terminate()
process.join(timeout=5)
if process.is_alive():
print(f"错误: 视频处理进程 {url} 强制结束后仍然存在。尝试 kill。")
process.kill()
process.join(timeout=2)
else:
print(f"进程 {url} 已优雅停止。尝试获取最后数据...")
try:
final_payload = data_q.get(timeout=10)
print(f"从停止的任务 {url} 收到最终数据。")
pusher_instance = get_data_pusher_instance()
if pusher_instance and pusher_instance.target_url:
print(f"准备将任务 {url} 的最后数据推送到 {pusher_instance.target_url}")
pusher_instance.push_specific_payload(final_payload)
final_data_pushed = True
elif pusher_instance:
print(
f"DataPusher 服务已配置但未设置目标URL (pusher.target_url is None)。无法推送任务 {url} 的最后数据。")
else:
print(f"DataPusher 服务未初始化或不可用。无法推送任务 {url} 的最后数据。")
except queue.Empty:
print(f"警告: 任务 {url} 优雅停止后,未从其数据队列中获取到最终数据 (队列为空或超时)。")
except Exception as e_q_get:
print(f"获取或处理来自任务 {url} 的最终数据时发生错误: {e_q_get}")
else:
print(f"视频处理进程先前已停止或已结束: {url}")
try:
while not data_q.empty():
try:
data_q.get_nowait()
except queue.Empty:
break
data_q.close()
data_q.join_thread()
except Exception as e_q_final_cleanup:
print(f"清理任务 {url} 的数据队列的最后步骤中发生错误: {e_q_final_cleanup}")
del process_map[url]
message = f"视频处理已为 URL '{url}' 停止。"
if final_data_pushed:
message += " 已尝试推送最后的数据。"
elif process.exitcode == 0:
message += " 进程已退出,但未确认最后数据推送 (可能未配置推送或队列问题)。"
return Response.success(message=message)

151
yolo_core.py Normal file
View File

@ -0,0 +1,151 @@
from ultralytics import YOLO
import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
class YOLODetector:
def __init__(self, model_path='models/best.engine'):
# 加载 TensorRT 模型
self.model = YOLO(model_path, task="detect")
# 英文类别名称到中文的映射
self.class_name_mapping = {
'pedestrian': '行人',
'people': '人群',
'bicycle': '自行车',
'car': '轿车',
'van': '面包车',
'truck': '卡车',
'tricycle': '三轮车',
'awning-tricycle': '篷式三轮车',
'bus': '公交车',
'motor': '摩托车'
}
# 为每个类别设置固定的RGB颜色
self.color_mapping = {
'pedestrian': (71, 0, 36), # 勃艮第红
'people': (0, 255, 0), # 绿色
'bicycle': (0, 49, 83), # 普鲁士蓝
'car': (0, 47, 167), # 克莱茵蓝
'van': (128, 0, 128), # 紫色
'truck': (212, 72, 72), # 缇香红
'tricycle': (0, 49, 83), # 橙色
'awning-tricycle': (251, 220, 106), # 申布伦黄
'bus': (73, 45, 34), # 凡戴克棕
'motor': (1, 132, 127) # 马尔斯绿
}
# 初始化类别计数器
self.class_counts = {cls_name: 0 for cls_name in self.class_name_mapping.keys()}
# 初始化字体
try:
self.font = ImageFont.truetype("simhei.ttf", 20)
except IOError:
self.font = ImageFont.load_default()
def detect_and_draw_English(self, frame, conf=0.3, iou=0.5):
"""
对输入帧进行目标检测并返回绘制结果
Args:
frame: 输入的图像帧BGR格式
conf: 置信度阈值
iou: IOU阈值
Returns:
annotated_frame: 绘制了检测结果的图像帧
"""
try:
# 进行 YOLO 目标检测
results = self.model(
frame,
conf=conf,
iou=iou,
half=True,
)
result = results[0]
# 使用YOLO自带的绘制功能
annotated_frame = result.plot()
return annotated_frame
except Exception as e:
print(f"Detection error: {e}")
return frame
def detect_and_draw_Chinese(self, frame, conf=0.2, iou=0.3):
"""
对输入帧进行目标检测并绘制中文标注
Args:
frame: 输入的图像帧BGR格式
conf: 置信度阈值
iou: IOU阈值
Returns:
annotated_frame: 绘制了检测结果的图像帧
"""
try:
# 进行 YOLO 目标检测
results = self.model(
frame,
conf=conf,
iou=iou,
# half=True,
)
result = results[0]
# 获取原始帧的副本
img = frame.copy()
# 转换为PIL图像以绘制中文
pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(pil_img)
# 绘制检测结果
for box in result.boxes:
# 获取边框坐标
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
# 获取类别ID和置信度
cls_id = int(box.cls[0].item())
conf = box.conf[0].item()
# 获取类别名称并转换为中文
cls_name = result.names[cls_id]
chinese_name = self.class_name_mapping.get(cls_name, cls_name)
# 获取该类别的颜色
color = self.color_mapping.get(cls_name, (255, 255, 255))
# 绘制边框
draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=3)
# 准备标签文本
text = f"{chinese_name} {conf:.2f}"
text_size = draw.textbbox((0, 0), text, font=self.font)
text_width = text_size[2] - text_size[0]
text_height = text_size[3] - text_size[1]
# 绘制标签背景(使用与边框相同的颜色)
draw.rectangle(
[(x1, y1 - text_height - 4), (x1 + text_width, y1)],
fill=color
)
# 绘制白色文本
draw.text(
(x1, y1 - text_height - 2),
text,
fill=(255, 255, 255), # 白色文本
font=self.font
)
# 转换回OpenCV格式
return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
except Exception as e:
print(f"Detection error: {e}")
return frame