import cv2 import numpy as np # import supervision as sv # No longer directly needed from sahi import AutoDetectionModel from sahi.predict import get_sliced_prediction # import base64 # No longer needed for primary output import logging logger = logging.getLogger(__name__) CLASS_NAMES_FIRE_SMOKE = ["fire", "smoggy"] MODEL_WEIGHTS_PATH_FIRE_SMOKE = "models/fire_smoke_model/best.pt" SAHI_PARAMS_LARGE = {"slice_height": 2000, "slice_width": 2000, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2} SAHI_PARAMS_SMALL = {"slice_height": 256, "slice_width": 256, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2} detection_model_fire_smoke = None def initialize_fire_smoke_model(): global detection_model_fire_smoke if detection_model_fire_smoke is None: logger.info(f"Loading fire_smoke detection model from {MODEL_WEIGHTS_PATH_FIRE_SMOKE}...") try: detection_model_fire_smoke = AutoDetectionModel.from_pretrained( model_type="ultralytics", model_path=MODEL_WEIGHTS_PATH_FIRE_SMOKE, confidence_threshold=0.8, ) logger.info("Fire_smoke detection model loaded successfully.") except Exception as e: logger.error(f"Error loading fire_smoke model: {e}", exc_info=True) raise RuntimeError(f"Could not load fire_smoke model: {e}") return detection_model_fire_smoke def _run_sahi_and_process_results(image_cv: np.ndarray, model, sahi_params: dict, model_class_names: list): # This helper function is the same try: sahi_result = get_sliced_prediction( image=image_cv, detection_model=model, slice_height=sahi_params["slice_height"], slice_width=sahi_params["slice_width"], overlap_height_ratio=sahi_params["overlap_height_ratio"], overlap_width_ratio=sahi_params["overlap_width_ratio"], verbose=0 ) except Exception as e: logger.error(f"SAHI prediction failed with params {sahi_params}: {e}", exc_info=True) return [] object_predictions_original = sahi_result.object_prediction_list valid_original_predictions = [] for pred in object_predictions_original: cat_id = int(pred.category.id) if 0 <= cat_id < len(model_class_names): valid_original_predictions.append(pred) else: logger.warning(f"SAHI Raw: Detected class ID {cat_id} for {model.__class__.__name__} is out of range for model_class_names: {model_class_names}. Ignoring.") return valid_original_predictions def detect_fire_smoke_with_sahi(image_cv: np.ndarray, extract: bool = False): global detection_model_fire_smoke if detection_model_fire_smoke is None: logger.warning("Fire_smoke model was not loaded. Attempting to initialize now.") initialize_fire_smoke_model() if detection_model_fire_smoke is None: logger.error("Fire_smoke model could not be initialized for detection.") return {"error": "Fire_smoke model is not available."} logger.info(f"Executing fire_smoke detection, extract flag is: {extract}") _obj_preds_all_model_outputs = [] logger.info(f"Attempting fire_smoke detection with SAHI (Large Slices)...") obj_preds_attempt1 = _run_sahi_and_process_results( image_cv, detection_model_fire_smoke, SAHI_PARAMS_LARGE, CLASS_NAMES_FIRE_SMOKE ) if obj_preds_attempt1: _obj_preds_all_model_outputs = obj_preds_attempt1 else: logger.info("No SAHI fire_smoke detections (large slices). Skipping small slices attempt as requested.") _obj_preds_all_model_outputs = [] targets_output_list = [] for pred in _obj_preds_all_model_outputs: class_id = int(pred.category.id) class_name = CLASS_NAMES_FIRE_SMOKE[class_id] minx, miny, maxx, maxy = pred.bbox.minx, pred.bbox.miny, pred.bbox.maxx, pred.bbox.maxy bbox_width = maxx - minx bbox_height = maxy - miny targets_output_list.append({ "type": class_name, "size": [int(bbox_width), int(bbox_height)], "leftTopPoint": [int(minx), int(miny)], "score": round(float(pred.score.value), 4) }) logger.info(f"Prepared {len(targets_output_list)} fire/smoke targets for API response.") return targets_output_list