paddleocr
This commit is contained in:
113
core/ocr.py
113
core/ocr.py
@ -3,14 +3,21 @@ import cv2
|
|||||||
import gc
|
import gc
|
||||||
import time
|
import time
|
||||||
import threading
|
import threading
|
||||||
from rapidocr import RapidOCR
|
import numpy as np
|
||||||
|
from paddleocr import PaddleOCR
|
||||||
from service.sensitive_service import get_all_sensitive_words
|
from service.sensitive_service import get_all_sensitive_words
|
||||||
|
|
||||||
|
# 解决NumPy 1.20+版本中np.int已移除的兼容性问题
|
||||||
|
try:
|
||||||
|
if not hasattr(np, 'int'):
|
||||||
|
np.int = int
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理NumPy兼容性时出错: {e}")
|
||||||
|
|
||||||
# 全局变量
|
# 全局变量
|
||||||
_ocr_engine = None
|
_ocr_engine = None
|
||||||
_forbidden_words = set()
|
_forbidden_words = set()
|
||||||
_conf_threshold = 0.5
|
_conf_threshold = 0.5
|
||||||
ocr_config_path = os.path.join(os.path.dirname(__file__), "config", "config.yaml")
|
|
||||||
|
|
||||||
# 资源管理变量
|
# 资源管理变量
|
||||||
_ref_count = 0
|
_ref_count = 0
|
||||||
@ -19,6 +26,9 @@ _lock = threading.Lock()
|
|||||||
_release_timeout = 5 # 30秒无使用则释放
|
_release_timeout = 5 # 30秒无使用则释放
|
||||||
_is_releasing = False # 标记是否正在释放
|
_is_releasing = False # 标记是否正在释放
|
||||||
|
|
||||||
|
# 并行处理配置
|
||||||
|
_max_workers = 4 # 并行处理的线程数
|
||||||
|
|
||||||
# 调试用计数器
|
# 调试用计数器
|
||||||
_debug_counter = {
|
_debug_counter = {
|
||||||
"created": 0,
|
"created": 0,
|
||||||
@ -35,9 +45,6 @@ def _release_engine():
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
_is_releasing = True
|
_is_releasing = True
|
||||||
# 如果有释放方法则调用
|
|
||||||
if hasattr(_ocr_engine, 'release'):
|
|
||||||
_ocr_engine.release()
|
|
||||||
_ocr_engine = None
|
_ocr_engine = None
|
||||||
_debug_counter["released"] += 1
|
_debug_counter["released"] += 1
|
||||||
print(f"OCR engine released. Stats: {_debug_counter}")
|
print(f"OCR engine released. Stats: {_debug_counter}")
|
||||||
@ -52,8 +59,9 @@ def _release_engine():
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
import tensorflow as tf
|
import paddle
|
||||||
tf.keras.backend.clear_session()
|
if paddle.is_compiled_with_cuda():
|
||||||
|
paddle.device.cuda.empty_cache()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
@ -61,12 +69,11 @@ def _release_engine():
|
|||||||
|
|
||||||
|
|
||||||
def _monitor_thread():
|
def _monitor_thread():
|
||||||
"""监控线程、优化检查逻辑"""
|
"""监控线程,优化检查逻辑"""
|
||||||
global _ref_count, _last_used_time, _ocr_engine
|
global _ref_count, _last_used_time, _ocr_engine
|
||||||
while True:
|
while True:
|
||||||
time.sleep(5) # 每5秒检查一次
|
time.sleep(5) # 每5秒检查一次
|
||||||
with _lock:
|
with _lock:
|
||||||
# 只有当引擎存在、没有引用且超时、才释放
|
|
||||||
if _ocr_engine and _ref_count == 0 and not _is_releasing:
|
if _ocr_engine and _ref_count == 0 and not _is_releasing:
|
||||||
elapsed = time.time() - _last_used_time
|
elapsed = time.time() - _last_used_time
|
||||||
if elapsed > _release_timeout:
|
if elapsed > _release_timeout:
|
||||||
@ -91,25 +98,18 @@ def load_model():
|
|||||||
print(f"Forbidden words load error: {e}")
|
print(f"Forbidden words load error: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 验证配置文件
|
|
||||||
if not os.path.exists(ocr_config_path):
|
|
||||||
print(f"OCR config not found: {ocr_config_path}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def detect(frame):
|
def detect(frame):
|
||||||
"""OCR检测、优化引用计数管理"""
|
"""OCR检测,支持并行处理"""
|
||||||
global _ocr_engine, _forbidden_words, _conf_threshold, _ref_count, _last_used_time
|
global _ocr_engine, _forbidden_words, _conf_threshold, _ref_count, _last_used_time, _max_workers
|
||||||
|
|
||||||
# 验证前置条件
|
# 验证前置条件
|
||||||
if not _forbidden_words:
|
if not _forbidden_words:
|
||||||
return (False, "违禁词未初始化")
|
return (False, "违禁词未初始化")
|
||||||
if frame is None or frame.size == 0:
|
if frame is None or frame.size == 0:
|
||||||
return (False, "无效帧数据")
|
return (False, "无效帧数据")
|
||||||
if not os.path.exists(ocr_config_path):
|
|
||||||
return (False, f"OCR配置文件不存在: {ocr_config_path}")
|
|
||||||
|
|
||||||
# 增加引用计数并获取引擎实例
|
# 增加引用计数并获取引擎实例
|
||||||
engine = None
|
engine = None
|
||||||
@ -121,15 +121,22 @@ def detect(frame):
|
|||||||
# 初始化引擎(如果未初始化且不在释放中)
|
# 初始化引擎(如果未初始化且不在释放中)
|
||||||
if not _ocr_engine and not _is_releasing:
|
if not _ocr_engine and not _is_releasing:
|
||||||
try:
|
try:
|
||||||
_ocr_engine = RapidOCR(config_path=ocr_config_path)
|
# 初始化PaddleOCR,设置并行处理参数
|
||||||
|
_ocr_engine = PaddleOCR(
|
||||||
|
use_angle_cls=True,
|
||||||
|
lang="ch",
|
||||||
|
show_log=False,
|
||||||
|
use_gpu=True,
|
||||||
|
max_text_length=1024,
|
||||||
|
threads=_max_workers
|
||||||
|
)
|
||||||
_debug_counter["created"] += 1
|
_debug_counter["created"] += 1
|
||||||
print(f"OCR engine initialized. Stats: {_debug_counter}")
|
print(f"PaddleOCR engine initialized with {_max_workers} workers. Stats: {_debug_counter}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"OCR model load failed: {e}")
|
print(f"OCR model load failed: {e}")
|
||||||
_ref_count -= 1 # 恢复引用计数
|
_ref_count -= 1
|
||||||
return (False, f"引擎初始化失败: {str(e)}")
|
return (False, f"引擎初始化失败: {str(e)}")
|
||||||
|
|
||||||
# 获取当前引擎引用
|
|
||||||
engine = _ocr_engine
|
engine = _ocr_engine
|
||||||
|
|
||||||
# 检查引擎是否可用
|
# 检查引擎是否可用
|
||||||
@ -140,15 +147,56 @@ def detect(frame):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 执行OCR检测
|
# 执行OCR检测
|
||||||
ocr_res = engine(frame)
|
ocr_res = engine.ocr(frame, cls=True)
|
||||||
|
|
||||||
# 验证OCR结果格式
|
# 验证OCR结果格式
|
||||||
if not ocr_res or not hasattr(ocr_res, 'txts') or not hasattr(ocr_res, 'scores'):
|
if not ocr_res or not isinstance(ocr_res, list):
|
||||||
return (False, "无OCR结果")
|
return (False, "无OCR结果")
|
||||||
|
|
||||||
# 处理OCR结果
|
# 处理OCR结果 - 兼容多种格式
|
||||||
texts = [t.strip() for t in ocr_res.txts if t and isinstance(t, str)]
|
texts = []
|
||||||
confs = [c for c in ocr_res.scores if c and isinstance(c, (int, float))]
|
confs = []
|
||||||
|
for line in ocr_res:
|
||||||
|
if line is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 处理line可能是列表或直接是文本信息的情况
|
||||||
|
if isinstance(line, list):
|
||||||
|
items_to_process = line
|
||||||
|
else:
|
||||||
|
items_to_process = [line]
|
||||||
|
|
||||||
|
for item in items_to_process:
|
||||||
|
# 跳过纯数字列表(可能是坐标信息)
|
||||||
|
if isinstance(item, list) and all(isinstance(x, (int, float)) for x in item):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 处理元组形式的文本和置信度 (text, confidence)
|
||||||
|
if isinstance(item, tuple) and len(item) == 2:
|
||||||
|
text, conf = item
|
||||||
|
if isinstance(text, str) and isinstance(conf, (int, float)):
|
||||||
|
texts.append(text.strip())
|
||||||
|
confs.append(float(conf))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 处理列表形式的[坐标信息, (text, confidence)]
|
||||||
|
if isinstance(item, list) and len(item) >= 2:
|
||||||
|
# 尝试从列表中提取文本和置信度
|
||||||
|
text_data = item[1]
|
||||||
|
if isinstance(text_data, tuple) and len(text_data) == 2:
|
||||||
|
text, conf = text_data
|
||||||
|
if isinstance(text, str) and isinstance(conf, (int, float)):
|
||||||
|
texts.append(text.strip())
|
||||||
|
confs.append(float(conf))
|
||||||
|
continue
|
||||||
|
elif isinstance(text_data, str):
|
||||||
|
# 只有文本没有置信度的情况
|
||||||
|
texts.append(text_data.strip())
|
||||||
|
confs.append(1.0) # 默认最高置信度
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 无法识别的格式,记录日志
|
||||||
|
print(f"无法解析的OCR结果格式: {item}")
|
||||||
|
|
||||||
if len(texts) != len(confs):
|
if len(texts) != len(confs):
|
||||||
return (False, "OCR结果格式异常")
|
return (False, "OCR结果格式异常")
|
||||||
@ -178,9 +226,16 @@ def detect(frame):
|
|||||||
return (False, f"检测错误: {str(e)}")
|
return (False, f"检测错误: {str(e)}")
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# 减少引用计数、确保线程安全
|
# 减少引用计数,确保线程安全
|
||||||
with _lock:
|
with _lock:
|
||||||
_ref_count = max(0, _ref_count - 1)
|
_ref_count = max(0, _ref_count - 1)
|
||||||
# 持续使用时更新最后使用时间
|
|
||||||
if _ref_count > 0:
|
if _ref_count > 0:
|
||||||
_last_used_time = time.time()
|
_last_used_time = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
def batch_detect(frames):
|
||||||
|
"""批量检测接口,充分利用并行能力"""
|
||||||
|
results = []
|
||||||
|
for frame in frames:
|
||||||
|
results.append(detect(frame))
|
||||||
|
return results
|
||||||
|
Reference in New Issue
Block a user