Files
video/core/ocr.py
2025-09-05 17:23:50 +08:00

187 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import gc
import time
import threading
from rapidocr import RapidOCR
from service.sensitive_service import get_all_sensitive_words
# 全局变量
_ocr_engine = None
_forbidden_words = set()
_conf_threshold = 0.5
ocr_config_path = os.path.join(os.path.dirname(__file__), "config", "config.yaml")
# 资源管理变量
_ref_count = 0
_last_used_time = 0
_lock = threading.Lock()
_release_timeout = 5 # 30秒无使用则释放
_is_releasing = False # 标记是否正在释放
# 调试用计数器
_debug_counter = {
"created": 0,
"released": 0,
"detected": 0
}
def _release_engine():
"""释放OCR引擎资源"""
global _ocr_engine, _is_releasing
if not _ocr_engine or _is_releasing:
return
try:
_is_releasing = True
# 如果有释放方法则调用
if hasattr(_ocr_engine, 'release'):
_ocr_engine.release()
_ocr_engine = None
_debug_counter["released"] += 1
print(f"OCR engine released. Stats: {_debug_counter}")
# 清理GPU缓存
gc.collect()
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
except ImportError:
pass
try:
import tensorflow as tf
tf.keras.backend.clear_session()
except ImportError:
pass
finally:
_is_releasing = False
def _monitor_thread():
"""监控线程,优化检查逻辑"""
global _ref_count, _last_used_time, _ocr_engine
while True:
time.sleep(5) # 每5秒检查一次
with _lock:
# 只有当引擎存在、没有引用且超时,才释放
if _ocr_engine and _ref_count == 0 and not _is_releasing:
elapsed = time.time() - _last_used_time
if elapsed > _release_timeout:
print(f"Idle timeout ({elapsed:.1f}s > {_release_timeout}s), releasing engine")
_release_engine()
def load_model():
"""加载违禁词列表和初始化监控线程"""
global _forbidden_words
# 确保监控线程只启动一次
if not any(t.name == "OCRMonitor" for t in threading.enumerate()):
threading.Thread(target=_monitor_thread, daemon=True, name="OCRMonitor").start()
print("OCR monitor thread started")
# 加载违禁词
try:
_forbidden_words = get_all_sensitive_words()
print(f"Loaded {len(_forbidden_words)} forbidden words")
except Exception as e:
print(f"Forbidden words load error: {e}")
return False
# 验证配置文件
if not os.path.exists(ocr_config_path):
print(f"OCR config not found: {ocr_config_path}")
return False
return True
def detect(frame):
"""OCR检测优化引用计数管理"""
global _ocr_engine, _forbidden_words, _conf_threshold, _ref_count, _last_used_time
# 验证前置条件
if not _forbidden_words:
return (False, "违禁词未初始化")
if frame is None or frame.size == 0:
return (False, "无效帧数据")
if not os.path.exists(ocr_config_path):
return (False, f"OCR配置文件不存在: {ocr_config_path}")
# 增加引用计数并获取引擎实例
engine = None
with _lock:
_ref_count += 1
_last_used_time = time.time()
_debug_counter["detected"] += 1
# 初始化引擎(如果未初始化且不在释放中)
if not _ocr_engine and not _is_releasing:
try:
_ocr_engine = RapidOCR(config_path=ocr_config_path)
_debug_counter["created"] += 1
print(f"OCR engine initialized. Stats: {_debug_counter}")
except Exception as e:
print(f"OCR model load failed: {e}")
_ref_count -= 1 # 恢复引用计数
return (False, f"引擎初始化失败: {str(e)}")
# 获取当前引擎引用
engine = _ocr_engine
# 检查引擎是否可用
if not engine:
with _lock:
_ref_count -= 1
return (False, "OCR引擎不可用")
try:
# 执行OCR检测
ocr_res = engine(frame)
# 验证OCR结果格式
if not ocr_res or not hasattr(ocr_res, 'txts') or not hasattr(ocr_res, 'scores'):
return (False, "无OCR结果")
# 处理OCR结果
texts = [t.strip() for t in ocr_res.txts if t and isinstance(t, str)]
confs = [c for c in ocr_res.scores if c and isinstance(c, (int, float))]
if len(texts) != len(confs):
return (False, "OCR结果格式异常")
# 筛选违禁词
vio_info = []
for txt, conf in zip(texts, confs):
if conf < _conf_threshold:
continue
matched = [w for w in _forbidden_words if w in txt]
if matched:
vio_info.append(f"文本: '{txt}' 包含违禁词: {', '.join(matched)} (置信度: {conf:.2f})")
# 构建结果
has_text = len(texts) > 0
has_violation = len(vio_info) > 0
if not has_text:
return (False, "未识别到文本")
elif has_violation:
return (True, "; ".join(vio_info))
else:
return (False, "未检测到违禁词")
except Exception as e:
print(f"OCR detect error: {e}")
return (False, f"检测错误: {str(e)}")
finally:
# 减少引用计数,确保线程安全
with _lock:
_ref_count = max(0, _ref_count - 1)
# 持续使用时更新最后使用时间
if _ref_count > 0:
_last_used_time = time.time()