Files
video/ocr/ocr_violation_detector.py

178 lines
7.5 KiB
Python
Raw Normal View History

2025-09-02 19:49:54 +08:00
import os
import cv2
from rapidocr import RapidOCR
class OCRViolationDetector:
"""
封装RapidOCR引擎用于检测图像帧中的违禁词
2025-09-03 14:38:42 +08:00
核心功能加载违禁词初始化OCR引擎单帧图像违禁词检测
2025-09-02 19:49:54 +08:00
"""
2025-09-03 14:38:42 +08:00
def __init__(self,
forbidden_words_path: str,
ocr_config_path: str,
2025-09-03 16:22:21 +08:00
ocr_confidence_threshold: float = 0.5):
2025-09-02 19:49:54 +08:00
"""
2025-09-03 16:22:21 +08:00
初始化OCR引擎和违禁词列表
2025-09-02 19:49:54 +08:00
Args:
forbidden_words_path (str): 违禁词列表 .txt 文件的路径
2025-09-03 14:38:42 +08:00
ocr_config_path (str): OCR配置文件如1.yaml的路径
ocr_confidence_threshold (float): OCR识别结果的置信度阈值0~1
2025-09-02 19:49:54 +08:00
"""
2025-09-03 16:22:21 +08:00
# 加载违禁词
2025-09-02 19:49:54 +08:00
self.forbidden_words = self._load_forbidden_words(forbidden_words_path)
2025-09-02 21:30:28 +08:00
2025-09-03 16:22:21 +08:00
# 初始化RapidOCR引擎
2025-09-03 14:38:42 +08:00
self.ocr_engine = self._initialize_ocr(ocr_config_path)
2025-09-02 21:30:28 +08:00
2025-09-03 14:38:42 +08:00
# 校验核心依赖是否就绪
self._check_dependencies()
2025-09-03 16:22:21 +08:00
# 设置置信度阈值限制在0~1范围
2025-09-03 14:38:42 +08:00
self.OCR_CONFIDENCE_THRESHOLD = max(0.0, min(ocr_confidence_threshold, 1.0))
2025-09-03 16:22:21 +08:00
print(f"OCR置信度阈值已设置范围0~1: {self.OCR_CONFIDENCE_THRESHOLD:.4f}")
2025-09-02 19:49:54 +08:00
2025-09-03 14:38:42 +08:00
def _load_forbidden_words(self, path: str) -> set:
"""
从TXT文件加载违禁词去重过滤空行支持UTF-8编码
"""
forbidden_words = set()
2025-09-03 16:22:21 +08:00
# 检查文件是否存在
2025-09-02 19:49:54 +08:00
if not os.path.exists(path):
2025-09-03 16:22:21 +08:00
print(f"错误:违禁词文件不存在: {path}")
2025-09-03 14:38:42 +08:00
return forbidden_words
2025-09-02 19:49:54 +08:00
2025-09-03 16:22:21 +08:00
# 读取文件并处理内容
2025-09-02 19:49:54 +08:00
try:
with open(path, 'r', encoding='utf-8') as f:
2025-09-03 14:38:42 +08:00
forbidden_words = {
line.strip() for line in f
if line.strip() # 跳过空行或纯空格行
}
2025-09-03 16:22:21 +08:00
print(f"成功加载违禁词: {len(forbidden_words)} 个(已去重)")
2025-09-03 14:38:42 +08:00
except UnicodeDecodeError:
2025-09-03 16:22:21 +08:00
print(f"错误违禁词文件编码错误需UTF-8: {path}")
2025-09-03 14:38:42 +08:00
except PermissionError:
2025-09-03 16:22:21 +08:00
print(f"错误:无权限读取违禁词文件: {path}")
2025-09-02 19:49:54 +08:00
except Exception as e:
2025-09-03 16:22:21 +08:00
print(f"错误:加载违禁词失败: {str(e)}")
2025-09-02 19:49:54 +08:00
2025-09-03 14:38:42 +08:00
return forbidden_words
2025-09-02 19:49:54 +08:00
2025-09-03 14:38:42 +08:00
def _initialize_ocr(self, config_path: str) -> RapidOCR | None:
2025-09-02 19:49:54 +08:00
"""
2025-09-03 14:38:42 +08:00
初始化RapidOCR引擎校验配置文件捕获初始化异常
2025-09-02 19:49:54 +08:00
"""
2025-09-03 16:22:21 +08:00
print("开始初始化RapidOCR引擎...")
2025-09-02 19:49:54 +08:00
2025-09-03 16:22:21 +08:00
# 检查配置文件是否存在
2025-09-03 14:38:42 +08:00
if not os.path.exists(config_path):
2025-09-03 16:22:21 +08:00
print(f"错误OCR配置文件不存在: {config_path}")
2025-09-03 14:38:42 +08:00
return None
2025-09-02 19:49:54 +08:00
2025-09-03 16:22:21 +08:00
# 初始化OCR引擎
2025-09-02 19:49:54 +08:00
try:
2025-09-03 14:38:42 +08:00
ocr_engine = RapidOCR(config_path=config_path)
2025-09-03 16:22:21 +08:00
print("RapidOCR引擎初始化成功")
2025-09-03 14:38:42 +08:00
return ocr_engine
except ImportError:
2025-09-03 16:22:21 +08:00
print("错误RapidOCR依赖未安装需执行pip install rapidocr-onnxruntime")
2025-09-03 14:38:42 +08:00
except Exception as e:
2025-09-03 16:22:21 +08:00
print(f"错误RapidOCR初始化失败: {str(e)}")
2025-09-02 19:49:54 +08:00
2025-09-03 14:38:42 +08:00
return None
2025-09-02 19:49:54 +08:00
2025-09-03 14:38:42 +08:00
def _check_dependencies(self) -> None:
2025-09-03 16:22:21 +08:00
"""校验OCR引擎和违禁词列表是否就绪"""
2025-09-03 14:38:42 +08:00
if not self.ocr_engine:
2025-09-03 16:22:21 +08:00
print("警告:⚠️ OCR引擎未就绪违禁词检测功能将禁用")
2025-09-03 14:38:42 +08:00
if not self.forbidden_words:
2025-09-03 16:22:21 +08:00
print("警告:⚠️ 违禁词列表为空,违禁词检测功能将禁用")
2025-09-03 14:38:42 +08:00
def detect(self, frame) -> tuple[bool, list, list]:
"""
对单帧图像进行OCR违禁词检测核心方法
Args:
frame: 输入图像帧NumPy数组BGR格式cv2读取的图像
2025-09-02 19:49:54 +08:00
2025-09-03 14:38:42 +08:00
Returns:
tuple[bool, list, list]:
- 第一个元素是否检测到违禁词True/False
- 第二个元素检测到的违禁词列表空列表表示无违禁词
- 第三个元素对应违禁词的置信度列表与违禁词列表一一对应
"""
# 初始化返回结果
has_violation = False
violation_words = []
violation_confs = []
2025-09-03 16:22:21 +08:00
# 前置校验
2025-09-03 14:38:42 +08:00
if frame is None or frame.size == 0:
2025-09-03 16:22:21 +08:00
print("警告输入图像帧为空或无效跳过OCR检测")
2025-09-03 14:38:42 +08:00
return has_violation, violation_words, violation_confs
if not self.ocr_engine or not self.forbidden_words:
2025-09-03 16:22:21 +08:00
print("OCR引擎未就绪或违禁词为空跳过OCR检测")
2025-09-03 14:38:42 +08:00
return has_violation, violation_words, violation_confs
try:
2025-09-03 16:22:21 +08:00
# 执行OCR识别
print("开始执行OCR识别...")
2025-09-03 14:38:42 +08:00
ocr_result = self.ocr_engine(frame)
2025-09-03 16:22:21 +08:00
print(f"RapidOCR原始结果: {ocr_result}")
2025-09-03 14:38:42 +08:00
2025-09-03 16:22:21 +08:00
# 校验OCR结果是否有效
2025-09-03 14:38:42 +08:00
if ocr_result is None:
2025-09-03 16:22:21 +08:00
print("OCR识别未返回任何结果图像无文本或识别失败")
2025-09-03 14:38:42 +08:00
return has_violation, violation_words, violation_confs
2025-09-03 16:22:21 +08:00
# 检查txts和scores是否存在且不为None
2025-09-03 14:38:42 +08:00
if not hasattr(ocr_result, 'txts') or ocr_result.txts is None:
2025-09-03 16:22:21 +08:00
print("警告OCR结果中txts为None或不存在")
2025-09-03 14:38:42 +08:00
return has_violation, violation_words, violation_confs
if not hasattr(ocr_result, 'scores') or ocr_result.scores is None:
2025-09-03 16:22:21 +08:00
print("警告OCR结果中scores为None或不存在")
2025-09-03 14:38:42 +08:00
return has_violation, violation_words, violation_confs
2025-09-03 16:22:21 +08:00
# 转为列表并去None
2025-09-03 14:38:42 +08:00
if not isinstance(ocr_result.txts, (list, tuple)):
2025-09-03 16:22:21 +08:00
print(f"警告OCR txts不是可迭代类型实际类型: {type(ocr_result.txts)}")
2025-09-03 14:38:42 +08:00
texts = []
else:
texts = [txt.strip() for txt in ocr_result.txts if txt and isinstance(txt, str)]
if not isinstance(ocr_result.scores, (list, tuple)):
2025-09-03 16:22:21 +08:00
print(f"警告OCR scores不是可迭代类型实际类型: {type(ocr_result.scores)}")
2025-09-03 14:38:42 +08:00
confidences = []
else:
confidences = [conf for conf in ocr_result.scores if conf and isinstance(conf, (int, float))]
2025-09-03 16:22:21 +08:00
# 校验文本和置信度列表长度是否一致
2025-09-03 14:38:42 +08:00
if len(texts) != len(confidences):
2025-09-03 16:22:21 +08:00
print(f"警告OCR文本与置信度数量不匹配文本{len(texts)}个,置信度{len(confidences)}个),跳过检测")
2025-09-03 14:38:42 +08:00
return has_violation, violation_words, violation_confs
if len(texts) == 0:
2025-09-03 16:22:21 +08:00
print("OCR未识别到任何有效文本")
2025-09-03 14:38:42 +08:00
return has_violation, violation_words, violation_confs
2025-09-03 16:22:21 +08:00
# 遍历识别结果,筛选违禁词
print(f"开始筛选违禁词(阈值{self.OCR_CONFIDENCE_THRESHOLD:.4f}")
2025-09-02 19:49:54 +08:00
for text, conf in zip(texts, confidences):
if conf < self.OCR_CONFIDENCE_THRESHOLD:
2025-09-03 16:22:21 +08:00
print(f"文本 '{text}' 置信度{conf:.4f} < 阈值,跳过")
2025-09-02 19:49:54 +08:00
continue
2025-09-03 14:38:42 +08:00
matched_words = [word for word in self.forbidden_words if word in text]
if matched_words:
has_violation = True
violation_words.extend(matched_words)
2025-09-03 16:22:21 +08:00
violation_confs.extend([conf] * len(matched_words))
print(f"警告:检测到违禁词: {matched_words}(来源文本: '{text}',置信度: {conf:.4f}")
2025-09-02 19:49:54 +08:00
except Exception as e:
2025-09-03 16:22:21 +08:00
print(f"错误OCR检测过程异常: {str(e)}")
2025-09-03 14:38:42 +08:00
2025-09-03 16:22:21 +08:00
return has_violation, violation_words, violation_confs