diff --git a/.idea/Video.iml b/.idea/Video.iml index 8f67bb8..8437fe6 100644 --- a/.idea/Video.iml +++ b/.idea/Video.iml @@ -2,7 +2,7 @@ - + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml index 0f99d01..b6a491a 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -3,5 +3,5 @@ - + \ No newline at end of file diff --git a/config.ini b/config.ini index 9f431f7..ca81a44 100644 --- a/config.ini +++ b/config.ini @@ -15,5 +15,5 @@ algorithm = HS256 access_token_expire_minutes = 30 [live] -rtmp_url = rtmp://192.168.110.65:1935/live/ -webrtc_url = http://192.168.110.65:1985/rtc/v1/whep/?app=live&stream= +rtmp_url = rtmp://192.168.110.25:1935/live/ +webrtc_url = http://192.168.110.25:1985/rtc/v1/whep/?app=live&stream= diff --git a/core/all.py b/core/all.py new file mode 100644 index 0000000..9065b80 --- /dev/null +++ b/core/all.py @@ -0,0 +1,45 @@ +from core.ocr import load_model as ocrLoadModel, detect as ocrDetect +from core.face import load_model as faceLoadModel, detect as faceDetect +from core.yolo import load_model as yoloLoadModel, detect as yoloDetect + +# 添加一个标记变量,用于监控load_model是否已被调用 +_model_loaded = False + + +def load_model(): + global _model_loaded + + # 如果已经调用过,直接忽略 + if _model_loaded: + return + + # 首次调用时加载模型 + ocrLoadModel() + faceLoadModel() + yoloLoadModel() + + # 标记为已调用 + _model_loaded = True + + +def detect(frame): + # 先进行YOLO检测 + yolo_flag, yolo_result = yoloDetect(frame) + print("YOLO检测结果:", yolo_result) + if yolo_flag: + return (True, yolo_result, "yolo") + + # YOLO未检测到,进行人脸检测 + face_flag, face_result = faceDetect(frame) + print("人脸检测结果:", face_result) + if face_flag: + return (True, face_result, "face") + + # 人脸未检测到,进行OCR检测 + ocr_flag, ocr_result = ocrDetect(frame) + print("OCR检测结果:", ocr_result) + if ocr_flag: + return (True, ocr_result, "ocr") + + # 所有检测都未检测到 + return (False, "未检测到任何内容", "none") \ No newline at end of file diff --git a/ocr/config/1.yaml b/core/config/config.yaml similarity index 100% rename from ocr/config/1.yaml rename to core/config/config.yaml diff --git a/core/face.py b/core/face.py new file mode 100644 index 0000000..e1edee4 --- /dev/null +++ b/core/face.py @@ -0,0 +1,113 @@ +import os +import numpy as np +import cv2 +from PIL import Image # 确保正确导入Image类 +from insightface.app import FaceAnalysis +# 导入获取人脸信息的服务 +from service.face_service import get_all_face_name_with_eigenvalue + +# 全局变量 +_face_app = None +_known_faces_embeddings = {} # 存储姓名到特征值的映射 +_known_faces_names = [] # 存储所有已知姓名 + + +def load_model(): + """加载人脸识别模型及已知人脸特征库""" + global _face_app, _known_faces_embeddings, _known_faces_names + + # 初始化InsightFace模型 + try: + _face_app = FaceAnalysis(name='buffalo_l', root=os.path.expanduser('~/.insightface')) + _face_app.prepare(ctx_id=0, det_size=(640, 640)) + except Exception as e: + print(f"Face model load failed: {e}") + return False + + # 从服务获取所有人脸姓名和特征值 + try: + face_data = get_all_face_name_with_eigenvalue() + + # 处理获取到的人脸数据 + for person_name, eigenvalue_data in face_data.items(): + # 处理特征值数据 - 兼容数组和字符串两种格式 + if isinstance(eigenvalue_data, np.ndarray): + # 如果已经是numpy数组,直接使用 + eigenvalue = eigenvalue_data.astype(np.float32) + elif isinstance(eigenvalue_data, str): + # 清理字符串:移除方括号、换行符和多余空格 + cleaned = eigenvalue_data.replace('[', '').replace(']', '').replace('\n', '').strip() + # 按空格或逗号分割(处理可能的不同分隔符) + values = [v for v in cleaned.split() if v] + # 转换为数组 + eigenvalue = np.array(list(map(float, values)), dtype=np.float32) + else: + # 不支持的类型 + print(f"Unsupported eigenvalue type for {person_name}") + continue + + # 归一化处理 + norm = np.linalg.norm(eigenvalue) + if norm != 0: + eigenvalue = eigenvalue / norm + + _known_faces_embeddings[person_name] = eigenvalue + _known_faces_names.append(person_name) + + except Exception as e: + print(f"Error loading face data from service: {e}") + + return True if _face_app else False + + +def detect(frame, threshold=0.4): + """检测并识别人脸,返回结果元组(是否匹配到已知人脸, 结果字符串)""" + global _face_app, _known_faces_embeddings, _known_faces_names + + if not _face_app or not _known_faces_names or frame is None: + return (False, "未初始化或无效帧") + + try: + faces = _face_app.get(frame) + except Exception as e: + print(f"Face detect error: {e}") + return (False, f"检测错误: {str(e)}") + + result_parts = [] + has_matched = False # 新增标记:是否有匹配到的已知人脸 + + for face in faces: + # 特征归一化 + embedding = face.embedding.astype(np.float32) + norm = np.linalg.norm(embedding) + if norm == 0: + continue + embedding = embedding / norm + + # 对比已知人脸 + max_sim, best_name = -1.0, "Unknown" + for name in _known_faces_names: + known_emb = _known_faces_embeddings[name] + sim = np.dot(embedding, known_emb) + if sim > max_sim: + max_sim = sim + best_name = name + + # 判断匹配结果 + is_match = max_sim >= threshold + if is_match: + has_matched = True # 只要有一个匹配成功,就标记为True + + bbox = face.bbox + result_parts.append( + f"{'匹配' if is_match else '不匹配'}: {best_name} (相似度: {max_sim:.2f}, 边界框: {bbox})" + ) + + # 构建结果字符串 + if not result_parts: + result_str = "未检测到人脸" + else: + result_str = "; ".join(result_parts) + + # 第一个返回值改为:是否匹配到已知人脸 + return (has_matched, result_str) diff --git a/core/models/best.pt b/core/models/best.pt new file mode 100644 index 0000000..eb9a0b6 Binary files /dev/null and b/core/models/best.pt differ diff --git a/core/ocr.py b/core/ocr.py new file mode 100644 index 0000000..3b38287 --- /dev/null +++ b/core/ocr.py @@ -0,0 +1,76 @@ +import os +import cv2 +from rapidocr import RapidOCR +from service.sensitive_service import get_all_sensitive_words + +# 全局变量 +_ocr_engine = None +_forbidden_words = set() +_conf_threshold = 0.5 + +ocr_config_path = os.path.join(os.path.dirname(__file__), "config", "config.yaml") + + +def load_model(): + """加载OCR引擎及违禁词列表""" + global _ocr_engine, _forbidden_words, _conf_threshold + + # 加载违禁词 + try: + _forbidden_words = get_all_sensitive_words() + except Exception as e: + print(f"Forbidden words load error: {e}") + + # 初始化OCR引擎 + if not os.path.exists(ocr_config_path): + print(f"OCR config not found: {ocr_config_path}") + return False + + try: + _ocr_engine = RapidOCR(config_path=ocr_config_path) + except Exception as e: + print(f"OCR model load failed: {e}") + return False + + return True if _ocr_engine else False + + +def detect(frame): + """OCR检测并筛选违禁词,返回(是否检测到违禁词, 结果字符串)""" + if not _ocr_engine or not _forbidden_words or frame is None or frame.size == 0: + return (False, "未初始化或无效帧") + + try: + ocr_res = _ocr_engine(frame) + except Exception as e: + print(f"OCR detect error: {e}") + return (False, f"检测错误: {str(e)}") + + if not ocr_res or not hasattr(ocr_res, 'txts') or not hasattr(ocr_res, 'scores'): + return (False, "无OCR结果") + + # 处理OCR结果 + texts = [t.strip() for t in ocr_res.txts if t and isinstance(t, str)] + confs = [c for c in ocr_res.scores if c and isinstance(c, (int, float))] + if len(texts) != len(confs): + return (False, "OCR结果格式异常") + + # 筛选违禁词 + vio_info = [] + for txt, conf in zip(texts, confs): + if conf < _conf_threshold: + continue + matched = [w for w in _forbidden_words if w in txt] + if matched: + vio_info.append(f"文本: '{txt}' 包含违禁词: {', '.join(matched)} (置信度: {conf:.2f})") + + # 构建结果字符串 + has_text = len(texts) > 0 + has_violation = len(vio_info) > 0 + + if not has_text: + return (False, "未识别到文本") + elif has_violation: + return (True, "; ".join(vio_info)) + else: + return (False, "未检测到违禁词") \ No newline at end of file diff --git a/core/rtc.py b/core/rtc.py deleted file mode 100644 index 823cd27..0000000 --- a/core/rtc.py +++ /dev/null @@ -1,137 +0,0 @@ -import asyncio -import logging -from aiortc import RTCPeerConnection, RTCSessionDescription -import aiohttp -from ocr.ocr_violation_detector import OCRViolationDetector - -import logging - -# 创建检测器实例 -detector = OCRViolationDetector( - forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt", - ocr_confidence_threshold=0.7, - log_level=logging.INFO, - log_file="ocr_detection.log" -) - -# 配置日志 -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger("whep_video_puller") - - -async def whep_pull_video_stream(ip,whep_url): - """ - 通过WHEP从指定URL拉取视频流并在收到每一帧时打印消息 - - Args: - whep_url: WHEP端点的URL - """ - pc = RTCPeerConnection() - - # 添加连接状态变化监听 - @pc.on("connectionstatechange") - async def on_connectionstatechange(): - print(f"连接状态: {pc.connectionState}") - - # 添加ICE连接状态变化监听 - @pc.on("iceconnectionstatechange") - async def on_iceconnectionstatechange(): - print(f"ICE连接状态: {pc.iceConnectionState}") - - # 添加视频接收器 - pc.addTransceiver("video", direction="recvonly") - - # 处理接收到的视频轨道 - @pc.on("track") - def on_track(track): - print(f"接收到轨道: {track.kind}") - if track.kind == "video": - print(f"轨道ID: {track.id}") - print(f"轨道就绪状态: {track.readyState}") - # 创建异步任务来处理视频帧 - asyncio.ensure_future(handle_video_track(track)) - - async def handle_video_track(track): - """处理视频轨道,接收并打印每一帧""" - frame_count = 0 - print("开始处理视频轨道...") - - while True: - try: - # 尝试接收帧 - frame = await track.recv() - frame_count += 1 - print(f"收到原始帧 (第{frame_count}帧)") - - # 打印帧的基本信息 - if hasattr(frame, 'width') and hasattr(frame, 'height'): - print(f" 尺寸: {frame.width}x{frame.height}") - if hasattr(frame, 'time_base'): - print(f" 时间基准: {frame.time_base}") - if hasattr(frame, 'pts'): - print(f" 显示时间戳: {frame.pts}") - - has_violation, violations, confidences = OCRViolationDetector.detect(frame) - - # 输出检测结果 - if has_violation: - detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:") - for word, conf in zip(violations, confidences): - detector.logger.info(f"- {word} (置信度: {conf:.4f})") - else: - detector.logger.info("图片中未检测到违禁词") - except Exception as e: - print(f"接收帧时出错: {e}") - # 等待一段时间后重试 - await asyncio.sleep(0.1) - continue - - # 创建offer - offer = await pc.createOffer() - await pc.setLocalDescription(offer) - - print(f"本地SDP信息:\n{offer.sdp}") - - # 通过HTTP POST发送offer到WHEP端点 - async with aiohttp.ClientSession() as session: - async with session.post( - whep_url, - data=offer.sdp, - headers={"Content-Type": "application/sdp"} - ) as response: - if response.status != 201: - print(f"WHEP服务器返回错误: {response.status}") - print(f"响应内容: {await response.text()}") - raise Exception(f"WHEP服务器返回错误: {response.status}") - - # 获取answer SDP - answer_sdp = await response.text() - - # 创建RTCSessionDescription对象 - answer = RTCSessionDescription(sdp=answer_sdp, type="answer") - - print(f"收到远程SDP:\n{answer_sdp}") - - # 设置远程描述 - await pc.setRemoteDescription(answer) - - print("连接已建立,开始接收视频流...") - - # 保持连接,直到用户中断 - try: - while True: - await asyncio.sleep(1) - # 检查连接状态 - print(f"当前连接状态: {pc.connectionState}") - except KeyboardInterrupt: - print("用户中断,关闭连接...") - finally: - await pc.close() - - -if __name__ == "__main__": - # 替换为你的WHEP端点URL - WHEP_URL = "http://192.168.110.25:1985/rtc/v1/whep/?app=live&stream=473b95a47e338301cbd96809ea7ac416" - - # 运行拉流任务 - asyncio.run(whep_pull_video_stream(WHEP_URL)) diff --git a/core/rtmp.py b/core/rtmp.py deleted file mode 100644 index 6de5447..0000000 --- a/core/rtmp.py +++ /dev/null @@ -1,112 +0,0 @@ -import asyncio -import logging -import cv2 -import time -from ocr.model_violation_detector import MultiModelViolationDetector - - -# 配置文件相对路径(根据实际目录结构调整) -YOLO_MODEL_PATH = "../ocr/models/best.pt" # 关键修正:从core目录向上一级找ocr文件夹 -FORBIDDEN_WORDS_PATH = "../ocr/forbidden_words.txt" -OCR_CONFIG_PATH = "../ocr/config/1.yaml" -KNOWN_FACES_DIR = "../ocr/known_faces" - -# 创建检测器实例 -detector = MultiModelViolationDetector( - forbidden_words_path=FORBIDDEN_WORDS_PATH, - ocr_config_path=OCR_CONFIG_PATH, - yolo_model_path=YOLO_MODEL_PATH, - known_faces_dir=KNOWN_FACES_DIR, - ocr_confidence_threshold=0.5 -) - -# 配置日志 -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger("rtmp_video_puller") - - -async def rtmp_pull_video_stream(rtmp_url): - """ - 通过RTMP从指定URL拉取视频流并进行违规检测 - """ - cap = None # 初始化视频捕获对象 - try: - # 异步打开RTMP流 - cap = await asyncio.to_thread( - cv2.VideoCapture, - rtmp_url, - cv2.CAP_FFMPEG # 指定FFmpeg后端确保RTMP兼容性 - ) - - # 检查RTMP流是否成功打开 - is_opened = await asyncio.to_thread(cap.isOpened) - if not is_opened: - raise Exception(f"RTMP流打开失败: {rtmp_url}(请检查URL有效性和FFmpeg环境)") - - # 获取RTMP流基础信息 - width = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_WIDTH) - height = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_HEIGHT) - fps = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FPS) - - # 处理异常情况 - fps = fps if fps > 0 else 30.0 - width, height = int(width), int(height) - - # 打印流初始化成功信息 - print(f"RTMP流状态: 已成功连接") - print(f"流基础信息: 分辨率 {width}x{height} | 配置帧率 {fps:.2f} FPS") - print("开始接收视频帧...(按 Ctrl+C 中断)") - - # 初始化帧统计参数 - frame_count = 0 - start_time = time.time() - - # 循环读取视频帧 - while True: - ret, frame = await asyncio.to_thread(cap.read) - - if not ret: - print(f"RTMP流状态: 帧读取失败(可能流已中断或结束)") - break - - frame_count += 1 - - # 打印当前帧信息 - print(f"收到帧 (第{frame_count}帧)") - print(f" 帧尺寸: {width}x{height}") - print(f" 配置帧率: {fps:.2f} FPS") - - if frame is not None: - has_violation, violation_type, details = detector.detect_violations(frame) - if has_violation: - print(f"检测到违规 - 类型: {violation_type}, 详情: {details}") - else: - print("未检测到任何违规内容") - else: - print(f"无法读取测试图像") - - # 每100帧统计一次实际接收帧率 - if frame_count % 100 == 0: - elapsed_time = time.time() - start_time - actual_fps = frame_count / elapsed_time - print(f"---- 帧统计: 累计{frame_count}帧 | 实际平均帧率 {actual_fps:.2f} FPS ----") - - except KeyboardInterrupt: - print(f"\n用户操作: 已通过 Ctrl+C 中断程序") - except Exception as e: - logger.error(f"RTMP流处理异常: {str(e)}", exc_info=True) - print(f"错误信息: {str(e)}") - finally: - if cap is not None: - await asyncio.to_thread(cap.release) - print(f"\n资源释放: RTMP流已关闭") - print(f"最终统计: 共接收 {frame_count if 'frame_count' in locals() else 0} 帧") - - -if __name__ == "__main__": - RTMP_URL = "rtmp://192.168.110.25:1935/live/473b95a47e338301cbd96809ea7ac416" - - try: - asyncio.run(rtmp_pull_video_stream(RTMP_URL)) - except Exception as e: - print(f"程序启动失败: {str(e)}") diff --git a/core/yolo.py b/core/yolo.py new file mode 100644 index 0000000..aed4698 --- /dev/null +++ b/core/yolo.py @@ -0,0 +1,55 @@ +import os + +import cv2 +from ultralytics import YOLO + +# 全局变量 +_yolo_model = None + + +model_path = os.path.join(os.path.dirname(__file__), "models", "best.pt") + + +def load_model(): + """加载YOLO目标检测模型""" + global _yolo_model + + try: + _yolo_model = YOLO(model_path) + except Exception as e: + print(f"YOLO model load failed: {e}") + return False + + return True if _yolo_model else False + + +def detect(frame, conf_threshold=0.2): + """YOLO目标检测,返回(是否识别到, 结果字符串)""" + global _yolo_model + + if not _yolo_model or frame is None: + return (False, "未初始化或无效帧") + + try: + results = _yolo_model(frame, conf=conf_threshold) + # 检查是否有检测结果 + has_results = len(results[0].boxes) > 0 if results else False + + if not has_results: + return (False, "未检测到目标") + + # 构建结果字符串 + result_parts = [] + for box in results[0].boxes: + cls = int(box.cls[0]) + conf = float(box.conf[0]) + bbox = [float(x) for x in box.xyxy[0]] + class_name = _yolo_model.names[cls] if hasattr(_yolo_model, 'names') else f"类别{cls}" + result_parts.append(f"{class_name} (置信度: {conf:.2f}, 边界框: {bbox})") + + result_str = "; ".join(result_parts) + return (has_results, result_str) + + except Exception as e: + print(f"YOLO detect error: {e}") + return (False, f"检测错误: {str(e)}") \ No newline at end of file diff --git a/main.py b/main.py index 12cfe33..d23764b 100644 --- a/main.py +++ b/main.py @@ -1,12 +1,19 @@ -import uvicorn -from fastapi import FastAPI +from PIL import Image # 正确导入 +import numpy as np +import uvicorn +from PIL import Image +from fastapi import FastAPI +from core.all import load_model,detect from ds.config import SERVER_CONFIG from middle.error_handler import global_exception_handler from service.user_service import router as user_router +from service.sensitive_service import router as sensitive_router +from service.face_service import router as face_router from service.device_service import router as device_router from ws.ws import ws_router, lifespan + # ------------------------------ # 初始化 FastAPI 应用、指定生命周期管理 # ------------------------------ @@ -22,6 +29,8 @@ app = FastAPI( # ------------------------------ app.include_router(user_router) app.include_router(device_router) +app.include_router(face_router) +app.include_router(sensitive_router) app.include_router(ws_router) # ------------------------------ @@ -33,11 +42,19 @@ app.add_exception_handler(Exception, global_exception_handler) # 启动服务 # ------------------------------ if __name__ == "__main__": + # -------------------------- 配置调整 -------------------------- + # 模型配置路径(建议改为环境变量) + YOLO_MODEL_PATH = r"/core/models\best.pt" + OCR_CONFIG_PATH = r"/core/config\config.yaml" + + # 初始化项目(默认端口设为8000,避免初始化失败时port未定义) port = int(SERVER_CONFIG.get("port", 8000)) + + # 启动 UVicorn 服务 uvicorn.run( app="main:app", host="0.0.0.0", port=port, - reload=True, + workers=8, ws="websockets" ) diff --git a/middle/auth_middleware.py b/middle/auth_middleware.py index 9cac02d..9897f61 100644 --- a/middle/auth_middleware.py +++ b/middle/auth_middleware.py @@ -8,7 +8,8 @@ from passlib.context import CryptContext from ds.config import JWT_CONFIG from ds.db import db -from service.user_service import UserResponse + +# 移除这里的 from service.user_service import UserResponse 导入 # ------------------------------ # 密码加密配置 @@ -25,6 +26,7 @@ ACCESS_TOKEN_EXPIRE_MINUTES = int(JWT_CONFIG["access_token_expire_minutes"]) # OAuth2 依赖(从请求头获取 Token、格式:Bearer ) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login") + # ------------------------------ # 密码工具函数 # ------------------------------ @@ -32,10 +34,12 @@ def verify_password(plain_password: str, hashed_password: str) -> bool: """验证明文密码与加密密码是否匹配""" return pwd_context.verify(plain_password, hashed_password) + def get_password_hash(password: str) -> str: """对明文密码进行 bcrypt 加密""" return pwd_context.hash(password) + # ------------------------------ # JWT 工具函数 # ------------------------------ @@ -53,11 +57,15 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) - encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt + # ------------------------------ # 认证依赖(获取当前登录用户) # ------------------------------ -def get_current_user(token: str = Depends(oauth2_scheme)) -> UserResponse: +def get_current_user(token: str = Depends(oauth2_scheme)): # 移除返回类型注解 """从 Token 中解析用户信息、验证通过后返回当前用户""" + # 延迟导入,打破循环依赖 + from schema.user_schema import UserResponse # 在这里导入 + # 认证失败异常 credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -89,7 +97,7 @@ def get_current_user(token: str = Depends(oauth2_scheme)) -> UserResponse: raise credentials_exception # 用户不存在 # 转换为 UserResponse 模型(自动校验字段) - return UserResponse(** user) + return UserResponse(**user) except Exception as e: raise credentials_exception from e finally: diff --git a/ocr/face_recognizer.py b/ocr/face_recognizer.py deleted file mode 100644 index 1947748..0000000 --- a/ocr/face_recognizer.py +++ /dev/null @@ -1,139 +0,0 @@ -import os -import cv2 -import numpy as np -import insightface -from insightface.app import FaceAnalysis - - -class FaceRecognizer: - """ - 封装InsightFace人脸识别功能,支持从文件夹加载已知人脸。 - """ - - def __init__(self, known_faces_dir: str): - self.known_faces_dir = known_faces_dir - self.app = self._initialize_insightface() - self.known_faces_embeddings = {} - self.known_faces_names = [] - self._load_known_faces() - - def _initialize_insightface(self): - """初始化InsightFace FaceAnalysis应用""" - print("初始化InsightFace引擎...") - try: - app = FaceAnalysis(name='buffalo_l', root='~/.insightface') - app.prepare(ctx_id=0, det_size=(640, 640)) - print("InsightFace引擎初始化完成") - return app - except Exception as e: - print(f"InsightFace初始化失败: {e}") - print("请检查依赖是否安装及模型是否可访问") - return None - - def _load_known_faces(self): - """加载已知人脸特征""" - if not os.path.exists(self.known_faces_dir): - print(f"已知人脸目录不存在,已创建: {self.known_faces_dir}") - os.makedirs(self.known_faces_dir, exist_ok=True) - return - - print(f"从目录加载人脸特征: {self.known_faces_dir}") - for person_name in os.listdir(self.known_faces_dir): - person_dir = os.path.join(self.known_faces_dir, person_name) - if os.path.isdir(person_dir): - print(f"处理人物: {person_name}") - embeddings = [] - for filename in os.listdir(person_dir): - if filename.lower().endswith(('.png', '.jpg', '.jpeg')): - image_path = os.path.join(person_dir, filename) - try: - img = cv2.imread(image_path) - if img is None: - print(f"无法读取图片: {image_path},已跳过") - continue - - faces = self.app.get(img) - if faces: - embeddings.append(faces[0].embedding) - print(f"提取特征成功: {filename}") - else: - print(f"未检测到人脸: {filename},已跳过") - except Exception as e: - print(f"处理图片出错 {image_path}: {e}") - - if embeddings: - self.known_faces_embeddings[person_name] = np.array(embeddings).mean(axis=0) - self.known_faces_names.append(person_name) - print(f"人物 {person_name} 加载完成,共 {len(embeddings)} 张照片") - else: - print(f"人物 {person_name} 无有效特征,已跳过") - print(f"人脸加载完成,共 {len(self.known_faces_names)} 人") - - def recognize(self, frame, threshold=0.4): - """识别人脸并返回结果""" - if not self.app or not self.known_faces_names: - return False, None, None - - faces = self.app.get(frame) - if not faces: - return False, None, None - - for face in faces: - for known_name in self.known_faces_names: - known_embedding = self.known_faces_embeddings[known_name] - - embedding1 = face.embedding.astype(np.float32) - embedding2 = known_embedding.astype(np.float32) - - dot_product = np.dot(embedding1, embedding2) - norm_embedding1 = np.linalg.norm(embedding1) - norm_embedding2 = np.linalg.norm(embedding2) - - similarity = 0.0 if (norm_embedding1 == 0 or norm_embedding2 == 0) else ( - dot_product / (norm_embedding1 * norm_embedding2) - ) - - if similarity >= threshold: - print(f"检测到已知人物: {known_name} (相似度: {similarity:.4f})") - return True, known_name, similarity - - return False, None, None - - def test_single_image(self, image_path: str, threshold=0.4): - """测试单张图片识别""" - if not os.path.exists(image_path): - print(f"图片不存在: {image_path}") - return False, None, None - - frame = cv2.imread(image_path) - if frame is None: - print(f"无法读取图片: {image_path}") - return False, None, None - - result, name, similarity = self.recognize(frame, threshold) - - if result: - print(f"识别结果: {name} (相似度: {similarity:.4f})") - - faces = self.app.get(frame) - for face in faces: - bbox = face.bbox.astype(int) - cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2) - text = f"{name}: {similarity:.2f}" - cv2.putText(frame, text, (bbox[0], bbox[1] - 10), - cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) - - cv2.imshow('识别结果', frame) - print("按任意键关闭窗口...") - cv2.waitKey(0) - cv2.destroyAllWindows() - else: - print("未识别到已知人脸") - - return result, name, similarity - -# -# if __name__ == "__main__": -# recognizer = FaceRecognizer(known_faces_dir="known_faces") -# test_image_path = r"F:\OCR\RapidOCR-main\known_faces\B\14sino-qiu02-master1050.jpg" -# recognizer.test_single_image(test_image_path, threshold=0.4) \ No newline at end of file diff --git a/ocr/feature_extraction.py b/ocr/feature_extraction.py deleted file mode 100644 index ac9e1e9..0000000 --- a/ocr/feature_extraction.py +++ /dev/null @@ -1,156 +0,0 @@ -import cv2 -import numpy as np -import insightface -from insightface.app import FaceAnalysis -from io import BytesIO -from PIL import Image - - -class BinaryFaceFeatureHandler: - """ - 专门处理图片二进制数据的特征提取器,支持分批次接收二进制数据并累积计算平均特征 - """ - - def __init__(self): - self.app = self._init_insightface() - self.feature_list = [] # 存储所有图片二进制数据提取的特征 - - def _init_insightface(self): - """初始化InsightFace引擎""" - try: - print("正在初始化InsightFace引擎...") - app = FaceAnalysis(name='buffalo_l', root='~/.insightface') - app.prepare(ctx_id=0, det_size=(640, 640)) - print("InsightFace引擎初始化完成") - return app - except Exception as e: - print(f"InsightFace初始化失败: {e}") - return None - - def add_binary_data(self, binary_data): - """ - 接收单张图片的二进制数据,提取特征并保存 - - 参数: - binary_data: 图片的二进制数据(bytes类型) - - 返回: - 成功提取特征时返回 (True, 特征值numpy数组) - 失败时返回 (False, None) - """ - if not self.app: - print("引擎未初始化,无法处理") - return False, None - - try: - # 直接处理二进制数据:转换为图像格式 - img = Image.open(BytesIO(binary_data)) - frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) - - # 提取特征 - faces = self.app.get(frame) - if faces: - # 获取当前提取的特征值 - current_feature = faces[0].embedding - # 添加到特征列表 - self.feature_list.append(current_feature) - print(f"已累计 {len(self.feature_list)} 个特征") - # 返回成功标志和当前特征值 - return True,current_feature - else: - print("二进制数据中未检测到人脸") - return False, None - except Exception as e: - print(f"处理二进制数据出错: {e}") - return False, None - - def get_average_feature(self, features): - """ - 计算多个特征向量的平均值 - - 参数: - features: 特征值列表,每个元素可以是字符串格式或numpy数组 - 例如: [feature1, feature2, ...] - 返回: - 单一平均特征向量的numpy数组,若无可计算数据则返回None - """ - try: - # 验证输入是否为列表且不为空 - if not isinstance(features, list) or len(features) == 0: - print("输入必须是包含至少一个特征值的列表") - return None - - # 处理每个特征值 - processed_features = [] - for i, embedding in enumerate(features): - try: - if isinstance(embedding, str): - # 处理包含括号和逗号的字符串格式 - embedding_str = embedding.replace('[', '').replace(']', '').replace(',', ' ').strip() - embedding_list = [float(num) for num in embedding_str.split() if num.strip()] - embedding_np = np.array(embedding_list, dtype=np.float32) - else: - embedding_np = np.array(embedding, dtype=np.float32) - - # 验证特征值格式 - if len(embedding_np.shape) == 1: - processed_features.append(embedding_np) - print(f"已添加第 {i + 1} 个特征值用于计算平均值") - else: - print(f"跳过第 {i + 1} 个特征值,不是一维数组") - - except Exception as e: - print(f"处理第 {i + 1} 个特征值时出错: {e}") - - # 确保有有效的特征值 - if not processed_features: - print("没有有效的特征值用于计算平均值") - return None - - # 检查所有特征向量维度是否相同 - dims = {feat.shape[0] for feat in processed_features} - if len(dims) > 1: - print(f"特征值维度不一致,无法计算平均值。检测到的维度: {dims}") - return None - - # 计算平均值 - avg_feature = np.mean(processed_features, axis=0) - print(f"成功计算 {len(processed_features)} 个特征值的平均特征向量,维度: {avg_feature.shape[0]}") - - return avg_feature - - except Exception as e: - print(f"计算平均特征值时出错: {e}") - return None - - # def clear(self): - # """清空已存储的特征数据""" - # self.feature_list = [] - # print("已清空所有特征数据") - - -# # 使用示例 -# if __name__ == "__main__": -# handler = BinaryFaceFeatureHandler() -# -# # 模拟接收图片二进制数据 -# try: -# # 第一次接收 -# with open(r"D:\Git\bin\video\ocr\known_faces\B\102-f.jpg_1140x855.jpg", "rb") as f: -# bin_data1 = f.read() -# success, feature1 = handler.add_binary_data(bin_data1) -# if success: -# print(f"第一次提取的特征值前5个: {feature1[:5]}") -# -# # 第二次接收 -# with open(r"D:\Git\bin\video\ocr\known_faces\B\104-1.jpg", "rb") as f: -# bin_data2 = f.read() -# success, feature2 = handler.add_binary_data(bin_data2) -# if success: -# print(f"第二次提取的特征值前5个: {feature2[:5]}") -# -# # 计算平均值 -# avg_feature = handler.get_average_feature() -# -# except Exception as e: -# print(f"处理过程出错: {e}") diff --git a/ocr/forbidden_words.txt b/ocr/forbidden_words.txt deleted file mode 100644 index 1b96a31..0000000 --- a/ocr/forbidden_words.txt +++ /dev/null @@ -1,18 +0,0 @@ -打倒习近平 -打到习近平 -打倒毛泽东 -打到毛泽东 -打到主席 -打倒主席 -打到共产主义 -打倒共产主义 -打到共产党 -打倒共产党 -胖猫 -法轮功 -法轮大法 -法轮大法好 -法轮功大法好 -法轮 -李洪志 -习近平 \ No newline at end of file diff --git a/ocr/images/img.png b/ocr/images/img.png deleted file mode 100644 index 7c6e9bf..0000000 Binary files a/ocr/images/img.png and /dev/null differ diff --git a/ocr/images/img_7.png b/ocr/images/img_7.png deleted file mode 100644 index 2308c31..0000000 Binary files a/ocr/images/img_7.png and /dev/null differ diff --git a/ocr/known_faces/B/102-f.jpg_1140x855.jpg b/ocr/known_faces/B/102-f.jpg_1140x855.jpg deleted file mode 100644 index 9184dbc..0000000 Binary files a/ocr/known_faces/B/102-f.jpg_1140x855.jpg and /dev/null differ diff --git a/ocr/known_faces/B/104-1.jpg b/ocr/known_faces/B/104-1.jpg deleted file mode 100644 index e6a2a8f..0000000 Binary files a/ocr/known_faces/B/104-1.jpg and /dev/null differ diff --git a/ocr/known_faces/B/110627170414_boxilai_304x304_cns.jpg.webp b/ocr/known_faces/B/110627170414_boxilai_304x304_cns.jpg.webp deleted file mode 100644 index 9fb4dd6..0000000 Binary files a/ocr/known_faces/B/110627170414_boxilai_304x304_cns.jpg.webp and /dev/null differ diff --git a/ocr/known_faces/B/14sino-qiu02-master1050.jpg b/ocr/known_faces/B/14sino-qiu02-master1050.jpg deleted file mode 100644 index b18210c..0000000 Binary files a/ocr/known_faces/B/14sino-qiu02-master1050.jpg and /dev/null differ diff --git a/ocr/known_faces/B/xilai003.webp b/ocr/known_faces/B/xilai003.webp deleted file mode 100644 index 186417b..0000000 Binary files a/ocr/known_faces/B/xilai003.webp and /dev/null differ diff --git a/ocr/known_faces/W/120208041156_wang_lijun_304x171_xinhua.jpg.webp b/ocr/known_faces/W/120208041156_wang_lijun_304x171_xinhua.jpg.webp deleted file mode 100644 index 520bc13..0000000 Binary files a/ocr/known_faces/W/120208041156_wang_lijun_304x171_xinhua.jpg.webp and /dev/null differ diff --git a/ocr/known_faces/W/2f0f70db48.jpg b/ocr/known_faces/W/2f0f70db48.jpg deleted file mode 100644 index 7dbeda5..0000000 Binary files a/ocr/known_faces/W/2f0f70db48.jpg and /dev/null differ diff --git a/ocr/known_faces/W/lijun-jumbo.jpg b/ocr/known_faces/W/lijun-jumbo.jpg deleted file mode 100644 index c1b742d..0000000 Binary files a/ocr/known_faces/W/lijun-jumbo.jpg and /dev/null differ diff --git a/ocr/known_faces/X/1404123658308624.jpg b/ocr/known_faces/X/1404123658308624.jpg deleted file mode 100644 index 1ece977..0000000 Binary files a/ocr/known_faces/X/1404123658308624.jpg and /dev/null differ diff --git a/ocr/known_faces/X/Xu_CaiHou.jpg b/ocr/known_faces/X/Xu_CaiHou.jpg deleted file mode 100644 index e81f62e..0000000 Binary files a/ocr/known_faces/X/Xu_CaiHou.jpg and /dev/null differ diff --git a/ocr/known_faces/X/a0a2e8d4-69d2-409d-ac3e-fdf8f6755f0e_cx0_cy6_cw0_w1023_r1_s.jpg b/ocr/known_faces/X/a0a2e8d4-69d2-409d-ac3e-fdf8f6755f0e_cx0_cy6_cw0_w1023_r1_s.jpg deleted file mode 100644 index 591264d..0000000 Binary files a/ocr/known_faces/X/a0a2e8d4-69d2-409d-ac3e-fdf8f6755f0e_cx0_cy6_cw0_w1023_r1_s.jpg and /dev/null differ diff --git a/ocr/logger_config.py b/ocr/logger_config.py deleted file mode 100644 index 09052a5..0000000 --- a/ocr/logger_config.py +++ /dev/null @@ -1,49 +0,0 @@ -#日志文件 -import logging -import sys - -def setup_logger(): - """ - 配置一个全局日志记录器,支持输出到控制台和文件。 - """ - # 创建一个日志记录器 - - # 配置日志 - logging.basicConfig(level=logging.INFO) - logger = logging.getLogger(__name__) - - logger = logging.getLogger("ViolationDetectorLogger") - logger.setLevel(logging.DEBUG) # 设置最低级别为DEBUG - - # 如果已经有处理器了,就不要重复添加,防止日志重复打印 - if logger.hasHandlers(): - return logger - - # --- 控制台处理器 --- - console_handler = logging.StreamHandler(sys.stdout) - # 对于控制台,我们只显示INFO及以上级别的信息 - console_handler.setLevel(logging.INFO) - console_formatter = logging.Formatter( - '%(asctime)s - %(levelname)s - [%(module)s:%(lineno)d] - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' - ) - console_handler.setFormatter(console_formatter) - - # --- 文件处理器 --- - file_handler = logging.FileHandler("violation_detector.log", mode='a', encoding='utf-8') - # 对于文件,我们记录所有DEBUG及以上级别的信息 - file_handler.setLevel(logging.DEBUG) - file_formatter = logging.Formatter( - '%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' - ) - file_handler.setFormatter(file_formatter) - - # 将处理器添加到日志记录器 - logger.addHandler(console_handler) - logger.addHandler(file_handler) - - return logger - -# 创建并导出logger实例 -logger = setup_logger() diff --git a/ocr/model_violation_detector.py b/ocr/model_violation_detector.py deleted file mode 100644 index d7958ae..0000000 --- a/ocr/model_violation_detector.py +++ /dev/null @@ -1,136 +0,0 @@ -import os - -import cv2 -import yaml -from pathlib import Path -from .ocr_violation_detector import OCRViolationDetector -from .yolo_violation_detector import ViolationDetector as YoloViolationDetector -from .face_recognizer import FaceRecognizer - -class MultiModelViolationDetector: - """ - 多模型违规检测封装类,串行调用OCR、人脸识别和YOLO模型,任一模型检测到违规即返回结果 - """ - - def __init__(self, - forbidden_words_path: str, - ocr_config_path: str, - yolo_model_path: str, - known_faces_dir: str, - ocr_confidence_threshold: float = 0.5): - """ - 初始化所有检测模型 - """ - # 初始化OCR检测器 - self.ocr_detector = OCRViolationDetector( - forbidden_words_path=forbidden_words_path, - ocr_config_path=ocr_config_path, - ocr_confidence_threshold=ocr_confidence_threshold - ) - - # 初始化人脸识别器 - self.face_recognizer = FaceRecognizer( - known_faces_dir=known_faces_dir - ) - - # 初始化YOLO检测器 - self.yolo_detector = YoloViolationDetector( - model_path=yolo_model_path - ) - - print("多模型违规检测器初始化完成") - - def detect_violations(self, frame): - """ - 串行调用三个检测模型(OCR → 人脸识别 → YOLO),任一检测到违规即返回结果 - """ - # 1. 首先进行OCR违禁词检测 - try: - ocr_has_violation, ocr_words, ocr_confs = self.ocr_detector.detect(frame) - if ocr_has_violation: - details = { - "words": ocr_words, - "confidences": ocr_confs - } - print(f"警告: OCR检测到违禁内容: {details}") - return (True, "ocr", details) - except Exception as e: - print(f"错误: OCR检测出错: {str(e)}") - - # 2. 接着进行人脸识别检测 - try: - face_has_violation, face_name, face_similarity = self.face_recognizer.recognize(frame) - if face_has_violation: - details = { - "name": face_name, - "similarity": face_similarity - } - print(f"警告: 人脸识别到违规人员: {details}") - return (True, "face", details) - except Exception as e: - print(f"错误: 人脸识别出错: {str(e)}") - - # 3. 最后进行YOLO目标检测 - try: - yolo_results = self.yolo_detector.detect(frame) - if len(yolo_results.boxes) > 0: - details = { - "classes": yolo_results.names, - "boxes": yolo_results.boxes.xyxy.tolist(), - "confidences": yolo_results.boxes.conf.tolist(), - "class_ids": yolo_results.boxes.cls.tolist() - } - print(f"警告: YOLO检测到违规目标: {details}") - return (True, "yolo", details) - except Exception as e: - print(f"错误: YOLO检测出错: {str(e)}") - - # 所有检测均未发现违规 - return (False, None, None) - - -def load_config(config_path: str) -> dict: - """加载YAML配置文件""" - try: - with open(config_path, 'r', encoding='utf-8') as f: - return yaml.safe_load(f) - except FileNotFoundError: - print(f"错误: 配置文件未找到: {config_path}") - raise - except yaml.YAMLError as e: - print(f"错误: 配置文件格式错误: {config_path}, 错误: {str(e)}") - raise - except Exception as e: - print(f"错误: 加载配置文件出错: {str(e)}") - raise - - -# 使用示例 -# if __name__ == "__main__": -# # 加载配置文件 -# config = load_config("config.yaml") # 配置文件路径,可根据实际情况修改 -# -# # 初始化多模型检测器 -# detector = MultiModelViolationDetector( -# forbidden_words_path=config["forbidden_words_path"], -# ocr_config_path=config["ocr_config_path"], -# yolo_model_path=config["yolo_model_path"], -# known_faces_dir=config["known_faces_dir"], -# ocr_confidence_threshold=config.get("ocr_confidence_threshold", 0.5) -# ) -# -# # 读取测试图像(可替换为视频帧读取逻辑) -# test_image_path = config.get("test_image_path") # 从配置文件获取测试图片路径 -# if test_image_path: -# frame = cv2.imread(test_image_path) -# -# if frame is not None: -# has_violation, violation_type, details = detector.detect_violations(frame) -# if has_violation: -# print(f"检测到违规 - 类型: {violation_type}, 详情: {details}") -# else: -# print("未检测到任何违规内容") -# else: -# print(f"无法读取测试图像: {test_image_path}") -# else: -# print("配置文件中未指定测试图像路径") \ No newline at end of file diff --git a/ocr/models/best.pt b/ocr/models/best.pt deleted file mode 100644 index c6958f7..0000000 Binary files a/ocr/models/best.pt and /dev/null differ diff --git a/ocr/ocr_violation_detector.py b/ocr/ocr_violation_detector.py deleted file mode 100644 index ebbb068..0000000 --- a/ocr/ocr_violation_detector.py +++ /dev/null @@ -1,178 +0,0 @@ -import os -import cv2 -from rapidocr import RapidOCR - - -class OCRViolationDetector: - """ - 封装RapidOCR引擎,用于检测图像帧中的违禁词。 - 核心功能:加载违禁词、初始化OCR引擎、单帧图像违禁词检测 - """ - - def __init__(self, - forbidden_words_path: str, - ocr_config_path: str, - ocr_confidence_threshold: float = 0.5): - """ - 初始化OCR引擎和违禁词列表。 - - Args: - forbidden_words_path (str): 违禁词列表 .txt 文件的路径。 - ocr_config_path (str): OCR配置文件(如1.yaml)的路径。 - ocr_confidence_threshold (float): OCR识别结果的置信度阈值(0~1)。 - """ - # 加载违禁词 - self.forbidden_words = self._load_forbidden_words(forbidden_words_path) - - # 初始化RapidOCR引擎 - self.ocr_engine = self._initialize_ocr(ocr_config_path) - - # 校验核心依赖是否就绪 - self._check_dependencies() - - # 设置置信度阈值(限制在0~1范围) - self.OCR_CONFIDENCE_THRESHOLD = max(0.0, min(ocr_confidence_threshold, 1.0)) - print(f"OCR置信度阈值已设置(范围0~1): {self.OCR_CONFIDENCE_THRESHOLD:.4f}") - - def _load_forbidden_words(self, path: str) -> set: - """ - 从TXT文件加载违禁词(去重、过滤空行,支持UTF-8编码) - """ - forbidden_words = set() - - # 检查文件是否存在 - if not os.path.exists(path): - print(f"错误:违禁词文件不存在: {path}") - return forbidden_words - - # 读取文件并处理内容 - try: - with open(path, 'r', encoding='utf-8') as f: - forbidden_words = { - line.strip() for line in f - if line.strip() # 跳过空行或纯空格行 - } - print(f"成功加载违禁词: {len(forbidden_words)} 个(已去重)") - except UnicodeDecodeError: - print(f"错误:违禁词文件编码错误(需UTF-8): {path}") - except PermissionError: - print(f"错误:无权限读取违禁词文件: {path}") - except Exception as e: - print(f"错误:加载违禁词失败: {str(e)}") - - return forbidden_words - - def _initialize_ocr(self, config_path: str) -> RapidOCR | None: - """ - 初始化RapidOCR引擎(校验配置文件、捕获初始化异常) - """ - print("开始初始化RapidOCR引擎...") - - # 检查配置文件是否存在 - if not os.path.exists(config_path): - print(f"错误:OCR配置文件不存在: {config_path}") - return None - - # 初始化OCR引擎 - try: - ocr_engine = RapidOCR(config_path=config_path) - print("RapidOCR引擎初始化成功") - return ocr_engine - except ImportError: - print("错误:RapidOCR依赖未安装(需执行:pip install rapidocr-onnxruntime)") - except Exception as e: - print(f"错误:RapidOCR初始化失败: {str(e)}") - - return None - - def _check_dependencies(self) -> None: - """校验OCR引擎和违禁词列表是否就绪""" - if not self.ocr_engine: - print("警告:⚠️ OCR引擎未就绪,违禁词检测功能将禁用") - if not self.forbidden_words: - print("警告:⚠️ 违禁词列表为空,违禁词检测功能将禁用") - - def detect(self, frame) -> tuple[bool, list, list]: - """ - 对单帧图像进行OCR违禁词检测(核心方法) - - Args: - frame: 输入图像帧(NumPy数组,BGR格式,cv2读取的图像)。 - - Returns: - tuple[bool, list, list]: - - 第一个元素:是否检测到违禁词(True/False); - - 第二个元素:检测到的违禁词列表(空列表表示无违禁词); - - 第三个元素:对应违禁词的置信度列表(与违禁词列表一一对应)。 - """ - # 初始化返回结果 - has_violation = False - violation_words = [] - violation_confs = [] - - # 前置校验 - if frame is None or frame.size == 0: - print("警告:输入图像帧为空或无效,跳过OCR检测") - return has_violation, violation_words, violation_confs - if not self.ocr_engine or not self.forbidden_words: - print("OCR引擎未就绪或违禁词为空,跳过OCR检测") - return has_violation, violation_words, violation_confs - - try: - # 执行OCR识别 - print("开始执行OCR识别...") - ocr_result = self.ocr_engine(frame) - print(f"RapidOCR原始结果: {ocr_result}") - - # 校验OCR结果是否有效 - if ocr_result is None: - print("OCR识别未返回任何结果(图像无文本或识别失败)") - return has_violation, violation_words, violation_confs - - # 检查txts和scores是否存在且不为None - if not hasattr(ocr_result, 'txts') or ocr_result.txts is None: - print("警告:OCR结果中txts为None或不存在") - return has_violation, violation_words, violation_confs - - if not hasattr(ocr_result, 'scores') or ocr_result.scores is None: - print("警告:OCR结果中scores为None或不存在") - return has_violation, violation_words, violation_confs - - # 转为列表并去None - if not isinstance(ocr_result.txts, (list, tuple)): - print(f"警告:OCR txts不是可迭代类型,实际类型: {type(ocr_result.txts)}") - texts = [] - else: - texts = [txt.strip() for txt in ocr_result.txts if txt and isinstance(txt, str)] - - if not isinstance(ocr_result.scores, (list, tuple)): - print(f"警告:OCR scores不是可迭代类型,实际类型: {type(ocr_result.scores)}") - confidences = [] - else: - confidences = [conf for conf in ocr_result.scores if conf and isinstance(conf, (int, float))] - - # 校验文本和置信度列表长度是否一致 - if len(texts) != len(confidences): - print(f"警告:OCR文本与置信度数量不匹配(文本{len(texts)}个,置信度{len(confidences)}个),跳过检测") - return has_violation, violation_words, violation_confs - if len(texts) == 0: - print("OCR未识别到任何有效文本") - return has_violation, violation_words, violation_confs - - # 遍历识别结果,筛选违禁词 - print(f"开始筛选违禁词(阈值{self.OCR_CONFIDENCE_THRESHOLD:.4f})") - for text, conf in zip(texts, confidences): - if conf < self.OCR_CONFIDENCE_THRESHOLD: - print(f"文本 '{text}' 置信度{conf:.4f} < 阈值,跳过") - continue - matched_words = [word for word in self.forbidden_words if word in text] - if matched_words: - has_violation = True - violation_words.extend(matched_words) - violation_confs.extend([conf] * len(matched_words)) - print(f"警告:检测到违禁词: {matched_words}(来源文本: '{text}',置信度: {conf:.4f})") - - except Exception as e: - print(f"错误:OCR检测过程异常: {str(e)}") - - return has_violation, violation_words, violation_confs \ No newline at end of file diff --git a/ocr/yolo_violation_detector.py b/ocr/yolo_violation_detector.py deleted file mode 100644 index 144503f..0000000 --- a/ocr/yolo_violation_detector.py +++ /dev/null @@ -1,47 +0,0 @@ -from ultralytics import YOLO -import cv2 - -class ViolationDetector: - """ - 用于加载YOLOv8 .pt模型并进行违规内容检测的类。 - """ - def __init__(self, model_path): - """ - 初始化检测器。 - - Args: - model_path (str): YOLO .pt模型的路径。 - """ - print(f"正在从 '{model_path}' 加载YOLO模型...") - self.model = YOLO(model_path) - print("YOLO模型加载成功。") - - def detect(self, frame): - """ - 对单帧图像进行目标检测。 - - Args: - frame: 输入的图像帧 (NumPy数组, BGR格式)。 - - Returns: - ultralytics.engine.results.Results: YOLO的检测结果对象。 - """ - # conf可以根据您的模型效果进行调整 - # --- 为了测试,我们暂时将置信度调低,例如 0.2 --- - results = self.model(frame, conf=0.2) - return results[0] - - def draw_boxes(self, frame, result): - """ - 在图像帧上绘制检测框。 - - Args: - frame: 原始图像帧。 - result: YOLO的检测结果对象。 - - Returns: - numpy.ndarray: 绘制了检测框的图像帧。 - """ - # 使用YOLO自带的plot功能,方便快捷 - annotated_frame = result.plot() - return annotated_frame \ No newline at end of file diff --git a/rtc/rtc.py b/rtc/rtc.py deleted file mode 100644 index d9e4a03..0000000 --- a/rtc/rtc.py +++ /dev/null @@ -1,164 +0,0 @@ -import queue -import asyncio -import aiohttp -import threading -import time -from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration -from aiortc.mediastreams import MediaStreamTrack - -# 创建一个长度为1的队列,用于生产者和消费者之间的通信 -frame_queue = queue.Queue(maxsize=1) - - -class VideoTrack(MediaStreamTrack): - """自定义视频轨道类,继承自MediaStreamTrack""" - kind = "video" - - def __init__(self, max_frames=100): - super().__init__() - self.frames = queue.Queue(maxsize=max_frames) - - async def recv(self): - return await super().recv() - - -def webrtc_producer(webrtc_url): - """ - 生产者方法:从WEBRTC读取视频帧并放入队列 - 仅当队列空时才放入新帧,否则丢弃 - """ - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - # 创建RTCPeerConnection对象,不使用ICE服务器 - pc = RTCPeerConnection(RTCConfiguration(iceServers=[])) - video_track = VideoTrack() - pc.addTrack(video_track) - - @pc.on("track") - async def on_track(track): - if track.kind == "video": - print("接收到视频轨道,开始接收视频帧") - while True: - # 从轨道接收视频帧 - frame = await track.recv() - # 转换为BGR24格式的NumPy数组 - frame_bgr24 = frame.to_ndarray(format='bgr24') - - # 检查队列是否为空,为空则加入,否则丢弃 - if frame_queue.empty(): - try: - frame_queue.put_nowait(frame_bgr24) - print("帧已放入队列") - except queue.Full: - print("队列已满,丢弃帧") - else: - print("队列非空,丢弃帧") - - async def main(): - # 创建并发送SDP Offer - offer = await pc.createOffer() - print("已创建本地SDP Offer") - await pc.setLocalDescription(offer) - - # 发送Offer到服务器并接收Answer - async with aiohttp.ClientSession() as session: - print(f"开始向服务器 {webrtc_url} 发送SDP Offer") - async with session.post( - webrtc_url, - data=offer.sdp.encode(), - headers={ - "Content-Type": "application/sdp", - "Content-Length": str(len(offer.sdp)) - }, - ssl=False - ) as response: - print("已接收到服务器的响应") - answer_sdp = await response.text() - await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer')) - - # 保持连接 - try: - while True: - await asyncio.sleep(0.1) - except KeyboardInterrupt: - pass - finally: - print("关闭RTCPeerConnection") - await pc.close() - - try: - loop.run_until_complete(main()) - finally: - loop.close() - - -def frame_consumer(ip): - """ - 消费者方法:从队列中读取帧并处理 - 每次处理后休眠200ms模拟延迟 - """ - print("消费者启动,开始等待帧...") - try: - while True: - # 阻塞等待队列中的帧 - frame = frame_queue.get() - print(f"消费帧,大小: {frame.shape}") - - has_violation, violations, confidences = OCRViolationDetector.detect(frame) - - - # 输出检测结果 - if has_violation: - detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:") - for word, conf in zip(violations, confidences): - detector.logger.info(f"- {word} (置信度: {conf:.4f})") - else: - detector.logger.info("图片中未检测到违禁词") - - - # 标记任务完成 - frame_queue.task_done() - except KeyboardInterrupt: - print("消费者退出") - - -def start_webrtc_stream(ip, webrtc_url): - """ - 启动WebRTC视频流处理的主方法 - 参数: webrtc_url - WebRTC服务器地址 - """ - print(f"开始连接到WebRTC服务器: {webrtc_url}") - - # 启动生产者线程 - producer_thread = threading.Thread( - target=webrtc_producer, - args=(webrtc_url,), - daemon=True, - name="webrtc-producer" - ) - - # 启动消费者线程 - consumer_thread = threading.Thread( - target=frame_consumer(ip), - daemon=True, - name="frame-consumer" - ) - - producer_thread.start() - consumer_thread.start() - print("生产者和消费者线程已启动") - - try: - # 保持主线程运行 - while True: - time.sleep(1) - except KeyboardInterrupt: - print("程序正在退出...") - - -if __name__ == "__main__": - # 示例用法 - # 实际使用时替换为真实的WebRTC服务器地址 - webrtc_server_url = "http://192.168.110.65:1985/rtc/v1/whep/?app=live&stream=677a4845aa48cb8526c811ad56fc5e60" - start_webrtc_stream(webrtc_server_url) diff --git a/rtmp/rtmp.py b/rtmp/rtmp.py deleted file mode 100644 index c200c04..0000000 --- a/rtmp/rtmp.py +++ /dev/null @@ -1,101 +0,0 @@ -import asyncio -import logging -import cv2 -import time - -# 配置日志(与WHEP代码保持一致的日志风格) -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger("rtmp_video_puller") - - -async def rtmp_pull_video_stream(rtmp_url): - """ - 通过RTMP从指定URL拉取视频流并在收到每一帧时打印消息 - 功能与WHEP拉流函数对齐:流状态反馈、帧信息打印、帧率统计、异常处理 - - Args: - rtmp_url: RTMP流的URL地址(如 rtmp://xxx/live/stream_key) - """ - cap = None # 初始化视频捕获对象 - try: - # 1. 异步打开RTMP流(指定FFmpeg后端确保RTMP兼容性,同步操作通过to_thread避免阻塞事件循环) - cap = await asyncio.to_thread( - cv2.VideoCapture, - rtmp_url, - cv2.CAP_FFMPEG # 必须指定FFmpeg后端,RTMP协议依赖该后端解析 - ) - - # 2. 检查RTMP流是否成功打开 - is_opened = await asyncio.to_thread(cap.isOpened) - if not is_opened: - raise Exception(f"RTMP流打开失败: {rtmp_url}(请检查URL有效性和FFmpeg环境)") - - # 3. 异步获取RTMP流基础信息(分辨率、帧率) - width = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_WIDTH) - height = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_HEIGHT) - fps = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FPS) - - # 处理异常情况:部分RTMP流未返回帧率时默认30FPS - fps = fps if fps > 0 else 30.0 - # 分辨率转为整数(视频尺寸必然是整数) - width, height = int(width), int(height) - - # 打印流初始化成功信息(与WHEP连接成功信息风格一致) - print(f"RTMP流状态: 已成功连接") - print(f"流基础信息: 分辨率 {width}x{height} | 配置帧率 {fps:.2f} FPS") - print("开始接收视频帧...(按 Ctrl+C 中断)") - - # 4. 初始化帧统计参数 - frame_count = 0 # 总接收帧数 - start_time = time.time() # 统计起始时间 - - # 5. 循环异步读取视频帧(核心逻辑) - while True: - # 异步读取一帧(cv2.read是同步操作,用to_thread适配异步环境) - ret, frame = await asyncio.to_thread(cap.read) - - # 检查帧是否读取成功(流中断/结束时ret为False) - if not ret: - print(f"RTMP流状态: 帧读取失败(可能流已中断或结束)") - break - - # 帧计数累加 - frame_count += 1 - - # 6. 打印当前帧基础信息(与WHEP帧信息打印风格对齐) - print(f"收到帧 (第{frame_count}帧)") - print(f" 帧尺寸: {width}x{height}") - print(f" 配置帧率: {fps:.2f} FPS") - - # 7. 每100帧统计一次实际接收帧率(补充性能监控,与原RTMP示例逻辑一致) - if frame_count % 100 == 0: - elapsed_time = time.time() - start_time - actual_fps = frame_count / elapsed_time # 实际接收帧率(可能低于配置帧率) - print(f"---- 帧统计: 累计{frame_count}帧 | 实际平均帧率 {actual_fps:.2f} FPS ----") - - # (可选)帧数据处理入口:如需处理帧(如推流、分析),可在此处添加逻辑 - # 示例:yield frame (若需生成器模式,可调整函数为异步生成器) - - # 8. 异常处理(覆盖用户中断、通用错误) - except KeyboardInterrupt: - print(f"\n用户操作: 已通过 Ctrl+C 中断程序") - except Exception as e: - # 日志记录详细错误(便于问题排查),同时打印用户可见信息 - logger.error(f"RTMP流处理异常: {str(e)}", exc_info=True) - print(f"错误信息: {str(e)}") - finally: - # 9. 资源释放(无论成功/失败都确保释放,避免内存泄漏) - if cap is not None: - await asyncio.to_thread(cap.release) - print(f"\n资源释放: RTMP流已关闭") - print(f"最终统计: 共接收 {frame_count if 'frame_count' in locals() else 0} 帧") - - -if __name__ == "__main__": - RTMP_URL = "rtmp://192.168.110.25:1935/live/473b95a47e338301cbd96809ea7ac416" - - # 运行RTMP拉流任务(与WHEP一致的异步执行方式) - try: - asyncio.run(rtmp_pull_video_stream(RTMP_URL)) - except Exception as e: - print(f"程序启动失败: {str(e)}") \ No newline at end of file diff --git a/schema/face_schema.py b/schema/face_schema.py index 99b1a59..b611766 100644 --- a/schema/face_schema.py +++ b/schema/face_schema.py @@ -23,8 +23,8 @@ class FaceResponse(BaseModel): """人脸记录响应模型(仍包含ID,由数据库生成后返回)""" id: int = Field(..., description="主键ID(数据库自增)") name: str = Field(None, description="名称") - eigenvalue: str = Field(None, description="特征(暂为None)") + eigenvalue: str | None = Field(None, description="特征(可为空)") created_at: datetime = Field(..., description="记录创建时间") updated_at: datetime = Field(..., description="记录更新时间") - model_config = {"from_attributes": True} \ No newline at end of file + model_config = {"from_attributes": True} diff --git a/service/device_service.py b/service/device_service.py index dde502d..e5d850e 100644 --- a/service/device_service.py +++ b/service/device_service.py @@ -1,6 +1,6 @@ import json -from fastapi import APIRouter, Query, HTTPException +from fastapi import APIRouter, Query, HTTPException,Request from mysql.connector import Error as MySQLError from ds.db import db @@ -108,7 +108,7 @@ def update_online_status_by_ip(client_ip: str, online_status: int) -> bool: # 原有接口保持不变 # ------------------------------ @router.post("/add", response_model=APIResponse, summary="创建设备信息") -async def create_device(device_data: DeviceCreateRequest): +async def create_device(device_data: DeviceCreateRequest, request: Request): # 注入Request对象 # 原有代码保持不变 conn = None cursor = None @@ -125,11 +125,10 @@ async def create_device(device_data: DeviceCreateRequest): return APIResponse( code=200, message=f"设备IP {device_data.ip} 已存在,返回已有设备信息", - data=DeviceResponse(**existing_device) + data=DeviceResponse(** existing_device) ) - from fastapi import Request - request = Request(scope={"type": "http"}) + # 直接使用注入的request对象获取用户代理 user_agent = request.headers.get("User-Agent", "").lower() if user_agent == "default": @@ -184,7 +183,6 @@ async def create_device(device_data: DeviceCreateRequest): finally: db.close_connection(conn, cursor) - @router.get("/", response_model=APIResponse, summary="获取设备列表(支持筛选分页)") async def get_device_list( page: int = Query(1, ge=1, description="页码,默认第1页"), diff --git a/service/face_service.py b/service/face_service.py index 8c26af6..c31f2d2 100644 --- a/service/face_service.py +++ b/service/face_service.py @@ -6,15 +6,15 @@ from schema.face_schema import FaceCreateRequest, FaceUpdateRequest, FaceRespons from schema.response_schema import APIResponse from middle.auth_middleware import get_current_user from schema.user_schema import UserResponse -from ocr.feature_extraction import BinaryFaceFeatureHandler + +from util.face_util import add_binary_data,get_average_feature +#初始化实例 router = APIRouter( prefix="/faces", tags=["人脸管理"] ) -# 创建 BinaryFaceFeatureHandler 的实例 -binary_face_feature_handler = BinaryFaceFeatureHandler() # ------------------------------ @@ -33,6 +33,8 @@ async def create_face( - ID 由数据库自动生成,无需前端传入 - 暂不处理文件内容,eigenvalue 设为 None """ + + # 调用你的方法 conn = None cursor = None try: @@ -45,14 +47,24 @@ async def create_face( # 把文件转为二进制数组 file_content = await file.read() - # 调用人脸识别得到特征值 + # 计算特征值 + flag, eigenvalue = add_binary_data(file_content) + + if flag == False: + raise HTTPException( + status_code=500, + detail="未检测到人脸" + ) + + # 打印数组长度 + print(f"文件大小:{len(file_content)} 字节") # 2. 插入数据库:无需传 ID(自增),只传 name 和 eigenvalue(None) insert_query = """ INSERT INTO face (name, eigenvalue) VALUES (%s, %s) """ - cursor.execute(insert_query, (face_create.name, None)) + cursor.execute(insert_query, (face_create.name, str(eigenvalue))) conn.commit() # 3. 获取数据库自动生成的 ID(关键:用 LAST_INSERT_ID() 查刚插入的记录) @@ -60,19 +72,45 @@ async def create_face( cursor.execute(select_new_query) created_face = cursor.fetchone() + if not created_face: + raise HTTPException( + status_code=500, + detail="创建人脸记录成功,但无法获取新创建的记录" + ) + return APIResponse( code=201, message=f"人脸记录创建成功(ID:{created_face['id']},文件名:{file.filename})", - data=FaceResponse(**created_face) + data=FaceResponse(** created_face) ) except MySQLError as e: if conn: conn.rollback() - raise Exception(f"创建人脸记录失败:{str(e)}") from e + # 改为使用HTTPException + raise HTTPException( + status_code=500, + detail=f"创建人脸记录失败:{str(e)}" + ) from e + except Exception as e: + # 捕获其他可能的异常 + raise HTTPException( + status_code=500, + detail=f"服务器错误:{str(e)}" + ) from e finally: await file.close() # 关闭文件流 db.close_connection(conn, cursor) + # 调用人脸识别得到特征值(这里可以添加你的人脸识别逻辑) + flag, eigenvalue = add_binary_data(file_content) + if flag == False: + raise HTTPException( + status_code=500, + detail="未检测到人脸" + ) + + # 将 eigenvalue 转为 str + eigenvalue = str(eigenvalue) # ------------------------------ # 2. 获取单个人脸记录(不变,用自增ID查询) @@ -104,18 +142,21 @@ async def get_face( data=FaceResponse(**face) ) except MySQLError as e: - raise Exception(f"查询人脸记录失败:{str(e)}") from e + # 改为使用HTTPException + raise HTTPException( + status_code=500, + detail=f"查询人脸记录失败:{str(e)}" + ) from e finally: db.close_connection(conn, cursor) -# 后续 3.获取所有、4.更新、5.删除 接口(不变,仅用数据库自增的 ID 操作,无需修改) +# 后续 3.获取所有、4.更新、5.删除 接口(修复异常处理) # ------------------------------ # 3. 获取所有人脸记录(不变) # ------------------------------ @router.get("", response_model=APIResponse, summary="获取所有人脸记录") async def get_all_faces( - current_user: UserResponse = Depends(get_current_user) ): conn = None cursor = None @@ -130,10 +171,13 @@ async def get_all_faces( return APIResponse( code=200, message="所有人脸记录查询成功", - data=[FaceResponse(**face) for face in faces] + data=[FaceResponse(** face) for face in faces] ) except MySQLError as e: - raise Exception(f"查询所有人脸记录失败:{str(e)}") from e + raise HTTPException( + status_code=500, + detail=f"查询所有人脸记录失败:{str(e)}" + ) from e finally: db.close_connection(conn, cursor) @@ -194,7 +238,10 @@ async def update_face( except MySQLError as e: if conn: conn.rollback() - raise Exception(f"更新人脸记录失败:{str(e)}") from e + raise HTTPException( + status_code=500, + detail=f"更新人脸记录失败:{str(e)}" + ) from e finally: db.close_connection(conn, cursor) @@ -234,7 +281,10 @@ async def delete_face( except MySQLError as e: if conn: conn.rollback() - raise Exception(f"删除人脸记录失败:{str(e)}") from e + raise HTTPException( + status_code=500, + detail=f"删除人脸记录失败:{str(e)}" + ) from e finally: db.close_connection(conn, cursor) @@ -249,38 +299,43 @@ def get_all_face_name_with_eigenvalue() -> dict: conn = None cursor = None try: + # 1. 建立数据库连接并获取游标(dictionary=True使结果以字典形式返回) conn = db.get_connection() cursor = conn.cursor(dictionary=True) - # 只查询需要的字段,提高效率 + # 2. 执行SQL查询:只获取name非空的记录,减少数据传输 query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL" cursor.execute(query) - faces = cursor.fetchall() + faces = cursor.fetchall() # 返回结果:列表套字典,如 [{"name":"张三","eigenvalue":...}, ...] - # 先收集所有名称对应的特征值列表(处理重复名称) + # 3. 收集同一名称对应的所有特征值(处理名称重复场景) name_to_eigenvalues = {} for face in faces: name = face["name"] eigenvalue = face["eigenvalue"] + # 若名称已存在,追加特征值;否则新建列表存储 if name in name_to_eigenvalues: name_to_eigenvalues[name].append(eigenvalue) else: name_to_eigenvalues[name] = [eigenvalue] - # 构建最终字典:重复名称取平均特征值,唯一名称直接取特征值 + # 4. 构建最终字典:重复名称取平均,唯一名称直接取特征值 face_dict = {} for name, eigenvalues in name_to_eigenvalues.items(): - print("调用的特征值是:" + eigenvalues) + + # 处理特征值:多个则求平均,单个则直接使用 if len(eigenvalues) > 1: - # 调用平均特征值计算方法 - face_dict[name] = binary_face_feature_handler.get_average_feature(eigenvalues) + # 调用外部方法计算平均特征值(需确保binary_face_feature_handler已正确导入) + face_dict[name] = get_average_feature(eigenvalues) else: + # 取列表中唯一的特征值(避免value为列表类型) face_dict[name] = eigenvalues[0] return face_dict except MySQLError as e: + # 捕获数据库异常,添加上下文信息后重新抛出(便于定位问题) raise Exception(f"获取人脸名称与特征值失败:{str(e)}") from e finally: - # 确保资源释放 - db.close_connection(conn, cursor) + # 5. 无论是否异常,均释放数据库连接和游标(避免资源泄漏) + db.close_connection(conn, cursor) \ No newline at end of file diff --git a/util/face_util.py b/util/face_util.py new file mode 100644 index 0000000..f3a1ad2 --- /dev/null +++ b/util/face_util.py @@ -0,0 +1,145 @@ +import cv2 +import numpy as np +import insightface +from insightface.app import FaceAnalysis +from io import BytesIO +from PIL import Image + +# 全局变量存储InsightFace引擎和特征列表 +_insightface_app = None +_feature_list = [] + + +def init_insightface(): + """初始化InsightFace引擎""" + global _insightface_app + try: + print("正在初始化InsightFace引擎...") + app = FaceAnalysis(name='buffalo_l', root='~/.insightface') + app.prepare(ctx_id=0, det_size=(640, 640)) + print("InsightFace引擎初始化完成") + _insightface_app = app + return app + except Exception as e: + print(f"InsightFace初始化失败: {e}") + return None + + +def add_binary_data(binary_data): + """ + 接收单张图片的二进制数据,提取特征并保存 + + 参数: + binary_data: 图片的二进制数据(bytes类型) + + 返回: + 成功提取特征时返回 (True, 特征值numpy数组) + 失败时返回 (False, None) + """ + global _insightface_app, _feature_list + + if not _insightface_app: + print("引擎未初始化,无法处理") + return False, None + + try: + # 直接处理二进制数据:转换为图像格式 + img = Image.open(BytesIO(binary_data)) + frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + + # 提取特征 + faces = _insightface_app.get(frame) + if faces: + # 获取当前提取的特征值 + current_feature = faces[0].embedding + # 添加到特征列表 + _feature_list.append(current_feature) + print(f"已累计 {len(_feature_list)} 个特征") + # 返回成功标志和当前特征值 + return True, current_feature + else: + print("二进制数据中未检测到人脸") + return False, None + except Exception as e: + print(f"处理二进制数据出错: {e}") + return False, None + + +def get_average_feature(features=None): + """ + 计算多个特征向量的平均值 + + 参数: + features: 可选,特征值列表。如果未提供,则使用全局存储的_feature_list + 每个元素可以是字符串格式或numpy数组 + + 返回: + 单一平均特征向量的numpy数组,若无可计算数据则返回None + """ + global _feature_list + + # 如果未提供features参数,则使用全局特征列表 + if features is None: + features = _feature_list + + try: + # 验证输入是否为列表且不为空 + if not isinstance(features, list) or len(features) == 0: + print("输入必须是包含至少一个特征值的列表") + return None + + # 处理每个特征值 + processed_features = [] + for i, embedding in enumerate(features): + try: + if isinstance(embedding, str): + # 处理包含括号和逗号的字符串格式 + embedding_str = embedding.replace('[', '').replace(']', '').replace(',', ' ').strip() + embedding_list = [float(num) for num in embedding_str.split() if num.strip()] + embedding_np = np.array(embedding_list, dtype=np.float32) + else: + embedding_np = np.array(embedding, dtype=np.float32) + + # 验证特征值格式 + if len(embedding_np.shape) == 1: + processed_features.append(embedding_np) + print(f"已添加第 {i + 1} 个特征值用于计算平均值") + else: + print(f"跳过第 {i + 1} 个特征值,不是一维数组") + + except Exception as e: + print(f"处理第 {i + 1} 个特征值时出错: {e}") + + # 确保有有效的特征值 + if not processed_features: + print("没有有效的特征值用于计算平均值") + return None + + # 检查所有特征向量维度是否相同 + dims = {feat.shape[0] for feat in processed_features} + if len(dims) > 1: + print(f"特征值维度不一致,无法计算平均值。检测到的维度: {dims}") + return None + + # 计算平均值 + avg_feature = np.mean(processed_features, axis=0) + print(f"成功计算 {len(processed_features)} 个特征值的平均特征向量,维度: {avg_feature.shape[0]}") + + return avg_feature + + except Exception as e: + print(f"计算平均特征值时出错: {e}") + return None + + +def clear_features(): + """清空已存储的特征数据""" + global _feature_list + _feature_list = [] + print("已清空所有特征数据") + + +def get_feature_list(): + """获取当前存储的特征列表""" + global _feature_list + return _feature_list.copy() # 返回副本防止外部直接修改 \ No newline at end of file diff --git a/ws.html b/ws.html deleted file mode 100644 index d81ceb2..0000000 --- a/ws.html +++ /dev/null @@ -1,482 +0,0 @@ - - - - - - WebSocket 测试工具 - - - -
-

WebSocket 测试工具

- - -
-
连接状态:
-
未连接
-
服务地址:
-
ws://192.168.110.25:8000/ws
-
连接时间:
-
-
-
- - -
- - - - -
- - - - - -
-
- - -
-

发送自定义消息

- - -
- - -
-

消息日志

-
-
[加载完成] 请点击「建立连接」开始测试
-
- -
-
- - - - \ No newline at end of file diff --git a/ws/ws.py b/ws/ws.py index f125018..5f6571c 100644 --- a/ws/ws.py +++ b/ws/ws.py @@ -4,314 +4,300 @@ import json import os from contextlib import asynccontextmanager from typing import Dict, Optional, AsyncGenerator -from concurrent.futures import ThreadPoolExecutor # 新增:显式线程池 - from service.device_service import update_online_status_by_ip, increment_alarm_count_by_ip from service.device_action_service import add_device_action from schema.device_action_schema import DeviceActionCreate +from core.all import detect import cv2 import numpy as np from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI -from queue import Queue # 线程安全队列,无需额外Lock +from core.all import load_model -from ocr.model_violation_detector import MultiModelViolationDetector - -# -------------------------- 配置调整 -------------------------- -# 模型路径(建议改为环境变量) -YOLO_MODEL_PATH = r"D:\Git\bin\video\ocr\models\best.pt" -OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml" - -# 核心优化:模型池大小(决定最大并发任务数,显存占用=大小×单模型显存) -MODEL_POOL_SIZE = 5 # 示例:设为5,支持5个任务并行,显存会明显上升 -THREAD_POOL_SIZE = MODEL_POOL_SIZE * 2 # 线程池大小≥模型池,避免线程瓶颈 - -# 其他配置 -HEARTBEAT_INTERVAL = 30 # 心跳间隔(秒) +# 配置常量 +HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒) HEARTBEAT_TIMEOUT = 600 # 客户端超时阈值(秒) -WS_ENDPOINT = "/ws" # WebSocket端点 -FRAME_QUEUE_SIZE = 5 # 增大帧队列,允许缓存更多帧(避免丢帧) +WS_ENDPOINT = "/ws" # WebSocket端点路径 +FRAME_QUEUE_SIZE = 1 # 帧队列大小限制 -# -------------------------- 工具函数 -------------------------- + +# 工具函数:获取格式化时间字符串(统一时间戳格式) def get_current_time_str() -> str: return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + def get_current_time_file_str() -> str: return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") -# -------------------------- 模型池重构(核心修改1) -------------------------- -class ModelPool: - def __init__(self, pool_size: int = MODEL_POOL_SIZE): - self.pool = Queue(maxsize=pool_size) - # 移除冗余Lock:Queue.get()/put()本身线程安全 - self._init_models(pool_size) - print(f"[{get_current_time_str()}] 模型池初始化完成(共{pool_size}个实例,显存已预分配)") - def _init_models(self, pool_size: int): - """预加载所有模型实例(初始化时显存会一次性上升)""" - for i in range(pool_size): - try: - detector = MultiModelViolationDetector( - ocr_config_path=OCR_CONFIG_PATH, - yolo_model_path=YOLO_MODEL_PATH, - ocr_confidence_threshold=0.5 - ) - self.pool.put(detector) - print(f"[{get_current_time_str()}] 模型实例{i+1}/{pool_size}加载完成") - except Exception as e: - raise RuntimeError(f"模型实例{i+1}加载失败:{str(e)}") - - def get_model(self) -> MultiModelViolationDetector: - """获取模型(阻塞直到有空闲实例,确保并发安全)""" - return self.pool.get() - - def return_model(self, detector: MultiModelViolationDetector): - """归还模型(立即释放资源供其他任务使用)""" - self.pool.put(detector) - -# -------------------------- 全局资源初始化 -------------------------- -model_pool = ModelPool(pool_size=MODEL_POOL_SIZE) # 初始化模型池(预占显存) -thread_pool = ThreadPoolExecutor( # 显式创建线程池(核心修改2) - max_workers=THREAD_POOL_SIZE, - thread_name_prefix="ModelWorker-" # 线程命名,便于调试 -) - -# -------------------------- 客户端连接封装(核心修改3) -------------------------- +# 客户端连接封装 class ClientConnection: def __init__(self, websocket: WebSocket, client_ip: str): self.websocket = websocket self.client_ip = client_ip self.last_heartbeat = datetime.datetime.now() - self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE) # 增大队列 + self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE) self.consumer_task: Optional[asyncio.Task] = None - # 移除“客户端独占模型”:不再持有detector属性 def update_heartbeat(self): + """更新心跳时间(客户端发送心跳时调用)""" self.last_heartbeat = datetime.datetime.now() def is_alive(self) -> bool: + """判断客户端是否存活(心跳超时检查)""" timeout = (datetime.datetime.now() - self.last_heartbeat).total_seconds() return timeout < HEARTBEAT_TIMEOUT def start_consumer(self): - """启动帧消费任务(每个客户端一个独立任务)""" + """启动帧消费任务""" self.consumer_task = asyncio.create_task(self.consume_frames()) return self.consumer_task async def send_frame_permit(self): - """发送帧许可信号(允许客户端继续发帧)""" + """ + 发送「帧发送许可信号」 + 通知客户端可发送下一帧图像 + """ try: - await self.websocket.send_json({ + frame_permit_msg = { "type": "frame", "timestamp": get_current_time_str(), "client_ip": self.client_ip - }) + } + await self.websocket.send_json(frame_permit_msg) + print(f"[{get_current_time_str()}] 客户端{self.client_ip}:已发送帧发送许可信号(取帧后立即通知)") except Exception as e: - print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧许可发送失败 - {str(e)}") + print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧许可信号发送失败 - {str(e)}") async def consume_frames(self) -> None: - """消费帧队列(并发核心:每帧临时借模型处理)""" + """消费队列中的帧并处理(核心调整:取帧后立即发许可,再处理帧)""" try: while True: - # 1. 从队列取帧(无帧时阻塞) + # 1. 从队列取出帧(阻塞直到有帧可用) frame_data = await self.frame_queue.get() - # 2. 立即发送下一帧许可(让客户端持续发帧,积累并发任务) - await self.send_frame_permit() + + # -------------------------- 核心修改:取出帧后立即发送下一帧许可 -------------------------- + await self.send_frame_permit() # 取帧即通知客户端发下一帧,无需等处理完成 + # ----------------------------------------------------------------------------------------- + try: - # 3. 并行处理帧(核心:任务级借模型) + # 2. 处理取出的帧(即使处理慢,客户端也已收到许可,可提前准备下一帧) await self.process_frame(frame_data) finally: - self.frame_queue.task_done() # 标记帧处理完成 + # 3. 标记帧任务完成(无论处理成功/失败,都需清理队列) + self.frame_queue.task_done() + except asyncio.CancelledError: print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧消费任务已取消") except Exception as e: - print(f"[{get_current_time_str()}] 客户端{self.client_ip}:消费逻辑错误 - {str(e)}") + print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧消费逻辑错误 - {str(e)}") async def process_frame(self, frame_data: bytes) -> None: - """处理单帧(核心修改4:任务级借还模型)""" - # 1. 临时借用模型(阻塞直到有空闲实例,显存随借用数上升) - detector = model_pool.get_model() + """处理单帧图像数据(检测违规后发送危险通知 + 调用违规次数加一方法)""" + # 二进制数据转OpenCV图像 + nparr = np.frombuffer(frame_data, np.uint8) + img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + if img is None: + print(f"[{get_current_time_str()}] 客户端{self.client_ip}:无法解析图像数据") + return + + # 确保图像保存目录存在 + os.makedirs('images', exist_ok=True) + + # 保存图像(按IP+时间戳命名,避免冲突) + filename = f"images/{self.client_ip.replace('.', '_')}_{get_current_time_file_str()}.jpg" try: - # 2. 二进制转OpenCV图像 - nparr = np.frombuffer(frame_data, np.uint8) - img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) - if img is None: - print(f"[{get_current_time_str()}] 客户端{self.client_ip}:图像解析失败") - return - - # 3. 保存图像(可选) - os.makedirs('images', exist_ok=True) - filename = f"images/{self.client_ip.replace('.', '_')}_{get_current_time_file_str()}.jpg" cv2.imwrite(filename, img) - - # 4. 显式线程池执行AI检测(真正并发,无线程瓶颈) - loop = asyncio.get_running_loop() - has_violation, violation_type, details = await loop.run_in_executor( - thread_pool, # 用自定义线程池,避免默认线程不足 - detector.detect_violations, # 临时借用的模型 - img # 输入图像 - ) - - # 5. 违规处理(与原逻辑一致) + print(f"[{get_current_time_str()}] 图像已保存至:{filename}") + has_violation, data, type = detect(img) + print(has_violation) + print(type) + print(data) if has_violation: - print(f"[{get_current_time_str()}] 客户端{self.client_ip}:违规 - {violation_type}") - # 违规次数更新(用线程池避免阻塞事件循环) - await loop.run_in_executor(thread_pool, increment_alarm_count_by_ip, self.client_ip) - # 发送危险通知 - await self.websocket.send_json({ + print( + f"[{get_current_time_str()}] 客户端{self.client_ip}:检测到违规 - 类型: {type}, 详情: {data}") + + # 调用违规次数加一方法 + try: + await asyncio.to_thread(increment_alarm_count_by_ip, self.client_ip) + print(f"[{get_current_time_str()}] 客户端{self.client_ip}:违规次数已+1") + except Exception as e: + print(f"[{get_current_time_str()}] 客户端{self.client_ip}:违规次数更新失败 - {str(e)}") + + # 发送「危险通知」 + danger_msg = { "type": "danger", "timestamp": get_current_time_str(), - "client_ip": self.client_ip, - "violation_type": violation_type, - "details": details - }) + "client_ip": self.client_ip + } + await self.websocket.send_json(danger_msg) else: - print(f"[{get_current_time_str()}] 客户端{self.client_ip}:无违规") + print(f"[{get_current_time_str()}] 客户端{self.client_ip}:未检测到违规") except Exception as e: - print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧处理错误 - {str(e)}") - finally: - # 6. 无论成功/失败,强制归还模型(核心:释放资源供其他任务使用) - model_pool.return_model(detector) - print(f"[{get_current_time_str()}] 客户端{self.client_ip}:模型已归还(可复用)") + print(f"[{get_current_time_str()}] 客户端{self.client_ip}:图像处理错误 - {str(e)}") -# -------------------------- 全局状态与心跳 -------------------------- + +# 全局状态管理 connected_clients: Dict[str, ClientConnection] = {} -client_lock = asyncio.Lock() # 保护客户端字典的异步锁 heartbeat_task: Optional[asyncio.Task] = None + +# 心跳检查(定时清理超时客户端 + 调用离线状态更新方法) async def heartbeat_checker(): - """心跳检查(移除模型归还逻辑,因模型已任务级归还)""" while True: current_time = get_current_time_str() - async with client_lock: - # 筛选超时客户端 - timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()] + timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()] - for ip in timeout_ips: - async with client_lock: - conn = connected_clients.get(ip) - if not conn: - continue - # 取消消费任务+关闭连接 - if conn.consumer_task and not conn.consumer_task.done(): - conn.consumer_task.cancel() - await conn.websocket.close(code=1008, reason="心跳超时") - # 标记离线(用线程池) - loop = asyncio.get_running_loop() - await loop.run_in_executor(thread_pool, update_online_status_by_ip, ip, 0) - await loop.run_in_executor( - thread_pool, add_device_action, DeviceActionCreate(client_ip=ip, action=0) - ) - connected_clients.pop(ip) - print(f"[{current_time}] 客户端{ip}:超时离线(资源已清理)") + if timeout_ips: + print(f"[{current_time}] 心跳检查:{len(timeout_ips)}个客户端超时(IP:{timeout_ips})") + for ip in timeout_ips: + try: + conn = connected_clients[ip] + if conn.consumer_task and not conn.consumer_task.done(): + conn.consumer_task.cancel() + await conn.websocket.close(code=1008, reason="心跳超时") - # 打印在线状态 - async with client_lock: + # 超时设为离线并记录 + try: + await asyncio.to_thread(update_online_status_by_ip, ip, 0) + action_data = DeviceActionCreate(client_ip=ip, action=0) + await asyncio.to_thread(add_device_action, action_data) + print(f"[{current_time}] 客户端{ip}:已标记为离线并记录操作") + except Exception as e: + print(f"[{current_time}] 客户端{ip}:离线状态更新失败 - {str(e)}") + finally: + connected_clients.pop(ip, None) + else: print(f"[{current_time}] 心跳检查:{len(connected_clients)}个客户端在线") + await asyncio.sleep(HEARTBEAT_INTERVAL) -# -------------------------- 应用生命周期(核心修改5:管理线程池) -------------------------- + +# 应用生命周期管理 @asynccontextmanager async def lifespan(app: FastAPI): global heartbeat_task - # 启动心跳任务 heartbeat_task = asyncio.create_task(heartbeat_checker()) - print(f"[{get_current_time_str()}] 心跳任务启动(ID:{id(heartbeat_task)})") - print(f"[{get_current_time_str()}] 线程池启动(最大线程数:{THREAD_POOL_SIZE})") - yield # 应用运行期间 - # 清理资源 + print(f"[{get_current_time_str()}] 全局心跳检查任务启动(任务ID:{id(heartbeat_task)})") + yield if heartbeat_task and not heartbeat_task.done(): heartbeat_task.cancel() - await heartbeat_task - print(f"[{get_current_time_str()}] 心跳任务已关闭") - # 关闭线程池(等待所有任务完成) - thread_pool.shutdown(wait=True) - print(f"[{get_current_time_str()}] 线程池已关闭") + try: + await heartbeat_task + print(f"[{get_current_time_str()}] 全局心跳检查任务已取消") + except asyncio.CancelledError: + pass -# -------------------------- WebSocket路由 -------------------------- + +# 消息处理工具函数 +async def send_heartbeat_ack(conn: ClientConnection): + try: + heartbeat_ack_msg = { + "type": "heart", + "timestamp": get_current_time_str(), + "client_ip": conn.client_ip + } + await conn.websocket.send_json(heartbeat_ack_msg) + print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:已发送心跳确认") + return True + except Exception as e: + connected_clients.pop(conn.client_ip, None) + print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:心跳确认发送失败 - {str(e)}") + return False + + +async def handle_text_msg(conn: ClientConnection, text: str): + try: + msg = json.loads(text) + if msg.get("type") == "heart": + conn.update_heartbeat() + await send_heartbeat_ack(conn) + else: + print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:未知文本消息类型({msg.get('type')})") + except json.JSONDecodeError: + print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:无效JSON文本消息") + + +async def handle_binary_msg(conn: ClientConnection, data: bytes): + try: + conn.frame_queue.put_nowait(data) + print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:图像数据({len(data)}字节)已加入队列") + except asyncio.QueueFull: + print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:帧队列已满,丢弃当前图像数据") + + +# WebSocket路由配置 ws_router = APIRouter() + @ws_router.websocket(WS_ENDPOINT) async def websocket_endpoint(websocket: WebSocket): + # 加载模型 + load_model() await websocket.accept() client_ip = websocket.client.host if websocket.client else "unknown_ip" current_time = get_current_time_str() - print(f"[{current_time}] 客户端{client_ip}:连接建立") + print(f"[{current_time}] 客户端{client_ip}:WebSocket连接已建立") - new_conn = None is_online_updated = False - try: - # 处理重复连接(关闭旧连接) - async with client_lock: - if client_ip in connected_clients: - old_conn = connected_clients[client_ip] - if old_conn.consumer_task and not old_conn.consumer_task.done(): - old_conn.consumer_task.cancel() - await old_conn.websocket.close(code=1008, reason="新连接抢占") - connected_clients.pop(client_ip) - print(f"[{current_time}] 客户端{client_ip}:旧连接已关闭") - # 创建新连接+启动消费任务 + try: + # 处理重复连接 + if client_ip in connected_clients: + old_conn = connected_clients[client_ip] + if old_conn.consumer_task and not old_conn.consumer_task.done(): + old_conn.consumer_task.cancel() + await old_conn.websocket.close(code=1008, reason="同一IP新连接建立") + connected_clients.pop(client_ip) + print(f"[{current_time}] 客户端{client_ip}:已关闭旧连接") + + # 注册新连接 new_conn = ClientConnection(websocket, client_ip) + connected_clients[client_ip] = new_conn new_conn.start_consumer() - # 初始发送帧许可(让客户端立即发帧) + # 初始许可:连接建立后立即发一次,让客户端知道可发第一帧(后续靠取帧后自动发) await new_conn.send_frame_permit() - # 标记客户端在线 - loop = asyncio.get_running_loop() - await loop.run_in_executor(thread_pool, update_online_status_by_ip, client_ip, 1) - await loop.run_in_executor( - thread_pool, add_device_action, DeviceActionCreate(client_ip=client_ip, action=1) - ) - is_online_updated = True - async with client_lock: - connected_clients[client_ip] = new_conn - print(f"[{current_time}] 客户端{client_ip}:注册成功(在线数:{len(connected_clients)})") + # 标记上线并记录 + try: + await asyncio.to_thread(update_online_status_by_ip, client_ip, 1) + action_data = DeviceActionCreate(client_ip=client_ip, action=1) + await asyncio.to_thread(add_device_action, action_data) + print(f"[{current_time}] 客户端{client_ip}:已标记为在线并记录操作") + is_online_updated = True + except Exception as e: + print(f"[{current_time}] 客户端{client_ip}:上线状态更新失败 - {str(e)}") - # 消息循环(接收文本/二进制帧) + print(f"[{current_time}] 客户端{client_ip}:新连接注册成功,在线数:{len(connected_clients)}") + + # 消息循环 while True: data = await websocket.receive() if "text" in data: - # 处理文本消息(如心跳) - try: - msg = json.loads(data["text"]) - if msg.get("type") == "heart": - new_conn.update_heartbeat() - # 回复心跳确认 - await websocket.send_json({ - "type": "heart", - "timestamp": get_current_time_str(), - "client_ip": client_ip - }) - except json.JSONDecodeError: - print(f"[{get_current_time_str()}] 客户端{client_ip}:无效JSON") + await handle_text_msg(new_conn, data["text"]) elif "bytes" in data: - # 处理二进制帧(图像) - try: - await new_conn.frame_queue.put(data["bytes"]) - print(f"[{get_current_time_str()}] 客户端{client_ip}:帧已入队(队列大小:{new_conn.frame_queue.qsize()})") - except asyncio.QueueFull: - print(f"[{get_current_time_str()}] 客户端{client_ip}:帧队列满(丢弃当前帧)") + await handle_binary_msg(new_conn, data["bytes"]) except WebSocketDisconnect as e: - print(f"[{get_current_time_str()}] 客户端{client_ip}:主动断开(代码:{e.code})") + print(f"[{get_current_time_str()}] 客户端{client_ip}:主动断开连接(代码:{e.code})") except Exception as e: print(f"[{get_current_time_str()}] 客户端{client_ip}:连接异常 - {str(e)[:50]}") finally: - # 清理资源(无需归还模型,已在process_frame中归还) - if new_conn and client_ip in connected_clients: - async with client_lock: - conn = connected_clients.get(client_ip) - if conn: - if conn.consumer_task and not conn.consumer_task.done(): - conn.consumer_task.cancel() - # 标记离线(仅当在线状态已更新时) - if is_online_updated: - loop = asyncio.get_running_loop() - await loop.run_in_executor(thread_pool, update_online_status_by_ip, client_ip, 0) - await loop.run_in_executor( - thread_pool, add_device_action, DeviceActionCreate(client_ip=client_ip, action=0) - ) - connected_clients.pop(client_ip) - async with client_lock: - print(f"[{get_current_time_str()}] 客户端{client_ip}:资源清理完成(在线数:{len(connected_clients)})") + # 清理资源并标记离线 + if client_ip in connected_clients: + conn = connected_clients[client_ip] + if conn.consumer_task and not conn.consumer_task.done(): + conn.consumer_task.cancel() + + # 主动/异常断开时标记离线 + if is_online_updated: + try: + await asyncio.to_thread(update_online_status_by_ip, client_ip, 0) + action_data = DeviceActionCreate(client_ip=client_ip, action=0) + await asyncio.to_thread(add_device_action, action_data) + print(f"[{get_current_time_str()}] 客户端{client_ip}:断开后已标记为离线") + except Exception as e: + print(f"[{get_current_time_str()}] 客户端{client_ip}:断开后离线更新失败 - {str(e)}") + + connected_clients.pop(client_ip, None) + print(f"[{get_current_time_str()}] 客户端{client_ip}:资源已清理,在线数:{len(connected_clients)}")