This commit is contained in:
2025-09-03 16:22:21 +08:00
parent b7773f5f00
commit d83923d06b
5 changed files with 211 additions and 342 deletions

View File

@ -1,6 +1,5 @@
import os
import cv2
import logging
from rapidocr import RapidOCR
@ -13,153 +12,85 @@ class OCRViolationDetector:
def __init__(self,
forbidden_words_path: str,
ocr_config_path: str,
ocr_confidence_threshold: float = 0.5,
log_level: int = logging.INFO,
log_file: str = None):
ocr_confidence_threshold: float = 0.5):
"""
初始化OCR引擎违禁词列表和日志配置
初始化OCR引擎违禁词列表。
Args:
forbidden_words_path (str): 违禁词列表 .txt 文件的路径。
ocr_config_path (str): OCR配置文件如1.yaml的路径。
ocr_confidence_threshold (float): OCR识别结果的置信度阈值0~1
log_level (int): 日志级别默认为logging.INFO。
log_file (str, optional): 日志文件路径,如不提供则只输出到控制台。
"""
# 初始化日志(确保先初始化日志,后续操作可正常打日志)
self.logger = self._setup_logger(log_level, log_file)
# 加载违禁词(优先级:先加载配置,再初始化引擎)
# 加载违禁词
self.forbidden_words = self._load_forbidden_words(forbidden_words_path)
# 初始化RapidOCR引擎(传入配置文件路径)
# 初始化RapidOCR引擎
self.ocr_engine = self._initialize_ocr(ocr_config_path)
# 校验核心依赖是否就绪
self._check_dependencies()
# 设置置信度阈值限制在0~1范围,避免非法值
# 设置置信度阈值限制在0~1范围
self.OCR_CONFIDENCE_THRESHOLD = max(0.0, min(ocr_confidence_threshold, 1.0))
self.logger.info(f"OCR置信度阈值已设置范围0~1: {self.OCR_CONFIDENCE_THRESHOLD:.4f}")
def _setup_logger(self, log_level: int, log_file: str = None) -> logging.Logger:
"""
配置日志系统(避免重复添加处理器,支持控制台+文件双输出)
Args:
log_level: 日志级别如logging.DEBUG、logging.INFO
log_file: 日志文件路径为None时仅输出到控制台。
Returns:
logging.Logger: 配置好的日志实例。
"""
logger = logging.getLogger('OCRViolationDetector')
logger.setLevel(log_level)
# 避免重复添加处理器(防止日志重复输出)
if logger.handlers:
return logger
# 定义日志格式(包含时间、模块名、级别、内容)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# 1. 添加控制台处理器
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# 2. 若指定日志文件,添加文件处理器(自动创建目录)
if log_file:
try:
log_dir = os.path.dirname(log_file)
# 若日志目录不存在,自动创建
if log_dir and not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
self.logger.debug(f"自动创建日志目录: {log_dir}")
file_handler = logging.FileHandler(log_file, encoding='utf-8')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.info(f"日志文件已配置: {log_file}")
except Exception as e:
logger.warning(f"创建日志文件失败(仅控制台输出): {str(e)}")
return logger
print(f"OCR置信度阈值已设置范围0~1: {self.OCR_CONFIDENCE_THRESHOLD:.4f}")
def _load_forbidden_words(self, path: str) -> set:
"""
从TXT文件加载违禁词去重、过滤空行支持UTF-8编码
Args:
path (str): 违禁词TXT文件路径。
Returns:
set: 去重后的违禁词集合(空集合表示加载失败)。
"""
forbidden_words = set()
# 第一步:检查文件是否存在
# 检查文件是否存在
if not os.path.exists(path):
self.logger.error(f"违禁词文件不存在: {path}")
print(f"错误:违禁词文件不存在: {path}")
return forbidden_words
# 第二步:读取文件并处理内容
# 读取文件并处理内容
try:
with open(path, 'r', encoding='utf-8') as f:
# 过滤空行、去除首尾空格、去重
forbidden_words = {
line.strip() for line in f
if line.strip() # 跳过空行或纯空格行
}
self.logger.info(f"成功加载违禁词: {len(forbidden_words)} 个(已去重)")
self.logger.debug(f"违禁词列表: {forbidden_words}")
print(f"成功加载违禁词: {len(forbidden_words)} 个(已去重)")
except UnicodeDecodeError:
self.logger.error(f"违禁词文件编码错误需UTF-8: {path}")
print(f"错误:违禁词文件编码错误需UTF-8: {path}")
except PermissionError:
self.logger.error(f"无权限读取违禁词文件: {path}")
print(f"错误:无权限读取违禁词文件: {path}")
except Exception as e:
self.logger.error(f"加载违禁词失败: {str(e)}", exc_info=True)
print(f"错误:加载违禁词失败: {str(e)}")
return forbidden_words
def _initialize_ocr(self, config_path: str) -> RapidOCR | None:
"""
初始化RapidOCR引擎校验配置文件、捕获初始化异常
Args:
config_path (str): RapidOCR配置文件如1.yaml路径。
Returns:
RapidOCR | None: OCR引擎实例None表示初始化失败
"""
self.logger.info("开始初始化RapidOCR引擎...")
print("开始初始化RapidOCR引擎...")
# 第一步:检查配置文件是否存在
# 检查配置文件是否存在
if not os.path.exists(config_path):
self.logger.error(f"OCR配置文件不存在: {config_path}")
print(f"错误:OCR配置文件不存在: {config_path}")
return None
# 第二步:初始化OCR引擎捕获RapidOCR相关异常
# 初始化OCR引擎
try:
ocr_engine = RapidOCR(config_path=config_path)
self.logger.info("RapidOCR引擎初始化成功")
print("RapidOCR引擎初始化成功")
return ocr_engine
except ImportError:
self.logger.error("RapidOCR依赖未安装需执行pip install rapidocr-onnxruntime")
print("错误:RapidOCR依赖未安装需执行pip install rapidocr-onnxruntime")
except Exception as e:
self.logger.error(f"RapidOCR初始化失败: {str(e)}", exc_info=True)
print(f"错误:RapidOCR初始化失败: {str(e)}")
return None
def _check_dependencies(self) -> None:
"""校验OCR引擎和违禁词列表是否就绪(输出警告日志)"""
"""校验OCR引擎和违禁词列表是否就绪"""
if not self.ocr_engine:
self.logger.warning("⚠️ OCR引擎未就绪违禁词检测功能将禁用")
print("警告:⚠️ OCR引擎未就绪违禁词检测功能将禁用")
if not self.forbidden_words:
self.logger.warning("⚠️ 违禁词列表为空,违禁词检测功能将禁用")
print("警告:⚠️ 违禁词列表为空,违禁词检测功能将禁用")
def detect(self, frame) -> tuple[bool, list, list]:
"""
@ -179,76 +110,69 @@ class OCRViolationDetector:
violation_words = []
violation_confs = []
# 前置校验1. 图像帧是否有效 2. OCR引擎是否就绪 3. 违禁词是否存在
# 前置校验
if frame is None or frame.size == 0:
self.logger.warning("输入图像帧为空或无效跳过OCR检测")
print("警告:输入图像帧为空或无效跳过OCR检测")
return has_violation, violation_words, violation_confs
if not self.ocr_engine or not self.forbidden_words:
self.logger.debug("OCR引擎未就绪或违禁词为空跳过OCR检测")
print("OCR引擎未就绪或违禁词为空跳过OCR检测")
return has_violation, violation_words, violation_confs
try:
# 1. 执行OCR识别获取RapidOCR原始结果
self.logger.debug("开始执行OCR识别...")
# 执行OCR识别
print("开始执行OCR识别...")
ocr_result = self.ocr_engine(frame)
self.logger.debug(f"RapidOCR原始结果: {ocr_result}")
print(f"RapidOCR原始结果: {ocr_result}")
# 2. 校验OCR结果是否有效避免None或格式异常
# 校验OCR结果是否有效
if ocr_result is None:
self.logger.debug("OCR识别未返回任何结果图像无文本或识别失败")
print("OCR识别未返回任何结果图像无文本或识别失败")
return has_violation, violation_words, violation_confs
# 3. 检查txts和scores是否存在且不为None
# 检查txts和scores是否存在且不为None
if not hasattr(ocr_result, 'txts') or ocr_result.txts is None:
self.logger.warning("OCR结果中txts为None或不存在")
print("警告:OCR结果中txts为None或不存在")
return has_violation, violation_words, violation_confs
if not hasattr(ocr_result, 'scores') or ocr_result.scores is None:
self.logger.warning("OCR结果中scores为None或不存在")
print("警告:OCR结果中scores为None或不存在")
return has_violation, violation_words, violation_confs
# 4. 转为列表并去None防止单个元素为None
# 确保txts是可迭代的如果不是则转为空列表
# 转为列表并去None
if not isinstance(ocr_result.txts, (list, tuple)):
self.logger.warning(f"OCR txts不是可迭代类型实际类型: {type(ocr_result.txts)}")
print(f"警告:OCR txts不是可迭代类型实际类型: {type(ocr_result.txts)}")
texts = []
else:
texts = [txt.strip() for txt in ocr_result.txts if txt and isinstance(txt, str)]
# 确保scores是可迭代的如果不是则转为空列表
if not isinstance(ocr_result.scores, (list, tuple)):
self.logger.warning(f"OCR scores不是可迭代类型实际类型: {type(ocr_result.scores)}")
print(f"警告:OCR scores不是可迭代类型实际类型: {type(ocr_result.scores)}")
confidences = []
else:
confidences = [conf for conf in ocr_result.scores if conf and isinstance(conf, (int, float))]
# 5. 校验文本和置信度列表长度是否一致避免zip迭代错误
# 校验文本和置信度列表长度是否一致
if len(texts) != len(confidences):
self.logger.warning(
f"OCR文本与置信度数量不匹配文本{len(texts)}个,置信度{len(confidences)}个),跳过检测")
print(f"警告OCR文本与置信度数量不匹配文本{len(texts)}个,置信度{len(confidences)}个),跳过检测")
return has_violation, violation_words, violation_confs
if len(texts) == 0:
self.logger.debug("OCR未识别到任何有效文本")
print("OCR未识别到任何有效文本")
return has_violation, violation_words, violation_confs
# 6. 遍历识别结果,筛选违禁词(按置信度阈值过滤)
self.logger.debug(f"开始筛选违禁词(阈值{self.OCR_CONFIDENCE_THRESHOLD:.4f}")
# 遍历识别结果,筛选违禁词
print(f"开始筛选违禁词(阈值{self.OCR_CONFIDENCE_THRESHOLD:.4f}")
for text, conf in zip(texts, confidences):
# 过滤低置信度结果
if conf < self.OCR_CONFIDENCE_THRESHOLD:
self.logger.debug(f"文本 '{text}' 置信度{conf:.4f} < 阈值,跳过")
print(f"文本 '{text}' 置信度{conf:.4f} < 阈值,跳过")
continue
# 检查当前文本是否包含违禁词(支持一个文本含多个违禁词)
matched_words = [word for word in self.forbidden_words if word in text]
if matched_words:
has_violation = True
# 记录所有匹配的违禁词和对应置信度
violation_words.extend(matched_words)
violation_confs.extend([conf] * len(matched_words)) # 一个文本对应多个违禁词时,置信度复用
self.logger.warning(f"检测到违禁词: {matched_words}(来源文本: '{text}',置信度: {conf:.4f}")
violation_confs.extend([conf] * len(matched_words))
print(f"警告:检测到违禁词: {matched_words}(来源文本: '{text}',置信度: {conf:.4f}")
except Exception as e:
# 捕获所有异常,确保不中断上层调用
self.logger.error(f"OCR检测过程异常: {str(e)}", exc_info=True)
print(f"错误OCR检测过程异常: {str(e)}")
return has_violation, violation_words, violation_confs
return has_violation, violation_words, violation_confs