ocr1.0
This commit is contained in:
119
ocr/config/1.yaml
Normal file
119
ocr/config/1.yaml
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
Global:
|
||||||
|
text_score: 0.5
|
||||||
|
|
||||||
|
use_det: true
|
||||||
|
use_cls: true
|
||||||
|
use_rec: true
|
||||||
|
|
||||||
|
min_height: 30
|
||||||
|
width_height_ratio: 8
|
||||||
|
max_side_len: 2000
|
||||||
|
min_side_len: 30
|
||||||
|
|
||||||
|
return_word_box: false
|
||||||
|
return_single_char_box: false
|
||||||
|
|
||||||
|
font_path: null
|
||||||
|
|
||||||
|
EngineConfig:
|
||||||
|
onnxruntime:
|
||||||
|
intra_op_num_threads: -1
|
||||||
|
inter_op_num_threads: -1
|
||||||
|
enable_cpu_mem_arena: false
|
||||||
|
|
||||||
|
cpu_ep_cfg:
|
||||||
|
arena_extend_strategy: "kSameAsRequested"
|
||||||
|
|
||||||
|
use_cuda: true # 改为true以启用CUDA
|
||||||
|
cuda_ep_cfg:
|
||||||
|
device_id: 0
|
||||||
|
arena_extend_strategy: "kNextPowerOfTwo"
|
||||||
|
cudnn_conv_algo_search: "EXHAUSTIVE"
|
||||||
|
do_copy_in_default_stream: true
|
||||||
|
|
||||||
|
use_dml: false
|
||||||
|
dm_ep_cfg: null
|
||||||
|
|
||||||
|
use_cann: false
|
||||||
|
cann_ep_cfg:
|
||||||
|
device_id: 0
|
||||||
|
arena_extend_strategy: "kNextPowerOfTwo"
|
||||||
|
npu_mem_limit: 21474836480 # 20 * 1024 * 1024 * 1024
|
||||||
|
op_select_impl_mode: "high_performance"
|
||||||
|
optypelist_for_implmode: "Gelu"
|
||||||
|
enable_cann_graph: true
|
||||||
|
|
||||||
|
openvino:
|
||||||
|
inference_num_threads: -1
|
||||||
|
performance_hint: null
|
||||||
|
performance_num_requests: -1
|
||||||
|
enable_cpu_pinning: null
|
||||||
|
num_streams: -1
|
||||||
|
enable_hyper_threading: null
|
||||||
|
scheduling_core_type: null
|
||||||
|
|
||||||
|
paddle:
|
||||||
|
cpu_math_library_num_threads: -1
|
||||||
|
use_npu: false
|
||||||
|
npu_id: 0
|
||||||
|
use_cuda: true # 改为true以启用CUDA
|
||||||
|
gpu_id: 0
|
||||||
|
gpu_mem: 500
|
||||||
|
|
||||||
|
torch:
|
||||||
|
use_cuda: true # 已经是true
|
||||||
|
gpu_id: 0
|
||||||
|
|
||||||
|
Det:
|
||||||
|
engine_type: "torch"
|
||||||
|
lang_type: "ch"
|
||||||
|
model_type: "mobile"
|
||||||
|
ocr_version: "PP-OCRv4"
|
||||||
|
|
||||||
|
task_type: "det"
|
||||||
|
|
||||||
|
model_path: null
|
||||||
|
model_dir: null
|
||||||
|
|
||||||
|
limit_side_len: 736
|
||||||
|
limit_type: min
|
||||||
|
std: [ 0.5, 0.5, 0.5 ]
|
||||||
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
|
|
||||||
|
thresh: 0.3
|
||||||
|
box_thresh: 0.5
|
||||||
|
max_candidates: 1000
|
||||||
|
unclip_ratio: 1.6
|
||||||
|
use_dilation: true
|
||||||
|
score_mode: fast
|
||||||
|
|
||||||
|
Cls:
|
||||||
|
engine_type: "torch"
|
||||||
|
lang_type: "ch"
|
||||||
|
model_type: "mobile"
|
||||||
|
ocr_version: "PP-OCRv4"
|
||||||
|
|
||||||
|
task_type: "cls"
|
||||||
|
|
||||||
|
model_path: null
|
||||||
|
model_dir: null
|
||||||
|
|
||||||
|
cls_image_shape: [3, 48, 192]
|
||||||
|
cls_batch_num: 6
|
||||||
|
cls_thresh: 0.9
|
||||||
|
label_list: ["0", "180"]
|
||||||
|
|
||||||
|
Rec:
|
||||||
|
engine_type: "torch"
|
||||||
|
lang_type: "ch"
|
||||||
|
model_type: "mobile"
|
||||||
|
ocr_version: "PP-OCRv4"
|
||||||
|
|
||||||
|
task_type: "rec"
|
||||||
|
|
||||||
|
model_path: null
|
||||||
|
model_dir: null
|
||||||
|
|
||||||
|
rec_keys_path: null
|
||||||
|
rec_img_shape: [3, 48, 320]
|
||||||
|
rec_batch_num: 6
|
@ -15,3 +15,4 @@
|
|||||||
法轮功大法好
|
法轮功大法好
|
||||||
法轮
|
法轮
|
||||||
李洪志
|
李洪志
|
||||||
|
习近平
|
@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import cv2
|
import cv2
|
||||||
|
import logging
|
||||||
from rapidocr import RapidOCR
|
from rapidocr import RapidOCR
|
||||||
from logger_config import logger
|
|
||||||
|
|
||||||
|
|
||||||
class OCRViolationDetector:
|
class OCRViolationDetector:
|
||||||
@ -9,47 +9,110 @@ class OCRViolationDetector:
|
|||||||
封装RapidOCR引擎,用于检测图像帧中的违禁词。
|
封装RapidOCR引擎,用于检测图像帧中的违禁词。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, forbidden_words_path: str, ocr_confidence_threshold: float = 0.5): # 降低阈值提高检出率
|
def __init__(self, forbidden_words_path: str, ocr_confidence_threshold: float = 0.5,
|
||||||
|
log_level: int = logging.INFO, log_file: str = None):
|
||||||
"""
|
"""
|
||||||
初始化OCR引擎和违禁词列表。
|
初始化OCR引擎、违禁词列表和日志配置。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
forbidden_words_path (str): 违禁词列表 .txt 文件的路径。
|
forbidden_words_path (str): 违禁词列表 .txt 文件的路径。
|
||||||
ocr_confidence_threshold (float): OCR识别结果的置信度阈值。
|
ocr_confidence_threshold (float): OCR识别结果的置信度阈值。
|
||||||
|
log_level (int): 日志级别,默认为logging.INFO
|
||||||
|
log_file (str, optional): 日志文件路径,如不提供则只输出到控制台
|
||||||
"""
|
"""
|
||||||
|
# 初始化日志
|
||||||
|
self.logger = self._setup_logger(log_level, log_file)
|
||||||
|
|
||||||
|
# 加载违禁词
|
||||||
self.forbidden_words = self._load_forbidden_words(forbidden_words_path)
|
self.forbidden_words = self._load_forbidden_words(forbidden_words_path)
|
||||||
|
|
||||||
|
# 初始化OCR引擎
|
||||||
self.ocr_engine = self._initialize_ocr()
|
self.ocr_engine = self._initialize_ocr()
|
||||||
|
|
||||||
|
# 设置置信度阈值
|
||||||
self.OCR_CONFIDENCE_THRESHOLD = ocr_confidence_threshold
|
self.OCR_CONFIDENCE_THRESHOLD = ocr_confidence_threshold
|
||||||
|
self.logger.info(f"OCR置信度阈值设置为: {ocr_confidence_threshold}")
|
||||||
|
|
||||||
|
def _setup_logger(self, log_level: int, log_file: str = None) -> logging.Logger:
|
||||||
|
"""
|
||||||
|
配置日志系统
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_level: 日志级别
|
||||||
|
log_file: 日志文件路径,如为None则只输出到控制台
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
配置好的logger实例
|
||||||
|
"""
|
||||||
|
# 创建logger
|
||||||
|
logger = logging.getLogger('OCRViolationDetector')
|
||||||
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
|
# 避免重复添加处理器
|
||||||
|
if logger.handlers:
|
||||||
|
return logger
|
||||||
|
|
||||||
|
# 定义日志格式
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加控制台处理器
|
||||||
|
console_handler = logging.StreamHandler()
|
||||||
|
console_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
# 如果提供了日志文件路径,则添加文件处理器
|
||||||
|
if log_file:
|
||||||
|
try:
|
||||||
|
# 确保日志目录存在
|
||||||
|
log_dir = os.path.dirname(log_file)
|
||||||
|
if log_dir and not os.path.exists(log_dir):
|
||||||
|
os.makedirs(log_dir)
|
||||||
|
|
||||||
|
file_handler = logging.FileHandler(log_file, encoding='utf-8')
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
logger.info(f"日志文件将保存至: {log_file}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"无法创建日志文件处理器: {str(e)},仅输出至控制台")
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
def _load_forbidden_words(self, path):
|
def _load_forbidden_words(self, path):
|
||||||
"""从txt文件加载违禁词列表(与rapidocr_test.py保持一致)"""
|
"""从txt文件加载违禁词列表"""
|
||||||
words = set()
|
words = set()
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
logger.warning(f"警告:未找到违禁词文件 {path},将跳过违禁词检测")
|
self.logger.warning(f"警告:未找到违禁词文件 {path},将跳过违禁词检测")
|
||||||
return words
|
return words
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(path, 'r', encoding='utf-8') as f:
|
with open(path, 'r', encoding='utf-8') as f:
|
||||||
# 去除每行首尾空格和换行符,过滤空行(不排除注释行,与测试代码统一)
|
# 去除每行首尾空格和换行符,过滤空行
|
||||||
words = {line.strip() for line in f if line.strip()}
|
words = {line.strip() for line in f if line.strip()}
|
||||||
logger.info(f"成功加载 {len(words)} 个违禁词。")
|
self.logger.info(f"成功加载 {len(words)} 个违禁词。")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"加载违禁词文件失败:{str(e)},将跳过违禁词检测")
|
self.logger.error(f"加载违禁词文件失败:{str(e)},将跳过违禁词检测")
|
||||||
return words
|
return words
|
||||||
|
|
||||||
def _initialize_ocr(self):
|
def _initialize_ocr(self):
|
||||||
"""初始化RapidOCR引擎"""
|
"""初始化RapidOCR引擎"""
|
||||||
logger.info("正在初始化RapidOCR引擎...")
|
self.logger.info("正在初始化RapidOCR引擎...")
|
||||||
|
|
||||||
config_path = r".\config\1.yaml"
|
config_path = r"../ocr/config/1.yaml"
|
||||||
try:
|
try:
|
||||||
|
# 检查配置文件是否存在
|
||||||
|
if not os.path.exists(config_path):
|
||||||
|
self.logger.error(f"RapidOCR配置文件不存在: {config_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
engine = RapidOCR(
|
engine = RapidOCR(
|
||||||
config_path=config_path
|
config_path=config_path
|
||||||
)
|
)
|
||||||
logger.info("RapidOCR引擎初始化成功。")
|
self.logger.info("RapidOCR引擎初始化成功。")
|
||||||
return engine
|
return engine
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"RapidOCR引擎初始化失败: {e}")
|
self.logger.error(f"RapidOCR引擎初始化失败: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def detect(self, frame):
|
def detect(self, frame):
|
||||||
@ -57,6 +120,7 @@ class OCRViolationDetector:
|
|||||||
对单帧图像进行OCR,检测所有出现的违禁词并返回列表
|
对单帧图像进行OCR,检测所有出现的违禁词并返回列表
|
||||||
返回格式:(是否有违禁词, 违禁词列表, 对应的置信度列表)
|
返回格式:(是否有违禁词, 违禁词列表, 对应的置信度列表)
|
||||||
"""
|
"""
|
||||||
|
print("收到帧")
|
||||||
if not self.ocr_engine or not self.forbidden_words:
|
if not self.ocr_engine or not self.forbidden_words:
|
||||||
return False, [], []
|
return False, [], []
|
||||||
|
|
||||||
@ -64,9 +128,9 @@ class OCRViolationDetector:
|
|||||||
all_confidences = [] # 存储对应违禁词的置信度
|
all_confidences = [] # 存储对应违禁词的置信度
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 执行OCR识别(修正调用方式,与测试代码一致)
|
# 执行OCR识别
|
||||||
result = self.ocr_engine(frame)
|
result = self.ocr_engine(frame)
|
||||||
logger.debug(f"RapidOCR 原始返回结果: {result}")
|
self.logger.debug(f"RapidOCR 原始返回结果: {result}")
|
||||||
|
|
||||||
if result is None:
|
if result is None:
|
||||||
return False, [], []
|
return False, [], []
|
||||||
@ -78,59 +142,19 @@ class OCRViolationDetector:
|
|||||||
# 遍历所有识别结果,收集所有违禁词
|
# 遍历所有识别结果,收集所有违禁词
|
||||||
for text, conf in zip(texts, confidences):
|
for text, conf in zip(texts, confidences):
|
||||||
if conf < self.OCR_CONFIDENCE_THRESHOLD:
|
if conf < self.OCR_CONFIDENCE_THRESHOLD:
|
||||||
logger.debug(f"文本 '{text}' 置信度 {conf:.4f} 低于阈值,跳过")
|
self.logger.debug(f"文本 '{text}' 置信度 {conf:.4f} 低于阈值,跳过")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 检查当前文本中是否包含多个违禁词
|
# 检查当前文本中是否包含多个违禁词
|
||||||
for word in self.forbidden_words:
|
for word in self.forbidden_words:
|
||||||
if word in text:
|
if word in text:
|
||||||
logger.warning(f"OCR检测到违禁词: '{word}' (来自文本: '{text}') 置信度: {conf:.4f}")
|
self.logger.warning(f"OCR检测到违禁词: '{word}' (来自文本: '{text}') 置信度: {conf:.4f}")
|
||||||
all_prohibited.append(word)
|
all_prohibited.append(word)
|
||||||
all_confidences.append(conf)
|
all_confidences.append(conf)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"OCR检测过程中发生错误: {e}", exc_info=True)
|
self.logger.error(f"OCR检测过程中发生错误: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
# 返回检测结果(是否有违禁词、所有违禁词列表、对应置信度列表)
|
# 返回检测结果(是否有违禁词、所有违禁词列表、对应置信度列表)
|
||||||
return len(all_prohibited) > 0, all_prohibited, all_confidences
|
return len(all_prohibited) > 0, all_prohibited, all_confidences
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# def test_single_image():
|
|
||||||
# """测试单张图片的OCR违规检测(显示所有违禁词)"""
|
|
||||||
# TEST_IMAGE_PATH = r"ocr/images/img_7.png" # 修正路径格式
|
|
||||||
# FORBIDDEN_WORDS_PATH = r"ocr/forbidden_words.txt"
|
|
||||||
# CONFIDENCE_THRESHOLD = 0.5
|
|
||||||
#
|
|
||||||
# detector = OCRViolationDetector(
|
|
||||||
# forbidden_words_path=FORBIDDEN_WORDS_PATH,
|
|
||||||
# ocr_confidence_threshold=CONFIDENCE_THRESHOLD
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# if not os.path.exists(TEST_IMAGE_PATH):
|
|
||||||
# print(f"错误:图片文件不存在 - {TEST_IMAGE_PATH}")
|
|
||||||
# return
|
|
||||||
#
|
|
||||||
# frame = cv2.imread(TEST_IMAGE_PATH)
|
|
||||||
# if frame is None:
|
|
||||||
# print(f"错误:无法读取图片 - {TEST_IMAGE_PATH}")
|
|
||||||
# return
|
|
||||||
#
|
|
||||||
# # 执行检测
|
|
||||||
# has_violation, words, confidences = detector.detect(frame)
|
|
||||||
#
|
|
||||||
# # 输出所有检测到的违禁词
|
|
||||||
# if has_violation:
|
|
||||||
# print(f"测试结果:图片中共检测到 {len(words)} 个违禁词:")
|
|
||||||
# for word, conf in zip(words, confidences):
|
|
||||||
# print(f"- {word}(置信度:{conf:.4f})")
|
|
||||||
# else:
|
|
||||||
# print("测试结果:图片中未检测到违禁词")
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# if __name__ == "__main__":
|
|
||||||
# print("开始单张图片OCR违规检测测试...")
|
|
||||||
# test_single_image()
|
|
||||||
# print("测试完成")
|
|
21
rtc/rtc.py
21
rtc/rtc.py
@ -4,6 +4,7 @@ import cv2 # 导入OpenCV库
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration
|
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration
|
||||||
from aiortc.mediastreams import MediaStreamTrack
|
from aiortc.mediastreams import MediaStreamTrack
|
||||||
|
from ocr.ocr_violation_detector import OCRViolationDetector
|
||||||
|
|
||||||
|
|
||||||
class VideoTrack(MediaStreamTrack):
|
class VideoTrack(MediaStreamTrack):
|
||||||
@ -47,7 +48,7 @@ async def rtc_frame_receiver(url, frame_queue):
|
|||||||
if frame_queue.empty():
|
if frame_queue.empty():
|
||||||
# 队列为空、放入当前cv2帧
|
# 队列为空、放入当前cv2帧
|
||||||
await frame_queue.put(frame_cv2)
|
await frame_queue.put(frame_cv2)
|
||||||
print(f"第{total_frames}帧:队列为空、已放入新的cv2帧,尺寸: {frame_cv2.shape}")
|
# print(f"第{total_frames}帧:队列为空、已放入新的cv2帧,尺寸: {frame_cv2.shape}")
|
||||||
else:
|
else:
|
||||||
# 队列非空、说明上一帧还未处理、跳过当前帧
|
# 队列非空、说明上一帧还未处理、跳过当前帧
|
||||||
print(f"第{total_frames}帧:队列非空、跳过该帧")
|
print(f"第{total_frames}帧:队列非空、跳过该帧")
|
||||||
@ -93,23 +94,35 @@ async def frame_consumer(frame_queue):
|
|||||||
|
|
||||||
Args: frame_queue: 帧队列
|
Args: frame_queue: 帧队列
|
||||||
"""
|
"""
|
||||||
|
# 创建OCR检测器实例(请替换为实际的违禁词文件路径)
|
||||||
|
ocr_detector = OCRViolationDetector(
|
||||||
|
forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt", # 替换为实际路径
|
||||||
|
ocr_confidence_threshold=0.5,)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# 从队列中获取cv2帧(队列为空时会阻塞等待新帧)
|
# 从队列中获取cv2帧(队列为空时会阻塞等待新帧)
|
||||||
current_frame = await frame_queue.get()
|
current_frame = await frame_queue.get()
|
||||||
|
|
||||||
|
ocr_detector.detect(current_frame)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 验证这是cv2可以处理的帧
|
# 验证这是cv2可以处理的帧
|
||||||
print(f"从队列获取到cv2帧、尺寸: {current_frame.shape}、数据类型: {current_frame.dtype}")
|
# print(f"从队列获取到cv2帧、尺寸: {current_frame.shape}、数据类型: {current_frame.dtype}")
|
||||||
|
|
||||||
# 这里可以添加cv2的处理代码,例如显示帧
|
# 这里可以添加cv2的处理代码,例如显示帧
|
||||||
# cv2.imshow('Received Frame', current_frame)
|
# cv2.imshow('Received Frame', current_frame)
|
||||||
# if cv2.waitKey(1) & 0xFF == ord('q'):
|
# if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||||
# break
|
# break
|
||||||
|
|
||||||
print("cv2帧处理完成")
|
# print("cv2帧处理完成")
|
||||||
|
|
||||||
# 标记任务完成
|
# 标记任务完成
|
||||||
frame_queue.task_done()
|
frame_queue.task_done()
|
||||||
print("帧处理完成、队列已清空")
|
# print("帧处理完成、队列已清空")
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
|
Reference in New Issue
Block a user