ocr1.0
This commit is contained in:
@ -7,227 +7,248 @@ from rapidocr import RapidOCR
|
||||
class OCRViolationDetector:
|
||||
"""
|
||||
封装RapidOCR引擎,用于检测图像帧中的违禁词。
|
||||
核心功能:加载违禁词、初始化OCR引擎、单帧图像违禁词检测
|
||||
"""
|
||||
|
||||
def __init__(self, forbidden_words_path: str, ocr_confidence_threshold: float = 0.5,
|
||||
log_level: int = logging.INFO, log_file: str = None):
|
||||
def __init__(self,
|
||||
forbidden_words_path: str,
|
||||
ocr_config_path: str,
|
||||
ocr_confidence_threshold: float = 0.5,
|
||||
log_level: int = logging.INFO,
|
||||
log_file: str = None):
|
||||
"""
|
||||
初始化OCR引擎、违禁词列表和日志配置。
|
||||
|
||||
Args:
|
||||
forbidden_words_path (str): 违禁词列表 .txt 文件的路径。
|
||||
ocr_confidence_threshold (float): OCR识别结果的置信度阈值。
|
||||
log_level (int): 日志级别,默认为logging.INFO
|
||||
log_file (str, optional): 日志文件路径,如不提供则只输出到控制台
|
||||
ocr_config_path (str): OCR配置文件(如1.yaml)的路径。
|
||||
ocr_confidence_threshold (float): OCR识别结果的置信度阈值(0~1)。
|
||||
log_level (int): 日志级别,默认为logging.INFO。
|
||||
log_file (str, optional): 日志文件路径,如不提供则只输出到控制台。
|
||||
"""
|
||||
# 初始化日志
|
||||
# 初始化日志(确保先初始化日志,后续操作可正常打日志)
|
||||
self.logger = self._setup_logger(log_level, log_file)
|
||||
|
||||
# 加载违禁词
|
||||
# 加载违禁词(优先级:先加载配置,再初始化引擎)
|
||||
self.forbidden_words = self._load_forbidden_words(forbidden_words_path)
|
||||
|
||||
# 初始化OCR引擎
|
||||
self.ocr_engine = self._initialize_ocr()
|
||||
# 初始化RapidOCR引擎(传入配置文件路径)
|
||||
self.ocr_engine = self._initialize_ocr(ocr_config_path)
|
||||
|
||||
# 设置置信度阈值
|
||||
self.OCR_CONFIDENCE_THRESHOLD = ocr_confidence_threshold
|
||||
self.logger.info(f"OCR置信度阈值设置为: {ocr_confidence_threshold}")
|
||||
# 校验核心依赖是否就绪
|
||||
self._check_dependencies()
|
||||
|
||||
# 设置置信度阈值(限制在0~1范围,避免非法值)
|
||||
self.OCR_CONFIDENCE_THRESHOLD = max(0.0, min(ocr_confidence_threshold, 1.0))
|
||||
self.logger.info(f"OCR置信度阈值已设置(范围0~1): {self.OCR_CONFIDENCE_THRESHOLD:.4f}")
|
||||
|
||||
def _setup_logger(self, log_level: int, log_file: str = None) -> logging.Logger:
|
||||
"""
|
||||
配置日志系统
|
||||
配置日志系统(避免重复添加处理器,支持控制台+文件双输出)
|
||||
|
||||
Args:
|
||||
log_level: 日志级别
|
||||
log_file: 日志文件路径,如为None则只输出到控制台
|
||||
log_level: 日志级别(如logging.DEBUG、logging.INFO)。
|
||||
log_file: 日志文件路径,为None时仅输出到控制台。
|
||||
|
||||
Returns:
|
||||
配置好的logger实例
|
||||
logging.Logger: 配置好的日志实例。
|
||||
"""
|
||||
# 创建logger
|
||||
logger = logging.getLogger('OCRViolationDetector')
|
||||
logger.setLevel(log_level)
|
||||
|
||||
# 避免重复添加处理器
|
||||
# 避免重复添加处理器(防止日志重复输出)
|
||||
if logger.handlers:
|
||||
return logger
|
||||
|
||||
# 定义日志格式
|
||||
# 定义日志格式(包含时间、模块名、级别、内容)
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
|
||||
# 添加控制台处理器
|
||||
# 1. 添加控制台处理器
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(formatter)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# 如果提供了日志文件路径,则添加文件处理器
|
||||
# 2. 若指定日志文件,添加文件处理器(自动创建目录)
|
||||
if log_file:
|
||||
try:
|
||||
# 确保日志目录存在
|
||||
log_dir = os.path.dirname(log_file)
|
||||
# 若日志目录不存在,自动创建
|
||||
if log_dir and not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
self.logger.debug(f"自动创建日志目录: {log_dir}")
|
||||
|
||||
file_handler = logging.FileHandler(log_file, encoding='utf-8')
|
||||
file_handler.setFormatter(formatter)
|
||||
logger.addHandler(file_handler)
|
||||
logger.info(f"日志文件将保存至: {log_file}")
|
||||
logger.info(f"日志文件已配置: {log_file}")
|
||||
except Exception as e:
|
||||
logger.warning(f"无法创建日志文件处理器: {str(e)},仅输出至控制台")
|
||||
logger.warning(f"创建日志文件失败(仅控制台输出): {str(e)}")
|
||||
|
||||
return logger
|
||||
|
||||
def _load_forbidden_words(self, path):
|
||||
"""从txt文件加载违禁词列表"""
|
||||
words = set()
|
||||
if not os.path.exists(path):
|
||||
self.logger.warning(f"警告:未找到违禁词文件 {path},将跳过违禁词检测")
|
||||
return words
|
||||
def _load_forbidden_words(self, path: str) -> set:
|
||||
"""
|
||||
从TXT文件加载违禁词(去重、过滤空行,支持UTF-8编码)
|
||||
|
||||
Args:
|
||||
path (str): 违禁词TXT文件路径。
|
||||
|
||||
Returns:
|
||||
set: 去重后的违禁词集合(空集合表示加载失败)。
|
||||
"""
|
||||
forbidden_words = set()
|
||||
|
||||
# 第一步:检查文件是否存在
|
||||
if not os.path.exists(path):
|
||||
self.logger.error(f"违禁词文件不存在: {path}")
|
||||
return forbidden_words
|
||||
|
||||
# 第二步:读取文件并处理内容
|
||||
try:
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
# 去除每行首尾空格和换行符,过滤空行
|
||||
words = {line.strip() for line in f if line.strip()}
|
||||
self.logger.info(f"成功加载 {len(words)} 个违禁词。")
|
||||
# 过滤空行、去除首尾空格、去重
|
||||
forbidden_words = {
|
||||
line.strip() for line in f
|
||||
if line.strip() # 跳过空行或纯空格行
|
||||
}
|
||||
self.logger.info(f"成功加载违禁词: {len(forbidden_words)} 个(已去重)")
|
||||
self.logger.debug(f"违禁词列表: {forbidden_words}")
|
||||
except UnicodeDecodeError:
|
||||
self.logger.error(f"违禁词文件编码错误(需UTF-8): {path}")
|
||||
except PermissionError:
|
||||
self.logger.error(f"无权限读取违禁词文件: {path}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"加载违禁词文件失败:{str(e)},将跳过违禁词检测")
|
||||
return words
|
||||
self.logger.error(f"加载违禁词失败: {str(e)}", exc_info=True)
|
||||
|
||||
def _initialize_ocr(self):
|
||||
"""初始化RapidOCR引擎"""
|
||||
self.logger.info("正在初始化RapidOCR引擎...")
|
||||
return forbidden_words
|
||||
|
||||
config_path = r"D:\Git\bin\video\ocr\config\1.yaml"
|
||||
try:
|
||||
# 检查配置文件是否存在
|
||||
if not os.path.exists(config_path):
|
||||
self.logger.error(f"RapidOCR配置文件不存在: {config_path}")
|
||||
return None
|
||||
def _initialize_ocr(self, config_path: str) -> RapidOCR | None:
|
||||
"""
|
||||
初始化RapidOCR引擎(校验配置文件、捕获初始化异常)
|
||||
|
||||
engine = RapidOCR(
|
||||
config_path=config_path
|
||||
)
|
||||
self.logger.info("RapidOCR引擎初始化成功。")
|
||||
return engine
|
||||
except Exception as e:
|
||||
self.logger.error(f"RapidOCR引擎初始化失败: {e}")
|
||||
Args:
|
||||
config_path (str): RapidOCR配置文件(如1.yaml)路径。
|
||||
|
||||
Returns:
|
||||
RapidOCR | None: OCR引擎实例(None表示初始化失败)。
|
||||
"""
|
||||
self.logger.info("开始初始化RapidOCR引擎...")
|
||||
|
||||
# 第一步:检查配置文件是否存在
|
||||
if not os.path.exists(config_path):
|
||||
self.logger.error(f"OCR配置文件不存在: {config_path}")
|
||||
return None
|
||||
|
||||
def detect(self, frame):
|
||||
"""
|
||||
对单帧图像进行OCR,检测所有出现的违禁词并返回列表
|
||||
返回格式:(是否有违禁词, 违禁词列表, 对应的置信度列表)
|
||||
"""
|
||||
print("收到帧")
|
||||
if not self.ocr_engine or not self.forbidden_words:
|
||||
return False, [], []
|
||||
# 第二步:初始化OCR引擎(捕获RapidOCR相关异常)
|
||||
try:
|
||||
ocr_engine = RapidOCR(config_path=config_path)
|
||||
self.logger.info("RapidOCR引擎初始化成功")
|
||||
return ocr_engine
|
||||
except ImportError:
|
||||
self.logger.error("RapidOCR依赖未安装(需执行:pip install rapidocr-onnxruntime)")
|
||||
except Exception as e:
|
||||
self.logger.error(f"RapidOCR初始化失败: {str(e)}", exc_info=True)
|
||||
|
||||
all_prohibited = [] # 存储所有检测到的违禁词
|
||||
all_confidences = [] # 存储对应违禁词的置信度
|
||||
return None
|
||||
|
||||
def _check_dependencies(self) -> None:
|
||||
"""校验OCR引擎和违禁词列表是否就绪(输出警告日志)"""
|
||||
if not self.ocr_engine:
|
||||
self.logger.warning("⚠️ OCR引擎未就绪,违禁词检测功能将禁用")
|
||||
if not self.forbidden_words:
|
||||
self.logger.warning("⚠️ 违禁词列表为空,违禁词检测功能将禁用")
|
||||
|
||||
def detect(self, frame) -> tuple[bool, list, list]:
|
||||
"""
|
||||
对单帧图像进行OCR违禁词检测(核心方法)
|
||||
|
||||
Args:
|
||||
frame: 输入图像帧(NumPy数组,BGR格式,cv2读取的图像)。
|
||||
|
||||
Returns:
|
||||
tuple[bool, list, list]:
|
||||
- 第一个元素:是否检测到违禁词(True/False);
|
||||
- 第二个元素:检测到的违禁词列表(空列表表示无违禁词);
|
||||
- 第三个元素:对应违禁词的置信度列表(与违禁词列表一一对应)。
|
||||
"""
|
||||
# 初始化返回结果
|
||||
has_violation = False
|
||||
violation_words = []
|
||||
violation_confs = []
|
||||
|
||||
# 前置校验:1. 图像帧是否有效 2. OCR引擎是否就绪 3. 违禁词是否存在
|
||||
if frame is None or frame.size == 0:
|
||||
self.logger.warning("输入图像帧为空或无效,跳过OCR检测")
|
||||
return has_violation, violation_words, violation_confs
|
||||
if not self.ocr_engine or not self.forbidden_words:
|
||||
self.logger.debug("OCR引擎未就绪或违禁词为空,跳过OCR检测")
|
||||
return has_violation, violation_words, violation_confs
|
||||
|
||||
try:
|
||||
# 执行OCR识别
|
||||
result = self.ocr_engine(frame)
|
||||
self.logger.debug(f"RapidOCR 原始返回结果: {result}")
|
||||
# 1. 执行OCR识别(获取RapidOCR原始结果)
|
||||
self.logger.debug("开始执行OCR识别...")
|
||||
ocr_result = self.ocr_engine(frame)
|
||||
self.logger.debug(f"RapidOCR原始结果: {ocr_result}")
|
||||
|
||||
if result is None:
|
||||
return False, [], []
|
||||
# 2. 校验OCR结果是否有效(避免None或格式异常)
|
||||
if ocr_result is None:
|
||||
self.logger.debug("OCR识别未返回任何结果(图像无文本或识别失败)")
|
||||
return has_violation, violation_words, violation_confs
|
||||
|
||||
# 提取文本和置信度(适配RapidOCR的结果格式)
|
||||
texts = result.txts if hasattr(result, 'txts') else []
|
||||
confidences = result.scores if hasattr(result, 'scores') else []
|
||||
# 3. 检查txts和scores是否存在且不为None
|
||||
if not hasattr(ocr_result, 'txts') or ocr_result.txts is None:
|
||||
self.logger.warning("OCR结果中txts为None或不存在")
|
||||
return has_violation, violation_words, violation_confs
|
||||
|
||||
# 遍历所有识别结果,收集所有违禁词
|
||||
if not hasattr(ocr_result, 'scores') or ocr_result.scores is None:
|
||||
self.logger.warning("OCR结果中scores为None或不存在")
|
||||
return has_violation, violation_words, violation_confs
|
||||
|
||||
# 4. 转为列表并去None(防止单个元素为None)
|
||||
# 确保txts是可迭代的,如果不是则转为空列表
|
||||
if not isinstance(ocr_result.txts, (list, tuple)):
|
||||
self.logger.warning(f"OCR txts不是可迭代类型,实际类型: {type(ocr_result.txts)}")
|
||||
texts = []
|
||||
else:
|
||||
texts = [txt.strip() for txt in ocr_result.txts if txt and isinstance(txt, str)]
|
||||
|
||||
# 确保scores是可迭代的,如果不是则转为空列表
|
||||
if not isinstance(ocr_result.scores, (list, tuple)):
|
||||
self.logger.warning(f"OCR scores不是可迭代类型,实际类型: {type(ocr_result.scores)}")
|
||||
confidences = []
|
||||
else:
|
||||
confidences = [conf for conf in ocr_result.scores if conf and isinstance(conf, (int, float))]
|
||||
|
||||
# 5. 校验文本和置信度列表长度是否一致(避免zip迭代错误)
|
||||
if len(texts) != len(confidences):
|
||||
self.logger.warning(
|
||||
f"OCR文本与置信度数量不匹配(文本{len(texts)}个,置信度{len(confidences)}个),跳过检测")
|
||||
return has_violation, violation_words, violation_confs
|
||||
if len(texts) == 0:
|
||||
self.logger.debug("OCR未识别到任何有效文本")
|
||||
return has_violation, violation_words, violation_confs
|
||||
|
||||
# 6. 遍历识别结果,筛选违禁词(按置信度阈值过滤)
|
||||
self.logger.debug(f"开始筛选违禁词(阈值{self.OCR_CONFIDENCE_THRESHOLD:.4f})")
|
||||
for text, conf in zip(texts, confidences):
|
||||
# 过滤低置信度结果
|
||||
if conf < self.OCR_CONFIDENCE_THRESHOLD:
|
||||
self.logger.debug(f"文本 '{text}' 置信度 {conf:.4f} 低于阈值,跳过")
|
||||
self.logger.debug(f"文本 '{text}' 置信度{conf:.4f} < 阈值,跳过")
|
||||
continue
|
||||
|
||||
# 检查当前文本中是否包含多个违禁词
|
||||
for word in self.forbidden_words:
|
||||
if word in text:
|
||||
self.logger.warning(f"OCR检测到违禁词: '{word}' (来自文本: '{text}') 置信度: {conf:.4f}")
|
||||
all_prohibited.append(word)
|
||||
all_confidences.append(conf)
|
||||
# 检查当前文本是否包含违禁词(支持一个文本含多个违禁词)
|
||||
matched_words = [word for word in self.forbidden_words if word in text]
|
||||
if matched_words:
|
||||
has_violation = True
|
||||
# 记录所有匹配的违禁词和对应置信度
|
||||
violation_words.extend(matched_words)
|
||||
violation_confs.extend([conf] * len(matched_words)) # 一个文本对应多个违禁词时,置信度复用
|
||||
self.logger.warning(f"检测到违禁词: {matched_words}(来源文本: '{text}',置信度: {conf:.4f})")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"OCR检测过程中发生错误: {e}", exc_info=True)
|
||||
# 捕获所有异常,确保不中断上层调用
|
||||
self.logger.error(f"OCR检测过程异常: {str(e)}", exc_info=True)
|
||||
|
||||
# 返回检测结果(是否有违禁词、所有违禁词列表、对应置信度列表)
|
||||
return len(all_prohibited) > 0, all_prohibited, all_confidences
|
||||
|
||||
# def test_image(self, image_path: str, show_image: bool = True) -> tuple:
|
||||
# """
|
||||
# 对单张图片进行OCR违禁词检测并展示结果
|
||||
#
|
||||
# Args:
|
||||
# image_path (str): 图片文件路径
|
||||
# show_image (bool): 是否显示图片,默认为True
|
||||
#
|
||||
# Returns:
|
||||
# tuple: (是否有违禁词, 违禁词列表, 对应的置信度列表)
|
||||
# """
|
||||
# # 检查图片文件是否存在
|
||||
# if not os.path.exists(image_path):
|
||||
# self.logger.error(f"图片文件不存在: {image_path}")
|
||||
# return False, [], []
|
||||
#
|
||||
# try:
|
||||
# # 读取图片
|
||||
# frame = cv2.imread(image_path)
|
||||
# if frame is None:
|
||||
# self.logger.error(f"无法读取图片: {image_path}")
|
||||
# return False, [], []
|
||||
#
|
||||
# self.logger.info(f"开始处理图片: {image_path}")
|
||||
#
|
||||
# # 调用检测方法
|
||||
# has_violation, violations, confidences = self.detect(frame)
|
||||
#
|
||||
# # 输出检测结果
|
||||
# if has_violation:
|
||||
# self.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
|
||||
# for word, conf in zip(violations, confidences):
|
||||
# self.logger.info(f"- {word} (置信度: {conf:.4f})")
|
||||
# else:
|
||||
# self.logger.info("图片中未检测到违禁词")
|
||||
|
||||
# # 显示图片(如果需要)
|
||||
# if show_image:
|
||||
# # 调整图片大小以便于显示(如果太大)
|
||||
# height, width = frame.shape[:2]
|
||||
# max_size = 800
|
||||
# if max(height, width) > max_size:
|
||||
# scale = max_size / max(height, width)
|
||||
# frame = cv2.resize(frame, None, fx=scale, fy=scale)
|
||||
#
|
||||
# cv2.imshow(f"OCR检测结果: {'发现违禁词' if has_violation else '未发现违禁词'}", frame)
|
||||
# cv2.waitKey(0) # 等待用户按键
|
||||
# cv2.destroyAllWindows()
|
||||
#
|
||||
# return has_violation, violations, confidences
|
||||
#
|
||||
# except Exception as e:
|
||||
# self.logger.error(f"处理图片时发生错误: {str(e)}", exc_info=True)
|
||||
# return False, [], []
|
||||
#
|
||||
#
|
||||
# # 使用示例
|
||||
# if __name__ == "__main__":
|
||||
# # 配置参数
|
||||
# forbidden_words_path = "forbidden_words.txt" # 违禁词文件路径
|
||||
# test_image_path = r"D:\Git\bin\video\ocr\images\img_7.png" # 测试图片路径
|
||||
# ocr_threshold = 0.6 # OCR置信度阈值
|
||||
#
|
||||
# # 创建检测器实例
|
||||
# detector = OCRViolationDetector(
|
||||
# forbidden_words_path=forbidden_words_path,
|
||||
# ocr_confidence_threshold=ocr_threshold,
|
||||
# log_level=logging.INFO,
|
||||
# log_file="ocr_detection.log"
|
||||
# )
|
||||
#
|
||||
# # 测试图片
|
||||
# detector.test_image(test_image_path, show_image=True)
|
||||
return has_violation, violation_words, violation_confs
|
||||
|
Reference in New Issue
Block a user