Files
video/ocr/ocr_violation_detector.py
2025-09-03 16:22:21 +08:00

178 lines
7.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
from rapidocr import RapidOCR
class OCRViolationDetector:
"""
封装RapidOCR引擎用于检测图像帧中的违禁词。
核心功能加载违禁词、初始化OCR引擎、单帧图像违禁词检测
"""
def __init__(self,
forbidden_words_path: str,
ocr_config_path: str,
ocr_confidence_threshold: float = 0.5):
"""
初始化OCR引擎和违禁词列表。
Args:
forbidden_words_path (str): 违禁词列表 .txt 文件的路径。
ocr_config_path (str): OCR配置文件如1.yaml的路径。
ocr_confidence_threshold (float): OCR识别结果的置信度阈值0~1
"""
# 加载违禁词
self.forbidden_words = self._load_forbidden_words(forbidden_words_path)
# 初始化RapidOCR引擎
self.ocr_engine = self._initialize_ocr(ocr_config_path)
# 校验核心依赖是否就绪
self._check_dependencies()
# 设置置信度阈值限制在0~1范围
self.OCR_CONFIDENCE_THRESHOLD = max(0.0, min(ocr_confidence_threshold, 1.0))
print(f"OCR置信度阈值已设置范围0~1: {self.OCR_CONFIDENCE_THRESHOLD:.4f}")
def _load_forbidden_words(self, path: str) -> set:
"""
从TXT文件加载违禁词去重、过滤空行支持UTF-8编码
"""
forbidden_words = set()
# 检查文件是否存在
if not os.path.exists(path):
print(f"错误:违禁词文件不存在: {path}")
return forbidden_words
# 读取文件并处理内容
try:
with open(path, 'r', encoding='utf-8') as f:
forbidden_words = {
line.strip() for line in f
if line.strip() # 跳过空行或纯空格行
}
print(f"成功加载违禁词: {len(forbidden_words)} 个(已去重)")
except UnicodeDecodeError:
print(f"错误违禁词文件编码错误需UTF-8: {path}")
except PermissionError:
print(f"错误:无权限读取违禁词文件: {path}")
except Exception as e:
print(f"错误:加载违禁词失败: {str(e)}")
return forbidden_words
def _initialize_ocr(self, config_path: str) -> RapidOCR | None:
"""
初始化RapidOCR引擎校验配置文件、捕获初始化异常
"""
print("开始初始化RapidOCR引擎...")
# 检查配置文件是否存在
if not os.path.exists(config_path):
print(f"错误OCR配置文件不存在: {config_path}")
return None
# 初始化OCR引擎
try:
ocr_engine = RapidOCR(config_path=config_path)
print("RapidOCR引擎初始化成功")
return ocr_engine
except ImportError:
print("错误RapidOCR依赖未安装需执行pip install rapidocr-onnxruntime")
except Exception as e:
print(f"错误RapidOCR初始化失败: {str(e)}")
return None
def _check_dependencies(self) -> None:
"""校验OCR引擎和违禁词列表是否就绪"""
if not self.ocr_engine:
print("警告:⚠️ OCR引擎未就绪违禁词检测功能将禁用")
if not self.forbidden_words:
print("警告:⚠️ 违禁词列表为空,违禁词检测功能将禁用")
def detect(self, frame) -> tuple[bool, list, list]:
"""
对单帧图像进行OCR违禁词检测核心方法
Args:
frame: 输入图像帧NumPy数组BGR格式cv2读取的图像
Returns:
tuple[bool, list, list]:
- 第一个元素是否检测到违禁词True/False
- 第二个元素:检测到的违禁词列表(空列表表示无违禁词);
- 第三个元素:对应违禁词的置信度列表(与违禁词列表一一对应)。
"""
# 初始化返回结果
has_violation = False
violation_words = []
violation_confs = []
# 前置校验
if frame is None or frame.size == 0:
print("警告输入图像帧为空或无效跳过OCR检测")
return has_violation, violation_words, violation_confs
if not self.ocr_engine or not self.forbidden_words:
print("OCR引擎未就绪或违禁词为空跳过OCR检测")
return has_violation, violation_words, violation_confs
try:
# 执行OCR识别
print("开始执行OCR识别...")
ocr_result = self.ocr_engine(frame)
print(f"RapidOCR原始结果: {ocr_result}")
# 校验OCR结果是否有效
if ocr_result is None:
print("OCR识别未返回任何结果图像无文本或识别失败")
return has_violation, violation_words, violation_confs
# 检查txts和scores是否存在且不为None
if not hasattr(ocr_result, 'txts') or ocr_result.txts is None:
print("警告OCR结果中txts为None或不存在")
return has_violation, violation_words, violation_confs
if not hasattr(ocr_result, 'scores') or ocr_result.scores is None:
print("警告OCR结果中scores为None或不存在")
return has_violation, violation_words, violation_confs
# 转为列表并去None
if not isinstance(ocr_result.txts, (list, tuple)):
print(f"警告OCR txts不是可迭代类型实际类型: {type(ocr_result.txts)}")
texts = []
else:
texts = [txt.strip() for txt in ocr_result.txts if txt and isinstance(txt, str)]
if not isinstance(ocr_result.scores, (list, tuple)):
print(f"警告OCR scores不是可迭代类型实际类型: {type(ocr_result.scores)}")
confidences = []
else:
confidences = [conf for conf in ocr_result.scores if conf and isinstance(conf, (int, float))]
# 校验文本和置信度列表长度是否一致
if len(texts) != len(confidences):
print(f"警告OCR文本与置信度数量不匹配文本{len(texts)}个,置信度{len(confidences)}个),跳过检测")
return has_violation, violation_words, violation_confs
if len(texts) == 0:
print("OCR未识别到任何有效文本")
return has_violation, violation_words, violation_confs
# 遍历识别结果,筛选违禁词
print(f"开始筛选违禁词(阈值{self.OCR_CONFIDENCE_THRESHOLD:.4f}")
for text, conf in zip(texts, confidences):
if conf < self.OCR_CONFIDENCE_THRESHOLD:
print(f"文本 '{text}' 置信度{conf:.4f} < 阈值,跳过")
continue
matched_words = [word for word in self.forbidden_words if word in text]
if matched_words:
has_violation = True
violation_words.extend(matched_words)
violation_confs.extend([conf] * len(matched_words))
print(f"警告:检测到违禁词: {matched_words}(来源文本: '{text}',置信度: {conf:.4f}")
except Exception as e:
print(f"错误OCR检测过程异常: {str(e)}")
return has_violation, violation_words, violation_confs