ocr1.0
This commit is contained in:
		
							
								
								
									
										81
									
								
								core/rtmp.py
									
									
									
									
									
								
							
							
						
						
									
										81
									
								
								core/rtmp.py
									
									
									
									
									
								
							| @ -2,109 +2,101 @@ import asyncio | |||||||
| import logging | import logging | ||||||
| import cv2 | import cv2 | ||||||
| import time | import time | ||||||
| from ocr.ocr_violation_detector import OCRViolationDetector | from ocr.model_violation_detector import MultiModelViolationDetector | ||||||
|  |  | ||||||
| import logging |  | ||||||
|  | # 配置文件相对路径(根据实际目录结构调整) | ||||||
|  | YOLO_MODEL_PATH = "../ocr/models/best.pt"  # 关键修正:从core目录向上一级找ocr文件夹 | ||||||
|  | FORBIDDEN_WORDS_PATH = "../ocr/forbidden_words.txt" | ||||||
|  | OCR_CONFIG_PATH = "../ocr/config/1.yaml" | ||||||
|  | KNOWN_FACES_DIR = "../ocr/known_faces" | ||||||
|  |  | ||||||
| # 创建检测器实例 | # 创建检测器实例 | ||||||
| detector = OCRViolationDetector( | detector = MultiModelViolationDetector( | ||||||
|         forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt", |     forbidden_words_path=FORBIDDEN_WORDS_PATH, | ||||||
|         ocr_confidence_threshold=0.7, |     ocr_config_path=OCR_CONFIG_PATH, | ||||||
|         log_level=logging.INFO, |     yolo_model_path=YOLO_MODEL_PATH, | ||||||
|         log_file="ocr_detection.log" |     known_faces_dir=KNOWN_FACES_DIR, | ||||||
|  |     ocr_confidence_threshold=0.5 | ||||||
| ) | ) | ||||||
|  |  | ||||||
| # 配置日志(与WHEP代码保持一致的日志风格) | # 配置日志 | ||||||
| logging.basicConfig(level=logging.INFO) | logging.basicConfig(level=logging.INFO) | ||||||
| logger = logging.getLogger("rtmp_video_puller") | logger = logging.getLogger("rtmp_video_puller") | ||||||
|  |  | ||||||
|  |  | ||||||
| async def rtmp_pull_video_stream(rtmp_url): | async def rtmp_pull_video_stream(rtmp_url): | ||||||
|     """ |     """ | ||||||
|     通过RTMP从指定URL拉取视频流并在收到每一帧时打印消息 |     通过RTMP从指定URL拉取视频流并进行违规检测 | ||||||
|     功能与WHEP拉流函数对齐:流状态反馈、帧信息打印、帧率统计、异常处理 |  | ||||||
|  |  | ||||||
|     Args: |  | ||||||
|         rtmp_url: RTMP流的URL地址(如 rtmp://xxx/live/stream_key) |  | ||||||
|     """ |     """ | ||||||
|     cap = None  # 初始化视频捕获对象 |     cap = None  # 初始化视频捕获对象 | ||||||
|     try: |     try: | ||||||
|         # 1. 异步打开RTMP流(指定FFmpeg后端确保RTMP兼容性,同步操作通过to_thread避免阻塞事件循环) |         # 异步打开RTMP流 | ||||||
|         cap = await asyncio.to_thread( |         cap = await asyncio.to_thread( | ||||||
|             cv2.VideoCapture, |             cv2.VideoCapture, | ||||||
|             rtmp_url, |             rtmp_url, | ||||||
|             cv2.CAP_FFMPEG  # 必须指定FFmpeg后端,RTMP协议依赖该后端解析 |             cv2.CAP_FFMPEG  # 指定FFmpeg后端确保RTMP兼容性 | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         # 2. 检查RTMP流是否成功打开 |         # 检查RTMP流是否成功打开 | ||||||
|         is_opened = await asyncio.to_thread(cap.isOpened) |         is_opened = await asyncio.to_thread(cap.isOpened) | ||||||
|         if not is_opened: |         if not is_opened: | ||||||
|             raise Exception(f"RTMP流打开失败: {rtmp_url}(请检查URL有效性和FFmpeg环境)") |             raise Exception(f"RTMP流打开失败: {rtmp_url}(请检查URL有效性和FFmpeg环境)") | ||||||
|  |  | ||||||
|         # 3. 异步获取RTMP流基础信息(分辨率、帧率) |         # 获取RTMP流基础信息 | ||||||
|         width = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_WIDTH) |         width = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_WIDTH) | ||||||
|         height = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_HEIGHT) |         height = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_HEIGHT) | ||||||
|         fps = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FPS) |         fps = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FPS) | ||||||
|  |  | ||||||
|         # 处理异常情况:部分RTMP流未返回帧率时默认30FPS |         # 处理异常情况 | ||||||
|         fps = fps if fps > 0 else 30.0 |         fps = fps if fps > 0 else 30.0 | ||||||
|         # 分辨率转为整数(视频尺寸必然是整数) |  | ||||||
|         width, height = int(width), int(height) |         width, height = int(width), int(height) | ||||||
|  |  | ||||||
|         # 打印流初始化成功信息(与WHEP连接成功信息风格一致) |         # 打印流初始化成功信息 | ||||||
|         print(f"RTMP流状态: 已成功连接") |         print(f"RTMP流状态: 已成功连接") | ||||||
|         print(f"流基础信息: 分辨率 {width}x{height} | 配置帧率 {fps:.2f} FPS") |         print(f"流基础信息: 分辨率 {width}x{height} | 配置帧率 {fps:.2f} FPS") | ||||||
|         print("开始接收视频帧...(按 Ctrl+C 中断)") |         print("开始接收视频帧...(按 Ctrl+C 中断)") | ||||||
|  |  | ||||||
|         # 4. 初始化帧统计参数 |         # 初始化帧统计参数 | ||||||
|         frame_count = 0  # 总接收帧数 |         frame_count = 0 | ||||||
|         start_time = time.time()  # 统计起始时间 |         start_time = time.time() | ||||||
|  |  | ||||||
|         # 5. 循环异步读取视频帧(核心逻辑) |         # 循环读取视频帧 | ||||||
|         while True: |         while True: | ||||||
|             # 异步读取一帧(cv2.read是同步操作,用to_thread适配异步环境) |  | ||||||
|             ret, frame = await asyncio.to_thread(cap.read) |             ret, frame = await asyncio.to_thread(cap.read) | ||||||
|  |  | ||||||
|             # 检查帧是否读取成功(流中断/结束时ret为False) |  | ||||||
|             if not ret: |             if not ret: | ||||||
|                 print(f"RTMP流状态: 帧读取失败(可能流已中断或结束)") |                 print(f"RTMP流状态: 帧读取失败(可能流已中断或结束)") | ||||||
|                 break |                 break | ||||||
|  |  | ||||||
|             # 帧计数累加 |  | ||||||
|             frame_count += 1 |             frame_count += 1 | ||||||
|  |  | ||||||
|             # 6. 打印当前帧基础信息(与WHEP帧信息打印风格对齐) |             # 打印当前帧信息 | ||||||
|             print(f"收到帧 (第{frame_count}帧)") |             print(f"收到帧 (第{frame_count}帧)") | ||||||
|             print(f"  帧尺寸: {width}x{height}") |             print(f"  帧尺寸: {width}x{height}") | ||||||
|             print(f"  配置帧率: {fps:.2f} FPS") |             print(f"  配置帧率: {fps:.2f} FPS") | ||||||
|  |  | ||||||
|             has_violation, violations, confidences = OCRViolationDetector.detect(frame) |             if frame is not None: | ||||||
|  |                 has_violation, violation_type, details = detector.detect_violations(frame) | ||||||
|             # 输出检测结果 |                 if has_violation: | ||||||
|             if has_violation: |                     print(f"检测到违规 - 类型: {violation_type}, 详情: {details}") | ||||||
|                 detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:") |                 else: | ||||||
|                 for word, conf in zip(violations, confidences): |                     print("未检测到任何违规内容") | ||||||
|                     detector.logger.info(f"- {word} (置信度: {conf:.4f})") |  | ||||||
|             else: |             else: | ||||||
|                 detector.logger.info("图片中未检测到违禁词") |                 print(f"无法读取测试图像") | ||||||
|             # 7. 每100帧统计一次实际接收帧率(补充性能监控,与原RTMP示例逻辑一致) |  | ||||||
|  |             # 每100帧统计一次实际接收帧率 | ||||||
|             if frame_count % 100 == 0: |             if frame_count % 100 == 0: | ||||||
|                 elapsed_time = time.time() - start_time |                 elapsed_time = time.time() - start_time | ||||||
|                 actual_fps = frame_count / elapsed_time  # 实际接收帧率(可能低于配置帧率) |                 actual_fps = frame_count / elapsed_time | ||||||
|                 print(f"---- 帧统计: 累计{frame_count}帧 | 实际平均帧率 {actual_fps:.2f} FPS ----") |                 print(f"---- 帧统计: 累计{frame_count}帧 | 实际平均帧率 {actual_fps:.2f} FPS ----") | ||||||
|  |  | ||||||
|             # (可选)帧数据处理入口:如需处理帧(如推流、分析),可在此处添加逻辑 |  | ||||||
|             # 示例:yield frame (若需生成器模式,可调整函数为异步生成器) |  | ||||||
|  |  | ||||||
|     # 8. 异常处理(覆盖用户中断、通用错误) |  | ||||||
|     except KeyboardInterrupt: |     except KeyboardInterrupt: | ||||||
|         print(f"\n用户操作: 已通过 Ctrl+C 中断程序") |         print(f"\n用户操作: 已通过 Ctrl+C 中断程序") | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         # 日志记录详细错误(便于问题排查),同时打印用户可见信息 |  | ||||||
|         logger.error(f"RTMP流处理异常: {str(e)}", exc_info=True) |         logger.error(f"RTMP流处理异常: {str(e)}", exc_info=True) | ||||||
|         print(f"错误信息: {str(e)}") |         print(f"错误信息: {str(e)}") | ||||||
|     finally: |     finally: | ||||||
|         # 9. 资源释放(无论成功/失败都确保释放,避免内存泄漏) |  | ||||||
|         if cap is not None: |         if cap is not None: | ||||||
|             await asyncio.to_thread(cap.release) |             await asyncio.to_thread(cap.release) | ||||||
|             print(f"\n资源释放: RTMP流已关闭") |             print(f"\n资源释放: RTMP流已关闭") | ||||||
| @ -114,7 +106,6 @@ async def rtmp_pull_video_stream(rtmp_url): | |||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     RTMP_URL = "rtmp://192.168.110.25:1935/live/473b95a47e338301cbd96809ea7ac416" |     RTMP_URL = "rtmp://192.168.110.25:1935/live/473b95a47e338301cbd96809ea7ac416" | ||||||
|  |  | ||||||
|     # 运行RTMP拉流任务(与WHEP一致的异步执行方式) |  | ||||||
|     try: |     try: | ||||||
|         asyncio.run(rtmp_pull_video_stream(RTMP_URL)) |         asyncio.run(rtmp_pull_video_stream(RTMP_URL)) | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|  | |||||||
| @ -4,10 +4,12 @@ import numpy as np | |||||||
| import insightface | import insightface | ||||||
| from insightface.app import FaceAnalysis | from insightface.app import FaceAnalysis | ||||||
|  |  | ||||||
|  |  | ||||||
| class FaceRecognizer: | class FaceRecognizer: | ||||||
|     """ |     """ | ||||||
|     封装InsightFace人脸识别功能,支持从文件夹加载已知人脸。 |     封装InsightFace人脸识别功能,支持从文件夹加载已知人脸。 | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, known_faces_dir: str): |     def __init__(self, known_faces_dir: str): | ||||||
|         self.known_faces_dir = known_faces_dir |         self.known_faces_dir = known_faces_dir | ||||||
|         self.app = self._initialize_insightface() |         self.app = self._initialize_insightface() | ||||||
| @ -16,40 +18,30 @@ class FaceRecognizer: | |||||||
|         self._load_known_faces() |         self._load_known_faces() | ||||||
|  |  | ||||||
|     def _initialize_insightface(self): |     def _initialize_insightface(self): | ||||||
|         """ |         """初始化InsightFace FaceAnalysis应用""" | ||||||
|         初始化InsightFace FaceAnalysis应用。 |         print("初始化InsightFace引擎...") | ||||||
|         默认使用CPU,如果检测到CUDA,会自动使用GPU。 |  | ||||||
|         """ |  | ||||||
|         print("正在初始化InsightFace人脸识别引擎...") |  | ||||||
|         try: |         try: | ||||||
|             # 默认模型是 'buffalo_l',包含检测、对齐、识别功能 |             app = FaceAnalysis(name='buffalo_l', root='~/.insightface') | ||||||
|             # 如果需要更小的模型,可以尝试 'buffalo_s' 或 'buffalo_m' |             app.prepare(ctx_id=0, det_size=(640, 640)) | ||||||
|             # ctx_id=0 表示使用GPU,ctx_id=-1 表示使用CPU |             print("InsightFace引擎初始化完成") | ||||||
|             # InsightFace会自动检测CUDA并选择GPU,所以通常不需要手动设置ctx_id |  | ||||||
|             app = FaceAnalysis(name='buffalo_l', root='~/.insightface') # 模型下载到用户目录 |  | ||||||
|             app.prepare(ctx_id=0, det_size=(640, 640)) # det_size影响检测性能和精度 |  | ||||||
|             print("InsightFace人脸识别引擎初始化成功。") |  | ||||||
|             return app |             return app | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             print(f"InsightFace人脸识别引擎初始化失败: {e}") |             print(f"InsightFace初始化失败: {e}") | ||||||
|             print("请确保已安装insightface和onnxruntime,并且模型文件已下载或可访问。") |             print("请检查依赖是否安装及模型是否可访问") | ||||||
|             return None |             return None | ||||||
|  |  | ||||||
|     def _load_known_faces(self): |     def _load_known_faces(self): | ||||||
|         """ |         """加载已知人脸特征""" | ||||||
|         扫描已知人脸目录,加载每个人的照片并计算人脸特征。 |  | ||||||
|         """ |  | ||||||
|         if not os.path.exists(self.known_faces_dir): |         if not os.path.exists(self.known_faces_dir): | ||||||
|             print(f"警告: 已知人脸目录 '{self.known_faces_dir}' 不存在。请创建并放入照片。") |             print(f"已知人脸目录不存在,已创建: {self.known_faces_dir}") | ||||||
|             os.makedirs(self.known_faces_dir, exist_ok=True) |             os.makedirs(self.known_faces_dir, exist_ok=True) | ||||||
|             return |             return | ||||||
|  |  | ||||||
|         print(f"正在加载已知人脸特征从: '{self.known_faces_dir}'...") |         print(f"从目录加载人脸特征: {self.known_faces_dir}") | ||||||
|         for person_name in os.listdir(self.known_faces_dir): |         for person_name in os.listdir(self.known_faces_dir): | ||||||
|  |  | ||||||
|             person_dir = os.path.join(self.known_faces_dir, person_name) |             person_dir = os.path.join(self.known_faces_dir, person_name) | ||||||
|             if os.path.isdir(person_dir): |             if os.path.isdir(person_dir): | ||||||
|                 print(f"  加载人物: {person_name}") |                 print(f"处理人物: {person_name}") | ||||||
|                 embeddings = [] |                 embeddings = [] | ||||||
|                 for filename in os.listdir(person_dir): |                 for filename in os.listdir(person_dir): | ||||||
|                     if filename.lower().endswith(('.png', '.jpg', '.jpeg')): |                     if filename.lower().endswith(('.png', '.jpg', '.jpeg')): | ||||||
| @ -57,131 +49,91 @@ class FaceRecognizer: | |||||||
|                         try: |                         try: | ||||||
|                             img = cv2.imread(image_path) |                             img = cv2.imread(image_path) | ||||||
|                             if img is None: |                             if img is None: | ||||||
|                                 print(f"    警告: 无法读取图片 '{image_path}',已跳过。") |                                 print(f"无法读取图片: {image_path},已跳过") | ||||||
|                                 continue |                                 continue | ||||||
|  |  | ||||||
|                             # 查找人脸并提取特征 |  | ||||||
|                             faces = self.app.get(img) |                             faces = self.app.get(img) | ||||||
|                             if faces: |                             if faces: | ||||||
|                                 # 通常一张照片只有一个人脸,取第一个 |  | ||||||
|                                 embeddings.append(faces[0].embedding) |                                 embeddings.append(faces[0].embedding) | ||||||
|                                 print(f"    成功提取 '{filename}' 的人脸特征。") |                                 print(f"提取特征成功: {filename}") | ||||||
|                             else: |                             else: | ||||||
|                                 print(f"    警告: 在图片 '{filename}' 中未检测到人脸,已跳过。") |                                 print(f"未检测到人脸: {filename},已跳过") | ||||||
|                         except Exception as e: |                         except Exception as e: | ||||||
|                             print(f"    处理图片 '{image_path}' 时发生错误: {e}") |                             print(f"处理图片出错 {image_path}: {e}") | ||||||
|  |  | ||||||
|                 if embeddings: |                 if embeddings: | ||||||
|                     # 将多张照片的特征取平均,作为该人物的最终特征 |  | ||||||
|                     self.known_faces_embeddings[person_name] = np.array(embeddings).mean(axis=0) |                     self.known_faces_embeddings[person_name] = np.array(embeddings).mean(axis=0) | ||||||
|                     self.known_faces_names.append(person_name) |                     self.known_faces_names.append(person_name) | ||||||
|                     print(f"  人物 '{person_name}' 加载完成,共 {len(embeddings)} 张照片。") |                     print(f"人物 {person_name} 加载完成,共 {len(embeddings)} 张照片") | ||||||
|                 else: |                 else: | ||||||
|                     print(f"  警告: 人物 '{person_name}' 没有有效的人脸特征,已跳过。") |                     print(f"人物 {person_name} 无有效特征,已跳过") | ||||||
|         print(f"已知人脸加载完成。共 {len(self.known_faces_names)} 个人物。") |         print(f"人脸加载完成,共 {len(self.known_faces_names)} 人") | ||||||
|  |  | ||||||
|     def recognize(self, frame, threshold=0.4): |     def recognize(self, frame, threshold=0.4): | ||||||
|         """ |         """识别人脸并返回结果""" | ||||||
|         在视频帧中识别人脸。 |  | ||||||
|  |  | ||||||
|         Args: |  | ||||||
|             frame: 输入的图像帧 (NumPy数组, BGR格式)。 |  | ||||||
|             threshold (float): 识别相似度阈值。0.0到1.0,越高越严格。 |  | ||||||
|  |  | ||||||
|         Returns: |  | ||||||
|             tuple[bool, str|None, float|None]: 是否检测到已知人脸,检测到的人名,以及相似度。 |  | ||||||
|         """ |  | ||||||
|         if not self.app or not self.known_faces_names: |         if not self.app or not self.known_faces_names: | ||||||
|             return False, None, None |             return False, None, None | ||||||
|  |  | ||||||
|         faces = self.app.get(frame) # 在帧中检测并提取所有人的脸 |         faces = self.app.get(frame) | ||||||
|         if not faces: |         if not faces: | ||||||
|             return False, None, None |             return False, None, None | ||||||
|  |  | ||||||
|         for face in faces: |         for face in faces: | ||||||
|             # 遍历已知人脸,进行比对 |  | ||||||
|             for known_name in self.known_faces_names: |             for known_name in self.known_faces_names: | ||||||
|                 known_embedding = self.known_faces_embeddings[known_name] |                 known_embedding = self.known_faces_embeddings[known_name] | ||||||
|  |  | ||||||
|                 # --- 关键修改:手动计算余弦相似度 --- |  | ||||||
|                 # 确保embedding是float32类型,避免潜在的类型不匹配问题 |  | ||||||
|                 embedding1 = face.embedding.astype(np.float32) |                 embedding1 = face.embedding.astype(np.float32) | ||||||
|                 embedding2 = known_embedding.astype(np.float32) |                 embedding2 = known_embedding.astype(np.float32) | ||||||
|  |  | ||||||
|                 # 计算点积 |  | ||||||
|                 dot_product = np.dot(embedding1, embedding2) |                 dot_product = np.dot(embedding1, embedding2) | ||||||
|                 # 计算L2范数(向量长度) |  | ||||||
|                 norm_embedding1 = np.linalg.norm(embedding1) |                 norm_embedding1 = np.linalg.norm(embedding1) | ||||||
|                 norm_embedding2 = np.linalg.norm(embedding2) |                 norm_embedding2 = np.linalg.norm(embedding2) | ||||||
|  |  | ||||||
|                 # 避免除以零 |                 similarity = 0.0 if (norm_embedding1 == 0 or norm_embedding2 == 0) else ( | ||||||
|                 if norm_embedding1 == 0 or norm_embedding2 == 0: |                         dot_product / (norm_embedding1 * norm_embedding2) | ||||||
|                     similarity = 0.0 |                 ) | ||||||
|                 else: |  | ||||||
|                     similarity = dot_product / (norm_embedding1 * norm_embedding2) |  | ||||||
|                 # ------------------------------------- |  | ||||||
|  |  | ||||||
|                 if similarity >= threshold: |                 if similarity >= threshold: | ||||||
|                     print(f"!!! 人脸识别检测到已知人物: '{known_name}' (相似度: {similarity:.4f}) !!!") |                     print(f"检测到已知人物: {known_name} (相似度: {similarity:.4f})") | ||||||
|                     return True, known_name, similarity # 只要检测到一个就立即返回 |                     return True, known_name, similarity | ||||||
|  |  | ||||||
|         return False, None, None # 没有检测到已知人脸 |         return False, None, None | ||||||
|  |  | ||||||
|  |     def test_single_image(self, image_path: str, threshold=0.4): | ||||||
|  |         """测试单张图片识别""" | ||||||
|  |         if not os.path.exists(image_path): | ||||||
|  |             print(f"图片不存在: {image_path}") | ||||||
|  |             return False, None, None | ||||||
|  |  | ||||||
|     # def test_single_image(self, image_path: str, threshold=0.4): |         frame = cv2.imread(image_path) | ||||||
|     #     """ |         if frame is None: | ||||||
|     #     测试单张图片的人脸识别效果 |             print(f"无法读取图片: {image_path}") | ||||||
|     # |             return False, None, None | ||||||
|     #     Args: |  | ||||||
|     #         image_path: 图片路径 |  | ||||||
|     #         threshold: 识别阈值 |  | ||||||
|     # |  | ||||||
|     #     Returns: |  | ||||||
|     #         tuple[bool, str|None, float|None]: 是否检测到已知人脸,检测到的人名,以及相似度 |  | ||||||
|     #     """ |  | ||||||
|     #     if not os.path.exists(image_path): |  | ||||||
|     #         print(f"错误: 图片 '{image_path}' 不存在") |  | ||||||
|     #         return False, None, None |  | ||||||
|     # |  | ||||||
|     #     # 读取图片 |  | ||||||
|     #     frame = cv2.imread(image_path) |  | ||||||
|     #     if frame is None: |  | ||||||
|     #         print(f"错误: 无法读取图片 '{image_path}'") |  | ||||||
|     #         return False, None, None |  | ||||||
|     # |  | ||||||
|     #     # 调用识别方法 |  | ||||||
|     #     result, name, similarity = self.recognize(frame, threshold) |  | ||||||
|     # |  | ||||||
|     #     # 显示结果 |  | ||||||
|     #     if result: |  | ||||||
|     #         print(f"测试结果: 在图片中识别到 {name},相似度: {similarity:.4f}") |  | ||||||
|     # |  | ||||||
|     #         # 绘制识别结果并显示图片 |  | ||||||
|     #         faces = self.app.get(frame) |  | ||||||
|     #         for face in faces: |  | ||||||
|     #             bbox = face.bbox.astype(int) |  | ||||||
|     #             # 绘制 bounding box |  | ||||||
|     #             cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2) |  | ||||||
|     #             # 绘制姓名和相似度 |  | ||||||
|     #             text = f"{name}: {similarity:.2f}" |  | ||||||
|     #             cv2.putText(frame, text, (bbox[0], bbox[1] - 10), |  | ||||||
|     #                         cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) |  | ||||||
|     # |  | ||||||
|     #         # 显示图片 |  | ||||||
|     #         cv2.imshow('Recognition Result', frame) |  | ||||||
|     #         print("按任意键关闭图片窗口...") |  | ||||||
|     #         cv2.waitKey(0) |  | ||||||
|     #         cv2.destroyAllWindows() |  | ||||||
|     #     else: |  | ||||||
|     #         print("测试结果: 未在图片中识别到已知人脸") |  | ||||||
|     # |  | ||||||
|     #     return result, name, similarity |  | ||||||
|  |  | ||||||
|  |         result, name, similarity = self.recognize(frame, threshold) | ||||||
|  |  | ||||||
|  |         if result: | ||||||
|  |             print(f"识别结果: {name} (相似度: {similarity:.4f})") | ||||||
|  |  | ||||||
|  |             faces = self.app.get(frame) | ||||||
|  |             for face in faces: | ||||||
|  |                 bbox = face.bbox.astype(int) | ||||||
|  |                 cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2) | ||||||
|  |                 text = f"{name}: {similarity:.2f}" | ||||||
|  |                 cv2.putText(frame, text, (bbox[0], bbox[1] - 10), | ||||||
|  |                             cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) | ||||||
|  |  | ||||||
|  |             cv2.imshow('识别结果', frame) | ||||||
|  |             print("按任意键关闭窗口...") | ||||||
|  |             cv2.waitKey(0) | ||||||
|  |             cv2.destroyAllWindows() | ||||||
|  |         else: | ||||||
|  |             print("未识别到已知人脸") | ||||||
|  |  | ||||||
|  |         return result, name, similarity | ||||||
|  |  | ||||||
| # if __name__ == "__main__": |  | ||||||
| #     # 初始化人脸识别器,指定已知人脸目录 |  | ||||||
| #     recognizer = FaceRecognizer(known_faces_dir="known_faces") |  | ||||||
| # | # | ||||||
| #     # 测试单张图片 | # if __name__ == "__main__": | ||||||
| #     test_image_path = r"F:\OCR\RapidOCR-main\known_faces\B\14sino-qiu02-master1050.jpg"  # 替换为你的测试图片路径 | #     recognizer = FaceRecognizer(known_faces_dir="known_faces") | ||||||
|  | #     test_image_path = r"F:\OCR\RapidOCR-main\known_faces\B\14sino-qiu02-master1050.jpg" | ||||||
| #     recognizer.test_single_image(test_image_path, threshold=0.4) | #     recognizer.test_single_image(test_image_path, threshold=0.4) | ||||||
| @ -1,35 +1,30 @@ | |||||||
| import cv2 | import os | ||||||
| from logger_config import logger |  | ||||||
| from ocr_violation_detector import OCRViolationDetector |  | ||||||
| from yolo_violation_detector import ViolationDetector as YoloViolationDetector |  | ||||||
| from face_recognizer import FaceRecognizer |  | ||||||
|  |  | ||||||
|  | import cv2 | ||||||
|  | import yaml | ||||||
|  | from pathlib import Path | ||||||
|  | from .ocr_violation_detector import OCRViolationDetector | ||||||
|  | from .yolo_violation_detector import ViolationDetector as YoloViolationDetector | ||||||
|  | from .face_recognizer import FaceRecognizer | ||||||
|  |  | ||||||
| class MultiModelViolationDetector: | class MultiModelViolationDetector: | ||||||
|     """ |     """ | ||||||
|     多模型违规检测封装类,串行调用OCR、人脸识别和YOLO模型(调整为YOLO最后检测),任一模型检测到违规即返回结果 |     多模型违规检测封装类,串行调用OCR、人脸识别和YOLO模型,任一模型检测到违规即返回结果 | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, |     def __init__(self, | ||||||
|                  forbidden_words_path: str, |                  forbidden_words_path: str, | ||||||
|                  ocr_config_path: str,  # 新增OCR配置文件路径参数 |                  ocr_config_path: str, | ||||||
|                  yolo_model_path: str, |                  yolo_model_path: str, | ||||||
|                  known_faces_dir: str, |                  known_faces_dir: str, | ||||||
|                  ocr_confidence_threshold: float = 0.5): |                  ocr_confidence_threshold: float = 0.5): | ||||||
|         """ |         """ | ||||||
|         初始化所有检测模型 |         初始化所有检测模型 | ||||||
|  |  | ||||||
|         Args: |  | ||||||
|             forbidden_words_path: 违禁词文件路径 |  | ||||||
|             ocr_config_path: OCR配置文件(1.yaml)路径 |  | ||||||
|             yolo_model_path: YOLO模型文件路径 |  | ||||||
|             known_faces_dir: 已知人脸目录路径 |  | ||||||
|             ocr_confidence_threshold: OCR置信度阈值 |  | ||||||
|         """ |         """ | ||||||
|         # 初始化OCR检测器(传入配置文件路径) |         # 初始化OCR检测器 | ||||||
|         self.ocr_detector = OCRViolationDetector( |         self.ocr_detector = OCRViolationDetector( | ||||||
|             forbidden_words_path=forbidden_words_path, |             forbidden_words_path=forbidden_words_path, | ||||||
|             ocr_config_path=ocr_config_path,  # 传递配置文件路径 |             ocr_config_path=ocr_config_path, | ||||||
|             ocr_confidence_threshold=ocr_confidence_threshold |             ocr_confidence_threshold=ocr_confidence_threshold | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
| @ -38,22 +33,16 @@ class MultiModelViolationDetector: | |||||||
|             known_faces_dir=known_faces_dir |             known_faces_dir=known_faces_dir | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         # 初始化YOLO检测器(调整为最后初始化) |         # 初始化YOLO检测器 | ||||||
|         self.yolo_detector = YoloViolationDetector( |         self.yolo_detector = YoloViolationDetector( | ||||||
|             model_path=yolo_model_path |             model_path=yolo_model_path | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         logger.info("多模型违规检测器初始化完成") |         print("多模型违规检测器初始化完成") | ||||||
|  |  | ||||||
|     def detect_violations(self, frame): |     def detect_violations(self, frame): | ||||||
|         """ |         """ | ||||||
|         串行调用三个检测模型(OCR → 人脸识别 → YOLO),任一检测到违规即返回结果 |         串行调用三个检测模型(OCR → 人脸识别 → YOLO),任一检测到违规即返回结果 | ||||||
|         Args: |  | ||||||
|             frame: 输入视频帧 (NumPy数组, BGR格式) |  | ||||||
|         Returns: |  | ||||||
|             tuple: (是否有违规, 违规类型, 违规详情) |  | ||||||
|                   违规类型: 'ocr' | 'yolo' | 'face' | None |  | ||||||
|                   违规详情: 对应模型的检测结果 |  | ||||||
|         """ |         """ | ||||||
|         # 1. 首先进行OCR违禁词检测 |         # 1. 首先进行OCR违禁词检测 | ||||||
|         try: |         try: | ||||||
| @ -63,10 +52,10 @@ class MultiModelViolationDetector: | |||||||
|                     "words": ocr_words, |                     "words": ocr_words, | ||||||
|                     "confidences": ocr_confs |                     "confidences": ocr_confs | ||||||
|                 } |                 } | ||||||
|                 logger.warning(f"OCR检测到违禁内容: {details}") |                 print(f"警告: OCR检测到违禁内容: {details}") | ||||||
|                 return (True, "ocr", details) |                 return (True, "ocr", details) | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             logger.error(f"OCR检测出错: {str(e)}", exc_info=True) |             print(f"错误: OCR检测出错: {str(e)}") | ||||||
|  |  | ||||||
|         # 2. 接着进行人脸识别检测 |         # 2. 接着进行人脸识别检测 | ||||||
|         try: |         try: | ||||||
| @ -76,58 +65,72 @@ class MultiModelViolationDetector: | |||||||
|                     "name": face_name, |                     "name": face_name, | ||||||
|                     "similarity": face_similarity |                     "similarity": face_similarity | ||||||
|                 } |                 } | ||||||
|                 logger.warning(f"人脸识别到违规人员: {details}") |                 print(f"警告: 人脸识别到违规人员: {details}") | ||||||
|                 return (True, "face", details) |                 return (True, "face", details) | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             logger.error(f"人脸识别出错: {str(e)}", exc_info=True) |             print(f"错误: 人脸识别出错: {str(e)}") | ||||||
|  |  | ||||||
|         # 3. 最后进行YOLO目标检测(调整为最后检测) |         # 3. 最后进行YOLO目标检测 | ||||||
|         try: |         try: | ||||||
|             yolo_results = self.yolo_detector.detect(frame) |             yolo_results = self.yolo_detector.detect(frame) | ||||||
|             # 检查是否有检测结果(根据实际业务定义何为违规目标) |  | ||||||
|             if len(yolo_results.boxes) > 0: |             if len(yolo_results.boxes) > 0: | ||||||
|                 # 提取检测到的目标信息 |  | ||||||
|                 details = { |                 details = { | ||||||
|                     "classes": yolo_results.names, |                     "classes": yolo_results.names, | ||||||
|                     "boxes": yolo_results.boxes.xyxy.tolist(),  # 边界框坐标 |                     "boxes": yolo_results.boxes.xyxy.tolist(), | ||||||
|                     "confidences": yolo_results.boxes.conf.tolist(),  # 置信度 |                     "confidences": yolo_results.boxes.conf.tolist(), | ||||||
|                     "class_ids": yolo_results.boxes.cls.tolist()  # 类别ID |                     "class_ids": yolo_results.boxes.cls.tolist() | ||||||
|                 } |                 } | ||||||
|                 logger.warning(f"YOLO检测到违规目标: {details}") |                 print(f"警告: YOLO检测到违规目标: {details}") | ||||||
|                 return (True, "yolo", details) |                 return (True, "yolo", details) | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             logger.error(f"YOLO检测出错: {str(e)}", exc_info=True) |             print(f"错误: YOLO检测出错: {str(e)}") | ||||||
|  |  | ||||||
|         # 所有检测均未发现违规 |         # 所有检测均未发现违规 | ||||||
|         return (False, None, None) |         return (False, None, None) | ||||||
|  |  | ||||||
|  |  | ||||||
| # # 使用示例 | def load_config(config_path: str) -> dict: | ||||||
|  |     """加载YAML配置文件""" | ||||||
|  |     try: | ||||||
|  |         with open(config_path, 'r', encoding='utf-8') as f: | ||||||
|  |             return yaml.safe_load(f) | ||||||
|  |     except FileNotFoundError: | ||||||
|  |         print(f"错误: 配置文件未找到: {config_path}") | ||||||
|  |         raise | ||||||
|  |     except yaml.YAMLError as e: | ||||||
|  |         print(f"错误: 配置文件格式错误: {config_path}, 错误: {str(e)}") | ||||||
|  |         raise | ||||||
|  |     except Exception as e: | ||||||
|  |         print(f"错误: 加载配置文件出错: {str(e)}") | ||||||
|  |         raise | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 使用示例 | ||||||
| # if __name__ == "__main__": | # if __name__ == "__main__": | ||||||
| #     # 配置文件路径(根据实际情况修改) | #     # 加载配置文件 | ||||||
| #     FORBIDDEN_WORDS_PATH = r"D:\Git\bin\video\ocr\forbidden_words.txt" | #     config = load_config("config.yaml")  # 配置文件路径,可根据实际情况修改 | ||||||
| #     OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml"  # 新增OCR配置文件路径 |  | ||||||
| #     YOLO_MODEL_PATH = r"D:\Git\bin\video\ocr\models\best.pt" |  | ||||||
| #     KNOWN_FACES_DIR = r"D:\Git\bin\video\ocr\known_faces" |  | ||||||
| # | # | ||||||
| #     # 初始化多模型检测器 | #     # 初始化多模型检测器 | ||||||
| #     detector = MultiModelViolationDetector( | #     detector = MultiModelViolationDetector( | ||||||
| #         forbidden_words_path=FORBIDDEN_WORDS_PATH, | #         forbidden_words_path=config["forbidden_words_path"], | ||||||
| #         ocr_config_path=OCR_CONFIG_PATH,  # 传入OCR配置文件路径 | #         ocr_config_path=config["ocr_config_path"], | ||||||
| #         yolo_model_path=YOLO_MODEL_PATH, | #         yolo_model_path=config["yolo_model_path"], | ||||||
| #         known_faces_dir=KNOWN_FACES_DIR, | #         known_faces_dir=config["known_faces_dir"], | ||||||
| #         ocr_confidence_threshold=0.5 | #         ocr_confidence_threshold=config.get("ocr_confidence_threshold", 0.5) | ||||||
| #     ) | #     ) | ||||||
| # | # | ||||||
| #     # 读取测试图像(可替换为视频帧读取逻辑) | #     # 读取测试图像(可替换为视频帧读取逻辑) | ||||||
| #     test_image_path = r"D:\Git\bin\video\ocr\images\img.png" | #     test_image_path = config.get("test_image_path")  # 从配置文件获取测试图片路径 | ||||||
| #     frame = cv2.imread(test_image_path) | #     if test_image_path: | ||||||
|  | #         frame = cv2.imread(test_image_path) | ||||||
| # | # | ||||||
| #     if frame is not None: | #         if frame is not None: | ||||||
| #         has_violation, violation_type, details = detector.detect_violations(frame) | #             has_violation, violation_type, details = detector.detect_violations(frame) | ||||||
| #         if has_violation: | #             if has_violation: | ||||||
| #             print(f"检测到违规 - 类型: {violation_type}, 详情: {details}") | #                 print(f"检测到违规 - 类型: {violation_type}, 详情: {details}") | ||||||
|  | #             else: | ||||||
|  | #                 print("未检测到任何违规内容") | ||||||
| #         else: | #         else: | ||||||
| #             print("未检测到任何违规内容") | #             print(f"无法读取测试图像: {test_image_path}") | ||||||
| #     else: | #     else: | ||||||
| #         print(f"无法读取测试图像: {test_image_path}") | #         print("配置文件中未指定测试图像路径") | ||||||
| @ -1,6 +1,5 @@ | |||||||
| import os | import os | ||||||
| import cv2 | import cv2 | ||||||
| import logging |  | ||||||
| from rapidocr import RapidOCR | from rapidocr import RapidOCR | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -13,153 +12,85 @@ class OCRViolationDetector: | |||||||
|     def __init__(self, |     def __init__(self, | ||||||
|                  forbidden_words_path: str, |                  forbidden_words_path: str, | ||||||
|                  ocr_config_path: str, |                  ocr_config_path: str, | ||||||
|                  ocr_confidence_threshold: float = 0.5, |                  ocr_confidence_threshold: float = 0.5): | ||||||
|                  log_level: int = logging.INFO, |  | ||||||
|                  log_file: str = None): |  | ||||||
|         """ |         """ | ||||||
|         初始化OCR引擎、违禁词列表和日志配置。 |         初始化OCR引擎和违禁词列表。 | ||||||
|  |  | ||||||
|         Args: |         Args: | ||||||
|             forbidden_words_path (str): 违禁词列表 .txt 文件的路径。 |             forbidden_words_path (str): 违禁词列表 .txt 文件的路径。 | ||||||
|             ocr_config_path (str): OCR配置文件(如1.yaml)的路径。 |             ocr_config_path (str): OCR配置文件(如1.yaml)的路径。 | ||||||
|             ocr_confidence_threshold (float): OCR识别结果的置信度阈值(0~1)。 |             ocr_confidence_threshold (float): OCR识别结果的置信度阈值(0~1)。 | ||||||
|             log_level (int): 日志级别,默认为logging.INFO。 |  | ||||||
|             log_file (str, optional): 日志文件路径,如不提供则只输出到控制台。 |  | ||||||
|         """ |         """ | ||||||
|         # 初始化日志(确保先初始化日志,后续操作可正常打日志) |         # 加载违禁词 | ||||||
|         self.logger = self._setup_logger(log_level, log_file) |  | ||||||
|  |  | ||||||
|         # 加载违禁词(优先级:先加载配置,再初始化引擎) |  | ||||||
|         self.forbidden_words = self._load_forbidden_words(forbidden_words_path) |         self.forbidden_words = self._load_forbidden_words(forbidden_words_path) | ||||||
|  |  | ||||||
|         # 初始化RapidOCR引擎(传入配置文件路径) |         # 初始化RapidOCR引擎 | ||||||
|         self.ocr_engine = self._initialize_ocr(ocr_config_path) |         self.ocr_engine = self._initialize_ocr(ocr_config_path) | ||||||
|  |  | ||||||
|         # 校验核心依赖是否就绪 |         # 校验核心依赖是否就绪 | ||||||
|         self._check_dependencies() |         self._check_dependencies() | ||||||
|  |  | ||||||
|         # 设置置信度阈值(限制在0~1范围,避免非法值) |         # 设置置信度阈值(限制在0~1范围) | ||||||
|         self.OCR_CONFIDENCE_THRESHOLD = max(0.0, min(ocr_confidence_threshold, 1.0)) |         self.OCR_CONFIDENCE_THRESHOLD = max(0.0, min(ocr_confidence_threshold, 1.0)) | ||||||
|         self.logger.info(f"OCR置信度阈值已设置(范围0~1): {self.OCR_CONFIDENCE_THRESHOLD:.4f}") |         print(f"OCR置信度阈值已设置(范围0~1): {self.OCR_CONFIDENCE_THRESHOLD:.4f}") | ||||||
|  |  | ||||||
|     def _setup_logger(self, log_level: int, log_file: str = None) -> logging.Logger: |  | ||||||
|         """ |  | ||||||
|         配置日志系统(避免重复添加处理器,支持控制台+文件双输出) |  | ||||||
|  |  | ||||||
|         Args: |  | ||||||
|             log_level: 日志级别(如logging.DEBUG、logging.INFO)。 |  | ||||||
|             log_file: 日志文件路径,为None时仅输出到控制台。 |  | ||||||
|  |  | ||||||
|         Returns: |  | ||||||
|             logging.Logger: 配置好的日志实例。 |  | ||||||
|         """ |  | ||||||
|         logger = logging.getLogger('OCRViolationDetector') |  | ||||||
|         logger.setLevel(log_level) |  | ||||||
|  |  | ||||||
|         # 避免重复添加处理器(防止日志重复输出) |  | ||||||
|         if logger.handlers: |  | ||||||
|             return logger |  | ||||||
|  |  | ||||||
|         # 定义日志格式(包含时间、模块名、级别、内容) |  | ||||||
|         formatter = logging.Formatter( |  | ||||||
|             '%(asctime)s - %(name)s - %(levelname)s - %(message)s', |  | ||||||
|             datefmt='%Y-%m-%d %H:%M:%S' |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         # 1. 添加控制台处理器 |  | ||||||
|         console_handler = logging.StreamHandler() |  | ||||||
|         console_handler.setFormatter(formatter) |  | ||||||
|         logger.addHandler(console_handler) |  | ||||||
|  |  | ||||||
|         # 2. 若指定日志文件,添加文件处理器(自动创建目录) |  | ||||||
|         if log_file: |  | ||||||
|             try: |  | ||||||
|                 log_dir = os.path.dirname(log_file) |  | ||||||
|                 # 若日志目录不存在,自动创建 |  | ||||||
|                 if log_dir and not os.path.exists(log_dir): |  | ||||||
|                     os.makedirs(log_dir, exist_ok=True) |  | ||||||
|                     self.logger.debug(f"自动创建日志目录: {log_dir}") |  | ||||||
|  |  | ||||||
|                 file_handler = logging.FileHandler(log_file, encoding='utf-8') |  | ||||||
|                 file_handler.setFormatter(formatter) |  | ||||||
|                 logger.addHandler(file_handler) |  | ||||||
|                 logger.info(f"日志文件已配置: {log_file}") |  | ||||||
|             except Exception as e: |  | ||||||
|                 logger.warning(f"创建日志文件失败(仅控制台输出): {str(e)}") |  | ||||||
|  |  | ||||||
|         return logger |  | ||||||
|  |  | ||||||
|     def _load_forbidden_words(self, path: str) -> set: |     def _load_forbidden_words(self, path: str) -> set: | ||||||
|         """ |         """ | ||||||
|         从TXT文件加载违禁词(去重、过滤空行,支持UTF-8编码) |         从TXT文件加载违禁词(去重、过滤空行,支持UTF-8编码) | ||||||
|  |  | ||||||
|         Args: |  | ||||||
|             path (str): 违禁词TXT文件路径。 |  | ||||||
|  |  | ||||||
|         Returns: |  | ||||||
|             set: 去重后的违禁词集合(空集合表示加载失败)。 |  | ||||||
|         """ |         """ | ||||||
|         forbidden_words = set() |         forbidden_words = set() | ||||||
|  |  | ||||||
|         # 第一步:检查文件是否存在 |         # 检查文件是否存在 | ||||||
|         if not os.path.exists(path): |         if not os.path.exists(path): | ||||||
|             self.logger.error(f"违禁词文件不存在: {path}") |             print(f"错误:违禁词文件不存在: {path}") | ||||||
|             return forbidden_words |             return forbidden_words | ||||||
|  |  | ||||||
|         # 第二步:读取文件并处理内容 |         # 读取文件并处理内容 | ||||||
|         try: |         try: | ||||||
|             with open(path, 'r', encoding='utf-8') as f: |             with open(path, 'r', encoding='utf-8') as f: | ||||||
|                 # 过滤空行、去除首尾空格、去重 |  | ||||||
|                 forbidden_words = { |                 forbidden_words = { | ||||||
|                     line.strip() for line in f |                     line.strip() for line in f | ||||||
|                     if line.strip()  # 跳过空行或纯空格行 |                     if line.strip()  # 跳过空行或纯空格行 | ||||||
|                 } |                 } | ||||||
|             self.logger.info(f"成功加载违禁词: {len(forbidden_words)} 个(已去重)") |             print(f"成功加载违禁词: {len(forbidden_words)} 个(已去重)") | ||||||
|             self.logger.debug(f"违禁词列表: {forbidden_words}") |  | ||||||
|         except UnicodeDecodeError: |         except UnicodeDecodeError: | ||||||
|             self.logger.error(f"违禁词文件编码错误(需UTF-8): {path}") |             print(f"错误:违禁词文件编码错误(需UTF-8): {path}") | ||||||
|         except PermissionError: |         except PermissionError: | ||||||
|             self.logger.error(f"无权限读取违禁词文件: {path}") |             print(f"错误:无权限读取违禁词文件: {path}") | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             self.logger.error(f"加载违禁词失败: {str(e)}", exc_info=True) |             print(f"错误:加载违禁词失败: {str(e)}") | ||||||
|  |  | ||||||
|         return forbidden_words |         return forbidden_words | ||||||
|  |  | ||||||
|     def _initialize_ocr(self, config_path: str) -> RapidOCR | None: |     def _initialize_ocr(self, config_path: str) -> RapidOCR | None: | ||||||
|         """ |         """ | ||||||
|         初始化RapidOCR引擎(校验配置文件、捕获初始化异常) |         初始化RapidOCR引擎(校验配置文件、捕获初始化异常) | ||||||
|  |  | ||||||
|         Args: |  | ||||||
|             config_path (str): RapidOCR配置文件(如1.yaml)路径。 |  | ||||||
|  |  | ||||||
|         Returns: |  | ||||||
|             RapidOCR | None: OCR引擎实例(None表示初始化失败)。 |  | ||||||
|         """ |         """ | ||||||
|         self.logger.info("开始初始化RapidOCR引擎...") |         print("开始初始化RapidOCR引擎...") | ||||||
|  |  | ||||||
|         # 第一步:检查配置文件是否存在 |         # 检查配置文件是否存在 | ||||||
|         if not os.path.exists(config_path): |         if not os.path.exists(config_path): | ||||||
|             self.logger.error(f"OCR配置文件不存在: {config_path}") |             print(f"错误:OCR配置文件不存在: {config_path}") | ||||||
|             return None |             return None | ||||||
|  |  | ||||||
|         # 第二步:初始化OCR引擎(捕获RapidOCR相关异常) |         # 初始化OCR引擎 | ||||||
|         try: |         try: | ||||||
|             ocr_engine = RapidOCR(config_path=config_path) |             ocr_engine = RapidOCR(config_path=config_path) | ||||||
|             self.logger.info("RapidOCR引擎初始化成功") |             print("RapidOCR引擎初始化成功") | ||||||
|             return ocr_engine |             return ocr_engine | ||||||
|         except ImportError: |         except ImportError: | ||||||
|             self.logger.error("RapidOCR依赖未安装(需执行:pip install rapidocr-onnxruntime)") |             print("错误:RapidOCR依赖未安装(需执行:pip install rapidocr-onnxruntime)") | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             self.logger.error(f"RapidOCR初始化失败: {str(e)}", exc_info=True) |             print(f"错误:RapidOCR初始化失败: {str(e)}") | ||||||
|  |  | ||||||
|         return None |         return None | ||||||
|  |  | ||||||
|     def _check_dependencies(self) -> None: |     def _check_dependencies(self) -> None: | ||||||
|         """校验OCR引擎和违禁词列表是否就绪(输出警告日志)""" |         """校验OCR引擎和违禁词列表是否就绪""" | ||||||
|         if not self.ocr_engine: |         if not self.ocr_engine: | ||||||
|             self.logger.warning("⚠️ OCR引擎未就绪,违禁词检测功能将禁用") |             print("警告:⚠️ OCR引擎未就绪,违禁词检测功能将禁用") | ||||||
|         if not self.forbidden_words: |         if not self.forbidden_words: | ||||||
|             self.logger.warning("⚠️ 违禁词列表为空,违禁词检测功能将禁用") |             print("警告:⚠️ 违禁词列表为空,违禁词检测功能将禁用") | ||||||
|  |  | ||||||
|     def detect(self, frame) -> tuple[bool, list, list]: |     def detect(self, frame) -> tuple[bool, list, list]: | ||||||
|         """ |         """ | ||||||
| @ -179,76 +110,69 @@ class OCRViolationDetector: | |||||||
|         violation_words = [] |         violation_words = [] | ||||||
|         violation_confs = [] |         violation_confs = [] | ||||||
|  |  | ||||||
|         # 前置校验:1. 图像帧是否有效 2. OCR引擎是否就绪 3. 违禁词是否存在 |         # 前置校验 | ||||||
|         if frame is None or frame.size == 0: |         if frame is None or frame.size == 0: | ||||||
|             self.logger.warning("输入图像帧为空或无效,跳过OCR检测") |             print("警告:输入图像帧为空或无效,跳过OCR检测") | ||||||
|             return has_violation, violation_words, violation_confs |             return has_violation, violation_words, violation_confs | ||||||
|         if not self.ocr_engine or not self.forbidden_words: |         if not self.ocr_engine or not self.forbidden_words: | ||||||
|             self.logger.debug("OCR引擎未就绪或违禁词为空,跳过OCR检测") |             print("OCR引擎未就绪或违禁词为空,跳过OCR检测") | ||||||
|             return has_violation, violation_words, violation_confs |             return has_violation, violation_words, violation_confs | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             # 1. 执行OCR识别(获取RapidOCR原始结果) |             # 执行OCR识别 | ||||||
|             self.logger.debug("开始执行OCR识别...") |             print("开始执行OCR识别...") | ||||||
|             ocr_result = self.ocr_engine(frame) |             ocr_result = self.ocr_engine(frame) | ||||||
|             self.logger.debug(f"RapidOCR原始结果: {ocr_result}") |             print(f"RapidOCR原始结果: {ocr_result}") | ||||||
|  |  | ||||||
|             # 2. 校验OCR结果是否有效(避免None或格式异常) |             # 校验OCR结果是否有效 | ||||||
|             if ocr_result is None: |             if ocr_result is None: | ||||||
|                 self.logger.debug("OCR识别未返回任何结果(图像无文本或识别失败)") |                 print("OCR识别未返回任何结果(图像无文本或识别失败)") | ||||||
|                 return has_violation, violation_words, violation_confs |                 return has_violation, violation_words, violation_confs | ||||||
|  |  | ||||||
|             # 3. 检查txts和scores是否存在且不为None |             # 检查txts和scores是否存在且不为None | ||||||
|             if not hasattr(ocr_result, 'txts') or ocr_result.txts is None: |             if not hasattr(ocr_result, 'txts') or ocr_result.txts is None: | ||||||
|                 self.logger.warning("OCR结果中txts为None或不存在") |                 print("警告:OCR结果中txts为None或不存在") | ||||||
|                 return has_violation, violation_words, violation_confs |                 return has_violation, violation_words, violation_confs | ||||||
|  |  | ||||||
|             if not hasattr(ocr_result, 'scores') or ocr_result.scores is None: |             if not hasattr(ocr_result, 'scores') or ocr_result.scores is None: | ||||||
|                 self.logger.warning("OCR结果中scores为None或不存在") |                 print("警告:OCR结果中scores为None或不存在") | ||||||
|                 return has_violation, violation_words, violation_confs |                 return has_violation, violation_words, violation_confs | ||||||
|  |  | ||||||
|             # 4. 转为列表并去None(防止单个元素为None) |             # 转为列表并去None | ||||||
|             # 确保txts是可迭代的,如果不是则转为空列表 |  | ||||||
|             if not isinstance(ocr_result.txts, (list, tuple)): |             if not isinstance(ocr_result.txts, (list, tuple)): | ||||||
|                 self.logger.warning(f"OCR txts不是可迭代类型,实际类型: {type(ocr_result.txts)}") |                 print(f"警告:OCR txts不是可迭代类型,实际类型: {type(ocr_result.txts)}") | ||||||
|                 texts = [] |                 texts = [] | ||||||
|             else: |             else: | ||||||
|                 texts = [txt.strip() for txt in ocr_result.txts if txt and isinstance(txt, str)] |                 texts = [txt.strip() for txt in ocr_result.txts if txt and isinstance(txt, str)] | ||||||
|  |  | ||||||
|             # 确保scores是可迭代的,如果不是则转为空列表 |  | ||||||
|             if not isinstance(ocr_result.scores, (list, tuple)): |             if not isinstance(ocr_result.scores, (list, tuple)): | ||||||
|                 self.logger.warning(f"OCR scores不是可迭代类型,实际类型: {type(ocr_result.scores)}") |                 print(f"警告:OCR scores不是可迭代类型,实际类型: {type(ocr_result.scores)}") | ||||||
|                 confidences = [] |                 confidences = [] | ||||||
|             else: |             else: | ||||||
|                 confidences = [conf for conf in ocr_result.scores if conf and isinstance(conf, (int, float))] |                 confidences = [conf for conf in ocr_result.scores if conf and isinstance(conf, (int, float))] | ||||||
|  |  | ||||||
|             # 5. 校验文本和置信度列表长度是否一致(避免zip迭代错误) |             # 校验文本和置信度列表长度是否一致 | ||||||
|             if len(texts) != len(confidences): |             if len(texts) != len(confidences): | ||||||
|                 self.logger.warning( |                 print(f"警告:OCR文本与置信度数量不匹配(文本{len(texts)}个,置信度{len(confidences)}个),跳过检测") | ||||||
|                     f"OCR文本与置信度数量不匹配(文本{len(texts)}个,置信度{len(confidences)}个),跳过检测") |  | ||||||
|                 return has_violation, violation_words, violation_confs |                 return has_violation, violation_words, violation_confs | ||||||
|             if len(texts) == 0: |             if len(texts) == 0: | ||||||
|                 self.logger.debug("OCR未识别到任何有效文本") |                 print("OCR未识别到任何有效文本") | ||||||
|                 return has_violation, violation_words, violation_confs |                 return has_violation, violation_words, violation_confs | ||||||
|  |  | ||||||
|             # 6. 遍历识别结果,筛选违禁词(按置信度阈值过滤) |             # 遍历识别结果,筛选违禁词 | ||||||
|             self.logger.debug(f"开始筛选违禁词(阈值{self.OCR_CONFIDENCE_THRESHOLD:.4f})") |             print(f"开始筛选违禁词(阈值{self.OCR_CONFIDENCE_THRESHOLD:.4f})") | ||||||
|             for text, conf in zip(texts, confidences): |             for text, conf in zip(texts, confidences): | ||||||
|                 # 过滤低置信度结果 |  | ||||||
|                 if conf < self.OCR_CONFIDENCE_THRESHOLD: |                 if conf < self.OCR_CONFIDENCE_THRESHOLD: | ||||||
|                     self.logger.debug(f"文本 '{text}' 置信度{conf:.4f} < 阈值,跳过") |                     print(f"文本 '{text}' 置信度{conf:.4f} < 阈值,跳过") | ||||||
|                     continue |                     continue | ||||||
|                 # 检查当前文本是否包含违禁词(支持一个文本含多个违禁词) |  | ||||||
|                 matched_words = [word for word in self.forbidden_words if word in text] |                 matched_words = [word for word in self.forbidden_words if word in text] | ||||||
|                 if matched_words: |                 if matched_words: | ||||||
|                     has_violation = True |                     has_violation = True | ||||||
|                     # 记录所有匹配的违禁词和对应置信度 |  | ||||||
|                     violation_words.extend(matched_words) |                     violation_words.extend(matched_words) | ||||||
|                     violation_confs.extend([conf] * len(matched_words))  # 一个文本对应多个违禁词时,置信度复用 |                     violation_confs.extend([conf] * len(matched_words)) | ||||||
|                     self.logger.warning(f"检测到违禁词: {matched_words}(来源文本: '{text}',置信度: {conf:.4f})") |                     print(f"警告:检测到违禁词: {matched_words}(来源文本: '{text}',置信度: {conf:.4f})") | ||||||
|  |  | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             # 捕获所有异常,确保不中断上层调用 |             print(f"错误:OCR检测过程异常: {str(e)}") | ||||||
|             self.logger.error(f"OCR检测过程异常: {str(e)}", exc_info=True) |  | ||||||
|  |  | ||||||
|         return has_violation, violation_words, violation_confs |         return has_violation, violation_words, violation_confs | ||||||
| @ -1,6 +1,5 @@ | |||||||
| from ultralytics import YOLO | from ultralytics import YOLO | ||||||
| import cv2 | import cv2 | ||||||
| from logger_config import logger |  | ||||||
|  |  | ||||||
| class ViolationDetector: | class ViolationDetector: | ||||||
|     """ |     """ | ||||||
| @ -13,9 +12,9 @@ class ViolationDetector: | |||||||
|         Args: |         Args: | ||||||
|             model_path (str): YOLO .pt模型的路径。 |             model_path (str): YOLO .pt模型的路径。 | ||||||
|         """ |         """ | ||||||
|         logger.info(f"正在从 '{model_path}' 加载YOLO模型...") |         print(f"正在从 '{model_path}' 加载YOLO模型...") | ||||||
|         self.model = YOLO(model_path) |         self.model = YOLO(model_path) | ||||||
|         logger.info("YOLO模型加载成功。") |         print("YOLO模型加载成功。") | ||||||
|  |  | ||||||
|     def detect(self, frame): |     def detect(self, frame): | ||||||
|         """ |         """ | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user