Files
AI-TianDong/utils/pho_detector.py

163 lines
6.7 KiB
Python
Raw Normal View History

2025-07-24 12:45:27 +08:00
# utils/pho_detector.py
import cv2
import numpy as np
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
import logging
from typing import List, Dict, Any
# --- Basic Configuration ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Module-Specific Configuration ---
# All classes this module can detect.
ALL_SUPPORTED_CLASSES = ['pho', 'shelves', 'pile', 'hole']
# Configuration for the first model (detects 'pho' and 'shelves').
MODEL_1_CONFIG = {
"name": "pho_model",
# IMPORTANT: Update this path to your actual model location.
"weights_path": "models/photovoltaic_model/pho_model/best.pt",
# Class order must match the model's training IDs (0: 'pho', 1: 'shelves').
"class_names": ['pho', 'shelves'],
"confidence_threshold": 0.7
}
# Configuration for the second model (detects 'pile' and 'hole').
MODEL_2_CONFIG = {
"name": "pile_model",
# IMPORTANT: Update this path to your actual model location.
"weights_path": "models/photovoltaic_model/pile_model/best.pt",
# Class order must match the model's training IDs (0: 'pile', 1: 'hole').
"class_names": ['pile', 'hole'],
"confidence_threshold": 0.7
}
# Unified SAHI slicing parameters for all tasks in this module.
SAHI_PARAMS = {
"slice_height": 800,
"slice_width": 800,
"overlap_height_ratio": 0.2,
"overlap_width_ratio": 0.2
}
# Global dictionary to cache loaded model instances.
detection_models: Dict[str, AutoDetectionModel] = {}
# --- Model Management Functions ---
def initialize_model(model_config: Dict[str, Any]) -> AutoDetectionModel:
"""Initializes a single model based on its configuration if not already loaded."""
global detection_models
model_name = model_config["name"]
if model_name not in detection_models:
logger.info(f"Loading detection model '{model_name}' from {model_config['weights_path']}...")
try:
model = AutoDetectionModel.from_pretrained(
model_type="ultralytics",
model_path=model_config["weights_path"],
confidence_threshold=model_config["confidence_threshold"],
)
detection_models[model_name] = model
logger.info(f"Model '{model_name}' loaded successfully.")
except Exception as e:
logger.error(f"Error loading model '{model_name}': {e}", exc_info=True)
raise RuntimeError(f"Could not load model '{model_name}': {e}")
return detection_models[model_name]
def initialize_all_pho_models():
"""Initializes all models defined in this module. Called once at application startup."""
logger.info("Pre-initializing all photovoltaic-related models...")
initialize_model(MODEL_1_CONFIG)
initialize_model(MODEL_2_CONFIG)
logger.info("All photovoltaic models have been initialized.")
# --- Inference and Core Logic ---
def _run_inference_for_model(image_cv: np.ndarray, model_config: Dict[str, Any]) -> List[Dict]:
"""Internal function to run SAHI prediction for a single specified model."""
try:
model = initialize_model(model_config)
except RuntimeError as e:
logger.error(f"Cannot run inference; model '{model_config['name']}' failed to initialize: {e}")
return []
logger.info(f"Running SAHI prediction with model: {model_config['name']}...")
try:
sahi_result = get_sliced_prediction(
image=image_cv,
detection_model=model,
**SAHI_PARAMS,
verbose=0
)
except Exception as e:
logger.error(f"SAHI prediction failed for model '{model_config['name']}': {e}", exc_info=True)
return []
targets_output_list = []
# Use the specific class list for this model to map IDs to names.
model_class_names = model_config["class_names"]
for pred in sahi_result.object_prediction_list:
cat_id = int(pred.category.id)
# Core Logic: Map the model's local ID to its correct global class name.
if 0 <= cat_id < len(model_class_names):
class_name = model_class_names[cat_id]
minx, miny, maxx, maxy = pred.bbox.minx, pred.bbox.miny, pred.bbox.maxx, pred.bbox.maxy
targets_output_list.append({
"type": class_name,
"size": [int(maxx - minx), int(maxy - miny)],
"leftTopPoint": [int(minx), int(miny)],
"score": round(float(pred.score.value), 4)
})
else:
logger.warning(f"Model '{model_config['name']}' detected an out-of-range class ID: {cat_id}. Ignoring.")
logger.info(f"Model '{model_config['name']}' found {len(targets_output_list)} raw objects.")
return targets_output_list
def detect_objects_by_class(image_cv: np.ndarray, target_classes: List[str]) -> List[Dict]:
"""
Main public function. Detects objects by intelligently dispatching to the correct model(s).
Args:
image_cv: The input OpenCV image.
target_classes: A list of class names to detect (e.g., ['pho', 'hole']).
Returns:
A list of detection dictionaries for the successfully found targets.
"""
final_results = []
# Determine which targets belong to Model 1.
model_1_targets = [cls for cls in target_classes if cls in MODEL_1_CONFIG["class_names"]]
# Determine which targets belong to Model 2.
model_2_targets = [cls for cls in target_classes if cls in MODEL_2_CONFIG["class_names"]]
# Run Model 1 only if it's needed.
if model_1_targets:
logger.info(f"Model 1 ('{MODEL_1_CONFIG['name']}') triggered for targets: {model_1_targets}")
model_1_results = _run_inference_for_model(image_cv, MODEL_1_CONFIG)
# Filter the raw results to only include what the user asked for.
final_results.extend([res for res in model_1_results if res["type"] in model_1_targets])
# Run Model 2 only if it's needed.
if model_2_targets:
logger.info(f"Model 2 ('{MODEL_2_CONFIG['name']}') triggered for targets: {model_2_targets}")
model_2_results = _run_inference_for_model(image_cv, MODEL_2_CONFIG)
# Filter the raw results to only include what the user asked for.
final_results.extend([res for res in model_2_results if res["type"] in model_2_targets])
logger.info(f"Photovoltaic detection complete. Returning {len(final_results)} filtered targets.")
return final_results
# --- Export for main.py ---
# This list is imported by main.py to know which tasks to delegate to this module.
PHO_DETECTOR_SUPPORTED_CLASSES = ALL_SUPPORTED_CLASSES