This commit is contained in:
2025-09-02 21:30:28 +08:00
parent df74b688fa
commit e21432c6a1
4 changed files with 219 additions and 62 deletions

119
ocr/config/1.yaml Normal file
View File

@ -0,0 +1,119 @@
Global:
text_score: 0.5
use_det: true
use_cls: true
use_rec: true
min_height: 30
width_height_ratio: 8
max_side_len: 2000
min_side_len: 30
return_word_box: false
return_single_char_box: false
font_path: null
EngineConfig:
onnxruntime:
intra_op_num_threads: -1
inter_op_num_threads: -1
enable_cpu_mem_arena: false
cpu_ep_cfg:
arena_extend_strategy: "kSameAsRequested"
use_cuda: true # 改为true以启用CUDA
cuda_ep_cfg:
device_id: 0
arena_extend_strategy: "kNextPowerOfTwo"
cudnn_conv_algo_search: "EXHAUSTIVE"
do_copy_in_default_stream: true
use_dml: false
dm_ep_cfg: null
use_cann: false
cann_ep_cfg:
device_id: 0
arena_extend_strategy: "kNextPowerOfTwo"
npu_mem_limit: 21474836480 # 20 * 1024 * 1024 * 1024
op_select_impl_mode: "high_performance"
optypelist_for_implmode: "Gelu"
enable_cann_graph: true
openvino:
inference_num_threads: -1
performance_hint: null
performance_num_requests: -1
enable_cpu_pinning: null
num_streams: -1
enable_hyper_threading: null
scheduling_core_type: null
paddle:
cpu_math_library_num_threads: -1
use_npu: false
npu_id: 0
use_cuda: true # 改为true以启用CUDA
gpu_id: 0
gpu_mem: 500
torch:
use_cuda: true # 已经是true
gpu_id: 0
Det:
engine_type: "torch"
lang_type: "ch"
model_type: "mobile"
ocr_version: "PP-OCRv4"
task_type: "det"
model_path: null
model_dir: null
limit_side_len: 736
limit_type: min
std: [ 0.5, 0.5, 0.5 ]
mean: [ 0.5, 0.5, 0.5 ]
thresh: 0.3
box_thresh: 0.5
max_candidates: 1000
unclip_ratio: 1.6
use_dilation: true
score_mode: fast
Cls:
engine_type: "torch"
lang_type: "ch"
model_type: "mobile"
ocr_version: "PP-OCRv4"
task_type: "cls"
model_path: null
model_dir: null
cls_image_shape: [3, 48, 192]
cls_batch_num: 6
cls_thresh: 0.9
label_list: ["0", "180"]
Rec:
engine_type: "torch"
lang_type: "ch"
model_type: "mobile"
ocr_version: "PP-OCRv4"
task_type: "rec"
model_path: null
model_dir: null
rec_keys_path: null
rec_img_shape: [3, 48, 320]
rec_batch_num: 6

View File

@ -15,3 +15,4 @@
法轮功大法好 法轮功大法好
法轮 法轮
李洪志 李洪志
习近平

View File

@ -1,7 +1,7 @@
import os import os
import cv2 import cv2
import logging
from rapidocr import RapidOCR from rapidocr import RapidOCR
from logger_config import logger
class OCRViolationDetector: class OCRViolationDetector:
@ -9,47 +9,110 @@ class OCRViolationDetector:
封装RapidOCR引擎用于检测图像帧中的违禁词。 封装RapidOCR引擎用于检测图像帧中的违禁词。
""" """
def __init__(self, forbidden_words_path: str, ocr_confidence_threshold: float = 0.5): # 降低阈值提高检出率 def __init__(self, forbidden_words_path: str, ocr_confidence_threshold: float = 0.5,
log_level: int = logging.INFO, log_file: str = None):
""" """
初始化OCR引擎违禁词列表。 初始化OCR引擎违禁词列表和日志配置
Args: Args:
forbidden_words_path (str): 违禁词列表 .txt 文件的路径。 forbidden_words_path (str): 违禁词列表 .txt 文件的路径。
ocr_confidence_threshold (float): OCR识别结果的置信度阈值。 ocr_confidence_threshold (float): OCR识别结果的置信度阈值。
log_level (int): 日志级别默认为logging.INFO
log_file (str, optional): 日志文件路径,如不提供则只输出到控制台
""" """
# 初始化日志
self.logger = self._setup_logger(log_level, log_file)
# 加载违禁词
self.forbidden_words = self._load_forbidden_words(forbidden_words_path) self.forbidden_words = self._load_forbidden_words(forbidden_words_path)
# 初始化OCR引擎
self.ocr_engine = self._initialize_ocr() self.ocr_engine = self._initialize_ocr()
# 设置置信度阈值
self.OCR_CONFIDENCE_THRESHOLD = ocr_confidence_threshold self.OCR_CONFIDENCE_THRESHOLD = ocr_confidence_threshold
self.logger.info(f"OCR置信度阈值设置为: {ocr_confidence_threshold}")
def _setup_logger(self, log_level: int, log_file: str = None) -> logging.Logger:
"""
配置日志系统
Args:
log_level: 日志级别
log_file: 日志文件路径如为None则只输出到控制台
Returns:
配置好的logger实例
"""
# 创建logger
logger = logging.getLogger('OCRViolationDetector')
logger.setLevel(log_level)
# 避免重复添加处理器
if logger.handlers:
return logger
# 定义日志格式
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# 添加控制台处理器
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# 如果提供了日志文件路径,则添加文件处理器
if log_file:
try:
# 确保日志目录存在
log_dir = os.path.dirname(log_file)
if log_dir and not os.path.exists(log_dir):
os.makedirs(log_dir)
file_handler = logging.FileHandler(log_file, encoding='utf-8')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.info(f"日志文件将保存至: {log_file}")
except Exception as e:
logger.warning(f"无法创建日志文件处理器: {str(e)},仅输出至控制台")
return logger
def _load_forbidden_words(self, path): def _load_forbidden_words(self, path):
"""从txt文件加载违禁词列表与rapidocr_test.py保持一致""" """从txt文件加载违禁词列表"""
words = set() words = set()
if not os.path.exists(path): if not os.path.exists(path):
logger.warning(f"警告:未找到违禁词文件 {path},将跳过违禁词检测") self.logger.warning(f"警告:未找到违禁词文件 {path},将跳过违禁词检测")
return words return words
try: try:
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
# 去除每行首尾空格和换行符,过滤空行(不排除注释行,与测试代码统一) # 去除每行首尾空格和换行符,过滤空行
words = {line.strip() for line in f if line.strip()} words = {line.strip() for line in f if line.strip()}
logger.info(f"成功加载 {len(words)} 个违禁词。") self.logger.info(f"成功加载 {len(words)} 个违禁词。")
except Exception as e: except Exception as e:
logger.error(f"加载违禁词文件失败:{str(e)},将跳过违禁词检测") self.logger.error(f"加载违禁词文件失败:{str(e)},将跳过违禁词检测")
return words return words
def _initialize_ocr(self): def _initialize_ocr(self):
"""初始化RapidOCR引擎""" """初始化RapidOCR引擎"""
logger.info("正在初始化RapidOCR引擎...") self.logger.info("正在初始化RapidOCR引擎...")
config_path = r".\config\1.yaml" config_path = r"../ocr/config/1.yaml"
try: try:
# 检查配置文件是否存在
if not os.path.exists(config_path):
self.logger.error(f"RapidOCR配置文件不存在: {config_path}")
return None
engine = RapidOCR( engine = RapidOCR(
config_path=config_path config_path=config_path
) )
logger.info("RapidOCR引擎初始化成功。") self.logger.info("RapidOCR引擎初始化成功。")
return engine return engine
except Exception as e: except Exception as e:
logger.error(f"RapidOCR引擎初始化失败: {e}") self.logger.error(f"RapidOCR引擎初始化失败: {e}")
return None return None
def detect(self, frame): def detect(self, frame):
@ -57,6 +120,7 @@ class OCRViolationDetector:
对单帧图像进行OCR检测所有出现的违禁词并返回列表 对单帧图像进行OCR检测所有出现的违禁词并返回列表
返回格式:(是否有违禁词, 违禁词列表, 对应的置信度列表) 返回格式:(是否有违禁词, 违禁词列表, 对应的置信度列表)
""" """
print("收到帧")
if not self.ocr_engine or not self.forbidden_words: if not self.ocr_engine or not self.forbidden_words:
return False, [], [] return False, [], []
@ -64,9 +128,9 @@ class OCRViolationDetector:
all_confidences = [] # 存储对应违禁词的置信度 all_confidences = [] # 存储对应违禁词的置信度
try: try:
# 执行OCR识别(修正调用方式,与测试代码一致) # 执行OCR识别
result = self.ocr_engine(frame) result = self.ocr_engine(frame)
logger.debug(f"RapidOCR 原始返回结果: {result}") self.logger.debug(f"RapidOCR 原始返回结果: {result}")
if result is None: if result is None:
return False, [], [] return False, [], []
@ -78,59 +142,19 @@ class OCRViolationDetector:
# 遍历所有识别结果,收集所有违禁词 # 遍历所有识别结果,收集所有违禁词
for text, conf in zip(texts, confidences): for text, conf in zip(texts, confidences):
if conf < self.OCR_CONFIDENCE_THRESHOLD: if conf < self.OCR_CONFIDENCE_THRESHOLD:
logger.debug(f"文本 '{text}' 置信度 {conf:.4f} 低于阈值,跳过") self.logger.debug(f"文本 '{text}' 置信度 {conf:.4f} 低于阈值,跳过")
continue continue
# 检查当前文本中是否包含多个违禁词 # 检查当前文本中是否包含多个违禁词
for word in self.forbidden_words: for word in self.forbidden_words:
if word in text: if word in text:
logger.warning(f"OCR检测到违禁词: '{word}' (来自文本: '{text}') 置信度: {conf:.4f}") self.logger.warning(f"OCR检测到违禁词: '{word}' (来自文本: '{text}') 置信度: {conf:.4f}")
all_prohibited.append(word) all_prohibited.append(word)
all_confidences.append(conf) all_confidences.append(conf)
except Exception as e: except Exception as e:
logger.error(f"OCR检测过程中发生错误: {e}", exc_info=True) self.logger.error(f"OCR检测过程中发生错误: {e}", exc_info=True)
# 返回检测结果(是否有违禁词、所有违禁词列表、对应置信度列表) # 返回检测结果(是否有违禁词、所有违禁词列表、对应置信度列表)
return len(all_prohibited) > 0, all_prohibited, all_confidences return len(all_prohibited) > 0, all_prohibited, all_confidences
# def test_single_image():
# """测试单张图片的OCR违规检测显示所有违禁词"""
# TEST_IMAGE_PATH = r"ocr/images/img_7.png" # 修正路径格式
# FORBIDDEN_WORDS_PATH = r"ocr/forbidden_words.txt"
# CONFIDENCE_THRESHOLD = 0.5
#
# detector = OCRViolationDetector(
# forbidden_words_path=FORBIDDEN_WORDS_PATH,
# ocr_confidence_threshold=CONFIDENCE_THRESHOLD
# )
#
# if not os.path.exists(TEST_IMAGE_PATH):
# print(f"错误:图片文件不存在 - {TEST_IMAGE_PATH}")
# return
#
# frame = cv2.imread(TEST_IMAGE_PATH)
# if frame is None:
# print(f"错误:无法读取图片 - {TEST_IMAGE_PATH}")
# return
#
# # 执行检测
# has_violation, words, confidences = detector.detect(frame)
#
# # 输出所有检测到的违禁词
# if has_violation:
# print(f"测试结果:图片中共检测到 {len(words)} 个违禁词:")
# for word, conf in zip(words, confidences):
# print(f"- {word}(置信度:{conf:.4f}")
# else:
# print("测试结果:图片中未检测到违禁词")
#
#
# if __name__ == "__main__":
# print("开始单张图片OCR违规检测测试...")
# test_single_image()
# print("测试完成")

View File

@ -4,6 +4,7 @@ import cv2 # 导入OpenCV库
import numpy as np import numpy as np
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration
from aiortc.mediastreams import MediaStreamTrack from aiortc.mediastreams import MediaStreamTrack
from ocr.ocr_violation_detector import OCRViolationDetector
class VideoTrack(MediaStreamTrack): class VideoTrack(MediaStreamTrack):
@ -47,7 +48,7 @@ async def rtc_frame_receiver(url, frame_queue):
if frame_queue.empty(): if frame_queue.empty():
# 队列为空、放入当前cv2帧 # 队列为空、放入当前cv2帧
await frame_queue.put(frame_cv2) await frame_queue.put(frame_cv2)
print(f"{total_frames}队列为空、已放入新的cv2帧尺寸: {frame_cv2.shape}") # print(f"第{total_frames}队列为空、已放入新的cv2帧尺寸: {frame_cv2.shape}")
else: else:
# 队列非空、说明上一帧还未处理、跳过当前帧 # 队列非空、说明上一帧还未处理、跳过当前帧
print(f"{total_frames}帧:队列非空、跳过该帧") print(f"{total_frames}帧:队列非空、跳过该帧")
@ -93,23 +94,35 @@ async def frame_consumer(frame_queue):
Args: frame_queue: 帧队列 Args: frame_queue: 帧队列
""" """
# 创建OCR检测器实例请替换为实际的违禁词文件路径
ocr_detector = OCRViolationDetector(
forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt", # 替换为实际路径
ocr_confidence_threshold=0.5,)
while True: while True:
# 从队列中获取cv2帧队列为空时会阻塞等待新帧 # 从队列中获取cv2帧队列为空时会阻塞等待新帧
current_frame = await frame_queue.get() current_frame = await frame_queue.get()
ocr_detector.detect(current_frame)
# 验证这是cv2可以处理的帧 # 验证这是cv2可以处理的帧
print(f"从队列获取到cv2帧、尺寸: {current_frame.shape}、数据类型: {current_frame.dtype}") # print(f"从队列获取到cv2帧、尺寸: {current_frame.shape}、数据类型: {current_frame.dtype}")
# 这里可以添加cv2的处理代码例如显示帧 # 这里可以添加cv2的处理代码例如显示帧
# cv2.imshow('Received Frame', current_frame) # cv2.imshow('Received Frame', current_frame)
# if cv2.waitKey(1) & 0xFF == ord('q'): # if cv2.waitKey(1) & 0xFF == ord('q'):
# break # break
print("cv2帧处理完成") # print("cv2帧处理完成")
# 标记任务完成 # 标记任务完成
frame_queue.task_done() frame_queue.task_done()
print("帧处理完成、队列已清空") # print("帧处理完成、队列已清空")
async def main(): async def main():