Files
video/ocr/ocr_violation_detector.py
2025-09-02 23:06:36 +08:00

160 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import logging
from rapidocr import RapidOCR
class OCRViolationDetector:
"""
封装RapidOCR引擎用于检测图像帧中的违禁词。
"""
def __init__(self, forbidden_words_path: str, ocr_confidence_threshold: float = 0.5,
log_level: int = logging.INFO, log_file: str = None):
"""
初始化OCR引擎、违禁词列表和日志配置。
Args:
forbidden_words_path (str): 违禁词列表 .txt 文件的路径。
ocr_confidence_threshold (float): OCR识别结果的置信度阈值。
log_level (int): 日志级别默认为logging.INFO
log_file (str, optional): 日志文件路径,如不提供则只输出到控制台
"""
# 初始化日志
self.logger = self._setup_logger(log_level, log_file)
# 加载违禁词
self.forbidden_words = self._load_forbidden_words(forbidden_words_path)
# 初始化OCR引擎
self.ocr_engine = self._initialize_ocr()
# 设置置信度阈值
self.OCR_CONFIDENCE_THRESHOLD = ocr_confidence_threshold
self.logger.info(f"OCR置信度阈值设置为: {ocr_confidence_threshold}")
def _setup_logger(self, log_level: int, log_file: str = None) -> logging.Logger:
"""
配置日志系统
Args:
log_level: 日志级别
log_file: 日志文件路径如为None则只输出到控制台
Returns:
配置好的logger实例
"""
# 创建logger
logger = logging.getLogger('OCRViolationDetector')
logger.setLevel(log_level)
# 避免重复添加处理器
if logger.handlers:
return logger
# 定义日志格式
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# 添加控制台处理器
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# 如果提供了日志文件路径,则添加文件处理器
if log_file:
try:
# 确保日志目录存在
log_dir = os.path.dirname(log_file)
if log_dir and not os.path.exists(log_dir):
os.makedirs(log_dir)
file_handler = logging.FileHandler(log_file, encoding='utf-8')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.info(f"日志文件将保存至: {log_file}")
except Exception as e:
logger.warning(f"无法创建日志文件处理器: {str(e)},仅输出至控制台")
return logger
def _load_forbidden_words(self, path):
"""从txt文件加载违禁词列表"""
words = set()
if not os.path.exists(path):
self.logger.warning(f"警告:未找到违禁词文件 {path},将跳过违禁词检测")
return words
try:
with open(path, 'r', encoding='utf-8') as f:
# 去除每行首尾空格和换行符,过滤空行
words = {line.strip() for line in f if line.strip()}
self.logger.info(f"成功加载 {len(words)} 个违禁词。")
except Exception as e:
self.logger.error(f"加载违禁词文件失败:{str(e)},将跳过违禁词检测")
return words
def _initialize_ocr(self):
"""初始化RapidOCR引擎"""
self.logger.info("正在初始化RapidOCR引擎...")
config_path = r"D:\Git\bin\video\ocr\config\1.yaml"
try:
# 检查配置文件是否存在
if not os.path.exists(config_path):
self.logger.error(f"RapidOCR配置文件不存在: {config_path}")
return None
engine = RapidOCR(
config_path=config_path
)
self.logger.info("RapidOCR引擎初始化成功。")
return engine
except Exception as e:
self.logger.error(f"RapidOCR引擎初始化失败: {e}")
return None
def detect(self, frame):
"""
对单帧图像进行OCR检测所有出现的违禁词并返回列表
返回格式:(是否有违禁词, 违禁词列表, 对应的置信度列表)
"""
print("收到帧")
if not self.ocr_engine or not self.forbidden_words:
return False, [], []
all_prohibited = [] # 存储所有检测到的违禁词
all_confidences = [] # 存储对应违禁词的置信度
try:
# 执行OCR识别
result = self.ocr_engine(frame)
self.logger.debug(f"RapidOCR 原始返回结果: {result}")
if result is None:
return False, [], []
# 提取文本和置信度适配RapidOCR的结果格式
texts = result.txts if hasattr(result, 'txts') else []
confidences = result.scores if hasattr(result, 'scores') else []
# 遍历所有识别结果,收集所有违禁词
for text, conf in zip(texts, confidences):
if conf < self.OCR_CONFIDENCE_THRESHOLD:
self.logger.debug(f"文本 '{text}' 置信度 {conf:.4f} 低于阈值,跳过")
continue
# 检查当前文本中是否包含多个违禁词
for word in self.forbidden_words:
if word in text:
self.logger.warning(f"OCR检测到违禁词: '{word}' (来自文本: '{text}') 置信度: {conf:.4f}")
all_prohibited.append(word)
all_confidences.append(conf)
except Exception as e:
self.logger.error(f"OCR检测过程中发生错误: {e}", exc_info=True)
# 返回检测结果(是否有违禁词、所有违禁词列表、对应置信度列表)
return len(all_prohibited) > 0, all_prohibited, all_confidences