优化ocr检测时间,加载默认模型
This commit is contained in:
@ -2,10 +2,10 @@
|
|||||||
port = 8000
|
port = 8000
|
||||||
|
|
||||||
[mysql]
|
[mysql]
|
||||||
host = 192.168.110.65
|
host = 192.168.110.2
|
||||||
port = 6975
|
port = 13386
|
||||||
user = video_check
|
user = video_check
|
||||||
password = fsjPfhxCs8NrFGmL
|
password = taWtMSpXh88SHnps
|
||||||
database = video_check
|
database = video_check
|
||||||
charset = utf8mb4
|
charset = utf8mb4
|
||||||
|
|
||||||
|
|||||||
@ -44,7 +44,6 @@ def save_db(model_type, client_ip, result):
|
|||||||
|
|
||||||
|
|
||||||
def detectFrame(client_ip, frame):
|
def detectFrame(client_ip, frame):
|
||||||
|
|
||||||
# YOLO检测
|
# YOLO检测
|
||||||
yolo_flag, yolo_result = yoloDetect(frame, float(BUSINESS_CONFIG["yolo_conf"]))
|
yolo_flag, yolo_result = yoloDetect(frame, float(BUSINESS_CONFIG["yolo_conf"]))
|
||||||
if yolo_flag:
|
if yolo_flag:
|
||||||
@ -103,36 +102,11 @@ def danger_handler(client_ip):
|
|||||||
json_data=json.dumps(lock_msg)
|
json_data=json.dumps(lock_msg)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 增加危险记录次数
|
# 增加危险记录次数
|
||||||
increment_alarm_count_by_ip(client_ip)
|
increment_alarm_count_by_ip(client_ip)
|
||||||
|
|
||||||
# 更新设备状态为未处理
|
# 更新设备状态为未处理
|
||||||
update_is_need_handler_by_client_ip(client_ip, 1)
|
update_is_need_handler_by_client_ip(client_ip, 1)
|
||||||
|
|
||||||
def extract_prohibited_words(ocr_result: str) -> str:
|
|
||||||
"""
|
|
||||||
从多文本块的ocr_result中提取所有违禁词(去重后用逗号拼接)
|
|
||||||
适配格式:多个"文本: ... 包含违禁词: ...;"片段
|
|
||||||
"""
|
|
||||||
# 用正则匹配所有"包含违禁词: ...;"的片段(非贪婪匹配到分号)
|
|
||||||
# 匹配规则:"包含违禁词: "后面的内容,直到遇到";"结束
|
|
||||||
pattern = r"包含违禁词: (.*?);"
|
|
||||||
all_prohibited_segments = re.findall(pattern, ocr_result, re.DOTALL)
|
|
||||||
|
|
||||||
all_words = []
|
|
||||||
for segment in all_prohibited_segments:
|
|
||||||
# 去除每个片段中的置信度信息(如"(置信度: 1.00)")
|
|
||||||
cleaned = re.sub(r"\s*\([^)]*\)", "", segment.strip())
|
|
||||||
# 分割词语并过滤空值
|
|
||||||
words = [word.strip() for word in cleaned.split(",") if word.strip()]
|
|
||||||
all_words.extend(words)
|
|
||||||
|
|
||||||
# 去重后用逗号拼接
|
|
||||||
unique_words = list(set(all_words))
|
|
||||||
return ",".join(unique_words)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_face_names(face_result: str) -> str:
|
def extract_face_names(face_result: str) -> str:
|
||||||
pattern = r"匹配: (.*?) \("
|
pattern = r"匹配: (.*?) \("
|
||||||
all_names = re.findall(pattern, face_result)
|
all_names = re.findall(pattern, face_result)
|
||||||
|
|||||||
2
ds/db.py
2
ds/db.py
@ -56,4 +56,4 @@ class Database:
|
|||||||
|
|
||||||
|
|
||||||
# 暴露数据库操作工具
|
# 暴露数据库操作工具
|
||||||
db = Database()
|
db = Database()
|
||||||
@ -1,4 +1,4 @@
|
|||||||
from http.client import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -9,7 +9,7 @@ import os
|
|||||||
from ds.db import db
|
from ds.db import db
|
||||||
from service.file_service import get_absolute_path
|
from service.file_service import get_absolute_path
|
||||||
|
|
||||||
# 全局变量
|
# 全局变量:初始化时为None,无模型时保持None
|
||||||
current_yolo_model = None
|
current_yolo_model = None
|
||||||
current_model_absolute_path = None # 存储模型绝对路径,不依赖model实例
|
current_model_absolute_path = None # 存储模型绝对路径,不依赖model实例
|
||||||
|
|
||||||
@ -18,114 +18,173 @@ MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB
|
|||||||
|
|
||||||
|
|
||||||
def load_yolo_model():
|
def load_yolo_model():
|
||||||
"""加载模型并存储绝对路径"""
|
"""
|
||||||
|
加载模型并存储绝对路径
|
||||||
|
无有效模型路径/模型文件不存在/加载失败时,跳过加载(不抛出异常)
|
||||||
|
"""
|
||||||
global current_yolo_model, current_model_absolute_path
|
global current_yolo_model, current_model_absolute_path
|
||||||
|
# 1. 获取数据库中的模型路径(无模型时返回None)
|
||||||
model_rel_path = get_enabled_model_rel_path()
|
model_rel_path = get_enabled_model_rel_path()
|
||||||
|
|
||||||
|
# 2. 无模型路径时,跳过加载
|
||||||
|
if not model_rel_path:
|
||||||
|
print("[模型初始化] 未获取到有效模型路径,已跳过模型加载")
|
||||||
|
current_yolo_model = None
|
||||||
|
current_model_absolute_path = None
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 3. 有模型路径时,执行正常加载流程
|
||||||
print(f"[模型初始化] 加载模型:{model_rel_path}")
|
print(f"[模型初始化] 加载模型:{model_rel_path}")
|
||||||
|
|
||||||
# 计算并存储绝对路径
|
|
||||||
current_model_absolute_path = get_absolute_path(model_rel_path)
|
|
||||||
print(f"[模型初始化] 绝对路径:{current_model_absolute_path}")
|
|
||||||
|
|
||||||
# 检查模型文件
|
|
||||||
if not os.path.exists(current_model_absolute_path):
|
|
||||||
raise FileNotFoundError(f"模型文件不存在: {current_model_absolute_path}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# 计算绝对路径(避免路径处理异常)
|
||||||
|
current_model_absolute_path = get_absolute_path(model_rel_path)
|
||||||
|
print(f"[模型初始化] 模型绝对路径:{current_model_absolute_path}")
|
||||||
|
|
||||||
|
# 检查模型文件是否存在
|
||||||
|
if not os.path.exists(current_model_absolute_path):
|
||||||
|
print(f"[模型初始化] 警告:模型文件不存在({current_model_absolute_path}),已跳过加载")
|
||||||
|
current_yolo_model = None
|
||||||
|
current_model_absolute_path = None
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 加载YOLO模型
|
||||||
new_model = YOLO(current_model_absolute_path)
|
new_model = YOLO(current_model_absolute_path)
|
||||||
|
# 设备分配(GPU/CPU)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
new_model.to('cuda')
|
new_model.to('cuda')
|
||||||
print("模型已移动到GPU")
|
print("[模型初始化] 模型已移动到GPU设备")
|
||||||
else:
|
else:
|
||||||
print("使用CPU进行推理")
|
print("[模型初始化] 未检测到GPU,使用CPU进行推理")
|
||||||
|
|
||||||
|
# 更新全局模型变量
|
||||||
current_yolo_model = new_model
|
current_yolo_model = new_model
|
||||||
print(f"成功加载模型: {current_model_absolute_path}")
|
print(f"[模型初始化] 成功加载模型:{current_model_absolute_path}")
|
||||||
return current_yolo_model
|
return current_yolo_model
|
||||||
|
|
||||||
|
# 捕获所有加载异常,避免中断项目启动
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"模型加载失败:{str(e)}")
|
print(f"[模型初始化] 警告:模型加载失败({str(e)}),已跳过加载")
|
||||||
raise
|
current_yolo_model = None
|
||||||
|
current_model_absolute_path = None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_current_model():
|
def get_current_model():
|
||||||
"""获取当前模型实例"""
|
"""
|
||||||
if current_yolo_model is None:
|
获取当前模型实例
|
||||||
raise ValueError("尚未加载任何YOLO模型,请先调用load_yolo_model加载模型")
|
无模型时返回None(不抛出异常,避免中断流程)
|
||||||
|
"""
|
||||||
return current_yolo_model
|
return current_yolo_model
|
||||||
|
|
||||||
|
|
||||||
def detect(image_np, conf_threshold=0.8):
|
def detect(image_np, conf_threshold=0.8):
|
||||||
# 1. 输入格式验证
|
"""
|
||||||
|
执行YOLO检测
|
||||||
|
无模型时返回明确提示,不崩溃;有模型时正常返回检测结果
|
||||||
|
"""
|
||||||
|
# 优先检查模型是否已加载
|
||||||
|
model = get_current_model()
|
||||||
|
if not model:
|
||||||
|
error_msg = "检测失败:未加载任何YOLO模型(数据库中无默认模型或模型加载失败)"
|
||||||
|
print(f"[检测流程] {error_msg}")
|
||||||
|
return False, error_msg # 返回False+错误提示,而非None
|
||||||
|
|
||||||
|
# 2. 输入格式验证(保留原逻辑,格式错误仍抛异常,属于参数问题)
|
||||||
if not isinstance(image_np, np.ndarray):
|
if not isinstance(image_np, np.ndarray):
|
||||||
raise ValueError("输入必须是numpy数组(BGR图像)")
|
raise ValueError("输入必须是numpy数组(BGR图像格式)")
|
||||||
if image_np.ndim != 3 or image_np.shape[-1] != 3:
|
if image_np.ndim != 3 or image_np.shape[-1] != 3:
|
||||||
raise ValueError(f"输入图像格式错误,需为 (h, w, 3) 的BGR数组,当前shape: {image_np.shape}")
|
raise ValueError(f"输入图像格式错误,需为 (高度, 宽度, 3) 的BGR数组,当前shape: {image_np.shape}")
|
||||||
|
|
||||||
detection_results = []
|
detection_results = []
|
||||||
try:
|
try:
|
||||||
model = get_current_model()
|
# 3. 检测配置
|
||||||
if not current_model_absolute_path:
|
|
||||||
raise RuntimeError("模型未初始化!请先调用 load_yolo_model 加载模型")
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
print(f"检测设备:{device} | 置信度阈值:{conf_threshold}")
|
|
||||||
|
|
||||||
# 图像尺寸信息
|
|
||||||
img_height, img_width = image_np.shape[:2]
|
img_height, img_width = image_np.shape[:2]
|
||||||
print(f"输入图像尺寸:{img_width}x{img_height}")
|
print(f"[检测流程] 设备:{device} | 置信度阈值:{conf_threshold} | 图像尺寸:{img_width}x{img_height}")
|
||||||
|
|
||||||
# YOLO检测
|
# 4. 执行YOLO预测
|
||||||
print("执行YOLO检测")
|
print("[检测流程] 开始执行YOLO检测")
|
||||||
results = model.predict(
|
results = model.predict(
|
||||||
image_np,
|
image_np,
|
||||||
conf=conf_threshold,
|
conf=conf_threshold,
|
||||||
device=device,
|
device=device,
|
||||||
show=False,
|
show=False, # 不显示检测窗口
|
||||||
|
verbose=False # 关闭YOLO内部日志(可选,减少冗余输出)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. 整理检测结果(仅保留Chest类别,ID=2)
|
# 5. 整理检测结果(仅保留置信度达标结果,原逻辑保留)
|
||||||
for box in results[0].boxes:
|
for box in results[0].boxes:
|
||||||
class_id = int(box.cls[0]) # 类别ID
|
class_id = int(box.cls[0])
|
||||||
class_name = model.names[class_id]
|
class_name = model.names[class_id]
|
||||||
confidence = float(box.conf[0])
|
confidence = float(box.conf[0])
|
||||||
|
# 转换为整数坐标(x1, y1, x2, y2)
|
||||||
bbox = tuple(map(int, box.xyxy[0]))
|
bbox = tuple(map(int, box.xyxy[0]))
|
||||||
|
|
||||||
# 过滤条件:置信度达标 + 类别为Chest(class_id=2)
|
# 过滤条件:置信度达标
|
||||||
# and class_id == 2
|
if confidence >= conf_threshold and 0 <= class_id <= 5:
|
||||||
if confidence >= conf_threshold:
|
|
||||||
detection_results.append({
|
detection_results.append({
|
||||||
"class": class_name,
|
"class": class_name,
|
||||||
"confidence": confidence,
|
"confidence": round(confidence, 4), # 保留4位小数,优化输出
|
||||||
"bbox": bbox
|
"bbox": bbox
|
||||||
})
|
})
|
||||||
|
|
||||||
# 判断是否有目标
|
# 6. 判断是否检测到目标
|
||||||
has_content = len(detection_results) > 0
|
has_content = len(detection_results) > 0
|
||||||
|
print(f"[检测流程] 检测完成:共检测到 {len(detection_results)} 个目标")
|
||||||
return has_content, detection_results
|
return has_content, detection_results
|
||||||
|
|
||||||
|
# 7. 捕获检测过程异常,返回明确错误信息
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"检测过程出错:{str(e)}"
|
error_msg = f"检测过程出错:{str(e)}"
|
||||||
print(error_msg)
|
print(f"[检测流程] {error_msg}")
|
||||||
return False, None
|
return False, error_msg
|
||||||
|
|
||||||
|
|
||||||
def get_enabled_model_rel_path():
|
def get_enabled_model_rel_path():
|
||||||
"""获取数据库中启用的模型相对路径"""
|
"""
|
||||||
|
从数据库获取启用的默认模型相对路径
|
||||||
|
无模型/数据库错误时返回None,仅记录警告日志
|
||||||
|
"""
|
||||||
conn = None
|
conn = None
|
||||||
cursor = None
|
cursor = None
|
||||||
try:
|
try:
|
||||||
|
# 建立数据库连接
|
||||||
conn = db.get_connection()
|
conn = db.get_connection()
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
# 查询默认模型(is_default=1)
|
||||||
query = "SELECT path FROM model WHERE is_default = 1 LIMIT 1"
|
query = "SELECT path FROM model WHERE is_default = 1 LIMIT 1"
|
||||||
cursor.execute(query)
|
cursor.execute(query)
|
||||||
result = cursor.fetchone()
|
result = cursor.fetchone()
|
||||||
|
|
||||||
if not result or not result.get('path'):
|
# 有有效路径则返回,否则返回None
|
||||||
raise HTTPException(status_code=404, detail="未找到启用的默认模型")
|
if result and isinstance(result.get('path'), str) and result['path'].strip():
|
||||||
|
model_path = result['path'].strip()
|
||||||
|
print(f"找到默认模型路径:{model_path}")
|
||||||
|
return model_path
|
||||||
|
else:
|
||||||
|
print("警告:未找到启用的默认模型")
|
||||||
|
return None
|
||||||
|
|
||||||
return result['path']
|
# 捕获MySQL相关错误
|
||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
raise HTTPException(status_code=500, detail=f"查询默认模型时发生数据库错误:{str(e)}") from e
|
print(f"警告:查询默认模型时发生数据库错误({str(e)})")
|
||||||
|
return None
|
||||||
|
# 捕获其他通用错误
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, HTTPException):
|
print(f"[数据库查询] 警告:获取默认模型路径失败({str(e)})")
|
||||||
raise e
|
return None
|
||||||
raise HTTPException(status_code=500, detail=f"获取默认模型路径失败:{str(e)}") from e
|
# 确保数据库连接和游标关闭
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
if cursor:
|
||||||
|
try:
|
||||||
|
cursor.close()
|
||||||
|
print("游标已关闭")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"关闭游标时出错:{str(e)}")
|
||||||
|
# 关闭连接(允许重复关闭,无需检查是否已关闭)
|
||||||
|
if conn:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
print("数据库连接已关闭")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"关闭数据库连接时出错:{str(e)}")
|
||||||
@ -1,4 +1,4 @@
|
|||||||
# 首先添加NumPy兼容处理
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
# 修复np.int已弃用的问题
|
# 修复np.int已弃用的问题
|
||||||
@ -8,29 +8,120 @@ if not hasattr(np, 'int'):
|
|||||||
from paddleocr import PaddleOCR
|
from paddleocr import PaddleOCR
|
||||||
from service.sensitive_service import get_all_sensitive_words
|
from service.sensitive_service import get_all_sensitive_words
|
||||||
|
|
||||||
|
|
||||||
|
# AC自动机节点定义
|
||||||
|
class AhoNode:
|
||||||
|
def __init__(self):
|
||||||
|
self.children = {} # 子节点映射(字符->节点)
|
||||||
|
self.fail = None # 失败指针(类似KMP的next数组)
|
||||||
|
self.is_end = False # 标记是否为某个模式串的结尾
|
||||||
|
self.word = None # 存储当前结尾对应的完整违禁词
|
||||||
|
|
||||||
|
|
||||||
|
# AC自动机实现(多模式字符串匹配)
|
||||||
|
class AhoCorasick:
|
||||||
|
def __init__(self):
|
||||||
|
self.root = AhoNode() # 根节点
|
||||||
|
|
||||||
|
def add_word(self, word):
|
||||||
|
"""添加违禁词到Trie树"""
|
||||||
|
if not isinstance(word, str) or not word.strip():
|
||||||
|
return # 过滤无效词
|
||||||
|
node = self.root
|
||||||
|
for char in word:
|
||||||
|
if char not in node.children:
|
||||||
|
node.children[char] = AhoNode()
|
||||||
|
node = node.children[char]
|
||||||
|
node.is_end = True
|
||||||
|
node.word = word # 记录完整词
|
||||||
|
|
||||||
|
def build_fail(self):
|
||||||
|
"""构建失败指针(BFS遍历)"""
|
||||||
|
queue = []
|
||||||
|
# 根节点的子节点失败指针指向根节点
|
||||||
|
for child in self.root.children.values():
|
||||||
|
child.fail = self.root
|
||||||
|
queue.append(child)
|
||||||
|
|
||||||
|
# BFS处理其他节点
|
||||||
|
while queue:
|
||||||
|
current_node = queue.pop(0)
|
||||||
|
# 遍历当前节点的所有子节点
|
||||||
|
for char, child in current_node.children.items():
|
||||||
|
# 寻找失败指针目标节点
|
||||||
|
fail_node = current_node.fail
|
||||||
|
while fail_node is not None and char not in fail_node.children:
|
||||||
|
fail_node = fail_node.fail
|
||||||
|
# 确定失败指针指向
|
||||||
|
child.fail = fail_node.children[char] if (fail_node and char in fail_node.children) else self.root
|
||||||
|
queue.append(child)
|
||||||
|
|
||||||
|
def match(self, text):
|
||||||
|
"""匹配文本中所有出现的违禁词(去重)"""
|
||||||
|
result = set()
|
||||||
|
node = self.root
|
||||||
|
for char in text:
|
||||||
|
# 沿失败链查找可用节点
|
||||||
|
while node is not None and char not in node.children:
|
||||||
|
node = node.fail
|
||||||
|
# 重置到根节点(如果没找到)
|
||||||
|
node = node.children[char] if (node and char in node.children) else self.root
|
||||||
|
|
||||||
|
# 收集所有匹配的违禁词(包括失败链上的)
|
||||||
|
temp = node
|
||||||
|
while temp != self.root:
|
||||||
|
if temp.is_end:
|
||||||
|
result.add(temp.word)
|
||||||
|
temp = temp.fail
|
||||||
|
return list(result)
|
||||||
|
|
||||||
|
|
||||||
|
# 全局变量
|
||||||
_ocr_engine = None
|
_ocr_engine = None
|
||||||
_forbidden_words = set()
|
_ac_automaton = None # 替换原有的_forbidden_words集合
|
||||||
_conf_threshold = 0.5
|
_conf_threshold = 0.5
|
||||||
|
|
||||||
|
|
||||||
def set_forbidden_words(new_words):
|
def set_forbidden_words(new_words):
|
||||||
global _forbidden_words
|
"""更新违禁词(使用AC自动机存储)"""
|
||||||
|
global _ac_automaton
|
||||||
if not isinstance(new_words, (set, list, tuple)):
|
if not isinstance(new_words, (set, list, tuple)):
|
||||||
raise TypeError("新违禁词必须是集合、列表或元组类型")
|
raise TypeError("新违禁词必须是集合、列表或元组类型")
|
||||||
_forbidden_words = set(new_words) # 确保是集合类型
|
|
||||||
print(f"已通过函数更新违禁词,当前数量: {len(_forbidden_words)}")
|
# 初始化AC自动机并添加有效词
|
||||||
|
_ac_automaton = AhoCorasick()
|
||||||
|
valid_words = [word for word in new_words if isinstance(word, str) and word.strip()]
|
||||||
|
for word in valid_words:
|
||||||
|
_ac_automaton.add_word(word.strip())
|
||||||
|
# 构建失败指针(关键步骤)
|
||||||
|
_ac_automaton.build_fail()
|
||||||
|
|
||||||
|
print(f"已通过函数更新违禁词,当前数量: {len(valid_words)}")
|
||||||
|
|
||||||
|
|
||||||
def load_forbidden_words():
|
def load_forbidden_words():
|
||||||
global _forbidden_words
|
"""从敏感词服务加载违禁词并初始化AC自动机"""
|
||||||
|
global _ac_automaton
|
||||||
try:
|
try:
|
||||||
_forbidden_words = get_all_sensitive_words()
|
sensitive_words = get_all_sensitive_words() # 保持原接口不变(返回list[str])
|
||||||
print(f"加载的违禁词数量: {len(_forbidden_words)}")
|
_ac_automaton = AhoCorasick()
|
||||||
|
|
||||||
|
# 添加所有有效敏感词
|
||||||
|
valid_words = [word for word in sensitive_words if isinstance(word, str) and word.strip()]
|
||||||
|
for word in valid_words:
|
||||||
|
_ac_automaton.add_word(word.strip())
|
||||||
|
|
||||||
|
# 构建失败指针
|
||||||
|
_ac_automaton.build_fail()
|
||||||
|
print(f"加载的违禁词数量: {len(valid_words)}")
|
||||||
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Forbidden words load error: {e}")
|
print(f"Forbidden words load error: {e}")
|
||||||
return False
|
return False
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def init_ocr_engine():
|
def init_ocr_engine():
|
||||||
|
"""初始化OCR引擎和违禁词自动机"""
|
||||||
global _ocr_engine
|
global _ocr_engine
|
||||||
try:
|
try:
|
||||||
_ocr_engine = PaddleOCR(
|
_ocr_engine = PaddleOCR(
|
||||||
@ -52,34 +143,39 @@ def init_ocr_engine():
|
|||||||
|
|
||||||
|
|
||||||
def detect(frame, conf_threshold=0.8):
|
def detect(frame, conf_threshold=0.8):
|
||||||
|
"""检测帧中的文本是否包含违禁词(拆分OCR和匹配时间)"""
|
||||||
print("开始进行OCR检测...")
|
print("开始进行OCR检测...")
|
||||||
|
total_start = time.time() # 总耗时开始
|
||||||
|
ocr_time = 0.0 # OCR及结果解析耗时
|
||||||
|
match_time = 0.0 # 违禁词匹配耗时
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if not _ocr_engine or not _ac_automaton:
|
||||||
|
return (False, "OCR引擎或违禁词库未初始化")
|
||||||
|
|
||||||
|
# 1. OCR识别及结果解析阶段
|
||||||
|
ocr_start = time.time()
|
||||||
ocr_res = _ocr_engine.ocr(frame, cls=True)
|
ocr_res = _ocr_engine.ocr(frame, cls=True)
|
||||||
if not ocr_res or not isinstance(ocr_res, list):
|
if not ocr_res or not isinstance(ocr_res, list):
|
||||||
return (False, "无OCR结果")
|
return (False, "无OCR结果")
|
||||||
|
|
||||||
texts = []
|
texts = []
|
||||||
confs = []
|
confs = []
|
||||||
|
# 解析OCR结果
|
||||||
for line in ocr_res:
|
for line in ocr_res:
|
||||||
if line is None:
|
if line is None:
|
||||||
continue
|
continue
|
||||||
if isinstance(line, list):
|
items_to_process = line if isinstance(line, list) else [line]
|
||||||
items_to_process = line
|
|
||||||
else:
|
|
||||||
items_to_process = [line]
|
|
||||||
|
|
||||||
for item in items_to_process:
|
for item in items_to_process:
|
||||||
|
# 过滤坐标类数据
|
||||||
if isinstance(item, list) and len(item) == 4:
|
if isinstance(item, list) and len(item) == 4:
|
||||||
is_coordinate = True
|
is_coordinate = all(isinstance(p, list) and len(p) == 2 and
|
||||||
for point in item:
|
all(isinstance(c, (int, float)) for c in p)
|
||||||
if not (isinstance(point, list) and len(point) == 2 and
|
for p in item)
|
||||||
all(isinstance(coord, (int, float)) for coord in point)):
|
|
||||||
is_coordinate = False
|
|
||||||
break
|
|
||||||
if is_coordinate:
|
if is_coordinate:
|
||||||
continue
|
continue
|
||||||
if isinstance(item, list) and all(isinstance(x, (int, float)) for x in item):
|
# 提取文本和置信度
|
||||||
continue
|
|
||||||
if isinstance(item, tuple) and len(item) == 2:
|
if isinstance(item, tuple) and len(item) == 2:
|
||||||
text, conf = item
|
text, conf = item
|
||||||
if isinstance(text, str) and isinstance(conf, (int, float)):
|
if isinstance(text, str) and isinstance(conf, (int, float)):
|
||||||
@ -98,22 +194,26 @@ def detect(frame, conf_threshold=0.8):
|
|||||||
texts.append(text_data.strip())
|
texts.append(text_data.strip())
|
||||||
confs.append(1.0)
|
confs.append(1.0)
|
||||||
continue
|
continue
|
||||||
print(f"无法解析的OCR结果格式: {item}")
|
ocr_end = time.time()
|
||||||
|
ocr_time = ocr_end - ocr_start # 计算OCR阶段耗时
|
||||||
|
|
||||||
if len(texts) != len(confs):
|
if len(texts) != len(confs):
|
||||||
return (False, "OCR结果格式异常")
|
return (False, "OCR结果格式异常")
|
||||||
|
|
||||||
# 收集所有识别到的违禁词(去重且保持出现顺序)
|
# 2. 违禁词匹配阶段
|
||||||
|
match_start = time.time()
|
||||||
vio_words = []
|
vio_words = []
|
||||||
for txt, conf in zip(texts, confs):
|
for txt, conf in zip(texts, confs):
|
||||||
if conf < _conf_threshold: # 过滤低置信度结果
|
if conf < _conf_threshold:
|
||||||
continue
|
continue
|
||||||
# 提取当前文本中包含的违禁词
|
# 用AC自动机匹配当前文本中的所有违禁词
|
||||||
matched = [w for w in _forbidden_words if w in txt]
|
matched_words = _ac_automaton.match(txt)
|
||||||
# 仅添加未记录过的违禁词(去重)
|
# 全局去重并保持顺序
|
||||||
for word in matched:
|
for word in matched_words:
|
||||||
if word not in vio_words:
|
if word not in vio_words:
|
||||||
vio_words.append(word)
|
vio_words.append(word)
|
||||||
|
match_end = time.time()
|
||||||
|
match_time = match_end - match_start # 计算匹配阶段耗时
|
||||||
|
|
||||||
has_text = len(texts) > 0
|
has_text = len(texts) > 0
|
||||||
has_violation = len(vio_words) > 0
|
has_violation = len(vio_words) > 0
|
||||||
@ -121,11 +221,17 @@ def detect(frame, conf_threshold=0.8):
|
|||||||
if not has_text:
|
if not has_text:
|
||||||
return (False, "未识别到文本")
|
return (False, "未识别到文本")
|
||||||
elif has_violation:
|
elif has_violation:
|
||||||
# 多个违禁词用逗号拼接
|
|
||||||
return (True, ", ".join(vio_words))
|
return (True, ", ".join(vio_words))
|
||||||
else:
|
else:
|
||||||
return (False, "未检测到违禁词")
|
return (False, "未检测到违禁词")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"OCR detect error: {e}")
|
print(f"OCR detect error: {e}")
|
||||||
return (False, f"检测错误: {str(e)}")
|
return (False, f"检测错误: {str(e)}")
|
||||||
|
finally:
|
||||||
|
# 打印各阶段耗时
|
||||||
|
total_time = time.time() - total_start
|
||||||
|
print(f"当前帧耗时明细:")
|
||||||
|
print(f" OCR识别及解析:{ocr_time:.8f}秒")
|
||||||
|
print(f" 违禁词匹配:{match_time:.8f}秒")
|
||||||
|
print(f" 总耗时:{total_time:.8f}秒")
|
||||||
Reference in New Issue
Block a user