93 lines
4.4 KiB
Python
93 lines
4.4 KiB
Python
import cv2
|
|
import numpy as np
|
|
# import supervision as sv # No longer directly needed
|
|
from sahi import AutoDetectionModel
|
|
from sahi.predict import get_sliced_prediction
|
|
# import base64 # No longer needed for primary output
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
CLASS_NAMES_FIRE_SMOKE = ["fire", "smoggy"]
|
|
MODEL_WEIGHTS_PATH_FIRE_SMOKE = "models/fire_smoke_model/best.pt"
|
|
|
|
SAHI_PARAMS_LARGE = {"slice_height": 2000, "slice_width": 2000, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2}
|
|
SAHI_PARAMS_SMALL = {"slice_height": 256, "slice_width": 256, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2}
|
|
|
|
detection_model_fire_smoke = None
|
|
|
|
def initialize_fire_smoke_model():
|
|
global detection_model_fire_smoke
|
|
if detection_model_fire_smoke is None:
|
|
logger.info(f"Loading fire_smoke detection model from {MODEL_WEIGHTS_PATH_FIRE_SMOKE}...")
|
|
try:
|
|
detection_model_fire_smoke = AutoDetectionModel.from_pretrained(
|
|
model_type="ultralytics",
|
|
model_path=MODEL_WEIGHTS_PATH_FIRE_SMOKE,
|
|
confidence_threshold=0.8,
|
|
)
|
|
logger.info("Fire_smoke detection model loaded successfully.")
|
|
except Exception as e:
|
|
logger.error(f"Error loading fire_smoke model: {e}", exc_info=True)
|
|
raise RuntimeError(f"Could not load fire_smoke model: {e}")
|
|
return detection_model_fire_smoke
|
|
|
|
def _run_sahi_and_process_results(image_cv: np.ndarray, model, sahi_params: dict, model_class_names: list):
|
|
# This helper function is the same
|
|
try:
|
|
sahi_result = get_sliced_prediction(
|
|
image=image_cv, detection_model=model, slice_height=sahi_params["slice_height"],
|
|
slice_width=sahi_params["slice_width"], overlap_height_ratio=sahi_params["overlap_height_ratio"],
|
|
overlap_width_ratio=sahi_params["overlap_width_ratio"], verbose=0
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"SAHI prediction failed with params {sahi_params}: {e}", exc_info=True)
|
|
return []
|
|
object_predictions_original = sahi_result.object_prediction_list
|
|
valid_original_predictions = []
|
|
for pred in object_predictions_original:
|
|
cat_id = int(pred.category.id)
|
|
if 0 <= cat_id < len(model_class_names):
|
|
valid_original_predictions.append(pred)
|
|
else:
|
|
logger.warning(f"SAHI Raw: Detected class ID {cat_id} for {model.__class__.__name__} is out of range for model_class_names: {model_class_names}. Ignoring.")
|
|
return valid_original_predictions
|
|
|
|
def detect_fire_smoke_with_sahi(image_cv: np.ndarray, extract: bool = False):
|
|
global detection_model_fire_smoke
|
|
if detection_model_fire_smoke is None:
|
|
logger.warning("Fire_smoke model was not loaded. Attempting to initialize now.")
|
|
initialize_fire_smoke_model()
|
|
if detection_model_fire_smoke is None:
|
|
logger.error("Fire_smoke model could not be initialized for detection.")
|
|
return {"error": "Fire_smoke model is not available."}
|
|
|
|
logger.info(f"Executing fire_smoke detection, extract flag is: {extract}")
|
|
_obj_preds_all_model_outputs = []
|
|
|
|
logger.info(f"Attempting fire_smoke detection with SAHI (Large Slices)...")
|
|
obj_preds_attempt1 = _run_sahi_and_process_results(
|
|
image_cv, detection_model_fire_smoke, SAHI_PARAMS_LARGE, CLASS_NAMES_FIRE_SMOKE
|
|
)
|
|
if obj_preds_attempt1:
|
|
_obj_preds_all_model_outputs = obj_preds_attempt1
|
|
else:
|
|
logger.info("No SAHI fire_smoke detections (large slices). Skipping small slices attempt as requested.")
|
|
_obj_preds_all_model_outputs = []
|
|
|
|
targets_output_list = []
|
|
for pred in _obj_preds_all_model_outputs:
|
|
class_id = int(pred.category.id)
|
|
class_name = CLASS_NAMES_FIRE_SMOKE[class_id]
|
|
minx, miny, maxx, maxy = pred.bbox.minx, pred.bbox.miny, pred.bbox.maxx, pred.bbox.maxy
|
|
bbox_width = maxx - minx
|
|
bbox_height = maxy - miny
|
|
targets_output_list.append({
|
|
"type": class_name,
|
|
"size": [int(bbox_width), int(bbox_height)],
|
|
"leftTopPoint": [int(minx), int(miny)],
|
|
"score": round(float(pred.score.value), 4)
|
|
})
|
|
|
|
logger.info(f"Prepared {len(targets_output_list)} fire/smoke targets for API response.")
|
|
return targets_output_list |