Files
video/ocr/ocr_violation_detector.py
2025-09-03 13:52:24 +08:00

233 lines
8.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import logging
from rapidocr import RapidOCR
class OCRViolationDetector:
"""
封装RapidOCR引擎用于检测图像帧中的违禁词。
"""
def __init__(self, forbidden_words_path: str, ocr_confidence_threshold: float = 0.5,
log_level: int = logging.INFO, log_file: str = None):
"""
初始化OCR引擎、违禁词列表和日志配置。
Args:
forbidden_words_path (str): 违禁词列表 .txt 文件的路径。
ocr_confidence_threshold (float): OCR识别结果的置信度阈值。
log_level (int): 日志级别默认为logging.INFO
log_file (str, optional): 日志文件路径,如不提供则只输出到控制台
"""
# 初始化日志
self.logger = self._setup_logger(log_level, log_file)
# 加载违禁词
self.forbidden_words = self._load_forbidden_words(forbidden_words_path)
# 初始化OCR引擎
self.ocr_engine = self._initialize_ocr()
# 设置置信度阈值
self.OCR_CONFIDENCE_THRESHOLD = ocr_confidence_threshold
self.logger.info(f"OCR置信度阈值设置为: {ocr_confidence_threshold}")
def _setup_logger(self, log_level: int, log_file: str = None) -> logging.Logger:
"""
配置日志系统
Args:
log_level: 日志级别
log_file: 日志文件路径如为None则只输出到控制台
Returns:
配置好的logger实例
"""
# 创建logger
logger = logging.getLogger('OCRViolationDetector')
logger.setLevel(log_level)
# 避免重复添加处理器
if logger.handlers:
return logger
# 定义日志格式
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# 添加控制台处理器
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# 如果提供了日志文件路径,则添加文件处理器
if log_file:
try:
# 确保日志目录存在
log_dir = os.path.dirname(log_file)
if log_dir and not os.path.exists(log_dir):
os.makedirs(log_dir)
file_handler = logging.FileHandler(log_file, encoding='utf-8')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.info(f"日志文件将保存至: {log_file}")
except Exception as e:
logger.warning(f"无法创建日志文件处理器: {str(e)},仅输出至控制台")
return logger
def _load_forbidden_words(self, path):
"""从txt文件加载违禁词列表"""
words = set()
if not os.path.exists(path):
self.logger.warning(f"警告:未找到违禁词文件 {path},将跳过违禁词检测")
return words
try:
with open(path, 'r', encoding='utf-8') as f:
# 去除每行首尾空格和换行符,过滤空行
words = {line.strip() for line in f if line.strip()}
self.logger.info(f"成功加载 {len(words)} 个违禁词。")
except Exception as e:
self.logger.error(f"加载违禁词文件失败:{str(e)},将跳过违禁词检测")
return words
def _initialize_ocr(self):
"""初始化RapidOCR引擎"""
self.logger.info("正在初始化RapidOCR引擎...")
config_path = r"D:\Git\bin\video\ocr\config\1.yaml"
try:
# 检查配置文件是否存在
if not os.path.exists(config_path):
self.logger.error(f"RapidOCR配置文件不存在: {config_path}")
return None
engine = RapidOCR(
config_path=config_path
)
self.logger.info("RapidOCR引擎初始化成功。")
return engine
except Exception as e:
self.logger.error(f"RapidOCR引擎初始化失败: {e}")
return None
def detect(self, frame):
"""
对单帧图像进行OCR检测所有出现的违禁词并返回列表
返回格式:(是否有违禁词, 违禁词列表, 对应的置信度列表)
"""
print("收到帧")
if not self.ocr_engine or not self.forbidden_words:
return False, [], []
all_prohibited = [] # 存储所有检测到的违禁词
all_confidences = [] # 存储对应违禁词的置信度
try:
# 执行OCR识别
result = self.ocr_engine(frame)
self.logger.debug(f"RapidOCR 原始返回结果: {result}")
if result is None:
return False, [], []
# 提取文本和置信度适配RapidOCR的结果格式
texts = result.txts if hasattr(result, 'txts') else []
confidences = result.scores if hasattr(result, 'scores') else []
# 遍历所有识别结果,收集所有违禁词
for text, conf in zip(texts, confidences):
if conf < self.OCR_CONFIDENCE_THRESHOLD:
self.logger.debug(f"文本 '{text}' 置信度 {conf:.4f} 低于阈值,跳过")
continue
# 检查当前文本中是否包含多个违禁词
for word in self.forbidden_words:
if word in text:
self.logger.warning(f"OCR检测到违禁词: '{word}' (来自文本: '{text}') 置信度: {conf:.4f}")
all_prohibited.append(word)
all_confidences.append(conf)
except Exception as e:
self.logger.error(f"OCR检测过程中发生错误: {e}", exc_info=True)
# 返回检测结果(是否有违禁词、所有违禁词列表、对应置信度列表)
return len(all_prohibited) > 0, all_prohibited, all_confidences
# def test_image(self, image_path: str, show_image: bool = True) -> tuple:
# """
# 对单张图片进行OCR违禁词检测并展示结果
#
# Args:
# image_path (str): 图片文件路径
# show_image (bool): 是否显示图片默认为True
#
# Returns:
# tuple: (是否有违禁词, 违禁词列表, 对应的置信度列表)
# """
# # 检查图片文件是否存在
# if not os.path.exists(image_path):
# self.logger.error(f"图片文件不存在: {image_path}")
# return False, [], []
#
# try:
# # 读取图片
# frame = cv2.imread(image_path)
# if frame is None:
# self.logger.error(f"无法读取图片: {image_path}")
# return False, [], []
#
# self.logger.info(f"开始处理图片: {image_path}")
#
# # 调用检测方法
# has_violation, violations, confidences = self.detect(frame)
#
# # 输出检测结果
# if has_violation:
# self.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
# for word, conf in zip(violations, confidences):
# self.logger.info(f"- {word} (置信度: {conf:.4f})")
# else:
# self.logger.info("图片中未检测到违禁词")
# # 显示图片(如果需要)
# if show_image:
# # 调整图片大小以便于显示(如果太大)
# height, width = frame.shape[:2]
# max_size = 800
# if max(height, width) > max_size:
# scale = max_size / max(height, width)
# frame = cv2.resize(frame, None, fx=scale, fy=scale)
#
# cv2.imshow(f"OCR检测结果: {'发现违禁词' if has_violation else '未发现违禁词'}", frame)
# cv2.waitKey(0) # 等待用户按键
# cv2.destroyAllWindows()
#
# return has_violation, violations, confidences
#
# except Exception as e:
# self.logger.error(f"处理图片时发生错误: {str(e)}", exc_info=True)
# return False, [], []
#
#
# # 使用示例
# if __name__ == "__main__":
# # 配置参数
# forbidden_words_path = "forbidden_words.txt" # 违禁词文件路径
# test_image_path = r"D:\Git\bin\video\ocr\images\img_7.png" # 测试图片路径
# ocr_threshold = 0.6 # OCR置信度阈值
#
# # 创建检测器实例
# detector = OCRViolationDetector(
# forbidden_words_path=forbidden_words_path,
# ocr_confidence_threshold=ocr_threshold,
# log_level=logging.INFO,
# log_file="ocr_detection.log"
# )
#
# # 测试图片
# detector.test_image(test_image_path, show_image=True)