Files
video_detect/service/model_service.py

190 lines
7.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import HTTPException
import numpy as np
import torch
from MySQLdb import MySQLError
from ultralytics import YOLO
import os
from ds.db import db
from service.file_service import get_absolute_path
# 全局变量初始化时为None无模型时保持None
current_yolo_model = None
current_model_absolute_path = None # 存储模型绝对路径不依赖model实例
ALLOWED_MODEL_EXT = {"pt"}
MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB
def load_yolo_model():
"""
加载模型并存储绝对路径
无有效模型路径/模型文件不存在/加载失败时,跳过加载(不抛出异常)
"""
global current_yolo_model, current_model_absolute_path
# 1. 获取数据库中的模型路径无模型时返回None
model_rel_path = get_enabled_model_rel_path()
# 2. 无模型路径时,跳过加载
if not model_rel_path:
print("[模型初始化] 未获取到有效模型路径,已跳过模型加载")
current_yolo_model = None
current_model_absolute_path = None
return None
# 3. 有模型路径时,执行正常加载流程
print(f"[模型初始化] 加载模型:{model_rel_path}")
try:
# 计算绝对路径(避免路径处理异常)
current_model_absolute_path = get_absolute_path(model_rel_path)
print(f"[模型初始化] 模型绝对路径:{current_model_absolute_path}")
# 检查模型文件是否存在
if not os.path.exists(current_model_absolute_path):
print(f"[模型初始化] 警告:模型文件不存在({current_model_absolute_path}),已跳过加载")
current_yolo_model = None
current_model_absolute_path = None
return None
# 加载YOLO模型
new_model = YOLO(current_model_absolute_path)
# 设备分配GPU/CPU
if torch.cuda.is_available():
new_model.to('cuda')
print("[模型初始化] 模型已移动到GPU设备")
else:
print("[模型初始化] 未检测到GPU使用CPU进行推理")
# 更新全局模型变量
current_yolo_model = new_model
print(f"[模型初始化] 成功加载模型:{current_model_absolute_path}")
return current_yolo_model
# 捕获所有加载异常,避免中断项目启动
except Exception as e:
print(f"[模型初始化] 警告:模型加载失败({str(e)}),已跳过加载")
current_yolo_model = None
current_model_absolute_path = None
return None
def get_current_model():
"""
获取当前模型实例
无模型时返回None不抛出异常避免中断流程
"""
return current_yolo_model
def detect(image_np, conf_threshold=0.8):
"""
执行YOLO检测
无模型时返回明确提示,不崩溃;有模型时正常返回检测结果
"""
# 优先检查模型是否已加载
model = get_current_model()
if not model:
error_msg = "检测失败未加载任何YOLO模型数据库中无默认模型或模型加载失败"
print(f"[检测流程] {error_msg}")
return False, error_msg # 返回False+错误提示而非None
# 2. 输入格式验证(保留原逻辑,格式错误仍抛异常,属于参数问题)
if not isinstance(image_np, np.ndarray):
raise ValueError("输入必须是numpy数组BGR图像格式")
if image_np.ndim != 3 or image_np.shape[-1] != 3:
raise ValueError(f"输入图像格式错误,需为 (高度, 宽度, 3) 的BGR数组当前shape: {image_np.shape}")
detection_results = []
try:
# 3. 检测配置
device = "cuda" if torch.cuda.is_available() else "cpu"
img_height, img_width = image_np.shape[:2]
print(f"[检测流程] 设备:{device} | 置信度阈值:{conf_threshold} | 图像尺寸:{img_width}x{img_height}")
# 4. 执行YOLO预测
print("[检测流程] 开始执行YOLO检测")
results = model.predict(
image_np,
conf=conf_threshold,
device=device,
show=False, # 不显示检测窗口
verbose=False # 关闭YOLO内部日志可选减少冗余输出
)
# 5. 整理检测结果(仅保留置信度达标结果,原逻辑保留)
for box in results[0].boxes:
class_id = int(box.cls[0])
class_name = model.names[class_id]
confidence = float(box.conf[0])
# 转换为整数坐标x1, y1, x2, y2
bbox = tuple(map(int, box.xyxy[0]))
# 过滤条件:置信度达标
if confidence >= conf_threshold and 0 <= class_id <= 5:
detection_results.append({
"class": class_name,
"confidence": round(confidence, 4), # 保留4位小数优化输出
"bbox": bbox
})
# 6. 判断是否检测到目标
has_content = len(detection_results) > 0
print(f"[检测流程] 检测完成:共检测到 {len(detection_results)} 个目标")
return has_content, detection_results
# 7. 捕获检测过程异常,返回明确错误信息
except Exception as e:
error_msg = f"检测过程出错:{str(e)}"
print(f"[检测流程] {error_msg}")
return False, error_msg
def get_enabled_model_rel_path():
"""
从数据库获取启用的默认模型相对路径
无模型/数据库错误时返回None仅记录警告日志
"""
conn = None
cursor = None
try:
# 建立数据库连接
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 查询默认模型is_default=1
query = "SELECT path FROM model WHERE is_default = 1 LIMIT 1"
cursor.execute(query)
result = cursor.fetchone()
# 有有效路径则返回否则返回None
if result and isinstance(result.get('path'), str) and result['path'].strip():
model_path = result['path'].strip()
print(f"找到默认模型路径:{model_path}")
return model_path
else:
print("警告:未找到启用的默认模型")
return None
# 捕获MySQL相关错误
except MySQLError as e:
print(f"警告:查询默认模型时发生数据库错误({str(e)}")
return None
# 捕获其他通用错误
except Exception as e:
print(f"[数据库查询] 警告:获取默认模型路径失败({str(e)}")
return None
# 确保数据库连接和游标关闭
finally:
if cursor:
try:
cursor.close()
print("游标已关闭")
except Exception as e:
print(f"关闭游标时出错:{str(e)}")
# 关闭连接(允许重复关闭,无需检查是否已关闭)
if conn:
try:
conn.close()
print("数据库连接已关闭")
except Exception as e:
print(f"关闭数据库连接时出错:{str(e)}")