From d83923d06baa6aeb6e84050e5bfce4f0fa1fb468 Mon Sep 17 00:00:00 2001 From: ninghongbin <2409766686@qq.com> Date: Wed, 3 Sep 2025 16:22:21 +0800 Subject: [PATCH] ocr1.0 --- core/rtmp.py | 83 +++++++-------- ocr/face_recognizer.py | 180 ++++++++++++-------------------- ocr/model_violation_detector.py | 113 ++++++++++---------- ocr/ocr_violation_detector.py | 170 +++++++++--------------------- ocr/yolo_violation_detector.py | 7 +- 5 files changed, 211 insertions(+), 342 deletions(-) diff --git a/core/rtmp.py b/core/rtmp.py index 02da436..6de5447 100644 --- a/core/rtmp.py +++ b/core/rtmp.py @@ -2,109 +2,101 @@ import asyncio import logging import cv2 import time -from ocr.ocr_violation_detector import OCRViolationDetector +from ocr.model_violation_detector import MultiModelViolationDetector -import logging + +# 配置文件相对路径(根据实际目录结构调整) +YOLO_MODEL_PATH = "../ocr/models/best.pt" # 关键修正:从core目录向上一级找ocr文件夹 +FORBIDDEN_WORDS_PATH = "../ocr/forbidden_words.txt" +OCR_CONFIG_PATH = "../ocr/config/1.yaml" +KNOWN_FACES_DIR = "../ocr/known_faces" # 创建检测器实例 -detector = OCRViolationDetector( - forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt", - ocr_confidence_threshold=0.7, - log_level=logging.INFO, - log_file="ocr_detection.log" +detector = MultiModelViolationDetector( + forbidden_words_path=FORBIDDEN_WORDS_PATH, + ocr_config_path=OCR_CONFIG_PATH, + yolo_model_path=YOLO_MODEL_PATH, + known_faces_dir=KNOWN_FACES_DIR, + ocr_confidence_threshold=0.5 ) -# 配置日志(与WHEP代码保持一致的日志风格) +# 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger("rtmp_video_puller") async def rtmp_pull_video_stream(rtmp_url): """ - 通过RTMP从指定URL拉取视频流并在收到每一帧时打印消息 - 功能与WHEP拉流函数对齐:流状态反馈、帧信息打印、帧率统计、异常处理 - - Args: - rtmp_url: RTMP流的URL地址(如 rtmp://xxx/live/stream_key) + 通过RTMP从指定URL拉取视频流并进行违规检测 """ cap = None # 初始化视频捕获对象 try: - # 1. 异步打开RTMP流(指定FFmpeg后端确保RTMP兼容性,同步操作通过to_thread避免阻塞事件循环) + # 异步打开RTMP流 cap = await asyncio.to_thread( cv2.VideoCapture, rtmp_url, - cv2.CAP_FFMPEG # 必须指定FFmpeg后端,RTMP协议依赖该后端解析 + cv2.CAP_FFMPEG # 指定FFmpeg后端确保RTMP兼容性 ) - # 2. 检查RTMP流是否成功打开 + # 检查RTMP流是否成功打开 is_opened = await asyncio.to_thread(cap.isOpened) if not is_opened: raise Exception(f"RTMP流打开失败: {rtmp_url}(请检查URL有效性和FFmpeg环境)") - # 3. 异步获取RTMP流基础信息(分辨率、帧率) + # 获取RTMP流基础信息 width = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_WIDTH) height = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_HEIGHT) fps = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FPS) - # 处理异常情况:部分RTMP流未返回帧率时默认30FPS + # 处理异常情况 fps = fps if fps > 0 else 30.0 - # 分辨率转为整数(视频尺寸必然是整数) width, height = int(width), int(height) - # 打印流初始化成功信息(与WHEP连接成功信息风格一致) + # 打印流初始化成功信息 print(f"RTMP流状态: 已成功连接") print(f"流基础信息: 分辨率 {width}x{height} | 配置帧率 {fps:.2f} FPS") print("开始接收视频帧...(按 Ctrl+C 中断)") - # 4. 初始化帧统计参数 - frame_count = 0 # 总接收帧数 - start_time = time.time() # 统计起始时间 + # 初始化帧统计参数 + frame_count = 0 + start_time = time.time() - # 5. 循环异步读取视频帧(核心逻辑) + # 循环读取视频帧 while True: - # 异步读取一帧(cv2.read是同步操作,用to_thread适配异步环境) ret, frame = await asyncio.to_thread(cap.read) - # 检查帧是否读取成功(流中断/结束时ret为False) if not ret: print(f"RTMP流状态: 帧读取失败(可能流已中断或结束)") break - # 帧计数累加 frame_count += 1 - # 6. 打印当前帧基础信息(与WHEP帧信息打印风格对齐) + # 打印当前帧信息 print(f"收到帧 (第{frame_count}帧)") print(f" 帧尺寸: {width}x{height}") print(f" 配置帧率: {fps:.2f} FPS") - has_violation, violations, confidences = OCRViolationDetector.detect(frame) - - # 输出检测结果 - if has_violation: - detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:") - for word, conf in zip(violations, confidences): - detector.logger.info(f"- {word} (置信度: {conf:.4f})") + if frame is not None: + has_violation, violation_type, details = detector.detect_violations(frame) + if has_violation: + print(f"检测到违规 - 类型: {violation_type}, 详情: {details}") + else: + print("未检测到任何违规内容") else: - detector.logger.info("图片中未检测到违禁词") - # 7. 每100帧统计一次实际接收帧率(补充性能监控,与原RTMP示例逻辑一致) + print(f"无法读取测试图像") + + # 每100帧统计一次实际接收帧率 if frame_count % 100 == 0: elapsed_time = time.time() - start_time - actual_fps = frame_count / elapsed_time # 实际接收帧率(可能低于配置帧率) + actual_fps = frame_count / elapsed_time print(f"---- 帧统计: 累计{frame_count}帧 | 实际平均帧率 {actual_fps:.2f} FPS ----") - # (可选)帧数据处理入口:如需处理帧(如推流、分析),可在此处添加逻辑 - # 示例:yield frame (若需生成器模式,可调整函数为异步生成器) - - # 8. 异常处理(覆盖用户中断、通用错误) except KeyboardInterrupt: print(f"\n用户操作: 已通过 Ctrl+C 中断程序") except Exception as e: - # 日志记录详细错误(便于问题排查),同时打印用户可见信息 logger.error(f"RTMP流处理异常: {str(e)}", exc_info=True) print(f"错误信息: {str(e)}") finally: - # 9. 资源释放(无论成功/失败都确保释放,避免内存泄漏) if cap is not None: await asyncio.to_thread(cap.release) print(f"\n资源释放: RTMP流已关闭") @@ -114,8 +106,7 @@ async def rtmp_pull_video_stream(rtmp_url): if __name__ == "__main__": RTMP_URL = "rtmp://192.168.110.25:1935/live/473b95a47e338301cbd96809ea7ac416" - # 运行RTMP拉流任务(与WHEP一致的异步执行方式) try: asyncio.run(rtmp_pull_video_stream(RTMP_URL)) except Exception as e: - print(f"程序启动失败: {str(e)}") \ No newline at end of file + print(f"程序启动失败: {str(e)}") diff --git a/ocr/face_recognizer.py b/ocr/face_recognizer.py index d3c2aa7..1947748 100644 --- a/ocr/face_recognizer.py +++ b/ocr/face_recognizer.py @@ -4,10 +4,12 @@ import numpy as np import insightface from insightface.app import FaceAnalysis + class FaceRecognizer: """ 封装InsightFace人脸识别功能,支持从文件夹加载已知人脸。 """ + def __init__(self, known_faces_dir: str): self.known_faces_dir = known_faces_dir self.app = self._initialize_insightface() @@ -16,40 +18,30 @@ class FaceRecognizer: self._load_known_faces() def _initialize_insightface(self): - """ - 初始化InsightFace FaceAnalysis应用。 - 默认使用CPU,如果检测到CUDA,会自动使用GPU。 - """ - print("正在初始化InsightFace人脸识别引擎...") + """初始化InsightFace FaceAnalysis应用""" + print("初始化InsightFace引擎...") try: - # 默认模型是 'buffalo_l',包含检测、对齐、识别功能 - # 如果需要更小的模型,可以尝试 'buffalo_s' 或 'buffalo_m' - # ctx_id=0 表示使用GPU,ctx_id=-1 表示使用CPU - # InsightFace会自动检测CUDA并选择GPU,所以通常不需要手动设置ctx_id - app = FaceAnalysis(name='buffalo_l', root='~/.insightface') # 模型下载到用户目录 - app.prepare(ctx_id=0, det_size=(640, 640)) # det_size影响检测性能和精度 - print("InsightFace人脸识别引擎初始化成功。") + app = FaceAnalysis(name='buffalo_l', root='~/.insightface') + app.prepare(ctx_id=0, det_size=(640, 640)) + print("InsightFace引擎初始化完成") return app except Exception as e: - print(f"InsightFace人脸识别引擎初始化失败: {e}") - print("请确保已安装insightface和onnxruntime,并且模型文件已下载或可访问。") + print(f"InsightFace初始化失败: {e}") + print("请检查依赖是否安装及模型是否可访问") return None def _load_known_faces(self): - """ - 扫描已知人脸目录,加载每个人的照片并计算人脸特征。 - """ + """加载已知人脸特征""" if not os.path.exists(self.known_faces_dir): - print(f"警告: 已知人脸目录 '{self.known_faces_dir}' 不存在。请创建并放入照片。") + print(f"已知人脸目录不存在,已创建: {self.known_faces_dir}") os.makedirs(self.known_faces_dir, exist_ok=True) return - print(f"正在加载已知人脸特征从: '{self.known_faces_dir}'...") + print(f"从目录加载人脸特征: {self.known_faces_dir}") for person_name in os.listdir(self.known_faces_dir): - person_dir = os.path.join(self.known_faces_dir, person_name) if os.path.isdir(person_dir): - print(f" 加载人物: {person_name}") + print(f"处理人物: {person_name}") embeddings = [] for filename in os.listdir(person_dir): if filename.lower().endswith(('.png', '.jpg', '.jpeg')): @@ -57,131 +49,91 @@ class FaceRecognizer: try: img = cv2.imread(image_path) if img is None: - print(f" 警告: 无法读取图片 '{image_path}',已跳过。") + print(f"无法读取图片: {image_path},已跳过") continue - - # 查找人脸并提取特征 + faces = self.app.get(img) if faces: - # 通常一张照片只有一个人脸,取第一个 embeddings.append(faces[0].embedding) - print(f" 成功提取 '{filename}' 的人脸特征。") + print(f"提取特征成功: {filename}") else: - print(f" 警告: 在图片 '{filename}' 中未检测到人脸,已跳过。") + print(f"未检测到人脸: {filename},已跳过") except Exception as e: - print(f" 处理图片 '{image_path}' 时发生错误: {e}") - + print(f"处理图片出错 {image_path}: {e}") + if embeddings: - # 将多张照片的特征取平均,作为该人物的最终特征 self.known_faces_embeddings[person_name] = np.array(embeddings).mean(axis=0) self.known_faces_names.append(person_name) - print(f" 人物 '{person_name}' 加载完成,共 {len(embeddings)} 张照片。") + print(f"人物 {person_name} 加载完成,共 {len(embeddings)} 张照片") else: - print(f" 警告: 人物 '{person_name}' 没有有效的人脸特征,已跳过。") - print(f"已知人脸加载完成。共 {len(self.known_faces_names)} 个人物。") + print(f"人物 {person_name} 无有效特征,已跳过") + print(f"人脸加载完成,共 {len(self.known_faces_names)} 人") def recognize(self, frame, threshold=0.4): - """ - 在视频帧中识别人脸。 - - Args: - frame: 输入的图像帧 (NumPy数组, BGR格式)。 - threshold (float): 识别相似度阈值。0.0到1.0,越高越严格。 - - Returns: - tuple[bool, str|None, float|None]: 是否检测到已知人脸,检测到的人名,以及相似度。 - """ + """识别人脸并返回结果""" if not self.app or not self.known_faces_names: return False, None, None - faces = self.app.get(frame) # 在帧中检测并提取所有人的脸 + faces = self.app.get(frame) if not faces: return False, None, None for face in faces: - # 遍历已知人脸,进行比对 for known_name in self.known_faces_names: known_embedding = self.known_faces_embeddings[known_name] - - # --- 关键修改:手动计算余弦相似度 --- - # 确保embedding是float32类型,避免潜在的类型不匹配问题 + embedding1 = face.embedding.astype(np.float32) embedding2 = known_embedding.astype(np.float32) - # 计算点积 dot_product = np.dot(embedding1, embedding2) - # 计算L2范数(向量长度) norm_embedding1 = np.linalg.norm(embedding1) norm_embedding2 = np.linalg.norm(embedding2) - # 避免除以零 - if norm_embedding1 == 0 or norm_embedding2 == 0: - similarity = 0.0 - else: - similarity = dot_product / (norm_embedding1 * norm_embedding2) - # ------------------------------------- - + similarity = 0.0 if (norm_embedding1 == 0 or norm_embedding2 == 0) else ( + dot_product / (norm_embedding1 * norm_embedding2) + ) + if similarity >= threshold: - print(f"!!! 人脸识别检测到已知人物: '{known_name}' (相似度: {similarity:.4f}) !!!") - return True, known_name, similarity # 只要检测到一个就立即返回 - - return False, None, None # 没有检测到已知人脸 + print(f"检测到已知人物: {known_name} (相似度: {similarity:.4f})") + return True, known_name, similarity + return False, None, None - # def test_single_image(self, image_path: str, threshold=0.4): - # """ - # 测试单张图片的人脸识别效果 - # - # Args: - # image_path: 图片路径 - # threshold: 识别阈值 - # - # Returns: - # tuple[bool, str|None, float|None]: 是否检测到已知人脸,检测到的人名,以及相似度 - # """ - # if not os.path.exists(image_path): - # print(f"错误: 图片 '{image_path}' 不存在") - # return False, None, None - # - # # 读取图片 - # frame = cv2.imread(image_path) - # if frame is None: - # print(f"错误: 无法读取图片 '{image_path}'") - # return False, None, None - # - # # 调用识别方法 - # result, name, similarity = self.recognize(frame, threshold) - # - # # 显示结果 - # if result: - # print(f"测试结果: 在图片中识别到 {name},相似度: {similarity:.4f}") - # - # # 绘制识别结果并显示图片 - # faces = self.app.get(frame) - # for face in faces: - # bbox = face.bbox.astype(int) - # # 绘制 bounding box - # cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2) - # # 绘制姓名和相似度 - # text = f"{name}: {similarity:.2f}" - # cv2.putText(frame, text, (bbox[0], bbox[1] - 10), - # cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) - # - # # 显示图片 - # cv2.imshow('Recognition Result', frame) - # print("按任意键关闭图片窗口...") - # cv2.waitKey(0) - # cv2.destroyAllWindows() - # else: - # print("测试结果: 未在图片中识别到已知人脸") - # - # return result, name, similarity + def test_single_image(self, image_path: str, threshold=0.4): + """测试单张图片识别""" + if not os.path.exists(image_path): + print(f"图片不存在: {image_path}") + return False, None, None + frame = cv2.imread(image_path) + if frame is None: + print(f"无法读取图片: {image_path}") + return False, None, None + + result, name, similarity = self.recognize(frame, threshold) + + if result: + print(f"识别结果: {name} (相似度: {similarity:.4f})") + + faces = self.app.get(frame) + for face in faces: + bbox = face.bbox.astype(int) + cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2) + text = f"{name}: {similarity:.2f}" + cv2.putText(frame, text, (bbox[0], bbox[1] - 10), + cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) + + cv2.imshow('识别结果', frame) + print("按任意键关闭窗口...") + cv2.waitKey(0) + cv2.destroyAllWindows() + else: + print("未识别到已知人脸") + + return result, name, similarity -# if __name__ == "__main__": -# # 初始化人脸识别器,指定已知人脸目录 -# recognizer = FaceRecognizer(known_faces_dir="known_faces") # -# # 测试单张图片 -# test_image_path = r"F:\OCR\RapidOCR-main\known_faces\B\14sino-qiu02-master1050.jpg" # 替换为你的测试图片路径 +# if __name__ == "__main__": +# recognizer = FaceRecognizer(known_faces_dir="known_faces") +# test_image_path = r"F:\OCR\RapidOCR-main\known_faces\B\14sino-qiu02-master1050.jpg" # recognizer.test_single_image(test_image_path, threshold=0.4) \ No newline at end of file diff --git a/ocr/model_violation_detector.py b/ocr/model_violation_detector.py index 92a19c7..d7958ae 100644 --- a/ocr/model_violation_detector.py +++ b/ocr/model_violation_detector.py @@ -1,35 +1,30 @@ -import cv2 -from logger_config import logger -from ocr_violation_detector import OCRViolationDetector -from yolo_violation_detector import ViolationDetector as YoloViolationDetector -from face_recognizer import FaceRecognizer +import os +import cv2 +import yaml +from pathlib import Path +from .ocr_violation_detector import OCRViolationDetector +from .yolo_violation_detector import ViolationDetector as YoloViolationDetector +from .face_recognizer import FaceRecognizer class MultiModelViolationDetector: """ - 多模型违规检测封装类,串行调用OCR、人脸识别和YOLO模型(调整为YOLO最后检测),任一模型检测到违规即返回结果 + 多模型违规检测封装类,串行调用OCR、人脸识别和YOLO模型,任一模型检测到违规即返回结果 """ def __init__(self, forbidden_words_path: str, - ocr_config_path: str, # 新增OCR配置文件路径参数 + ocr_config_path: str, yolo_model_path: str, known_faces_dir: str, ocr_confidence_threshold: float = 0.5): """ 初始化所有检测模型 - - Args: - forbidden_words_path: 违禁词文件路径 - ocr_config_path: OCR配置文件(1.yaml)路径 - yolo_model_path: YOLO模型文件路径 - known_faces_dir: 已知人脸目录路径 - ocr_confidence_threshold: OCR置信度阈值 """ - # 初始化OCR检测器(传入配置文件路径) + # 初始化OCR检测器 self.ocr_detector = OCRViolationDetector( forbidden_words_path=forbidden_words_path, - ocr_config_path=ocr_config_path, # 传递配置文件路径 + ocr_config_path=ocr_config_path, ocr_confidence_threshold=ocr_confidence_threshold ) @@ -38,22 +33,16 @@ class MultiModelViolationDetector: known_faces_dir=known_faces_dir ) - # 初始化YOLO检测器(调整为最后初始化) + # 初始化YOLO检测器 self.yolo_detector = YoloViolationDetector( model_path=yolo_model_path ) - logger.info("多模型违规检测器初始化完成") + print("多模型违规检测器初始化完成") def detect_violations(self, frame): """ 串行调用三个检测模型(OCR → 人脸识别 → YOLO),任一检测到违规即返回结果 - Args: - frame: 输入视频帧 (NumPy数组, BGR格式) - Returns: - tuple: (是否有违规, 违规类型, 违规详情) - 违规类型: 'ocr' | 'yolo' | 'face' | None - 违规详情: 对应模型的检测结果 """ # 1. 首先进行OCR违禁词检测 try: @@ -63,10 +52,10 @@ class MultiModelViolationDetector: "words": ocr_words, "confidences": ocr_confs } - logger.warning(f"OCR检测到违禁内容: {details}") + print(f"警告: OCR检测到违禁内容: {details}") return (True, "ocr", details) except Exception as e: - logger.error(f"OCR检测出错: {str(e)}", exc_info=True) + print(f"错误: OCR检测出错: {str(e)}") # 2. 接着进行人脸识别检测 try: @@ -76,58 +65,72 @@ class MultiModelViolationDetector: "name": face_name, "similarity": face_similarity } - logger.warning(f"人脸识别到违规人员: {details}") + print(f"警告: 人脸识别到违规人员: {details}") return (True, "face", details) except Exception as e: - logger.error(f"人脸识别出错: {str(e)}", exc_info=True) + print(f"错误: 人脸识别出错: {str(e)}") - # 3. 最后进行YOLO目标检测(调整为最后检测) + # 3. 最后进行YOLO目标检测 try: yolo_results = self.yolo_detector.detect(frame) - # 检查是否有检测结果(根据实际业务定义何为违规目标) if len(yolo_results.boxes) > 0: - # 提取检测到的目标信息 details = { "classes": yolo_results.names, - "boxes": yolo_results.boxes.xyxy.tolist(), # 边界框坐标 - "confidences": yolo_results.boxes.conf.tolist(), # 置信度 - "class_ids": yolo_results.boxes.cls.tolist() # 类别ID + "boxes": yolo_results.boxes.xyxy.tolist(), + "confidences": yolo_results.boxes.conf.tolist(), + "class_ids": yolo_results.boxes.cls.tolist() } - logger.warning(f"YOLO检测到违规目标: {details}") + print(f"警告: YOLO检测到违规目标: {details}") return (True, "yolo", details) except Exception as e: - logger.error(f"YOLO检测出错: {str(e)}", exc_info=True) + print(f"错误: YOLO检测出错: {str(e)}") # 所有检测均未发现违规 return (False, None, None) -# # 使用示例 +def load_config(config_path: str) -> dict: + """加载YAML配置文件""" + try: + with open(config_path, 'r', encoding='utf-8') as f: + return yaml.safe_load(f) + except FileNotFoundError: + print(f"错误: 配置文件未找到: {config_path}") + raise + except yaml.YAMLError as e: + print(f"错误: 配置文件格式错误: {config_path}, 错误: {str(e)}") + raise + except Exception as e: + print(f"错误: 加载配置文件出错: {str(e)}") + raise + + +# 使用示例 # if __name__ == "__main__": -# # 配置文件路径(根据实际情况修改) -# FORBIDDEN_WORDS_PATH = r"D:\Git\bin\video\ocr\forbidden_words.txt" -# OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml" # 新增OCR配置文件路径 -# YOLO_MODEL_PATH = r"D:\Git\bin\video\ocr\models\best.pt" -# KNOWN_FACES_DIR = r"D:\Git\bin\video\ocr\known_faces" +# # 加载配置文件 +# config = load_config("config.yaml") # 配置文件路径,可根据实际情况修改 # # # 初始化多模型检测器 # detector = MultiModelViolationDetector( -# forbidden_words_path=FORBIDDEN_WORDS_PATH, -# ocr_config_path=OCR_CONFIG_PATH, # 传入OCR配置文件路径 -# yolo_model_path=YOLO_MODEL_PATH, -# known_faces_dir=KNOWN_FACES_DIR, -# ocr_confidence_threshold=0.5 +# forbidden_words_path=config["forbidden_words_path"], +# ocr_config_path=config["ocr_config_path"], +# yolo_model_path=config["yolo_model_path"], +# known_faces_dir=config["known_faces_dir"], +# ocr_confidence_threshold=config.get("ocr_confidence_threshold", 0.5) # ) # # # 读取测试图像(可替换为视频帧读取逻辑) -# test_image_path = r"D:\Git\bin\video\ocr\images\img.png" -# frame = cv2.imread(test_image_path) +# test_image_path = config.get("test_image_path") # 从配置文件获取测试图片路径 +# if test_image_path: +# frame = cv2.imread(test_image_path) # -# if frame is not None: -# has_violation, violation_type, details = detector.detect_violations(frame) -# if has_violation: -# print(f"检测到违规 - 类型: {violation_type}, 详情: {details}") +# if frame is not None: +# has_violation, violation_type, details = detector.detect_violations(frame) +# if has_violation: +# print(f"检测到违规 - 类型: {violation_type}, 详情: {details}") +# else: +# print("未检测到任何违规内容") # else: -# print("未检测到任何违规内容") +# print(f"无法读取测试图像: {test_image_path}") # else: -# print(f"无法读取测试图像: {test_image_path}") \ No newline at end of file +# print("配置文件中未指定测试图像路径") \ No newline at end of file diff --git a/ocr/ocr_violation_detector.py b/ocr/ocr_violation_detector.py index bfb3407..ebbb068 100644 --- a/ocr/ocr_violation_detector.py +++ b/ocr/ocr_violation_detector.py @@ -1,6 +1,5 @@ import os import cv2 -import logging from rapidocr import RapidOCR @@ -13,153 +12,85 @@ class OCRViolationDetector: def __init__(self, forbidden_words_path: str, ocr_config_path: str, - ocr_confidence_threshold: float = 0.5, - log_level: int = logging.INFO, - log_file: str = None): + ocr_confidence_threshold: float = 0.5): """ - 初始化OCR引擎、违禁词列表和日志配置。 + 初始化OCR引擎和违禁词列表。 Args: forbidden_words_path (str): 违禁词列表 .txt 文件的路径。 ocr_config_path (str): OCR配置文件(如1.yaml)的路径。 ocr_confidence_threshold (float): OCR识别结果的置信度阈值(0~1)。 - log_level (int): 日志级别,默认为logging.INFO。 - log_file (str, optional): 日志文件路径,如不提供则只输出到控制台。 """ - # 初始化日志(确保先初始化日志,后续操作可正常打日志) - self.logger = self._setup_logger(log_level, log_file) - - # 加载违禁词(优先级:先加载配置,再初始化引擎) + # 加载违禁词 self.forbidden_words = self._load_forbidden_words(forbidden_words_path) - # 初始化RapidOCR引擎(传入配置文件路径) + # 初始化RapidOCR引擎 self.ocr_engine = self._initialize_ocr(ocr_config_path) # 校验核心依赖是否就绪 self._check_dependencies() - # 设置置信度阈值(限制在0~1范围,避免非法值) + # 设置置信度阈值(限制在0~1范围) self.OCR_CONFIDENCE_THRESHOLD = max(0.0, min(ocr_confidence_threshold, 1.0)) - self.logger.info(f"OCR置信度阈值已设置(范围0~1): {self.OCR_CONFIDENCE_THRESHOLD:.4f}") - - def _setup_logger(self, log_level: int, log_file: str = None) -> logging.Logger: - """ - 配置日志系统(避免重复添加处理器,支持控制台+文件双输出) - - Args: - log_level: 日志级别(如logging.DEBUG、logging.INFO)。 - log_file: 日志文件路径,为None时仅输出到控制台。 - - Returns: - logging.Logger: 配置好的日志实例。 - """ - logger = logging.getLogger('OCRViolationDetector') - logger.setLevel(log_level) - - # 避免重复添加处理器(防止日志重复输出) - if logger.handlers: - return logger - - # 定义日志格式(包含时间、模块名、级别、内容) - formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' - ) - - # 1. 添加控制台处理器 - console_handler = logging.StreamHandler() - console_handler.setFormatter(formatter) - logger.addHandler(console_handler) - - # 2. 若指定日志文件,添加文件处理器(自动创建目录) - if log_file: - try: - log_dir = os.path.dirname(log_file) - # 若日志目录不存在,自动创建 - if log_dir and not os.path.exists(log_dir): - os.makedirs(log_dir, exist_ok=True) - self.logger.debug(f"自动创建日志目录: {log_dir}") - - file_handler = logging.FileHandler(log_file, encoding='utf-8') - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) - logger.info(f"日志文件已配置: {log_file}") - except Exception as e: - logger.warning(f"创建日志文件失败(仅控制台输出): {str(e)}") - - return logger + print(f"OCR置信度阈值已设置(范围0~1): {self.OCR_CONFIDENCE_THRESHOLD:.4f}") def _load_forbidden_words(self, path: str) -> set: """ 从TXT文件加载违禁词(去重、过滤空行,支持UTF-8编码) - - Args: - path (str): 违禁词TXT文件路径。 - - Returns: - set: 去重后的违禁词集合(空集合表示加载失败)。 """ forbidden_words = set() - # 第一步:检查文件是否存在 + # 检查文件是否存在 if not os.path.exists(path): - self.logger.error(f"违禁词文件不存在: {path}") + print(f"错误:违禁词文件不存在: {path}") return forbidden_words - # 第二步:读取文件并处理内容 + # 读取文件并处理内容 try: with open(path, 'r', encoding='utf-8') as f: - # 过滤空行、去除首尾空格、去重 forbidden_words = { line.strip() for line in f if line.strip() # 跳过空行或纯空格行 } - self.logger.info(f"成功加载违禁词: {len(forbidden_words)} 个(已去重)") - self.logger.debug(f"违禁词列表: {forbidden_words}") + print(f"成功加载违禁词: {len(forbidden_words)} 个(已去重)") except UnicodeDecodeError: - self.logger.error(f"违禁词文件编码错误(需UTF-8): {path}") + print(f"错误:违禁词文件编码错误(需UTF-8): {path}") except PermissionError: - self.logger.error(f"无权限读取违禁词文件: {path}") + print(f"错误:无权限读取违禁词文件: {path}") except Exception as e: - self.logger.error(f"加载违禁词失败: {str(e)}", exc_info=True) + print(f"错误:加载违禁词失败: {str(e)}") return forbidden_words def _initialize_ocr(self, config_path: str) -> RapidOCR | None: """ 初始化RapidOCR引擎(校验配置文件、捕获初始化异常) - - Args: - config_path (str): RapidOCR配置文件(如1.yaml)路径。 - - Returns: - RapidOCR | None: OCR引擎实例(None表示初始化失败)。 """ - self.logger.info("开始初始化RapidOCR引擎...") + print("开始初始化RapidOCR引擎...") - # 第一步:检查配置文件是否存在 + # 检查配置文件是否存在 if not os.path.exists(config_path): - self.logger.error(f"OCR配置文件不存在: {config_path}") + print(f"错误:OCR配置文件不存在: {config_path}") return None - # 第二步:初始化OCR引擎(捕获RapidOCR相关异常) + # 初始化OCR引擎 try: ocr_engine = RapidOCR(config_path=config_path) - self.logger.info("RapidOCR引擎初始化成功") + print("RapidOCR引擎初始化成功") return ocr_engine except ImportError: - self.logger.error("RapidOCR依赖未安装(需执行:pip install rapidocr-onnxruntime)") + print("错误:RapidOCR依赖未安装(需执行:pip install rapidocr-onnxruntime)") except Exception as e: - self.logger.error(f"RapidOCR初始化失败: {str(e)}", exc_info=True) + print(f"错误:RapidOCR初始化失败: {str(e)}") return None def _check_dependencies(self) -> None: - """校验OCR引擎和违禁词列表是否就绪(输出警告日志)""" + """校验OCR引擎和违禁词列表是否就绪""" if not self.ocr_engine: - self.logger.warning("⚠️ OCR引擎未就绪,违禁词检测功能将禁用") + print("警告:⚠️ OCR引擎未就绪,违禁词检测功能将禁用") if not self.forbidden_words: - self.logger.warning("⚠️ 违禁词列表为空,违禁词检测功能将禁用") + print("警告:⚠️ 违禁词列表为空,违禁词检测功能将禁用") def detect(self, frame) -> tuple[bool, list, list]: """ @@ -179,76 +110,69 @@ class OCRViolationDetector: violation_words = [] violation_confs = [] - # 前置校验:1. 图像帧是否有效 2. OCR引擎是否就绪 3. 违禁词是否存在 + # 前置校验 if frame is None or frame.size == 0: - self.logger.warning("输入图像帧为空或无效,跳过OCR检测") + print("警告:输入图像帧为空或无效,跳过OCR检测") return has_violation, violation_words, violation_confs if not self.ocr_engine or not self.forbidden_words: - self.logger.debug("OCR引擎未就绪或违禁词为空,跳过OCR检测") + print("OCR引擎未就绪或违禁词为空,跳过OCR检测") return has_violation, violation_words, violation_confs try: - # 1. 执行OCR识别(获取RapidOCR原始结果) - self.logger.debug("开始执行OCR识别...") + # 执行OCR识别 + print("开始执行OCR识别...") ocr_result = self.ocr_engine(frame) - self.logger.debug(f"RapidOCR原始结果: {ocr_result}") + print(f"RapidOCR原始结果: {ocr_result}") - # 2. 校验OCR结果是否有效(避免None或格式异常) + # 校验OCR结果是否有效 if ocr_result is None: - self.logger.debug("OCR识别未返回任何结果(图像无文本或识别失败)") + print("OCR识别未返回任何结果(图像无文本或识别失败)") return has_violation, violation_words, violation_confs - # 3. 检查txts和scores是否存在且不为None + # 检查txts和scores是否存在且不为None if not hasattr(ocr_result, 'txts') or ocr_result.txts is None: - self.logger.warning("OCR结果中txts为None或不存在") + print("警告:OCR结果中txts为None或不存在") return has_violation, violation_words, violation_confs if not hasattr(ocr_result, 'scores') or ocr_result.scores is None: - self.logger.warning("OCR结果中scores为None或不存在") + print("警告:OCR结果中scores为None或不存在") return has_violation, violation_words, violation_confs - # 4. 转为列表并去None(防止单个元素为None) - # 确保txts是可迭代的,如果不是则转为空列表 + # 转为列表并去None if not isinstance(ocr_result.txts, (list, tuple)): - self.logger.warning(f"OCR txts不是可迭代类型,实际类型: {type(ocr_result.txts)}") + print(f"警告:OCR txts不是可迭代类型,实际类型: {type(ocr_result.txts)}") texts = [] else: texts = [txt.strip() for txt in ocr_result.txts if txt and isinstance(txt, str)] - # 确保scores是可迭代的,如果不是则转为空列表 if not isinstance(ocr_result.scores, (list, tuple)): - self.logger.warning(f"OCR scores不是可迭代类型,实际类型: {type(ocr_result.scores)}") + print(f"警告:OCR scores不是可迭代类型,实际类型: {type(ocr_result.scores)}") confidences = [] else: confidences = [conf for conf in ocr_result.scores if conf and isinstance(conf, (int, float))] - # 5. 校验文本和置信度列表长度是否一致(避免zip迭代错误) + # 校验文本和置信度列表长度是否一致 if len(texts) != len(confidences): - self.logger.warning( - f"OCR文本与置信度数量不匹配(文本{len(texts)}个,置信度{len(confidences)}个),跳过检测") + print(f"警告:OCR文本与置信度数量不匹配(文本{len(texts)}个,置信度{len(confidences)}个),跳过检测") return has_violation, violation_words, violation_confs if len(texts) == 0: - self.logger.debug("OCR未识别到任何有效文本") + print("OCR未识别到任何有效文本") return has_violation, violation_words, violation_confs - # 6. 遍历识别结果,筛选违禁词(按置信度阈值过滤) - self.logger.debug(f"开始筛选违禁词(阈值{self.OCR_CONFIDENCE_THRESHOLD:.4f})") + # 遍历识别结果,筛选违禁词 + print(f"开始筛选违禁词(阈值{self.OCR_CONFIDENCE_THRESHOLD:.4f})") for text, conf in zip(texts, confidences): - # 过滤低置信度结果 if conf < self.OCR_CONFIDENCE_THRESHOLD: - self.logger.debug(f"文本 '{text}' 置信度{conf:.4f} < 阈值,跳过") + print(f"文本 '{text}' 置信度{conf:.4f} < 阈值,跳过") continue - # 检查当前文本是否包含违禁词(支持一个文本含多个违禁词) matched_words = [word for word in self.forbidden_words if word in text] if matched_words: has_violation = True - # 记录所有匹配的违禁词和对应置信度 violation_words.extend(matched_words) - violation_confs.extend([conf] * len(matched_words)) # 一个文本对应多个违禁词时,置信度复用 - self.logger.warning(f"检测到违禁词: {matched_words}(来源文本: '{text}',置信度: {conf:.4f})") + violation_confs.extend([conf] * len(matched_words)) + print(f"警告:检测到违禁词: {matched_words}(来源文本: '{text}',置信度: {conf:.4f})") except Exception as e: - # 捕获所有异常,确保不中断上层调用 - self.logger.error(f"OCR检测过程异常: {str(e)}", exc_info=True) + print(f"错误:OCR检测过程异常: {str(e)}") - return has_violation, violation_words, violation_confs + return has_violation, violation_words, violation_confs \ No newline at end of file diff --git a/ocr/yolo_violation_detector.py b/ocr/yolo_violation_detector.py index be50afc..144503f 100644 --- a/ocr/yolo_violation_detector.py +++ b/ocr/yolo_violation_detector.py @@ -1,6 +1,5 @@ from ultralytics import YOLO import cv2 -from logger_config import logger class ViolationDetector: """ @@ -13,9 +12,9 @@ class ViolationDetector: Args: model_path (str): YOLO .pt模型的路径。 """ - logger.info(f"正在从 '{model_path}' 加载YOLO模型...") + print(f"正在从 '{model_path}' 加载YOLO模型...") self.model = YOLO(model_path) - logger.info("YOLO模型加载成功。") + print("YOLO模型加载成功。") def detect(self, frame): """ @@ -45,4 +44,4 @@ class ViolationDetector: """ # 使用YOLO自带的plot功能,方便快捷 annotated_frame = result.plot() - return annotated_frame + return annotated_frame \ No newline at end of file