ocr1.0
This commit is contained in:
133
ocr/model_violation_detector.py
Normal file
133
ocr/model_violation_detector.py
Normal file
@ -0,0 +1,133 @@
|
||||
import cv2
|
||||
from logger_config import logger
|
||||
from ocr_violation_detector import OCRViolationDetector
|
||||
from yolo_violation_detector import ViolationDetector as YoloViolationDetector
|
||||
from face_recognizer import FaceRecognizer
|
||||
|
||||
|
||||
class MultiModelViolationDetector:
|
||||
"""
|
||||
多模型违规检测封装类,串行调用OCR、人脸识别和YOLO模型(调整为YOLO最后检测),任一模型检测到违规即返回结果
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
forbidden_words_path: str,
|
||||
ocr_config_path: str, # 新增OCR配置文件路径参数
|
||||
yolo_model_path: str,
|
||||
known_faces_dir: str,
|
||||
ocr_confidence_threshold: float = 0.5):
|
||||
"""
|
||||
初始化所有检测模型
|
||||
|
||||
Args:
|
||||
forbidden_words_path: 违禁词文件路径
|
||||
ocr_config_path: OCR配置文件(1.yaml)路径
|
||||
yolo_model_path: YOLO模型文件路径
|
||||
known_faces_dir: 已知人脸目录路径
|
||||
ocr_confidence_threshold: OCR置信度阈值
|
||||
"""
|
||||
# 初始化OCR检测器(传入配置文件路径)
|
||||
self.ocr_detector = OCRViolationDetector(
|
||||
forbidden_words_path=forbidden_words_path,
|
||||
ocr_config_path=ocr_config_path, # 传递配置文件路径
|
||||
ocr_confidence_threshold=ocr_confidence_threshold
|
||||
)
|
||||
|
||||
# 初始化人脸识别器
|
||||
self.face_recognizer = FaceRecognizer(
|
||||
known_faces_dir=known_faces_dir
|
||||
)
|
||||
|
||||
# 初始化YOLO检测器(调整为最后初始化)
|
||||
self.yolo_detector = YoloViolationDetector(
|
||||
model_path=yolo_model_path
|
||||
)
|
||||
|
||||
logger.info("多模型违规检测器初始化完成")
|
||||
|
||||
def detect_violations(self, frame):
|
||||
"""
|
||||
串行调用三个检测模型(OCR → 人脸识别 → YOLO),任一检测到违规即返回结果
|
||||
Args:
|
||||
frame: 输入视频帧 (NumPy数组, BGR格式)
|
||||
Returns:
|
||||
tuple: (是否有违规, 违规类型, 违规详情)
|
||||
违规类型: 'ocr' | 'yolo' | 'face' | None
|
||||
违规详情: 对应模型的检测结果
|
||||
"""
|
||||
# 1. 首先进行OCR违禁词检测
|
||||
try:
|
||||
ocr_has_violation, ocr_words, ocr_confs = self.ocr_detector.detect(frame)
|
||||
if ocr_has_violation:
|
||||
details = {
|
||||
"words": ocr_words,
|
||||
"confidences": ocr_confs
|
||||
}
|
||||
logger.warning(f"OCR检测到违禁内容: {details}")
|
||||
return (True, "ocr", details)
|
||||
except Exception as e:
|
||||
logger.error(f"OCR检测出错: {str(e)}", exc_info=True)
|
||||
|
||||
# 2. 接着进行人脸识别检测
|
||||
try:
|
||||
face_has_violation, face_name, face_similarity = self.face_recognizer.recognize(frame)
|
||||
if face_has_violation:
|
||||
details = {
|
||||
"name": face_name,
|
||||
"similarity": face_similarity
|
||||
}
|
||||
logger.warning(f"人脸识别到违规人员: {details}")
|
||||
return (True, "face", details)
|
||||
except Exception as e:
|
||||
logger.error(f"人脸识别出错: {str(e)}", exc_info=True)
|
||||
|
||||
# 3. 最后进行YOLO目标检测(调整为最后检测)
|
||||
try:
|
||||
yolo_results = self.yolo_detector.detect(frame)
|
||||
# 检查是否有检测结果(根据实际业务定义何为违规目标)
|
||||
if len(yolo_results.boxes) > 0:
|
||||
# 提取检测到的目标信息
|
||||
details = {
|
||||
"classes": yolo_results.names,
|
||||
"boxes": yolo_results.boxes.xyxy.tolist(), # 边界框坐标
|
||||
"confidences": yolo_results.boxes.conf.tolist(), # 置信度
|
||||
"class_ids": yolo_results.boxes.cls.tolist() # 类别ID
|
||||
}
|
||||
logger.warning(f"YOLO检测到违规目标: {details}")
|
||||
return (True, "yolo", details)
|
||||
except Exception as e:
|
||||
logger.error(f"YOLO检测出错: {str(e)}", exc_info=True)
|
||||
|
||||
# 所有检测均未发现违规
|
||||
return (False, None, None)
|
||||
|
||||
|
||||
# # 使用示例
|
||||
# if __name__ == "__main__":
|
||||
# # 配置文件路径(根据实际情况修改)
|
||||
# FORBIDDEN_WORDS_PATH = r"D:\Git\bin\video\ocr\forbidden_words.txt"
|
||||
# OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml" # 新增OCR配置文件路径
|
||||
# YOLO_MODEL_PATH = r"D:\Git\bin\video\ocr\models\best.pt"
|
||||
# KNOWN_FACES_DIR = r"D:\Git\bin\video\ocr\known_faces"
|
||||
#
|
||||
# # 初始化多模型检测器
|
||||
# detector = MultiModelViolationDetector(
|
||||
# forbidden_words_path=FORBIDDEN_WORDS_PATH,
|
||||
# ocr_config_path=OCR_CONFIG_PATH, # 传入OCR配置文件路径
|
||||
# yolo_model_path=YOLO_MODEL_PATH,
|
||||
# known_faces_dir=KNOWN_FACES_DIR,
|
||||
# ocr_confidence_threshold=0.5
|
||||
# )
|
||||
#
|
||||
# # 读取测试图像(可替换为视频帧读取逻辑)
|
||||
# test_image_path = r"D:\Git\bin\video\ocr\images\img.png"
|
||||
# frame = cv2.imread(test_image_path)
|
||||
#
|
||||
# if frame is not None:
|
||||
# has_violation, violation_type, details = detector.detect_violations(frame)
|
||||
# if has_violation:
|
||||
# print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
|
||||
# else:
|
||||
# print("未检测到任何违规内容")
|
||||
# else:
|
||||
# print(f"无法读取测试图像: {test_image_path}")
|
Reference in New Issue
Block a user