| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  | import time | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  | import numpy as np | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | # 修复np.int已弃用的问题 | 
					
						
							|  |  |  |  | if not hasattr(np, 'int'): | 
					
						
							|  |  |  |  |     np.int = int | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | from paddleocr import PaddleOCR | 
					
						
							|  |  |  |  | from service.sensitive_service import get_all_sensitive_words | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | # AC自动机节点定义 | 
					
						
							|  |  |  |  | class AhoNode: | 
					
						
							|  |  |  |  |     def __init__(self): | 
					
						
							|  |  |  |  |         self.children = {}  # 子节点映射(字符->节点) | 
					
						
							|  |  |  |  |         self.fail = None  # 失败指针(类似KMP的next数组) | 
					
						
							|  |  |  |  |         self.is_end = False  # 标记是否为某个模式串的结尾 | 
					
						
							|  |  |  |  |         self.word = None  # 存储当前结尾对应的完整违禁词 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | # AC自动机实现(多模式字符串匹配) | 
					
						
							|  |  |  |  | class AhoCorasick: | 
					
						
							|  |  |  |  |     def __init__(self): | 
					
						
							|  |  |  |  |         self.root = AhoNode()  # 根节点 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def add_word(self, word): | 
					
						
							|  |  |  |  |         """添加违禁词到Trie树""" | 
					
						
							|  |  |  |  |         if not isinstance(word, str) or not word.strip(): | 
					
						
							|  |  |  |  |             return  # 过滤无效词 | 
					
						
							|  |  |  |  |         node = self.root | 
					
						
							|  |  |  |  |         for char in word: | 
					
						
							|  |  |  |  |             if char not in node.children: | 
					
						
							|  |  |  |  |                 node.children[char] = AhoNode() | 
					
						
							|  |  |  |  |             node = node.children[char] | 
					
						
							|  |  |  |  |         node.is_end = True | 
					
						
							|  |  |  |  |         node.word = word  # 记录完整词 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def build_fail(self): | 
					
						
							|  |  |  |  |         """构建失败指针(BFS遍历)""" | 
					
						
							|  |  |  |  |         queue = [] | 
					
						
							|  |  |  |  |         # 根节点的子节点失败指针指向根节点 | 
					
						
							|  |  |  |  |         for child in self.root.children.values(): | 
					
						
							|  |  |  |  |             child.fail = self.root | 
					
						
							|  |  |  |  |             queue.append(child) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # BFS处理其他节点 | 
					
						
							|  |  |  |  |         while queue: | 
					
						
							|  |  |  |  |             current_node = queue.pop(0) | 
					
						
							|  |  |  |  |             # 遍历当前节点的所有子节点 | 
					
						
							|  |  |  |  |             for char, child in current_node.children.items(): | 
					
						
							|  |  |  |  |                 # 寻找失败指针目标节点 | 
					
						
							|  |  |  |  |                 fail_node = current_node.fail | 
					
						
							|  |  |  |  |                 while fail_node is not None and char not in fail_node.children: | 
					
						
							|  |  |  |  |                     fail_node = fail_node.fail | 
					
						
							|  |  |  |  |                 # 确定失败指针指向 | 
					
						
							|  |  |  |  |                 child.fail = fail_node.children[char] if (fail_node and char in fail_node.children) else self.root | 
					
						
							|  |  |  |  |                 queue.append(child) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def match(self, text): | 
					
						
							|  |  |  |  |         """匹配文本中所有出现的违禁词(去重)""" | 
					
						
							|  |  |  |  |         result = set() | 
					
						
							|  |  |  |  |         node = self.root | 
					
						
							|  |  |  |  |         for char in text: | 
					
						
							|  |  |  |  |             # 沿失败链查找可用节点 | 
					
						
							|  |  |  |  |             while node is not None and char not in node.children: | 
					
						
							|  |  |  |  |                 node = node.fail | 
					
						
							|  |  |  |  |             # 重置到根节点(如果没找到) | 
					
						
							|  |  |  |  |             node = node.children[char] if (node and char in node.children) else self.root | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |             # 收集所有匹配的违禁词(包括失败链上的) | 
					
						
							|  |  |  |  |             temp = node | 
					
						
							|  |  |  |  |             while temp != self.root: | 
					
						
							|  |  |  |  |                 if temp.is_end: | 
					
						
							|  |  |  |  |                     result.add(temp.word) | 
					
						
							|  |  |  |  |                 temp = temp.fail | 
					
						
							|  |  |  |  |         return list(result) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | # 全局变量 | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  | _ocr_engine = None | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  | _ac_automaton = None  # 替换原有的_forbidden_words集合 | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  | _conf_threshold = 0.5 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  | def set_forbidden_words(new_words): | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |     """更新违禁词(使用AC自动机存储)""" | 
					
						
							|  |  |  |  |     global _ac_automaton | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |     if not isinstance(new_words, (set, list, tuple)): | 
					
						
							|  |  |  |  |         raise TypeError("新违禁词必须是集合、列表或元组类型") | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |     # 初始化AC自动机并添加有效词 | 
					
						
							|  |  |  |  |     _ac_automaton = AhoCorasick() | 
					
						
							|  |  |  |  |     valid_words = [word for word in new_words if isinstance(word, str) and word.strip()] | 
					
						
							|  |  |  |  |     for word in valid_words: | 
					
						
							|  |  |  |  |         _ac_automaton.add_word(word.strip()) | 
					
						
							|  |  |  |  |     # 构建失败指针(关键步骤) | 
					
						
							|  |  |  |  |     _ac_automaton.build_fail() | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     print(f"已通过函数更新违禁词,当前数量: {len(valid_words)}") | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | def load_forbidden_words(): | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |     """从敏感词服务加载违禁词并初始化AC自动机""" | 
					
						
							|  |  |  |  |     global _ac_automaton | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |     try: | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |         sensitive_words = get_all_sensitive_words()  # 保持原接口不变(返回list[str]) | 
					
						
							|  |  |  |  |         _ac_automaton = AhoCorasick() | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 添加所有有效敏感词 | 
					
						
							|  |  |  |  |         valid_words = [word for word in sensitive_words if isinstance(word, str) and word.strip()] | 
					
						
							|  |  |  |  |         for word in valid_words: | 
					
						
							|  |  |  |  |             _ac_automaton.add_word(word.strip()) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 构建失败指针 | 
					
						
							|  |  |  |  |         _ac_automaton.build_fail() | 
					
						
							|  |  |  |  |         print(f"加载的违禁词数量: {len(valid_words)}") | 
					
						
							|  |  |  |  |         return True | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |     except Exception as e: | 
					
						
							|  |  |  |  |         print(f"Forbidden words load error: {e}") | 
					
						
							|  |  |  |  |         return False | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | def init_ocr_engine(): | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |     """初始化OCR引擎和违禁词自动机""" | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |     global _ocr_engine | 
					
						
							|  |  |  |  |     try: | 
					
						
							|  |  |  |  |         _ocr_engine = PaddleOCR( | 
					
						
							|  |  |  |  |             use_angle_cls=True, | 
					
						
							|  |  |  |  |             lang="ch", | 
					
						
							|  |  |  |  |             show_log=False, | 
					
						
							|  |  |  |  |             use_gpu=True, | 
					
						
							|  |  |  |  |             max_text_length=1024 | 
					
						
							|  |  |  |  |         ) | 
					
						
							|  |  |  |  |         load_result = load_forbidden_words() | 
					
						
							|  |  |  |  |         if not load_result: | 
					
						
							|  |  |  |  |             print("警告:违禁词加载失败,可能影响检测功能") | 
					
						
							|  |  |  |  |         print("OCR引擎初始化完成") | 
					
						
							|  |  |  |  |         return True | 
					
						
							|  |  |  |  |     except Exception as e: | 
					
						
							|  |  |  |  |         print(f"OCR引擎初始化错误: {e}") | 
					
						
							|  |  |  |  |         _ocr_engine = None | 
					
						
							|  |  |  |  |         return False | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | def detect(frame, conf_threshold=0.8): | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |     """检测帧中的文本是否包含违禁词(拆分OCR和匹配时间)""" | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |     print("开始进行OCR检测...") | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |     total_start = time.time()  # 总耗时开始 | 
					
						
							|  |  |  |  |     ocr_time = 0.0  # OCR及结果解析耗时 | 
					
						
							|  |  |  |  |     match_time = 0.0  # 违禁词匹配耗时 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |     try: | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |         if not _ocr_engine or not _ac_automaton: | 
					
						
							|  |  |  |  |             return (False, "OCR引擎或违禁词库未初始化") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 1. OCR识别及结果解析阶段 | 
					
						
							|  |  |  |  |         ocr_start = time.time() | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |         ocr_res = _ocr_engine.ocr(frame, cls=True) | 
					
						
							|  |  |  |  |         if not ocr_res or not isinstance(ocr_res, list): | 
					
						
							|  |  |  |  |             return (False, "无OCR结果") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         texts = [] | 
					
						
							|  |  |  |  |         confs = [] | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |         # 解析OCR结果 | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |         for line in ocr_res: | 
					
						
							|  |  |  |  |             if line is None: | 
					
						
							|  |  |  |  |                 continue | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |             items_to_process = line if isinstance(line, list) else [line] | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |             for item in items_to_process: | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |                 # 过滤坐标类数据 | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |                 if isinstance(item, list) and len(item) == 4: | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |                     is_coordinate = all(isinstance(p, list) and len(p) == 2 and | 
					
						
							|  |  |  |  |                                         all(isinstance(c, (int, float)) for c in p) | 
					
						
							|  |  |  |  |                                         for p in item) | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |                     if is_coordinate: | 
					
						
							|  |  |  |  |                         continue | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |                 # 提取文本和置信度 | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |                 if isinstance(item, tuple) and len(item) == 2: | 
					
						
							|  |  |  |  |                     text, conf = item | 
					
						
							|  |  |  |  |                     if isinstance(text, str) and isinstance(conf, (int, float)): | 
					
						
							|  |  |  |  |                         texts.append(text.strip()) | 
					
						
							|  |  |  |  |                         confs.append(float(conf)) | 
					
						
							|  |  |  |  |                         continue | 
					
						
							|  |  |  |  |                 if isinstance(item, list) and len(item) >= 2: | 
					
						
							|  |  |  |  |                     text_data = item[1] | 
					
						
							|  |  |  |  |                     if isinstance(text_data, tuple) and len(text_data) == 2: | 
					
						
							|  |  |  |  |                         text, conf = text_data | 
					
						
							|  |  |  |  |                         if isinstance(text, str) and isinstance(conf, (int, float)): | 
					
						
							|  |  |  |  |                             texts.append(text.strip()) | 
					
						
							|  |  |  |  |                             confs.append(float(conf)) | 
					
						
							|  |  |  |  |                             continue | 
					
						
							|  |  |  |  |                     elif isinstance(text_data, str): | 
					
						
							|  |  |  |  |                         texts.append(text_data.strip()) | 
					
						
							|  |  |  |  |                         confs.append(1.0) | 
					
						
							|  |  |  |  |                         continue | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |         ocr_end = time.time() | 
					
						
							|  |  |  |  |         ocr_time = ocr_end - ocr_start  # 计算OCR阶段耗时 | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |         if len(texts) != len(confs): | 
					
						
							|  |  |  |  |             return (False, "OCR结果格式异常") | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |         # 2. 违禁词匹配阶段 | 
					
						
							|  |  |  |  |         match_start = time.time() | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |         vio_words = [] | 
					
						
							|  |  |  |  |         for txt, conf in zip(texts, confs): | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |             if conf < _conf_threshold: | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |                 continue | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |             # 用AC自动机匹配当前文本中的所有违禁词 | 
					
						
							|  |  |  |  |             matched_words = _ac_automaton.match(txt) | 
					
						
							|  |  |  |  |             # 全局去重并保持顺序 | 
					
						
							|  |  |  |  |             for word in matched_words: | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |                 if word not in vio_words: | 
					
						
							|  |  |  |  |                     vio_words.append(word) | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |         match_end = time.time() | 
					
						
							|  |  |  |  |         match_time = match_end - match_start  # 计算匹配阶段耗时 | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |         has_text = len(texts) > 0 | 
					
						
							|  |  |  |  |         has_violation = len(vio_words) > 0 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         if not has_text: | 
					
						
							|  |  |  |  |             return (False, "未识别到文本") | 
					
						
							|  |  |  |  |         elif has_violation: | 
					
						
							|  |  |  |  |             return (True, ", ".join(vio_words)) | 
					
						
							|  |  |  |  |         else: | 
					
						
							|  |  |  |  |             return (False, "未检测到违禁词") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     except Exception as e: | 
					
						
							|  |  |  |  |         print(f"OCR detect error: {e}") | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |         return (False, f"检测错误: {str(e)}") | 
					
						
							|  |  |  |  |     finally: | 
					
						
							|  |  |  |  |         # 打印各阶段耗时 | 
					
						
							|  |  |  |  |         total_time = time.time() - total_start | 
					
						
							|  |  |  |  |         print(f"当前帧耗时明细:") | 
					
						
							|  |  |  |  |         print(f"  OCR识别及解析:{ocr_time:.8f}秒") | 
					
						
							|  |  |  |  |         print(f"  违禁词匹配:{match_time:.8f}秒") | 
					
						
							|  |  |  |  |         print(f"  总耗时:{total_time:.8f}秒") |