Files
video/core/ocr.py

254 lines
8.2 KiB
Python
Raw Normal View History

2025-09-04 22:59:27 +08:00
import os
import cv2
2025-09-05 17:23:50 +08:00
import gc
import time
import threading
2025-09-09 09:42:23 +08:00
import numpy as np
from paddleocr import PaddleOCR
2025-09-04 22:59:27 +08:00
from service.sensitive_service import get_all_sensitive_words
2025-09-09 09:42:23 +08:00
# 解决NumPy 1.20+版本中np.int已移除的兼容性问题
try:
if not hasattr(np, 'int'):
np.int = int
except Exception as e:
print(f"处理NumPy兼容性时出错: {e}")
2025-09-04 22:59:27 +08:00
# 全局变量
_ocr_engine = None
_forbidden_words = set()
_conf_threshold = 0.5
2025-09-05 17:23:50 +08:00
# 资源管理变量
_ref_count = 0
_last_used_time = 0
_lock = threading.Lock()
_release_timeout = 5 # 30秒无使用则释放
_is_releasing = False # 标记是否正在释放
2025-09-09 09:42:23 +08:00
# 并行处理配置
_max_workers = 4 # 并行处理的线程数
2025-09-05 17:23:50 +08:00
# 调试用计数器
_debug_counter = {
"created": 0,
"released": 0,
"detected": 0
}
def _release_engine():
"""释放OCR引擎资源"""
global _ocr_engine, _is_releasing
if not _ocr_engine or _is_releasing:
return
try:
_is_releasing = True
_ocr_engine = None
_debug_counter["released"] += 1
print(f"OCR engine released. Stats: {_debug_counter}")
# 清理GPU缓存
gc.collect()
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
except ImportError:
pass
try:
2025-09-09 09:42:23 +08:00
import paddle
if paddle.is_compiled_with_cuda():
paddle.device.cuda.empty_cache()
2025-09-05 17:23:50 +08:00
except ImportError:
pass
finally:
_is_releasing = False
def _monitor_thread():
2025-09-09 09:42:23 +08:00
"""监控线程,优化检查逻辑"""
2025-09-05 17:23:50 +08:00
global _ref_count, _last_used_time, _ocr_engine
while True:
time.sleep(5) # 每5秒检查一次
with _lock:
if _ocr_engine and _ref_count == 0 and not _is_releasing:
elapsed = time.time() - _last_used_time
if elapsed > _release_timeout:
print(f"Idle timeout ({elapsed:.1f}s > {_release_timeout}s), releasing engine")
_release_engine()
2025-09-04 22:59:27 +08:00
def load_model():
2025-09-05 17:23:50 +08:00
"""加载违禁词列表和初始化监控线程"""
global _forbidden_words
# 确保监控线程只启动一次
if not any(t.name == "OCRMonitor" for t in threading.enumerate()):
threading.Thread(target=_monitor_thread, daemon=True, name="OCRMonitor").start()
print("OCR monitor thread started")
2025-09-04 22:59:27 +08:00
# 加载违禁词
try:
_forbidden_words = get_all_sensitive_words()
2025-09-05 17:23:50 +08:00
print(f"Loaded {len(_forbidden_words)} forbidden words")
2025-09-04 22:59:27 +08:00
except Exception as e:
print(f"Forbidden words load error: {e}")
2025-09-05 17:23:50 +08:00
return False
2025-09-04 22:59:27 +08:00
2025-09-05 17:23:50 +08:00
return True
2025-09-04 22:59:27 +08:00
def detect(frame):
2025-09-09 09:42:23 +08:00
"""OCR检测支持并行处理"""
global _ocr_engine, _forbidden_words, _conf_threshold, _ref_count, _last_used_time, _max_workers
2025-09-05 17:23:50 +08:00
# 验证前置条件
if not _forbidden_words:
return (False, "违禁词未初始化")
if frame is None or frame.size == 0:
return (False, "无效帧数据")
# 增加引用计数并获取引擎实例
engine = None
with _lock:
_ref_count += 1
_last_used_time = time.time()
_debug_counter["detected"] += 1
# 初始化引擎(如果未初始化且不在释放中)
if not _ocr_engine and not _is_releasing:
try:
2025-09-09 09:42:23 +08:00
# 初始化PaddleOCR设置并行处理参数
_ocr_engine = PaddleOCR(
use_angle_cls=True,
lang="ch",
show_log=False,
use_gpu=True,
max_text_length=1024,
threads=_max_workers
)
2025-09-05 17:23:50 +08:00
_debug_counter["created"] += 1
2025-09-09 09:42:23 +08:00
print(f"PaddleOCR engine initialized with {_max_workers} workers. Stats: {_debug_counter}")
2025-09-05 17:23:50 +08:00
except Exception as e:
print(f"OCR model load failed: {e}")
2025-09-09 09:42:23 +08:00
_ref_count -= 1
2025-09-05 17:23:50 +08:00
return (False, f"引擎初始化失败: {str(e)}")
engine = _ocr_engine
# 检查引擎是否可用
if not engine:
with _lock:
_ref_count -= 1
return (False, "OCR引擎不可用")
2025-09-04 22:59:27 +08:00
try:
2025-09-05 17:23:50 +08:00
# 执行OCR检测
2025-09-09 09:42:23 +08:00
ocr_res = engine.ocr(frame, cls=True)
2025-09-05 17:23:50 +08:00
# 验证OCR结果格式
2025-09-09 09:42:23 +08:00
if not ocr_res or not isinstance(ocr_res, list):
2025-09-05 17:23:50 +08:00
return (False, "无OCR结果")
2025-09-09 09:42:23 +08:00
# 处理OCR结果 - 兼容多种格式
texts = []
confs = []
for line in ocr_res:
if line is None:
continue
# 处理line可能是列表或直接是文本信息的情况
if isinstance(line, list):
items_to_process = line
else:
items_to_process = [line]
for item in items_to_process:
2025-09-09 16:30:12 +08:00
# 精确识别并忽略图片坐标位置信息 [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
if isinstance(item, list) and len(item) == 4: # 四边形有4个顶点
is_coordinate = True
for point in item:
# 每个顶点应该是包含2个数字的列表
if not (isinstance(point, list) and len(point) == 2 and
all(isinstance(coord, (int, float)) for coord in point)):
is_coordinate = False
break
if is_coordinate:
continue # 是坐标信息,直接忽略
# 跳过纯数字列表(其他可能的坐标形式)
2025-09-09 09:42:23 +08:00
if isinstance(item, list) and all(isinstance(x, (int, float)) for x in item):
continue
# 处理元组形式的文本和置信度 (text, confidence)
if isinstance(item, tuple) and len(item) == 2:
text, conf = item
if isinstance(text, str) and isinstance(conf, (int, float)):
texts.append(text.strip())
confs.append(float(conf))
continue
# 处理列表形式的[坐标信息, (text, confidence)]
if isinstance(item, list) and len(item) >= 2:
# 尝试从列表中提取文本和置信度
text_data = item[1]
if isinstance(text_data, tuple) and len(text_data) == 2:
text, conf = text_data
if isinstance(text, str) and isinstance(conf, (int, float)):
texts.append(text.strip())
confs.append(float(conf))
continue
elif isinstance(text_data, str):
# 只有文本没有置信度的情况
texts.append(text_data.strip())
confs.append(1.0) # 默认最高置信度
continue
# 无法识别的格式,记录日志
print(f"无法解析的OCR结果格式: {item}")
2025-09-05 17:23:50 +08:00
if len(texts) != len(confs):
return (False, "OCR结果格式异常")
# 筛选违禁词
vio_info = []
for txt, conf in zip(texts, confs):
if conf < _conf_threshold:
continue
matched = [w for w in _forbidden_words if w in txt]
if matched:
vio_info.append(f"文本: '{txt}' 包含违禁词: {', '.join(matched)} (置信度: {conf:.2f})")
# 构建结果
has_text = len(texts) > 0
has_violation = len(vio_info) > 0
if not has_text:
return (False, "未识别到文本")
elif has_violation:
return (True, "; ".join(vio_info))
else:
return (False, "未检测到违禁词")
2025-09-04 22:59:27 +08:00
except Exception as e:
print(f"OCR detect error: {e}")
return (False, f"检测错误: {str(e)}")
2025-09-05 17:23:50 +08:00
finally:
2025-09-09 09:42:23 +08:00
# 减少引用计数,确保线程安全
2025-09-05 17:23:50 +08:00
with _lock:
_ref_count = max(0, _ref_count - 1)
if _ref_count > 0:
_last_used_time = time.time()
2025-09-09 09:42:23 +08:00
def batch_detect(frames):
"""批量检测接口,充分利用并行能力"""
results = []
for frame in frames:
results.append(detect(frame))
return results