目前可以成功动态更换模型运行的
This commit is contained in:
41
core/all.py
41
core/all.py
@ -57,50 +57,39 @@ def save_db(model_type, client_ip, result):
|
||||
|
||||
|
||||
|
||||
# 修正后的 detect 函数关键部分
|
||||
def detect(client_ip, frame):
|
||||
"""
|
||||
执行模型检测,检测到违规时按指定格式保存图片
|
||||
参数:
|
||||
frame: 待检测的图像帧(OpenCV格式,numpy.ndarray类型)
|
||||
返回:
|
||||
(检测结果布尔值, 检测详情, 检测模型类型)
|
||||
"""
|
||||
# 1. YOLO检测(优先级1)
|
||||
# 1. YOLO检测
|
||||
yolo_flag, yolo_result = yoloDetect(frame)
|
||||
print(f"YOLO检测结果:{yolo_result}")
|
||||
if yolo_flag:
|
||||
# model_type 传入 "yolo"(正确)
|
||||
full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip)
|
||||
if full_save_path: # 只判断完整路径是否有效(用于保存)
|
||||
if full_save_path:
|
||||
cv2.imwrite(full_save_path, frame)
|
||||
# 打印时使用「显示用短路径」,符合需求格式
|
||||
print(f"✅ YOLO违规图片已保存:{display_path}")
|
||||
print(f"✅ yolo违规图片已保存:{display_path}") # 日志也修正
|
||||
save_db(model_type="yolo", client_ip=client_ip, result=str(full_save_path))
|
||||
return (True, yolo_result, "yolo")
|
||||
#
|
||||
# # 2. 人脸检测(优先级2)
|
||||
|
||||
# 2. 人脸检测
|
||||
face_flag, face_result = faceDetect(frame)
|
||||
print(f"人脸检测结果:{face_result}")
|
||||
if face_flag:
|
||||
full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip)
|
||||
if full_save_path: # 只判断完整路径是否有效(用于保存)
|
||||
full_save_path, display_path = get_image_save_path(model_type="face", client_ip=client_ip) # 这里改了
|
||||
if full_save_path:
|
||||
cv2.imwrite(full_save_path, frame)
|
||||
# 打印时使用「显示用短路径」,符合需求格式
|
||||
print(f"✅ face违规图片已保存:{display_path}")
|
||||
print(f"✅ face违规图片已保存:{display_path}") # 日志也修正
|
||||
save_db(model_type="face", client_ip=client_ip, result=str(full_save_path))
|
||||
return (True, face_result, "face")
|
||||
|
||||
# 3. OCR检测(优先级3)
|
||||
# 3. OCR检测
|
||||
ocr_flag, ocr_result = ocrDetect(frame)
|
||||
print(f"OCR检测结果:{ocr_result}")
|
||||
if ocr_flag:
|
||||
# 解构元组,保存用完整路径,打印用短路径
|
||||
full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip)
|
||||
if full_save_path: # 只判断完整路径是否有效(用于保存)
|
||||
full_save_path, display_path = get_image_save_path(model_type="ocr", client_ip=client_ip) # 这里改了
|
||||
if full_save_path:
|
||||
cv2.imwrite(full_save_path, frame)
|
||||
# 打印时使用「显示用短路径」,符合需求格式
|
||||
print(f"✅ ocr违规图片已保存:{display_path}")
|
||||
print(f"✅ ocr违规图片已保存:{display_path}") # 日志也修正
|
||||
save_db(model_type="ocr", client_ip=client_ip, result=str(full_save_path))
|
||||
return (True, ocr_result, "ocr")
|
||||
|
||||
# 4. 无违规内容(不保存图片)
|
||||
print(f"❌ 未检测到任何违规内容,不保存图片")
|
||||
return (False, "未检测到任何内容", "none")
|
47
core/yolo.py
47
core/yolo.py
@ -1,37 +1,43 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from ultralytics import YOLO
|
||||
from service.model_service import get_current_yolo_model # 从模型管理模块获取模型
|
||||
|
||||
# 全局变量
|
||||
# 全局模型变量
|
||||
_yolo_model = None
|
||||
|
||||
|
||||
model_path = os.path.join(os.path.dirname(__file__), "models", "best.pt")
|
||||
|
||||
|
||||
def load_model():
|
||||
"""加载YOLO目标检测模型"""
|
||||
def load_model(model_path=None):
|
||||
"""加载YOLO模型(优先使用模型管理模块的默认模型)"""
|
||||
global _yolo_model
|
||||
|
||||
if model_path is None:
|
||||
_yolo_model = get_current_yolo_model()
|
||||
return _yolo_model is not None
|
||||
|
||||
try:
|
||||
_yolo_model = YOLO(model_path)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"YOLO model load failed: {e}")
|
||||
print(f"YOLO模型加载失败(指定路径):{str(e)}")
|
||||
return False
|
||||
|
||||
return True if _yolo_model else False
|
||||
|
||||
|
||||
def detect(frame, conf_threshold=0.2):
|
||||
"""YOLO目标检测、返回(是否识别到, 结果字符串)"""
|
||||
"""执行目标检测,返回(是否成功, 结果字符串)"""
|
||||
global _yolo_model
|
||||
|
||||
if not _yolo_model or frame is None:
|
||||
return (False, "未初始化或无效帧")
|
||||
# 确保模型已加载
|
||||
if not _yolo_model:
|
||||
if not load_model():
|
||||
return (False, "模型未初始化")
|
||||
|
||||
if frame is None:
|
||||
return (False, "无效输入帧")
|
||||
|
||||
try:
|
||||
results = _yolo_model(frame, conf=conf_threshold)
|
||||
# 检查是否有检测结果
|
||||
# 执行检测(frame应为numpy数组)
|
||||
results = _yolo_model(frame, conf=conf_threshold, verbose=False)
|
||||
has_results = len(results[0].boxes) > 0 if results else False
|
||||
|
||||
if not has_results:
|
||||
@ -42,13 +48,12 @@ def detect(frame, conf_threshold=0.2):
|
||||
for box in results[0].boxes:
|
||||
cls = int(box.cls[0])
|
||||
conf = float(box.conf[0])
|
||||
bbox = [float(x) for x in box.xyxy[0]]
|
||||
bbox = [round(x, 2) for x in box.xyxy[0].tolist()] # 保留两位小数
|
||||
class_name = _yolo_model.names[cls] if hasattr(_yolo_model, 'names') else f"类别{cls}"
|
||||
result_parts.append(f"{class_name} (置信度: {conf:.2f}, 边界框: {bbox})")
|
||||
result_parts.append(f"{class_name}(置信度:{conf:.2f},位置:{bbox})")
|
||||
|
||||
result_str = "; ".join(result_parts)
|
||||
return (has_results, result_str)
|
||||
return (True, "; ".join(result_parts))
|
||||
|
||||
except Exception as e:
|
||||
print(f"YOLO detect error: {e}")
|
||||
return (False, f"检测错误: {str(e)}")
|
||||
print(f"检测过程出错:{str(e)}")
|
||||
return (False, f"检测错误:{str(e)}")
|
||||
|
Reference in New Issue
Block a user