最新可用

This commit is contained in:
2025-09-04 22:59:27 +08:00
parent ec6dbfde90
commit 30bf7c9fcb
42 changed files with 746 additions and 1967 deletions

2
.idea/Video.iml generated
View File

@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4"> <module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager"> <component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" /> <content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="video" jdkType="Python SDK" /> <orderEntry type="jdk" jdkName="Python 3.10" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" /> <orderEntry type="sourceFolder" forTests="false" />
</component> </component>
</module> </module>

2
.idea/misc.xml generated
View File

@ -3,5 +3,5 @@
<component name="Black"> <component name="Black">
<option name="sdkName" value="video" /> <option name="sdkName" value="video" />
</component> </component>
<component name="ProjectRootManager" version="2" project-jdk-name="video" project-jdk-type="Python SDK" /> <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10" project-jdk-type="Python SDK" />
</project> </project>

View File

@ -15,5 +15,5 @@ algorithm = HS256
access_token_expire_minutes = 30 access_token_expire_minutes = 30
[live] [live]
rtmp_url = rtmp://192.168.110.65:1935/live/ rtmp_url = rtmp://192.168.110.25:1935/live/
webrtc_url = http://192.168.110.65:1985/rtc/v1/whep/?app=live&stream= webrtc_url = http://192.168.110.25:1985/rtc/v1/whep/?app=live&stream=

45
core/all.py Normal file
View File

@ -0,0 +1,45 @@
from core.ocr import load_model as ocrLoadModel, detect as ocrDetect
from core.face import load_model as faceLoadModel, detect as faceDetect
from core.yolo import load_model as yoloLoadModel, detect as yoloDetect
# 添加一个标记变量用于监控load_model是否已被调用
_model_loaded = False
def load_model():
global _model_loaded
# 如果已经调用过,直接忽略
if _model_loaded:
return
# 首次调用时加载模型
ocrLoadModel()
faceLoadModel()
yoloLoadModel()
# 标记为已调用
_model_loaded = True
def detect(frame):
# 先进行YOLO检测
yolo_flag, yolo_result = yoloDetect(frame)
print("YOLO检测结果", yolo_result)
if yolo_flag:
return (True, yolo_result, "yolo")
# YOLO未检测到进行人脸检测
face_flag, face_result = faceDetect(frame)
print("人脸检测结果:", face_result)
if face_flag:
return (True, face_result, "face")
# 人脸未检测到进行OCR检测
ocr_flag, ocr_result = ocrDetect(frame)
print("OCR检测结果", ocr_result)
if ocr_flag:
return (True, ocr_result, "ocr")
# 所有检测都未检测到
return (False, "未检测到任何内容", "none")

113
core/face.py Normal file
View File

@ -0,0 +1,113 @@
import os
import numpy as np
import cv2
from PIL import Image # 确保正确导入Image类
from insightface.app import FaceAnalysis
# 导入获取人脸信息的服务
from service.face_service import get_all_face_name_with_eigenvalue
# 全局变量
_face_app = None
_known_faces_embeddings = {} # 存储姓名到特征值的映射
_known_faces_names = [] # 存储所有已知姓名
def load_model():
"""加载人脸识别模型及已知人脸特征库"""
global _face_app, _known_faces_embeddings, _known_faces_names
# 初始化InsightFace模型
try:
_face_app = FaceAnalysis(name='buffalo_l', root=os.path.expanduser('~/.insightface'))
_face_app.prepare(ctx_id=0, det_size=(640, 640))
except Exception as e:
print(f"Face model load failed: {e}")
return False
# 从服务获取所有人脸姓名和特征值
try:
face_data = get_all_face_name_with_eigenvalue()
# 处理获取到的人脸数据
for person_name, eigenvalue_data in face_data.items():
# 处理特征值数据 - 兼容数组和字符串两种格式
if isinstance(eigenvalue_data, np.ndarray):
# 如果已经是numpy数组直接使用
eigenvalue = eigenvalue_data.astype(np.float32)
elif isinstance(eigenvalue_data, str):
# 清理字符串:移除方括号、换行符和多余空格
cleaned = eigenvalue_data.replace('[', '').replace(']', '').replace('\n', '').strip()
# 按空格或逗号分割(处理可能的不同分隔符)
values = [v for v in cleaned.split() if v]
# 转换为数组
eigenvalue = np.array(list(map(float, values)), dtype=np.float32)
else:
# 不支持的类型
print(f"Unsupported eigenvalue type for {person_name}")
continue
# 归一化处理
norm = np.linalg.norm(eigenvalue)
if norm != 0:
eigenvalue = eigenvalue / norm
_known_faces_embeddings[person_name] = eigenvalue
_known_faces_names.append(person_name)
except Exception as e:
print(f"Error loading face data from service: {e}")
return True if _face_app else False
def detect(frame, threshold=0.4):
"""检测并识别人脸,返回结果元组(是否匹配到已知人脸, 结果字符串)"""
global _face_app, _known_faces_embeddings, _known_faces_names
if not _face_app or not _known_faces_names or frame is None:
return (False, "未初始化或无效帧")
try:
faces = _face_app.get(frame)
except Exception as e:
print(f"Face detect error: {e}")
return (False, f"检测错误: {str(e)}")
result_parts = []
has_matched = False # 新增标记:是否有匹配到的已知人脸
for face in faces:
# 特征归一化
embedding = face.embedding.astype(np.float32)
norm = np.linalg.norm(embedding)
if norm == 0:
continue
embedding = embedding / norm
# 对比已知人脸
max_sim, best_name = -1.0, "Unknown"
for name in _known_faces_names:
known_emb = _known_faces_embeddings[name]
sim = np.dot(embedding, known_emb)
if sim > max_sim:
max_sim = sim
best_name = name
# 判断匹配结果
is_match = max_sim >= threshold
if is_match:
has_matched = True # 只要有一个匹配成功就标记为True
bbox = face.bbox
result_parts.append(
f"{'匹配' if is_match else '不匹配'}: {best_name} (相似度: {max_sim:.2f}, 边界框: {bbox})"
)
# 构建结果字符串
if not result_parts:
result_str = "未检测到人脸"
else:
result_str = "; ".join(result_parts)
# 第一个返回值改为:是否匹配到已知人脸
return (has_matched, result_str)

BIN
core/models/best.pt Normal file

Binary file not shown.

76
core/ocr.py Normal file
View File

@ -0,0 +1,76 @@
import os
import cv2
from rapidocr import RapidOCR
from service.sensitive_service import get_all_sensitive_words
# 全局变量
_ocr_engine = None
_forbidden_words = set()
_conf_threshold = 0.5
ocr_config_path = os.path.join(os.path.dirname(__file__), "config", "config.yaml")
def load_model():
"""加载OCR引擎及违禁词列表"""
global _ocr_engine, _forbidden_words, _conf_threshold
# 加载违禁词
try:
_forbidden_words = get_all_sensitive_words()
except Exception as e:
print(f"Forbidden words load error: {e}")
# 初始化OCR引擎
if not os.path.exists(ocr_config_path):
print(f"OCR config not found: {ocr_config_path}")
return False
try:
_ocr_engine = RapidOCR(config_path=ocr_config_path)
except Exception as e:
print(f"OCR model load failed: {e}")
return False
return True if _ocr_engine else False
def detect(frame):
"""OCR检测并筛选违禁词返回(是否检测到违禁词, 结果字符串)"""
if not _ocr_engine or not _forbidden_words or frame is None or frame.size == 0:
return (False, "未初始化或无效帧")
try:
ocr_res = _ocr_engine(frame)
except Exception as e:
print(f"OCR detect error: {e}")
return (False, f"检测错误: {str(e)}")
if not ocr_res or not hasattr(ocr_res, 'txts') or not hasattr(ocr_res, 'scores'):
return (False, "无OCR结果")
# 处理OCR结果
texts = [t.strip() for t in ocr_res.txts if t and isinstance(t, str)]
confs = [c for c in ocr_res.scores if c and isinstance(c, (int, float))]
if len(texts) != len(confs):
return (False, "OCR结果格式异常")
# 筛选违禁词
vio_info = []
for txt, conf in zip(texts, confs):
if conf < _conf_threshold:
continue
matched = [w for w in _forbidden_words if w in txt]
if matched:
vio_info.append(f"文本: '{txt}' 包含违禁词: {', '.join(matched)} (置信度: {conf:.2f})")
# 构建结果字符串
has_text = len(texts) > 0
has_violation = len(vio_info) > 0
if not has_text:
return (False, "未识别到文本")
elif has_violation:
return (True, "; ".join(vio_info))
else:
return (False, "未检测到违禁词")

View File

@ -1,137 +0,0 @@
import asyncio
import logging
from aiortc import RTCPeerConnection, RTCSessionDescription
import aiohttp
from ocr.ocr_violation_detector import OCRViolationDetector
import logging
# 创建检测器实例
detector = OCRViolationDetector(
forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt",
ocr_confidence_threshold=0.7,
log_level=logging.INFO,
log_file="ocr_detection.log"
)
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("whep_video_puller")
async def whep_pull_video_stream(ip,whep_url):
"""
通过WHEP从指定URL拉取视频流并在收到每一帧时打印消息
Args:
whep_url: WHEP端点的URL
"""
pc = RTCPeerConnection()
# 添加连接状态变化监听
@pc.on("connectionstatechange")
async def on_connectionstatechange():
print(f"连接状态: {pc.connectionState}")
# 添加ICE连接状态变化监听
@pc.on("iceconnectionstatechange")
async def on_iceconnectionstatechange():
print(f"ICE连接状态: {pc.iceConnectionState}")
# 添加视频接收器
pc.addTransceiver("video", direction="recvonly")
# 处理接收到的视频轨道
@pc.on("track")
def on_track(track):
print(f"接收到轨道: {track.kind}")
if track.kind == "video":
print(f"轨道ID: {track.id}")
print(f"轨道就绪状态: {track.readyState}")
# 创建异步任务来处理视频帧
asyncio.ensure_future(handle_video_track(track))
async def handle_video_track(track):
"""处理视频轨道,接收并打印每一帧"""
frame_count = 0
print("开始处理视频轨道...")
while True:
try:
# 尝试接收帧
frame = await track.recv()
frame_count += 1
print(f"收到原始帧 (第{frame_count}帧)")
# 打印帧的基本信息
if hasattr(frame, 'width') and hasattr(frame, 'height'):
print(f" 尺寸: {frame.width}x{frame.height}")
if hasattr(frame, 'time_base'):
print(f" 时间基准: {frame.time_base}")
if hasattr(frame, 'pts'):
print(f" 显示时间戳: {frame.pts}")
has_violation, violations, confidences = OCRViolationDetector.detect(frame)
# 输出检测结果
if has_violation:
detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
for word, conf in zip(violations, confidences):
detector.logger.info(f"- {word} (置信度: {conf:.4f})")
else:
detector.logger.info("图片中未检测到违禁词")
except Exception as e:
print(f"接收帧时出错: {e}")
# 等待一段时间后重试
await asyncio.sleep(0.1)
continue
# 创建offer
offer = await pc.createOffer()
await pc.setLocalDescription(offer)
print(f"本地SDP信息:\n{offer.sdp}")
# 通过HTTP POST发送offer到WHEP端点
async with aiohttp.ClientSession() as session:
async with session.post(
whep_url,
data=offer.sdp,
headers={"Content-Type": "application/sdp"}
) as response:
if response.status != 201:
print(f"WHEP服务器返回错误: {response.status}")
print(f"响应内容: {await response.text()}")
raise Exception(f"WHEP服务器返回错误: {response.status}")
# 获取answer SDP
answer_sdp = await response.text()
# 创建RTCSessionDescription对象
answer = RTCSessionDescription(sdp=answer_sdp, type="answer")
print(f"收到远程SDP:\n{answer_sdp}")
# 设置远程描述
await pc.setRemoteDescription(answer)
print("连接已建立,开始接收视频流...")
# 保持连接,直到用户中断
try:
while True:
await asyncio.sleep(1)
# 检查连接状态
print(f"当前连接状态: {pc.connectionState}")
except KeyboardInterrupt:
print("用户中断,关闭连接...")
finally:
await pc.close()
if __name__ == "__main__":
# 替换为你的WHEP端点URL
WHEP_URL = "http://192.168.110.25:1985/rtc/v1/whep/?app=live&stream=473b95a47e338301cbd96809ea7ac416"
# 运行拉流任务
asyncio.run(whep_pull_video_stream(WHEP_URL))

View File

@ -1,112 +0,0 @@
import asyncio
import logging
import cv2
import time
from ocr.model_violation_detector import MultiModelViolationDetector
# 配置文件相对路径(根据实际目录结构调整)
YOLO_MODEL_PATH = "../ocr/models/best.pt" # 关键修正从core目录向上一级找ocr文件夹
FORBIDDEN_WORDS_PATH = "../ocr/forbidden_words.txt"
OCR_CONFIG_PATH = "../ocr/config/1.yaml"
KNOWN_FACES_DIR = "../ocr/known_faces"
# 创建检测器实例
detector = MultiModelViolationDetector(
forbidden_words_path=FORBIDDEN_WORDS_PATH,
ocr_config_path=OCR_CONFIG_PATH,
yolo_model_path=YOLO_MODEL_PATH,
known_faces_dir=KNOWN_FACES_DIR,
ocr_confidence_threshold=0.5
)
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("rtmp_video_puller")
async def rtmp_pull_video_stream(rtmp_url):
"""
通过RTMP从指定URL拉取视频流并进行违规检测
"""
cap = None # 初始化视频捕获对象
try:
# 异步打开RTMP流
cap = await asyncio.to_thread(
cv2.VideoCapture,
rtmp_url,
cv2.CAP_FFMPEG # 指定FFmpeg后端确保RTMP兼容性
)
# 检查RTMP流是否成功打开
is_opened = await asyncio.to_thread(cap.isOpened)
if not is_opened:
raise Exception(f"RTMP流打开失败: {rtmp_url}请检查URL有效性和FFmpeg环境")
# 获取RTMP流基础信息
width = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_WIDTH)
height = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_HEIGHT)
fps = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FPS)
# 处理异常情况
fps = fps if fps > 0 else 30.0
width, height = int(width), int(height)
# 打印流初始化成功信息
print(f"RTMP流状态: 已成功连接")
print(f"流基础信息: 分辨率 {width}x{height} | 配置帧率 {fps:.2f} FPS")
print("开始接收视频帧...(按 Ctrl+C 中断)")
# 初始化帧统计参数
frame_count = 0
start_time = time.time()
# 循环读取视频帧
while True:
ret, frame = await asyncio.to_thread(cap.read)
if not ret:
print(f"RTMP流状态: 帧读取失败(可能流已中断或结束)")
break
frame_count += 1
# 打印当前帧信息
print(f"收到帧 (第{frame_count}帧)")
print(f" 帧尺寸: {width}x{height}")
print(f" 配置帧率: {fps:.2f} FPS")
if frame is not None:
has_violation, violation_type, details = detector.detect_violations(frame)
if has_violation:
print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
else:
print("未检测到任何违规内容")
else:
print(f"无法读取测试图像")
# 每100帧统计一次实际接收帧率
if frame_count % 100 == 0:
elapsed_time = time.time() - start_time
actual_fps = frame_count / elapsed_time
print(f"---- 帧统计: 累计{frame_count}帧 | 实际平均帧率 {actual_fps:.2f} FPS ----")
except KeyboardInterrupt:
print(f"\n用户操作: 已通过 Ctrl+C 中断程序")
except Exception as e:
logger.error(f"RTMP流处理异常: {str(e)}", exc_info=True)
print(f"错误信息: {str(e)}")
finally:
if cap is not None:
await asyncio.to_thread(cap.release)
print(f"\n资源释放: RTMP流已关闭")
print(f"最终统计: 共接收 {frame_count if 'frame_count' in locals() else 0}")
if __name__ == "__main__":
RTMP_URL = "rtmp://192.168.110.25:1935/live/473b95a47e338301cbd96809ea7ac416"
try:
asyncio.run(rtmp_pull_video_stream(RTMP_URL))
except Exception as e:
print(f"程序启动失败: {str(e)}")

55
core/yolo.py Normal file
View File

@ -0,0 +1,55 @@
import os
import cv2
from ultralytics import YOLO
# 全局变量
_yolo_model = None
model_path = os.path.join(os.path.dirname(__file__), "models", "best.pt")
def load_model():
"""加载YOLO目标检测模型"""
global _yolo_model
try:
_yolo_model = YOLO(model_path)
except Exception as e:
print(f"YOLO model load failed: {e}")
return False
return True if _yolo_model else False
def detect(frame, conf_threshold=0.2):
"""YOLO目标检测返回(是否识别到, 结果字符串)"""
global _yolo_model
if not _yolo_model or frame is None:
return (False, "未初始化或无效帧")
try:
results = _yolo_model(frame, conf=conf_threshold)
# 检查是否有检测结果
has_results = len(results[0].boxes) > 0 if results else False
if not has_results:
return (False, "未检测到目标")
# 构建结果字符串
result_parts = []
for box in results[0].boxes:
cls = int(box.cls[0])
conf = float(box.conf[0])
bbox = [float(x) for x in box.xyxy[0]]
class_name = _yolo_model.names[cls] if hasattr(_yolo_model, 'names') else f"类别{cls}"
result_parts.append(f"{class_name} (置信度: {conf:.2f}, 边界框: {bbox})")
result_str = "; ".join(result_parts)
return (has_results, result_str)
except Exception as e:
print(f"YOLO detect error: {e}")
return (False, f"检测错误: {str(e)}")

23
main.py
View File

@ -1,12 +1,19 @@
import uvicorn from PIL import Image # 正确导入
from fastapi import FastAPI import numpy as np
import uvicorn
from PIL import Image
from fastapi import FastAPI
from core.all import load_model,detect
from ds.config import SERVER_CONFIG from ds.config import SERVER_CONFIG
from middle.error_handler import global_exception_handler from middle.error_handler import global_exception_handler
from service.user_service import router as user_router from service.user_service import router as user_router
from service.sensitive_service import router as sensitive_router
from service.face_service import router as face_router
from service.device_service import router as device_router from service.device_service import router as device_router
from ws.ws import ws_router, lifespan from ws.ws import ws_router, lifespan
# ------------------------------ # ------------------------------
# 初始化 FastAPI 应用、指定生命周期管理 # 初始化 FastAPI 应用、指定生命周期管理
# ------------------------------ # ------------------------------
@ -22,6 +29,8 @@ app = FastAPI(
# ------------------------------ # ------------------------------
app.include_router(user_router) app.include_router(user_router)
app.include_router(device_router) app.include_router(device_router)
app.include_router(face_router)
app.include_router(sensitive_router)
app.include_router(ws_router) app.include_router(ws_router)
# ------------------------------ # ------------------------------
@ -33,11 +42,19 @@ app.add_exception_handler(Exception, global_exception_handler)
# 启动服务 # 启动服务
# ------------------------------ # ------------------------------
if __name__ == "__main__": if __name__ == "__main__":
# -------------------------- 配置调整 --------------------------
# 模型配置路径(建议改为环境变量)
YOLO_MODEL_PATH = r"/core/models\best.pt"
OCR_CONFIG_PATH = r"/core/config\config.yaml"
# 初始化项目默认端口设为8000避免初始化失败时port未定义
port = int(SERVER_CONFIG.get("port", 8000)) port = int(SERVER_CONFIG.get("port", 8000))
# 启动 UVicorn 服务
uvicorn.run( uvicorn.run(
app="main:app", app="main:app",
host="0.0.0.0", host="0.0.0.0",
port=port, port=port,
reload=True, workers=8,
ws="websockets" ws="websockets"
) )

View File

@ -8,7 +8,8 @@ from passlib.context import CryptContext
from ds.config import JWT_CONFIG from ds.config import JWT_CONFIG
from ds.db import db from ds.db import db
from service.user_service import UserResponse
# 移除这里的 from service.user_service import UserResponse 导入
# ------------------------------ # ------------------------------
# 密码加密配置 # 密码加密配置
@ -25,6 +26,7 @@ ACCESS_TOKEN_EXPIRE_MINUTES = int(JWT_CONFIG["access_token_expire_minutes"])
# OAuth2 依赖(从请求头获取 Token、格式Bearer <token> # OAuth2 依赖(从请求头获取 Token、格式Bearer <token>
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login")
# ------------------------------ # ------------------------------
# 密码工具函数 # 密码工具函数
# ------------------------------ # ------------------------------
@ -32,10 +34,12 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
"""验证明文密码与加密密码是否匹配""" """验证明文密码与加密密码是否匹配"""
return pwd_context.verify(plain_password, hashed_password) return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str: def get_password_hash(password: str) -> str:
"""对明文密码进行 bcrypt 加密""" """对明文密码进行 bcrypt 加密"""
return pwd_context.hash(password) return pwd_context.hash(password)
# ------------------------------ # ------------------------------
# JWT 工具函数 # JWT 工具函数
# ------------------------------ # ------------------------------
@ -53,11 +57,15 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt return encoded_jwt
# ------------------------------ # ------------------------------
# 认证依赖(获取当前登录用户) # 认证依赖(获取当前登录用户)
# ------------------------------ # ------------------------------
def get_current_user(token: str = Depends(oauth2_scheme)) -> UserResponse: def get_current_user(token: str = Depends(oauth2_scheme)): # 移除返回类型注解
"""从 Token 中解析用户信息、验证通过后返回当前用户""" """从 Token 中解析用户信息、验证通过后返回当前用户"""
# 延迟导入,打破循环依赖
from schema.user_schema import UserResponse # 在这里导入
# 认证失败异常 # 认证失败异常
credentials_exception = HTTPException( credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -89,7 +97,7 @@ def get_current_user(token: str = Depends(oauth2_scheme)) -> UserResponse:
raise credentials_exception # 用户不存在 raise credentials_exception # 用户不存在
# 转换为 UserResponse 模型(自动校验字段) # 转换为 UserResponse 模型(自动校验字段)
return UserResponse(** user) return UserResponse(**user)
except Exception as e: except Exception as e:
raise credentials_exception from e raise credentials_exception from e
finally: finally:

View File

@ -1,139 +0,0 @@
import os
import cv2
import numpy as np
import insightface
from insightface.app import FaceAnalysis
class FaceRecognizer:
"""
封装InsightFace人脸识别功能支持从文件夹加载已知人脸。
"""
def __init__(self, known_faces_dir: str):
self.known_faces_dir = known_faces_dir
self.app = self._initialize_insightface()
self.known_faces_embeddings = {}
self.known_faces_names = []
self._load_known_faces()
def _initialize_insightface(self):
"""初始化InsightFace FaceAnalysis应用"""
print("初始化InsightFace引擎...")
try:
app = FaceAnalysis(name='buffalo_l', root='~/.insightface')
app.prepare(ctx_id=0, det_size=(640, 640))
print("InsightFace引擎初始化完成")
return app
except Exception as e:
print(f"InsightFace初始化失败: {e}")
print("请检查依赖是否安装及模型是否可访问")
return None
def _load_known_faces(self):
"""加载已知人脸特征"""
if not os.path.exists(self.known_faces_dir):
print(f"已知人脸目录不存在,已创建: {self.known_faces_dir}")
os.makedirs(self.known_faces_dir, exist_ok=True)
return
print(f"从目录加载人脸特征: {self.known_faces_dir}")
for person_name in os.listdir(self.known_faces_dir):
person_dir = os.path.join(self.known_faces_dir, person_name)
if os.path.isdir(person_dir):
print(f"处理人物: {person_name}")
embeddings = []
for filename in os.listdir(person_dir):
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
image_path = os.path.join(person_dir, filename)
try:
img = cv2.imread(image_path)
if img is None:
print(f"无法读取图片: {image_path},已跳过")
continue
faces = self.app.get(img)
if faces:
embeddings.append(faces[0].embedding)
print(f"提取特征成功: {filename}")
else:
print(f"未检测到人脸: {filename},已跳过")
except Exception as e:
print(f"处理图片出错 {image_path}: {e}")
if embeddings:
self.known_faces_embeddings[person_name] = np.array(embeddings).mean(axis=0)
self.known_faces_names.append(person_name)
print(f"人物 {person_name} 加载完成,共 {len(embeddings)} 张照片")
else:
print(f"人物 {person_name} 无有效特征,已跳过")
print(f"人脸加载完成,共 {len(self.known_faces_names)}")
def recognize(self, frame, threshold=0.4):
"""识别人脸并返回结果"""
if not self.app or not self.known_faces_names:
return False, None, None
faces = self.app.get(frame)
if not faces:
return False, None, None
for face in faces:
for known_name in self.known_faces_names:
known_embedding = self.known_faces_embeddings[known_name]
embedding1 = face.embedding.astype(np.float32)
embedding2 = known_embedding.astype(np.float32)
dot_product = np.dot(embedding1, embedding2)
norm_embedding1 = np.linalg.norm(embedding1)
norm_embedding2 = np.linalg.norm(embedding2)
similarity = 0.0 if (norm_embedding1 == 0 or norm_embedding2 == 0) else (
dot_product / (norm_embedding1 * norm_embedding2)
)
if similarity >= threshold:
print(f"检测到已知人物: {known_name} (相似度: {similarity:.4f})")
return True, known_name, similarity
return False, None, None
def test_single_image(self, image_path: str, threshold=0.4):
"""测试单张图片识别"""
if not os.path.exists(image_path):
print(f"图片不存在: {image_path}")
return False, None, None
frame = cv2.imread(image_path)
if frame is None:
print(f"无法读取图片: {image_path}")
return False, None, None
result, name, similarity = self.recognize(frame, threshold)
if result:
print(f"识别结果: {name} (相似度: {similarity:.4f})")
faces = self.app.get(frame)
for face in faces:
bbox = face.bbox.astype(int)
cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2)
text = f"{name}: {similarity:.2f}"
cv2.putText(frame, text, (bbox[0], bbox[1] - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
cv2.imshow('识别结果', frame)
print("按任意键关闭窗口...")
cv2.waitKey(0)
cv2.destroyAllWindows()
else:
print("未识别到已知人脸")
return result, name, similarity
#
# if __name__ == "__main__":
# recognizer = FaceRecognizer(known_faces_dir="known_faces")
# test_image_path = r"F:\OCR\RapidOCR-main\known_faces\B\14sino-qiu02-master1050.jpg"
# recognizer.test_single_image(test_image_path, threshold=0.4)

View File

@ -1,156 +0,0 @@
import cv2
import numpy as np
import insightface
from insightface.app import FaceAnalysis
from io import BytesIO
from PIL import Image
class BinaryFaceFeatureHandler:
"""
专门处理图片二进制数据的特征提取器,支持分批次接收二进制数据并累积计算平均特征
"""
def __init__(self):
self.app = self._init_insightface()
self.feature_list = [] # 存储所有图片二进制数据提取的特征
def _init_insightface(self):
"""初始化InsightFace引擎"""
try:
print("正在初始化InsightFace引擎...")
app = FaceAnalysis(name='buffalo_l', root='~/.insightface')
app.prepare(ctx_id=0, det_size=(640, 640))
print("InsightFace引擎初始化完成")
return app
except Exception as e:
print(f"InsightFace初始化失败: {e}")
return None
def add_binary_data(self, binary_data):
"""
接收单张图片的二进制数据,提取特征并保存
参数:
binary_data: 图片的二进制数据bytes类型
返回:
成功提取特征时返回 (True, 特征值numpy数组)
失败时返回 (False, None)
"""
if not self.app:
print("引擎未初始化,无法处理")
return False, None
try:
# 直接处理二进制数据:转换为图像格式
img = Image.open(BytesIO(binary_data))
frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
# 提取特征
faces = self.app.get(frame)
if faces:
# 获取当前提取的特征值
current_feature = faces[0].embedding
# 添加到特征列表
self.feature_list.append(current_feature)
print(f"已累计 {len(self.feature_list)} 个特征")
# 返回成功标志和当前特征值
return True,current_feature
else:
print("二进制数据中未检测到人脸")
return False, None
except Exception as e:
print(f"处理二进制数据出错: {e}")
return False, None
def get_average_feature(self, features):
"""
计算多个特征向量的平均值
参数:
features: 特征值列表每个元素可以是字符串格式或numpy数组
例如: [feature1, feature2, ...]
返回:
单一平均特征向量的numpy数组若无可计算数据则返回None
"""
try:
# 验证输入是否为列表且不为空
if not isinstance(features, list) or len(features) == 0:
print("输入必须是包含至少一个特征值的列表")
return None
# 处理每个特征值
processed_features = []
for i, embedding in enumerate(features):
try:
if isinstance(embedding, str):
# 处理包含括号和逗号的字符串格式
embedding_str = embedding.replace('[', '').replace(']', '').replace(',', ' ').strip()
embedding_list = [float(num) for num in embedding_str.split() if num.strip()]
embedding_np = np.array(embedding_list, dtype=np.float32)
else:
embedding_np = np.array(embedding, dtype=np.float32)
# 验证特征值格式
if len(embedding_np.shape) == 1:
processed_features.append(embedding_np)
print(f"已添加第 {i + 1} 个特征值用于计算平均值")
else:
print(f"跳过第 {i + 1} 个特征值,不是一维数组")
except Exception as e:
print(f"处理第 {i + 1} 个特征值时出错: {e}")
# 确保有有效的特征值
if not processed_features:
print("没有有效的特征值用于计算平均值")
return None
# 检查所有特征向量维度是否相同
dims = {feat.shape[0] for feat in processed_features}
if len(dims) > 1:
print(f"特征值维度不一致,无法计算平均值。检测到的维度: {dims}")
return None
# 计算平均值
avg_feature = np.mean(processed_features, axis=0)
print(f"成功计算 {len(processed_features)} 个特征值的平均特征向量,维度: {avg_feature.shape[0]}")
return avg_feature
except Exception as e:
print(f"计算平均特征值时出错: {e}")
return None
# def clear(self):
# """清空已存储的特征数据"""
# self.feature_list = []
# print("已清空所有特征数据")
# # 使用示例
# if __name__ == "__main__":
# handler = BinaryFaceFeatureHandler()
#
# # 模拟接收图片二进制数据
# try:
# # 第一次接收
# with open(r"D:\Git\bin\video\ocr\known_faces\B\102-f.jpg_1140x855.jpg", "rb") as f:
# bin_data1 = f.read()
# success, feature1 = handler.add_binary_data(bin_data1)
# if success:
# print(f"第一次提取的特征值前5个: {feature1[:5]}")
#
# # 第二次接收
# with open(r"D:\Git\bin\video\ocr\known_faces\B\104-1.jpg", "rb") as f:
# bin_data2 = f.read()
# success, feature2 = handler.add_binary_data(bin_data2)
# if success:
# print(f"第二次提取的特征值前5个: {feature2[:5]}")
#
# # 计算平均值
# avg_feature = handler.get_average_feature()
#
# except Exception as e:
# print(f"处理过程出错: {e}")

View File

@ -1,18 +0,0 @@
打倒习近平
打到习近平
打倒毛泽东
打到毛泽东
打到主席
打倒主席
打到共产主义
打倒共产主义
打到共产党
打倒共产党
胖猫
法轮功
法轮大法
法轮大法好
法轮功大法好
法轮
李洪志
习近平

Binary file not shown.

Before

Width:  |  Height:  |  Size: 195 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 208 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 657 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 53 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 8.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 58 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 34 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 155 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 386 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.4 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 62 KiB

View File

@ -1,49 +0,0 @@
#日志文件
import logging
import sys
def setup_logger():
"""
配置一个全局日志记录器,支持输出到控制台和文件。
"""
# 创建一个日志记录器
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger = logging.getLogger("ViolationDetectorLogger")
logger.setLevel(logging.DEBUG) # 设置最低级别为DEBUG
# 如果已经有处理器了,就不要重复添加,防止日志重复打印
if logger.hasHandlers():
return logger
# --- 控制台处理器 ---
console_handler = logging.StreamHandler(sys.stdout)
# 对于控制台我们只显示INFO及以上级别的信息
console_handler.setLevel(logging.INFO)
console_formatter = logging.Formatter(
'%(asctime)s - %(levelname)s - [%(module)s:%(lineno)d] - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
console_handler.setFormatter(console_formatter)
# --- 文件处理器 ---
file_handler = logging.FileHandler("violation_detector.log", mode='a', encoding='utf-8')
# 对于文件我们记录所有DEBUG及以上级别的信息
file_handler.setLevel(logging.DEBUG)
file_formatter = logging.Formatter(
'%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
file_handler.setFormatter(file_formatter)
# 将处理器添加到日志记录器
logger.addHandler(console_handler)
logger.addHandler(file_handler)
return logger
# 创建并导出logger实例
logger = setup_logger()

View File

@ -1,136 +0,0 @@
import os
import cv2
import yaml
from pathlib import Path
from .ocr_violation_detector import OCRViolationDetector
from .yolo_violation_detector import ViolationDetector as YoloViolationDetector
from .face_recognizer import FaceRecognizer
class MultiModelViolationDetector:
"""
多模型违规检测封装类串行调用OCR、人脸识别和YOLO模型任一模型检测到违规即返回结果
"""
def __init__(self,
forbidden_words_path: str,
ocr_config_path: str,
yolo_model_path: str,
known_faces_dir: str,
ocr_confidence_threshold: float = 0.5):
"""
初始化所有检测模型
"""
# 初始化OCR检测器
self.ocr_detector = OCRViolationDetector(
forbidden_words_path=forbidden_words_path,
ocr_config_path=ocr_config_path,
ocr_confidence_threshold=ocr_confidence_threshold
)
# 初始化人脸识别器
self.face_recognizer = FaceRecognizer(
known_faces_dir=known_faces_dir
)
# 初始化YOLO检测器
self.yolo_detector = YoloViolationDetector(
model_path=yolo_model_path
)
print("多模型违规检测器初始化完成")
def detect_violations(self, frame):
"""
串行调用三个检测模型OCR → 人脸识别 → YOLO任一检测到违规即返回结果
"""
# 1. 首先进行OCR违禁词检测
try:
ocr_has_violation, ocr_words, ocr_confs = self.ocr_detector.detect(frame)
if ocr_has_violation:
details = {
"words": ocr_words,
"confidences": ocr_confs
}
print(f"警告: OCR检测到违禁内容: {details}")
return (True, "ocr", details)
except Exception as e:
print(f"错误: OCR检测出错: {str(e)}")
# 2. 接着进行人脸识别检测
try:
face_has_violation, face_name, face_similarity = self.face_recognizer.recognize(frame)
if face_has_violation:
details = {
"name": face_name,
"similarity": face_similarity
}
print(f"警告: 人脸识别到违规人员: {details}")
return (True, "face", details)
except Exception as e:
print(f"错误: 人脸识别出错: {str(e)}")
# 3. 最后进行YOLO目标检测
try:
yolo_results = self.yolo_detector.detect(frame)
if len(yolo_results.boxes) > 0:
details = {
"classes": yolo_results.names,
"boxes": yolo_results.boxes.xyxy.tolist(),
"confidences": yolo_results.boxes.conf.tolist(),
"class_ids": yolo_results.boxes.cls.tolist()
}
print(f"警告: YOLO检测到违规目标: {details}")
return (True, "yolo", details)
except Exception as e:
print(f"错误: YOLO检测出错: {str(e)}")
# 所有检测均未发现违规
return (False, None, None)
def load_config(config_path: str) -> dict:
"""加载YAML配置文件"""
try:
with open(config_path, 'r', encoding='utf-8') as f:
return yaml.safe_load(f)
except FileNotFoundError:
print(f"错误: 配置文件未找到: {config_path}")
raise
except yaml.YAMLError as e:
print(f"错误: 配置文件格式错误: {config_path}, 错误: {str(e)}")
raise
except Exception as e:
print(f"错误: 加载配置文件出错: {str(e)}")
raise
# 使用示例
# if __name__ == "__main__":
# # 加载配置文件
# config = load_config("config.yaml") # 配置文件路径,可根据实际情况修改
#
# # 初始化多模型检测器
# detector = MultiModelViolationDetector(
# forbidden_words_path=config["forbidden_words_path"],
# ocr_config_path=config["ocr_config_path"],
# yolo_model_path=config["yolo_model_path"],
# known_faces_dir=config["known_faces_dir"],
# ocr_confidence_threshold=config.get("ocr_confidence_threshold", 0.5)
# )
#
# # 读取测试图像(可替换为视频帧读取逻辑)
# test_image_path = config.get("test_image_path") # 从配置文件获取测试图片路径
# if test_image_path:
# frame = cv2.imread(test_image_path)
#
# if frame is not None:
# has_violation, violation_type, details = detector.detect_violations(frame)
# if has_violation:
# print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
# else:
# print("未检测到任何违规内容")
# else:
# print(f"无法读取测试图像: {test_image_path}")
# else:
# print("配置文件中未指定测试图像路径")

Binary file not shown.

View File

@ -1,178 +0,0 @@
import os
import cv2
from rapidocr import RapidOCR
class OCRViolationDetector:
"""
封装RapidOCR引擎用于检测图像帧中的违禁词。
核心功能加载违禁词、初始化OCR引擎、单帧图像违禁词检测
"""
def __init__(self,
forbidden_words_path: str,
ocr_config_path: str,
ocr_confidence_threshold: float = 0.5):
"""
初始化OCR引擎和违禁词列表。
Args:
forbidden_words_path (str): 违禁词列表 .txt 文件的路径。
ocr_config_path (str): OCR配置文件如1.yaml的路径。
ocr_confidence_threshold (float): OCR识别结果的置信度阈值0~1
"""
# 加载违禁词
self.forbidden_words = self._load_forbidden_words(forbidden_words_path)
# 初始化RapidOCR引擎
self.ocr_engine = self._initialize_ocr(ocr_config_path)
# 校验核心依赖是否就绪
self._check_dependencies()
# 设置置信度阈值限制在0~1范围
self.OCR_CONFIDENCE_THRESHOLD = max(0.0, min(ocr_confidence_threshold, 1.0))
print(f"OCR置信度阈值已设置范围0~1: {self.OCR_CONFIDENCE_THRESHOLD:.4f}")
def _load_forbidden_words(self, path: str) -> set:
"""
从TXT文件加载违禁词去重、过滤空行支持UTF-8编码
"""
forbidden_words = set()
# 检查文件是否存在
if not os.path.exists(path):
print(f"错误:违禁词文件不存在: {path}")
return forbidden_words
# 读取文件并处理内容
try:
with open(path, 'r', encoding='utf-8') as f:
forbidden_words = {
line.strip() for line in f
if line.strip() # 跳过空行或纯空格行
}
print(f"成功加载违禁词: {len(forbidden_words)} 个(已去重)")
except UnicodeDecodeError:
print(f"错误违禁词文件编码错误需UTF-8: {path}")
except PermissionError:
print(f"错误:无权限读取违禁词文件: {path}")
except Exception as e:
print(f"错误:加载违禁词失败: {str(e)}")
return forbidden_words
def _initialize_ocr(self, config_path: str) -> RapidOCR | None:
"""
初始化RapidOCR引擎校验配置文件、捕获初始化异常
"""
print("开始初始化RapidOCR引擎...")
# 检查配置文件是否存在
if not os.path.exists(config_path):
print(f"错误OCR配置文件不存在: {config_path}")
return None
# 初始化OCR引擎
try:
ocr_engine = RapidOCR(config_path=config_path)
print("RapidOCR引擎初始化成功")
return ocr_engine
except ImportError:
print("错误RapidOCR依赖未安装需执行pip install rapidocr-onnxruntime")
except Exception as e:
print(f"错误RapidOCR初始化失败: {str(e)}")
return None
def _check_dependencies(self) -> None:
"""校验OCR引擎和违禁词列表是否就绪"""
if not self.ocr_engine:
print("警告:⚠️ OCR引擎未就绪违禁词检测功能将禁用")
if not self.forbidden_words:
print("警告:⚠️ 违禁词列表为空,违禁词检测功能将禁用")
def detect(self, frame) -> tuple[bool, list, list]:
"""
对单帧图像进行OCR违禁词检测核心方法
Args:
frame: 输入图像帧NumPy数组BGR格式cv2读取的图像
Returns:
tuple[bool, list, list]:
- 第一个元素是否检测到违禁词True/False
- 第二个元素:检测到的违禁词列表(空列表表示无违禁词);
- 第三个元素:对应违禁词的置信度列表(与违禁词列表一一对应)。
"""
# 初始化返回结果
has_violation = False
violation_words = []
violation_confs = []
# 前置校验
if frame is None or frame.size == 0:
print("警告输入图像帧为空或无效跳过OCR检测")
return has_violation, violation_words, violation_confs
if not self.ocr_engine or not self.forbidden_words:
print("OCR引擎未就绪或违禁词为空跳过OCR检测")
return has_violation, violation_words, violation_confs
try:
# 执行OCR识别
print("开始执行OCR识别...")
ocr_result = self.ocr_engine(frame)
print(f"RapidOCR原始结果: {ocr_result}")
# 校验OCR结果是否有效
if ocr_result is None:
print("OCR识别未返回任何结果图像无文本或识别失败")
return has_violation, violation_words, violation_confs
# 检查txts和scores是否存在且不为None
if not hasattr(ocr_result, 'txts') or ocr_result.txts is None:
print("警告OCR结果中txts为None或不存在")
return has_violation, violation_words, violation_confs
if not hasattr(ocr_result, 'scores') or ocr_result.scores is None:
print("警告OCR结果中scores为None或不存在")
return has_violation, violation_words, violation_confs
# 转为列表并去None
if not isinstance(ocr_result.txts, (list, tuple)):
print(f"警告OCR txts不是可迭代类型实际类型: {type(ocr_result.txts)}")
texts = []
else:
texts = [txt.strip() for txt in ocr_result.txts if txt and isinstance(txt, str)]
if not isinstance(ocr_result.scores, (list, tuple)):
print(f"警告OCR scores不是可迭代类型实际类型: {type(ocr_result.scores)}")
confidences = []
else:
confidences = [conf for conf in ocr_result.scores if conf and isinstance(conf, (int, float))]
# 校验文本和置信度列表长度是否一致
if len(texts) != len(confidences):
print(f"警告OCR文本与置信度数量不匹配文本{len(texts)}个,置信度{len(confidences)}个),跳过检测")
return has_violation, violation_words, violation_confs
if len(texts) == 0:
print("OCR未识别到任何有效文本")
return has_violation, violation_words, violation_confs
# 遍历识别结果,筛选违禁词
print(f"开始筛选违禁词(阈值{self.OCR_CONFIDENCE_THRESHOLD:.4f}")
for text, conf in zip(texts, confidences):
if conf < self.OCR_CONFIDENCE_THRESHOLD:
print(f"文本 '{text}' 置信度{conf:.4f} < 阈值,跳过")
continue
matched_words = [word for word in self.forbidden_words if word in text]
if matched_words:
has_violation = True
violation_words.extend(matched_words)
violation_confs.extend([conf] * len(matched_words))
print(f"警告:检测到违禁词: {matched_words}(来源文本: '{text}',置信度: {conf:.4f}")
except Exception as e:
print(f"错误OCR检测过程异常: {str(e)}")
return has_violation, violation_words, violation_confs

View File

@ -1,47 +0,0 @@
from ultralytics import YOLO
import cv2
class ViolationDetector:
"""
用于加载YOLOv8 .pt模型并进行违规内容检测的类。
"""
def __init__(self, model_path):
"""
初始化检测器。
Args:
model_path (str): YOLO .pt模型的路径。
"""
print(f"正在从 '{model_path}' 加载YOLO模型...")
self.model = YOLO(model_path)
print("YOLO模型加载成功。")
def detect(self, frame):
"""
对单帧图像进行目标检测。
Args:
frame: 输入的图像帧 (NumPy数组, BGR格式)。
Returns:
ultralytics.engine.results.Results: YOLO的检测结果对象。
"""
# conf可以根据您的模型效果进行调整
# --- 为了测试,我们暂时将置信度调低,例如 0.2 ---
results = self.model(frame, conf=0.2)
return results[0]
def draw_boxes(self, frame, result):
"""
在图像帧上绘制检测框。
Args:
frame: 原始图像帧。
result: YOLO的检测结果对象。
Returns:
numpy.ndarray: 绘制了检测框的图像帧。
"""
# 使用YOLO自带的plot功能方便快捷
annotated_frame = result.plot()
return annotated_frame

View File

@ -1,164 +0,0 @@
import queue
import asyncio
import aiohttp
import threading
import time
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration
from aiortc.mediastreams import MediaStreamTrack
# 创建一个长度为1的队列用于生产者和消费者之间的通信
frame_queue = queue.Queue(maxsize=1)
class VideoTrack(MediaStreamTrack):
"""自定义视频轨道类继承自MediaStreamTrack"""
kind = "video"
def __init__(self, max_frames=100):
super().__init__()
self.frames = queue.Queue(maxsize=max_frames)
async def recv(self):
return await super().recv()
def webrtc_producer(webrtc_url):
"""
生产者方法从WEBRTC读取视频帧并放入队列
仅当队列空时才放入新帧,否则丢弃
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 创建RTCPeerConnection对象不使用ICE服务器
pc = RTCPeerConnection(RTCConfiguration(iceServers=[]))
video_track = VideoTrack()
pc.addTrack(video_track)
@pc.on("track")
async def on_track(track):
if track.kind == "video":
print("接收到视频轨道,开始接收视频帧")
while True:
# 从轨道接收视频帧
frame = await track.recv()
# 转换为BGR24格式的NumPy数组
frame_bgr24 = frame.to_ndarray(format='bgr24')
# 检查队列是否为空,为空则加入,否则丢弃
if frame_queue.empty():
try:
frame_queue.put_nowait(frame_bgr24)
print("帧已放入队列")
except queue.Full:
print("队列已满,丢弃帧")
else:
print("队列非空,丢弃帧")
async def main():
# 创建并发送SDP Offer
offer = await pc.createOffer()
print("已创建本地SDP Offer")
await pc.setLocalDescription(offer)
# 发送Offer到服务器并接收Answer
async with aiohttp.ClientSession() as session:
print(f"开始向服务器 {webrtc_url} 发送SDP Offer")
async with session.post(
webrtc_url,
data=offer.sdp.encode(),
headers={
"Content-Type": "application/sdp",
"Content-Length": str(len(offer.sdp))
},
ssl=False
) as response:
print("已接收到服务器的响应")
answer_sdp = await response.text()
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer'))
# 保持连接
try:
while True:
await asyncio.sleep(0.1)
except KeyboardInterrupt:
pass
finally:
print("关闭RTCPeerConnection")
await pc.close()
try:
loop.run_until_complete(main())
finally:
loop.close()
def frame_consumer(ip):
"""
消费者方法:从队列中读取帧并处理
每次处理后休眠200ms模拟延迟
"""
print("消费者启动,开始等待帧...")
try:
while True:
# 阻塞等待队列中的帧
frame = frame_queue.get()
print(f"消费帧,大小: {frame.shape}")
has_violation, violations, confidences = OCRViolationDetector.detect(frame)
# 输出检测结果
if has_violation:
detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
for word, conf in zip(violations, confidences):
detector.logger.info(f"- {word} (置信度: {conf:.4f})")
else:
detector.logger.info("图片中未检测到违禁词")
# 标记任务完成
frame_queue.task_done()
except KeyboardInterrupt:
print("消费者退出")
def start_webrtc_stream(ip, webrtc_url):
"""
启动WebRTC视频流处理的主方法
参数: webrtc_url - WebRTC服务器地址
"""
print(f"开始连接到WebRTC服务器: {webrtc_url}")
# 启动生产者线程
producer_thread = threading.Thread(
target=webrtc_producer,
args=(webrtc_url,),
daemon=True,
name="webrtc-producer"
)
# 启动消费者线程
consumer_thread = threading.Thread(
target=frame_consumer(ip),
daemon=True,
name="frame-consumer"
)
producer_thread.start()
consumer_thread.start()
print("生产者和消费者线程已启动")
try:
# 保持主线程运行
while True:
time.sleep(1)
except KeyboardInterrupt:
print("程序正在退出...")
if __name__ == "__main__":
# 示例用法
# 实际使用时替换为真实的WebRTC服务器地址
webrtc_server_url = "http://192.168.110.65:1985/rtc/v1/whep/?app=live&stream=677a4845aa48cb8526c811ad56fc5e60"
start_webrtc_stream(webrtc_server_url)

View File

@ -1,101 +0,0 @@
import asyncio
import logging
import cv2
import time
# 配置日志与WHEP代码保持一致的日志风格
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("rtmp_video_puller")
async def rtmp_pull_video_stream(rtmp_url):
"""
通过RTMP从指定URL拉取视频流并在收到每一帧时打印消息
功能与WHEP拉流函数对齐流状态反馈、帧信息打印、帧率统计、异常处理
Args:
rtmp_url: RTMP流的URL地址如 rtmp://xxx/live/stream_key
"""
cap = None # 初始化视频捕获对象
try:
# 1. 异步打开RTMP流指定FFmpeg后端确保RTMP兼容性同步操作通过to_thread避免阻塞事件循环
cap = await asyncio.to_thread(
cv2.VideoCapture,
rtmp_url,
cv2.CAP_FFMPEG # 必须指定FFmpeg后端RTMP协议依赖该后端解析
)
# 2. 检查RTMP流是否成功打开
is_opened = await asyncio.to_thread(cap.isOpened)
if not is_opened:
raise Exception(f"RTMP流打开失败: {rtmp_url}请检查URL有效性和FFmpeg环境")
# 3. 异步获取RTMP流基础信息分辨率、帧率
width = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_WIDTH)
height = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_HEIGHT)
fps = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FPS)
# 处理异常情况部分RTMP流未返回帧率时默认30FPS
fps = fps if fps > 0 else 30.0
# 分辨率转为整数(视频尺寸必然是整数)
width, height = int(width), int(height)
# 打印流初始化成功信息与WHEP连接成功信息风格一致
print(f"RTMP流状态: 已成功连接")
print(f"流基础信息: 分辨率 {width}x{height} | 配置帧率 {fps:.2f} FPS")
print("开始接收视频帧...(按 Ctrl+C 中断)")
# 4. 初始化帧统计参数
frame_count = 0 # 总接收帧数
start_time = time.time() # 统计起始时间
# 5. 循环异步读取视频帧(核心逻辑)
while True:
# 异步读取一帧cv2.read是同步操作用to_thread适配异步环境
ret, frame = await asyncio.to_thread(cap.read)
# 检查帧是否读取成功(流中断/结束时ret为False
if not ret:
print(f"RTMP流状态: 帧读取失败(可能流已中断或结束)")
break
# 帧计数累加
frame_count += 1
# 6. 打印当前帧基础信息与WHEP帧信息打印风格对齐
print(f"收到帧 (第{frame_count}帧)")
print(f" 帧尺寸: {width}x{height}")
print(f" 配置帧率: {fps:.2f} FPS")
# 7. 每100帧统计一次实际接收帧率补充性能监控与原RTMP示例逻辑一致
if frame_count % 100 == 0:
elapsed_time = time.time() - start_time
actual_fps = frame_count / elapsed_time # 实际接收帧率(可能低于配置帧率)
print(f"---- 帧统计: 累计{frame_count}帧 | 实际平均帧率 {actual_fps:.2f} FPS ----")
# (可选)帧数据处理入口:如需处理帧(如推流、分析),可在此处添加逻辑
# 示例yield frame (若需生成器模式,可调整函数为异步生成器)
# 8. 异常处理(覆盖用户中断、通用错误)
except KeyboardInterrupt:
print(f"\n用户操作: 已通过 Ctrl+C 中断程序")
except Exception as e:
# 日志记录详细错误(便于问题排查),同时打印用户可见信息
logger.error(f"RTMP流处理异常: {str(e)}", exc_info=True)
print(f"错误信息: {str(e)}")
finally:
# 9. 资源释放(无论成功/失败都确保释放,避免内存泄漏)
if cap is not None:
await asyncio.to_thread(cap.release)
print(f"\n资源释放: RTMP流已关闭")
print(f"最终统计: 共接收 {frame_count if 'frame_count' in locals() else 0}")
if __name__ == "__main__":
RTMP_URL = "rtmp://192.168.110.25:1935/live/473b95a47e338301cbd96809ea7ac416"
# 运行RTMP拉流任务与WHEP一致的异步执行方式
try:
asyncio.run(rtmp_pull_video_stream(RTMP_URL))
except Exception as e:
print(f"程序启动失败: {str(e)}")

View File

@ -23,8 +23,8 @@ class FaceResponse(BaseModel):
"""人脸记录响应模型仍包含ID由数据库生成后返回""" """人脸记录响应模型仍包含ID由数据库生成后返回"""
id: int = Field(..., description="主键ID数据库自增") id: int = Field(..., description="主键ID数据库自增")
name: str = Field(None, description="名称") name: str = Field(None, description="名称")
eigenvalue: str = Field(None, description="特征(暂为None") eigenvalue: str | None = Field(None, description="特征(可为空")
created_at: datetime = Field(..., description="记录创建时间") created_at: datetime = Field(..., description="记录创建时间")
updated_at: datetime = Field(..., description="记录更新时间") updated_at: datetime = Field(..., description="记录更新时间")
model_config = {"from_attributes": True} model_config = {"from_attributes": True}

View File

@ -1,6 +1,6 @@
import json import json
from fastapi import APIRouter, Query, HTTPException from fastapi import APIRouter, Query, HTTPException,Request
from mysql.connector import Error as MySQLError from mysql.connector import Error as MySQLError
from ds.db import db from ds.db import db
@ -108,7 +108,7 @@ def update_online_status_by_ip(client_ip: str, online_status: int) -> bool:
# 原有接口保持不变 # 原有接口保持不变
# ------------------------------ # ------------------------------
@router.post("/add", response_model=APIResponse, summary="创建设备信息") @router.post("/add", response_model=APIResponse, summary="创建设备信息")
async def create_device(device_data: DeviceCreateRequest): async def create_device(device_data: DeviceCreateRequest, request: Request): # 注入Request对象
# 原有代码保持不变 # 原有代码保持不变
conn = None conn = None
cursor = None cursor = None
@ -125,11 +125,10 @@ async def create_device(device_data: DeviceCreateRequest):
return APIResponse( return APIResponse(
code=200, code=200,
message=f"设备IP {device_data.ip} 已存在,返回已有设备信息", message=f"设备IP {device_data.ip} 已存在,返回已有设备信息",
data=DeviceResponse(**existing_device) data=DeviceResponse(** existing_device)
) )
from fastapi import Request # 直接使用注入的request对象获取用户代理
request = Request(scope={"type": "http"})
user_agent = request.headers.get("User-Agent", "").lower() user_agent = request.headers.get("User-Agent", "").lower()
if user_agent == "default": if user_agent == "default":
@ -184,7 +183,6 @@ async def create_device(device_data: DeviceCreateRequest):
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
@router.get("/", response_model=APIResponse, summary="获取设备列表(支持筛选分页)") @router.get("/", response_model=APIResponse, summary="获取设备列表(支持筛选分页)")
async def get_device_list( async def get_device_list(
page: int = Query(1, ge=1, description="页码默认第1页"), page: int = Query(1, ge=1, description="页码默认第1页"),

View File

@ -6,15 +6,15 @@ from schema.face_schema import FaceCreateRequest, FaceUpdateRequest, FaceRespons
from schema.response_schema import APIResponse from schema.response_schema import APIResponse
from middle.auth_middleware import get_current_user from middle.auth_middleware import get_current_user
from schema.user_schema import UserResponse from schema.user_schema import UserResponse
from ocr.feature_extraction import BinaryFaceFeatureHandler
from util.face_util import add_binary_data,get_average_feature
#初始化实例
router = APIRouter( router = APIRouter(
prefix="/faces", prefix="/faces",
tags=["人脸管理"] tags=["人脸管理"]
) )
# 创建 BinaryFaceFeatureHandler 的实例
binary_face_feature_handler = BinaryFaceFeatureHandler()
# ------------------------------ # ------------------------------
@ -33,6 +33,8 @@ async def create_face(
- ID 由数据库自动生成,无需前端传入 - ID 由数据库自动生成,无需前端传入
- 暂不处理文件内容eigenvalue 设为 None - 暂不处理文件内容eigenvalue 设为 None
""" """
# 调用你的方法
conn = None conn = None
cursor = None cursor = None
try: try:
@ -45,14 +47,24 @@ async def create_face(
# 把文件转为二进制数组 # 把文件转为二进制数组
file_content = await file.read() file_content = await file.read()
# 调用人脸识别得到特征值 # 计算特征值
flag, eigenvalue = add_binary_data(file_content)
if flag == False:
raise HTTPException(
status_code=500,
detail="未检测到人脸"
)
# 打印数组长度
print(f"文件大小:{len(file_content)} 字节")
# 2. 插入数据库:无需传 ID自增只传 name 和 eigenvalueNone # 2. 插入数据库:无需传 ID自增只传 name 和 eigenvalueNone
insert_query = """ insert_query = """
INSERT INTO face (name, eigenvalue) INSERT INTO face (name, eigenvalue)
VALUES (%s, %s) VALUES (%s, %s)
""" """
cursor.execute(insert_query, (face_create.name, None)) cursor.execute(insert_query, (face_create.name, str(eigenvalue)))
conn.commit() conn.commit()
# 3. 获取数据库自动生成的 ID关键用 LAST_INSERT_ID() 查刚插入的记录) # 3. 获取数据库自动生成的 ID关键用 LAST_INSERT_ID() 查刚插入的记录)
@ -60,19 +72,45 @@ async def create_face(
cursor.execute(select_new_query) cursor.execute(select_new_query)
created_face = cursor.fetchone() created_face = cursor.fetchone()
if not created_face:
raise HTTPException(
status_code=500,
detail="创建人脸记录成功,但无法获取新创建的记录"
)
return APIResponse( return APIResponse(
code=201, code=201,
message=f"人脸记录创建成功ID{created_face['id']},文件名:{file.filename}", message=f"人脸记录创建成功ID{created_face['id']},文件名:{file.filename}",
data=FaceResponse(**created_face) data=FaceResponse(** created_face)
) )
except MySQLError as e: except MySQLError as e:
if conn: if conn:
conn.rollback() conn.rollback()
raise Exception(f"创建人脸记录失败:{str(e)}") from e # 改为使用HTTPException
raise HTTPException(
status_code=500,
detail=f"创建人脸记录失败:{str(e)}"
) from e
except Exception as e:
# 捕获其他可能的异常
raise HTTPException(
status_code=500,
detail=f"服务器错误:{str(e)}"
) from e
finally: finally:
await file.close() # 关闭文件流 await file.close() # 关闭文件流
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# 调用人脸识别得到特征值(这里可以添加你的人脸识别逻辑)
flag, eigenvalue = add_binary_data(file_content)
if flag == False:
raise HTTPException(
status_code=500,
detail="未检测到人脸"
)
# 将 eigenvalue 转为 str
eigenvalue = str(eigenvalue)
# ------------------------------ # ------------------------------
# 2. 获取单个人脸记录不变用自增ID查询 # 2. 获取单个人脸记录不变用自增ID查询
@ -104,18 +142,21 @@ async def get_face(
data=FaceResponse(**face) data=FaceResponse(**face)
) )
except MySQLError as e: except MySQLError as e:
raise Exception(f"查询人脸记录失败:{str(e)}") from e # 改为使用HTTPException
raise HTTPException(
status_code=500,
detail=f"查询人脸记录失败:{str(e)}"
) from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# 后续 3.获取所有、4.更新、5.删除 接口(不变,仅用数据库自增的 ID 操作,无需修改 # 后续 3.获取所有、4.更新、5.删除 接口(修复异常处理
# ------------------------------ # ------------------------------
# 3. 获取所有人脸记录(不变) # 3. 获取所有人脸记录(不变)
# ------------------------------ # ------------------------------
@router.get("", response_model=APIResponse, summary="获取所有人脸记录") @router.get("", response_model=APIResponse, summary="获取所有人脸记录")
async def get_all_faces( async def get_all_faces(
current_user: UserResponse = Depends(get_current_user)
): ):
conn = None conn = None
cursor = None cursor = None
@ -130,10 +171,13 @@ async def get_all_faces(
return APIResponse( return APIResponse(
code=200, code=200,
message="所有人脸记录查询成功", message="所有人脸记录查询成功",
data=[FaceResponse(**face) for face in faces] data=[FaceResponse(** face) for face in faces]
) )
except MySQLError as e: except MySQLError as e:
raise Exception(f"查询所有人脸记录失败:{str(e)}") from e raise HTTPException(
status_code=500,
detail=f"查询所有人脸记录失败:{str(e)}"
) from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
@ -194,7 +238,10 @@ async def update_face(
except MySQLError as e: except MySQLError as e:
if conn: if conn:
conn.rollback() conn.rollback()
raise Exception(f"更新人脸记录失败:{str(e)}") from e raise HTTPException(
status_code=500,
detail=f"更新人脸记录失败:{str(e)}"
) from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
@ -234,7 +281,10 @@ async def delete_face(
except MySQLError as e: except MySQLError as e:
if conn: if conn:
conn.rollback() conn.rollback()
raise Exception(f"删除人脸记录失败:{str(e)}") from e raise HTTPException(
status_code=500,
detail=f"删除人脸记录失败:{str(e)}"
) from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
@ -249,38 +299,43 @@ def get_all_face_name_with_eigenvalue() -> dict:
conn = None conn = None
cursor = None cursor = None
try: try:
# 1. 建立数据库连接并获取游标dictionary=True使结果以字典形式返回
conn = db.get_connection() conn = db.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
# 只查询需要的字段,提高效率 # 2. 执行SQL查询只获取name非空的记录减少数据传输
query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL" query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL"
cursor.execute(query) cursor.execute(query)
faces = cursor.fetchall() faces = cursor.fetchall() # 返回结果:列表套字典,如 [{"name":"张三","eigenvalue":...}, ...]
# 先收集所有名称对应的特征值列表(处理重复名称) # 3. 收集同一名称对应的所有特征值(处理名称重复场景
name_to_eigenvalues = {} name_to_eigenvalues = {}
for face in faces: for face in faces:
name = face["name"] name = face["name"]
eigenvalue = face["eigenvalue"] eigenvalue = face["eigenvalue"]
# 若名称已存在,追加特征值;否则新建列表存储
if name in name_to_eigenvalues: if name in name_to_eigenvalues:
name_to_eigenvalues[name].append(eigenvalue) name_to_eigenvalues[name].append(eigenvalue)
else: else:
name_to_eigenvalues[name] = [eigenvalue] name_to_eigenvalues[name] = [eigenvalue]
# 构建最终字典:重复名称取平均特征值,唯一名称直接取特征值 # 4. 构建最终字典:重复名称取平均,唯一名称直接取特征值
face_dict = {} face_dict = {}
for name, eigenvalues in name_to_eigenvalues.items(): for name, eigenvalues in name_to_eigenvalues.items():
print("调用的特征值是:" + eigenvalues)
# 处理特征值:多个则求平均,单个则直接使用
if len(eigenvalues) > 1: if len(eigenvalues) > 1:
# 调用平均特征值计算方法 # 调用外部方法计算平均特征值需确保binary_face_feature_handler已正确导入
face_dict[name] = binary_face_feature_handler.get_average_feature(eigenvalues) face_dict[name] = get_average_feature(eigenvalues)
else: else:
# 取列表中唯一的特征值避免value为列表类型
face_dict[name] = eigenvalues[0] face_dict[name] = eigenvalues[0]
return face_dict return face_dict
except MySQLError as e: except MySQLError as e:
# 捕获数据库异常,添加上下文信息后重新抛出(便于定位问题)
raise Exception(f"获取人脸名称与特征值失败:{str(e)}") from e raise Exception(f"获取人脸名称与特征值失败:{str(e)}") from e
finally: finally:
# 确保资源释放 # 5. 无论是否异常,均释放数据库连接和游标(避免资源泄漏)
db.close_connection(conn, cursor) db.close_connection(conn, cursor)

145
util/face_util.py Normal file
View File

@ -0,0 +1,145 @@
import cv2
import numpy as np
import insightface
from insightface.app import FaceAnalysis
from io import BytesIO
from PIL import Image
# 全局变量存储InsightFace引擎和特征列表
_insightface_app = None
_feature_list = []
def init_insightface():
"""初始化InsightFace引擎"""
global _insightface_app
try:
print("正在初始化InsightFace引擎...")
app = FaceAnalysis(name='buffalo_l', root='~/.insightface')
app.prepare(ctx_id=0, det_size=(640, 640))
print("InsightFace引擎初始化完成")
_insightface_app = app
return app
except Exception as e:
print(f"InsightFace初始化失败: {e}")
return None
def add_binary_data(binary_data):
"""
接收单张图片的二进制数据,提取特征并保存
参数:
binary_data: 图片的二进制数据bytes类型
返回:
成功提取特征时返回 (True, 特征值numpy数组)
失败时返回 (False, None)
"""
global _insightface_app, _feature_list
if not _insightface_app:
print("引擎未初始化,无法处理")
return False, None
try:
# 直接处理二进制数据:转换为图像格式
img = Image.open(BytesIO(binary_data))
frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
# 提取特征
faces = _insightface_app.get(frame)
if faces:
# 获取当前提取的特征值
current_feature = faces[0].embedding
# 添加到特征列表
_feature_list.append(current_feature)
print(f"已累计 {len(_feature_list)} 个特征")
# 返回成功标志和当前特征值
return True, current_feature
else:
print("二进制数据中未检测到人脸")
return False, None
except Exception as e:
print(f"处理二进制数据出错: {e}")
return False, None
def get_average_feature(features=None):
"""
计算多个特征向量的平均值
参数:
features: 可选特征值列表。如果未提供则使用全局存储的_feature_list
每个元素可以是字符串格式或numpy数组
返回:
单一平均特征向量的numpy数组若无可计算数据则返回None
"""
global _feature_list
# 如果未提供features参数则使用全局特征列表
if features is None:
features = _feature_list
try:
# 验证输入是否为列表且不为空
if not isinstance(features, list) or len(features) == 0:
print("输入必须是包含至少一个特征值的列表")
return None
# 处理每个特征值
processed_features = []
for i, embedding in enumerate(features):
try:
if isinstance(embedding, str):
# 处理包含括号和逗号的字符串格式
embedding_str = embedding.replace('[', '').replace(']', '').replace(',', ' ').strip()
embedding_list = [float(num) for num in embedding_str.split() if num.strip()]
embedding_np = np.array(embedding_list, dtype=np.float32)
else:
embedding_np = np.array(embedding, dtype=np.float32)
# 验证特征值格式
if len(embedding_np.shape) == 1:
processed_features.append(embedding_np)
print(f"已添加第 {i + 1} 个特征值用于计算平均值")
else:
print(f"跳过第 {i + 1} 个特征值,不是一维数组")
except Exception as e:
print(f"处理第 {i + 1} 个特征值时出错: {e}")
# 确保有有效的特征值
if not processed_features:
print("没有有效的特征值用于计算平均值")
return None
# 检查所有特征向量维度是否相同
dims = {feat.shape[0] for feat in processed_features}
if len(dims) > 1:
print(f"特征值维度不一致,无法计算平均值。检测到的维度: {dims}")
return None
# 计算平均值
avg_feature = np.mean(processed_features, axis=0)
print(f"成功计算 {len(processed_features)} 个特征值的平均特征向量,维度: {avg_feature.shape[0]}")
return avg_feature
except Exception as e:
print(f"计算平均特征值时出错: {e}")
return None
def clear_features():
"""清空已存储的特征数据"""
global _feature_list
_feature_list = []
print("已清空所有特征数据")
def get_feature_list():
"""获取当前存储的特征列表"""
global _feature_list
return _feature_list.copy() # 返回副本防止外部直接修改

482
ws.html
View File

@ -1,482 +0,0 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>WebSocket 测试工具</title>
<style>
* {
box-sizing: border-box;
margin: 0;
padding: 0;
font-family: 'Arial', 'Microsoft YaHei', sans-serif;
}
body {
max-width: 1200px;
margin: 20px auto;
padding: 0 20px;
background-color: #f5f7fa;
}
.container {
background: white;
border-radius: 8px;
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
padding: 25px;
margin-bottom: 20px;
}
h1 {
color: #2c3e50;
margin-bottom: 20px;
font-size: 24px;
border-bottom: 2px solid #3498db;
padding-bottom: 10px;
}
.status-bar {
display: flex;
align-items: center;
gap: 15px;
margin-bottom: 20px;
padding: 12px 15px;
background-color: #f8f9fa;
border-radius: 6px;
}
.status-label {
font-weight: bold;
color: #495057;
}
.status-value {
padding: 4px 10px;
border-radius: 4px;
font-weight: bold;
}
.status-connected {
background-color: #d4edda;
color: #155724;
}
.status-disconnected {
background-color: #f8d7da;
color: #721c24;
}
.status-connecting {
background-color: #fff3cd;
color: #856404;
}
.btn {
padding: 8px 16px;
border: none;
border-radius: 4px;
cursor: pointer;
font-size: 14px;
font-weight: 500;
transition: background-color 0.2s;
}
.btn-primary {
background-color: #3498db;
color: white;
}
.btn-primary:hover {
background-color: #2980b9;
}
.btn-danger {
background-color: #e74c3c;
color: white;
}
.btn-danger:hover {
background-color: #c0392b;
}
.btn-success {
background-color: #2ecc71;
color: white;
}
.btn-success:hover {
background-color: #27ae60;
}
.control-group {
display: flex;
gap: 15px;
margin-bottom: 20px;
align-items: center;
}
.input-group {
display: flex;
gap: 10px;
align-items: center;
}
.input-group label {
color: #495057;
font-weight: 500;
}
.input-group input, .input-group select {
padding: 8px 12px;
border: 1px solid #ced4da;
border-radius: 4px;
font-size: 14px;
}
.message-area {
margin-top: 20px;
}
.message-input {
width: 100%;
height: 100px;
padding: 12px;
border: 1px solid #ced4da;
border-radius: 6px;
resize: none;
font-size: 14px;
margin-bottom: 10px;
}
.log-area {
width: 100%;
height: 300px;
padding: 15px;
border: 1px solid #ced4da;
border-radius: 6px;
background-color: #f8f9fa;
overflow-y: auto;
font-size: 14px;
line-height: 1.6;
}
.log-item {
margin-bottom: 8px;
padding-bottom: 8px;
border-bottom: 1px dashed #e9ecef;
}
.log-time {
color: #6c757d;
font-size: 12px;
margin-right: 10px;
}
.log-send {
color: #2980b9;
}
.log-receive {
color: #27ae60;
}
.log-status {
color: #856404;
}
.log-error {
color: #e74c3c;
}
</style>
</head>
<body>
<div class="container">
<h1>WebSocket 测试工具</h1>
<!-- 连接状态区 -->
<div class="status-bar">
<div class="status-label">连接状态:</div>
<div id="connectionStatus" class="status-value status-disconnected">未连接</div>
<div class="status-label">服务地址:</div>
<div id="wsUrl" class="status-value">ws://192.168.110.25:8000/ws</div>
<div class="status-label">连接时间:</div>
<div id="connectTime" class="status-value">-</div>
</div>
<!-- 控制按钮区 -->
<div class="control-group">
<button id="connectBtn" class="btn btn-primary">建立连接</button>
<button id="disconnectBtn" class="btn btn-danger" disabled>断开连接</button>
<!-- 心跳控制 -->
<div class="input-group">
<label>自动心跳:</label>
<select id="autoHeartbeat">
<option value="on">开启</option>
<option value="off">关闭</option>
</select>
<label>间隔(秒)</label>
<input type="number" id="heartbeatInterval" value="30" min="10" max="120" style="width: 80px;">
<button id="sendHeartbeatBtn" class="btn btn-success">手动发送心跳</button>
</div>
</div>
<!-- 自定义消息发送区 -->
<div class="message-area">
<h3>发送自定义消息</h3>
<textarea id="messageInput" class="message-input"
placeholder='示例:{"type":"test","content":"Hello WebSocket"}'>{"type":"test","content":"Hello WebSocket"}</textarea>
<button id="sendMessageBtn" class="btn btn-primary" disabled>发送消息</button>
</div>
<!-- 日志显示区 -->
<div class="message-area">
<h3>消息日志</h3>
<div id="logContainer" class="log-area">
<div class="log-item"><span class="log-time">[加载完成]</span> 请点击「建立连接」开始测试</div>
</div>
<button id="clearLogBtn" class="btn btn-primary" style="margin-top: 10px;">清空日志</button>
</div>
</div>
<script>
// 全局变量
let ws = null;
let heartbeatTimer = null;
const wsUrl = "ws://192.168.110.25:8000/ws";
// DOM 元素
const connectionStatus = document.getElementById('connectionStatus');
const connectTime = document.getElementById('connectTime');
const connectBtn = document.getElementById('connectBtn');
const disconnectBtn = document.getElementById('disconnectBtn');
const sendMessageBtn = document.getElementById('sendMessageBtn');
const sendHeartbeatBtn = document.getElementById('sendHeartbeatBtn');
const autoHeartbeat = document.getElementById('autoHeartbeat');
const heartbeatInterval = document.getElementById('heartbeatInterval');
const messageInput = document.getElementById('messageInput');
const logContainer = document.getElementById('logContainer');
const clearLogBtn = document.getElementById('clearLogBtn');
// 工具函数:添加日志
function addLog(content, type = 'status') {
const now = new Date().toLocaleString('zh-CN', {
year: 'numeric', month: '2-digit', day: '2-digit',
hour: '2-digit', minute: '2-digit', second: '2-digit'
});
const logItem = document.createElement('div');
logItem.className = 'log-item';
let logClass = '';
switch (type) {
case 'send':
logClass = 'log-send';
break;
case 'receive':
logClass = 'log-receive';
break;
case 'error':
logClass = 'log-error';
break;
default:
logClass = 'log-status';
}
logItem.innerHTML = `<span class="log-time">[${now}]</span> <span class="${logClass}">${content}</span>`;
logContainer.appendChild(logItem);
// 滚动到最新日志
logContainer.scrollTop = logContainer.scrollHeight;
}
// 工具函数格式化JSON便于日志显示
function formatJson(jsonStr) {
try {
const obj = JSON.parse(jsonStr);
return JSON.stringify(obj, null, 2);
} catch (e) {
return jsonStr; // 非JSON格式直接返回
}
}
// 建立WebSocket连接
function connectWebSocket() {
if (ws) {
addLog('已存在连接,无需重复建立', 'error');
return;
}
try {
ws = new WebSocket(wsUrl);
// 连接成功
ws.onopen = function () {
connectionStatus.className = 'status-value status-connected';
connectionStatus.textContent = '已连接';
const now = new Date().toLocaleString('zh-CN');
connectTime.textContent = now;
addLog(`连接成功!服务地址:${wsUrl}`, 'status');
// 更新按钮状态
connectBtn.disabled = true;
disconnectBtn.disabled = false;
sendMessageBtn.disabled = false;
// 开启自动心跳(默认开启)
if (autoHeartbeat.value === 'on') {
startAutoHeartbeat();
}
};
// 接收消息
ws.onmessage = function (event) {
const message = event.data;
addLog(`收到消息:\n${formatJson(message)}`, 'receive');
};
// 连接关闭
ws.onclose = function (event) {
connectionStatus.className = 'status-value status-disconnected';
connectionStatus.textContent = '已断开';
addLog(`连接断开!代码:${event.code},原因:${event.reason || '未知'}`, 'status');
// 清除自动心跳
stopAutoHeartbeat();
// 更新按钮状态
connectBtn.disabled = false;
disconnectBtn.disabled = true;
sendMessageBtn.disabled = true;
// 重置WebSocket对象
ws = null;
};
// 连接错误
ws.onerror = function (error) {
addLog(`连接错误:${error.message || '未知错误'}`, 'error');
};
} catch (e) {
addLog(`建立连接失败:${e.message}`, 'error');
ws = null;
}
}
// 断开WebSocket连接
function disconnectWebSocket() {
if (!ws) {
addLog('当前无连接,无需断开', 'error');
return;
}
ws.close(1000, '手动断开连接');
}
// 发送心跳消息(符合约定格式:{"timestamp":xxxxx, "type":"heartbeat"}
function sendHeartbeat() {
if (!ws || ws.readyState !== WebSocket.OPEN) {
addLog('发送心跳失败:当前无有效连接', 'error');
return;
}
const heartbeatMsg = {
timestamp: Date.now(), // 当前毫秒时间戳
type: "heartbeat"
};
const msgStr = JSON.stringify(heartbeatMsg);
ws.send(msgStr);
addLog(`发送心跳:\n${formatJson(msgStr)}`, 'send');
}
// 开启自动心跳
function startAutoHeartbeat() {
// 先停止已有定时器
stopAutoHeartbeat();
const interval = parseInt(heartbeatInterval.value) * 1000;
if (isNaN(interval) || interval < 10000) {
addLog('自动心跳间隔无效已重置为30秒', 'error');
heartbeatInterval.value = 30;
return startAutoHeartbeat();
}
addLog(`开启自动心跳,间隔:${heartbeatInterval.value}`, 'status');
heartbeatTimer = setInterval(sendHeartbeat, interval);
}
// 停止自动心跳
function stopAutoHeartbeat() {
if (heartbeatTimer) {
clearInterval(heartbeatTimer);
heartbeatTimer = null;
addLog('已停止自动心跳', 'status');
}
}
// 发送自定义消息
function sendCustomMessage() {
if (!ws || ws.readyState !== WebSocket.OPEN) {
addLog('发送消息失败:当前无有效连接', 'error');
return;
}
const msgStr = messageInput.value.trim();
if (!msgStr) {
addLog('发送消息失败:消息内容不能为空', 'error');
return;
}
try {
// 验证JSON格式可选仅提示不强制
JSON.parse(msgStr);
ws.send(msgStr);
addLog(`发送自定义消息:\n${formatJson(msgStr)}`, 'send');
} catch (e) {
addLog(`JSON格式错误${e.message},仍尝试发送原始内容`, 'error');
ws.send(msgStr);
addLog(`发送自定义消息非JSON\n${msgStr}`, 'send');
}
}
// 绑定按钮事件
connectBtn.addEventListener('click', connectWebSocket);
disconnectBtn.addEventListener('click', disconnectWebSocket);
sendMessageBtn.addEventListener('click', sendCustomMessage);
sendHeartbeatBtn.addEventListener('click', sendHeartbeat);
clearLogBtn.addEventListener('click', () => {
logContainer.innerHTML = '<div class="log-item"><span class="log-time">[日志已清空]</span> 请继续操作...</div>';
});
// 自动心跳开关变更事件
autoHeartbeat.addEventListener('change', function () {
if (ws && ws.readyState === WebSocket.OPEN) {
if (this.value === 'on') {
startAutoHeartbeat();
} else {
stopAutoHeartbeat();
}
} else {
addLog('需先建立有效连接才能控制自动心跳', 'error');
// 重置选择
this.value = 'off';
}
});
// 心跳间隔变更事件(实时生效)
heartbeatInterval.addEventListener('change', function () {
if (autoHeartbeat.value === 'on' && ws && ws.readyState === WebSocket.OPEN) {
startAutoHeartbeat();
}
});
// 快捷键支持Ctrl+Enter发送消息
messageInput.addEventListener('keydown', function (e) {
if (e.ctrlKey && e.key === 'Enter') {
sendCustomMessage();
e.preventDefault();
}
});
</script>
</body>
</html>

400
ws/ws.py
View File

@ -4,314 +4,300 @@ import json
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Dict, Optional, AsyncGenerator from typing import Dict, Optional, AsyncGenerator
from concurrent.futures import ThreadPoolExecutor # 新增:显式线程池
from service.device_service import update_online_status_by_ip, increment_alarm_count_by_ip from service.device_service import update_online_status_by_ip, increment_alarm_count_by_ip
from service.device_action_service import add_device_action from service.device_action_service import add_device_action
from schema.device_action_schema import DeviceActionCreate from schema.device_action_schema import DeviceActionCreate
from core.all import detect
import cv2 import cv2
import numpy as np import numpy as np
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI
from queue import Queue # 线程安全队列无需额外Lock from core.all import load_model
from ocr.model_violation_detector import MultiModelViolationDetector # 配置常量
HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒)
# -------------------------- 配置调整 --------------------------
# 模型路径(建议改为环境变量)
YOLO_MODEL_PATH = r"D:\Git\bin\video\ocr\models\best.pt"
OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml"
# 核心优化:模型池大小(决定最大并发任务数,显存占用=大小×单模型显存)
MODEL_POOL_SIZE = 5 # 示例设为5支持5个任务并行显存会明显上升
THREAD_POOL_SIZE = MODEL_POOL_SIZE * 2 # 线程池大小≥模型池,避免线程瓶颈
# 其他配置
HEARTBEAT_INTERVAL = 30 # 心跳间隔(秒)
HEARTBEAT_TIMEOUT = 600 # 客户端超时阈值(秒) HEARTBEAT_TIMEOUT = 600 # 客户端超时阈值(秒)
WS_ENDPOINT = "/ws" # WebSocket端点 WS_ENDPOINT = "/ws" # WebSocket端点路径
FRAME_QUEUE_SIZE = 5 # 增大帧队列,允许缓存更多帧(避免丢帧) FRAME_QUEUE_SIZE = 1 # 帧队列大小限制
# -------------------------- 工具函数 --------------------------
# 工具函数:获取格式化时间字符串(统一时间戳格式)
def get_current_time_str() -> str: def get_current_time_str() -> str:
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def get_current_time_file_str() -> str: def get_current_time_file_str() -> str:
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
# -------------------------- 模型池重构核心修改1 --------------------------
class ModelPool:
def __init__(self, pool_size: int = MODEL_POOL_SIZE):
self.pool = Queue(maxsize=pool_size)
# 移除冗余LockQueue.get()/put()本身线程安全
self._init_models(pool_size)
print(f"[{get_current_time_str()}] 模型池初始化完成(共{pool_size}个实例,显存已预分配)")
def _init_models(self, pool_size: int): # 客户端连接封装
"""预加载所有模型实例(初始化时显存会一次性上升)"""
for i in range(pool_size):
try:
detector = MultiModelViolationDetector(
ocr_config_path=OCR_CONFIG_PATH,
yolo_model_path=YOLO_MODEL_PATH,
ocr_confidence_threshold=0.5
)
self.pool.put(detector)
print(f"[{get_current_time_str()}] 模型实例{i+1}/{pool_size}加载完成")
except Exception as e:
raise RuntimeError(f"模型实例{i+1}加载失败:{str(e)}")
def get_model(self) -> MultiModelViolationDetector:
"""获取模型(阻塞直到有空闲实例,确保并发安全)"""
return self.pool.get()
def return_model(self, detector: MultiModelViolationDetector):
"""归还模型(立即释放资源供其他任务使用)"""
self.pool.put(detector)
# -------------------------- 全局资源初始化 --------------------------
model_pool = ModelPool(pool_size=MODEL_POOL_SIZE) # 初始化模型池(预占显存)
thread_pool = ThreadPoolExecutor( # 显式创建线程池核心修改2
max_workers=THREAD_POOL_SIZE,
thread_name_prefix="ModelWorker-" # 线程命名,便于调试
)
# -------------------------- 客户端连接封装核心修改3 --------------------------
class ClientConnection: class ClientConnection:
def __init__(self, websocket: WebSocket, client_ip: str): def __init__(self, websocket: WebSocket, client_ip: str):
self.websocket = websocket self.websocket = websocket
self.client_ip = client_ip self.client_ip = client_ip
self.last_heartbeat = datetime.datetime.now() self.last_heartbeat = datetime.datetime.now()
self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE) # 增大队列 self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE)
self.consumer_task: Optional[asyncio.Task] = None self.consumer_task: Optional[asyncio.Task] = None
# 移除“客户端独占模型”不再持有detector属性
def update_heartbeat(self): def update_heartbeat(self):
"""更新心跳时间(客户端发送心跳时调用)"""
self.last_heartbeat = datetime.datetime.now() self.last_heartbeat = datetime.datetime.now()
def is_alive(self) -> bool: def is_alive(self) -> bool:
"""判断客户端是否存活(心跳超时检查)"""
timeout = (datetime.datetime.now() - self.last_heartbeat).total_seconds() timeout = (datetime.datetime.now() - self.last_heartbeat).total_seconds()
return timeout < HEARTBEAT_TIMEOUT return timeout < HEARTBEAT_TIMEOUT
def start_consumer(self): def start_consumer(self):
"""启动帧消费任务(每个客户端一个独立任务)""" """启动帧消费任务"""
self.consumer_task = asyncio.create_task(self.consume_frames()) self.consumer_task = asyncio.create_task(self.consume_frames())
return self.consumer_task return self.consumer_task
async def send_frame_permit(self): async def send_frame_permit(self):
"""发送帧许可信号(允许客户端继续发帧)""" """
发送「帧发送许可信号」
通知客户端可发送下一帧图像
"""
try: try:
await self.websocket.send_json({ frame_permit_msg = {
"type": "frame", "type": "frame",
"timestamp": get_current_time_str(), "timestamp": get_current_time_str(),
"client_ip": self.client_ip "client_ip": self.client_ip
}) }
await self.websocket.send_json(frame_permit_msg)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:已发送帧发送许可信号(取帧后立即通知)")
except Exception as e: except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧许可发送失败 - {str(e)}") print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧许可信号发送失败 - {str(e)}")
async def consume_frames(self) -> None: async def consume_frames(self) -> None:
"""消费队列(并发核心:每帧临时借模型处理)""" """消费队列中的帧并处理(核心调整:取帧后立即发许可,再处理"""
try: try:
while True: while True:
# 1. 从队列取帧(无帧时阻塞) # 1. 从队列取帧(阻塞直到有帧可用
frame_data = await self.frame_queue.get() frame_data = await self.frame_queue.get()
# 2. 立即发送下一帧许可(让客户端持续发帧,积累并发任务)
await self.send_frame_permit() # -------------------------- 核心修改:取出帧后立即发送下一帧许可 --------------------------
await self.send_frame_permit() # 取帧即通知客户端发下一帧,无需等处理完成
# -----------------------------------------------------------------------------------------
try: try:
# 3. 并行处理帧(核心:任务级借模型 # 2. 处理取出的帧(即使处理慢,客户端也已收到许可,可提前准备下一帧
await self.process_frame(frame_data) await self.process_frame(frame_data)
finally: finally:
self.frame_queue.task_done() # 标记帧处理完成 # 3. 标记帧任务完成(无论处理成功/失败,都需清理队列)
self.frame_queue.task_done()
except asyncio.CancelledError: except asyncio.CancelledError:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧消费任务已取消") print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧消费任务已取消")
except Exception as e: except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:消费逻辑错误 - {str(e)}") print(f"[{get_current_time_str()}] 客户端{self.client_ip}消费逻辑错误 - {str(e)}")
async def process_frame(self, frame_data: bytes) -> None: async def process_frame(self, frame_data: bytes) -> None:
"""处理单帧核心修改4任务级借还模型""" """处理单帧图像数据(检测违规后发送危险通知 + 调用违规次数加一方法"""
# 1. 临时借用模型(阻塞直到有空闲实例,显存随借用数上升) # 二进制数据转OpenCV图像
detector = model_pool.get_model() nparr = np.frombuffer(frame_data, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:无法解析图像数据")
return
# 确保图像保存目录存在
os.makedirs('images', exist_ok=True)
# 保存图像按IP+时间戳命名,避免冲突)
filename = f"images/{self.client_ip.replace('.', '_')}_{get_current_time_file_str()}.jpg"
try: try:
# 2. 二进制转OpenCV图像
nparr = np.frombuffer(frame_data, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:图像解析失败")
return
# 3. 保存图像(可选)
os.makedirs('images', exist_ok=True)
filename = f"images/{self.client_ip.replace('.', '_')}_{get_current_time_file_str()}.jpg"
cv2.imwrite(filename, img) cv2.imwrite(filename, img)
print(f"[{get_current_time_str()}] 图像已保存至:{filename}")
# 4. 显式线程池执行AI检测真正并发无线程瓶颈 has_violation, data, type = detect(img)
loop = asyncio.get_running_loop() print(has_violation)
has_violation, violation_type, details = await loop.run_in_executor( print(type)
thread_pool, # 用自定义线程池,避免默认线程不足 print(data)
detector.detect_violations, # 临时借用的模型
img # 输入图像
)
# 5. 违规处理(与原逻辑一致)
if has_violation: if has_violation:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:违规 - {violation_type}") print(
# 违规次数更新(用线程池避免阻塞事件循环) f"[{get_current_time_str()}] 客户端{self.client_ip}:检测到违规 - 类型: {type}, 详情: {data}")
await loop.run_in_executor(thread_pool, increment_alarm_count_by_ip, self.client_ip)
# 发送危险通知 # 调用违规次数加一方法
await self.websocket.send_json({ try:
await asyncio.to_thread(increment_alarm_count_by_ip, self.client_ip)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:违规次数已+1")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:违规次数更新失败 - {str(e)}")
# 发送「危险通知」
danger_msg = {
"type": "danger", "type": "danger",
"timestamp": get_current_time_str(), "timestamp": get_current_time_str(),
"client_ip": self.client_ip, "client_ip": self.client_ip
"violation_type": violation_type, }
"details": details await self.websocket.send_json(danger_msg)
})
else: else:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}违规") print(f"[{get_current_time_str()}] 客户端{self.client_ip}未检测到违规")
except Exception as e: except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}处理错误 - {str(e)}") print(f"[{get_current_time_str()}] 客户端{self.client_ip}图像处理错误 - {str(e)}")
finally:
# 6. 无论成功/失败,强制归还模型(核心:释放资源供其他任务使用)
model_pool.return_model(detector)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:模型已归还(可复用)")
# -------------------------- 全局状态与心跳 --------------------------
# 全局状态管理
connected_clients: Dict[str, ClientConnection] = {} connected_clients: Dict[str, ClientConnection] = {}
client_lock = asyncio.Lock() # 保护客户端字典的异步锁
heartbeat_task: Optional[asyncio.Task] = None heartbeat_task: Optional[asyncio.Task] = None
# 心跳检查(定时清理超时客户端 + 调用离线状态更新方法)
async def heartbeat_checker(): async def heartbeat_checker():
"""心跳检查(移除模型归还逻辑,因模型已任务级归还)"""
while True: while True:
current_time = get_current_time_str() current_time = get_current_time_str()
async with client_lock: timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()]
# 筛选超时客户端
timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()]
for ip in timeout_ips: if timeout_ips:
async with client_lock: print(f"[{current_time}] 心跳检查:{len(timeout_ips)}个客户端超时IP{timeout_ips}")
conn = connected_clients.get(ip) for ip in timeout_ips:
if not conn: try:
continue conn = connected_clients[ip]
# 取消消费任务+关闭连接 if conn.consumer_task and not conn.consumer_task.done():
if conn.consumer_task and not conn.consumer_task.done(): conn.consumer_task.cancel()
conn.consumer_task.cancel() await conn.websocket.close(code=1008, reason="心跳超时")
await conn.websocket.close(code=1008, reason="心跳超时")
# 标记离线(用线程池)
loop = asyncio.get_running_loop()
await loop.run_in_executor(thread_pool, update_online_status_by_ip, ip, 0)
await loop.run_in_executor(
thread_pool, add_device_action, DeviceActionCreate(client_ip=ip, action=0)
)
connected_clients.pop(ip)
print(f"[{current_time}] 客户端{ip}:超时离线(资源已清理)")
# 打印在线状态 # 超时设为离线并记录
async with client_lock: try:
await asyncio.to_thread(update_online_status_by_ip, ip, 0)
action_data = DeviceActionCreate(client_ip=ip, action=0)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{current_time}] 客户端{ip}:已标记为离线并记录操作")
except Exception as e:
print(f"[{current_time}] 客户端{ip}:离线状态更新失败 - {str(e)}")
finally:
connected_clients.pop(ip, None)
else:
print(f"[{current_time}] 心跳检查:{len(connected_clients)}个客户端在线") print(f"[{current_time}] 心跳检查:{len(connected_clients)}个客户端在线")
await asyncio.sleep(HEARTBEAT_INTERVAL) await asyncio.sleep(HEARTBEAT_INTERVAL)
# -------------------------- 应用生命周期核心修改5管理线程池 --------------------------
# 应用生命周期管理
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
global heartbeat_task global heartbeat_task
# 启动心跳任务
heartbeat_task = asyncio.create_task(heartbeat_checker()) heartbeat_task = asyncio.create_task(heartbeat_checker())
print(f"[{get_current_time_str()}] 心跳任务启动ID{id(heartbeat_task)}") print(f"[{get_current_time_str()}] 全局心跳检查任务启动(任务ID{id(heartbeat_task)}")
print(f"[{get_current_time_str()}] 线程池启动(最大线程数:{THREAD_POOL_SIZE}") yield
yield # 应用运行期间
# 清理资源
if heartbeat_task and not heartbeat_task.done(): if heartbeat_task and not heartbeat_task.done():
heartbeat_task.cancel() heartbeat_task.cancel()
await heartbeat_task try:
print(f"[{get_current_time_str()}] 心跳任务已关闭") await heartbeat_task
# 关闭线程池(等待所有任务完成) print(f"[{get_current_time_str()}] 全局心跳检查任务已取消")
thread_pool.shutdown(wait=True) except asyncio.CancelledError:
print(f"[{get_current_time_str()}] 线程池已关闭") pass
# -------------------------- WebSocket路由 --------------------------
# 消息处理工具函数
async def send_heartbeat_ack(conn: ClientConnection):
try:
heartbeat_ack_msg = {
"type": "heart",
"timestamp": get_current_time_str(),
"client_ip": conn.client_ip
}
await conn.websocket.send_json(heartbeat_ack_msg)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:已发送心跳确认")
return True
except Exception as e:
connected_clients.pop(conn.client_ip, None)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:心跳确认发送失败 - {str(e)}")
return False
async def handle_text_msg(conn: ClientConnection, text: str):
try:
msg = json.loads(text)
if msg.get("type") == "heart":
conn.update_heartbeat()
await send_heartbeat_ack(conn)
else:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:未知文本消息类型({msg.get('type')}")
except json.JSONDecodeError:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}无效JSON文本消息")
async def handle_binary_msg(conn: ClientConnection, data: bytes):
try:
conn.frame_queue.put_nowait(data)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:图像数据({len(data)}字节)已加入队列")
except asyncio.QueueFull:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:帧队列已满,丢弃当前图像数据")
# WebSocket路由配置
ws_router = APIRouter() ws_router = APIRouter()
@ws_router.websocket(WS_ENDPOINT) @ws_router.websocket(WS_ENDPOINT)
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
# 加载模型
load_model()
await websocket.accept() await websocket.accept()
client_ip = websocket.client.host if websocket.client else "unknown_ip" client_ip = websocket.client.host if websocket.client else "unknown_ip"
current_time = get_current_time_str() current_time = get_current_time_str()
print(f"[{current_time}] 客户端{client_ip}:连接建立") print(f"[{current_time}] 客户端{client_ip}WebSocket连接建立")
new_conn = None
is_online_updated = False is_online_updated = False
try:
# 处理重复连接(关闭旧连接)
async with client_lock:
if client_ip in connected_clients:
old_conn = connected_clients[client_ip]
if old_conn.consumer_task and not old_conn.consumer_task.done():
old_conn.consumer_task.cancel()
await old_conn.websocket.close(code=1008, reason="新连接抢占")
connected_clients.pop(client_ip)
print(f"[{current_time}] 客户端{client_ip}:旧连接已关闭")
# 创建新连接+启动消费任务 try:
# 处理重复连接
if client_ip in connected_clients:
old_conn = connected_clients[client_ip]
if old_conn.consumer_task and not old_conn.consumer_task.done():
old_conn.consumer_task.cancel()
await old_conn.websocket.close(code=1008, reason="同一IP新连接建立")
connected_clients.pop(client_ip)
print(f"[{current_time}] 客户端{client_ip}:已关闭旧连接")
# 注册新连接
new_conn = ClientConnection(websocket, client_ip) new_conn = ClientConnection(websocket, client_ip)
connected_clients[client_ip] = new_conn
new_conn.start_consumer() new_conn.start_consumer()
# 初始发送帧许可(让客户端立即发帧 # 初始许可:连接建立后立即发一次,让客户端知道可发第一帧(后续靠取帧后自动发
await new_conn.send_frame_permit() await new_conn.send_frame_permit()
# 标记客户端在线 # 标记上线并记录
loop = asyncio.get_running_loop() try:
await loop.run_in_executor(thread_pool, update_online_status_by_ip, client_ip, 1) await asyncio.to_thread(update_online_status_by_ip, client_ip, 1)
await loop.run_in_executor( action_data = DeviceActionCreate(client_ip=client_ip, action=1)
thread_pool, add_device_action, DeviceActionCreate(client_ip=client_ip, action=1) await asyncio.to_thread(add_device_action, action_data)
) print(f"[{current_time}] 客户端{client_ip}:已标记为在线并记录操作")
is_online_updated = True is_online_updated = True
async with client_lock: except Exception as e:
connected_clients[client_ip] = new_conn print(f"[{current_time}] 客户端{client_ip}:上线状态更新失败 - {str(e)}")
print(f"[{current_time}] 客户端{client_ip}:注册成功(在线数:{len(connected_clients)}")
# 消息循环(接收文本/二进制帧) print(f"[{current_time}] 客户端{client_ip}:新连接注册成功,在线数:{len(connected_clients)}")
# 消息循环
while True: while True:
data = await websocket.receive() data = await websocket.receive()
if "text" in data: if "text" in data:
# 处理文本消息(如心跳) await handle_text_msg(new_conn, data["text"])
try:
msg = json.loads(data["text"])
if msg.get("type") == "heart":
new_conn.update_heartbeat()
# 回复心跳确认
await websocket.send_json({
"type": "heart",
"timestamp": get_current_time_str(),
"client_ip": client_ip
})
except json.JSONDecodeError:
print(f"[{get_current_time_str()}] 客户端{client_ip}无效JSON")
elif "bytes" in data: elif "bytes" in data:
# 处理二进制帧(图像) await handle_binary_msg(new_conn, data["bytes"])
try:
await new_conn.frame_queue.put(data["bytes"])
print(f"[{get_current_time_str()}] 客户端{client_ip}:帧已入队(队列大小:{new_conn.frame_queue.qsize()}")
except asyncio.QueueFull:
print(f"[{get_current_time_str()}] 客户端{client_ip}:帧队列满(丢弃当前帧)")
except WebSocketDisconnect as e: except WebSocketDisconnect as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}:主动断开(代码:{e.code}") print(f"[{get_current_time_str()}] 客户端{client_ip}:主动断开连接(代码:{e.code}")
except Exception as e: except Exception as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}:连接异常 - {str(e)[:50]}") print(f"[{get_current_time_str()}] 客户端{client_ip}:连接异常 - {str(e)[:50]}")
finally: finally:
# 清理资源无需归还模型已在process_frame中归还 # 清理资源并标记离线
if new_conn and client_ip in connected_clients: if client_ip in connected_clients:
async with client_lock: conn = connected_clients[client_ip]
conn = connected_clients.get(client_ip) if conn.consumer_task and not conn.consumer_task.done():
if conn: conn.consumer_task.cancel()
if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel() # 主动/异常断开时标记离线
# 标记离线(仅当在线状态已更新时) if is_online_updated:
if is_online_updated: try:
loop = asyncio.get_running_loop() await asyncio.to_thread(update_online_status_by_ip, client_ip, 0)
await loop.run_in_executor(thread_pool, update_online_status_by_ip, client_ip, 0) action_data = DeviceActionCreate(client_ip=client_ip, action=0)
await loop.run_in_executor( await asyncio.to_thread(add_device_action, action_data)
thread_pool, add_device_action, DeviceActionCreate(client_ip=client_ip, action=0) print(f"[{get_current_time_str()}] 客户端{client_ip}:断开后已标记为离线")
) except Exception as e:
connected_clients.pop(client_ip) print(f"[{get_current_time_str()}] 客户端{client_ip}:断开后离线更新失败 - {str(e)}")
async with client_lock:
print(f"[{get_current_time_str()}] 客户端{client_ip}:资源清理完成(在线数:{len(connected_clients)}") connected_clients.pop(client_ip, None)
print(f"[{get_current_time_str()}] 客户端{client_ip}:资源已清理,在线数:{len(connected_clients)}")