Files
video/ocr/ocr_violation_detector.py
2025-09-03 14:38:42 +08:00

255 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import logging
from rapidocr import RapidOCR
class OCRViolationDetector:
"""
封装RapidOCR引擎用于检测图像帧中的违禁词。
核心功能加载违禁词、初始化OCR引擎、单帧图像违禁词检测
"""
def __init__(self,
forbidden_words_path: str,
ocr_config_path: str,
ocr_confidence_threshold: float = 0.5,
log_level: int = logging.INFO,
log_file: str = None):
"""
初始化OCR引擎、违禁词列表和日志配置。
Args:
forbidden_words_path (str): 违禁词列表 .txt 文件的路径。
ocr_config_path (str): OCR配置文件如1.yaml的路径。
ocr_confidence_threshold (float): OCR识别结果的置信度阈值0~1
log_level (int): 日志级别默认为logging.INFO。
log_file (str, optional): 日志文件路径,如不提供则只输出到控制台。
"""
# 初始化日志(确保先初始化日志,后续操作可正常打日志)
self.logger = self._setup_logger(log_level, log_file)
# 加载违禁词(优先级:先加载配置,再初始化引擎)
self.forbidden_words = self._load_forbidden_words(forbidden_words_path)
# 初始化RapidOCR引擎传入配置文件路径
self.ocr_engine = self._initialize_ocr(ocr_config_path)
# 校验核心依赖是否就绪
self._check_dependencies()
# 设置置信度阈值限制在0~1范围避免非法值
self.OCR_CONFIDENCE_THRESHOLD = max(0.0, min(ocr_confidence_threshold, 1.0))
self.logger.info(f"OCR置信度阈值已设置范围0~1: {self.OCR_CONFIDENCE_THRESHOLD:.4f}")
def _setup_logger(self, log_level: int, log_file: str = None) -> logging.Logger:
"""
配置日志系统(避免重复添加处理器,支持控制台+文件双输出)
Args:
log_level: 日志级别如logging.DEBUG、logging.INFO
log_file: 日志文件路径为None时仅输出到控制台。
Returns:
logging.Logger: 配置好的日志实例。
"""
logger = logging.getLogger('OCRViolationDetector')
logger.setLevel(log_level)
# 避免重复添加处理器(防止日志重复输出)
if logger.handlers:
return logger
# 定义日志格式(包含时间、模块名、级别、内容)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# 1. 添加控制台处理器
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# 2. 若指定日志文件,添加文件处理器(自动创建目录)
if log_file:
try:
log_dir = os.path.dirname(log_file)
# 若日志目录不存在,自动创建
if log_dir and not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
self.logger.debug(f"自动创建日志目录: {log_dir}")
file_handler = logging.FileHandler(log_file, encoding='utf-8')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.info(f"日志文件已配置: {log_file}")
except Exception as e:
logger.warning(f"创建日志文件失败(仅控制台输出): {str(e)}")
return logger
def _load_forbidden_words(self, path: str) -> set:
"""
从TXT文件加载违禁词去重、过滤空行支持UTF-8编码
Args:
path (str): 违禁词TXT文件路径。
Returns:
set: 去重后的违禁词集合(空集合表示加载失败)。
"""
forbidden_words = set()
# 第一步:检查文件是否存在
if not os.path.exists(path):
self.logger.error(f"违禁词文件不存在: {path}")
return forbidden_words
# 第二步:读取文件并处理内容
try:
with open(path, 'r', encoding='utf-8') as f:
# 过滤空行、去除首尾空格、去重
forbidden_words = {
line.strip() for line in f
if line.strip() # 跳过空行或纯空格行
}
self.logger.info(f"成功加载违禁词: {len(forbidden_words)} 个(已去重)")
self.logger.debug(f"违禁词列表: {forbidden_words}")
except UnicodeDecodeError:
self.logger.error(f"违禁词文件编码错误需UTF-8: {path}")
except PermissionError:
self.logger.error(f"无权限读取违禁词文件: {path}")
except Exception as e:
self.logger.error(f"加载违禁词失败: {str(e)}", exc_info=True)
return forbidden_words
def _initialize_ocr(self, config_path: str) -> RapidOCR | None:
"""
初始化RapidOCR引擎校验配置文件、捕获初始化异常
Args:
config_path (str): RapidOCR配置文件如1.yaml路径。
Returns:
RapidOCR | None: OCR引擎实例None表示初始化失败
"""
self.logger.info("开始初始化RapidOCR引擎...")
# 第一步:检查配置文件是否存在
if not os.path.exists(config_path):
self.logger.error(f"OCR配置文件不存在: {config_path}")
return None
# 第二步初始化OCR引擎捕获RapidOCR相关异常
try:
ocr_engine = RapidOCR(config_path=config_path)
self.logger.info("RapidOCR引擎初始化成功")
return ocr_engine
except ImportError:
self.logger.error("RapidOCR依赖未安装需执行pip install rapidocr-onnxruntime")
except Exception as e:
self.logger.error(f"RapidOCR初始化失败: {str(e)}", exc_info=True)
return None
def _check_dependencies(self) -> None:
"""校验OCR引擎和违禁词列表是否就绪输出警告日志"""
if not self.ocr_engine:
self.logger.warning("⚠️ OCR引擎未就绪违禁词检测功能将禁用")
if not self.forbidden_words:
self.logger.warning("⚠️ 违禁词列表为空,违禁词检测功能将禁用")
def detect(self, frame) -> tuple[bool, list, list]:
"""
对单帧图像进行OCR违禁词检测核心方法
Args:
frame: 输入图像帧NumPy数组BGR格式cv2读取的图像
Returns:
tuple[bool, list, list]:
- 第一个元素是否检测到违禁词True/False
- 第二个元素:检测到的违禁词列表(空列表表示无违禁词);
- 第三个元素:对应违禁词的置信度列表(与违禁词列表一一对应)。
"""
# 初始化返回结果
has_violation = False
violation_words = []
violation_confs = []
# 前置校验1. 图像帧是否有效 2. OCR引擎是否就绪 3. 违禁词是否存在
if frame is None or frame.size == 0:
self.logger.warning("输入图像帧为空或无效跳过OCR检测")
return has_violation, violation_words, violation_confs
if not self.ocr_engine or not self.forbidden_words:
self.logger.debug("OCR引擎未就绪或违禁词为空跳过OCR检测")
return has_violation, violation_words, violation_confs
try:
# 1. 执行OCR识别获取RapidOCR原始结果
self.logger.debug("开始执行OCR识别...")
ocr_result = self.ocr_engine(frame)
self.logger.debug(f"RapidOCR原始结果: {ocr_result}")
# 2. 校验OCR结果是否有效避免None或格式异常
if ocr_result is None:
self.logger.debug("OCR识别未返回任何结果图像无文本或识别失败")
return has_violation, violation_words, violation_confs
# 3. 检查txts和scores是否存在且不为None
if not hasattr(ocr_result, 'txts') or ocr_result.txts is None:
self.logger.warning("OCR结果中txts为None或不存在")
return has_violation, violation_words, violation_confs
if not hasattr(ocr_result, 'scores') or ocr_result.scores is None:
self.logger.warning("OCR结果中scores为None或不存在")
return has_violation, violation_words, violation_confs
# 4. 转为列表并去None防止单个元素为None
# 确保txts是可迭代的如果不是则转为空列表
if not isinstance(ocr_result.txts, (list, tuple)):
self.logger.warning(f"OCR txts不是可迭代类型实际类型: {type(ocr_result.txts)}")
texts = []
else:
texts = [txt.strip() for txt in ocr_result.txts if txt and isinstance(txt, str)]
# 确保scores是可迭代的如果不是则转为空列表
if not isinstance(ocr_result.scores, (list, tuple)):
self.logger.warning(f"OCR scores不是可迭代类型实际类型: {type(ocr_result.scores)}")
confidences = []
else:
confidences = [conf for conf in ocr_result.scores if conf and isinstance(conf, (int, float))]
# 5. 校验文本和置信度列表长度是否一致避免zip迭代错误
if len(texts) != len(confidences):
self.logger.warning(
f"OCR文本与置信度数量不匹配文本{len(texts)}个,置信度{len(confidences)}个),跳过检测")
return has_violation, violation_words, violation_confs
if len(texts) == 0:
self.logger.debug("OCR未识别到任何有效文本")
return has_violation, violation_words, violation_confs
# 6. 遍历识别结果,筛选违禁词(按置信度阈值过滤)
self.logger.debug(f"开始筛选违禁词(阈值{self.OCR_CONFIDENCE_THRESHOLD:.4f}")
for text, conf in zip(texts, confidences):
# 过滤低置信度结果
if conf < self.OCR_CONFIDENCE_THRESHOLD:
self.logger.debug(f"文本 '{text}' 置信度{conf:.4f} < 阈值,跳过")
continue
# 检查当前文本是否包含违禁词(支持一个文本含多个违禁词)
matched_words = [word for word in self.forbidden_words if word in text]
if matched_words:
has_violation = True
# 记录所有匹配的违禁词和对应置信度
violation_words.extend(matched_words)
violation_confs.extend([conf] * len(matched_words)) # 一个文本对应多个违禁词时,置信度复用
self.logger.warning(f"检测到违禁词: {matched_words}(来源文本: '{text}',置信度: {conf:.4f}")
except Exception as e:
# 捕获所有异常,确保不中断上层调用
self.logger.error(f"OCR检测过程异常: {str(e)}", exc_info=True)
return has_violation, violation_words, violation_confs