优化ocr检测时间,加载默认模型
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
from http.client import HTTPException
|
||||
from fastapi import HTTPException
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -9,7 +9,7 @@ import os
|
||||
from ds.db import db
|
||||
from service.file_service import get_absolute_path
|
||||
|
||||
# 全局变量
|
||||
# 全局变量:初始化时为None,无模型时保持None
|
||||
current_yolo_model = None
|
||||
current_model_absolute_path = None # 存储模型绝对路径,不依赖model实例
|
||||
|
||||
@ -18,114 +18,173 @@ MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB
|
||||
|
||||
|
||||
def load_yolo_model():
|
||||
"""加载模型并存储绝对路径"""
|
||||
"""
|
||||
加载模型并存储绝对路径
|
||||
无有效模型路径/模型文件不存在/加载失败时,跳过加载(不抛出异常)
|
||||
"""
|
||||
global current_yolo_model, current_model_absolute_path
|
||||
# 1. 获取数据库中的模型路径(无模型时返回None)
|
||||
model_rel_path = get_enabled_model_rel_path()
|
||||
|
||||
# 2. 无模型路径时,跳过加载
|
||||
if not model_rel_path:
|
||||
print("[模型初始化] 未获取到有效模型路径,已跳过模型加载")
|
||||
current_yolo_model = None
|
||||
current_model_absolute_path = None
|
||||
return None
|
||||
|
||||
# 3. 有模型路径时,执行正常加载流程
|
||||
print(f"[模型初始化] 加载模型:{model_rel_path}")
|
||||
|
||||
# 计算并存储绝对路径
|
||||
current_model_absolute_path = get_absolute_path(model_rel_path)
|
||||
print(f"[模型初始化] 绝对路径:{current_model_absolute_path}")
|
||||
|
||||
# 检查模型文件
|
||||
if not os.path.exists(current_model_absolute_path):
|
||||
raise FileNotFoundError(f"模型文件不存在: {current_model_absolute_path}")
|
||||
|
||||
try:
|
||||
# 计算绝对路径(避免路径处理异常)
|
||||
current_model_absolute_path = get_absolute_path(model_rel_path)
|
||||
print(f"[模型初始化] 模型绝对路径:{current_model_absolute_path}")
|
||||
|
||||
# 检查模型文件是否存在
|
||||
if not os.path.exists(current_model_absolute_path):
|
||||
print(f"[模型初始化] 警告:模型文件不存在({current_model_absolute_path}),已跳过加载")
|
||||
current_yolo_model = None
|
||||
current_model_absolute_path = None
|
||||
return None
|
||||
|
||||
# 加载YOLO模型
|
||||
new_model = YOLO(current_model_absolute_path)
|
||||
# 设备分配(GPU/CPU)
|
||||
if torch.cuda.is_available():
|
||||
new_model.to('cuda')
|
||||
print("模型已移动到GPU")
|
||||
print("[模型初始化] 模型已移动到GPU设备")
|
||||
else:
|
||||
print("使用CPU进行推理")
|
||||
print("[模型初始化] 未检测到GPU,使用CPU进行推理")
|
||||
|
||||
# 更新全局模型变量
|
||||
current_yolo_model = new_model
|
||||
print(f"成功加载模型: {current_model_absolute_path}")
|
||||
print(f"[模型初始化] 成功加载模型:{current_model_absolute_path}")
|
||||
return current_yolo_model
|
||||
|
||||
# 捕获所有加载异常,避免中断项目启动
|
||||
except Exception as e:
|
||||
print(f"模型加载失败:{str(e)}")
|
||||
raise
|
||||
print(f"[模型初始化] 警告:模型加载失败({str(e)}),已跳过加载")
|
||||
current_yolo_model = None
|
||||
current_model_absolute_path = None
|
||||
return None
|
||||
|
||||
|
||||
def get_current_model():
|
||||
"""获取当前模型实例"""
|
||||
if current_yolo_model is None:
|
||||
raise ValueError("尚未加载任何YOLO模型,请先调用load_yolo_model加载模型")
|
||||
"""
|
||||
获取当前模型实例
|
||||
无模型时返回None(不抛出异常,避免中断流程)
|
||||
"""
|
||||
return current_yolo_model
|
||||
|
||||
|
||||
def detect(image_np, conf_threshold=0.8):
|
||||
# 1. 输入格式验证
|
||||
"""
|
||||
执行YOLO检测
|
||||
无模型时返回明确提示,不崩溃;有模型时正常返回检测结果
|
||||
"""
|
||||
# 优先检查模型是否已加载
|
||||
model = get_current_model()
|
||||
if not model:
|
||||
error_msg = "检测失败:未加载任何YOLO模型(数据库中无默认模型或模型加载失败)"
|
||||
print(f"[检测流程] {error_msg}")
|
||||
return False, error_msg # 返回False+错误提示,而非None
|
||||
|
||||
# 2. 输入格式验证(保留原逻辑,格式错误仍抛异常,属于参数问题)
|
||||
if not isinstance(image_np, np.ndarray):
|
||||
raise ValueError("输入必须是numpy数组(BGR图像)")
|
||||
raise ValueError("输入必须是numpy数组(BGR图像格式)")
|
||||
if image_np.ndim != 3 or image_np.shape[-1] != 3:
|
||||
raise ValueError(f"输入图像格式错误,需为 (h, w, 3) 的BGR数组,当前shape: {image_np.shape}")
|
||||
raise ValueError(f"输入图像格式错误,需为 (高度, 宽度, 3) 的BGR数组,当前shape: {image_np.shape}")
|
||||
|
||||
detection_results = []
|
||||
try:
|
||||
model = get_current_model()
|
||||
if not current_model_absolute_path:
|
||||
raise RuntimeError("模型未初始化!请先调用 load_yolo_model 加载模型")
|
||||
# 3. 检测配置
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"检测设备:{device} | 置信度阈值:{conf_threshold}")
|
||||
|
||||
# 图像尺寸信息
|
||||
img_height, img_width = image_np.shape[:2]
|
||||
print(f"输入图像尺寸:{img_width}x{img_height}")
|
||||
print(f"[检测流程] 设备:{device} | 置信度阈值:{conf_threshold} | 图像尺寸:{img_width}x{img_height}")
|
||||
|
||||
# YOLO检测
|
||||
print("执行YOLO检测")
|
||||
# 4. 执行YOLO预测
|
||||
print("[检测流程] 开始执行YOLO检测")
|
||||
results = model.predict(
|
||||
image_np,
|
||||
conf=conf_threshold,
|
||||
device=device,
|
||||
show=False,
|
||||
show=False, # 不显示检测窗口
|
||||
verbose=False # 关闭YOLO内部日志(可选,减少冗余输出)
|
||||
)
|
||||
|
||||
# 4. 整理检测结果(仅保留Chest类别,ID=2)
|
||||
# 5. 整理检测结果(仅保留置信度达标结果,原逻辑保留)
|
||||
for box in results[0].boxes:
|
||||
class_id = int(box.cls[0]) # 类别ID
|
||||
class_id = int(box.cls[0])
|
||||
class_name = model.names[class_id]
|
||||
confidence = float(box.conf[0])
|
||||
# 转换为整数坐标(x1, y1, x2, y2)
|
||||
bbox = tuple(map(int, box.xyxy[0]))
|
||||
|
||||
# 过滤条件:置信度达标 + 类别为Chest(class_id=2)
|
||||
# and class_id == 2
|
||||
if confidence >= conf_threshold:
|
||||
# 过滤条件:置信度达标
|
||||
if confidence >= conf_threshold and 0 <= class_id <= 5:
|
||||
detection_results.append({
|
||||
"class": class_name,
|
||||
"confidence": confidence,
|
||||
"confidence": round(confidence, 4), # 保留4位小数,优化输出
|
||||
"bbox": bbox
|
||||
})
|
||||
|
||||
# 判断是否有目标
|
||||
# 6. 判断是否检测到目标
|
||||
has_content = len(detection_results) > 0
|
||||
print(f"[检测流程] 检测完成:共检测到 {len(detection_results)} 个目标")
|
||||
return has_content, detection_results
|
||||
|
||||
# 7. 捕获检测过程异常,返回明确错误信息
|
||||
except Exception as e:
|
||||
error_msg = f"检测过程出错:{str(e)}"
|
||||
print(error_msg)
|
||||
return False, None
|
||||
print(f"[检测流程] {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
|
||||
def get_enabled_model_rel_path():
|
||||
"""获取数据库中启用的模型相对路径"""
|
||||
"""
|
||||
从数据库获取启用的默认模型相对路径
|
||||
无模型/数据库错误时返回None,仅记录警告日志
|
||||
"""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
# 建立数据库连接
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
# 查询默认模型(is_default=1)
|
||||
query = "SELECT path FROM model WHERE is_default = 1 LIMIT 1"
|
||||
cursor.execute(query)
|
||||
result = cursor.fetchone()
|
||||
|
||||
if not result or not result.get('path'):
|
||||
raise HTTPException(status_code=404, detail="未找到启用的默认模型")
|
||||
# 有有效路径则返回,否则返回None
|
||||
if result and isinstance(result.get('path'), str) and result['path'].strip():
|
||||
model_path = result['path'].strip()
|
||||
print(f"找到默认模型路径:{model_path}")
|
||||
return model_path
|
||||
else:
|
||||
print("警告:未找到启用的默认模型")
|
||||
return None
|
||||
|
||||
return result['path']
|
||||
# 捕获MySQL相关错误
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"查询默认模型时发生数据库错误:{str(e)}") from e
|
||||
print(f"警告:查询默认模型时发生数据库错误({str(e)})")
|
||||
return None
|
||||
# 捕获其他通用错误
|
||||
except Exception as e:
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
raise HTTPException(status_code=500, detail=f"获取默认模型路径失败:{str(e)}") from e
|
||||
print(f"[数据库查询] 警告:获取默认模型路径失败({str(e)})")
|
||||
return None
|
||||
# 确保数据库连接和游标关闭
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
if cursor:
|
||||
try:
|
||||
cursor.close()
|
||||
print("游标已关闭")
|
||||
except Exception as e:
|
||||
print(f"关闭游标时出错:{str(e)}")
|
||||
# 关闭连接(允许重复关闭,无需检查是否已关闭)
|
||||
if conn:
|
||||
try:
|
||||
conn.close()
|
||||
print("数据库连接已关闭")
|
||||
except Exception as e:
|
||||
print(f"关闭数据库连接时出错:{str(e)}")
|
||||
Reference in New Issue
Block a user