Files
video/core/ocr.py
2025-09-09 09:42:23 +08:00

242 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import gc
import time
import threading
import numpy as np
from paddleocr import PaddleOCR
from service.sensitive_service import get_all_sensitive_words
# 解决NumPy 1.20+版本中np.int已移除的兼容性问题
try:
if not hasattr(np, 'int'):
np.int = int
except Exception as e:
print(f"处理NumPy兼容性时出错: {e}")
# 全局变量
_ocr_engine = None
_forbidden_words = set()
_conf_threshold = 0.5
# 资源管理变量
_ref_count = 0
_last_used_time = 0
_lock = threading.Lock()
_release_timeout = 5 # 30秒无使用则释放
_is_releasing = False # 标记是否正在释放
# 并行处理配置
_max_workers = 4 # 并行处理的线程数
# 调试用计数器
_debug_counter = {
"created": 0,
"released": 0,
"detected": 0
}
def _release_engine():
"""释放OCR引擎资源"""
global _ocr_engine, _is_releasing
if not _ocr_engine or _is_releasing:
return
try:
_is_releasing = True
_ocr_engine = None
_debug_counter["released"] += 1
print(f"OCR engine released. Stats: {_debug_counter}")
# 清理GPU缓存
gc.collect()
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
except ImportError:
pass
try:
import paddle
if paddle.is_compiled_with_cuda():
paddle.device.cuda.empty_cache()
except ImportError:
pass
finally:
_is_releasing = False
def _monitor_thread():
"""监控线程,优化检查逻辑"""
global _ref_count, _last_used_time, _ocr_engine
while True:
time.sleep(5) # 每5秒检查一次
with _lock:
if _ocr_engine and _ref_count == 0 and not _is_releasing:
elapsed = time.time() - _last_used_time
if elapsed > _release_timeout:
print(f"Idle timeout ({elapsed:.1f}s > {_release_timeout}s), releasing engine")
_release_engine()
def load_model():
"""加载违禁词列表和初始化监控线程"""
global _forbidden_words
# 确保监控线程只启动一次
if not any(t.name == "OCRMonitor" for t in threading.enumerate()):
threading.Thread(target=_monitor_thread, daemon=True, name="OCRMonitor").start()
print("OCR monitor thread started")
# 加载违禁词
try:
_forbidden_words = get_all_sensitive_words()
print(f"Loaded {len(_forbidden_words)} forbidden words")
except Exception as e:
print(f"Forbidden words load error: {e}")
return False
return True
def detect(frame):
"""OCR检测支持并行处理"""
global _ocr_engine, _forbidden_words, _conf_threshold, _ref_count, _last_used_time, _max_workers
# 验证前置条件
if not _forbidden_words:
return (False, "违禁词未初始化")
if frame is None or frame.size == 0:
return (False, "无效帧数据")
# 增加引用计数并获取引擎实例
engine = None
with _lock:
_ref_count += 1
_last_used_time = time.time()
_debug_counter["detected"] += 1
# 初始化引擎(如果未初始化且不在释放中)
if not _ocr_engine and not _is_releasing:
try:
# 初始化PaddleOCR设置并行处理参数
_ocr_engine = PaddleOCR(
use_angle_cls=True,
lang="ch",
show_log=False,
use_gpu=True,
max_text_length=1024,
threads=_max_workers
)
_debug_counter["created"] += 1
print(f"PaddleOCR engine initialized with {_max_workers} workers. Stats: {_debug_counter}")
except Exception as e:
print(f"OCR model load failed: {e}")
_ref_count -= 1
return (False, f"引擎初始化失败: {str(e)}")
engine = _ocr_engine
# 检查引擎是否可用
if not engine:
with _lock:
_ref_count -= 1
return (False, "OCR引擎不可用")
try:
# 执行OCR检测
ocr_res = engine.ocr(frame, cls=True)
# 验证OCR结果格式
if not ocr_res or not isinstance(ocr_res, list):
return (False, "无OCR结果")
# 处理OCR结果 - 兼容多种格式
texts = []
confs = []
for line in ocr_res:
if line is None:
continue
# 处理line可能是列表或直接是文本信息的情况
if isinstance(line, list):
items_to_process = line
else:
items_to_process = [line]
for item in items_to_process:
# 跳过纯数字列表(可能是坐标信息)
if isinstance(item, list) and all(isinstance(x, (int, float)) for x in item):
continue
# 处理元组形式的文本和置信度 (text, confidence)
if isinstance(item, tuple) and len(item) == 2:
text, conf = item
if isinstance(text, str) and isinstance(conf, (int, float)):
texts.append(text.strip())
confs.append(float(conf))
continue
# 处理列表形式的[坐标信息, (text, confidence)]
if isinstance(item, list) and len(item) >= 2:
# 尝试从列表中提取文本和置信度
text_data = item[1]
if isinstance(text_data, tuple) and len(text_data) == 2:
text, conf = text_data
if isinstance(text, str) and isinstance(conf, (int, float)):
texts.append(text.strip())
confs.append(float(conf))
continue
elif isinstance(text_data, str):
# 只有文本没有置信度的情况
texts.append(text_data.strip())
confs.append(1.0) # 默认最高置信度
continue
# 无法识别的格式,记录日志
print(f"无法解析的OCR结果格式: {item}")
if len(texts) != len(confs):
return (False, "OCR结果格式异常")
# 筛选违禁词
vio_info = []
for txt, conf in zip(texts, confs):
if conf < _conf_threshold:
continue
matched = [w for w in _forbidden_words if w in txt]
if matched:
vio_info.append(f"文本: '{txt}' 包含违禁词: {', '.join(matched)} (置信度: {conf:.2f})")
# 构建结果
has_text = len(texts) > 0
has_violation = len(vio_info) > 0
if not has_text:
return (False, "未识别到文本")
elif has_violation:
return (True, "; ".join(vio_info))
else:
return (False, "未检测到违禁词")
except Exception as e:
print(f"OCR detect error: {e}")
return (False, f"检测错误: {str(e)}")
finally:
# 减少引用计数,确保线程安全
with _lock:
_ref_count = max(0, _ref_count - 1)
if _ref_count > 0:
_last_used_time = time.time()
def batch_detect(frames):
"""批量检测接口,充分利用并行能力"""
results = []
for frame in frames:
results.append(detect(frame))
return results