Files
AI-TianDong/utils/hardhat_detector copy.py
2025-07-24 12:45:27 +08:00

111 lines
5.4 KiB
Python

import cv2
import numpy as np
import supervision as sv
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
# import base64 # No longer needed for primary output if base64 image isn't returned
import logging
logger = logging.getLogger(__name__)
CLASS_NAMES_HELMET_MODEL_OUTPUT = ["helmet", "nohelmet", "vast", "novast"]
TARGET_CLASS_FOR_REPORTING = "nohelmet"
MODEL_WEIGHTS_PATH_HELMET = "models/helmet_model/best.pt"
SAHI_PARAMS_LARGE = {"slice_height": 1200, "slice_width": 1200, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2}
SAHI_PARAMS_SMALL = {"slice_height": 700, "slice_width": 700, "overlap_height_ratio": 0.2, "overlap_width_ratio": 0.2}
detection_model_helmet = None
def initialize_helmet_model():
global detection_model_helmet
if detection_model_helmet is None:
logger.info(f"Loading helmet detection model from {MODEL_WEIGHTS_PATH_HELMET}...")
try:
detection_model_helmet = AutoDetectionModel.from_pretrained(
model_type="ultralytics",
model_path=MODEL_WEIGHTS_PATH_HELMET,
confidence_threshold=0.8,
)
logger.info("Helmet detection model loaded successfully.")
except Exception as e:
logger.error(f"Error loading helmet model: {e}", exc_info=True)
raise RuntimeError(f"Could not load helmet model: {e}")
return detection_model_helmet
# encode_image_to_base64 is no longer strictly needed if not returning annotated image
def _run_sahi_and_process_results(image_cv: np.ndarray, model, sahi_params: dict, model_class_names: list):
# This helper function is mostly the same, returns List[ObjectPrediction]
try:
sahi_result = get_sliced_prediction(
image=image_cv, detection_model=model, slice_height=sahi_params["slice_height"],
slice_width=sahi_params["slice_width"], overlap_height_ratio=sahi_params["overlap_height_ratio"],
overlap_width_ratio=sahi_params["overlap_width_ratio"], verbose=0
)
except Exception as e:
logger.error(f"SAHI prediction failed with params {sahi_params}: {e}", exc_info=True)
return [] # Return empty list of predictions on SAHI error
object_predictions_original = sahi_result.object_prediction_list
valid_original_predictions = []
for pred in object_predictions_original:
cat_id = int(pred.category.id)
if 0 <= cat_id < len(model_class_names):
valid_original_predictions.append(pred)
else:
logger.warning(f"SAHI Raw: Detected class ID {cat_id} for {model.__class__.__name__} is out of range for model_class_names: {model_class_names}. Ignoring.")
return valid_original_predictions
def detect_hardhat_with_sahi(image_cv: np.ndarray, extract: bool = False):
global detection_model_helmet
if detection_model_helmet is None:
logger.warning("Helmet model was not loaded. Attempting to initialize now.")
initialize_helmet_model()
if detection_model_helmet is None:
logger.error("Helmet model could not be initialized for detection.")
return {"error": "Helmet model is not available."} # Return error dict
logger.info(f"Executing helmet task (reporting only '{TARGET_CLASS_FOR_REPORTING}'), extract flag is: {extract}")
_obj_preds_all_model_outputs = [] # Store all valid SAHI ObjectPrediction
logger.info(f"Attempting detection with SAHI (Large Slices)...")
obj_preds_attempt1 = _run_sahi_and_process_results(
image_cv, detection_model_helmet, SAHI_PARAMS_LARGE, CLASS_NAMES_HELMET_MODEL_OUTPUT
)
if obj_preds_attempt1: # If list is not empty
logger.info(f"Initial SAHI detections (large slices): {len(obj_preds_attempt1)}")
_obj_preds_all_model_outputs = obj_preds_attempt1
else:
logger.info("No SAHI detections (large slices). Retrying with small slices...")
obj_preds_attempt2 = _run_sahi_and_process_results(
image_cv, detection_model_helmet, SAHI_PARAMS_SMALL, CLASS_NAMES_HELMET_MODEL_OUTPUT
)
if obj_preds_attempt2:
logger.info(f"Initial SAHI detections (small slices): {len(obj_preds_attempt2)}")
_obj_preds_all_model_outputs = obj_preds_attempt2
else:
logger.info("No initial SAHI detections found even with small slices.")
# --- Transform to new output format, filtering for TARGET_CLASS_FOR_REPORTING ---
targets_output_list = []
for pred in _obj_preds_all_model_outputs:
class_id = int(pred.category.id)
class_name = CLASS_NAMES_HELMET_MODEL_OUTPUT[class_id]
if class_name == TARGET_CLASS_FOR_REPORTING:
minx, miny, maxx, maxy = pred.bbox.minx, pred.bbox.miny, pred.bbox.maxx, pred.bbox.maxy
bbox_width = maxx - minx
bbox_height = maxy - miny
targets_output_list.append({
"type": class_name,
"size": [int(bbox_width), int(bbox_height)],
"leftTopPoint": [int(minx), int(miny)],
"score": round(float(pred.score.value), 4) # Rounded score
})
logger.info(f"Prepared {len(targets_output_list)} '{TARGET_CLASS_FOR_REPORTING}' targets for API response.")
return targets_output_list # Directly return list of target dicts