最新可用

This commit is contained in:
2025-09-04 22:59:27 +08:00
parent ec6dbfde90
commit 30bf7c9fcb
42 changed files with 746 additions and 1967 deletions

2
.idea/Video.iml generated
View File

@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="video" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="Python 3.10" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

2
.idea/misc.xml generated
View File

@ -3,5 +3,5 @@
<component name="Black">
<option name="sdkName" value="video" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="video" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10" project-jdk-type="Python SDK" />
</project>

View File

@ -15,5 +15,5 @@ algorithm = HS256
access_token_expire_minutes = 30
[live]
rtmp_url = rtmp://192.168.110.65:1935/live/
webrtc_url = http://192.168.110.65:1985/rtc/v1/whep/?app=live&stream=
rtmp_url = rtmp://192.168.110.25:1935/live/
webrtc_url = http://192.168.110.25:1985/rtc/v1/whep/?app=live&stream=

45
core/all.py Normal file
View File

@ -0,0 +1,45 @@
from core.ocr import load_model as ocrLoadModel, detect as ocrDetect
from core.face import load_model as faceLoadModel, detect as faceDetect
from core.yolo import load_model as yoloLoadModel, detect as yoloDetect
# 添加一个标记变量用于监控load_model是否已被调用
_model_loaded = False
def load_model():
global _model_loaded
# 如果已经调用过,直接忽略
if _model_loaded:
return
# 首次调用时加载模型
ocrLoadModel()
faceLoadModel()
yoloLoadModel()
# 标记为已调用
_model_loaded = True
def detect(frame):
# 先进行YOLO检测
yolo_flag, yolo_result = yoloDetect(frame)
print("YOLO检测结果", yolo_result)
if yolo_flag:
return (True, yolo_result, "yolo")
# YOLO未检测到进行人脸检测
face_flag, face_result = faceDetect(frame)
print("人脸检测结果:", face_result)
if face_flag:
return (True, face_result, "face")
# 人脸未检测到进行OCR检测
ocr_flag, ocr_result = ocrDetect(frame)
print("OCR检测结果", ocr_result)
if ocr_flag:
return (True, ocr_result, "ocr")
# 所有检测都未检测到
return (False, "未检测到任何内容", "none")

113
core/face.py Normal file
View File

@ -0,0 +1,113 @@
import os
import numpy as np
import cv2
from PIL import Image # 确保正确导入Image类
from insightface.app import FaceAnalysis
# 导入获取人脸信息的服务
from service.face_service import get_all_face_name_with_eigenvalue
# 全局变量
_face_app = None
_known_faces_embeddings = {} # 存储姓名到特征值的映射
_known_faces_names = [] # 存储所有已知姓名
def load_model():
"""加载人脸识别模型及已知人脸特征库"""
global _face_app, _known_faces_embeddings, _known_faces_names
# 初始化InsightFace模型
try:
_face_app = FaceAnalysis(name='buffalo_l', root=os.path.expanduser('~/.insightface'))
_face_app.prepare(ctx_id=0, det_size=(640, 640))
except Exception as e:
print(f"Face model load failed: {e}")
return False
# 从服务获取所有人脸姓名和特征值
try:
face_data = get_all_face_name_with_eigenvalue()
# 处理获取到的人脸数据
for person_name, eigenvalue_data in face_data.items():
# 处理特征值数据 - 兼容数组和字符串两种格式
if isinstance(eigenvalue_data, np.ndarray):
# 如果已经是numpy数组直接使用
eigenvalue = eigenvalue_data.astype(np.float32)
elif isinstance(eigenvalue_data, str):
# 清理字符串:移除方括号、换行符和多余空格
cleaned = eigenvalue_data.replace('[', '').replace(']', '').replace('\n', '').strip()
# 按空格或逗号分割(处理可能的不同分隔符)
values = [v for v in cleaned.split() if v]
# 转换为数组
eigenvalue = np.array(list(map(float, values)), dtype=np.float32)
else:
# 不支持的类型
print(f"Unsupported eigenvalue type for {person_name}")
continue
# 归一化处理
norm = np.linalg.norm(eigenvalue)
if norm != 0:
eigenvalue = eigenvalue / norm
_known_faces_embeddings[person_name] = eigenvalue
_known_faces_names.append(person_name)
except Exception as e:
print(f"Error loading face data from service: {e}")
return True if _face_app else False
def detect(frame, threshold=0.4):
"""检测并识别人脸,返回结果元组(是否匹配到已知人脸, 结果字符串)"""
global _face_app, _known_faces_embeddings, _known_faces_names
if not _face_app or not _known_faces_names or frame is None:
return (False, "未初始化或无效帧")
try:
faces = _face_app.get(frame)
except Exception as e:
print(f"Face detect error: {e}")
return (False, f"检测错误: {str(e)}")
result_parts = []
has_matched = False # 新增标记:是否有匹配到的已知人脸
for face in faces:
# 特征归一化
embedding = face.embedding.astype(np.float32)
norm = np.linalg.norm(embedding)
if norm == 0:
continue
embedding = embedding / norm
# 对比已知人脸
max_sim, best_name = -1.0, "Unknown"
for name in _known_faces_names:
known_emb = _known_faces_embeddings[name]
sim = np.dot(embedding, known_emb)
if sim > max_sim:
max_sim = sim
best_name = name
# 判断匹配结果
is_match = max_sim >= threshold
if is_match:
has_matched = True # 只要有一个匹配成功就标记为True
bbox = face.bbox
result_parts.append(
f"{'匹配' if is_match else '不匹配'}: {best_name} (相似度: {max_sim:.2f}, 边界框: {bbox})"
)
# 构建结果字符串
if not result_parts:
result_str = "未检测到人脸"
else:
result_str = "; ".join(result_parts)
# 第一个返回值改为:是否匹配到已知人脸
return (has_matched, result_str)

BIN
core/models/best.pt Normal file

Binary file not shown.

76
core/ocr.py Normal file
View File

@ -0,0 +1,76 @@
import os
import cv2
from rapidocr import RapidOCR
from service.sensitive_service import get_all_sensitive_words
# 全局变量
_ocr_engine = None
_forbidden_words = set()
_conf_threshold = 0.5
ocr_config_path = os.path.join(os.path.dirname(__file__), "config", "config.yaml")
def load_model():
"""加载OCR引擎及违禁词列表"""
global _ocr_engine, _forbidden_words, _conf_threshold
# 加载违禁词
try:
_forbidden_words = get_all_sensitive_words()
except Exception as e:
print(f"Forbidden words load error: {e}")
# 初始化OCR引擎
if not os.path.exists(ocr_config_path):
print(f"OCR config not found: {ocr_config_path}")
return False
try:
_ocr_engine = RapidOCR(config_path=ocr_config_path)
except Exception as e:
print(f"OCR model load failed: {e}")
return False
return True if _ocr_engine else False
def detect(frame):
"""OCR检测并筛选违禁词返回(是否检测到违禁词, 结果字符串)"""
if not _ocr_engine or not _forbidden_words or frame is None or frame.size == 0:
return (False, "未初始化或无效帧")
try:
ocr_res = _ocr_engine(frame)
except Exception as e:
print(f"OCR detect error: {e}")
return (False, f"检测错误: {str(e)}")
if not ocr_res or not hasattr(ocr_res, 'txts') or not hasattr(ocr_res, 'scores'):
return (False, "无OCR结果")
# 处理OCR结果
texts = [t.strip() for t in ocr_res.txts if t and isinstance(t, str)]
confs = [c for c in ocr_res.scores if c and isinstance(c, (int, float))]
if len(texts) != len(confs):
return (False, "OCR结果格式异常")
# 筛选违禁词
vio_info = []
for txt, conf in zip(texts, confs):
if conf < _conf_threshold:
continue
matched = [w for w in _forbidden_words if w in txt]
if matched:
vio_info.append(f"文本: '{txt}' 包含违禁词: {', '.join(matched)} (置信度: {conf:.2f})")
# 构建结果字符串
has_text = len(texts) > 0
has_violation = len(vio_info) > 0
if not has_text:
return (False, "未识别到文本")
elif has_violation:
return (True, "; ".join(vio_info))
else:
return (False, "未检测到违禁词")

View File

@ -1,137 +0,0 @@
import asyncio
import logging
from aiortc import RTCPeerConnection, RTCSessionDescription
import aiohttp
from ocr.ocr_violation_detector import OCRViolationDetector
import logging
# 创建检测器实例
detector = OCRViolationDetector(
forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt",
ocr_confidence_threshold=0.7,
log_level=logging.INFO,
log_file="ocr_detection.log"
)
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("whep_video_puller")
async def whep_pull_video_stream(ip,whep_url):
"""
通过WHEP从指定URL拉取视频流并在收到每一帧时打印消息
Args:
whep_url: WHEP端点的URL
"""
pc = RTCPeerConnection()
# 添加连接状态变化监听
@pc.on("connectionstatechange")
async def on_connectionstatechange():
print(f"连接状态: {pc.connectionState}")
# 添加ICE连接状态变化监听
@pc.on("iceconnectionstatechange")
async def on_iceconnectionstatechange():
print(f"ICE连接状态: {pc.iceConnectionState}")
# 添加视频接收器
pc.addTransceiver("video", direction="recvonly")
# 处理接收到的视频轨道
@pc.on("track")
def on_track(track):
print(f"接收到轨道: {track.kind}")
if track.kind == "video":
print(f"轨道ID: {track.id}")
print(f"轨道就绪状态: {track.readyState}")
# 创建异步任务来处理视频帧
asyncio.ensure_future(handle_video_track(track))
async def handle_video_track(track):
"""处理视频轨道,接收并打印每一帧"""
frame_count = 0
print("开始处理视频轨道...")
while True:
try:
# 尝试接收帧
frame = await track.recv()
frame_count += 1
print(f"收到原始帧 (第{frame_count}帧)")
# 打印帧的基本信息
if hasattr(frame, 'width') and hasattr(frame, 'height'):
print(f" 尺寸: {frame.width}x{frame.height}")
if hasattr(frame, 'time_base'):
print(f" 时间基准: {frame.time_base}")
if hasattr(frame, 'pts'):
print(f" 显示时间戳: {frame.pts}")
has_violation, violations, confidences = OCRViolationDetector.detect(frame)
# 输出检测结果
if has_violation:
detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
for word, conf in zip(violations, confidences):
detector.logger.info(f"- {word} (置信度: {conf:.4f})")
else:
detector.logger.info("图片中未检测到违禁词")
except Exception as e:
print(f"接收帧时出错: {e}")
# 等待一段时间后重试
await asyncio.sleep(0.1)
continue
# 创建offer
offer = await pc.createOffer()
await pc.setLocalDescription(offer)
print(f"本地SDP信息:\n{offer.sdp}")
# 通过HTTP POST发送offer到WHEP端点
async with aiohttp.ClientSession() as session:
async with session.post(
whep_url,
data=offer.sdp,
headers={"Content-Type": "application/sdp"}
) as response:
if response.status != 201:
print(f"WHEP服务器返回错误: {response.status}")
print(f"响应内容: {await response.text()}")
raise Exception(f"WHEP服务器返回错误: {response.status}")
# 获取answer SDP
answer_sdp = await response.text()
# 创建RTCSessionDescription对象
answer = RTCSessionDescription(sdp=answer_sdp, type="answer")
print(f"收到远程SDP:\n{answer_sdp}")
# 设置远程描述
await pc.setRemoteDescription(answer)
print("连接已建立,开始接收视频流...")
# 保持连接,直到用户中断
try:
while True:
await asyncio.sleep(1)
# 检查连接状态
print(f"当前连接状态: {pc.connectionState}")
except KeyboardInterrupt:
print("用户中断,关闭连接...")
finally:
await pc.close()
if __name__ == "__main__":
# 替换为你的WHEP端点URL
WHEP_URL = "http://192.168.110.25:1985/rtc/v1/whep/?app=live&stream=473b95a47e338301cbd96809ea7ac416"
# 运行拉流任务
asyncio.run(whep_pull_video_stream(WHEP_URL))

View File

@ -1,112 +0,0 @@
import asyncio
import logging
import cv2
import time
from ocr.model_violation_detector import MultiModelViolationDetector
# 配置文件相对路径(根据实际目录结构调整)
YOLO_MODEL_PATH = "../ocr/models/best.pt" # 关键修正从core目录向上一级找ocr文件夹
FORBIDDEN_WORDS_PATH = "../ocr/forbidden_words.txt"
OCR_CONFIG_PATH = "../ocr/config/1.yaml"
KNOWN_FACES_DIR = "../ocr/known_faces"
# 创建检测器实例
detector = MultiModelViolationDetector(
forbidden_words_path=FORBIDDEN_WORDS_PATH,
ocr_config_path=OCR_CONFIG_PATH,
yolo_model_path=YOLO_MODEL_PATH,
known_faces_dir=KNOWN_FACES_DIR,
ocr_confidence_threshold=0.5
)
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("rtmp_video_puller")
async def rtmp_pull_video_stream(rtmp_url):
"""
通过RTMP从指定URL拉取视频流并进行违规检测
"""
cap = None # 初始化视频捕获对象
try:
# 异步打开RTMP流
cap = await asyncio.to_thread(
cv2.VideoCapture,
rtmp_url,
cv2.CAP_FFMPEG # 指定FFmpeg后端确保RTMP兼容性
)
# 检查RTMP流是否成功打开
is_opened = await asyncio.to_thread(cap.isOpened)
if not is_opened:
raise Exception(f"RTMP流打开失败: {rtmp_url}请检查URL有效性和FFmpeg环境")
# 获取RTMP流基础信息
width = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_WIDTH)
height = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_HEIGHT)
fps = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FPS)
# 处理异常情况
fps = fps if fps > 0 else 30.0
width, height = int(width), int(height)
# 打印流初始化成功信息
print(f"RTMP流状态: 已成功连接")
print(f"流基础信息: 分辨率 {width}x{height} | 配置帧率 {fps:.2f} FPS")
print("开始接收视频帧...(按 Ctrl+C 中断)")
# 初始化帧统计参数
frame_count = 0
start_time = time.time()
# 循环读取视频帧
while True:
ret, frame = await asyncio.to_thread(cap.read)
if not ret:
print(f"RTMP流状态: 帧读取失败(可能流已中断或结束)")
break
frame_count += 1
# 打印当前帧信息
print(f"收到帧 (第{frame_count}帧)")
print(f" 帧尺寸: {width}x{height}")
print(f" 配置帧率: {fps:.2f} FPS")
if frame is not None:
has_violation, violation_type, details = detector.detect_violations(frame)
if has_violation:
print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
else:
print("未检测到任何违规内容")
else:
print(f"无法读取测试图像")
# 每100帧统计一次实际接收帧率
if frame_count % 100 == 0:
elapsed_time = time.time() - start_time
actual_fps = frame_count / elapsed_time
print(f"---- 帧统计: 累计{frame_count}帧 | 实际平均帧率 {actual_fps:.2f} FPS ----")
except KeyboardInterrupt:
print(f"\n用户操作: 已通过 Ctrl+C 中断程序")
except Exception as e:
logger.error(f"RTMP流处理异常: {str(e)}", exc_info=True)
print(f"错误信息: {str(e)}")
finally:
if cap is not None:
await asyncio.to_thread(cap.release)
print(f"\n资源释放: RTMP流已关闭")
print(f"最终统计: 共接收 {frame_count if 'frame_count' in locals() else 0}")
if __name__ == "__main__":
RTMP_URL = "rtmp://192.168.110.25:1935/live/473b95a47e338301cbd96809ea7ac416"
try:
asyncio.run(rtmp_pull_video_stream(RTMP_URL))
except Exception as e:
print(f"程序启动失败: {str(e)}")

55
core/yolo.py Normal file
View File

@ -0,0 +1,55 @@
import os
import cv2
from ultralytics import YOLO
# 全局变量
_yolo_model = None
model_path = os.path.join(os.path.dirname(__file__), "models", "best.pt")
def load_model():
"""加载YOLO目标检测模型"""
global _yolo_model
try:
_yolo_model = YOLO(model_path)
except Exception as e:
print(f"YOLO model load failed: {e}")
return False
return True if _yolo_model else False
def detect(frame, conf_threshold=0.2):
"""YOLO目标检测返回(是否识别到, 结果字符串)"""
global _yolo_model
if not _yolo_model or frame is None:
return (False, "未初始化或无效帧")
try:
results = _yolo_model(frame, conf=conf_threshold)
# 检查是否有检测结果
has_results = len(results[0].boxes) > 0 if results else False
if not has_results:
return (False, "未检测到目标")
# 构建结果字符串
result_parts = []
for box in results[0].boxes:
cls = int(box.cls[0])
conf = float(box.conf[0])
bbox = [float(x) for x in box.xyxy[0]]
class_name = _yolo_model.names[cls] if hasattr(_yolo_model, 'names') else f"类别{cls}"
result_parts.append(f"{class_name} (置信度: {conf:.2f}, 边界框: {bbox})")
result_str = "; ".join(result_parts)
return (has_results, result_str)
except Exception as e:
print(f"YOLO detect error: {e}")
return (False, f"检测错误: {str(e)}")

23
main.py
View File

@ -1,12 +1,19 @@
import uvicorn
from fastapi import FastAPI
from PIL import Image # 正确导入
import numpy as np
import uvicorn
from PIL import Image
from fastapi import FastAPI
from core.all import load_model,detect
from ds.config import SERVER_CONFIG
from middle.error_handler import global_exception_handler
from service.user_service import router as user_router
from service.sensitive_service import router as sensitive_router
from service.face_service import router as face_router
from service.device_service import router as device_router
from ws.ws import ws_router, lifespan
# ------------------------------
# 初始化 FastAPI 应用、指定生命周期管理
# ------------------------------
@ -22,6 +29,8 @@ app = FastAPI(
# ------------------------------
app.include_router(user_router)
app.include_router(device_router)
app.include_router(face_router)
app.include_router(sensitive_router)
app.include_router(ws_router)
# ------------------------------
@ -33,11 +42,19 @@ app.add_exception_handler(Exception, global_exception_handler)
# 启动服务
# ------------------------------
if __name__ == "__main__":
# -------------------------- 配置调整 --------------------------
# 模型配置路径(建议改为环境变量)
YOLO_MODEL_PATH = r"/core/models\best.pt"
OCR_CONFIG_PATH = r"/core/config\config.yaml"
# 初始化项目默认端口设为8000避免初始化失败时port未定义
port = int(SERVER_CONFIG.get("port", 8000))
# 启动 UVicorn 服务
uvicorn.run(
app="main:app",
host="0.0.0.0",
port=port,
reload=True,
workers=8,
ws="websockets"
)

View File

@ -8,7 +8,8 @@ from passlib.context import CryptContext
from ds.config import JWT_CONFIG
from ds.db import db
from service.user_service import UserResponse
# 移除这里的 from service.user_service import UserResponse 导入
# ------------------------------
# 密码加密配置
@ -25,6 +26,7 @@ ACCESS_TOKEN_EXPIRE_MINUTES = int(JWT_CONFIG["access_token_expire_minutes"])
# OAuth2 依赖(从请求头获取 Token、格式Bearer <token>
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login")
# ------------------------------
# 密码工具函数
# ------------------------------
@ -32,10 +34,12 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
"""验证明文密码与加密密码是否匹配"""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""对明文密码进行 bcrypt 加密"""
return pwd_context.hash(password)
# ------------------------------
# JWT 工具函数
# ------------------------------
@ -53,11 +57,15 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
# ------------------------------
# 认证依赖(获取当前登录用户)
# ------------------------------
def get_current_user(token: str = Depends(oauth2_scheme)) -> UserResponse:
def get_current_user(token: str = Depends(oauth2_scheme)): # 移除返回类型注解
"""从 Token 中解析用户信息、验证通过后返回当前用户"""
# 延迟导入,打破循环依赖
from schema.user_schema import UserResponse # 在这里导入
# 认证失败异常
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -89,7 +97,7 @@ def get_current_user(token: str = Depends(oauth2_scheme)) -> UserResponse:
raise credentials_exception # 用户不存在
# 转换为 UserResponse 模型(自动校验字段)
return UserResponse(** user)
return UserResponse(**user)
except Exception as e:
raise credentials_exception from e
finally:

View File

@ -1,139 +0,0 @@
import os
import cv2
import numpy as np
import insightface
from insightface.app import FaceAnalysis
class FaceRecognizer:
"""
封装InsightFace人脸识别功能支持从文件夹加载已知人脸。
"""
def __init__(self, known_faces_dir: str):
self.known_faces_dir = known_faces_dir
self.app = self._initialize_insightface()
self.known_faces_embeddings = {}
self.known_faces_names = []
self._load_known_faces()
def _initialize_insightface(self):
"""初始化InsightFace FaceAnalysis应用"""
print("初始化InsightFace引擎...")
try:
app = FaceAnalysis(name='buffalo_l', root='~/.insightface')
app.prepare(ctx_id=0, det_size=(640, 640))
print("InsightFace引擎初始化完成")
return app
except Exception as e:
print(f"InsightFace初始化失败: {e}")
print("请检查依赖是否安装及模型是否可访问")
return None
def _load_known_faces(self):
"""加载已知人脸特征"""
if not os.path.exists(self.known_faces_dir):
print(f"已知人脸目录不存在,已创建: {self.known_faces_dir}")
os.makedirs(self.known_faces_dir, exist_ok=True)
return
print(f"从目录加载人脸特征: {self.known_faces_dir}")
for person_name in os.listdir(self.known_faces_dir):
person_dir = os.path.join(self.known_faces_dir, person_name)
if os.path.isdir(person_dir):
print(f"处理人物: {person_name}")
embeddings = []
for filename in os.listdir(person_dir):
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
image_path = os.path.join(person_dir, filename)
try:
img = cv2.imread(image_path)
if img is None:
print(f"无法读取图片: {image_path},已跳过")
continue
faces = self.app.get(img)
if faces:
embeddings.append(faces[0].embedding)
print(f"提取特征成功: {filename}")
else:
print(f"未检测到人脸: {filename},已跳过")
except Exception as e:
print(f"处理图片出错 {image_path}: {e}")
if embeddings:
self.known_faces_embeddings[person_name] = np.array(embeddings).mean(axis=0)
self.known_faces_names.append(person_name)
print(f"人物 {person_name} 加载完成,共 {len(embeddings)} 张照片")
else:
print(f"人物 {person_name} 无有效特征,已跳过")
print(f"人脸加载完成,共 {len(self.known_faces_names)}")
def recognize(self, frame, threshold=0.4):
"""识别人脸并返回结果"""
if not self.app or not self.known_faces_names:
return False, None, None
faces = self.app.get(frame)
if not faces:
return False, None, None
for face in faces:
for known_name in self.known_faces_names:
known_embedding = self.known_faces_embeddings[known_name]
embedding1 = face.embedding.astype(np.float32)
embedding2 = known_embedding.astype(np.float32)
dot_product = np.dot(embedding1, embedding2)
norm_embedding1 = np.linalg.norm(embedding1)
norm_embedding2 = np.linalg.norm(embedding2)
similarity = 0.0 if (norm_embedding1 == 0 or norm_embedding2 == 0) else (
dot_product / (norm_embedding1 * norm_embedding2)
)
if similarity >= threshold:
print(f"检测到已知人物: {known_name} (相似度: {similarity:.4f})")
return True, known_name, similarity
return False, None, None
def test_single_image(self, image_path: str, threshold=0.4):
"""测试单张图片识别"""
if not os.path.exists(image_path):
print(f"图片不存在: {image_path}")
return False, None, None
frame = cv2.imread(image_path)
if frame is None:
print(f"无法读取图片: {image_path}")
return False, None, None
result, name, similarity = self.recognize(frame, threshold)
if result:
print(f"识别结果: {name} (相似度: {similarity:.4f})")
faces = self.app.get(frame)
for face in faces:
bbox = face.bbox.astype(int)
cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2)
text = f"{name}: {similarity:.2f}"
cv2.putText(frame, text, (bbox[0], bbox[1] - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
cv2.imshow('识别结果', frame)
print("按任意键关闭窗口...")
cv2.waitKey(0)
cv2.destroyAllWindows()
else:
print("未识别到已知人脸")
return result, name, similarity
#
# if __name__ == "__main__":
# recognizer = FaceRecognizer(known_faces_dir="known_faces")
# test_image_path = r"F:\OCR\RapidOCR-main\known_faces\B\14sino-qiu02-master1050.jpg"
# recognizer.test_single_image(test_image_path, threshold=0.4)

View File

@ -1,156 +0,0 @@
import cv2
import numpy as np
import insightface
from insightface.app import FaceAnalysis
from io import BytesIO
from PIL import Image
class BinaryFaceFeatureHandler:
"""
专门处理图片二进制数据的特征提取器,支持分批次接收二进制数据并累积计算平均特征
"""
def __init__(self):
self.app = self._init_insightface()
self.feature_list = [] # 存储所有图片二进制数据提取的特征
def _init_insightface(self):
"""初始化InsightFace引擎"""
try:
print("正在初始化InsightFace引擎...")
app = FaceAnalysis(name='buffalo_l', root='~/.insightface')
app.prepare(ctx_id=0, det_size=(640, 640))
print("InsightFace引擎初始化完成")
return app
except Exception as e:
print(f"InsightFace初始化失败: {e}")
return None
def add_binary_data(self, binary_data):
"""
接收单张图片的二进制数据,提取特征并保存
参数:
binary_data: 图片的二进制数据bytes类型
返回:
成功提取特征时返回 (True, 特征值numpy数组)
失败时返回 (False, None)
"""
if not self.app:
print("引擎未初始化,无法处理")
return False, None
try:
# 直接处理二进制数据:转换为图像格式
img = Image.open(BytesIO(binary_data))
frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
# 提取特征
faces = self.app.get(frame)
if faces:
# 获取当前提取的特征值
current_feature = faces[0].embedding
# 添加到特征列表
self.feature_list.append(current_feature)
print(f"已累计 {len(self.feature_list)} 个特征")
# 返回成功标志和当前特征值
return True,current_feature
else:
print("二进制数据中未检测到人脸")
return False, None
except Exception as e:
print(f"处理二进制数据出错: {e}")
return False, None
def get_average_feature(self, features):
"""
计算多个特征向量的平均值
参数:
features: 特征值列表每个元素可以是字符串格式或numpy数组
例如: [feature1, feature2, ...]
返回:
单一平均特征向量的numpy数组若无可计算数据则返回None
"""
try:
# 验证输入是否为列表且不为空
if not isinstance(features, list) or len(features) == 0:
print("输入必须是包含至少一个特征值的列表")
return None
# 处理每个特征值
processed_features = []
for i, embedding in enumerate(features):
try:
if isinstance(embedding, str):
# 处理包含括号和逗号的字符串格式
embedding_str = embedding.replace('[', '').replace(']', '').replace(',', ' ').strip()
embedding_list = [float(num) for num in embedding_str.split() if num.strip()]
embedding_np = np.array(embedding_list, dtype=np.float32)
else:
embedding_np = np.array(embedding, dtype=np.float32)
# 验证特征值格式
if len(embedding_np.shape) == 1:
processed_features.append(embedding_np)
print(f"已添加第 {i + 1} 个特征值用于计算平均值")
else:
print(f"跳过第 {i + 1} 个特征值,不是一维数组")
except Exception as e:
print(f"处理第 {i + 1} 个特征值时出错: {e}")
# 确保有有效的特征值
if not processed_features:
print("没有有效的特征值用于计算平均值")
return None
# 检查所有特征向量维度是否相同
dims = {feat.shape[0] for feat in processed_features}
if len(dims) > 1:
print(f"特征值维度不一致,无法计算平均值。检测到的维度: {dims}")
return None
# 计算平均值
avg_feature = np.mean(processed_features, axis=0)
print(f"成功计算 {len(processed_features)} 个特征值的平均特征向量,维度: {avg_feature.shape[0]}")
return avg_feature
except Exception as e:
print(f"计算平均特征值时出错: {e}")
return None
# def clear(self):
# """清空已存储的特征数据"""
# self.feature_list = []
# print("已清空所有特征数据")
# # 使用示例
# if __name__ == "__main__":
# handler = BinaryFaceFeatureHandler()
#
# # 模拟接收图片二进制数据
# try:
# # 第一次接收
# with open(r"D:\Git\bin\video\ocr\known_faces\B\102-f.jpg_1140x855.jpg", "rb") as f:
# bin_data1 = f.read()
# success, feature1 = handler.add_binary_data(bin_data1)
# if success:
# print(f"第一次提取的特征值前5个: {feature1[:5]}")
#
# # 第二次接收
# with open(r"D:\Git\bin\video\ocr\known_faces\B\104-1.jpg", "rb") as f:
# bin_data2 = f.read()
# success, feature2 = handler.add_binary_data(bin_data2)
# if success:
# print(f"第二次提取的特征值前5个: {feature2[:5]}")
#
# # 计算平均值
# avg_feature = handler.get_average_feature()
#
# except Exception as e:
# print(f"处理过程出错: {e}")

View File

@ -1,18 +0,0 @@
打倒习近平
打到习近平
打倒毛泽东
打到毛泽东
打到主席
打倒主席
打到共产主义
打倒共产主义
打到共产党
打倒共产党
胖猫
法轮功
法轮大法
法轮大法好
法轮功大法好
法轮
李洪志
习近平

Binary file not shown.

Before

Width:  |  Height:  |  Size: 195 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 208 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 657 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 53 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 8.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 58 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 34 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 155 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 386 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.4 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 62 KiB

View File

@ -1,49 +0,0 @@
#日志文件
import logging
import sys
def setup_logger():
"""
配置一个全局日志记录器,支持输出到控制台和文件。
"""
# 创建一个日志记录器
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger = logging.getLogger("ViolationDetectorLogger")
logger.setLevel(logging.DEBUG) # 设置最低级别为DEBUG
# 如果已经有处理器了,就不要重复添加,防止日志重复打印
if logger.hasHandlers():
return logger
# --- 控制台处理器 ---
console_handler = logging.StreamHandler(sys.stdout)
# 对于控制台我们只显示INFO及以上级别的信息
console_handler.setLevel(logging.INFO)
console_formatter = logging.Formatter(
'%(asctime)s - %(levelname)s - [%(module)s:%(lineno)d] - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
console_handler.setFormatter(console_formatter)
# --- 文件处理器 ---
file_handler = logging.FileHandler("violation_detector.log", mode='a', encoding='utf-8')
# 对于文件我们记录所有DEBUG及以上级别的信息
file_handler.setLevel(logging.DEBUG)
file_formatter = logging.Formatter(
'%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
file_handler.setFormatter(file_formatter)
# 将处理器添加到日志记录器
logger.addHandler(console_handler)
logger.addHandler(file_handler)
return logger
# 创建并导出logger实例
logger = setup_logger()

View File

@ -1,136 +0,0 @@
import os
import cv2
import yaml
from pathlib import Path
from .ocr_violation_detector import OCRViolationDetector
from .yolo_violation_detector import ViolationDetector as YoloViolationDetector
from .face_recognizer import FaceRecognizer
class MultiModelViolationDetector:
"""
多模型违规检测封装类串行调用OCR、人脸识别和YOLO模型任一模型检测到违规即返回结果
"""
def __init__(self,
forbidden_words_path: str,
ocr_config_path: str,
yolo_model_path: str,
known_faces_dir: str,
ocr_confidence_threshold: float = 0.5):
"""
初始化所有检测模型
"""
# 初始化OCR检测器
self.ocr_detector = OCRViolationDetector(
forbidden_words_path=forbidden_words_path,
ocr_config_path=ocr_config_path,
ocr_confidence_threshold=ocr_confidence_threshold
)
# 初始化人脸识别器
self.face_recognizer = FaceRecognizer(
known_faces_dir=known_faces_dir
)
# 初始化YOLO检测器
self.yolo_detector = YoloViolationDetector(
model_path=yolo_model_path
)
print("多模型违规检测器初始化完成")
def detect_violations(self, frame):
"""
串行调用三个检测模型OCR → 人脸识别 → YOLO任一检测到违规即返回结果
"""
# 1. 首先进行OCR违禁词检测
try:
ocr_has_violation, ocr_words, ocr_confs = self.ocr_detector.detect(frame)
if ocr_has_violation:
details = {
"words": ocr_words,
"confidences": ocr_confs
}
print(f"警告: OCR检测到违禁内容: {details}")
return (True, "ocr", details)
except Exception as e:
print(f"错误: OCR检测出错: {str(e)}")
# 2. 接着进行人脸识别检测
try:
face_has_violation, face_name, face_similarity = self.face_recognizer.recognize(frame)
if face_has_violation:
details = {
"name": face_name,
"similarity": face_similarity
}
print(f"警告: 人脸识别到违规人员: {details}")
return (True, "face", details)
except Exception as e:
print(f"错误: 人脸识别出错: {str(e)}")
# 3. 最后进行YOLO目标检测
try:
yolo_results = self.yolo_detector.detect(frame)
if len(yolo_results.boxes) > 0:
details = {
"classes": yolo_results.names,
"boxes": yolo_results.boxes.xyxy.tolist(),
"confidences": yolo_results.boxes.conf.tolist(),
"class_ids": yolo_results.boxes.cls.tolist()
}
print(f"警告: YOLO检测到违规目标: {details}")
return (True, "yolo", details)
except Exception as e:
print(f"错误: YOLO检测出错: {str(e)}")
# 所有检测均未发现违规
return (False, None, None)
def load_config(config_path: str) -> dict:
"""加载YAML配置文件"""
try:
with open(config_path, 'r', encoding='utf-8') as f:
return yaml.safe_load(f)
except FileNotFoundError:
print(f"错误: 配置文件未找到: {config_path}")
raise
except yaml.YAMLError as e:
print(f"错误: 配置文件格式错误: {config_path}, 错误: {str(e)}")
raise
except Exception as e:
print(f"错误: 加载配置文件出错: {str(e)}")
raise
# 使用示例
# if __name__ == "__main__":
# # 加载配置文件
# config = load_config("config.yaml") # 配置文件路径,可根据实际情况修改
#
# # 初始化多模型检测器
# detector = MultiModelViolationDetector(
# forbidden_words_path=config["forbidden_words_path"],
# ocr_config_path=config["ocr_config_path"],
# yolo_model_path=config["yolo_model_path"],
# known_faces_dir=config["known_faces_dir"],
# ocr_confidence_threshold=config.get("ocr_confidence_threshold", 0.5)
# )
#
# # 读取测试图像(可替换为视频帧读取逻辑)
# test_image_path = config.get("test_image_path") # 从配置文件获取测试图片路径
# if test_image_path:
# frame = cv2.imread(test_image_path)
#
# if frame is not None:
# has_violation, violation_type, details = detector.detect_violations(frame)
# if has_violation:
# print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
# else:
# print("未检测到任何违规内容")
# else:
# print(f"无法读取测试图像: {test_image_path}")
# else:
# print("配置文件中未指定测试图像路径")

Binary file not shown.

View File

@ -1,178 +0,0 @@
import os
import cv2
from rapidocr import RapidOCR
class OCRViolationDetector:
"""
封装RapidOCR引擎用于检测图像帧中的违禁词。
核心功能加载违禁词、初始化OCR引擎、单帧图像违禁词检测
"""
def __init__(self,
forbidden_words_path: str,
ocr_config_path: str,
ocr_confidence_threshold: float = 0.5):
"""
初始化OCR引擎和违禁词列表。
Args:
forbidden_words_path (str): 违禁词列表 .txt 文件的路径。
ocr_config_path (str): OCR配置文件如1.yaml的路径。
ocr_confidence_threshold (float): OCR识别结果的置信度阈值0~1
"""
# 加载违禁词
self.forbidden_words = self._load_forbidden_words(forbidden_words_path)
# 初始化RapidOCR引擎
self.ocr_engine = self._initialize_ocr(ocr_config_path)
# 校验核心依赖是否就绪
self._check_dependencies()
# 设置置信度阈值限制在0~1范围
self.OCR_CONFIDENCE_THRESHOLD = max(0.0, min(ocr_confidence_threshold, 1.0))
print(f"OCR置信度阈值已设置范围0~1: {self.OCR_CONFIDENCE_THRESHOLD:.4f}")
def _load_forbidden_words(self, path: str) -> set:
"""
从TXT文件加载违禁词去重、过滤空行支持UTF-8编码
"""
forbidden_words = set()
# 检查文件是否存在
if not os.path.exists(path):
print(f"错误:违禁词文件不存在: {path}")
return forbidden_words
# 读取文件并处理内容
try:
with open(path, 'r', encoding='utf-8') as f:
forbidden_words = {
line.strip() for line in f
if line.strip() # 跳过空行或纯空格行
}
print(f"成功加载违禁词: {len(forbidden_words)} 个(已去重)")
except UnicodeDecodeError:
print(f"错误违禁词文件编码错误需UTF-8: {path}")
except PermissionError:
print(f"错误:无权限读取违禁词文件: {path}")
except Exception as e:
print(f"错误:加载违禁词失败: {str(e)}")
return forbidden_words
def _initialize_ocr(self, config_path: str) -> RapidOCR | None:
"""
初始化RapidOCR引擎校验配置文件、捕获初始化异常
"""
print("开始初始化RapidOCR引擎...")
# 检查配置文件是否存在
if not os.path.exists(config_path):
print(f"错误OCR配置文件不存在: {config_path}")
return None
# 初始化OCR引擎
try:
ocr_engine = RapidOCR(config_path=config_path)
print("RapidOCR引擎初始化成功")
return ocr_engine
except ImportError:
print("错误RapidOCR依赖未安装需执行pip install rapidocr-onnxruntime")
except Exception as e:
print(f"错误RapidOCR初始化失败: {str(e)}")
return None
def _check_dependencies(self) -> None:
"""校验OCR引擎和违禁词列表是否就绪"""
if not self.ocr_engine:
print("警告:⚠️ OCR引擎未就绪违禁词检测功能将禁用")
if not self.forbidden_words:
print("警告:⚠️ 违禁词列表为空,违禁词检测功能将禁用")
def detect(self, frame) -> tuple[bool, list, list]:
"""
对单帧图像进行OCR违禁词检测核心方法
Args:
frame: 输入图像帧NumPy数组BGR格式cv2读取的图像
Returns:
tuple[bool, list, list]:
- 第一个元素是否检测到违禁词True/False
- 第二个元素:检测到的违禁词列表(空列表表示无违禁词);
- 第三个元素:对应违禁词的置信度列表(与违禁词列表一一对应)。
"""
# 初始化返回结果
has_violation = False
violation_words = []
violation_confs = []
# 前置校验
if frame is None or frame.size == 0:
print("警告输入图像帧为空或无效跳过OCR检测")
return has_violation, violation_words, violation_confs
if not self.ocr_engine or not self.forbidden_words:
print("OCR引擎未就绪或违禁词为空跳过OCR检测")
return has_violation, violation_words, violation_confs
try:
# 执行OCR识别
print("开始执行OCR识别...")
ocr_result = self.ocr_engine(frame)
print(f"RapidOCR原始结果: {ocr_result}")
# 校验OCR结果是否有效
if ocr_result is None:
print("OCR识别未返回任何结果图像无文本或识别失败")
return has_violation, violation_words, violation_confs
# 检查txts和scores是否存在且不为None
if not hasattr(ocr_result, 'txts') or ocr_result.txts is None:
print("警告OCR结果中txts为None或不存在")
return has_violation, violation_words, violation_confs
if not hasattr(ocr_result, 'scores') or ocr_result.scores is None:
print("警告OCR结果中scores为None或不存在")
return has_violation, violation_words, violation_confs
# 转为列表并去None
if not isinstance(ocr_result.txts, (list, tuple)):
print(f"警告OCR txts不是可迭代类型实际类型: {type(ocr_result.txts)}")
texts = []
else:
texts = [txt.strip() for txt in ocr_result.txts if txt and isinstance(txt, str)]
if not isinstance(ocr_result.scores, (list, tuple)):
print(f"警告OCR scores不是可迭代类型实际类型: {type(ocr_result.scores)}")
confidences = []
else:
confidences = [conf for conf in ocr_result.scores if conf and isinstance(conf, (int, float))]
# 校验文本和置信度列表长度是否一致
if len(texts) != len(confidences):
print(f"警告OCR文本与置信度数量不匹配文本{len(texts)}个,置信度{len(confidences)}个),跳过检测")
return has_violation, violation_words, violation_confs
if len(texts) == 0:
print("OCR未识别到任何有效文本")
return has_violation, violation_words, violation_confs
# 遍历识别结果,筛选违禁词
print(f"开始筛选违禁词(阈值{self.OCR_CONFIDENCE_THRESHOLD:.4f}")
for text, conf in zip(texts, confidences):
if conf < self.OCR_CONFIDENCE_THRESHOLD:
print(f"文本 '{text}' 置信度{conf:.4f} < 阈值,跳过")
continue
matched_words = [word for word in self.forbidden_words if word in text]
if matched_words:
has_violation = True
violation_words.extend(matched_words)
violation_confs.extend([conf] * len(matched_words))
print(f"警告:检测到违禁词: {matched_words}(来源文本: '{text}',置信度: {conf:.4f}")
except Exception as e:
print(f"错误OCR检测过程异常: {str(e)}")
return has_violation, violation_words, violation_confs

View File

@ -1,47 +0,0 @@
from ultralytics import YOLO
import cv2
class ViolationDetector:
"""
用于加载YOLOv8 .pt模型并进行违规内容检测的类。
"""
def __init__(self, model_path):
"""
初始化检测器。
Args:
model_path (str): YOLO .pt模型的路径。
"""
print(f"正在从 '{model_path}' 加载YOLO模型...")
self.model = YOLO(model_path)
print("YOLO模型加载成功。")
def detect(self, frame):
"""
对单帧图像进行目标检测。
Args:
frame: 输入的图像帧 (NumPy数组, BGR格式)。
Returns:
ultralytics.engine.results.Results: YOLO的检测结果对象。
"""
# conf可以根据您的模型效果进行调整
# --- 为了测试,我们暂时将置信度调低,例如 0.2 ---
results = self.model(frame, conf=0.2)
return results[0]
def draw_boxes(self, frame, result):
"""
在图像帧上绘制检测框。
Args:
frame: 原始图像帧。
result: YOLO的检测结果对象。
Returns:
numpy.ndarray: 绘制了检测框的图像帧。
"""
# 使用YOLO自带的plot功能方便快捷
annotated_frame = result.plot()
return annotated_frame

View File

@ -1,164 +0,0 @@
import queue
import asyncio
import aiohttp
import threading
import time
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration
from aiortc.mediastreams import MediaStreamTrack
# 创建一个长度为1的队列用于生产者和消费者之间的通信
frame_queue = queue.Queue(maxsize=1)
class VideoTrack(MediaStreamTrack):
"""自定义视频轨道类继承自MediaStreamTrack"""
kind = "video"
def __init__(self, max_frames=100):
super().__init__()
self.frames = queue.Queue(maxsize=max_frames)
async def recv(self):
return await super().recv()
def webrtc_producer(webrtc_url):
"""
生产者方法从WEBRTC读取视频帧并放入队列
仅当队列空时才放入新帧,否则丢弃
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 创建RTCPeerConnection对象不使用ICE服务器
pc = RTCPeerConnection(RTCConfiguration(iceServers=[]))
video_track = VideoTrack()
pc.addTrack(video_track)
@pc.on("track")
async def on_track(track):
if track.kind == "video":
print("接收到视频轨道,开始接收视频帧")
while True:
# 从轨道接收视频帧
frame = await track.recv()
# 转换为BGR24格式的NumPy数组
frame_bgr24 = frame.to_ndarray(format='bgr24')
# 检查队列是否为空,为空则加入,否则丢弃
if frame_queue.empty():
try:
frame_queue.put_nowait(frame_bgr24)
print("帧已放入队列")
except queue.Full:
print("队列已满,丢弃帧")
else:
print("队列非空,丢弃帧")
async def main():
# 创建并发送SDP Offer
offer = await pc.createOffer()
print("已创建本地SDP Offer")
await pc.setLocalDescription(offer)
# 发送Offer到服务器并接收Answer
async with aiohttp.ClientSession() as session:
print(f"开始向服务器 {webrtc_url} 发送SDP Offer")
async with session.post(
webrtc_url,
data=offer.sdp.encode(),
headers={
"Content-Type": "application/sdp",
"Content-Length": str(len(offer.sdp))
},
ssl=False
) as response:
print("已接收到服务器的响应")
answer_sdp = await response.text()
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer'))
# 保持连接
try:
while True:
await asyncio.sleep(0.1)
except KeyboardInterrupt:
pass
finally:
print("关闭RTCPeerConnection")
await pc.close()
try:
loop.run_until_complete(main())
finally:
loop.close()
def frame_consumer(ip):
"""
消费者方法:从队列中读取帧并处理
每次处理后休眠200ms模拟延迟
"""
print("消费者启动,开始等待帧...")
try:
while True:
# 阻塞等待队列中的帧
frame = frame_queue.get()
print(f"消费帧,大小: {frame.shape}")
has_violation, violations, confidences = OCRViolationDetector.detect(frame)
# 输出检测结果
if has_violation:
detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
for word, conf in zip(violations, confidences):
detector.logger.info(f"- {word} (置信度: {conf:.4f})")
else:
detector.logger.info("图片中未检测到违禁词")
# 标记任务完成
frame_queue.task_done()
except KeyboardInterrupt:
print("消费者退出")
def start_webrtc_stream(ip, webrtc_url):
"""
启动WebRTC视频流处理的主方法
参数: webrtc_url - WebRTC服务器地址
"""
print(f"开始连接到WebRTC服务器: {webrtc_url}")
# 启动生产者线程
producer_thread = threading.Thread(
target=webrtc_producer,
args=(webrtc_url,),
daemon=True,
name="webrtc-producer"
)
# 启动消费者线程
consumer_thread = threading.Thread(
target=frame_consumer(ip),
daemon=True,
name="frame-consumer"
)
producer_thread.start()
consumer_thread.start()
print("生产者和消费者线程已启动")
try:
# 保持主线程运行
while True:
time.sleep(1)
except KeyboardInterrupt:
print("程序正在退出...")
if __name__ == "__main__":
# 示例用法
# 实际使用时替换为真实的WebRTC服务器地址
webrtc_server_url = "http://192.168.110.65:1985/rtc/v1/whep/?app=live&stream=677a4845aa48cb8526c811ad56fc5e60"
start_webrtc_stream(webrtc_server_url)

View File

@ -1,101 +0,0 @@
import asyncio
import logging
import cv2
import time
# 配置日志与WHEP代码保持一致的日志风格
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("rtmp_video_puller")
async def rtmp_pull_video_stream(rtmp_url):
"""
通过RTMP从指定URL拉取视频流并在收到每一帧时打印消息
功能与WHEP拉流函数对齐流状态反馈、帧信息打印、帧率统计、异常处理
Args:
rtmp_url: RTMP流的URL地址如 rtmp://xxx/live/stream_key
"""
cap = None # 初始化视频捕获对象
try:
# 1. 异步打开RTMP流指定FFmpeg后端确保RTMP兼容性同步操作通过to_thread避免阻塞事件循环
cap = await asyncio.to_thread(
cv2.VideoCapture,
rtmp_url,
cv2.CAP_FFMPEG # 必须指定FFmpeg后端RTMP协议依赖该后端解析
)
# 2. 检查RTMP流是否成功打开
is_opened = await asyncio.to_thread(cap.isOpened)
if not is_opened:
raise Exception(f"RTMP流打开失败: {rtmp_url}请检查URL有效性和FFmpeg环境")
# 3. 异步获取RTMP流基础信息分辨率、帧率
width = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_WIDTH)
height = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_HEIGHT)
fps = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FPS)
# 处理异常情况部分RTMP流未返回帧率时默认30FPS
fps = fps if fps > 0 else 30.0
# 分辨率转为整数(视频尺寸必然是整数)
width, height = int(width), int(height)
# 打印流初始化成功信息与WHEP连接成功信息风格一致
print(f"RTMP流状态: 已成功连接")
print(f"流基础信息: 分辨率 {width}x{height} | 配置帧率 {fps:.2f} FPS")
print("开始接收视频帧...(按 Ctrl+C 中断)")
# 4. 初始化帧统计参数
frame_count = 0 # 总接收帧数
start_time = time.time() # 统计起始时间
# 5. 循环异步读取视频帧(核心逻辑)
while True:
# 异步读取一帧cv2.read是同步操作用to_thread适配异步环境
ret, frame = await asyncio.to_thread(cap.read)
# 检查帧是否读取成功(流中断/结束时ret为False
if not ret:
print(f"RTMP流状态: 帧读取失败(可能流已中断或结束)")
break
# 帧计数累加
frame_count += 1
# 6. 打印当前帧基础信息与WHEP帧信息打印风格对齐
print(f"收到帧 (第{frame_count}帧)")
print(f" 帧尺寸: {width}x{height}")
print(f" 配置帧率: {fps:.2f} FPS")
# 7. 每100帧统计一次实际接收帧率补充性能监控与原RTMP示例逻辑一致
if frame_count % 100 == 0:
elapsed_time = time.time() - start_time
actual_fps = frame_count / elapsed_time # 实际接收帧率(可能低于配置帧率)
print(f"---- 帧统计: 累计{frame_count}帧 | 实际平均帧率 {actual_fps:.2f} FPS ----")
# (可选)帧数据处理入口:如需处理帧(如推流、分析),可在此处添加逻辑
# 示例yield frame (若需生成器模式,可调整函数为异步生成器)
# 8. 异常处理(覆盖用户中断、通用错误)
except KeyboardInterrupt:
print(f"\n用户操作: 已通过 Ctrl+C 中断程序")
except Exception as e:
# 日志记录详细错误(便于问题排查),同时打印用户可见信息
logger.error(f"RTMP流处理异常: {str(e)}", exc_info=True)
print(f"错误信息: {str(e)}")
finally:
# 9. 资源释放(无论成功/失败都确保释放,避免内存泄漏)
if cap is not None:
await asyncio.to_thread(cap.release)
print(f"\n资源释放: RTMP流已关闭")
print(f"最终统计: 共接收 {frame_count if 'frame_count' in locals() else 0}")
if __name__ == "__main__":
RTMP_URL = "rtmp://192.168.110.25:1935/live/473b95a47e338301cbd96809ea7ac416"
# 运行RTMP拉流任务与WHEP一致的异步执行方式
try:
asyncio.run(rtmp_pull_video_stream(RTMP_URL))
except Exception as e:
print(f"程序启动失败: {str(e)}")

View File

@ -23,7 +23,7 @@ class FaceResponse(BaseModel):
"""人脸记录响应模型仍包含ID由数据库生成后返回"""
id: int = Field(..., description="主键ID数据库自增")
name: str = Field(None, description="名称")
eigenvalue: str = Field(None, description="特征(暂为None")
eigenvalue: str | None = Field(None, description="特征(可为空")
created_at: datetime = Field(..., description="记录创建时间")
updated_at: datetime = Field(..., description="记录更新时间")

View File

@ -1,6 +1,6 @@
import json
from fastapi import APIRouter, Query, HTTPException
from fastapi import APIRouter, Query, HTTPException,Request
from mysql.connector import Error as MySQLError
from ds.db import db
@ -108,7 +108,7 @@ def update_online_status_by_ip(client_ip: str, online_status: int) -> bool:
# 原有接口保持不变
# ------------------------------
@router.post("/add", response_model=APIResponse, summary="创建设备信息")
async def create_device(device_data: DeviceCreateRequest):
async def create_device(device_data: DeviceCreateRequest, request: Request): # 注入Request对象
# 原有代码保持不变
conn = None
cursor = None
@ -125,11 +125,10 @@ async def create_device(device_data: DeviceCreateRequest):
return APIResponse(
code=200,
message=f"设备IP {device_data.ip} 已存在,返回已有设备信息",
data=DeviceResponse(**existing_device)
data=DeviceResponse(** existing_device)
)
from fastapi import Request
request = Request(scope={"type": "http"})
# 直接使用注入的request对象获取用户代理
user_agent = request.headers.get("User-Agent", "").lower()
if user_agent == "default":
@ -184,7 +183,6 @@ async def create_device(device_data: DeviceCreateRequest):
finally:
db.close_connection(conn, cursor)
@router.get("/", response_model=APIResponse, summary="获取设备列表(支持筛选分页)")
async def get_device_list(
page: int = Query(1, ge=1, description="页码默认第1页"),

View File

@ -6,15 +6,15 @@ from schema.face_schema import FaceCreateRequest, FaceUpdateRequest, FaceRespons
from schema.response_schema import APIResponse
from middle.auth_middleware import get_current_user
from schema.user_schema import UserResponse
from ocr.feature_extraction import BinaryFaceFeatureHandler
from util.face_util import add_binary_data,get_average_feature
#初始化实例
router = APIRouter(
prefix="/faces",
tags=["人脸管理"]
)
# 创建 BinaryFaceFeatureHandler 的实例
binary_face_feature_handler = BinaryFaceFeatureHandler()
# ------------------------------
@ -33,6 +33,8 @@ async def create_face(
- ID 由数据库自动生成,无需前端传入
- 暂不处理文件内容eigenvalue 设为 None
"""
# 调用你的方法
conn = None
cursor = None
try:
@ -45,14 +47,24 @@ async def create_face(
# 把文件转为二进制数组
file_content = await file.read()
# 调用人脸识别得到特征值
# 计算特征值
flag, eigenvalue = add_binary_data(file_content)
if flag == False:
raise HTTPException(
status_code=500,
detail="未检测到人脸"
)
# 打印数组长度
print(f"文件大小:{len(file_content)} 字节")
# 2. 插入数据库:无需传 ID自增只传 name 和 eigenvalueNone
insert_query = """
INSERT INTO face (name, eigenvalue)
VALUES (%s, %s)
"""
cursor.execute(insert_query, (face_create.name, None))
cursor.execute(insert_query, (face_create.name, str(eigenvalue)))
conn.commit()
# 3. 获取数据库自动生成的 ID关键用 LAST_INSERT_ID() 查刚插入的记录)
@ -60,19 +72,45 @@ async def create_face(
cursor.execute(select_new_query)
created_face = cursor.fetchone()
if not created_face:
raise HTTPException(
status_code=500,
detail="创建人脸记录成功,但无法获取新创建的记录"
)
return APIResponse(
code=201,
message=f"人脸记录创建成功ID{created_face['id']},文件名:{file.filename}",
data=FaceResponse(**created_face)
data=FaceResponse(** created_face)
)
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"创建人脸记录失败:{str(e)}") from e
# 改为使用HTTPException
raise HTTPException(
status_code=500,
detail=f"创建人脸记录失败:{str(e)}"
) from e
except Exception as e:
# 捕获其他可能的异常
raise HTTPException(
status_code=500,
detail=f"服务器错误:{str(e)}"
) from e
finally:
await file.close() # 关闭文件流
db.close_connection(conn, cursor)
# 调用人脸识别得到特征值(这里可以添加你的人脸识别逻辑)
flag, eigenvalue = add_binary_data(file_content)
if flag == False:
raise HTTPException(
status_code=500,
detail="未检测到人脸"
)
# 将 eigenvalue 转为 str
eigenvalue = str(eigenvalue)
# ------------------------------
# 2. 获取单个人脸记录不变用自增ID查询
@ -104,18 +142,21 @@ async def get_face(
data=FaceResponse(**face)
)
except MySQLError as e:
raise Exception(f"查询人脸记录失败:{str(e)}") from e
# 改为使用HTTPException
raise HTTPException(
status_code=500,
detail=f"查询人脸记录失败:{str(e)}"
) from e
finally:
db.close_connection(conn, cursor)
# 后续 3.获取所有、4.更新、5.删除 接口(不变,仅用数据库自增的 ID 操作,无需修改
# 后续 3.获取所有、4.更新、5.删除 接口(修复异常处理
# ------------------------------
# 3. 获取所有人脸记录(不变)
# ------------------------------
@router.get("", response_model=APIResponse, summary="获取所有人脸记录")
async def get_all_faces(
current_user: UserResponse = Depends(get_current_user)
):
conn = None
cursor = None
@ -130,10 +171,13 @@ async def get_all_faces(
return APIResponse(
code=200,
message="所有人脸记录查询成功",
data=[FaceResponse(**face) for face in faces]
data=[FaceResponse(** face) for face in faces]
)
except MySQLError as e:
raise Exception(f"查询所有人脸记录失败:{str(e)}") from e
raise HTTPException(
status_code=500,
detail=f"查询所有人脸记录失败:{str(e)}"
) from e
finally:
db.close_connection(conn, cursor)
@ -194,7 +238,10 @@ async def update_face(
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"更新人脸记录失败:{str(e)}") from e
raise HTTPException(
status_code=500,
detail=f"更新人脸记录失败:{str(e)}"
) from e
finally:
db.close_connection(conn, cursor)
@ -234,7 +281,10 @@ async def delete_face(
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"删除人脸记录失败:{str(e)}") from e
raise HTTPException(
status_code=500,
detail=f"删除人脸记录失败:{str(e)}"
) from e
finally:
db.close_connection(conn, cursor)
@ -249,38 +299,43 @@ def get_all_face_name_with_eigenvalue() -> dict:
conn = None
cursor = None
try:
# 1. 建立数据库连接并获取游标dictionary=True使结果以字典形式返回
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 只查询需要的字段,提高效率
# 2. 执行SQL查询只获取name非空的记录减少数据传输
query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL"
cursor.execute(query)
faces = cursor.fetchall()
faces = cursor.fetchall() # 返回结果:列表套字典,如 [{"name":"张三","eigenvalue":...}, ...]
# 先收集所有名称对应的特征值列表(处理重复名称)
# 3. 收集同一名称对应的所有特征值(处理名称重复场景
name_to_eigenvalues = {}
for face in faces:
name = face["name"]
eigenvalue = face["eigenvalue"]
# 若名称已存在,追加特征值;否则新建列表存储
if name in name_to_eigenvalues:
name_to_eigenvalues[name].append(eigenvalue)
else:
name_to_eigenvalues[name] = [eigenvalue]
# 构建最终字典:重复名称取平均特征值,唯一名称直接取特征值
# 4. 构建最终字典:重复名称取平均,唯一名称直接取特征值
face_dict = {}
for name, eigenvalues in name_to_eigenvalues.items():
print("调用的特征值是:" + eigenvalues)
# 处理特征值:多个则求平均,单个则直接使用
if len(eigenvalues) > 1:
# 调用平均特征值计算方法
face_dict[name] = binary_face_feature_handler.get_average_feature(eigenvalues)
# 调用外部方法计算平均特征值需确保binary_face_feature_handler已正确导入
face_dict[name] = get_average_feature(eigenvalues)
else:
# 取列表中唯一的特征值避免value为列表类型
face_dict[name] = eigenvalues[0]
return face_dict
except MySQLError as e:
# 捕获数据库异常,添加上下文信息后重新抛出(便于定位问题)
raise Exception(f"获取人脸名称与特征值失败:{str(e)}") from e
finally:
# 确保资源释放
# 5. 无论是否异常,均释放数据库连接和游标(避免资源泄漏)
db.close_connection(conn, cursor)

145
util/face_util.py Normal file
View File

@ -0,0 +1,145 @@
import cv2
import numpy as np
import insightface
from insightface.app import FaceAnalysis
from io import BytesIO
from PIL import Image
# 全局变量存储InsightFace引擎和特征列表
_insightface_app = None
_feature_list = []
def init_insightface():
"""初始化InsightFace引擎"""
global _insightface_app
try:
print("正在初始化InsightFace引擎...")
app = FaceAnalysis(name='buffalo_l', root='~/.insightface')
app.prepare(ctx_id=0, det_size=(640, 640))
print("InsightFace引擎初始化完成")
_insightface_app = app
return app
except Exception as e:
print(f"InsightFace初始化失败: {e}")
return None
def add_binary_data(binary_data):
"""
接收单张图片的二进制数据,提取特征并保存
参数:
binary_data: 图片的二进制数据bytes类型
返回:
成功提取特征时返回 (True, 特征值numpy数组)
失败时返回 (False, None)
"""
global _insightface_app, _feature_list
if not _insightface_app:
print("引擎未初始化,无法处理")
return False, None
try:
# 直接处理二进制数据:转换为图像格式
img = Image.open(BytesIO(binary_data))
frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
# 提取特征
faces = _insightface_app.get(frame)
if faces:
# 获取当前提取的特征值
current_feature = faces[0].embedding
# 添加到特征列表
_feature_list.append(current_feature)
print(f"已累计 {len(_feature_list)} 个特征")
# 返回成功标志和当前特征值
return True, current_feature
else:
print("二进制数据中未检测到人脸")
return False, None
except Exception as e:
print(f"处理二进制数据出错: {e}")
return False, None
def get_average_feature(features=None):
"""
计算多个特征向量的平均值
参数:
features: 可选特征值列表。如果未提供则使用全局存储的_feature_list
每个元素可以是字符串格式或numpy数组
返回:
单一平均特征向量的numpy数组若无可计算数据则返回None
"""
global _feature_list
# 如果未提供features参数则使用全局特征列表
if features is None:
features = _feature_list
try:
# 验证输入是否为列表且不为空
if not isinstance(features, list) or len(features) == 0:
print("输入必须是包含至少一个特征值的列表")
return None
# 处理每个特征值
processed_features = []
for i, embedding in enumerate(features):
try:
if isinstance(embedding, str):
# 处理包含括号和逗号的字符串格式
embedding_str = embedding.replace('[', '').replace(']', '').replace(',', ' ').strip()
embedding_list = [float(num) for num in embedding_str.split() if num.strip()]
embedding_np = np.array(embedding_list, dtype=np.float32)
else:
embedding_np = np.array(embedding, dtype=np.float32)
# 验证特征值格式
if len(embedding_np.shape) == 1:
processed_features.append(embedding_np)
print(f"已添加第 {i + 1} 个特征值用于计算平均值")
else:
print(f"跳过第 {i + 1} 个特征值,不是一维数组")
except Exception as e:
print(f"处理第 {i + 1} 个特征值时出错: {e}")
# 确保有有效的特征值
if not processed_features:
print("没有有效的特征值用于计算平均值")
return None
# 检查所有特征向量维度是否相同
dims = {feat.shape[0] for feat in processed_features}
if len(dims) > 1:
print(f"特征值维度不一致,无法计算平均值。检测到的维度: {dims}")
return None
# 计算平均值
avg_feature = np.mean(processed_features, axis=0)
print(f"成功计算 {len(processed_features)} 个特征值的平均特征向量,维度: {avg_feature.shape[0]}")
return avg_feature
except Exception as e:
print(f"计算平均特征值时出错: {e}")
return None
def clear_features():
"""清空已存储的特征数据"""
global _feature_list
_feature_list = []
print("已清空所有特征数据")
def get_feature_list():
"""获取当前存储的特征列表"""
global _feature_list
return _feature_list.copy() # 返回副本防止外部直接修改

482
ws.html
View File

@ -1,482 +0,0 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>WebSocket 测试工具</title>
<style>
* {
box-sizing: border-box;
margin: 0;
padding: 0;
font-family: 'Arial', 'Microsoft YaHei', sans-serif;
}
body {
max-width: 1200px;
margin: 20px auto;
padding: 0 20px;
background-color: #f5f7fa;
}
.container {
background: white;
border-radius: 8px;
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
padding: 25px;
margin-bottom: 20px;
}
h1 {
color: #2c3e50;
margin-bottom: 20px;
font-size: 24px;
border-bottom: 2px solid #3498db;
padding-bottom: 10px;
}
.status-bar {
display: flex;
align-items: center;
gap: 15px;
margin-bottom: 20px;
padding: 12px 15px;
background-color: #f8f9fa;
border-radius: 6px;
}
.status-label {
font-weight: bold;
color: #495057;
}
.status-value {
padding: 4px 10px;
border-radius: 4px;
font-weight: bold;
}
.status-connected {
background-color: #d4edda;
color: #155724;
}
.status-disconnected {
background-color: #f8d7da;
color: #721c24;
}
.status-connecting {
background-color: #fff3cd;
color: #856404;
}
.btn {
padding: 8px 16px;
border: none;
border-radius: 4px;
cursor: pointer;
font-size: 14px;
font-weight: 500;
transition: background-color 0.2s;
}
.btn-primary {
background-color: #3498db;
color: white;
}
.btn-primary:hover {
background-color: #2980b9;
}
.btn-danger {
background-color: #e74c3c;
color: white;
}
.btn-danger:hover {
background-color: #c0392b;
}
.btn-success {
background-color: #2ecc71;
color: white;
}
.btn-success:hover {
background-color: #27ae60;
}
.control-group {
display: flex;
gap: 15px;
margin-bottom: 20px;
align-items: center;
}
.input-group {
display: flex;
gap: 10px;
align-items: center;
}
.input-group label {
color: #495057;
font-weight: 500;
}
.input-group input, .input-group select {
padding: 8px 12px;
border: 1px solid #ced4da;
border-radius: 4px;
font-size: 14px;
}
.message-area {
margin-top: 20px;
}
.message-input {
width: 100%;
height: 100px;
padding: 12px;
border: 1px solid #ced4da;
border-radius: 6px;
resize: none;
font-size: 14px;
margin-bottom: 10px;
}
.log-area {
width: 100%;
height: 300px;
padding: 15px;
border: 1px solid #ced4da;
border-radius: 6px;
background-color: #f8f9fa;
overflow-y: auto;
font-size: 14px;
line-height: 1.6;
}
.log-item {
margin-bottom: 8px;
padding-bottom: 8px;
border-bottom: 1px dashed #e9ecef;
}
.log-time {
color: #6c757d;
font-size: 12px;
margin-right: 10px;
}
.log-send {
color: #2980b9;
}
.log-receive {
color: #27ae60;
}
.log-status {
color: #856404;
}
.log-error {
color: #e74c3c;
}
</style>
</head>
<body>
<div class="container">
<h1>WebSocket 测试工具</h1>
<!-- 连接状态区 -->
<div class="status-bar">
<div class="status-label">连接状态:</div>
<div id="connectionStatus" class="status-value status-disconnected">未连接</div>
<div class="status-label">服务地址:</div>
<div id="wsUrl" class="status-value">ws://192.168.110.25:8000/ws</div>
<div class="status-label">连接时间:</div>
<div id="connectTime" class="status-value">-</div>
</div>
<!-- 控制按钮区 -->
<div class="control-group">
<button id="connectBtn" class="btn btn-primary">建立连接</button>
<button id="disconnectBtn" class="btn btn-danger" disabled>断开连接</button>
<!-- 心跳控制 -->
<div class="input-group">
<label>自动心跳:</label>
<select id="autoHeartbeat">
<option value="on">开启</option>
<option value="off">关闭</option>
</select>
<label>间隔(秒)</label>
<input type="number" id="heartbeatInterval" value="30" min="10" max="120" style="width: 80px;">
<button id="sendHeartbeatBtn" class="btn btn-success">手动发送心跳</button>
</div>
</div>
<!-- 自定义消息发送区 -->
<div class="message-area">
<h3>发送自定义消息</h3>
<textarea id="messageInput" class="message-input"
placeholder='示例:{"type":"test","content":"Hello WebSocket"}'>{"type":"test","content":"Hello WebSocket"}</textarea>
<button id="sendMessageBtn" class="btn btn-primary" disabled>发送消息</button>
</div>
<!-- 日志显示区 -->
<div class="message-area">
<h3>消息日志</h3>
<div id="logContainer" class="log-area">
<div class="log-item"><span class="log-time">[加载完成]</span> 请点击「建立连接」开始测试</div>
</div>
<button id="clearLogBtn" class="btn btn-primary" style="margin-top: 10px;">清空日志</button>
</div>
</div>
<script>
// 全局变量
let ws = null;
let heartbeatTimer = null;
const wsUrl = "ws://192.168.110.25:8000/ws";
// DOM 元素
const connectionStatus = document.getElementById('connectionStatus');
const connectTime = document.getElementById('connectTime');
const connectBtn = document.getElementById('connectBtn');
const disconnectBtn = document.getElementById('disconnectBtn');
const sendMessageBtn = document.getElementById('sendMessageBtn');
const sendHeartbeatBtn = document.getElementById('sendHeartbeatBtn');
const autoHeartbeat = document.getElementById('autoHeartbeat');
const heartbeatInterval = document.getElementById('heartbeatInterval');
const messageInput = document.getElementById('messageInput');
const logContainer = document.getElementById('logContainer');
const clearLogBtn = document.getElementById('clearLogBtn');
// 工具函数:添加日志
function addLog(content, type = 'status') {
const now = new Date().toLocaleString('zh-CN', {
year: 'numeric', month: '2-digit', day: '2-digit',
hour: '2-digit', minute: '2-digit', second: '2-digit'
});
const logItem = document.createElement('div');
logItem.className = 'log-item';
let logClass = '';
switch (type) {
case 'send':
logClass = 'log-send';
break;
case 'receive':
logClass = 'log-receive';
break;
case 'error':
logClass = 'log-error';
break;
default:
logClass = 'log-status';
}
logItem.innerHTML = `<span class="log-time">[${now}]</span> <span class="${logClass}">${content}</span>`;
logContainer.appendChild(logItem);
// 滚动到最新日志
logContainer.scrollTop = logContainer.scrollHeight;
}
// 工具函数格式化JSON便于日志显示
function formatJson(jsonStr) {
try {
const obj = JSON.parse(jsonStr);
return JSON.stringify(obj, null, 2);
} catch (e) {
return jsonStr; // 非JSON格式直接返回
}
}
// 建立WebSocket连接
function connectWebSocket() {
if (ws) {
addLog('已存在连接,无需重复建立', 'error');
return;
}
try {
ws = new WebSocket(wsUrl);
// 连接成功
ws.onopen = function () {
connectionStatus.className = 'status-value status-connected';
connectionStatus.textContent = '已连接';
const now = new Date().toLocaleString('zh-CN');
connectTime.textContent = now;
addLog(`连接成功!服务地址:${wsUrl}`, 'status');
// 更新按钮状态
connectBtn.disabled = true;
disconnectBtn.disabled = false;
sendMessageBtn.disabled = false;
// 开启自动心跳(默认开启)
if (autoHeartbeat.value === 'on') {
startAutoHeartbeat();
}
};
// 接收消息
ws.onmessage = function (event) {
const message = event.data;
addLog(`收到消息:\n${formatJson(message)}`, 'receive');
};
// 连接关闭
ws.onclose = function (event) {
connectionStatus.className = 'status-value status-disconnected';
connectionStatus.textContent = '已断开';
addLog(`连接断开!代码:${event.code},原因:${event.reason || '未知'}`, 'status');
// 清除自动心跳
stopAutoHeartbeat();
// 更新按钮状态
connectBtn.disabled = false;
disconnectBtn.disabled = true;
sendMessageBtn.disabled = true;
// 重置WebSocket对象
ws = null;
};
// 连接错误
ws.onerror = function (error) {
addLog(`连接错误:${error.message || '未知错误'}`, 'error');
};
} catch (e) {
addLog(`建立连接失败:${e.message}`, 'error');
ws = null;
}
}
// 断开WebSocket连接
function disconnectWebSocket() {
if (!ws) {
addLog('当前无连接,无需断开', 'error');
return;
}
ws.close(1000, '手动断开连接');
}
// 发送心跳消息(符合约定格式:{"timestamp":xxxxx, "type":"heartbeat"}
function sendHeartbeat() {
if (!ws || ws.readyState !== WebSocket.OPEN) {
addLog('发送心跳失败:当前无有效连接', 'error');
return;
}
const heartbeatMsg = {
timestamp: Date.now(), // 当前毫秒时间戳
type: "heartbeat"
};
const msgStr = JSON.stringify(heartbeatMsg);
ws.send(msgStr);
addLog(`发送心跳:\n${formatJson(msgStr)}`, 'send');
}
// 开启自动心跳
function startAutoHeartbeat() {
// 先停止已有定时器
stopAutoHeartbeat();
const interval = parseInt(heartbeatInterval.value) * 1000;
if (isNaN(interval) || interval < 10000) {
addLog('自动心跳间隔无效已重置为30秒', 'error');
heartbeatInterval.value = 30;
return startAutoHeartbeat();
}
addLog(`开启自动心跳,间隔:${heartbeatInterval.value}`, 'status');
heartbeatTimer = setInterval(sendHeartbeat, interval);
}
// 停止自动心跳
function stopAutoHeartbeat() {
if (heartbeatTimer) {
clearInterval(heartbeatTimer);
heartbeatTimer = null;
addLog('已停止自动心跳', 'status');
}
}
// 发送自定义消息
function sendCustomMessage() {
if (!ws || ws.readyState !== WebSocket.OPEN) {
addLog('发送消息失败:当前无有效连接', 'error');
return;
}
const msgStr = messageInput.value.trim();
if (!msgStr) {
addLog('发送消息失败:消息内容不能为空', 'error');
return;
}
try {
// 验证JSON格式可选仅提示不强制
JSON.parse(msgStr);
ws.send(msgStr);
addLog(`发送自定义消息:\n${formatJson(msgStr)}`, 'send');
} catch (e) {
addLog(`JSON格式错误${e.message},仍尝试发送原始内容`, 'error');
ws.send(msgStr);
addLog(`发送自定义消息非JSON\n${msgStr}`, 'send');
}
}
// 绑定按钮事件
connectBtn.addEventListener('click', connectWebSocket);
disconnectBtn.addEventListener('click', disconnectWebSocket);
sendMessageBtn.addEventListener('click', sendCustomMessage);
sendHeartbeatBtn.addEventListener('click', sendHeartbeat);
clearLogBtn.addEventListener('click', () => {
logContainer.innerHTML = '<div class="log-item"><span class="log-time">[日志已清空]</span> 请继续操作...</div>';
});
// 自动心跳开关变更事件
autoHeartbeat.addEventListener('change', function () {
if (ws && ws.readyState === WebSocket.OPEN) {
if (this.value === 'on') {
startAutoHeartbeat();
} else {
stopAutoHeartbeat();
}
} else {
addLog('需先建立有效连接才能控制自动心跳', 'error');
// 重置选择
this.value = 'off';
}
});
// 心跳间隔变更事件(实时生效)
heartbeatInterval.addEventListener('change', function () {
if (autoHeartbeat.value === 'on' && ws && ws.readyState === WebSocket.OPEN) {
startAutoHeartbeat();
}
});
// 快捷键支持Ctrl+Enter发送消息
messageInput.addEventListener('keydown', function (e) {
if (e.ctrlKey && e.key === 'Enter') {
sendCustomMessage();
e.preventDefault();
}
});
</script>
</body>
</html>

400
ws/ws.py
View File

@ -4,314 +4,300 @@ import json
import os
from contextlib import asynccontextmanager
from typing import Dict, Optional, AsyncGenerator
from concurrent.futures import ThreadPoolExecutor # 新增:显式线程池
from service.device_service import update_online_status_by_ip, increment_alarm_count_by_ip
from service.device_action_service import add_device_action
from schema.device_action_schema import DeviceActionCreate
from core.all import detect
import cv2
import numpy as np
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI
from queue import Queue # 线程安全队列无需额外Lock
from core.all import load_model
from ocr.model_violation_detector import MultiModelViolationDetector
# -------------------------- 配置调整 --------------------------
# 模型路径(建议改为环境变量)
YOLO_MODEL_PATH = r"D:\Git\bin\video\ocr\models\best.pt"
OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml"
# 核心优化:模型池大小(决定最大并发任务数,显存占用=大小×单模型显存)
MODEL_POOL_SIZE = 5 # 示例设为5支持5个任务并行显存会明显上升
THREAD_POOL_SIZE = MODEL_POOL_SIZE * 2 # 线程池大小≥模型池,避免线程瓶颈
# 其他配置
HEARTBEAT_INTERVAL = 30 # 心跳间隔(秒)
# 配置常量
HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒)
HEARTBEAT_TIMEOUT = 600 # 客户端超时阈值(秒)
WS_ENDPOINT = "/ws" # WebSocket端点
FRAME_QUEUE_SIZE = 5 # 增大帧队列,允许缓存更多帧(避免丢帧)
WS_ENDPOINT = "/ws" # WebSocket端点路径
FRAME_QUEUE_SIZE = 1 # 帧队列大小限制
# -------------------------- 工具函数 --------------------------
# 工具函数:获取格式化时间字符串(统一时间戳格式)
def get_current_time_str() -> str:
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def get_current_time_file_str() -> str:
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
# -------------------------- 模型池重构核心修改1 --------------------------
class ModelPool:
def __init__(self, pool_size: int = MODEL_POOL_SIZE):
self.pool = Queue(maxsize=pool_size)
# 移除冗余LockQueue.get()/put()本身线程安全
self._init_models(pool_size)
print(f"[{get_current_time_str()}] 模型池初始化完成(共{pool_size}个实例,显存已预分配)")
def _init_models(self, pool_size: int):
"""预加载所有模型实例(初始化时显存会一次性上升)"""
for i in range(pool_size):
try:
detector = MultiModelViolationDetector(
ocr_config_path=OCR_CONFIG_PATH,
yolo_model_path=YOLO_MODEL_PATH,
ocr_confidence_threshold=0.5
)
self.pool.put(detector)
print(f"[{get_current_time_str()}] 模型实例{i+1}/{pool_size}加载完成")
except Exception as e:
raise RuntimeError(f"模型实例{i+1}加载失败:{str(e)}")
def get_model(self) -> MultiModelViolationDetector:
"""获取模型(阻塞直到有空闲实例,确保并发安全)"""
return self.pool.get()
def return_model(self, detector: MultiModelViolationDetector):
"""归还模型(立即释放资源供其他任务使用)"""
self.pool.put(detector)
# -------------------------- 全局资源初始化 --------------------------
model_pool = ModelPool(pool_size=MODEL_POOL_SIZE) # 初始化模型池(预占显存)
thread_pool = ThreadPoolExecutor( # 显式创建线程池核心修改2
max_workers=THREAD_POOL_SIZE,
thread_name_prefix="ModelWorker-" # 线程命名,便于调试
)
# -------------------------- 客户端连接封装核心修改3 --------------------------
# 客户端连接封装
class ClientConnection:
def __init__(self, websocket: WebSocket, client_ip: str):
self.websocket = websocket
self.client_ip = client_ip
self.last_heartbeat = datetime.datetime.now()
self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE) # 增大队列
self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE)
self.consumer_task: Optional[asyncio.Task] = None
# 移除“客户端独占模型”不再持有detector属性
def update_heartbeat(self):
"""更新心跳时间(客户端发送心跳时调用)"""
self.last_heartbeat = datetime.datetime.now()
def is_alive(self) -> bool:
"""判断客户端是否存活(心跳超时检查)"""
timeout = (datetime.datetime.now() - self.last_heartbeat).total_seconds()
return timeout < HEARTBEAT_TIMEOUT
def start_consumer(self):
"""启动帧消费任务(每个客户端一个独立任务)"""
"""启动帧消费任务"""
self.consumer_task = asyncio.create_task(self.consume_frames())
return self.consumer_task
async def send_frame_permit(self):
"""发送帧许可信号(允许客户端继续发帧)"""
"""
发送「帧发送许可信号」
通知客户端可发送下一帧图像
"""
try:
await self.websocket.send_json({
frame_permit_msg = {
"type": "frame",
"timestamp": get_current_time_str(),
"client_ip": self.client_ip
})
}
await self.websocket.send_json(frame_permit_msg)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:已发送帧发送许可信号(取帧后立即通知)")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧许可发送失败 - {str(e)}")
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧许可信号发送失败 - {str(e)}")
async def consume_frames(self) -> None:
"""消费队列(并发核心:每帧临时借模型处理)"""
"""消费队列中的帧并处理(核心调整:取帧后立即发许可,再处理"""
try:
while True:
# 1. 从队列取帧(无帧时阻塞)
# 1. 从队列取帧(阻塞直到有帧可用
frame_data = await self.frame_queue.get()
# 2. 立即发送下一帧许可(让客户端持续发帧,积累并发任务)
await self.send_frame_permit()
# -------------------------- 核心修改:取出帧后立即发送下一帧许可 --------------------------
await self.send_frame_permit() # 取帧即通知客户端发下一帧,无需等处理完成
# -----------------------------------------------------------------------------------------
try:
# 3. 并行处理帧(核心:任务级借模型
# 2. 处理取出的帧(即使处理慢,客户端也已收到许可,可提前准备下一帧
await self.process_frame(frame_data)
finally:
self.frame_queue.task_done() # 标记帧处理完成
# 3. 标记帧任务完成(无论处理成功/失败,都需清理队列)
self.frame_queue.task_done()
except asyncio.CancelledError:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧消费任务已取消")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:消费逻辑错误 - {str(e)}")
print(f"[{get_current_time_str()}] 客户端{self.client_ip}消费逻辑错误 - {str(e)}")
async def process_frame(self, frame_data: bytes) -> None:
"""处理单帧核心修改4任务级借还模型"""
# 1. 临时借用模型(阻塞直到有空闲实例,显存随借用数上升)
detector = model_pool.get_model()
"""处理单帧图像数据(检测违规后发送危险通知 + 调用违规次数加一方法"""
# 二进制数据转OpenCV图像
nparr = np.frombuffer(frame_data, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:无法解析图像数据")
return
# 确保图像保存目录存在
os.makedirs('images', exist_ok=True)
# 保存图像按IP+时间戳命名,避免冲突)
filename = f"images/{self.client_ip.replace('.', '_')}_{get_current_time_file_str()}.jpg"
try:
# 2. 二进制转OpenCV图像
nparr = np.frombuffer(frame_data, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:图像解析失败")
return
# 3. 保存图像(可选)
os.makedirs('images', exist_ok=True)
filename = f"images/{self.client_ip.replace('.', '_')}_{get_current_time_file_str()}.jpg"
cv2.imwrite(filename, img)
# 4. 显式线程池执行AI检测真正并发无线程瓶颈
loop = asyncio.get_running_loop()
has_violation, violation_type, details = await loop.run_in_executor(
thread_pool, # 用自定义线程池,避免默认线程不足
detector.detect_violations, # 临时借用的模型
img # 输入图像
)
# 5. 违规处理(与原逻辑一致)
print(f"[{get_current_time_str()}] 图像已保存至:{filename}")
has_violation, data, type = detect(img)
print(has_violation)
print(type)
print(data)
if has_violation:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:违规 - {violation_type}")
# 违规次数更新(用线程池避免阻塞事件循环)
await loop.run_in_executor(thread_pool, increment_alarm_count_by_ip, self.client_ip)
# 发送危险通知
await self.websocket.send_json({
print(
f"[{get_current_time_str()}] 客户端{self.client_ip}:检测到违规 - 类型: {type}, 详情: {data}")
# 调用违规次数加一方法
try:
await asyncio.to_thread(increment_alarm_count_by_ip, self.client_ip)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:违规次数已+1")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:违规次数更新失败 - {str(e)}")
# 发送「危险通知」
danger_msg = {
"type": "danger",
"timestamp": get_current_time_str(),
"client_ip": self.client_ip,
"violation_type": violation_type,
"details": details
})
"client_ip": self.client_ip
}
await self.websocket.send_json(danger_msg)
else:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}违规")
print(f"[{get_current_time_str()}] 客户端{self.client_ip}未检测到违规")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}处理错误 - {str(e)}")
finally:
# 6. 无论成功/失败,强制归还模型(核心:释放资源供其他任务使用)
model_pool.return_model(detector)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:模型已归还(可复用)")
print(f"[{get_current_time_str()}] 客户端{self.client_ip}图像处理错误 - {str(e)}")
# -------------------------- 全局状态与心跳 --------------------------
# 全局状态管理
connected_clients: Dict[str, ClientConnection] = {}
client_lock = asyncio.Lock() # 保护客户端字典的异步锁
heartbeat_task: Optional[asyncio.Task] = None
# 心跳检查(定时清理超时客户端 + 调用离线状态更新方法)
async def heartbeat_checker():
"""心跳检查(移除模型归还逻辑,因模型已任务级归还)"""
while True:
current_time = get_current_time_str()
async with client_lock:
# 筛选超时客户端
timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()]
timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()]
for ip in timeout_ips:
async with client_lock:
conn = connected_clients.get(ip)
if not conn:
continue
# 取消消费任务+关闭连接
if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel()
await conn.websocket.close(code=1008, reason="心跳超时")
# 标记离线(用线程池)
loop = asyncio.get_running_loop()
await loop.run_in_executor(thread_pool, update_online_status_by_ip, ip, 0)
await loop.run_in_executor(
thread_pool, add_device_action, DeviceActionCreate(client_ip=ip, action=0)
)
connected_clients.pop(ip)
print(f"[{current_time}] 客户端{ip}:超时离线(资源已清理)")
if timeout_ips:
print(f"[{current_time}] 心跳检查:{len(timeout_ips)}个客户端超时IP{timeout_ips}")
for ip in timeout_ips:
try:
conn = connected_clients[ip]
if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel()
await conn.websocket.close(code=1008, reason="心跳超时")
# 打印在线状态
async with client_lock:
# 超时设为离线并记录
try:
await asyncio.to_thread(update_online_status_by_ip, ip, 0)
action_data = DeviceActionCreate(client_ip=ip, action=0)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{current_time}] 客户端{ip}:已标记为离线并记录操作")
except Exception as e:
print(f"[{current_time}] 客户端{ip}:离线状态更新失败 - {str(e)}")
finally:
connected_clients.pop(ip, None)
else:
print(f"[{current_time}] 心跳检查:{len(connected_clients)}个客户端在线")
await asyncio.sleep(HEARTBEAT_INTERVAL)
# -------------------------- 应用生命周期核心修改5管理线程池 --------------------------
# 应用生命周期管理
@asynccontextmanager
async def lifespan(app: FastAPI):
global heartbeat_task
# 启动心跳任务
heartbeat_task = asyncio.create_task(heartbeat_checker())
print(f"[{get_current_time_str()}] 心跳任务启动ID{id(heartbeat_task)}")
print(f"[{get_current_time_str()}] 线程池启动(最大线程数:{THREAD_POOL_SIZE}")
yield # 应用运行期间
# 清理资源
print(f"[{get_current_time_str()}] 全局心跳检查任务启动(任务ID{id(heartbeat_task)}")
yield
if heartbeat_task and not heartbeat_task.done():
heartbeat_task.cancel()
await heartbeat_task
print(f"[{get_current_time_str()}] 心跳任务已关闭")
# 关闭线程池(等待所有任务完成)
thread_pool.shutdown(wait=True)
print(f"[{get_current_time_str()}] 线程池已关闭")
try:
await heartbeat_task
print(f"[{get_current_time_str()}] 全局心跳检查任务已取消")
except asyncio.CancelledError:
pass
# -------------------------- WebSocket路由 --------------------------
# 消息处理工具函数
async def send_heartbeat_ack(conn: ClientConnection):
try:
heartbeat_ack_msg = {
"type": "heart",
"timestamp": get_current_time_str(),
"client_ip": conn.client_ip
}
await conn.websocket.send_json(heartbeat_ack_msg)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:已发送心跳确认")
return True
except Exception as e:
connected_clients.pop(conn.client_ip, None)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:心跳确认发送失败 - {str(e)}")
return False
async def handle_text_msg(conn: ClientConnection, text: str):
try:
msg = json.loads(text)
if msg.get("type") == "heart":
conn.update_heartbeat()
await send_heartbeat_ack(conn)
else:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:未知文本消息类型({msg.get('type')}")
except json.JSONDecodeError:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}无效JSON文本消息")
async def handle_binary_msg(conn: ClientConnection, data: bytes):
try:
conn.frame_queue.put_nowait(data)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:图像数据({len(data)}字节)已加入队列")
except asyncio.QueueFull:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:帧队列已满,丢弃当前图像数据")
# WebSocket路由配置
ws_router = APIRouter()
@ws_router.websocket(WS_ENDPOINT)
async def websocket_endpoint(websocket: WebSocket):
# 加载模型
load_model()
await websocket.accept()
client_ip = websocket.client.host if websocket.client else "unknown_ip"
current_time = get_current_time_str()
print(f"[{current_time}] 客户端{client_ip}:连接建立")
print(f"[{current_time}] 客户端{client_ip}WebSocket连接建立")
new_conn = None
is_online_updated = False
try:
# 处理重复连接(关闭旧连接)
async with client_lock:
if client_ip in connected_clients:
old_conn = connected_clients[client_ip]
if old_conn.consumer_task and not old_conn.consumer_task.done():
old_conn.consumer_task.cancel()
await old_conn.websocket.close(code=1008, reason="新连接抢占")
connected_clients.pop(client_ip)
print(f"[{current_time}] 客户端{client_ip}:旧连接已关闭")
# 创建新连接+启动消费任务
try:
# 处理重复连接
if client_ip in connected_clients:
old_conn = connected_clients[client_ip]
if old_conn.consumer_task and not old_conn.consumer_task.done():
old_conn.consumer_task.cancel()
await old_conn.websocket.close(code=1008, reason="同一IP新连接建立")
connected_clients.pop(client_ip)
print(f"[{current_time}] 客户端{client_ip}:已关闭旧连接")
# 注册新连接
new_conn = ClientConnection(websocket, client_ip)
connected_clients[client_ip] = new_conn
new_conn.start_consumer()
# 初始发送帧许可(让客户端立即发帧
# 初始许可:连接建立后立即发一次,让客户端知道可发第一帧(后续靠取帧后自动发
await new_conn.send_frame_permit()
# 标记客户端在线
loop = asyncio.get_running_loop()
await loop.run_in_executor(thread_pool, update_online_status_by_ip, client_ip, 1)
await loop.run_in_executor(
thread_pool, add_device_action, DeviceActionCreate(client_ip=client_ip, action=1)
)
is_online_updated = True
async with client_lock:
connected_clients[client_ip] = new_conn
print(f"[{current_time}] 客户端{client_ip}:注册成功(在线数:{len(connected_clients)}")
# 标记上线并记录
try:
await asyncio.to_thread(update_online_status_by_ip, client_ip, 1)
action_data = DeviceActionCreate(client_ip=client_ip, action=1)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{current_time}] 客户端{client_ip}:已标记为在线并记录操作")
is_online_updated = True
except Exception as e:
print(f"[{current_time}] 客户端{client_ip}:上线状态更新失败 - {str(e)}")
# 消息循环(接收文本/二进制帧)
print(f"[{current_time}] 客户端{client_ip}:新连接注册成功,在线数:{len(connected_clients)}")
# 消息循环
while True:
data = await websocket.receive()
if "text" in data:
# 处理文本消息(如心跳)
try:
msg = json.loads(data["text"])
if msg.get("type") == "heart":
new_conn.update_heartbeat()
# 回复心跳确认
await websocket.send_json({
"type": "heart",
"timestamp": get_current_time_str(),
"client_ip": client_ip
})
except json.JSONDecodeError:
print(f"[{get_current_time_str()}] 客户端{client_ip}无效JSON")
await handle_text_msg(new_conn, data["text"])
elif "bytes" in data:
# 处理二进制帧(图像)
try:
await new_conn.frame_queue.put(data["bytes"])
print(f"[{get_current_time_str()}] 客户端{client_ip}:帧已入队(队列大小:{new_conn.frame_queue.qsize()}")
except asyncio.QueueFull:
print(f"[{get_current_time_str()}] 客户端{client_ip}:帧队列满(丢弃当前帧)")
await handle_binary_msg(new_conn, data["bytes"])
except WebSocketDisconnect as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}:主动断开(代码:{e.code}")
print(f"[{get_current_time_str()}] 客户端{client_ip}:主动断开连接(代码:{e.code}")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}:连接异常 - {str(e)[:50]}")
finally:
# 清理资源无需归还模型已在process_frame中归还
if new_conn and client_ip in connected_clients:
async with client_lock:
conn = connected_clients.get(client_ip)
if conn:
if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel()
# 标记离线(仅当在线状态已更新时)
if is_online_updated:
loop = asyncio.get_running_loop()
await loop.run_in_executor(thread_pool, update_online_status_by_ip, client_ip, 0)
await loop.run_in_executor(
thread_pool, add_device_action, DeviceActionCreate(client_ip=client_ip, action=0)
)
connected_clients.pop(client_ip)
async with client_lock:
print(f"[{get_current_time_str()}] 客户端{client_ip}:资源清理完成(在线数:{len(connected_clients)}")
# 清理资源并标记离线
if client_ip in connected_clients:
conn = connected_clients[client_ip]
if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel()
# 主动/异常断开时标记离线
if is_online_updated:
try:
await asyncio.to_thread(update_online_status_by_ip, client_ip, 0)
action_data = DeviceActionCreate(client_ip=client_ip, action=0)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{get_current_time_str()}] 客户端{client_ip}:断开后已标记为离线")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}:断开后离线更新失败 - {str(e)}")
connected_clients.pop(client_ip, None)
print(f"[{get_current_time_str()}] 客户端{client_ip}:资源已清理,在线数:{len(connected_clients)}")