Files
video_detect/service/ocr_service.py

237 lines
8.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import time
import numpy as np
# 修复np.int已弃用的问题
if not hasattr(np, 'int'):
np.int = int
from paddleocr import PaddleOCR
from service.sensitive_service import get_all_sensitive_words
# AC自动机节点定义
class AhoNode:
def __init__(self):
self.children = {} # 子节点映射(字符->节点)
self.fail = None # 失败指针类似KMP的next数组
self.is_end = False # 标记是否为某个模式串的结尾
self.word = None # 存储当前结尾对应的完整违禁词
# AC自动机实现多模式字符串匹配
class AhoCorasick:
def __init__(self):
self.root = AhoNode() # 根节点
def add_word(self, word):
"""添加违禁词到Trie树"""
if not isinstance(word, str) or not word.strip():
return # 过滤无效词
node = self.root
for char in word:
if char not in node.children:
node.children[char] = AhoNode()
node = node.children[char]
node.is_end = True
node.word = word # 记录完整词
def build_fail(self):
"""构建失败指针BFS遍历"""
queue = []
# 根节点的子节点失败指针指向根节点
for child in self.root.children.values():
child.fail = self.root
queue.append(child)
# BFS处理其他节点
while queue:
current_node = queue.pop(0)
# 遍历当前节点的所有子节点
for char, child in current_node.children.items():
# 寻找失败指针目标节点
fail_node = current_node.fail
while fail_node is not None and char not in fail_node.children:
fail_node = fail_node.fail
# 确定失败指针指向
child.fail = fail_node.children[char] if (fail_node and char in fail_node.children) else self.root
queue.append(child)
def match(self, text):
"""匹配文本中所有出现的违禁词(去重)"""
result = set()
node = self.root
for char in text:
# 沿失败链查找可用节点
while node is not None and char not in node.children:
node = node.fail
# 重置到根节点(如果没找到)
node = node.children[char] if (node and char in node.children) else self.root
# 收集所有匹配的违禁词(包括失败链上的)
temp = node
while temp != self.root:
if temp.is_end:
result.add(temp.word)
temp = temp.fail
return list(result)
# 全局变量
_ocr_engine = None
_ac_automaton = None # 替换原有的_forbidden_words集合
_conf_threshold = 0.5
def set_forbidden_words(new_words):
"""更新违禁词使用AC自动机存储"""
global _ac_automaton
if not isinstance(new_words, (set, list, tuple)):
raise TypeError("新违禁词必须是集合、列表或元组类型")
# 初始化AC自动机并添加有效词
_ac_automaton = AhoCorasick()
valid_words = [word for word in new_words if isinstance(word, str) and word.strip()]
for word in valid_words:
_ac_automaton.add_word(word.strip())
# 构建失败指针(关键步骤)
_ac_automaton.build_fail()
print(f"已通过函数更新违禁词,当前数量: {len(valid_words)}")
def load_forbidden_words():
"""从敏感词服务加载违禁词并初始化AC自动机"""
global _ac_automaton
try:
sensitive_words = get_all_sensitive_words() # 保持原接口不变返回list[str]
_ac_automaton = AhoCorasick()
# 添加所有有效敏感词
valid_words = [word for word in sensitive_words if isinstance(word, str) and word.strip()]
for word in valid_words:
_ac_automaton.add_word(word.strip())
# 构建失败指针
_ac_automaton.build_fail()
print(f"加载的违禁词数量: {len(valid_words)}")
return True
except Exception as e:
print(f"Forbidden words load error: {e}")
return False
def init_ocr_engine():
"""初始化OCR引擎和违禁词自动机"""
global _ocr_engine
try:
_ocr_engine = PaddleOCR(
use_angle_cls=True,
lang="ch",
show_log=False,
use_gpu=True,
max_text_length=1024
)
load_result = load_forbidden_words()
if not load_result:
print("警告:违禁词加载失败,可能影响检测功能")
print("OCR引擎初始化完成")
return True
except Exception as e:
print(f"OCR引擎初始化错误: {e}")
_ocr_engine = None
return False
def detect(frame, conf_threshold=0.8):
"""检测帧中的文本是否包含违禁词拆分OCR和匹配时间"""
print("开始进行OCR检测...")
total_start = time.time() # 总耗时开始
ocr_time = 0.0 # OCR及结果解析耗时
match_time = 0.0 # 违禁词匹配耗时
try:
if not _ocr_engine or not _ac_automaton:
return (False, "OCR引擎或违禁词库未初始化")
# 1. OCR识别及结果解析阶段
ocr_start = time.time()
ocr_res = _ocr_engine.ocr(frame, cls=True)
if not ocr_res or not isinstance(ocr_res, list):
return (False, "无OCR结果")
texts = []
confs = []
# 解析OCR结果
for line in ocr_res:
if line is None:
continue
items_to_process = line if isinstance(line, list) else [line]
for item in items_to_process:
# 过滤坐标类数据
if isinstance(item, list) and len(item) == 4:
is_coordinate = all(isinstance(p, list) and len(p) == 2 and
all(isinstance(c, (int, float)) for c in p)
for p in item)
if is_coordinate:
continue
# 提取文本和置信度
if isinstance(item, tuple) and len(item) == 2:
text, conf = item
if isinstance(text, str) and isinstance(conf, (int, float)):
texts.append(text.strip())
confs.append(float(conf))
continue
if isinstance(item, list) and len(item) >= 2:
text_data = item[1]
if isinstance(text_data, tuple) and len(text_data) == 2:
text, conf = text_data
if isinstance(text, str) and isinstance(conf, (int, float)):
texts.append(text.strip())
confs.append(float(conf))
continue
elif isinstance(text_data, str):
texts.append(text_data.strip())
confs.append(1.0)
continue
ocr_end = time.time()
ocr_time = ocr_end - ocr_start # 计算OCR阶段耗时
if len(texts) != len(confs):
return (False, "OCR结果格式异常")
# 2. 违禁词匹配阶段
match_start = time.time()
vio_words = []
for txt, conf in zip(texts, confs):
if conf < _conf_threshold:
continue
# 用AC自动机匹配当前文本中的所有违禁词
matched_words = _ac_automaton.match(txt)
# 全局去重并保持顺序
for word in matched_words:
if word not in vio_words:
vio_words.append(word)
match_end = time.time()
match_time = match_end - match_start # 计算匹配阶段耗时
has_text = len(texts) > 0
has_violation = len(vio_words) > 0
if not has_text:
return (False, "未识别到文本")
elif has_violation:
return (True, ", ".join(vio_words))
else:
return (False, "未检测到违禁词")
except Exception as e:
print(f"OCR detect error: {e}")
return (False, f"检测错误: {str(e)}")
finally:
# 打印各阶段耗时
total_time = time.time() - total_start
print(f"当前帧耗时明细:")
print(f" OCR识别及解析{ocr_time:.8f}")
print(f" 违禁词匹配:{match_time:.8f}")
print(f" 总耗时:{total_time:.8f}")