import os import cv2 from rapidocr import RapidOCR from service.sensitive_service import get_all_sensitive_words # 全局变量 _ocr_engine = None _forbidden_words = set() _conf_threshold = 0.5 ocr_config_path = os.path.join(os.path.dirname(__file__), "config", "config.yaml") def load_model(): """加载OCR引擎及违禁词列表""" global _ocr_engine, _forbidden_words, _conf_threshold # 加载违禁词 try: _forbidden_words = get_all_sensitive_words() except Exception as e: print(f"Forbidden words load error: {e}") # 初始化OCR引擎 if not os.path.exists(ocr_config_path): print(f"OCR config not found: {ocr_config_path}") return False try: _ocr_engine = RapidOCR(config_path=ocr_config_path) except Exception as e: print(f"OCR model load failed: {e}") return False return True if _ocr_engine else False def detect(frame): """OCR检测并筛选违禁词,返回(是否检测到违禁词, 结果字符串)""" if not _ocr_engine or not _forbidden_words or frame is None or frame.size == 0: return (False, "未初始化或无效帧") try: ocr_res = _ocr_engine(frame) except Exception as e: print(f"OCR detect error: {e}") return (False, f"检测错误: {str(e)}") if not ocr_res or not hasattr(ocr_res, 'txts') or not hasattr(ocr_res, 'scores'): return (False, "无OCR结果") # 处理OCR结果 texts = [t.strip() for t in ocr_res.txts if t and isinstance(t, str)] confs = [c for c in ocr_res.scores if c and isinstance(c, (int, float))] if len(texts) != len(confs): return (False, "OCR结果格式异常") # 筛选违禁词 vio_info = [] for txt, conf in zip(texts, confs): if conf < _conf_threshold: continue matched = [w for w in _forbidden_words if w in txt] if matched: vio_info.append(f"文本: '{txt}' 包含违禁词: {', '.join(matched)} (置信度: {conf:.2f})") # 构建结果字符串 has_text = len(texts) > 0 has_violation = len(vio_info) > 0 if not has_text: return (False, "未识别到文本") elif has_violation: return (True, "; ".join(vio_info)) else: return (False, "未检测到违禁词")