Files
video/ocr/ocr_violation_detector.py
2025-09-02 19:49:54 +08:00

136 lines
5.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
from rapidocr import RapidOCR
from logger_config import logger
class OCRViolationDetector:
"""
封装RapidOCR引擎用于检测图像帧中的违禁词。
"""
def __init__(self, forbidden_words_path: str, ocr_confidence_threshold: float = 0.5): # 降低阈值提高检出率
"""
初始化OCR引擎和违禁词列表。
Args:
forbidden_words_path (str): 违禁词列表 .txt 文件的路径。
ocr_confidence_threshold (float): OCR识别结果的置信度阈值。
"""
self.forbidden_words = self._load_forbidden_words(forbidden_words_path)
self.ocr_engine = self._initialize_ocr()
self.OCR_CONFIDENCE_THRESHOLD = ocr_confidence_threshold
def _load_forbidden_words(self, path):
"""从txt文件加载违禁词列表与rapidocr_test.py保持一致"""
words = set()
if not os.path.exists(path):
logger.warning(f"警告:未找到违禁词文件 {path},将跳过违禁词检测")
return words
try:
with open(path, 'r', encoding='utf-8') as f:
# 去除每行首尾空格和换行符,过滤空行(不排除注释行,与测试代码统一)
words = {line.strip() for line in f if line.strip()}
logger.info(f"成功加载 {len(words)} 个违禁词。")
except Exception as e:
logger.error(f"加载违禁词文件失败:{str(e)},将跳过违禁词检测")
return words
def _initialize_ocr(self):
"""初始化RapidOCR引擎"""
logger.info("正在初始化RapidOCR引擎...")
config_path = r".\config\1.yaml"
try:
engine = RapidOCR(
config_path=config_path
)
logger.info("RapidOCR引擎初始化成功。")
return engine
except Exception as e:
logger.error(f"RapidOCR引擎初始化失败: {e}")
return None
def detect(self, frame):
"""
对单帧图像进行OCR检测所有出现的违禁词并返回列表
返回格式:(是否有违禁词, 违禁词列表, 对应的置信度列表)
"""
if not self.ocr_engine or not self.forbidden_words:
return False, [], []
all_prohibited = [] # 存储所有检测到的违禁词
all_confidences = [] # 存储对应违禁词的置信度
try:
# 执行OCR识别修正调用方式与测试代码一致
result = self.ocr_engine(frame)
logger.debug(f"RapidOCR 原始返回结果: {result}")
if result is None:
return False, [], []
# 提取文本和置信度适配RapidOCR的结果格式
texts = result.txts if hasattr(result, 'txts') else []
confidences = result.scores if hasattr(result, 'scores') else []
# 遍历所有识别结果,收集所有违禁词
for text, conf in zip(texts, confidences):
if conf < self.OCR_CONFIDENCE_THRESHOLD:
logger.debug(f"文本 '{text}' 置信度 {conf:.4f} 低于阈值,跳过")
continue
# 检查当前文本中是否包含多个违禁词
for word in self.forbidden_words:
if word in text:
logger.warning(f"OCR检测到违禁词: '{word}' (来自文本: '{text}') 置信度: {conf:.4f}")
all_prohibited.append(word)
all_confidences.append(conf)
except Exception as e:
logger.error(f"OCR检测过程中发生错误: {e}", exc_info=True)
# 返回检测结果(是否有违禁词、所有违禁词列表、对应置信度列表)
return len(all_prohibited) > 0, all_prohibited, all_confidences
# def test_single_image():
# """测试单张图片的OCR违规检测显示所有违禁词"""
# TEST_IMAGE_PATH = r"ocr/images/img_7.png" # 修正路径格式
# FORBIDDEN_WORDS_PATH = r"ocr/forbidden_words.txt"
# CONFIDENCE_THRESHOLD = 0.5
#
# detector = OCRViolationDetector(
# forbidden_words_path=FORBIDDEN_WORDS_PATH,
# ocr_confidence_threshold=CONFIDENCE_THRESHOLD
# )
#
# if not os.path.exists(TEST_IMAGE_PATH):
# print(f"错误:图片文件不存在 - {TEST_IMAGE_PATH}")
# return
#
# frame = cv2.imread(TEST_IMAGE_PATH)
# if frame is None:
# print(f"错误:无法读取图片 - {TEST_IMAGE_PATH}")
# return
#
# # 执行检测
# has_violation, words, confidences = detector.detect(frame)
#
# # 输出所有检测到的违禁词
# if has_violation:
# print(f"测试结果:图片中共检测到 {len(words)} 个违禁词:")
# for word, conf in zip(words, confidences):
# print(f"- {word}(置信度:{conf:.4f}")
# else:
# print("测试结果:图片中未检测到违禁词")
#
#
# if __name__ == "__main__":
# print("开始单张图片OCR违规检测测试...")
# test_single_image()
# print("测试完成")