From 6cdb1e3d7d90ca3c3a19d222d23e623394d5634a Mon Sep 17 00:00:00 2001 From: ninghongbin <2409766686@qq.com> Date: Fri, 10 Oct 2025 11:35:37 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96ocr=E6=A3=80=E6=B5=8B?= =?UTF-8?q?=E6=97=B6=E9=97=B4,=E5=8A=A0=E8=BD=BD=E9=BB=98=E8=AE=A4?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.ini | 6 +- core/detect.py | 26 ------ ds/db.py | 2 +- service/model_service.py | 161 +++++++++++++++++++++++++------------ service/ocr_service.py | 166 ++++++++++++++++++++++++++++++++------- 5 files changed, 250 insertions(+), 111 deletions(-) diff --git a/config.ini b/config.ini index 000a996..239e22f 100644 --- a/config.ini +++ b/config.ini @@ -2,10 +2,10 @@ port = 8000 [mysql] -host = 192.168.110.65 -port = 6975 +host = 192.168.110.2 +port = 13386 user = video_check -password = fsjPfhxCs8NrFGmL +password = taWtMSpXh88SHnps database = video_check charset = utf8mb4 diff --git a/core/detect.py b/core/detect.py index 4bc51a0..bc99b9f 100644 --- a/core/detect.py +++ b/core/detect.py @@ -44,7 +44,6 @@ def save_db(model_type, client_ip, result): def detectFrame(client_ip, frame): - # YOLO检测 yolo_flag, yolo_result = yoloDetect(frame, float(BUSINESS_CONFIG["yolo_conf"])) if yolo_flag: @@ -103,36 +102,11 @@ def danger_handler(client_ip): json_data=json.dumps(lock_msg) ) ) - # 增加危险记录次数 increment_alarm_count_by_ip(client_ip) - # 更新设备状态为未处理 update_is_need_handler_by_client_ip(client_ip, 1) -def extract_prohibited_words(ocr_result: str) -> str: - """ - 从多文本块的ocr_result中提取所有违禁词(去重后用逗号拼接) - 适配格式:多个"文本: ... 包含违禁词: ...;"片段 - """ - # 用正则匹配所有"包含违禁词: ...;"的片段(非贪婪匹配到分号) - # 匹配规则:"包含违禁词: "后面的内容,直到遇到";"结束 - pattern = r"包含违禁词: (.*?);" - all_prohibited_segments = re.findall(pattern, ocr_result, re.DOTALL) - - all_words = [] - for segment in all_prohibited_segments: - # 去除每个片段中的置信度信息(如"(置信度: 1.00)") - cleaned = re.sub(r"\s*\([^)]*\)", "", segment.strip()) - # 分割词语并过滤空值 - words = [word.strip() for word in cleaned.split(",") if word.strip()] - all_words.extend(words) - - # 去重后用逗号拼接 - unique_words = list(set(all_words)) - return ",".join(unique_words) - - def extract_face_names(face_result: str) -> str: pattern = r"匹配: (.*?) \(" all_names = re.findall(pattern, face_result) diff --git a/ds/db.py b/ds/db.py index dff996d..b56d18f 100644 --- a/ds/db.py +++ b/ds/db.py @@ -56,4 +56,4 @@ class Database: # 暴露数据库操作工具 -db = Database() +db = Database() \ No newline at end of file diff --git a/service/model_service.py b/service/model_service.py index 33a3a06..a20dc52 100644 --- a/service/model_service.py +++ b/service/model_service.py @@ -1,4 +1,4 @@ -from http.client import HTTPException +from fastapi import HTTPException import numpy as np import torch @@ -9,7 +9,7 @@ import os from ds.db import db from service.file_service import get_absolute_path -# 全局变量 +# 全局变量:初始化时为None,无模型时保持None current_yolo_model = None current_model_absolute_path = None # 存储模型绝对路径,不依赖model实例 @@ -18,114 +18,173 @@ MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB def load_yolo_model(): - """加载模型并存储绝对路径""" + """ + 加载模型并存储绝对路径 + 无有效模型路径/模型文件不存在/加载失败时,跳过加载(不抛出异常) + """ global current_yolo_model, current_model_absolute_path + # 1. 获取数据库中的模型路径(无模型时返回None) model_rel_path = get_enabled_model_rel_path() + + # 2. 无模型路径时,跳过加载 + if not model_rel_path: + print("[模型初始化] 未获取到有效模型路径,已跳过模型加载") + current_yolo_model = None + current_model_absolute_path = None + return None + + # 3. 有模型路径时,执行正常加载流程 print(f"[模型初始化] 加载模型:{model_rel_path}") - - # 计算并存储绝对路径 - current_model_absolute_path = get_absolute_path(model_rel_path) - print(f"[模型初始化] 绝对路径:{current_model_absolute_path}") - - # 检查模型文件 - if not os.path.exists(current_model_absolute_path): - raise FileNotFoundError(f"模型文件不存在: {current_model_absolute_path}") - try: + # 计算绝对路径(避免路径处理异常) + current_model_absolute_path = get_absolute_path(model_rel_path) + print(f"[模型初始化] 模型绝对路径:{current_model_absolute_path}") + + # 检查模型文件是否存在 + if not os.path.exists(current_model_absolute_path): + print(f"[模型初始化] 警告:模型文件不存在({current_model_absolute_path}),已跳过加载") + current_yolo_model = None + current_model_absolute_path = None + return None + + # 加载YOLO模型 new_model = YOLO(current_model_absolute_path) + # 设备分配(GPU/CPU) if torch.cuda.is_available(): new_model.to('cuda') - print("模型已移动到GPU") + print("[模型初始化] 模型已移动到GPU设备") else: - print("使用CPU进行推理") + print("[模型初始化] 未检测到GPU,使用CPU进行推理") + + # 更新全局模型变量 current_yolo_model = new_model - print(f"成功加载模型: {current_model_absolute_path}") + print(f"[模型初始化] 成功加载模型:{current_model_absolute_path}") return current_yolo_model + + # 捕获所有加载异常,避免中断项目启动 except Exception as e: - print(f"模型加载失败:{str(e)}") - raise + print(f"[模型初始化] 警告:模型加载失败({str(e)}),已跳过加载") + current_yolo_model = None + current_model_absolute_path = None + return None def get_current_model(): - """获取当前模型实例""" - if current_yolo_model is None: - raise ValueError("尚未加载任何YOLO模型,请先调用load_yolo_model加载模型") + """ + 获取当前模型实例 + 无模型时返回None(不抛出异常,避免中断流程) + """ return current_yolo_model def detect(image_np, conf_threshold=0.8): - # 1. 输入格式验证 + """ + 执行YOLO检测 + 无模型时返回明确提示,不崩溃;有模型时正常返回检测结果 + """ + # 优先检查模型是否已加载 + model = get_current_model() + if not model: + error_msg = "检测失败:未加载任何YOLO模型(数据库中无默认模型或模型加载失败)" + print(f"[检测流程] {error_msg}") + return False, error_msg # 返回False+错误提示,而非None + + # 2. 输入格式验证(保留原逻辑,格式错误仍抛异常,属于参数问题) if not isinstance(image_np, np.ndarray): - raise ValueError("输入必须是numpy数组(BGR图像)") + raise ValueError("输入必须是numpy数组(BGR图像格式)") if image_np.ndim != 3 or image_np.shape[-1] != 3: - raise ValueError(f"输入图像格式错误,需为 (h, w, 3) 的BGR数组,当前shape: {image_np.shape}") + raise ValueError(f"输入图像格式错误,需为 (高度, 宽度, 3) 的BGR数组,当前shape: {image_np.shape}") + detection_results = [] try: - model = get_current_model() - if not current_model_absolute_path: - raise RuntimeError("模型未初始化!请先调用 load_yolo_model 加载模型") + # 3. 检测配置 device = "cuda" if torch.cuda.is_available() else "cpu" - print(f"检测设备:{device} | 置信度阈值:{conf_threshold}") - - # 图像尺寸信息 img_height, img_width = image_np.shape[:2] - print(f"输入图像尺寸:{img_width}x{img_height}") + print(f"[检测流程] 设备:{device} | 置信度阈值:{conf_threshold} | 图像尺寸:{img_width}x{img_height}") - # YOLO检测 - print("执行YOLO检测") + # 4. 执行YOLO预测 + print("[检测流程] 开始执行YOLO检测") results = model.predict( image_np, conf=conf_threshold, device=device, - show=False, + show=False, # 不显示检测窗口 + verbose=False # 关闭YOLO内部日志(可选,减少冗余输出) ) - # 4. 整理检测结果(仅保留Chest类别,ID=2) + # 5. 整理检测结果(仅保留置信度达标结果,原逻辑保留) for box in results[0].boxes: - class_id = int(box.cls[0]) # 类别ID + class_id = int(box.cls[0]) class_name = model.names[class_id] confidence = float(box.conf[0]) + # 转换为整数坐标(x1, y1, x2, y2) bbox = tuple(map(int, box.xyxy[0])) - # 过滤条件:置信度达标 + 类别为Chest(class_id=2) - # and class_id == 2 - if confidence >= conf_threshold: + # 过滤条件:置信度达标 + if confidence >= conf_threshold and 0 <= class_id <= 5: detection_results.append({ "class": class_name, - "confidence": confidence, + "confidence": round(confidence, 4), # 保留4位小数,优化输出 "bbox": bbox }) - # 判断是否有目标 + # 6. 判断是否检测到目标 has_content = len(detection_results) > 0 + print(f"[检测流程] 检测完成:共检测到 {len(detection_results)} 个目标") return has_content, detection_results + # 7. 捕获检测过程异常,返回明确错误信息 except Exception as e: error_msg = f"检测过程出错:{str(e)}" - print(error_msg) - return False, None + print(f"[检测流程] {error_msg}") + return False, error_msg def get_enabled_model_rel_path(): - """获取数据库中启用的模型相对路径""" + """ + 从数据库获取启用的默认模型相对路径 + 无模型/数据库错误时返回None,仅记录警告日志 + """ conn = None cursor = None try: + # 建立数据库连接 conn = db.get_connection() cursor = conn.cursor(dictionary=True) + # 查询默认模型(is_default=1) query = "SELECT path FROM model WHERE is_default = 1 LIMIT 1" cursor.execute(query) result = cursor.fetchone() - if not result or not result.get('path'): - raise HTTPException(status_code=404, detail="未找到启用的默认模型") + # 有有效路径则返回,否则返回None + if result and isinstance(result.get('path'), str) and result['path'].strip(): + model_path = result['path'].strip() + print(f"找到默认模型路径:{model_path}") + return model_path + else: + print("警告:未找到启用的默认模型") + return None - return result['path'] + # 捕获MySQL相关错误 except MySQLError as e: - raise HTTPException(status_code=500, detail=f"查询默认模型时发生数据库错误:{str(e)}") from e + print(f"警告:查询默认模型时发生数据库错误({str(e)})") + return None + # 捕获其他通用错误 except Exception as e: - if isinstance(e, HTTPException): - raise e - raise HTTPException(status_code=500, detail=f"获取默认模型路径失败:{str(e)}") from e + print(f"[数据库查询] 警告:获取默认模型路径失败({str(e)})") + return None + # 确保数据库连接和游标关闭 finally: - db.close_connection(conn, cursor) \ No newline at end of file + if cursor: + try: + cursor.close() + print("游标已关闭") + except Exception as e: + print(f"关闭游标时出错:{str(e)}") + # 关闭连接(允许重复关闭,无需检查是否已关闭) + if conn: + try: + conn.close() + print("数据库连接已关闭") + except Exception as e: + print(f"关闭数据库连接时出错:{str(e)}") \ No newline at end of file diff --git a/service/ocr_service.py b/service/ocr_service.py index fa9cfcd..4f850ea 100644 --- a/service/ocr_service.py +++ b/service/ocr_service.py @@ -1,4 +1,4 @@ -# 首先添加NumPy兼容处理 +import time import numpy as np # 修复np.int已弃用的问题 @@ -8,29 +8,120 @@ if not hasattr(np, 'int'): from paddleocr import PaddleOCR from service.sensitive_service import get_all_sensitive_words + +# AC自动机节点定义 +class AhoNode: + def __init__(self): + self.children = {} # 子节点映射(字符->节点) + self.fail = None # 失败指针(类似KMP的next数组) + self.is_end = False # 标记是否为某个模式串的结尾 + self.word = None # 存储当前结尾对应的完整违禁词 + + +# AC自动机实现(多模式字符串匹配) +class AhoCorasick: + def __init__(self): + self.root = AhoNode() # 根节点 + + def add_word(self, word): + """添加违禁词到Trie树""" + if not isinstance(word, str) or not word.strip(): + return # 过滤无效词 + node = self.root + for char in word: + if char not in node.children: + node.children[char] = AhoNode() + node = node.children[char] + node.is_end = True + node.word = word # 记录完整词 + + def build_fail(self): + """构建失败指针(BFS遍历)""" + queue = [] + # 根节点的子节点失败指针指向根节点 + for child in self.root.children.values(): + child.fail = self.root + queue.append(child) + + # BFS处理其他节点 + while queue: + current_node = queue.pop(0) + # 遍历当前节点的所有子节点 + for char, child in current_node.children.items(): + # 寻找失败指针目标节点 + fail_node = current_node.fail + while fail_node is not None and char not in fail_node.children: + fail_node = fail_node.fail + # 确定失败指针指向 + child.fail = fail_node.children[char] if (fail_node and char in fail_node.children) else self.root + queue.append(child) + + def match(self, text): + """匹配文本中所有出现的违禁词(去重)""" + result = set() + node = self.root + for char in text: + # 沿失败链查找可用节点 + while node is not None and char not in node.children: + node = node.fail + # 重置到根节点(如果没找到) + node = node.children[char] if (node and char in node.children) else self.root + + # 收集所有匹配的违禁词(包括失败链上的) + temp = node + while temp != self.root: + if temp.is_end: + result.add(temp.word) + temp = temp.fail + return list(result) + + +# 全局变量 _ocr_engine = None -_forbidden_words = set() +_ac_automaton = None # 替换原有的_forbidden_words集合 _conf_threshold = 0.5 + def set_forbidden_words(new_words): - global _forbidden_words + """更新违禁词(使用AC自动机存储)""" + global _ac_automaton if not isinstance(new_words, (set, list, tuple)): raise TypeError("新违禁词必须是集合、列表或元组类型") - _forbidden_words = set(new_words) # 确保是集合类型 - print(f"已通过函数更新违禁词,当前数量: {len(_forbidden_words)}") + + # 初始化AC自动机并添加有效词 + _ac_automaton = AhoCorasick() + valid_words = [word for word in new_words if isinstance(word, str) and word.strip()] + for word in valid_words: + _ac_automaton.add_word(word.strip()) + # 构建失败指针(关键步骤) + _ac_automaton.build_fail() + + print(f"已通过函数更新违禁词,当前数量: {len(valid_words)}") + def load_forbidden_words(): - global _forbidden_words + """从敏感词服务加载违禁词并初始化AC自动机""" + global _ac_automaton try: - _forbidden_words = get_all_sensitive_words() - print(f"加载的违禁词数量: {len(_forbidden_words)}") + sensitive_words = get_all_sensitive_words() # 保持原接口不变(返回list[str]) + _ac_automaton = AhoCorasick() + + # 添加所有有效敏感词 + valid_words = [word for word in sensitive_words if isinstance(word, str) and word.strip()] + for word in valid_words: + _ac_automaton.add_word(word.strip()) + + # 构建失败指针 + _ac_automaton.build_fail() + print(f"加载的违禁词数量: {len(valid_words)}") + return True except Exception as e: print(f"Forbidden words load error: {e}") return False - return True def init_ocr_engine(): + """初始化OCR引擎和违禁词自动机""" global _ocr_engine try: _ocr_engine = PaddleOCR( @@ -52,34 +143,39 @@ def init_ocr_engine(): def detect(frame, conf_threshold=0.8): + """检测帧中的文本是否包含违禁词(拆分OCR和匹配时间)""" print("开始进行OCR检测...") + total_start = time.time() # 总耗时开始 + ocr_time = 0.0 # OCR及结果解析耗时 + match_time = 0.0 # 违禁词匹配耗时 + try: + if not _ocr_engine or not _ac_automaton: + return (False, "OCR引擎或违禁词库未初始化") + + # 1. OCR识别及结果解析阶段 + ocr_start = time.time() ocr_res = _ocr_engine.ocr(frame, cls=True) if not ocr_res or not isinstance(ocr_res, list): return (False, "无OCR结果") texts = [] confs = [] + # 解析OCR结果 for line in ocr_res: if line is None: continue - if isinstance(line, list): - items_to_process = line - else: - items_to_process = [line] + items_to_process = line if isinstance(line, list) else [line] for item in items_to_process: + # 过滤坐标类数据 if isinstance(item, list) and len(item) == 4: - is_coordinate = True - for point in item: - if not (isinstance(point, list) and len(point) == 2 and - all(isinstance(coord, (int, float)) for coord in point)): - is_coordinate = False - break + is_coordinate = all(isinstance(p, list) and len(p) == 2 and + all(isinstance(c, (int, float)) for c in p) + for p in item) if is_coordinate: continue - if isinstance(item, list) and all(isinstance(x, (int, float)) for x in item): - continue + # 提取文本和置信度 if isinstance(item, tuple) and len(item) == 2: text, conf = item if isinstance(text, str) and isinstance(conf, (int, float)): @@ -98,22 +194,26 @@ def detect(frame, conf_threshold=0.8): texts.append(text_data.strip()) confs.append(1.0) continue - print(f"无法解析的OCR结果格式: {item}") + ocr_end = time.time() + ocr_time = ocr_end - ocr_start # 计算OCR阶段耗时 if len(texts) != len(confs): return (False, "OCR结果格式异常") - # 收集所有识别到的违禁词(去重且保持出现顺序) + # 2. 违禁词匹配阶段 + match_start = time.time() vio_words = [] for txt, conf in zip(texts, confs): - if conf < _conf_threshold: # 过滤低置信度结果 + if conf < _conf_threshold: continue - # 提取当前文本中包含的违禁词 - matched = [w for w in _forbidden_words if w in txt] - # 仅添加未记录过的违禁词(去重) - for word in matched: + # 用AC自动机匹配当前文本中的所有违禁词 + matched_words = _ac_automaton.match(txt) + # 全局去重并保持顺序 + for word in matched_words: if word not in vio_words: vio_words.append(word) + match_end = time.time() + match_time = match_end - match_start # 计算匹配阶段耗时 has_text = len(texts) > 0 has_violation = len(vio_words) > 0 @@ -121,11 +221,17 @@ def detect(frame, conf_threshold=0.8): if not has_text: return (False, "未识别到文本") elif has_violation: - # 多个违禁词用逗号拼接 return (True, ", ".join(vio_words)) else: return (False, "未检测到违禁词") except Exception as e: print(f"OCR detect error: {e}") - return (False, f"检测错误: {str(e)}") \ No newline at end of file + return (False, f"检测错误: {str(e)}") + finally: + # 打印各阶段耗时 + total_time = time.time() - total_start + print(f"当前帧耗时明细:") + print(f" OCR识别及解析:{ocr_time:.8f}秒") + print(f" 违禁词匹配:{match_time:.8f}秒") + print(f" 总耗时:{total_time:.8f}秒") \ No newline at end of file