163 lines
6.7 KiB
Python
163 lines
6.7 KiB
Python
# utils/pho_detector.py
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from sahi import AutoDetectionModel
|
|
from sahi.predict import get_sliced_prediction
|
|
import logging
|
|
from typing import List, Dict, Any
|
|
|
|
# --- Basic Configuration ---
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# --- Module-Specific Configuration ---
|
|
|
|
# All classes this module can detect.
|
|
ALL_SUPPORTED_CLASSES = ['pho', 'shelves', 'pile', 'hole']
|
|
|
|
# Configuration for the first model (detects 'pho' and 'shelves').
|
|
MODEL_1_CONFIG = {
|
|
"name": "pho_model",
|
|
# IMPORTANT: Update this path to your actual model location.
|
|
"weights_path": "models/photovoltaic_model/pho_model/best.pt",
|
|
# Class order must match the model's training IDs (0: 'pho', 1: 'shelves').
|
|
"class_names": ['pho', 'shelves'],
|
|
"confidence_threshold": 0.7
|
|
}
|
|
|
|
# Configuration for the second model (detects 'pile' and 'hole').
|
|
MODEL_2_CONFIG = {
|
|
"name": "pile_model",
|
|
# IMPORTANT: Update this path to your actual model location.
|
|
"weights_path": "models/photovoltaic_model/pile_model/best.pt",
|
|
# Class order must match the model's training IDs (0: 'pile', 1: 'hole').
|
|
"class_names": ['pile', 'hole'],
|
|
"confidence_threshold": 0.7
|
|
}
|
|
|
|
# Unified SAHI slicing parameters for all tasks in this module.
|
|
SAHI_PARAMS = {
|
|
"slice_height": 800,
|
|
"slice_width": 800,
|
|
"overlap_height_ratio": 0.2,
|
|
"overlap_width_ratio": 0.2
|
|
}
|
|
|
|
# Global dictionary to cache loaded model instances.
|
|
detection_models: Dict[str, AutoDetectionModel] = {}
|
|
|
|
# --- Model Management Functions ---
|
|
|
|
def initialize_model(model_config: Dict[str, Any]) -> AutoDetectionModel:
|
|
"""Initializes a single model based on its configuration if not already loaded."""
|
|
global detection_models
|
|
model_name = model_config["name"]
|
|
|
|
if model_name not in detection_models:
|
|
logger.info(f"Loading detection model '{model_name}' from {model_config['weights_path']}...")
|
|
try:
|
|
model = AutoDetectionModel.from_pretrained(
|
|
model_type="ultralytics",
|
|
model_path=model_config["weights_path"],
|
|
confidence_threshold=model_config["confidence_threshold"],
|
|
)
|
|
detection_models[model_name] = model
|
|
logger.info(f"Model '{model_name}' loaded successfully.")
|
|
except Exception as e:
|
|
logger.error(f"Error loading model '{model_name}': {e}", exc_info=True)
|
|
raise RuntimeError(f"Could not load model '{model_name}': {e}")
|
|
|
|
return detection_models[model_name]
|
|
|
|
def initialize_all_pho_models():
|
|
"""Initializes all models defined in this module. Called once at application startup."""
|
|
logger.info("Pre-initializing all photovoltaic-related models...")
|
|
initialize_model(MODEL_1_CONFIG)
|
|
initialize_model(MODEL_2_CONFIG)
|
|
logger.info("All photovoltaic models have been initialized.")
|
|
|
|
# --- Inference and Core Logic ---
|
|
|
|
def _run_inference_for_model(image_cv: np.ndarray, model_config: Dict[str, Any]) -> List[Dict]:
|
|
"""Internal function to run SAHI prediction for a single specified model."""
|
|
try:
|
|
model = initialize_model(model_config)
|
|
except RuntimeError as e:
|
|
logger.error(f"Cannot run inference; model '{model_config['name']}' failed to initialize: {e}")
|
|
return []
|
|
|
|
logger.info(f"Running SAHI prediction with model: {model_config['name']}...")
|
|
try:
|
|
sahi_result = get_sliced_prediction(
|
|
image=image_cv,
|
|
detection_model=model,
|
|
**SAHI_PARAMS,
|
|
verbose=0
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"SAHI prediction failed for model '{model_config['name']}': {e}", exc_info=True)
|
|
return []
|
|
|
|
targets_output_list = []
|
|
# Use the specific class list for this model to map IDs to names.
|
|
model_class_names = model_config["class_names"]
|
|
|
|
for pred in sahi_result.object_prediction_list:
|
|
cat_id = int(pred.category.id)
|
|
|
|
# Core Logic: Map the model's local ID to its correct global class name.
|
|
if 0 <= cat_id < len(model_class_names):
|
|
class_name = model_class_names[cat_id]
|
|
minx, miny, maxx, maxy = pred.bbox.minx, pred.bbox.miny, pred.bbox.maxx, pred.bbox.maxy
|
|
|
|
targets_output_list.append({
|
|
"type": class_name,
|
|
"size": [int(maxx - minx), int(maxy - miny)],
|
|
"leftTopPoint": [int(minx), int(miny)],
|
|
"score": round(float(pred.score.value), 4)
|
|
})
|
|
else:
|
|
logger.warning(f"Model '{model_config['name']}' detected an out-of-range class ID: {cat_id}. Ignoring.")
|
|
|
|
logger.info(f"Model '{model_config['name']}' found {len(targets_output_list)} raw objects.")
|
|
return targets_output_list
|
|
|
|
def detect_objects_by_class(image_cv: np.ndarray, target_classes: List[str]) -> List[Dict]:
|
|
"""
|
|
Main public function. Detects objects by intelligently dispatching to the correct model(s).
|
|
|
|
Args:
|
|
image_cv: The input OpenCV image.
|
|
target_classes: A list of class names to detect (e.g., ['pho', 'hole']).
|
|
|
|
Returns:
|
|
A list of detection dictionaries for the successfully found targets.
|
|
"""
|
|
final_results = []
|
|
|
|
# Determine which targets belong to Model 1.
|
|
model_1_targets = [cls for cls in target_classes if cls in MODEL_1_CONFIG["class_names"]]
|
|
# Determine which targets belong to Model 2.
|
|
model_2_targets = [cls for cls in target_classes if cls in MODEL_2_CONFIG["class_names"]]
|
|
|
|
# Run Model 1 only if it's needed.
|
|
if model_1_targets:
|
|
logger.info(f"Model 1 ('{MODEL_1_CONFIG['name']}') triggered for targets: {model_1_targets}")
|
|
model_1_results = _run_inference_for_model(image_cv, MODEL_1_CONFIG)
|
|
# Filter the raw results to only include what the user asked for.
|
|
final_results.extend([res for res in model_1_results if res["type"] in model_1_targets])
|
|
|
|
# Run Model 2 only if it's needed.
|
|
if model_2_targets:
|
|
logger.info(f"Model 2 ('{MODEL_2_CONFIG['name']}') triggered for targets: {model_2_targets}")
|
|
model_2_results = _run_inference_for_model(image_cv, MODEL_2_CONFIG)
|
|
# Filter the raw results to only include what the user asked for.
|
|
final_results.extend([res for res in model_2_results if res["type"] in model_2_targets])
|
|
|
|
logger.info(f"Photovoltaic detection complete. Returning {len(final_results)} filtered targets.")
|
|
return final_results
|
|
|
|
# --- Export for main.py ---
|
|
# This list is imported by main.py to know which tasks to delegate to this module.
|
|
PHO_DETECTOR_SUPPORTED_CLASSES = ALL_SUPPORTED_CLASSES |