import time import numpy as np # 修复np.int已弃用的问题 if not hasattr(np, 'int'): np.int = int from paddleocr import PaddleOCR from service.sensitive_service import get_all_sensitive_words # AC自动机节点定义 class AhoNode: def __init__(self): self.children = {} # 子节点映射(字符->节点) self.fail = None # 失败指针(类似KMP的next数组) self.is_end = False # 标记是否为某个模式串的结尾 self.word = None # 存储当前结尾对应的完整违禁词 # AC自动机实现(多模式字符串匹配) class AhoCorasick: def __init__(self): self.root = AhoNode() # 根节点 def add_word(self, word): """添加违禁词到Trie树""" if not isinstance(word, str) or not word.strip(): return # 过滤无效词 node = self.root for char in word: if char not in node.children: node.children[char] = AhoNode() node = node.children[char] node.is_end = True node.word = word # 记录完整词 def build_fail(self): """构建失败指针(BFS遍历)""" queue = [] # 根节点的子节点失败指针指向根节点 for child in self.root.children.values(): child.fail = self.root queue.append(child) # BFS处理其他节点 while queue: current_node = queue.pop(0) # 遍历当前节点的所有子节点 for char, child in current_node.children.items(): # 寻找失败指针目标节点 fail_node = current_node.fail while fail_node is not None and char not in fail_node.children: fail_node = fail_node.fail # 确定失败指针指向 child.fail = fail_node.children[char] if (fail_node and char in fail_node.children) else self.root queue.append(child) def match(self, text): """匹配文本中所有出现的违禁词(去重)""" result = set() node = self.root for char in text: # 沿失败链查找可用节点 while node is not None and char not in node.children: node = node.fail # 重置到根节点(如果没找到) node = node.children[char] if (node and char in node.children) else self.root # 收集所有匹配的违禁词(包括失败链上的) temp = node while temp != self.root: if temp.is_end: result.add(temp.word) temp = temp.fail return list(result) # 全局变量 _ocr_engine = None _ac_automaton = None # 替换原有的_forbidden_words集合 _conf_threshold = 0.5 def set_forbidden_words(new_words): """更新违禁词(使用AC自动机存储)""" global _ac_automaton if not isinstance(new_words, (set, list, tuple)): raise TypeError("新违禁词必须是集合、列表或元组类型") # 初始化AC自动机并添加有效词 _ac_automaton = AhoCorasick() valid_words = [word for word in new_words if isinstance(word, str) and word.strip()] for word in valid_words: _ac_automaton.add_word(word.strip()) # 构建失败指针(关键步骤) _ac_automaton.build_fail() print(f"已通过函数更新违禁词,当前数量: {len(valid_words)}") def load_forbidden_words(): """从敏感词服务加载违禁词并初始化AC自动机""" global _ac_automaton try: sensitive_words = get_all_sensitive_words() # 保持原接口不变(返回list[str]) _ac_automaton = AhoCorasick() # 添加所有有效敏感词 valid_words = [word for word in sensitive_words if isinstance(word, str) and word.strip()] for word in valid_words: _ac_automaton.add_word(word.strip()) # 构建失败指针 _ac_automaton.build_fail() print(f"加载的违禁词数量: {len(valid_words)}") return True except Exception as e: print(f"Forbidden words load error: {e}") return False def init_ocr_engine(): """初始化OCR引擎和违禁词自动机""" global _ocr_engine try: _ocr_engine = PaddleOCR( use_angle_cls=True, lang="ch", show_log=False, use_gpu=True, max_text_length=1024 ) load_result = load_forbidden_words() if not load_result: print("警告:违禁词加载失败,可能影响检测功能") print("OCR引擎初始化完成") return True except Exception as e: print(f"OCR引擎初始化错误: {e}") _ocr_engine = None return False def detect(frame, conf_threshold=0.8): """检测帧中的文本是否包含违禁词(拆分OCR和匹配时间)""" print("开始进行OCR检测...") total_start = time.time() # 总耗时开始 ocr_time = 0.0 # OCR及结果解析耗时 match_time = 0.0 # 违禁词匹配耗时 try: if not _ocr_engine or not _ac_automaton: return (False, "OCR引擎或违禁词库未初始化") # 1. OCR识别及结果解析阶段 ocr_start = time.time() ocr_res = _ocr_engine.ocr(frame, cls=True) if not ocr_res or not isinstance(ocr_res, list): return (False, "无OCR结果") texts = [] confs = [] # 解析OCR结果 for line in ocr_res: if line is None: continue items_to_process = line if isinstance(line, list) else [line] for item in items_to_process: # 过滤坐标类数据 if isinstance(item, list) and len(item) == 4: is_coordinate = all(isinstance(p, list) and len(p) == 2 and all(isinstance(c, (int, float)) for c in p) for p in item) if is_coordinate: continue # 提取文本和置信度 if isinstance(item, tuple) and len(item) == 2: text, conf = item if isinstance(text, str) and isinstance(conf, (int, float)): texts.append(text.strip()) confs.append(float(conf)) continue if isinstance(item, list) and len(item) >= 2: text_data = item[1] if isinstance(text_data, tuple) and len(text_data) == 2: text, conf = text_data if isinstance(text, str) and isinstance(conf, (int, float)): texts.append(text.strip()) confs.append(float(conf)) continue elif isinstance(text_data, str): texts.append(text_data.strip()) confs.append(1.0) continue ocr_end = time.time() ocr_time = ocr_end - ocr_start # 计算OCR阶段耗时 if len(texts) != len(confs): return (False, "OCR结果格式异常") # 2. 违禁词匹配阶段 match_start = time.time() vio_words = [] for txt, conf in zip(texts, confs): if conf < _conf_threshold: continue # 用AC自动机匹配当前文本中的所有违禁词 matched_words = _ac_automaton.match(txt) # 全局去重并保持顺序 for word in matched_words: if word not in vio_words: vio_words.append(word) match_end = time.time() match_time = match_end - match_start # 计算匹配阶段耗时 has_text = len(texts) > 0 has_violation = len(vio_words) > 0 if not has_text: return (False, "未识别到文本") elif has_violation: return (True, ", ".join(vio_words)) else: return (False, "未检测到违禁词") except Exception as e: print(f"OCR detect error: {e}") return (False, f"检测错误: {str(e)}") finally: # 打印各阶段耗时 total_time = time.time() - total_start print(f"当前帧耗时明细:") print(f" OCR识别及解析:{ocr_time:.8f}秒") print(f" 违禁词匹配:{match_time:.8f}秒") print(f" 总耗时:{total_time:.8f}秒")