# utils/pho_detector.py import cv2 import numpy as np from sahi import AutoDetectionModel from sahi.predict import get_sliced_prediction import logging from typing import List, Dict, Any # --- Basic Configuration --- logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # --- Module-Specific Configuration --- # All classes this module can detect. ALL_SUPPORTED_CLASSES = ['pho', 'shelves', 'pile', 'hole'] # Configuration for the first model (detects 'pho' and 'shelves'). MODEL_1_CONFIG = { "name": "pho_model", # IMPORTANT: Update this path to your actual model location. "weights_path": "models/photovoltaic_model/pho_model/best.pt", # Class order must match the model's training IDs (0: 'pho', 1: 'shelves'). "class_names": ['pho', 'shelves'], "confidence_threshold": 0.7 } # Configuration for the second model (detects 'pile' and 'hole'). MODEL_2_CONFIG = { "name": "pile_model", # IMPORTANT: Update this path to your actual model location. "weights_path": "models/photovoltaic_model/pile_model/best.pt", # Class order must match the model's training IDs (0: 'pile', 1: 'hole'). "class_names": ['pile', 'hole'], "confidence_threshold": 0.7 } # Unified SAHI slicing parameters for all tasks in this module. SAHI_PARAMS = { "slice_height": 800, "slice_width": 800, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2 } # Global dictionary to cache loaded model instances. detection_models: Dict[str, AutoDetectionModel] = {} # --- Model Management Functions --- def initialize_model(model_config: Dict[str, Any]) -> AutoDetectionModel: """Initializes a single model based on its configuration if not already loaded.""" global detection_models model_name = model_config["name"] if model_name not in detection_models: logger.info(f"Loading detection model '{model_name}' from {model_config['weights_path']}...") try: model = AutoDetectionModel.from_pretrained( model_type="ultralytics", model_path=model_config["weights_path"], confidence_threshold=model_config["confidence_threshold"], ) detection_models[model_name] = model logger.info(f"Model '{model_name}' loaded successfully.") except Exception as e: logger.error(f"Error loading model '{model_name}': {e}", exc_info=True) raise RuntimeError(f"Could not load model '{model_name}': {e}") return detection_models[model_name] def initialize_all_pho_models(): """Initializes all models defined in this module. Called once at application startup.""" logger.info("Pre-initializing all photovoltaic-related models...") initialize_model(MODEL_1_CONFIG) initialize_model(MODEL_2_CONFIG) logger.info("All photovoltaic models have been initialized.") # --- Inference and Core Logic --- def _run_inference_for_model(image_cv: np.ndarray, model_config: Dict[str, Any]) -> List[Dict]: """Internal function to run SAHI prediction for a single specified model.""" try: model = initialize_model(model_config) except RuntimeError as e: logger.error(f"Cannot run inference; model '{model_config['name']}' failed to initialize: {e}") return [] logger.info(f"Running SAHI prediction with model: {model_config['name']}...") try: sahi_result = get_sliced_prediction( image=image_cv, detection_model=model, **SAHI_PARAMS, verbose=0 ) except Exception as e: logger.error(f"SAHI prediction failed for model '{model_config['name']}': {e}", exc_info=True) return [] targets_output_list = [] # Use the specific class list for this model to map IDs to names. model_class_names = model_config["class_names"] for pred in sahi_result.object_prediction_list: cat_id = int(pred.category.id) # Core Logic: Map the model's local ID to its correct global class name. if 0 <= cat_id < len(model_class_names): class_name = model_class_names[cat_id] minx, miny, maxx, maxy = pred.bbox.minx, pred.bbox.miny, pred.bbox.maxx, pred.bbox.maxy targets_output_list.append({ "type": class_name, "size": [int(maxx - minx), int(maxy - miny)], "leftTopPoint": [int(minx), int(miny)], "score": round(float(pred.score.value), 4) }) else: logger.warning(f"Model '{model_config['name']}' detected an out-of-range class ID: {cat_id}. Ignoring.") logger.info(f"Model '{model_config['name']}' found {len(targets_output_list)} raw objects.") return targets_output_list def detect_objects_by_class(image_cv: np.ndarray, target_classes: List[str]) -> List[Dict]: """ Main public function. Detects objects by intelligently dispatching to the correct model(s). Args: image_cv: The input OpenCV image. target_classes: A list of class names to detect (e.g., ['pho', 'hole']). Returns: A list of detection dictionaries for the successfully found targets. """ final_results = [] # Determine which targets belong to Model 1. model_1_targets = [cls for cls in target_classes if cls in MODEL_1_CONFIG["class_names"]] # Determine which targets belong to Model 2. model_2_targets = [cls for cls in target_classes if cls in MODEL_2_CONFIG["class_names"]] # Run Model 1 only if it's needed. if model_1_targets: logger.info(f"Model 1 ('{MODEL_1_CONFIG['name']}') triggered for targets: {model_1_targets}") model_1_results = _run_inference_for_model(image_cv, MODEL_1_CONFIG) # Filter the raw results to only include what the user asked for. final_results.extend([res for res in model_1_results if res["type"] in model_1_targets]) # Run Model 2 only if it's needed. if model_2_targets: logger.info(f"Model 2 ('{MODEL_2_CONFIG['name']}') triggered for targets: {model_2_targets}") model_2_results = _run_inference_for_model(image_cv, MODEL_2_CONFIG) # Filter the raw results to only include what the user asked for. final_results.extend([res for res in model_2_results if res["type"] in model_2_targets]) logger.info(f"Photovoltaic detection complete. Returning {len(final_results)} filtered targets.") return final_results # --- Export for main.py --- # This list is imported by main.py to know which tasks to delegate to this module. PHO_DETECTOR_SUPPORTED_CLASSES = ALL_SUPPORTED_CLASSES