76 lines
2.3 KiB
Python
76 lines
2.3 KiB
Python
|
import os
|
|||
|
import cv2
|
|||
|
from rapidocr import RapidOCR
|
|||
|
from service.sensitive_service import get_all_sensitive_words
|
|||
|
|
|||
|
# 全局变量
|
|||
|
_ocr_engine = None
|
|||
|
_forbidden_words = set()
|
|||
|
_conf_threshold = 0.5
|
|||
|
|
|||
|
ocr_config_path = os.path.join(os.path.dirname(__file__), "config", "config.yaml")
|
|||
|
|
|||
|
|
|||
|
def load_model():
|
|||
|
"""加载OCR引擎及违禁词列表"""
|
|||
|
global _ocr_engine, _forbidden_words, _conf_threshold
|
|||
|
|
|||
|
# 加载违禁词
|
|||
|
try:
|
|||
|
_forbidden_words = get_all_sensitive_words()
|
|||
|
except Exception as e:
|
|||
|
print(f"Forbidden words load error: {e}")
|
|||
|
|
|||
|
# 初始化OCR引擎
|
|||
|
if not os.path.exists(ocr_config_path):
|
|||
|
print(f"OCR config not found: {ocr_config_path}")
|
|||
|
return False
|
|||
|
|
|||
|
try:
|
|||
|
_ocr_engine = RapidOCR(config_path=ocr_config_path)
|
|||
|
except Exception as e:
|
|||
|
print(f"OCR model load failed: {e}")
|
|||
|
return False
|
|||
|
|
|||
|
return True if _ocr_engine else False
|
|||
|
|
|||
|
|
|||
|
def detect(frame):
|
|||
|
"""OCR检测并筛选违禁词,返回(是否检测到违禁词, 结果字符串)"""
|
|||
|
if not _ocr_engine or not _forbidden_words or frame is None or frame.size == 0:
|
|||
|
return (False, "未初始化或无效帧")
|
|||
|
|
|||
|
try:
|
|||
|
ocr_res = _ocr_engine(frame)
|
|||
|
except Exception as e:
|
|||
|
print(f"OCR detect error: {e}")
|
|||
|
return (False, f"检测错误: {str(e)}")
|
|||
|
|
|||
|
if not ocr_res or not hasattr(ocr_res, 'txts') or not hasattr(ocr_res, 'scores'):
|
|||
|
return (False, "无OCR结果")
|
|||
|
|
|||
|
# 处理OCR结果
|
|||
|
texts = [t.strip() for t in ocr_res.txts if t and isinstance(t, str)]
|
|||
|
confs = [c for c in ocr_res.scores if c and isinstance(c, (int, float))]
|
|||
|
if len(texts) != len(confs):
|
|||
|
return (False, "OCR结果格式异常")
|
|||
|
|
|||
|
# 筛选违禁词
|
|||
|
vio_info = []
|
|||
|
for txt, conf in zip(texts, confs):
|
|||
|
if conf < _conf_threshold:
|
|||
|
continue
|
|||
|
matched = [w for w in _forbidden_words if w in txt]
|
|||
|
if matched:
|
|||
|
vio_info.append(f"文本: '{txt}' 包含违禁词: {', '.join(matched)} (置信度: {conf:.2f})")
|
|||
|
|
|||
|
# 构建结果字符串
|
|||
|
has_text = len(texts) > 0
|
|||
|
has_violation = len(vio_info) > 0
|
|||
|
|
|||
|
if not has_text:
|
|||
|
return (False, "未识别到文本")
|
|||
|
elif has_violation:
|
|||
|
return (True, "; ".join(vio_info))
|
|||
|
else:
|
|||
|
return (False, "未检测到违禁词")
|