ocr1.0
This commit is contained in:
17
ocr/forbidden_words.txt
Normal file
17
ocr/forbidden_words.txt
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
打倒习近平
|
||||||
|
打到习近平
|
||||||
|
打倒毛泽东
|
||||||
|
打到毛泽东
|
||||||
|
打到主席
|
||||||
|
打倒主席
|
||||||
|
打到共产主义
|
||||||
|
打倒共产主义
|
||||||
|
打到共产党
|
||||||
|
打倒共产党
|
||||||
|
胖猫
|
||||||
|
法轮功
|
||||||
|
法轮大法
|
||||||
|
法轮大法好
|
||||||
|
法轮功大法好
|
||||||
|
法轮
|
||||||
|
李洪志
|
44
ocr/logger_config.py
Normal file
44
ocr/logger_config.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
#日志文件
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
def setup_logger():
|
||||||
|
"""
|
||||||
|
配置一个全局日志记录器,支持输出到控制台和文件。
|
||||||
|
"""
|
||||||
|
# 创建一个日志记录器
|
||||||
|
logger = logging.getLogger("ViolationDetectorLogger")
|
||||||
|
logger.setLevel(logging.DEBUG) # 设置最低级别为DEBUG
|
||||||
|
|
||||||
|
# 如果已经有处理器了,就不要重复添加,防止日志重复打印
|
||||||
|
if logger.hasHandlers():
|
||||||
|
return logger
|
||||||
|
|
||||||
|
# --- 控制台处理器 ---
|
||||||
|
console_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
# 对于控制台,我们只显示INFO及以上级别的信息
|
||||||
|
console_handler.setLevel(logging.INFO)
|
||||||
|
console_formatter = logging.Formatter(
|
||||||
|
'%(asctime)s - %(levelname)s - [%(module)s:%(lineno)d] - %(message)s',
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
|
)
|
||||||
|
console_handler.setFormatter(console_formatter)
|
||||||
|
|
||||||
|
# --- 文件处理器 ---
|
||||||
|
file_handler = logging.FileHandler("violation_detector.log", mode='a', encoding='utf-8')
|
||||||
|
# 对于文件,我们记录所有DEBUG及以上级别的信息
|
||||||
|
file_handler.setLevel(logging.DEBUG)
|
||||||
|
file_formatter = logging.Formatter(
|
||||||
|
'%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
|
)
|
||||||
|
file_handler.setFormatter(file_formatter)
|
||||||
|
|
||||||
|
# 将处理器添加到日志记录器
|
||||||
|
logger.addHandler(console_handler)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
|
# 创建并导出logger实例
|
||||||
|
logger = setup_logger()
|
136
ocr/ocr_violation_detector.py
Normal file
136
ocr/ocr_violation_detector.py
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
from rapidocr import RapidOCR
|
||||||
|
from logger_config import logger
|
||||||
|
|
||||||
|
|
||||||
|
class OCRViolationDetector:
|
||||||
|
"""
|
||||||
|
封装RapidOCR引擎,用于检测图像帧中的违禁词。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, forbidden_words_path: str, ocr_confidence_threshold: float = 0.5): # 降低阈值提高检出率
|
||||||
|
"""
|
||||||
|
初始化OCR引擎和违禁词列表。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
forbidden_words_path (str): 违禁词列表 .txt 文件的路径。
|
||||||
|
ocr_confidence_threshold (float): OCR识别结果的置信度阈值。
|
||||||
|
"""
|
||||||
|
self.forbidden_words = self._load_forbidden_words(forbidden_words_path)
|
||||||
|
self.ocr_engine = self._initialize_ocr()
|
||||||
|
self.OCR_CONFIDENCE_THRESHOLD = ocr_confidence_threshold
|
||||||
|
|
||||||
|
def _load_forbidden_words(self, path):
|
||||||
|
"""从txt文件加载违禁词列表(与rapidocr_test.py保持一致)"""
|
||||||
|
words = set()
|
||||||
|
if not os.path.exists(path):
|
||||||
|
logger.warning(f"警告:未找到违禁词文件 {path},将跳过违禁词检测")
|
||||||
|
return words
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(path, 'r', encoding='utf-8') as f:
|
||||||
|
# 去除每行首尾空格和换行符,过滤空行(不排除注释行,与测试代码统一)
|
||||||
|
words = {line.strip() for line in f if line.strip()}
|
||||||
|
logger.info(f"成功加载 {len(words)} 个违禁词。")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加载违禁词文件失败:{str(e)},将跳过违禁词检测")
|
||||||
|
return words
|
||||||
|
|
||||||
|
def _initialize_ocr(self):
|
||||||
|
"""初始化RapidOCR引擎"""
|
||||||
|
logger.info("正在初始化RapidOCR引擎...")
|
||||||
|
|
||||||
|
config_path = r".\config\1.yaml"
|
||||||
|
try:
|
||||||
|
engine = RapidOCR(
|
||||||
|
config_path=config_path
|
||||||
|
)
|
||||||
|
logger.info("RapidOCR引擎初始化成功。")
|
||||||
|
return engine
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"RapidOCR引擎初始化失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def detect(self, frame):
|
||||||
|
"""
|
||||||
|
对单帧图像进行OCR,检测所有出现的违禁词并返回列表
|
||||||
|
返回格式:(是否有违禁词, 违禁词列表, 对应的置信度列表)
|
||||||
|
"""
|
||||||
|
if not self.ocr_engine or not self.forbidden_words:
|
||||||
|
return False, [], []
|
||||||
|
|
||||||
|
all_prohibited = [] # 存储所有检测到的违禁词
|
||||||
|
all_confidences = [] # 存储对应违禁词的置信度
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 执行OCR识别(修正调用方式,与测试代码一致)
|
||||||
|
result = self.ocr_engine(frame)
|
||||||
|
logger.debug(f"RapidOCR 原始返回结果: {result}")
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
return False, [], []
|
||||||
|
|
||||||
|
# 提取文本和置信度(适配RapidOCR的结果格式)
|
||||||
|
texts = result.txts if hasattr(result, 'txts') else []
|
||||||
|
confidences = result.scores if hasattr(result, 'scores') else []
|
||||||
|
|
||||||
|
# 遍历所有识别结果,收集所有违禁词
|
||||||
|
for text, conf in zip(texts, confidences):
|
||||||
|
if conf < self.OCR_CONFIDENCE_THRESHOLD:
|
||||||
|
logger.debug(f"文本 '{text}' 置信度 {conf:.4f} 低于阈值,跳过")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查当前文本中是否包含多个违禁词
|
||||||
|
for word in self.forbidden_words:
|
||||||
|
if word in text:
|
||||||
|
logger.warning(f"OCR检测到违禁词: '{word}' (来自文本: '{text}') 置信度: {conf:.4f}")
|
||||||
|
all_prohibited.append(word)
|
||||||
|
all_confidences.append(conf)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OCR检测过程中发生错误: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
# 返回检测结果(是否有违禁词、所有违禁词列表、对应置信度列表)
|
||||||
|
return len(all_prohibited) > 0, all_prohibited, all_confidences
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# def test_single_image():
|
||||||
|
# """测试单张图片的OCR违规检测(显示所有违禁词)"""
|
||||||
|
# TEST_IMAGE_PATH = r"ocr/images/img_7.png" # 修正路径格式
|
||||||
|
# FORBIDDEN_WORDS_PATH = r"ocr/forbidden_words.txt"
|
||||||
|
# CONFIDENCE_THRESHOLD = 0.5
|
||||||
|
#
|
||||||
|
# detector = OCRViolationDetector(
|
||||||
|
# forbidden_words_path=FORBIDDEN_WORDS_PATH,
|
||||||
|
# ocr_confidence_threshold=CONFIDENCE_THRESHOLD
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# if not os.path.exists(TEST_IMAGE_PATH):
|
||||||
|
# print(f"错误:图片文件不存在 - {TEST_IMAGE_PATH}")
|
||||||
|
# return
|
||||||
|
#
|
||||||
|
# frame = cv2.imread(TEST_IMAGE_PATH)
|
||||||
|
# if frame is None:
|
||||||
|
# print(f"错误:无法读取图片 - {TEST_IMAGE_PATH}")
|
||||||
|
# return
|
||||||
|
#
|
||||||
|
# # 执行检测
|
||||||
|
# has_violation, words, confidences = detector.detect(frame)
|
||||||
|
#
|
||||||
|
# # 输出所有检测到的违禁词
|
||||||
|
# if has_violation:
|
||||||
|
# print(f"测试结果:图片中共检测到 {len(words)} 个违禁词:")
|
||||||
|
# for word, conf in zip(words, confidences):
|
||||||
|
# print(f"- {word}(置信度:{conf:.4f})")
|
||||||
|
# else:
|
||||||
|
# print("测试结果:图片中未检测到违禁词")
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# print("开始单张图片OCR违规检测测试...")
|
||||||
|
# test_single_image()
|
||||||
|
# print("测试完成")
|
Reference in New Issue
Block a user