131 lines
4.5 KiB
Python
131 lines
4.5 KiB
Python
|
|
from http.client import HTTPException
|
|||
|
|
|
|||
|
|
import numpy as np
|
|||
|
|
import torch
|
|||
|
|
from MySQLdb import MySQLError
|
|||
|
|
from ultralytics import YOLO
|
|||
|
|
import os
|
|||
|
|
|
|||
|
|
from ds.db import db
|
|||
|
|
from service.file_service import get_absolute_path
|
|||
|
|
|
|||
|
|
# 全局变量
|
|||
|
|
current_yolo_model = None
|
|||
|
|
current_model_absolute_path = None # 存储模型绝对路径,不依赖model实例
|
|||
|
|
|
|||
|
|
ALLOWED_MODEL_EXT = {"pt"}
|
|||
|
|
MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB
|
|||
|
|
|
|||
|
|
|
|||
|
|
def load_yolo_model():
|
|||
|
|
"""加载模型并存储绝对路径"""
|
|||
|
|
global current_yolo_model, current_model_absolute_path
|
|||
|
|
model_rel_path = get_enabled_model_rel_path()
|
|||
|
|
print(f"[模型初始化] 加载模型:{model_rel_path}")
|
|||
|
|
|
|||
|
|
# 计算并存储绝对路径
|
|||
|
|
current_model_absolute_path = get_absolute_path(model_rel_path)
|
|||
|
|
print(f"[模型初始化] 绝对路径:{current_model_absolute_path}")
|
|||
|
|
|
|||
|
|
# 检查模型文件
|
|||
|
|
if not os.path.exists(current_model_absolute_path):
|
|||
|
|
raise FileNotFoundError(f"模型文件不存在: {current_model_absolute_path}")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
new_model = YOLO(current_model_absolute_path)
|
|||
|
|
if torch.cuda.is_available():
|
|||
|
|
new_model.to('cuda')
|
|||
|
|
print("模型已移动到GPU")
|
|||
|
|
else:
|
|||
|
|
print("使用CPU进行推理")
|
|||
|
|
current_yolo_model = new_model
|
|||
|
|
print(f"成功加载模型: {current_model_absolute_path}")
|
|||
|
|
return current_yolo_model
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"模型加载失败:{str(e)}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_current_model():
|
|||
|
|
"""获取当前模型实例"""
|
|||
|
|
if current_yolo_model is None:
|
|||
|
|
raise ValueError("尚未加载任何YOLO模型,请先调用load_yolo_model加载模型")
|
|||
|
|
return current_yolo_model
|
|||
|
|
|
|||
|
|
|
|||
|
|
def detect(image_np, conf_threshold=0.8):
|
|||
|
|
# 1. 输入格式验证
|
|||
|
|
if not isinstance(image_np, np.ndarray):
|
|||
|
|
raise ValueError("输入必须是numpy数组(BGR图像)")
|
|||
|
|
if image_np.ndim != 3 or image_np.shape[-1] != 3:
|
|||
|
|
raise ValueError(f"输入图像格式错误,需为 (h, w, 3) 的BGR数组,当前shape: {image_np.shape}")
|
|||
|
|
detection_results = []
|
|||
|
|
try:
|
|||
|
|
model = get_current_model()
|
|||
|
|
if not current_model_absolute_path:
|
|||
|
|
raise RuntimeError("模型未初始化!请先调用 load_yolo_model 加载模型")
|
|||
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|||
|
|
print(f"检测设备:{device} | 置信度阈值:{conf_threshold}")
|
|||
|
|
|
|||
|
|
# 图像尺寸信息
|
|||
|
|
img_height, img_width = image_np.shape[:2]
|
|||
|
|
print(f"输入图像尺寸:{img_width}x{img_height}")
|
|||
|
|
|
|||
|
|
# YOLO检测
|
|||
|
|
print("执行YOLO检测")
|
|||
|
|
results = model.predict(
|
|||
|
|
image_np,
|
|||
|
|
conf=conf_threshold,
|
|||
|
|
device=device,
|
|||
|
|
show=False,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 4. 整理检测结果(仅保留Chest类别,ID=2)
|
|||
|
|
for box in results[0].boxes:
|
|||
|
|
class_id = int(box.cls[0]) # 类别ID
|
|||
|
|
class_name = model.names[class_id]
|
|||
|
|
confidence = float(box.conf[0])
|
|||
|
|
bbox = tuple(map(int, box.xyxy[0]))
|
|||
|
|
|
|||
|
|
# 过滤条件:置信度达标 + 类别为Chest(class_id=2)
|
|||
|
|
# and class_id == 2
|
|||
|
|
if confidence >= conf_threshold:
|
|||
|
|
detection_results.append({
|
|||
|
|
"class": class_name,
|
|||
|
|
"confidence": confidence,
|
|||
|
|
"bbox": bbox
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
# 判断是否有目标
|
|||
|
|
has_content = len(detection_results) > 0
|
|||
|
|
return has_content, detection_results
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
error_msg = f"检测过程出错:{str(e)}"
|
|||
|
|
print(error_msg)
|
|||
|
|
return False, None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_enabled_model_rel_path():
|
|||
|
|
"""获取数据库中启用的模型相对路径"""
|
|||
|
|
conn = None
|
|||
|
|
cursor = None
|
|||
|
|
try:
|
|||
|
|
conn = db.get_connection()
|
|||
|
|
cursor = conn.cursor(dictionary=True)
|
|||
|
|
query = "SELECT path FROM model WHERE is_default = 1 LIMIT 1"
|
|||
|
|
cursor.execute(query)
|
|||
|
|
result = cursor.fetchone()
|
|||
|
|
|
|||
|
|
if not result or not result.get('path'):
|
|||
|
|
raise HTTPException(status_code=404, detail="未找到启用的默认模型")
|
|||
|
|
|
|||
|
|
return result['path']
|
|||
|
|
except MySQLError as e:
|
|||
|
|
raise HTTPException(status_code=500, detail=f"查询默认模型时发生数据库错误:{str(e)}") from e
|
|||
|
|
except Exception as e:
|
|||
|
|
if isinstance(e, HTTPException):
|
|||
|
|
raise e
|
|||
|
|
raise HTTPException(status_code=500, detail=f"获取默认模型路径失败:{str(e)}") from e
|
|||
|
|
finally:
|
|||
|
|
db.close_connection(conn, cursor)
|