参数分离完全版
This commit is contained in:
31
README.md
Normal file
31
README.md
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
### models 目录
|
||||||
|
|
||||||
|
- 存放模型文件
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### frame_transfer.py
|
||||||
|
|
||||||
|
- 从检测结果队列推送数据到 RTMP 服务器【不必修改】
|
||||||
|
- 从原始队列拿取数据、调用 yolo_core 封装的方法进行检测【四类】
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### rtc_handler.py
|
||||||
|
|
||||||
|
- 从 WebRTC 实时视频流截取帧并持续推送到原始队列【不必修改】
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### yolo_core.py
|
||||||
|
|
||||||
|
- 封装四类方法【参数均为原始队列、检测结果队列】
|
||||||
|
- 方法一:原始YOLO检测
|
||||||
|
- 方法二:原始YOLO检测 + 汉化 + 颜色
|
||||||
|
- 方法三:原始累计计数
|
||||||
|
- 方法四:原始累计计数 + 汉化 + 颜色
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
**读取 WebRTC 流和推送结果帧的代码不需要修改**
|
||||||
|
|
389
api_server.py
Normal file
389
api_server.py
Normal file
@ -0,0 +1,389 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from typing import Dict, Optional, Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, File, UploadFile, HTTPException, Form
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from result import Response
|
||||||
|
# 假设 rfdetr_core.py 在同一目录下或 PYTHONPATH 中
|
||||||
|
from rfdetr_core import RFDETRDetector
|
||||||
|
|
||||||
|
# --- Global Variables and Configuration ---
|
||||||
|
model_management_router = APIRouter()
|
||||||
|
|
||||||
|
BASE_MODEL_DIR = "models"
|
||||||
|
BASE_CONFIG_DIR = "configs"
|
||||||
|
os.makedirs(BASE_MODEL_DIR, exist_ok=True)
|
||||||
|
os.makedirs(BASE_CONFIG_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
# 用于存储当前激活的检测器实例和可用模型信息
|
||||||
|
current_detector: Optional[RFDETRDetector] = None
|
||||||
|
current_model_identifier: Optional[str] = None
|
||||||
|
available_models_info: Dict[str, Dict] = {} # 存储模型标识符及其配置内容
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(config_name: str) -> Optional[Dict]:
|
||||||
|
"""加载指定的JSON配置文件。"""
|
||||||
|
config_path = os.path.join(BASE_CONFIG_DIR, f"{config_name}.json")
|
||||||
|
if os.path.exists(config_path):
|
||||||
|
try:
|
||||||
|
with open(config_path, 'r', encoding='utf-8') as f:
|
||||||
|
return json.load(f)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"错误:加载配置文件 '{config_path}' 失败: {e}")
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_detector(model_identifier: str) -> Optional[RFDETRDetector]:
|
||||||
|
"""根据模型标识符初始化检测器。"""
|
||||||
|
global current_detector, current_model_identifier
|
||||||
|
try:
|
||||||
|
config = available_models_info.get(model_identifier)
|
||||||
|
if not config:
|
||||||
|
print(f"错误:未找到模型 '{model_identifier}' 的缓存配置,尝试从磁盘加载。") # 更明确的日志
|
||||||
|
config_data = load_config(model_identifier)
|
||||||
|
if not config_data:
|
||||||
|
raise FileNotFoundError(f"配置文件 {model_identifier}.json 未找到或无法加载。")
|
||||||
|
available_models_info[model_identifier] = config_data
|
||||||
|
config = config_data
|
||||||
|
|
||||||
|
model_filename = config.get('model_pth_filename')
|
||||||
|
if not model_filename:
|
||||||
|
raise ValueError(f"配置文件 '{model_identifier}.json' 中缺少 'model_pth_filename' 字段。")
|
||||||
|
|
||||||
|
model_full_path = os.path.join(BASE_MODEL_DIR, model_filename)
|
||||||
|
if not os.path.exists(model_full_path):
|
||||||
|
raise FileNotFoundError(f"模型文件 '{model_full_path}' (在 '{model_identifier}.json' 中指定) 不存在。")
|
||||||
|
|
||||||
|
print(f"尝试使用配置 '{model_identifier}.json' 和模型 '{model_full_path}' 初始化检测器...")
|
||||||
|
detector = RFDETRDetector(config_name=model_identifier, base_model_dir=BASE_MODEL_DIR,
|
||||||
|
base_config_dir=BASE_CONFIG_DIR)
|
||||||
|
print(f"检测器 '{model_identifier}' 初始化成功。")
|
||||||
|
current_detector = detector
|
||||||
|
current_model_identifier = model_identifier
|
||||||
|
|
||||||
|
# 通知 DataPusher 更新其检测器实例
|
||||||
|
try:
|
||||||
|
from data_pusher import get_data_pusher_instance # 动态导入以避免潜在的循环依赖问题
|
||||||
|
pusher = get_data_pusher_instance()
|
||||||
|
if pusher:
|
||||||
|
print(f"通知 DataPusher 更新其检测器实例为: {model_identifier}")
|
||||||
|
pusher.update_detector_instance(current_detector)
|
||||||
|
# else:
|
||||||
|
# 如果 pusher 为 None,可能是因为它尚未在主应用启动时被完全初始化
|
||||||
|
# data_pusher 模块内部的 initialize_data_pusher 负责记录其自身初始化状态
|
||||||
|
# print(f"DataPusher 实例尚未初始化 (或初始化失败),无法更新其检测器。")
|
||||||
|
except ImportError:
|
||||||
|
print(
|
||||||
|
"警告: 无法导入 data_pusher 模块以更新 DataPusher 的检测器实例。如果您不使用数据推送功能,此消息可忽略。")
|
||||||
|
except Exception as e_pusher:
|
||||||
|
print(f"警告: 通知 DataPusher 更新检测器时发生意外错误: {e_pusher}")
|
||||||
|
|
||||||
|
return detector
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
print(f"初始化检测器 '{model_identifier}' 失败 (文件未找到): {e}")
|
||||||
|
if current_model_identifier == model_identifier:
|
||||||
|
current_detector = None
|
||||||
|
current_model_identifier = None
|
||||||
|
raise HTTPException(status_code=404, detail=str(e)) # 保持 404,让上层知道是文件问题
|
||||||
|
except ValueError as e: # 捕获配置字段缺失等问题
|
||||||
|
print(f"初始化检测器 '{model_identifier}' 失败 (配置值错误): {e}")
|
||||||
|
if current_model_identifier == model_identifier:
|
||||||
|
current_detector = None
|
||||||
|
current_model_identifier = None
|
||||||
|
raise HTTPException(status_code=400, detail=str(e)) # 配置问题用 400
|
||||||
|
except Exception as e: # 其他来自 RFDETRDetector 内部的初始化错误
|
||||||
|
print(f"初始化检测器 '{model_identifier}' 失败 (内部错误): {e}")
|
||||||
|
# import traceback
|
||||||
|
# traceback.print_exc() # 在服务器日志中打印完整堆栈,方便调试
|
||||||
|
if current_model_identifier == model_identifier:
|
||||||
|
current_detector = None
|
||||||
|
current_model_identifier = None
|
||||||
|
if model_identifier in available_models_info:
|
||||||
|
del available_models_info[model_identifier]
|
||||||
|
# 这些通常是服务器端问题或模型/库的兼容性问题,所以用500
|
||||||
|
raise HTTPException(status_code=500, detail=f"检测器 '{model_identifier}' 内部初始化失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_active_detector() -> Optional[RFDETRDetector]:
|
||||||
|
"""获取当前激活的检测器实例。"""
|
||||||
|
return current_detector
|
||||||
|
|
||||||
|
|
||||||
|
def get_active_model_identifier() -> Optional[str]:
|
||||||
|
"""获取当前激活的模型标识符。"""
|
||||||
|
return current_model_identifier
|
||||||
|
|
||||||
|
|
||||||
|
def scan_and_load_available_models():
|
||||||
|
"""扫描配置目录,加载所有有效的模型配置。"""
|
||||||
|
global available_models_info
|
||||||
|
available_models_info = {}
|
||||||
|
# print(f"扫描目录 '{BASE_CONFIG_DIR}' 以查找配置文件...") # 减少日志冗余
|
||||||
|
if not os.path.exists(BASE_CONFIG_DIR):
|
||||||
|
# print(f"配置目录 '{BASE_CONFIG_DIR}' 不存在,跳过扫描。")
|
||||||
|
return
|
||||||
|
for filename in os.listdir(BASE_CONFIG_DIR):
|
||||||
|
if filename.endswith(".json"):
|
||||||
|
model_identifier = filename[:-5]
|
||||||
|
# print(f"找到配置文件: {filename},模型标识符: {model_identifier}")
|
||||||
|
config_data = load_config(model_identifier)
|
||||||
|
if config_data:
|
||||||
|
model_pth = config_data.get('model_pth_filename')
|
||||||
|
model_full_path = os.path.join(BASE_MODEL_DIR, model_pth) if model_pth else None
|
||||||
|
if model_pth and os.path.exists(model_full_path):
|
||||||
|
available_models_info[model_identifier] = config_data
|
||||||
|
# print(f"模型配置 '{model_identifier}' 加载成功。")
|
||||||
|
# elif not model_pth:
|
||||||
|
# print(f"警告:模型 '{model_identifier}' 的配置文件中未指定 'model_pth_filename'。")
|
||||||
|
# else:
|
||||||
|
# print(f"警告:模型 '{model_identifier}' 的模型文件 '{model_full_path}' 未找到。")
|
||||||
|
# else:
|
||||||
|
# print(f"警告:无法加载模型 '{model_identifier}' 的配置文件。")
|
||||||
|
print(f"可用模型配置已扫描/更新: {list(available_models_info.keys())}")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Extracted Startup Logic (to be called by the main app) ---
|
||||||
|
async def initialize_default_model_on_startup():
|
||||||
|
"""应用启动时,扫描并加载可用模型,尝试激活第一个。"""
|
||||||
|
print("执行模型管理模块的启动初始化...")
|
||||||
|
scan_and_load_available_models()
|
||||||
|
global current_model_identifier, current_detector
|
||||||
|
|
||||||
|
if available_models_info:
|
||||||
|
first_model = sorted(list(available_models_info.keys()))[0]
|
||||||
|
print(f"尝试将第一个可用模型 '{first_model}' 设置为活动模型。")
|
||||||
|
try:
|
||||||
|
initialize_detector(first_model)
|
||||||
|
print(f"默认模型 '{current_model_identifier}' 加载并激活成功。")
|
||||||
|
except HTTPException as e: # 捕获 initialize_detector 抛出的HTTPException
|
||||||
|
print(f"加载默认模型 '{first_model}' 失败: {e.detail} (状态码: {e.status_code})")
|
||||||
|
current_model_identifier = None
|
||||||
|
current_detector = None
|
||||||
|
# 不需要再捕获 Exception as e,因为 initialize_detector 已经处理并转换为 HTTPException
|
||||||
|
else:
|
||||||
|
print("没有可用的模型配置,服务器启动但无默认模型激活。")
|
||||||
|
current_model_identifier = None
|
||||||
|
current_detector = None
|
||||||
|
|
||||||
|
|
||||||
|
# --- Pydantic Models for Request/Response ---
|
||||||
|
|
||||||
|
class ModelIdentifier(BaseModel):
|
||||||
|
model_identifier: str
|
||||||
|
|
||||||
|
|
||||||
|
# Define a standard response model for OpenAPI documentation
|
||||||
|
class StandardResponse(BaseModel):
|
||||||
|
code: int
|
||||||
|
data: Optional[Any] = None
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
# --- API Endpoints ---
|
||||||
|
|
||||||
|
@model_management_router.post("/upload_model_and_config/", response_model=StandardResponse)
|
||||||
|
async def upload_model_and_config(
|
||||||
|
model_identifier_form: str = Form(..., alias="model_identifier"),
|
||||||
|
config_file: UploadFile = File(...),
|
||||||
|
model_file: UploadFile = File(...)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
上传模型文件 (.pth) 和配置文件 (.json)。
|
||||||
|
- **model_identifier**: 模型的唯一名称 (例如 "人车检测")。配置文件将以此名称保存 (e.g., "人车检测.json")。
|
||||||
|
- **config_file**: JSON 配置文件。
|
||||||
|
- **model_file**: Pytorch 模型文件 (.pth)。其文件名必须与配置文件中 'model_pth_filename' 字段指定的一致。
|
||||||
|
"""
|
||||||
|
config_filename = f"{model_identifier_form}.json"
|
||||||
|
config_path = os.path.join(BASE_CONFIG_DIR, config_filename)
|
||||||
|
model_path = None # 在 try 块外部声明,以便 finally 和 except 中可用
|
||||||
|
global current_model_identifier, current_detector # current_detector 实际上在这里主要由 initialize_detector 设置
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(f"开始上传模型 '{model_identifier_form}'...")
|
||||||
|
config_content = await config_file.read()
|
||||||
|
try:
|
||||||
|
config_data = json.loads(config_content.decode('utf-8'))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# raise HTTPException(status_code=400, detail="无效的JSON配置文件格式,请检查JSON语法。")
|
||||||
|
return JSONResponse(status_code=400,
|
||||||
|
content=Response.error(message="无效的JSON配置文件格式,请检查JSON语法。", code=400))
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
# raise HTTPException(status_code=400, detail="配置文件编码错误,请确保为UTF-8编码。")
|
||||||
|
return JSONResponse(status_code=400,
|
||||||
|
content=Response.error(message="配置文件编码错误,请确保为UTF-8编码。", code=400))
|
||||||
|
|
||||||
|
if 'model_pth_filename' not in config_data:
|
||||||
|
# raise HTTPException(status_code=400, detail="配置文件中必须包含 'model_pth_filename' 字段。")
|
||||||
|
return JSONResponse(status_code=400,
|
||||||
|
content=Response.error(message="配置文件中必须包含 'model_pth_filename' 字段。",
|
||||||
|
code=400))
|
||||||
|
|
||||||
|
target_model_filename = config_data['model_pth_filename']
|
||||||
|
if not isinstance(target_model_filename, str) or not target_model_filename.endswith(".pth"):
|
||||||
|
# raise HTTPException(status_code=400, detail="配置文件中的 'model_pth_filename' 必须是有效的 .pth 文件名字符串。")
|
||||||
|
return JSONResponse(status_code=400, content=Response.error(
|
||||||
|
message="配置文件中的 'model_pth_filename' 必须是有效的 .pth 文件名字符串。", code=400))
|
||||||
|
|
||||||
|
if model_file.filename != target_model_filename:
|
||||||
|
# raise HTTPException(
|
||||||
|
# status_code=400,
|
||||||
|
# detail=f"上传的模型文件名 '{model_file.filename}' 与配置文件中指定的 '{target_model_filename}' 不匹配。"
|
||||||
|
# )
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content=Response.error(
|
||||||
|
message=f"上传的模型文件名 '{model_file.filename}' 与配置文件中指定的 '{target_model_filename}' 不匹配。",
|
||||||
|
code=400
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(config_path, 'wb') as f:
|
||||||
|
f.write(config_content)
|
||||||
|
print(f"配置文件 '{config_path}' 已保存。")
|
||||||
|
|
||||||
|
model_path = os.path.join(BASE_MODEL_DIR, target_model_filename)
|
||||||
|
with open(model_path, "wb") as buffer:
|
||||||
|
shutil.copyfileobj(model_file.file, buffer)
|
||||||
|
print(f"模型文件 '{model_path}' 已保存。")
|
||||||
|
|
||||||
|
# 在尝试初始化之前,将配置数据添加到 available_models_info
|
||||||
|
# 这样 initialize_detector 即使在缓存未命中的情况下也能通过 load_config 找到它
|
||||||
|
available_models_info[model_identifier_form] = config_data
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(f"尝试初始化并激活新上传的模型: '{model_identifier_form}'")
|
||||||
|
initialize_detector(model_identifier_form) # 此函数会设置全局的 current_detector 和 current_model_identifier
|
||||||
|
print(f"新上传的模型 '{model_identifier_form}' 验证并激活成功。")
|
||||||
|
# 成功后,available_models_info 已包含此模型,current_detector 和 current_model_identifier 已更新
|
||||||
|
# 无需在此处再次调用 scan_and_load_available_models(),因为它会重新扫描所有,可能覆盖内存中的一些状态或引入不必要的IO
|
||||||
|
except HTTPException as e: # 从 initialize_detector 抛出的错误
|
||||||
|
print(f"初始化新上传的模型 '{model_identifier_form}' 失败: {e.detail}")
|
||||||
|
# 清理已保存的文件和内存中的条目
|
||||||
|
if os.path.exists(config_path): os.remove(config_path)
|
||||||
|
if model_path and os.path.exists(model_path): os.remove(model_path)
|
||||||
|
if model_identifier_form in available_models_info: del available_models_info[model_identifier_form]
|
||||||
|
# scan_and_load_available_models() # 失败后重新扫描是好的,以确保 available_models_info 准确
|
||||||
|
# 但如果 initialize_detector 内部已经删除了 available_models_info 中的条目,这里可能不需要
|
||||||
|
# 为保持一致性,如果上面删除了,这里重新扫描一下比较稳妥
|
||||||
|
scan_and_load_available_models()
|
||||||
|
|
||||||
|
# 重新抛出为 422,并包含原始错误信息
|
||||||
|
# raise HTTPException(status_code=422,
|
||||||
|
# detail=f"模型 '{model_identifier_form}' 已上传但初始化失败,请检查模型或配置内容。原始错误: {e.detail}")
|
||||||
|
return JSONResponse(status_code=422,
|
||||||
|
content=Response.error(
|
||||||
|
message=f"模型 '{model_identifier_form}' 已上传但初始化失败,请检查模型或配置内容。原始错误: {e.detail}",
|
||||||
|
code=422
|
||||||
|
))
|
||||||
|
|
||||||
|
# 如果成功,确保全局模型列表是最新的(虽然 initialize_detector 已更新了 current_*,但列表可能需要刷新以供 /available_models 使用)
|
||||||
|
# scan_and_load_available_models() # 移除这个,因为当前模型已激活,列表会在下次调用 /available_models 时刷新
|
||||||
|
|
||||||
|
# return UploadResponse(
|
||||||
|
# message=f"模型 '{model_identifier_form}' 和配置文件上传成功并已激活。",
|
||||||
|
# model_identifier=model_identifier_form,
|
||||||
|
# config_filename=config_filename,
|
||||||
|
# model_filename=target_model_filename
|
||||||
|
# )
|
||||||
|
return Response.success(data={
|
||||||
|
"message": f"模型 '{model_identifier_form}' 和配置文件上传成功并已激活。", # Message is also in the wrapper
|
||||||
|
"model_identifier": model_identifier_form,
|
||||||
|
"config_filename": config_filename,
|
||||||
|
"model_filename": target_model_filename
|
||||||
|
}, message=f"模型 '{model_identifier_form}' 和配置文件上传成功并已激活。")
|
||||||
|
|
||||||
|
except HTTPException as e: # 捕获直接由 FastAPI 验证或其他地方抛出的 HTTPException
|
||||||
|
# This might catch exceptions from initialize_detector if they are not caught internally by the above try-except for initialize_detector
|
||||||
|
return JSONResponse(status_code=e.status_code, content=Response.error(message=e.detail, code=e.status_code))
|
||||||
|
except Exception as e:
|
||||||
|
identifier_for_log = model_identifier_form if 'model_identifier_form' in locals() else "unknown"
|
||||||
|
print(f"上传模型 '{identifier_for_log}' 过程中发生意外的严重错误: {e}")
|
||||||
|
# import traceback
|
||||||
|
# traceback.print_exc()
|
||||||
|
# 尝试清理,以防文件部分写入
|
||||||
|
if config_path and os.path.exists(config_path):
|
||||||
|
os.remove(config_path)
|
||||||
|
if model_path and os.path.exists(model_path):
|
||||||
|
os.remove(model_path)
|
||||||
|
if 'model_identifier_form' in locals() and model_identifier_form in available_models_info:
|
||||||
|
del available_models_info[model_identifier_form]
|
||||||
|
scan_and_load_available_models() # 出错后务必刷新列表
|
||||||
|
|
||||||
|
# raise HTTPException(status_code=500, detail=f"上传模型 '{identifier_for_log}' 过程中发生内部服务器错误: {str(e)}")
|
||||||
|
return JSONResponse(status_code=500, content=Response.error(
|
||||||
|
message=f"上传模型 '{identifier_for_log}' 过程中发生内部服务器错误: {str(e)}", code=500))
|
||||||
|
finally:
|
||||||
|
if config_file: await config_file.close()
|
||||||
|
if model_file: await model_file.close()
|
||||||
|
|
||||||
|
|
||||||
|
@model_management_router.post("/select_model/", response_model=StandardResponse)
|
||||||
|
# @model_management_router.post("/select_model/{model_identifier_path}", response_model=StandardResponse)
|
||||||
|
async def select_model(model_identifier_path: str):
|
||||||
|
"""
|
||||||
|
根据提供的标识符选择并激活一个已上传的模型。
|
||||||
|
路径参数 `model_identifier_path` 即为模型的唯一名称。
|
||||||
|
"""
|
||||||
|
global current_model_identifier, current_detector
|
||||||
|
|
||||||
|
# 总是先扫描以获取最新的可用模型列表,以防外部文件更改
|
||||||
|
scan_and_load_available_models()
|
||||||
|
if model_identifier_path not in available_models_info:
|
||||||
|
# raise HTTPException(status_code=404, detail=f"模型标识符 '{model_identifier_path}' 未在可用配置中找到。请确保已上传或名称正确。")
|
||||||
|
return JSONResponse(status_code=404, content=Response.error(
|
||||||
|
message=f"模型标识符 '{model_identifier_path}' 未在可用配置中找到。请确保已上传或名称正确。", code=404))
|
||||||
|
|
||||||
|
if current_model_identifier == model_identifier_path and current_detector is not None:
|
||||||
|
print(f"模型 '{model_identifier_path}' 已经是活动模型。")
|
||||||
|
# return SelectModelResponse(
|
||||||
|
# message=f"模型 '{model_identifier_path}' 已是当前活动模型。",
|
||||||
|
# active_model=current_model_identifier
|
||||||
|
# )
|
||||||
|
return Response.success(data={
|
||||||
|
"active_model": current_model_identifier
|
||||||
|
}, message=f"模型 '{model_identifier_path}' 已是当前活动模型。")
|
||||||
|
try:
|
||||||
|
print(f"尝试激活模型: {model_identifier_path}")
|
||||||
|
initialize_detector(model_identifier_path)
|
||||||
|
print(f"模型 '{current_model_identifier}' 已成功激活。")
|
||||||
|
# return SelectModelResponse(
|
||||||
|
# message=f"模型 '{model_identifier_path}' 启动成功。",
|
||||||
|
# active_model=current_model_identifier
|
||||||
|
# )
|
||||||
|
return Response.success(data={
|
||||||
|
"active_model": current_model_identifier
|
||||||
|
}, message=f"模型 '{model_identifier_path}' 启动成功。")
|
||||||
|
except HTTPException as e:
|
||||||
|
# initialize_detector 内部已经处理了 current_model_identifier 和 current_detector 的清理
|
||||||
|
# 以及 available_models_info 中对应条目的移除(如果是其内部错误)
|
||||||
|
print(f"激活模型 '{model_identifier_path}' 失败: {e.detail} (原始状态码: {e.status_code})")
|
||||||
|
# 此处不再需要手动清理 available_models_info,因为 initialize_detector 如果因内部错误(非文件找不到)
|
||||||
|
# 导致无法实例化 RFDETRDetector,它会自己删除条目。
|
||||||
|
# 如果是文件找不到 (404 from initialize_detector),条目可能还在,但下次 scan 会处理。
|
||||||
|
scan_and_load_available_models() # 确保 select 失败后,列表也是最新的
|
||||||
|
# raise HTTPException(status_code=e.status_code, detail=e.detail) # 重新抛出原始的 HTTPException
|
||||||
|
return JSONResponse(status_code=e.status_code, content=Response.error(message=e.detail, code=e.status_code))
|
||||||
|
# 不需要再捕获 Exception as e,因为 initialize_detector 已经处理并转换为 HTTPException
|
||||||
|
|
||||||
|
|
||||||
|
@model_management_router.get("/available_models/", response_model=StandardResponse)
|
||||||
|
async def get_available_models():
|
||||||
|
"""列出所有当前可用的模型标识符。"""
|
||||||
|
scan_and_load_available_models()
|
||||||
|
# return list(available_models_info.keys())
|
||||||
|
return Response.success(data=list(available_models_info.keys()), message="成功获取可用模型列表。")
|
||||||
|
|
||||||
|
|
||||||
|
@model_management_router.get("/current_model/", response_model=StandardResponse)
|
||||||
|
async def get_current_model_endpoint():
|
||||||
|
"""获取当前激活的模型标识符。"""
|
||||||
|
# return current_model_identifier
|
||||||
|
if current_model_identifier:
|
||||||
|
return Response.success(data=current_model_identifier, message="成功获取当前激活的模型。")
|
||||||
|
else:
|
||||||
|
return Response.success(data=None, message="当前没有激活的模型。")
|
58
configs/人车.json
Normal file
58
configs/人车.json
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
{
|
||||||
|
"model_id": "1",
|
||||||
|
"model_pth_filename": "人车.pth",
|
||||||
|
"resolution": 448,
|
||||||
|
"classes_en": [
|
||||||
|
"pedestrian", "person", "bicycle", "car", "van", "truck", "tricycle", "awning-tricycle", "bus", "motor"
|
||||||
|
],
|
||||||
|
|
||||||
|
"classes_zh_map": {
|
||||||
|
"pedestrian":"行人",
|
||||||
|
"person": "人群",
|
||||||
|
"bicycle": "自行车",
|
||||||
|
"car": "小汽车",
|
||||||
|
"van": "面包车",
|
||||||
|
"truck": "卡车",
|
||||||
|
"tricycle":"三轮车",
|
||||||
|
"awning-tricycle":"篷式三轮车",
|
||||||
|
"bus": "公交车",
|
||||||
|
"motor":"摩托车"
|
||||||
|
},
|
||||||
|
|
||||||
|
"class_colors_hex": {
|
||||||
|
"pedestrian": "#470024",
|
||||||
|
"person": "#00FF00",
|
||||||
|
"bicycle": "#003153",
|
||||||
|
"car": "#002FA7",
|
||||||
|
"van": "#800080",
|
||||||
|
"truck": "#D44848",
|
||||||
|
"tricycle": "#003153",
|
||||||
|
"awning-tricycle": "#FBDC6A",
|
||||||
|
"bus": "#492D22",
|
||||||
|
"motor": "#01847F"
|
||||||
|
},
|
||||||
|
|
||||||
|
"detection_settings": {
|
||||||
|
"enabled_classes": {
|
||||||
|
"pedestrian": true,
|
||||||
|
"person": true,
|
||||||
|
"bicycle": false,
|
||||||
|
"car": true,
|
||||||
|
"van": true,
|
||||||
|
"truck": true,
|
||||||
|
"tricycle": false,
|
||||||
|
"awning-tricycle": false,
|
||||||
|
"bus": true,
|
||||||
|
"motor": true
|
||||||
|
},
|
||||||
|
"default_confidence_threshold": 0.7
|
||||||
|
},
|
||||||
|
|
||||||
|
"default_color_hex": "#00FF00",
|
||||||
|
"tracker_activation_threshold": 0.5,
|
||||||
|
"tracker_lost_buffer": 120,
|
||||||
|
"tracker_match_threshold": 0.85,
|
||||||
|
"tracker_frame_rate": 25,
|
||||||
|
"tracker_consecutive_frames": 2
|
||||||
|
}
|
||||||
|
|
248
data_pusher.py
Normal file
248
data_pusher.py
Normal file
@ -0,0 +1,248 @@
|
|||||||
|
import base64
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
import time
|
||||||
|
import datetime
|
||||||
|
import logging
|
||||||
|
from fastapi import APIRouter, HTTPException, Body # 切换到 APIRouter
|
||||||
|
from pydantic import BaseModel, HttpUrl
|
||||||
|
# import uvicorn # 不再由此文件运行uvicorn
|
||||||
|
from apscheduler.schedulers.background import BackgroundScheduler
|
||||||
|
from typing import Optional # 用于类型提示
|
||||||
|
# 确保 RFDETRDetector 可以被导入,假设 rfdetr_core.py 在同一目录或 PYTHONPATH 中
|
||||||
|
# from rfdetr_core import RFDETRDetector # 在实际使用中取消注释并确保路径正确
|
||||||
|
|
||||||
|
# 配置日志记录
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# app = FastAPI(title="Data Pusher Service", version="1.0.0") # 移除独立的 FastAPI app
|
||||||
|
pusher_router = APIRouter() # 创建一个 APIRouter
|
||||||
|
|
||||||
|
class DataPusher:
|
||||||
|
def __init__(self, detector): # detector: RFDETRDetector
|
||||||
|
if detector is None:
|
||||||
|
logger.error("DataPusher initialized with a None detector. Push functionality will be impaired.")
|
||||||
|
# 仍然创建实例,但功能会受限,_get_data_payload 会处理 detector is None
|
||||||
|
self.detector = detector
|
||||||
|
self.scheduler = BackgroundScheduler(daemon=True)
|
||||||
|
self.push_job_id = "rt_push_job"
|
||||||
|
self.target_url = None
|
||||||
|
if not self.scheduler.running:
|
||||||
|
try:
|
||||||
|
self.scheduler.start()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error starting APScheduler in DataPusher: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def update_detector_instance(self, detector):
|
||||||
|
"""允许在运行时更新检测器实例,例如当主应用切换模型时"""
|
||||||
|
logger.info(f"DataPusher's detector instance is being updated.")
|
||||||
|
self.detector = detector
|
||||||
|
if detector is None:
|
||||||
|
logger.warning("DataPusher's detector instance updated to None.")
|
||||||
|
|
||||||
|
def _get_data_payload(self):
|
||||||
|
"""获取当前的类别计数和最新标注的帧"""
|
||||||
|
if self.detector is None:
|
||||||
|
logger.warning("DataPusher: Detector not available. Cannot get data payload.")
|
||||||
|
return { # 即使检测器不可用,也返回一个结构,包含空数据
|
||||||
|
# "timestamp": time.time(),
|
||||||
|
"category_counts": {},
|
||||||
|
"frame_base64": None,
|
||||||
|
"error": "Detector not available"
|
||||||
|
}
|
||||||
|
|
||||||
|
category_counts = getattr(self.detector, 'category_counts', {})
|
||||||
|
# 如果 detector 存在但没有 last_annotated_frame (例如模型刚加载还没处理第一帧)
|
||||||
|
last_frame_np = getattr(self.detector, 'last_annotated_frame', None)
|
||||||
|
|
||||||
|
frame_base64 = None
|
||||||
|
if last_frame_np is not None and isinstance(last_frame_np, np.ndarray):
|
||||||
|
try:
|
||||||
|
_, buffer = cv2.imencode('.jpg', last_frame_np)
|
||||||
|
frame_base64 = base64.b64encode(buffer).decode('utf-8')
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error encoding frame to base64: {e}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
# "timestamp": time.time(),
|
||||||
|
"category_counts": category_counts,
|
||||||
|
"frame_base64": frame_base64
|
||||||
|
}
|
||||||
|
|
||||||
|
def _push_data_task(self):
|
||||||
|
"""执行数据推送的任务"""
|
||||||
|
if not self.target_url:
|
||||||
|
# logger.warning("Target URL not set. Skipping push task.") # 减少日志噪音,仅在初次设置时记录
|
||||||
|
return
|
||||||
|
|
||||||
|
payload = self._get_data_payload()
|
||||||
|
# if payload is None: # _get_data_payload 现在总会返回一个字典
|
||||||
|
# logger.warning("No payload to push.")
|
||||||
|
# return
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(self.target_url, json=payload, timeout=5)
|
||||||
|
response.raise_for_status()
|
||||||
|
logger.debug(f"Data pushed successfully to {self.target_url}. Status: {response.status_code}") # 改为 debug 级别
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
logger.error(f"Error pushing data to {self.target_url}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"An unexpected error occurred during data push: {e}")
|
||||||
|
|
||||||
|
def setup_push_schedule(self, frequency: float, target_url: str):
|
||||||
|
"""设置或更新推送计划"""
|
||||||
|
if not isinstance(frequency, (int, float)) or frequency <= 0:
|
||||||
|
raise ValueError("Frequency must be a positive number (pushes per second).")
|
||||||
|
|
||||||
|
self.target_url = str(target_url)
|
||||||
|
interval_seconds = 1.0 / frequency
|
||||||
|
|
||||||
|
if not self.scheduler.running: # 确保调度器正在运行
|
||||||
|
try:
|
||||||
|
logger.info("APScheduler was not running. Attempting to start it now.")
|
||||||
|
self.scheduler.start()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start APScheduler in setup_push_schedule: {e}")
|
||||||
|
raise RuntimeError(f"APScheduler could not be started: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.scheduler.get_job(self.push_job_id):
|
||||||
|
self.scheduler.remove_job(self.push_job_id)
|
||||||
|
logger.info(f"Removed existing push job: {self.push_job_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error removing existing job: {e}")
|
||||||
|
|
||||||
|
first_run_time = datetime.datetime.now() + datetime.timedelta(seconds=10)
|
||||||
|
self.scheduler.add_job(
|
||||||
|
self._push_data_task,
|
||||||
|
trigger='interval',
|
||||||
|
seconds=interval_seconds,
|
||||||
|
id=self.push_job_id,
|
||||||
|
next_run_time=first_run_time,
|
||||||
|
replace_existing=True
|
||||||
|
)
|
||||||
|
logger.info(f"Push task scheduled to {self.target_url} every {interval_seconds:.2f}s, starting in 10s.")
|
||||||
|
|
||||||
|
def stop_push_schedule(self):
|
||||||
|
"""停止数据推送任务"""
|
||||||
|
if self.scheduler.get_job(self.push_job_id):
|
||||||
|
try:
|
||||||
|
self.scheduler.remove_job(self.push_job_id)
|
||||||
|
logger.info(f"Push job {self.push_job_id} stopped successfully.")
|
||||||
|
self.target_url = None # 清除目标 URL
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error stopping push job {self.push_job_id}: {e}")
|
||||||
|
else:
|
||||||
|
logger.info("No active push job to stop.")
|
||||||
|
|
||||||
|
def shutdown_scheduler(self):
|
||||||
|
"""安全关闭调度器"""
|
||||||
|
if self.scheduler.running:
|
||||||
|
try:
|
||||||
|
self.scheduler.shutdown()
|
||||||
|
logger.info("DataPusher's APScheduler shut down successfully.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error shutting down DataPusher's APScheduler: {e}")
|
||||||
|
|
||||||
|
def push_specific_payload(self, payload: dict):
|
||||||
|
"""推送一个特定的、已格式化的数据负载到配置的 target_url。"""
|
||||||
|
if not self.target_url:
|
||||||
|
logger.warning("DataPusher: Target URL not set. Cannot push specific payload.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not payload:
|
||||||
|
logger.warning("DataPusher: Received empty payload for specific push. Skipping.")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Attempting to push specific payload to {self.target_url}")
|
||||||
|
try:
|
||||||
|
response = requests.post(self.target_url, json=payload, timeout=10) # Increased timeout for one-off
|
||||||
|
response.raise_for_status()
|
||||||
|
logger.info(f"Specific payload pushed successfully to {self.target_url}. Status: {response.status_code}")
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
logger.error(f"Error pushing specific payload to {self.target_url}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"An unexpected error occurred during specific payload push: {e}")
|
||||||
|
|
||||||
|
# 全局 DataPusher 实例,将由主应用初始化
|
||||||
|
data_pusher_instance: Optional[DataPusher] = None
|
||||||
|
|
||||||
|
# --- FastAPI Request Body Model ---
|
||||||
|
class PushConfigRequest(BaseModel):
|
||||||
|
frequency: float
|
||||||
|
url: HttpUrl
|
||||||
|
|
||||||
|
# --- FastAPI HTTP Endpoint (using APIRouter) ---
|
||||||
|
@pusher_router.post("/setup_push", summary="配置数据推送任务")
|
||||||
|
async def handle_setup_push(config: PushConfigRequest = Body(...)):
|
||||||
|
global data_pusher_instance
|
||||||
|
if data_pusher_instance is None:
|
||||||
|
# 这个错误理论上不应该发生,如果主应用正确初始化了 data_pusher_instance
|
||||||
|
logger.error("CRITICAL: /setup_push called but data_pusher_instance is None. Main app did not initialize it.")
|
||||||
|
raise HTTPException(status_code=503, detail="DataPusher service not available. Initialization may have failed.")
|
||||||
|
|
||||||
|
if config.frequency <= 0: # Pydantic v2 中可以直接用 gt=0
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid frequency value. Must be a positive number.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
data_pusher_instance.setup_push_schedule(config.frequency, str(config.url))
|
||||||
|
return {
|
||||||
|
"message": "Push task configured successfully.",
|
||||||
|
"frequency_hz": config.frequency,
|
||||||
|
"interval_seconds": 1.0 / config.frequency,
|
||||||
|
"target_url": str(config.url),
|
||||||
|
"first_push_delay_seconds": 10
|
||||||
|
}
|
||||||
|
except ValueError as ve:
|
||||||
|
raise HTTPException(status_code=400, detail=str(ve))
|
||||||
|
except RuntimeError as re: # 例如 APScheduler 启动失败
|
||||||
|
logger.error(f"Runtime error during push schedule setup: {re}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(re))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error setting up push schedule: {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||||
|
|
||||||
|
@pusher_router.post("/stop_push", summary="停止当前数据推送任务")
|
||||||
|
async def handle_stop_push():
|
||||||
|
global data_pusher_instance
|
||||||
|
if data_pusher_instance is None:
|
||||||
|
logger.error("CRITICAL: /stop_push called but data_pusher_instance is None.")
|
||||||
|
raise HTTPException(status_code=503, detail="DataPusher service not available.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
data_pusher_instance.stop_push_schedule()
|
||||||
|
return {"message": "Push task stopped successfully if it was running."}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error stopping push schedule: {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=f"Internal server error while stopping push: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Initialization and Shutdown Functions for Main App ---
|
||||||
|
def initialize_data_pusher(detector_instance_param): # Renamed to avoid conflict
|
||||||
|
"""
|
||||||
|
由主应用程序调用以创建和配置 DataPusher 实例。
|
||||||
|
"""
|
||||||
|
global data_pusher_instance
|
||||||
|
if data_pusher_instance is None:
|
||||||
|
logger.info("Initializing DataPusher instance...")
|
||||||
|
data_pusher_instance = DataPusher(detector_instance_param)
|
||||||
|
else:
|
||||||
|
logger.info("DataPusher instance already initialized. Updating detector instance if provided.")
|
||||||
|
data_pusher_instance.update_detector_instance(detector_instance_param)
|
||||||
|
return data_pusher_instance
|
||||||
|
|
||||||
|
def get_data_pusher_instance() -> Optional[DataPusher]:
|
||||||
|
"""获取 DataPusher 实例 (主要用于主应用可能需要访问它的其他方法,如 shutdown)"""
|
||||||
|
return data_pusher_instance
|
||||||
|
|
||||||
|
# 移除 if __name__ == '__main__' 和 run_pusher_service,因为不再独立运行
|
||||||
|
# 示例代码可以移至主应用的文档或测试脚本中。
|
||||||
|
|
||||||
|
# 注意:
|
||||||
|
# RFDETRDetector 实例的生命周期由 api_server.py (current_detector) 管理。
|
||||||
|
# 当 api_server.py 中的模型切换时,需要有一种机制来更新 DataPusher 内部的 detector 引用。
|
||||||
|
# initialize_data_pusher 可以被多次调用 (例如,在模型切换后),它会更新 DataPusher 持有的 detector 实例。
|
BIN
font/MSYH.TTC
Normal file
BIN
font/MSYH.TTC
Normal file
Binary file not shown.
226
frame_transfer.py
Normal file
226
frame_transfer.py
Normal file
@ -0,0 +1,226 @@
|
|||||||
|
import base64
|
||||||
|
import requests as http_client
|
||||||
|
import queue
|
||||||
|
import time # 确保 time 被导入,如果之前被误删
|
||||||
|
import cv2
|
||||||
|
import traceback
|
||||||
|
import threading # 确保导入 threading
|
||||||
|
import av # 重新导入 av
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# from fastapi import requests
|
||||||
|
|
||||||
|
# 从 rfdetr_core 导入 RFDETRDetector 仅用于类型提示 (可选)
|
||||||
|
from rfdetr_core import RFDETRDetector
|
||||||
|
|
||||||
|
|
||||||
|
# 目标检测处理函数
|
||||||
|
# 函数签名已更改:现在接受一个 detector_instance 作为参数
|
||||||
|
def yolo_frame(rtc_q: queue.Queue, yolo_q: queue.Queue, stream_detector_instance: RFDETRDetector):
|
||||||
|
thread_name = threading.current_thread().name # 获取线程名称用于日志
|
||||||
|
print(f"处理线程 '{thread_name}' 已启动。")
|
||||||
|
error_message_displayed_once = False
|
||||||
|
no_detector_message_displayed_once = False # 用于只提示一次没有检测器
|
||||||
|
|
||||||
|
if stream_detector_instance is None:
|
||||||
|
print(f"错误 (线程 '{thread_name}'): 未提供有效的检测器实例给yolo_frame。线程将无法处理视频。")
|
||||||
|
# 此线程实际上无法做任何有用的工作,可以考虑直接退出或进入一个安全循环
|
||||||
|
# 为简单起见,我们允许它进入主循环,但它会在每次迭代时打印警告
|
||||||
|
|
||||||
|
while True:
|
||||||
|
frame = None
|
||||||
|
# current_category_counts = {} # 将在获取后转换
|
||||||
|
try:
|
||||||
|
# 恢复队列长度打印
|
||||||
|
print(f"线程 '{thread_name}' - 原始队列长度: {rtc_q.qsize()}, 检测队列长度: {yolo_q.qsize()}")
|
||||||
|
|
||||||
|
frame = rtc_q.get(timeout=0.1)
|
||||||
|
if frame is None:
|
||||||
|
print(f"处理线程 '{thread_name}' 接收到停止信号,正在退出...")
|
||||||
|
# 发送包含None frame和空计数的字典作为停止信号
|
||||||
|
yolo_q.put({"frame": None, "category_counts": {}})
|
||||||
|
break
|
||||||
|
|
||||||
|
category_counts_for_packet = {}
|
||||||
|
if stream_detector_instance:
|
||||||
|
no_detector_message_displayed_once = False # 检测器有效,重置提示
|
||||||
|
annotated_frame = stream_detector_instance.detect_and_draw_count(frame)
|
||||||
|
error_message_displayed_once = False
|
||||||
|
|
||||||
|
# 获取英文键的类别计数
|
||||||
|
english_counts = stream_detector_instance.category_counts.copy() if hasattr(stream_detector_instance, 'category_counts') else {}
|
||||||
|
|
||||||
|
# 转换为中文键的类别计数
|
||||||
|
if hasattr(stream_detector_instance, 'VISDRONE_CLASSES_CHINESE'):
|
||||||
|
chinese_map = stream_detector_instance.VISDRONE_CLASSES_CHINESE
|
||||||
|
for eng_key, count_val in english_counts.items():
|
||||||
|
# 使用 get 提供一个默认值,以防某个英文类别在中文映射表中确实没有
|
||||||
|
chi_key = chinese_map.get(eng_key, eng_key)
|
||||||
|
category_counts_for_packet[chi_key] = count_val
|
||||||
|
else:
|
||||||
|
# 如果没有中文映射表,则直接使用英文计数 (或记录警告)
|
||||||
|
category_counts_for_packet = english_counts
|
||||||
|
# logger.warning(f"线程 '{thread_name}': stream_detector_instance 没有 VISDRONE_CLASSES_CHINESE 属性,将使用英文类别计数。")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# 如果没有有效的检测器实例传递进来
|
||||||
|
if not no_detector_message_displayed_once:
|
||||||
|
print(f"警告 (线程 '{thread_name}'): 无有效检测器实例。将在帧上绘制提示。")
|
||||||
|
no_detector_message_displayed_once = True
|
||||||
|
|
||||||
|
annotated_frame = frame.copy()
|
||||||
|
cv2.putText(annotated_frame,
|
||||||
|
"No detector instance provided for this stream",
|
||||||
|
(30, 50), cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
1, (0, 0, 255), 2, cv2.LINE_AA)
|
||||||
|
category_counts_for_packet = {} # 无检测器,计数为空
|
||||||
|
|
||||||
|
# 将帧和类别计数一起放入队列
|
||||||
|
data_packet = {"frame": annotated_frame, "category_counts": category_counts_for_packet}
|
||||||
|
try:
|
||||||
|
yolo_q.put_nowait(data_packet)
|
||||||
|
except queue.Full:
|
||||||
|
# print(f"警告 (线程 '{thread_name}'): yolo_q 已满,丢弃帧。") # 避免刷屏
|
||||||
|
pass
|
||||||
|
|
||||||
|
except queue.Empty:
|
||||||
|
time.sleep(0.01)
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
if not error_message_displayed_once:
|
||||||
|
print(f"线程 '{thread_name}' (yolo_frame) 处理时发生严重错误: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
error_message_displayed_once = True
|
||||||
|
time.sleep(1)
|
||||||
|
if frame is not None:
|
||||||
|
try:
|
||||||
|
pass
|
||||||
|
except queue.Full:
|
||||||
|
pass
|
||||||
|
continue
|
||||||
|
print(f"处理线程 '{thread_name}' 已停止。")
|
||||||
|
|
||||||
|
|
||||||
|
def push_frame(yolo_q: queue.Queue, rtmp_url: str, gateway: str, frequency: int, push_url: str):
|
||||||
|
thread_name = threading.current_thread().name
|
||||||
|
print(f"推流线程 '{thread_name}' (RTMP: {rtmp_url}) 已启动。")
|
||||||
|
|
||||||
|
output_container = None
|
||||||
|
stream = None
|
||||||
|
first_frame_processed = False
|
||||||
|
last_push_time = 0 # 记录上次推送base64的时间
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
frame_to_push = None
|
||||||
|
received_category_counts = {} # 初始化为空字典
|
||||||
|
try:
|
||||||
|
data_packet = yolo_q.get(timeout=0.1)
|
||||||
|
if data_packet:
|
||||||
|
frame_to_push = data_packet.get("frame")
|
||||||
|
received_category_counts = data_packet.get("category_counts", {})
|
||||||
|
else: # data_packet is None (不太可能,除非队列明确放入None)
|
||||||
|
time.sleep(0.01)
|
||||||
|
continue
|
||||||
|
|
||||||
|
except queue.Empty:
|
||||||
|
time.sleep(0.01)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if frame_to_push is None: # 这是通过 data_packet["frame"] is None 来判断的
|
||||||
|
print(f"推流线程 '{thread_name}' 接收到停止信号,正在清理并退出...")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not first_frame_processed:
|
||||||
|
if frame_to_push is not None:
|
||||||
|
try:
|
||||||
|
height, width, _ = frame_to_push.shape
|
||||||
|
print(f"线程 '{thread_name}': 首帧尺寸 {width}x{height},正在初始化RTMP推流器到 {rtmp_url}")
|
||||||
|
output_container = av.open(rtmp_url, 'w', format='flv')
|
||||||
|
stream = output_container.add_stream('libx264', rate=25)
|
||||||
|
stream.pix_fmt = 'yuv420p'
|
||||||
|
stream.width = width
|
||||||
|
stream.height = height
|
||||||
|
stream.options = {'preset': 'ultrafast', 'tune': 'zerolatency', 'crf': '25'}
|
||||||
|
print(f"线程 '{thread_name}': RTMP推流器初始化成功。")
|
||||||
|
first_frame_processed = True
|
||||||
|
except Exception as e_init:
|
||||||
|
print(f"错误 (线程 '{thread_name}'): 初始化PyAV推流容器/流失败: {e_init}")
|
||||||
|
traceback.print_exc()
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not output_container or not stream:
|
||||||
|
print(f"错误 (线程 '{thread_name}'): 推流器未初始化,无法推流。可能是首帧处理失败。")
|
||||||
|
time.sleep(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 持续推流到RTMP
|
||||||
|
try:
|
||||||
|
video_frame = av.VideoFrame.from_ndarray(frame_to_push, format='bgr24')
|
||||||
|
for packet in stream.encode(video_frame):
|
||||||
|
output_container.mux(packet)
|
||||||
|
except Exception as e_push:
|
||||||
|
print(f"错误 (线程 '{thread_name}'): 推送帧到RTMP时发生错误: {e_push}")
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
# 定时推送base64帧
|
||||||
|
current_time = time.time()
|
||||||
|
if current_time - last_push_time >= frequency:
|
||||||
|
# 将接收到的类别计数传递给 push_base64_frame
|
||||||
|
push_base64_frame(frame_to_push, gateway, push_url, thread_name, received_category_counts)
|
||||||
|
last_push_time = current_time
|
||||||
|
|
||||||
|
except Exception as e_outer:
|
||||||
|
print(f"推流线程 '{thread_name}' 发生严重外部错误: {e_outer}")
|
||||||
|
traceback.print_exc()
|
||||||
|
finally:
|
||||||
|
print(f"推流线程 '{thread_name}': 进入finally块,准备关闭推流器。")
|
||||||
|
if stream and output_container:
|
||||||
|
try:
|
||||||
|
print(f"推流线程 '{thread_name}': 正在编码流的剩余部分...")
|
||||||
|
for packet in stream.encode(None):
|
||||||
|
output_container.mux(packet)
|
||||||
|
print(f"推流线程 '{thread_name}': 编码剩余部分完成。")
|
||||||
|
except Exception as e_flush:
|
||||||
|
print(f"错误 (线程 '{thread_name}'): 关闭推流流时发生编码/刷新错误: {e_flush}")
|
||||||
|
traceback.print_exc()
|
||||||
|
if output_container:
|
||||||
|
try:
|
||||||
|
print(f"推流线程 '{thread_name}': 正在关闭推流容器...")
|
||||||
|
output_container.close()
|
||||||
|
print(f"推流线程 '{thread_name}': 推流容器已关闭。")
|
||||||
|
except Exception as e_close:
|
||||||
|
print(f"错误 (线程 '{thread_name}'): 关闭推流容器时发生错误: {e_close}")
|
||||||
|
traceback.print_exc()
|
||||||
|
print(f"推流线程 '{thread_name}' 已停止并完成清理。")
|
||||||
|
|
||||||
|
|
||||||
|
def push_base64_frame(frame: np.ndarray, gateway: str, push_url: str, thread_name: str, category_counts: dict):
|
||||||
|
"""将帧转换为base64并推送到指定URL"""
|
||||||
|
try:
|
||||||
|
# 转换为JPEG格式
|
||||||
|
_, buffer = cv2.imencode('.jpg', frame)
|
||||||
|
# 转换为base64字符串
|
||||||
|
frame_base64 = base64.b64encode(buffer).decode('utf-8')
|
||||||
|
|
||||||
|
# 构建JSON数据
|
||||||
|
data = {
|
||||||
|
"gateway": gateway,
|
||||||
|
"frame_base64": frame_base64,
|
||||||
|
"category_counts": category_counts # 使用传入的 category_counts
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"DEBUG push_base64_frame: Pushing data: {data.get('category_counts')}") # 调试打印,检查发送的数据
|
||||||
|
# 发送POST请求
|
||||||
|
response = http_client.post(push_url, json=data, timeout=5)
|
||||||
|
|
||||||
|
# 检查响应
|
||||||
|
if response.status_code == 200:
|
||||||
|
print(f"线程 '{thread_name}': base64帧已成功推送到 {push_url}")
|
||||||
|
else:
|
||||||
|
print(f"错误 (线程 '{thread_name}'): 推送base64帧失败,状态码: {response.status_code}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"错误 (线程 '{thread_name}'): 处理或推送base64帧时发生错误: {e}")
|
5
main.py
Normal file
5
main.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
import uvicorn
|
||||||
|
from web import app
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
BIN
models/人车.pth
Normal file
BIN
models/人车.pth
Normal file
Binary file not shown.
27
requirements - 副本.txt
Normal file
27
requirements - 副本.txt
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
aiohttp==3.11.14
|
||||||
|
aiortc==1.11.0
|
||||||
|
aiosignal==1.3.2
|
||||||
|
APScheduler==3.11.0
|
||||||
|
av==14.2.0
|
||||||
|
fastapi==0.115.11
|
||||||
|
huggingface-hub==0.30.1
|
||||||
|
numpy==2.1.1
|
||||||
|
nvidia-cuda-runtime-cu12==12.8.90
|
||||||
|
opencv-contrib-python==4.11.0.86
|
||||||
|
opencv-python==4.11.0.86
|
||||||
|
pillow==11.1.0
|
||||||
|
pillow_heif==0.22.0
|
||||||
|
pycuda==2025.1
|
||||||
|
pydantic==2.10.6
|
||||||
|
pydantic_core==2.27.2
|
||||||
|
requests==2.32.3
|
||||||
|
requests-toolbelt==1.0.0
|
||||||
|
rfdetr==1.1.0
|
||||||
|
safetensors==0.5.3
|
||||||
|
supervision==0.25.1
|
||||||
|
torch==2.6.0+cu126
|
||||||
|
torchaudio==2.6.0+cu126
|
||||||
|
torchvision==0.21.0+cu126
|
||||||
|
transformers==4.50.3
|
||||||
|
uvicorn==0.34.0
|
||||||
|
wandb==0.19.9
|
233
requirements.txt
Normal file
233
requirements.txt
Normal file
@ -0,0 +1,233 @@
|
|||||||
|
absl-py==2.2.1
|
||||||
|
accelerate==1.6.0
|
||||||
|
addict==2.4.0
|
||||||
|
aiohappyeyeballs==2.6.1
|
||||||
|
aiohttp==3.11.14
|
||||||
|
aioice==0.9.0
|
||||||
|
aiortc==1.11.0
|
||||||
|
aiosignal==1.3.2
|
||||||
|
albucore==0.0.23
|
||||||
|
albumentations==2.0.5
|
||||||
|
altgraph==0.17.4
|
||||||
|
annotated-types==0.7.0
|
||||||
|
anyio==4.8.0
|
||||||
|
anywidget==0.9.18
|
||||||
|
appdirs==1.4.4
|
||||||
|
APScheduler==3.11.0
|
||||||
|
asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1733250440834/work
|
||||||
|
async-timeout==5.0.1
|
||||||
|
attrs==25.3.0
|
||||||
|
av==14.2.0
|
||||||
|
basicsr==1.4.2
|
||||||
|
bbox_visualizer==0.2.0
|
||||||
|
blind-watermark==0.4.4
|
||||||
|
certifi==2025.1.31
|
||||||
|
cffi==1.17.1
|
||||||
|
charset-normalizer==3.4.1
|
||||||
|
click==8.1.8
|
||||||
|
cmake==3.31.6
|
||||||
|
colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1733218098505/work
|
||||||
|
coloredlogs==15.0.1
|
||||||
|
comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1733502965406/work
|
||||||
|
contourpy==1.3.1
|
||||||
|
cryptography==44.0.2
|
||||||
|
cycler==0.12.1
|
||||||
|
Cython==3.0.12
|
||||||
|
debugpy @ file:///D:/bld/debugpy_1741148401445/work
|
||||||
|
decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1740384970518/work
|
||||||
|
deep-sort-realtime==1.3.2
|
||||||
|
defusedxml==0.7.1
|
||||||
|
dnspython==2.7.0
|
||||||
|
docker-pycreds==0.4.0
|
||||||
|
easydict==1.13
|
||||||
|
einops==0.8.1
|
||||||
|
et_xmlfile==2.0.0
|
||||||
|
exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1733208806608/work
|
||||||
|
executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1733569351617/work
|
||||||
|
facexlib==0.3.0
|
||||||
|
fairscale==0.4.13
|
||||||
|
fastapi==0.115.11
|
||||||
|
filelock==3.13.1
|
||||||
|
filetype==1.2.0
|
||||||
|
filterpy==1.4.5
|
||||||
|
fire==0.7.0
|
||||||
|
flatbuffers==25.2.10
|
||||||
|
fonttools==4.56.0
|
||||||
|
frozenlist==1.5.0
|
||||||
|
fsspec==2024.6.1
|
||||||
|
ftfy==6.3.1
|
||||||
|
future==1.0.0
|
||||||
|
gfpgan==1.3.8
|
||||||
|
gitdb==4.0.12
|
||||||
|
GitPython==3.1.44
|
||||||
|
google-crc32c==1.7.1
|
||||||
|
grpcio==1.71.0
|
||||||
|
h11==0.14.0
|
||||||
|
httptools==0.6.4
|
||||||
|
huggingface-hub==0.30.1
|
||||||
|
humanfriendly==10.0
|
||||||
|
idna==3.7
|
||||||
|
ifaddr==0.2.0
|
||||||
|
imageio==2.37.0
|
||||||
|
importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1737420181517/work
|
||||||
|
iniconfig==2.1.0
|
||||||
|
insightface==0.7.3
|
||||||
|
ipykernel @ file:///D:/bld/ipykernel_1719845595208/work
|
||||||
|
ipython @ file:///D:/bld/bld/rattler-build_ipython_1740856913/work
|
||||||
|
ipywidgets==8.1.5
|
||||||
|
jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1733300866624/work
|
||||||
|
Jinja2==3.1.4
|
||||||
|
joblib==1.4.2
|
||||||
|
jupyter_bbox_widget==0.6.0
|
||||||
|
jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1733440914442/work
|
||||||
|
jupyter_core @ file:///D:/bld/jupyter_core_1727163532151/work
|
||||||
|
jupyterlab_widgets==3.0.13
|
||||||
|
kiwisolver==1.4.8
|
||||||
|
lap==0.5.12
|
||||||
|
lazy_loader==0.4
|
||||||
|
llvmlite==0.44.0
|
||||||
|
lmdb==1.6.2
|
||||||
|
Mako==1.3.9
|
||||||
|
Markdown==3.7
|
||||||
|
markdown-it-py==3.0.0
|
||||||
|
MarkupSafe==2.1.5
|
||||||
|
matplotlib==3.10.1
|
||||||
|
matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1733416936468/work
|
||||||
|
mdurl==0.1.2
|
||||||
|
memory-profiler==0.61.0
|
||||||
|
motmetrics==1.4.0
|
||||||
|
mpmath==1.3.0
|
||||||
|
multidict==6.2.0
|
||||||
|
nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1733325553580/work
|
||||||
|
networkx==3.3
|
||||||
|
ninja==1.11.1.3
|
||||||
|
Nuitka==2.6.7
|
||||||
|
numba==0.61.0
|
||||||
|
numpy==2.1.1
|
||||||
|
nvidia-cuda-runtime-cu12==12.8.90
|
||||||
|
onnx==1.16.1
|
||||||
|
onnx-graphsurgeon==0.5.7
|
||||||
|
onnxruntime-gpu==1.20.2
|
||||||
|
onnxsim==0.4.36
|
||||||
|
onnxslim==0.1.48
|
||||||
|
open_clip_torch==2.32.0
|
||||||
|
opencv-contrib-python==4.11.0.86
|
||||||
|
opencv-python==4.11.0.86
|
||||||
|
openpyxl==3.1.5
|
||||||
|
ordered-set==4.1.0
|
||||||
|
packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1733203243479/work
|
||||||
|
pandas==2.2.3
|
||||||
|
parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1733271261340/work
|
||||||
|
pefile==2023.2.7
|
||||||
|
peft==0.15.1
|
||||||
|
pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1733327343728/work
|
||||||
|
pillow==11.1.0
|
||||||
|
pillow_heif==0.22.0
|
||||||
|
platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1733232627818/work
|
||||||
|
pluggy==1.5.0
|
||||||
|
polygraphy==0.49.20
|
||||||
|
prettytable==3.15.1
|
||||||
|
prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1737453357274/work
|
||||||
|
propcache==0.3.1
|
||||||
|
protobuf==3.20.2
|
||||||
|
psutil @ file:///D:/bld/psutil_1740663160591/work
|
||||||
|
psygnal==0.12.0
|
||||||
|
pure_eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1733569405015/work
|
||||||
|
py-cpuinfo==9.0.0
|
||||||
|
pybboxes==0.1.6
|
||||||
|
pycocotools==2.0.8
|
||||||
|
pycparser==2.22
|
||||||
|
pycuda==2025.1
|
||||||
|
pydantic==2.10.6
|
||||||
|
pydantic_core==2.27.2
|
||||||
|
pyee==13.0.0
|
||||||
|
Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1736243443484/work
|
||||||
|
pyinstaller==6.12.0
|
||||||
|
pyinstaller-hooks-contrib==2025.1
|
||||||
|
pylabel==0.1.55
|
||||||
|
pylibsrtp==0.11.0
|
||||||
|
PyMuPDF==1.25.4
|
||||||
|
pyOpenSSL==25.0.0
|
||||||
|
pyparsing==3.2.1
|
||||||
|
pyproj==3.7.1
|
||||||
|
pyreadline3==3.5.4
|
||||||
|
PySide6==6.8.2.1
|
||||||
|
PySide6_Addons==6.8.2.1
|
||||||
|
PySide6_Essentials==6.8.2.1
|
||||||
|
pytest==8.3.5
|
||||||
|
python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1733215673016/work
|
||||||
|
python-dotenv==1.0.1
|
||||||
|
python-multipart==0.0.20
|
||||||
|
pytools==2025.1.2
|
||||||
|
pytz==2025.1
|
||||||
|
PyWavelets==1.8.0
|
||||||
|
pywin32==307
|
||||||
|
pywin32-ctypes==0.2.3
|
||||||
|
PyYAML==6.0.2
|
||||||
|
pyzmq @ file:///D:/bld/pyzmq_1738270977186/work
|
||||||
|
realesrgan==0.3.0
|
||||||
|
regex==2024.11.6
|
||||||
|
requests==2.32.3
|
||||||
|
requests-toolbelt==1.0.0
|
||||||
|
rf100vl==1.0.0
|
||||||
|
rfdetr==1.1.0
|
||||||
|
rich==14.0.0
|
||||||
|
roboflow==1.1.60
|
||||||
|
safetensors==0.5.3
|
||||||
|
sahi==0.11.22
|
||||||
|
scikit-image==0.25.2
|
||||||
|
scikit-learn==1.6.1
|
||||||
|
scipy==1.15.2
|
||||||
|
seaborn==0.13.2
|
||||||
|
sentry-sdk==2.25.1
|
||||||
|
setproctitle==1.3.5
|
||||||
|
shapely==2.0.7
|
||||||
|
shiboken6==6.8.2.1
|
||||||
|
simsimd==6.2.1
|
||||||
|
six @ file:///home/conda/feedstock_root/build_artifacts/six_1733380938961/work
|
||||||
|
smmap==5.0.2
|
||||||
|
sniffio==1.3.1
|
||||||
|
stack_data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1733569443808/work
|
||||||
|
starlette==0.46.1
|
||||||
|
stringzilla==3.12.2
|
||||||
|
supervision==0.25.1
|
||||||
|
sympy==1.13.1
|
||||||
|
tabulate==0.9.0
|
||||||
|
tb-nightly==2.20.0a20250326
|
||||||
|
tensorboard-data-server==0.7.2
|
||||||
|
tensorrt==10.9.0.34
|
||||||
|
tensorrt_cu12==10.9.0.34
|
||||||
|
tensorrt_cu12_bindings==10.9.0.34
|
||||||
|
tensorrt_cu12_libs==10.9.0.34
|
||||||
|
termcolor==2.5.0
|
||||||
|
terminaltables==3.1.10
|
||||||
|
threadpoolctl==3.5.0
|
||||||
|
tifffile==2025.2.18
|
||||||
|
timm==1.0.15
|
||||||
|
tokenizers==0.21.1
|
||||||
|
tomli==2.2.1
|
||||||
|
torch==2.6.0+cu126
|
||||||
|
torchaudio==2.6.0+cu126
|
||||||
|
torchvision==0.21.0+cu126
|
||||||
|
tornado @ file:///D:/bld/tornado_1732615925919/work
|
||||||
|
tqdm==4.67.1
|
||||||
|
traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1733367359838/work
|
||||||
|
transformers==4.50.3
|
||||||
|
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1733188668063/work
|
||||||
|
tzdata==2025.1
|
||||||
|
tzlocal==5.3.1
|
||||||
|
ultralytics==8.3.99
|
||||||
|
ultralytics-thop==2.0.14
|
||||||
|
urllib3==2.3.0
|
||||||
|
uvicorn==0.34.0
|
||||||
|
wandb==0.19.9
|
||||||
|
watchfiles==1.0.4
|
||||||
|
wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1733231326287/work
|
||||||
|
websockets==15.0.1
|
||||||
|
Werkzeug==3.1.3
|
||||||
|
widgetsnbextension==4.0.13
|
||||||
|
xmltodict==0.14.2
|
||||||
|
yapf==0.43.0
|
||||||
|
yarl==1.18.3
|
||||||
|
zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1732827521216/work
|
||||||
|
zstandard==0.23.0
|
21
result.py
Normal file
21
result.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
# 定义返回值类
|
||||||
|
class Response:
|
||||||
|
def __init__(self, code, data, message):
|
||||||
|
self.code = code
|
||||||
|
self.data = data
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
"code": self.code,
|
||||||
|
"data": self.data,
|
||||||
|
"message": self.message
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def success(cls, data=None, message="操作成功"):
|
||||||
|
return cls(200, data, message).to_dict()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def error(cls, data=None, message="操作失败", code=400):
|
||||||
|
return cls(code, data, message).to_dict()
|
254
rfdetr_core.py
Normal file
254
rfdetr_core.py
Normal file
@ -0,0 +1,254 @@
|
|||||||
|
import cv2
|
||||||
|
import supervision as sv
|
||||||
|
from rfdetr import RFDETRBase
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Dict, Set
|
||||||
|
from PIL import Image, ImageDraw, ImageFont # 导入PIL库
|
||||||
|
import numpy as np # 导入numpy用于图像格式转换
|
||||||
|
import json # 新增
|
||||||
|
import os # 新增
|
||||||
|
|
||||||
|
class RFDETRDetector:
|
||||||
|
def __init__(self, config_name: str, base_model_dir="models", base_config_dir="configs", default_font_path="./font/MSYH.TTC", default_font_size=15):
|
||||||
|
self.config_path = os.path.join(base_config_dir, f"{config_name}.json")
|
||||||
|
if not os.path.exists(self.config_path):
|
||||||
|
raise FileNotFoundError(f"配置文件不存在: {self.config_path}")
|
||||||
|
|
||||||
|
with open(self.config_path, 'r', encoding='utf-8') as f:
|
||||||
|
self.config = json.load(f)
|
||||||
|
|
||||||
|
model_path = os.path.join(base_model_dir, self.config['model_pth_filename'])
|
||||||
|
resolution = self.config['resolution']
|
||||||
|
|
||||||
|
# 从配置读取字体路径和大小,如果未提供则使用默认值
|
||||||
|
font_path = self.config.get('font_path', default_font_path)
|
||||||
|
font_size = self.config.get('font_size', default_font_size)
|
||||||
|
|
||||||
|
# 1. 初始化模型
|
||||||
|
self.model = RFDETRBase(
|
||||||
|
pretrain_weights=model_path,
|
||||||
|
# pretrain_weights=model_path or r"E:\A\rf-detr-main\output\pre-train1\checkpoint_best_ema.pth",
|
||||||
|
resolution=resolution
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. 初始化跟踪器
|
||||||
|
self.tracker = sv.ByteTrack(
|
||||||
|
track_activation_threshold=self.config['tracker_activation_threshold'],
|
||||||
|
lost_track_buffer=self.config['tracker_lost_buffer'],
|
||||||
|
minimum_matching_threshold=self.config['tracker_match_threshold'],
|
||||||
|
minimum_consecutive_frames=self.config['tracker_consecutive_frames'],
|
||||||
|
frame_rate=self.config['tracker_frame_rate']
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 类别定义
|
||||||
|
self.VISDRONE_CLASSES = self.config['classes_en']
|
||||||
|
self.VISDRONE_CLASSES_CHINESE = self.config['classes_zh_map']
|
||||||
|
|
||||||
|
# 新增:加载类别启用配置
|
||||||
|
self.detection_settings = self.config.get('detection_settings', {})
|
||||||
|
self.enabled_classes_filter = self.detection_settings.get('enabled_classes', {})
|
||||||
|
# 构建一个查找表,对于未在filter中指定的类别,默认为 True (启用)
|
||||||
|
self._active_classes_lookup = {
|
||||||
|
cls_name: self.enabled_classes_filter.get(cls_name, True)
|
||||||
|
for cls_name in self.VISDRONE_CLASSES
|
||||||
|
}
|
||||||
|
print(f"活动类别配置: {self._active_classes_lookup}")
|
||||||
|
|
||||||
|
# 4. 初始化字体
|
||||||
|
self.FONT_SIZE = font_size
|
||||||
|
try:
|
||||||
|
self.font = ImageFont.truetype(font_path, self.FONT_SIZE)
|
||||||
|
except IOError:
|
||||||
|
print(f"错误:无法加载字体 {font_path}。将使用默认字体。")
|
||||||
|
self.font = ImageFont.load_default() # 使用真正通用的默认字体
|
||||||
|
|
||||||
|
# 5. 类别计数器 (作为类属性)
|
||||||
|
self.class_tracks: Dict[str, Set[int]] = defaultdict(set)
|
||||||
|
self.category_counts: Dict[str, int] = defaultdict(int)
|
||||||
|
|
||||||
|
# 6. 初始化标注器
|
||||||
|
# 从配置加载默认颜色,如果失败则使用预设颜色
|
||||||
|
self.default_color_hex = self.config.get('default_color_hex', "#00FF00") # 默认绿色
|
||||||
|
self.bounding_box_thickness = self.config.get('bounding_box_thickness', 2)
|
||||||
|
|
||||||
|
# 加载颜色配置,用于 PIL 绘制
|
||||||
|
self.class_colors_hex = self.config.get('class_colors_hex', {})
|
||||||
|
self.last_annotated_frame: np.ndarray | None = None # 新增: 用于存储最新的标注帧
|
||||||
|
|
||||||
|
def _hex_to_rgb(self, hex_color: str) -> tuple:
|
||||||
|
"""将十六进制颜色字符串转换为RGB元组。"""
|
||||||
|
hex_color = hex_color.lstrip('#')
|
||||||
|
try:
|
||||||
|
return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
|
||||||
|
except ValueError:
|
||||||
|
print(f"警告: 无法解析十六进制颜色 '{hex_color}', 将使用默认颜色。")
|
||||||
|
# 解析失败时返回一个默认颜色,例如红色
|
||||||
|
return self._hex_to_rgb(self.default_color_hex if self.default_color_hex != hex_color else "#00FF00")
|
||||||
|
|
||||||
|
def _update_counter(self, detections: sv.Detections):
|
||||||
|
"""更新类别计数器"""
|
||||||
|
# 只统计有 tracker_id 的检测结果
|
||||||
|
valid_indices = detections.tracker_id != None
|
||||||
|
if not np.any(valid_indices): # 处理 detections 为空或 tracker_id 都为 None 的情况
|
||||||
|
return
|
||||||
|
|
||||||
|
class_ids = detections.class_id[valid_indices]
|
||||||
|
track_ids = detections.tracker_id[valid_indices]
|
||||||
|
|
||||||
|
for class_id, track_id in zip(class_ids, track_ids):
|
||||||
|
if track_id is None: # 跳过没有 tracker_id 的项
|
||||||
|
continue
|
||||||
|
# 使用英文类别名作为内部 key
|
||||||
|
class_name = self.VISDRONE_CLASSES[class_id]
|
||||||
|
if track_id not in self.class_tracks[class_name]:
|
||||||
|
self.class_tracks[class_name].add(track_id)
|
||||||
|
self.category_counts[class_name] += 1
|
||||||
|
|
||||||
|
def _draw_frame(self, frame: np.ndarray, detections: sv.Detections) -> np.ndarray:
|
||||||
|
"""使用PIL绘制检测框、中文标签和计数信息"""
|
||||||
|
|
||||||
|
pil_image = Image.fromarray(cv2.cvtColor(frame.copy(), cv2.COLOR_BGR2RGB))
|
||||||
|
draw = ImageDraw.Draw(pil_image)
|
||||||
|
|
||||||
|
# --- 使用 PIL 绘制检测框和中文标签 ---
|
||||||
|
valid_indices = detections.tracker_id != None # 或直接使用 detections.xyxy 如果不过滤无 tracker_id 的
|
||||||
|
if np.any(valid_indices):
|
||||||
|
boxes = detections.xyxy[valid_indices]
|
||||||
|
class_ids = detections.class_id[valid_indices]
|
||||||
|
# tracker_ids = detections.tracker_id[valid_indices] # 如果需要 tracker_id
|
||||||
|
|
||||||
|
for box, class_id in zip(boxes, class_ids):
|
||||||
|
x1, y1, x2, y2 = map(int, box)
|
||||||
|
|
||||||
|
english_label = self.VISDRONE_CLASSES[class_id]
|
||||||
|
chinese_label = self.VISDRONE_CLASSES_CHINESE.get(english_label, english_label)
|
||||||
|
|
||||||
|
# 获取边界框颜色
|
||||||
|
box_color_hex = self.class_colors_hex.get(english_label, self.default_color_hex)
|
||||||
|
box_rgb_color = self._hex_to_rgb(box_color_hex)
|
||||||
|
|
||||||
|
# 绘制边界框
|
||||||
|
draw.rectangle([x1, y1, x2, y2], outline=box_rgb_color, width=self.bounding_box_thickness)
|
||||||
|
|
||||||
|
# 绘制中文标签 (与之前逻辑类似)
|
||||||
|
text_to_draw = f"{chinese_label}"
|
||||||
|
# 标签背景 (可选,使其更易读)
|
||||||
|
# label_text_bbox = draw.textbbox((0,0), text_to_draw, font=self.font)
|
||||||
|
# label_width = label_text_bbox[2] - label_text_bbox[0]
|
||||||
|
# label_height = label_text_bbox[3] - label_text_bbox[1]
|
||||||
|
# label_bg_y1 = y1 - label_height - 4 if y1 - label_height - 4 > 0 else y1 + 2
|
||||||
|
# draw.rectangle([x1, label_bg_y1, x1 + label_width + 4, label_bg_y1 + label_height + 2], fill=box_rgb_color)
|
||||||
|
# text_color = (255,255,255) if sum(box_rgb_color) < 382 else (0,0,0) # 简易对比色
|
||||||
|
text_color = (255, 255, 255) # 白色 (RGB)
|
||||||
|
|
||||||
|
text_x = x1 + 2 # 稍微偏移,避免紧贴边框
|
||||||
|
text_y = y1 - self.FONT_SIZE - 2
|
||||||
|
if text_y < 0: # 如果标签超出图像顶部
|
||||||
|
text_y = y1 + 2
|
||||||
|
|
||||||
|
draw.text((text_x, text_y), text_to_draw, font=self.font, fill=text_color)
|
||||||
|
|
||||||
|
# --- 绘制统计面板 (右上角) ---
|
||||||
|
stats_text_lines = [
|
||||||
|
f"{self.VISDRONE_CLASSES_CHINESE.get(cls, cls)}: {self.category_counts[cls]}"
|
||||||
|
for cls in self.VISDRONE_CLASSES if self.category_counts[cls] > 0
|
||||||
|
]
|
||||||
|
|
||||||
|
frame_height, frame_width, _ = frame.shape
|
||||||
|
stats_start_x = frame_width - self.config.get('stats_panel_width', 200)
|
||||||
|
stats_start_y = self.config.get('stats_panel_margin_y', 10)
|
||||||
|
line_height = self.FONT_SIZE + self.config.get('stats_line_spacing', 5)
|
||||||
|
|
||||||
|
stats_text_color_hex = self.config.get('stats_text_color_hex', "#FFFFFF")
|
||||||
|
stats_text_color = self._hex_to_rgb(stats_text_color_hex)
|
||||||
|
|
||||||
|
# 可选:为统计面板添加背景
|
||||||
|
if stats_text_lines:
|
||||||
|
panel_height = len(stats_text_lines) * line_height + 10
|
||||||
|
panel_y2 = stats_start_y + panel_height
|
||||||
|
# 半透明背景
|
||||||
|
# overlay = Image.new('RGBA', pil_image.size, (0,0,0,0))
|
||||||
|
# panel_draw = ImageDraw.Draw(overlay)
|
||||||
|
# panel_draw.rectangle(
|
||||||
|
# [stats_start_x - 5, stats_start_y - 5, frame_width - 5, panel_y2],
|
||||||
|
# fill=(100, 100, 100, 128) # 半透明灰色
|
||||||
|
# )
|
||||||
|
# pil_image = Image.alpha_composite(pil_image.convert('RGBA'), overlay)
|
||||||
|
# draw = ImageDraw.Draw(pil_image) # 如果用了 alpha_composite, 需要重新获取 draw 对象
|
||||||
|
|
||||||
|
# 或者简单不透明背景
|
||||||
|
# draw.rectangle(
|
||||||
|
# [stats_start_x - 5, stats_start_y - 5, frame_width - 5, panel_y2],
|
||||||
|
# fill=self._hex_to_rgb(self.config.get('stats_panel_bg_color_hex', "#808080")) # 例如灰色背景
|
||||||
|
# )
|
||||||
|
|
||||||
|
for i, line in enumerate(stats_text_lines):
|
||||||
|
text_pos = (stats_start_x, stats_start_y + i * line_height)
|
||||||
|
draw.text(text_pos, line, font=self.font, fill=stats_text_color)
|
||||||
|
|
||||||
|
final_annotated_frame = cv2.cvtColor(np.array(pil_image.convert('RGB')), cv2.COLOR_RGB2BGR)
|
||||||
|
return final_annotated_frame
|
||||||
|
|
||||||
|
def detect_and_draw_count(self, frame: np.ndarray, conf: float = -1.0) -> np.ndarray:
|
||||||
|
"""执行单帧检测、跟踪、计数并绘制结果(包含类别过滤)。"""
|
||||||
|
if conf == -1.0:
|
||||||
|
# 优先从 detection_settings 中获取,其次是顶层config,最后是硬编码默认值
|
||||||
|
effective_conf = float(
|
||||||
|
self.detection_settings.get('default_confidence_threshold',
|
||||||
|
self.config.get('default_confidence_threshold', 0.8))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
effective_conf = conf
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. 执行检测
|
||||||
|
detections = self.model.predict(frame, threshold=effective_conf)
|
||||||
|
|
||||||
|
# 处理 detections 为 None 或空的情况
|
||||||
|
if detections is None or len(detections) == 0:
|
||||||
|
detections = sv.Detections.empty()
|
||||||
|
annotated_frame = self._draw_frame(frame, detections)
|
||||||
|
self.last_annotated_frame = annotated_frame.copy() # 新增
|
||||||
|
return annotated_frame
|
||||||
|
|
||||||
|
# 新增:根据配置过滤检测到的类别
|
||||||
|
if detections is not None and len(detections) > 0:
|
||||||
|
keep_indices = []
|
||||||
|
for i, class_id in enumerate(detections.class_id):
|
||||||
|
if class_id < len(self.VISDRONE_CLASSES): # 确保 class_id 有效
|
||||||
|
class_name = self.VISDRONE_CLASSES[class_id]
|
||||||
|
if self._active_classes_lookup.get(class_name, True): # 默认为 True
|
||||||
|
keep_indices.append(i)
|
||||||
|
else:
|
||||||
|
print(f"警告: 检测到无效的 class_id {class_id},超出了已知类别范围。")
|
||||||
|
|
||||||
|
if not keep_indices:
|
||||||
|
detections = sv.Detections.empty()
|
||||||
|
else:
|
||||||
|
detections = detections[keep_indices]
|
||||||
|
|
||||||
|
# 如果过滤后没有检测结果
|
||||||
|
if len(detections) == 0:
|
||||||
|
annotated_frame = self._draw_frame(frame, sv.Detections.empty())
|
||||||
|
self.last_annotated_frame = annotated_frame.copy() # 新增
|
||||||
|
return annotated_frame
|
||||||
|
|
||||||
|
# 2. 执行跟踪 (只对过滤后的结果进行跟踪)
|
||||||
|
detections = self.tracker.update_with_detections(detections)
|
||||||
|
|
||||||
|
# 3. 更新计数器 (只对过滤并跟踪后的结果进行计数)
|
||||||
|
self._update_counter(detections)
|
||||||
|
|
||||||
|
# 4. 绘制结果
|
||||||
|
annotated_frame = self._draw_frame(frame, detections)
|
||||||
|
self.last_annotated_frame = annotated_frame.copy() # 新增
|
||||||
|
|
||||||
|
return annotated_frame
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理帧时发生错误: {e}")
|
||||||
|
if frame is not None:
|
||||||
|
self.last_annotated_frame = frame.copy() # 新增
|
||||||
|
else:
|
||||||
|
self.last_annotated_frame = None # 新增
|
||||||
|
return frame
|
94
rtc_handler.py
Normal file
94
rtc_handler.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
import asyncio
|
||||||
|
import queue
|
||||||
|
from fractions import Fraction
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import av
|
||||||
|
import numpy as np
|
||||||
|
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration, VideoStreamTrack
|
||||||
|
|
||||||
|
|
||||||
|
class DummyVideoTrack(VideoStreamTrack):
|
||||||
|
async def recv(self):
|
||||||
|
# 简洁初始化、返回固定颜色的帧
|
||||||
|
return np.full((480, 640, 3), (0, 0, 255), dtype=np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
async def receive_video_frames(whep_url):
|
||||||
|
pc = RTCPeerConnection(RTCConfiguration(iceServers=[]))
|
||||||
|
frames_queue = asyncio.Queue()
|
||||||
|
|
||||||
|
pc.addTrack(DummyVideoTrack())
|
||||||
|
|
||||||
|
@pc.on("track")
|
||||||
|
def on_track(track):
|
||||||
|
if track.kind == "video":
|
||||||
|
asyncio.create_task(consume_track(track, frames_queue))
|
||||||
|
|
||||||
|
@pc.on("iceconnectionstatechange")
|
||||||
|
def on_ice_connection_state_change():
|
||||||
|
print(f"ICE 连接状态: {pc.iceConnectionState}")
|
||||||
|
|
||||||
|
offer = await pc.createOffer()
|
||||||
|
await pc.setLocalDescription(offer)
|
||||||
|
|
||||||
|
headers = {"Content-Type": "application/sdp"}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(whep_url, data=pc.localDescription.sdp, headers=headers) as response:
|
||||||
|
if response.status != 201:
|
||||||
|
raise Exception(f"服务器返回错误: {response.status}")
|
||||||
|
|
||||||
|
answer = RTCSessionDescription(sdp=await response.text(), type="answer")
|
||||||
|
await pc.setRemoteDescription(answer)
|
||||||
|
|
||||||
|
if "Location" in response.headers:
|
||||||
|
base_url = f"{urlparse(whep_url).scheme}://{urlparse(whep_url).netloc}"
|
||||||
|
print("ICE 协商 URL:", base_url + response.headers["Location"])
|
||||||
|
|
||||||
|
while pc.iceConnectionState not in ["connected", "completed"]:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
print("ICE 连接完成,开始接收视频流")
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
frame = await frames_queue.get()
|
||||||
|
if frame is None:
|
||||||
|
break
|
||||||
|
yield frame
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
await pc.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def consume_track(track, frames_queue):
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
frame = await track.recv()
|
||||||
|
if frame is None:
|
||||||
|
print("没有接收到有效的帧数据")
|
||||||
|
await frames_queue.put(None)
|
||||||
|
break
|
||||||
|
img = frame.to_ndarray(format="bgr24")
|
||||||
|
await frames_queue.put(img)
|
||||||
|
except Exception as e:
|
||||||
|
print("处理帧错误:", e)
|
||||||
|
await frames_queue.put(None)
|
||||||
|
|
||||||
|
|
||||||
|
def rtc_frame(url, frame_queue):
|
||||||
|
async def main():
|
||||||
|
async for frame in receive_video_frames(url):
|
||||||
|
try:
|
||||||
|
frame_queue.put_nowait(frame)
|
||||||
|
except queue.Full:
|
||||||
|
frame_queue.get_nowait()
|
||||||
|
frame_queue.put_nowait(frame)
|
||||||
|
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
loop.run_until_complete(main())
|
||||||
|
loop.close()
|
276
web.py
Normal file
276
web.py
Normal file
@ -0,0 +1,276 @@
|
|||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import queue
|
||||||
|
import time
|
||||||
|
from multiprocessing import Process, Event, Queue as MpQueue
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import FastAPI, Request, HTTPException
|
||||||
|
|
||||||
|
from api_server import model_management_router, initialize_default_model_on_startup, get_active_detector, \
|
||||||
|
get_active_model_identifier
|
||||||
|
from data_pusher import pusher_router, initialize_data_pusher, get_data_pusher_instance
|
||||||
|
from frame_transfer import yolo_frame, push_frame
|
||||||
|
from result import Response
|
||||||
|
from rfdetr_core import RFDETRDetector
|
||||||
|
from rtc_handler import rtc_frame
|
||||||
|
|
||||||
|
app = FastAPI(title="Real-time Video Processing, Model Management, and Data Pusher API")
|
||||||
|
|
||||||
|
app.include_router(model_management_router, prefix="/api", tags=["Model Management"])
|
||||||
|
app.include_router(pusher_router, prefix="/api/pusher", tags=["Data Pusher"])
|
||||||
|
|
||||||
|
process_map: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def web_app_startup_event():
|
||||||
|
print("主应用服务启动中...")
|
||||||
|
await initialize_default_model_on_startup()
|
||||||
|
active_detector_instance = get_active_detector()
|
||||||
|
if active_detector_instance:
|
||||||
|
print(f"主应用启动:检测到活动模型 '{get_active_model_identifier()}',将用于初始化 DataPusher。")
|
||||||
|
else:
|
||||||
|
print("主应用启动:未检测到活动模型。DataPusher 将以无检测器状态初始化。")
|
||||||
|
initialize_data_pusher(active_detector_instance)
|
||||||
|
print("DataPusher 服务已初始化。")
|
||||||
|
print("主应用启动完成。")
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
async def web_app_shutdown_event():
|
||||||
|
print("主应用服务关闭中...")
|
||||||
|
pusher = get_data_pusher_instance()
|
||||||
|
if pusher:
|
||||||
|
print("正在关闭 DataPusher 调度器...")
|
||||||
|
pusher.shutdown_scheduler()
|
||||||
|
else:
|
||||||
|
print("DataPusher 实例未找到,跳过调度器关闭。")
|
||||||
|
|
||||||
|
print("正在尝试终止所有活动的视频处理子进程...")
|
||||||
|
for url, task_info in list(process_map.items()):
|
||||||
|
process: Process = task_info['process']
|
||||||
|
stop_event: Event = task_info['stop_event']
|
||||||
|
data_q: MpQueue = task_info['data_queue']
|
||||||
|
|
||||||
|
if process.is_alive():
|
||||||
|
print(f"向进程 {url} 发送停止信号...")
|
||||||
|
stop_event.set()
|
||||||
|
try:
|
||||||
|
process.join(timeout=15)
|
||||||
|
if process.is_alive():
|
||||||
|
print(f"进程 {url} 在优雅关闭超时后仍然存活,尝试 terminate。")
|
||||||
|
process.terminate()
|
||||||
|
process.join(timeout=5)
|
||||||
|
if process.is_alive():
|
||||||
|
print(f"进程 {url} 在 terminate 后仍然存活,尝试 kill。")
|
||||||
|
process.kill()
|
||||||
|
process.join(timeout=2)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"关闭/终止进程 {url} 时发生错误: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not data_q.empty():
|
||||||
|
pass
|
||||||
|
data_q.close()
|
||||||
|
data_q.join_thread()
|
||||||
|
except Exception as e_q_cleanup:
|
||||||
|
print(f"清理进程 {url} 的数据队列时出错: {e_q_cleanup}")
|
||||||
|
|
||||||
|
del process_map[url]
|
||||||
|
print("所有视频处理子进程已尝试终止和清理。")
|
||||||
|
print("主应用服务关闭完成。")
|
||||||
|
|
||||||
|
|
||||||
|
def start_video_processing(url: str, rtmp_url: str, model_config_name: str, stop_event: Event, data_queue: MpQueue,
|
||||||
|
gateway: str,frequency:int, push_url:str):
|
||||||
|
print(f"视频处理子进程启动 (URL: {url}, Model: {model_config_name})")
|
||||||
|
detector_instance_for_stream: RFDETRDetector = None
|
||||||
|
producer_thread, transfer_thread, consumer_thread = None, None, None
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(f"正在为流 {url} 初始化模型: {model_config_name}...")
|
||||||
|
detector_instance_for_stream = RFDETRDetector(config_name=model_config_name)
|
||||||
|
print(f"模型 {model_config_name} 为流 {url} 初始化成功。")
|
||||||
|
rtc_q = queue.Queue(maxsize=10000)
|
||||||
|
yolo_q = queue.Queue(maxsize=10000)
|
||||||
|
import threading
|
||||||
|
producer_thread = threading.Thread(target=rtc_frame, args=(url, rtc_q), name=f"RTC-{url[:20]}", daemon=True)
|
||||||
|
transfer_thread = threading.Thread(target=yolo_frame, args=(rtc_q, yolo_q, detector_instance_for_stream),
|
||||||
|
name=f"YOLO-{url[:20]}", daemon=True)
|
||||||
|
consumer_thread = threading.Thread(target=push_frame, args=(yolo_q, rtmp_url,gateway,frequency,push_url), name=f"Push-{url[:20]}",
|
||||||
|
daemon=True)
|
||||||
|
|
||||||
|
producer_thread.start()
|
||||||
|
transfer_thread.start()
|
||||||
|
consumer_thread.start()
|
||||||
|
|
||||||
|
stop_event.wait()
|
||||||
|
print(f"子进程 {url}: 收到停止信号。准备关闭线程...")
|
||||||
|
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
print(f"错误 (视频进程 {url}): 模型配置文件 '{model_config_name}.json' 未找到。错误: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"错误 (视频进程 {url}): 初始化或运行时错误。错误: {e}")
|
||||||
|
finally:
|
||||||
|
print(f"视频处理子进程 {url} 进入 finally 块。")
|
||||||
|
|
||||||
|
if producer_thread and producer_thread.is_alive():
|
||||||
|
print(f"子进程 {url}: producer_thread is still alive (daemon).")
|
||||||
|
if transfer_thread and transfer_thread.is_alive():
|
||||||
|
print(f"子进程 {url}: transfer_thread is still alive (daemon).")
|
||||||
|
if consumer_thread and consumer_thread.is_alive():
|
||||||
|
print(f"子进程 {url}: consumer_thread is still alive (daemon).")
|
||||||
|
|
||||||
|
if detector_instance_for_stream:
|
||||||
|
print(f"子进程 {url}: 收集最后数据...")
|
||||||
|
final_counts = getattr(detector_instance_for_stream, 'category_counts', {})
|
||||||
|
final_frame_np = getattr(detector_instance_for_stream, 'last_annotated_frame', None)
|
||||||
|
frame_base64 = None
|
||||||
|
if final_frame_np is not None and isinstance(final_frame_np, np.ndarray):
|
||||||
|
try:
|
||||||
|
_, buffer = cv2.imencode('.jpg', final_frame_np)
|
||||||
|
frame_base64 = base64.b64encode(buffer).decode('utf-8')
|
||||||
|
except Exception as e_encode:
|
||||||
|
print(f"子进程 {url}: 帧编码错误: {e_encode}")
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"timestamp": time.time(),
|
||||||
|
"category_counts": final_counts,
|
||||||
|
"frame_base64": frame_base64,
|
||||||
|
"source_url": url,
|
||||||
|
"event": "task_stopped_final_data"
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
data_queue.put(payload, timeout=5)
|
||||||
|
print(f"子进程 {url}: 已将最终数据放入队列。")
|
||||||
|
except queue.Full:
|
||||||
|
print(f"子进程 {url}: 无法将最终数据放入队列 (队列已满或超时)。")
|
||||||
|
except Exception as e_put:
|
||||||
|
print(f"子进程 {url}: 将最终数据放入队列时发生错误: {e_put}")
|
||||||
|
else:
|
||||||
|
print(f"子进程 {url}: 检测器实例不可用,无法发送最终数据。")
|
||||||
|
|
||||||
|
try:
|
||||||
|
data_queue.close()
|
||||||
|
except Exception as e_q_close:
|
||||||
|
print(f"子进程 {url}: 关闭数据队列时出错: {e_q_close}")
|
||||||
|
print(f"视频处理子进程 {url} 执行完毕。")
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/start_video", tags=["Video Processing"])
|
||||||
|
async def start_video(request: Request):
|
||||||
|
data = await request.json()
|
||||||
|
url = data.get("url")
|
||||||
|
model_identifier_to_use = data.get("model_identifier")
|
||||||
|
host = data.get("host")
|
||||||
|
rtmp_port = data.get("rtmp_port")
|
||||||
|
rtc_port = data.get("rtc_port")
|
||||||
|
gateway = data.get("gateway")
|
||||||
|
frequency = data.get("frequency")
|
||||||
|
push_url = data.get("push_url")
|
||||||
|
|
||||||
|
# 生成MD5
|
||||||
|
md5_hash = hashlib.md5(url.encode()).hexdigest()
|
||||||
|
rtmp_url = f"rtmp://{host}:{rtmp_port}/live/{md5_hash}"
|
||||||
|
rtc_url = f"http://{host}:{rtc_port}/rtc/v1/whep/?{md5_hash}"
|
||||||
|
if not url or not rtmp_url:
|
||||||
|
raise HTTPException(status_code=400, detail="'url' 和 'rtmp_url' 字段是必须的。")
|
||||||
|
|
||||||
|
if not model_identifier_to_use:
|
||||||
|
print(f"请求中未指定 model_identifier,尝试使用全局激活的模型。")
|
||||||
|
model_identifier_to_use = get_active_model_identifier()
|
||||||
|
if not model_identifier_to_use:
|
||||||
|
raise HTTPException(status_code=400, detail="请求中未指定 'model_identifier',且当前无全局激活的默认模型。")
|
||||||
|
print(f"将为流 {url} 使用当前全局激活的模型: {model_identifier_to_use}")
|
||||||
|
|
||||||
|
if url in process_map and process_map[url]['process'].is_alive():
|
||||||
|
raise HTTPException(status_code=409, detail=f"视频处理进程已在运行: {url}")
|
||||||
|
|
||||||
|
print(f"请求启动视频处理: URL = {url}, RTMP = {rtmp_url}, Model = {model_identifier_to_use}")
|
||||||
|
|
||||||
|
stop_event = Event()
|
||||||
|
data_queue = MpQueue(maxsize=1)
|
||||||
|
|
||||||
|
process = Process(target=start_video_processing,
|
||||||
|
args=(url, rtmp_url, model_identifier_to_use, stop_event, data_queue, gateway, frequency, push_url))
|
||||||
|
process.start()
|
||||||
|
process_map[url] = {'process': process, 'stop_event': stop_event, 'data_queue': data_queue}
|
||||||
|
return Response.success(message=f"视频处理已为 URL '{url}' 使用模型 '{model_identifier_to_use}' 启动。",
|
||||||
|
data=rtc_url)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/stop_video", tags=["Video Processing"])
|
||||||
|
async def stop_video(request: Request):
|
||||||
|
data = await request.json()
|
||||||
|
url = data.get("url")
|
||||||
|
if not url:
|
||||||
|
raise HTTPException(status_code=400, detail="'url' 字段是必须的。")
|
||||||
|
|
||||||
|
task_info = process_map.get(url)
|
||||||
|
if not task_info:
|
||||||
|
raise HTTPException(status_code=404, detail=f"没有找到与 URL '{url}' 匹配的活动视频处理进程。")
|
||||||
|
|
||||||
|
process: Process = task_info['process']
|
||||||
|
stop_event: Event = task_info['stop_event']
|
||||||
|
data_q: MpQueue = task_info['data_queue']
|
||||||
|
|
||||||
|
final_data_pushed = False
|
||||||
|
if process.is_alive():
|
||||||
|
print(f"请求停止视频处理: {url}. 发送停止信号...")
|
||||||
|
stop_event.set()
|
||||||
|
process.join(timeout=20)
|
||||||
|
|
||||||
|
if process.is_alive():
|
||||||
|
print(f"警告: 视频处理进程 {url} 在超时后未能正常终止,尝试强制结束。")
|
||||||
|
process.terminate()
|
||||||
|
process.join(timeout=5)
|
||||||
|
if process.is_alive():
|
||||||
|
print(f"错误: 视频处理进程 {url} 强制结束后仍然存在。尝试 kill。")
|
||||||
|
process.kill()
|
||||||
|
process.join(timeout=2)
|
||||||
|
else:
|
||||||
|
print(f"进程 {url} 已优雅停止。尝试获取最后数据...")
|
||||||
|
try:
|
||||||
|
final_payload = data_q.get(timeout=10)
|
||||||
|
print(f"从停止的任务 {url} 收到最终数据。")
|
||||||
|
|
||||||
|
pusher_instance = get_data_pusher_instance()
|
||||||
|
if pusher_instance and pusher_instance.target_url:
|
||||||
|
print(f"准备将任务 {url} 的最后数据推送到 {pusher_instance.target_url}")
|
||||||
|
pusher_instance.push_specific_payload(final_payload)
|
||||||
|
final_data_pushed = True
|
||||||
|
elif pusher_instance:
|
||||||
|
print(
|
||||||
|
f"DataPusher 服务已配置,但未设置目标URL (pusher.target_url is None)。无法推送任务 {url} 的最后数据。")
|
||||||
|
else:
|
||||||
|
print(f"DataPusher 服务未初始化或不可用。无法推送任务 {url} 的最后数据。")
|
||||||
|
|
||||||
|
except queue.Empty:
|
||||||
|
print(f"警告: 任务 {url} 优雅停止后,未从其数据队列中获取到最终数据 (队列为空或超时)。")
|
||||||
|
except Exception as e_q_get:
|
||||||
|
print(f"获取或处理来自任务 {url} 的最终数据时发生错误: {e_q_get}")
|
||||||
|
else:
|
||||||
|
print(f"视频处理进程先前已停止或已结束: {url}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
while not data_q.empty():
|
||||||
|
try:
|
||||||
|
data_q.get_nowait()
|
||||||
|
except queue.Empty:
|
||||||
|
break
|
||||||
|
data_q.close()
|
||||||
|
data_q.join_thread()
|
||||||
|
except Exception as e_q_final_cleanup:
|
||||||
|
print(f"清理任务 {url} 的数据队列的最后步骤中发生错误: {e_q_final_cleanup}")
|
||||||
|
|
||||||
|
del process_map[url]
|
||||||
|
message = f"视频处理已为 URL '{url}' 停止。"
|
||||||
|
if final_data_pushed:
|
||||||
|
message += " 已尝试推送最后的数据。"
|
||||||
|
elif process.exitcode == 0:
|
||||||
|
message += " 进程已退出,但未确认最后数据推送 (可能未配置推送或队列问题)。"
|
||||||
|
|
||||||
|
return Response.success(message=message)
|
151
yolo_core.py
Normal file
151
yolo_core.py
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
from ultralytics import YOLO
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
import tensorrt as trt
|
||||||
|
import pycuda.driver as cuda
|
||||||
|
import pycuda.autoinit
|
||||||
|
|
||||||
|
class YOLODetector:
|
||||||
|
def __init__(self, model_path='models/best.engine'):
|
||||||
|
# 加载 TensorRT 模型
|
||||||
|
self.model = YOLO(model_path, task="detect")
|
||||||
|
# 英文类别名称到中文的映射
|
||||||
|
self.class_name_mapping = {
|
||||||
|
'pedestrian': '行人',
|
||||||
|
'people': '人群',
|
||||||
|
'bicycle': '自行车',
|
||||||
|
'car': '轿车',
|
||||||
|
'van': '面包车',
|
||||||
|
'truck': '卡车',
|
||||||
|
'tricycle': '三轮车',
|
||||||
|
'awning-tricycle': '篷式三轮车',
|
||||||
|
'bus': '公交车',
|
||||||
|
'motor': '摩托车'
|
||||||
|
}
|
||||||
|
# 为每个类别设置固定的RGB颜色
|
||||||
|
self.color_mapping = {
|
||||||
|
'pedestrian': (71, 0, 36), # 勃艮第红
|
||||||
|
'people': (0, 255, 0), # 绿色
|
||||||
|
'bicycle': (0, 49, 83), # 普鲁士蓝
|
||||||
|
'car': (0, 47, 167), # 克莱茵蓝
|
||||||
|
'van': (128, 0, 128), # 紫色
|
||||||
|
'truck': (212, 72, 72), # 缇香红
|
||||||
|
'tricycle': (0, 49, 83), # 橙色
|
||||||
|
'awning-tricycle': (251, 220, 106), # 申布伦黄
|
||||||
|
'bus': (73, 45, 34), # 凡戴克棕
|
||||||
|
'motor': (1, 132, 127) # 马尔斯绿
|
||||||
|
}
|
||||||
|
# 初始化类别计数器
|
||||||
|
self.class_counts = {cls_name: 0 for cls_name in self.class_name_mapping.keys()}
|
||||||
|
# 初始化字体
|
||||||
|
try:
|
||||||
|
self.font = ImageFont.truetype("simhei.ttf", 20)
|
||||||
|
except IOError:
|
||||||
|
self.font = ImageFont.load_default()
|
||||||
|
|
||||||
|
def detect_and_draw_English(self, frame, conf=0.3, iou=0.5):
|
||||||
|
"""
|
||||||
|
对输入帧进行目标检测并返回绘制结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame: 输入的图像帧(BGR格式)
|
||||||
|
conf: 置信度阈值
|
||||||
|
iou: IOU阈值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
annotated_frame: 绘制了检测结果的图像帧
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 进行 YOLO 目标检测
|
||||||
|
results = self.model(
|
||||||
|
frame,
|
||||||
|
conf=conf,
|
||||||
|
iou=iou,
|
||||||
|
half=True,
|
||||||
|
)
|
||||||
|
result = results[0]
|
||||||
|
|
||||||
|
# 使用YOLO自带的绘制功能
|
||||||
|
annotated_frame = result.plot()
|
||||||
|
|
||||||
|
return annotated_frame
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Detection error: {e}")
|
||||||
|
return frame
|
||||||
|
|
||||||
|
def detect_and_draw_Chinese(self, frame, conf=0.2, iou=0.3):
|
||||||
|
"""
|
||||||
|
对输入帧进行目标检测并绘制中文标注
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame: 输入的图像帧(BGR格式)
|
||||||
|
conf: 置信度阈值
|
||||||
|
iou: IOU阈值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
annotated_frame: 绘制了检测结果的图像帧
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 进行 YOLO 目标检测
|
||||||
|
results = self.model(
|
||||||
|
frame,
|
||||||
|
conf=conf,
|
||||||
|
iou=iou,
|
||||||
|
# half=True,
|
||||||
|
)
|
||||||
|
result = results[0]
|
||||||
|
|
||||||
|
# 获取原始帧的副本
|
||||||
|
img = frame.copy()
|
||||||
|
|
||||||
|
# 转换为PIL图像以绘制中文
|
||||||
|
pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
||||||
|
draw = ImageDraw.Draw(pil_img)
|
||||||
|
|
||||||
|
# 绘制检测结果
|
||||||
|
for box in result.boxes:
|
||||||
|
# 获取边框坐标
|
||||||
|
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
|
||||||
|
|
||||||
|
# 获取类别ID和置信度
|
||||||
|
cls_id = int(box.cls[0].item())
|
||||||
|
conf = box.conf[0].item()
|
||||||
|
|
||||||
|
# 获取类别名称并转换为中文
|
||||||
|
cls_name = result.names[cls_id]
|
||||||
|
chinese_name = self.class_name_mapping.get(cls_name, cls_name)
|
||||||
|
|
||||||
|
# 获取该类别的颜色
|
||||||
|
color = self.color_mapping.get(cls_name, (255, 255, 255))
|
||||||
|
|
||||||
|
# 绘制边框
|
||||||
|
draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=3)
|
||||||
|
|
||||||
|
# 准备标签文本
|
||||||
|
text = f"{chinese_name} {conf:.2f}"
|
||||||
|
text_size = draw.textbbox((0, 0), text, font=self.font)
|
||||||
|
text_width = text_size[2] - text_size[0]
|
||||||
|
text_height = text_size[3] - text_size[1]
|
||||||
|
|
||||||
|
# 绘制标签背景(使用与边框相同的颜色)
|
||||||
|
draw.rectangle(
|
||||||
|
[(x1, y1 - text_height - 4), (x1 + text_width, y1)],
|
||||||
|
fill=color
|
||||||
|
)
|
||||||
|
|
||||||
|
# 绘制白色文本
|
||||||
|
draw.text(
|
||||||
|
(x1, y1 - text_height - 2),
|
||||||
|
text,
|
||||||
|
fill=(255, 255, 255), # 白色文本
|
||||||
|
font=self.font
|
||||||
|
)
|
||||||
|
|
||||||
|
# 转换回OpenCV格式
|
||||||
|
return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Detection error: {e}")
|
||||||
|
return frame
|
Reference in New Issue
Block a user