最新可用
2
.idea/Video.iml
generated
@ -2,7 +2,7 @@
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="video" jdkType="Python SDK" />
|
||||
<orderEntry type="jdk" jdkName="Python 3.10" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
2
.idea/misc.xml
generated
@ -3,5 +3,5 @@
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="video" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="video" project-jdk-type="Python SDK" />
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10" project-jdk-type="Python SDK" />
|
||||
</project>
|
@ -15,5 +15,5 @@ algorithm = HS256
|
||||
access_token_expire_minutes = 30
|
||||
|
||||
[live]
|
||||
rtmp_url = rtmp://192.168.110.65:1935/live/
|
||||
webrtc_url = http://192.168.110.65:1985/rtc/v1/whep/?app=live&stream=
|
||||
rtmp_url = rtmp://192.168.110.25:1935/live/
|
||||
webrtc_url = http://192.168.110.25:1985/rtc/v1/whep/?app=live&stream=
|
||||
|
45
core/all.py
Normal file
@ -0,0 +1,45 @@
|
||||
from core.ocr import load_model as ocrLoadModel, detect as ocrDetect
|
||||
from core.face import load_model as faceLoadModel, detect as faceDetect
|
||||
from core.yolo import load_model as yoloLoadModel, detect as yoloDetect
|
||||
|
||||
# 添加一个标记变量,用于监控load_model是否已被调用
|
||||
_model_loaded = False
|
||||
|
||||
|
||||
def load_model():
|
||||
global _model_loaded
|
||||
|
||||
# 如果已经调用过,直接忽略
|
||||
if _model_loaded:
|
||||
return
|
||||
|
||||
# 首次调用时加载模型
|
||||
ocrLoadModel()
|
||||
faceLoadModel()
|
||||
yoloLoadModel()
|
||||
|
||||
# 标记为已调用
|
||||
_model_loaded = True
|
||||
|
||||
|
||||
def detect(frame):
|
||||
# 先进行YOLO检测
|
||||
yolo_flag, yolo_result = yoloDetect(frame)
|
||||
print("YOLO检测结果:", yolo_result)
|
||||
if yolo_flag:
|
||||
return (True, yolo_result, "yolo")
|
||||
|
||||
# YOLO未检测到,进行人脸检测
|
||||
face_flag, face_result = faceDetect(frame)
|
||||
print("人脸检测结果:", face_result)
|
||||
if face_flag:
|
||||
return (True, face_result, "face")
|
||||
|
||||
# 人脸未检测到,进行OCR检测
|
||||
ocr_flag, ocr_result = ocrDetect(frame)
|
||||
print("OCR检测结果:", ocr_result)
|
||||
if ocr_flag:
|
||||
return (True, ocr_result, "ocr")
|
||||
|
||||
# 所有检测都未检测到
|
||||
return (False, "未检测到任何内容", "none")
|
113
core/face.py
Normal file
@ -0,0 +1,113 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image # 确保正确导入Image类
|
||||
from insightface.app import FaceAnalysis
|
||||
# 导入获取人脸信息的服务
|
||||
from service.face_service import get_all_face_name_with_eigenvalue
|
||||
|
||||
# 全局变量
|
||||
_face_app = None
|
||||
_known_faces_embeddings = {} # 存储姓名到特征值的映射
|
||||
_known_faces_names = [] # 存储所有已知姓名
|
||||
|
||||
|
||||
def load_model():
|
||||
"""加载人脸识别模型及已知人脸特征库"""
|
||||
global _face_app, _known_faces_embeddings, _known_faces_names
|
||||
|
||||
# 初始化InsightFace模型
|
||||
try:
|
||||
_face_app = FaceAnalysis(name='buffalo_l', root=os.path.expanduser('~/.insightface'))
|
||||
_face_app.prepare(ctx_id=0, det_size=(640, 640))
|
||||
except Exception as e:
|
||||
print(f"Face model load failed: {e}")
|
||||
return False
|
||||
|
||||
# 从服务获取所有人脸姓名和特征值
|
||||
try:
|
||||
face_data = get_all_face_name_with_eigenvalue()
|
||||
|
||||
# 处理获取到的人脸数据
|
||||
for person_name, eigenvalue_data in face_data.items():
|
||||
# 处理特征值数据 - 兼容数组和字符串两种格式
|
||||
if isinstance(eigenvalue_data, np.ndarray):
|
||||
# 如果已经是numpy数组,直接使用
|
||||
eigenvalue = eigenvalue_data.astype(np.float32)
|
||||
elif isinstance(eigenvalue_data, str):
|
||||
# 清理字符串:移除方括号、换行符和多余空格
|
||||
cleaned = eigenvalue_data.replace('[', '').replace(']', '').replace('\n', '').strip()
|
||||
# 按空格或逗号分割(处理可能的不同分隔符)
|
||||
values = [v for v in cleaned.split() if v]
|
||||
# 转换为数组
|
||||
eigenvalue = np.array(list(map(float, values)), dtype=np.float32)
|
||||
else:
|
||||
# 不支持的类型
|
||||
print(f"Unsupported eigenvalue type for {person_name}")
|
||||
continue
|
||||
|
||||
# 归一化处理
|
||||
norm = np.linalg.norm(eigenvalue)
|
||||
if norm != 0:
|
||||
eigenvalue = eigenvalue / norm
|
||||
|
||||
_known_faces_embeddings[person_name] = eigenvalue
|
||||
_known_faces_names.append(person_name)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading face data from service: {e}")
|
||||
|
||||
return True if _face_app else False
|
||||
|
||||
|
||||
def detect(frame, threshold=0.4):
|
||||
"""检测并识别人脸,返回结果元组(是否匹配到已知人脸, 结果字符串)"""
|
||||
global _face_app, _known_faces_embeddings, _known_faces_names
|
||||
|
||||
if not _face_app or not _known_faces_names or frame is None:
|
||||
return (False, "未初始化或无效帧")
|
||||
|
||||
try:
|
||||
faces = _face_app.get(frame)
|
||||
except Exception as e:
|
||||
print(f"Face detect error: {e}")
|
||||
return (False, f"检测错误: {str(e)}")
|
||||
|
||||
result_parts = []
|
||||
has_matched = False # 新增标记:是否有匹配到的已知人脸
|
||||
|
||||
for face in faces:
|
||||
# 特征归一化
|
||||
embedding = face.embedding.astype(np.float32)
|
||||
norm = np.linalg.norm(embedding)
|
||||
if norm == 0:
|
||||
continue
|
||||
embedding = embedding / norm
|
||||
|
||||
# 对比已知人脸
|
||||
max_sim, best_name = -1.0, "Unknown"
|
||||
for name in _known_faces_names:
|
||||
known_emb = _known_faces_embeddings[name]
|
||||
sim = np.dot(embedding, known_emb)
|
||||
if sim > max_sim:
|
||||
max_sim = sim
|
||||
best_name = name
|
||||
|
||||
# 判断匹配结果
|
||||
is_match = max_sim >= threshold
|
||||
if is_match:
|
||||
has_matched = True # 只要有一个匹配成功,就标记为True
|
||||
|
||||
bbox = face.bbox
|
||||
result_parts.append(
|
||||
f"{'匹配' if is_match else '不匹配'}: {best_name} (相似度: {max_sim:.2f}, 边界框: {bbox})"
|
||||
)
|
||||
|
||||
# 构建结果字符串
|
||||
if not result_parts:
|
||||
result_str = "未检测到人脸"
|
||||
else:
|
||||
result_str = "; ".join(result_parts)
|
||||
|
||||
# 第一个返回值改为:是否匹配到已知人脸
|
||||
return (has_matched, result_str)
|
BIN
core/models/best.pt
Normal file
76
core/ocr.py
Normal file
@ -0,0 +1,76 @@
|
||||
import os
|
||||
import cv2
|
||||
from rapidocr import RapidOCR
|
||||
from service.sensitive_service import get_all_sensitive_words
|
||||
|
||||
# 全局变量
|
||||
_ocr_engine = None
|
||||
_forbidden_words = set()
|
||||
_conf_threshold = 0.5
|
||||
|
||||
ocr_config_path = os.path.join(os.path.dirname(__file__), "config", "config.yaml")
|
||||
|
||||
|
||||
def load_model():
|
||||
"""加载OCR引擎及违禁词列表"""
|
||||
global _ocr_engine, _forbidden_words, _conf_threshold
|
||||
|
||||
# 加载违禁词
|
||||
try:
|
||||
_forbidden_words = get_all_sensitive_words()
|
||||
except Exception as e:
|
||||
print(f"Forbidden words load error: {e}")
|
||||
|
||||
# 初始化OCR引擎
|
||||
if not os.path.exists(ocr_config_path):
|
||||
print(f"OCR config not found: {ocr_config_path}")
|
||||
return False
|
||||
|
||||
try:
|
||||
_ocr_engine = RapidOCR(config_path=ocr_config_path)
|
||||
except Exception as e:
|
||||
print(f"OCR model load failed: {e}")
|
||||
return False
|
||||
|
||||
return True if _ocr_engine else False
|
||||
|
||||
|
||||
def detect(frame):
|
||||
"""OCR检测并筛选违禁词,返回(是否检测到违禁词, 结果字符串)"""
|
||||
if not _ocr_engine or not _forbidden_words or frame is None or frame.size == 0:
|
||||
return (False, "未初始化或无效帧")
|
||||
|
||||
try:
|
||||
ocr_res = _ocr_engine(frame)
|
||||
except Exception as e:
|
||||
print(f"OCR detect error: {e}")
|
||||
return (False, f"检测错误: {str(e)}")
|
||||
|
||||
if not ocr_res or not hasattr(ocr_res, 'txts') or not hasattr(ocr_res, 'scores'):
|
||||
return (False, "无OCR结果")
|
||||
|
||||
# 处理OCR结果
|
||||
texts = [t.strip() for t in ocr_res.txts if t and isinstance(t, str)]
|
||||
confs = [c for c in ocr_res.scores if c and isinstance(c, (int, float))]
|
||||
if len(texts) != len(confs):
|
||||
return (False, "OCR结果格式异常")
|
||||
|
||||
# 筛选违禁词
|
||||
vio_info = []
|
||||
for txt, conf in zip(texts, confs):
|
||||
if conf < _conf_threshold:
|
||||
continue
|
||||
matched = [w for w in _forbidden_words if w in txt]
|
||||
if matched:
|
||||
vio_info.append(f"文本: '{txt}' 包含违禁词: {', '.join(matched)} (置信度: {conf:.2f})")
|
||||
|
||||
# 构建结果字符串
|
||||
has_text = len(texts) > 0
|
||||
has_violation = len(vio_info) > 0
|
||||
|
||||
if not has_text:
|
||||
return (False, "未识别到文本")
|
||||
elif has_violation:
|
||||
return (True, "; ".join(vio_info))
|
||||
else:
|
||||
return (False, "未检测到违禁词")
|
137
core/rtc.py
@ -1,137 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from aiortc import RTCPeerConnection, RTCSessionDescription
|
||||
import aiohttp
|
||||
from ocr.ocr_violation_detector import OCRViolationDetector
|
||||
|
||||
import logging
|
||||
|
||||
# 创建检测器实例
|
||||
detector = OCRViolationDetector(
|
||||
forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt",
|
||||
ocr_confidence_threshold=0.7,
|
||||
log_level=logging.INFO,
|
||||
log_file="ocr_detection.log"
|
||||
)
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("whep_video_puller")
|
||||
|
||||
|
||||
async def whep_pull_video_stream(ip,whep_url):
|
||||
"""
|
||||
通过WHEP从指定URL拉取视频流并在收到每一帧时打印消息
|
||||
|
||||
Args:
|
||||
whep_url: WHEP端点的URL
|
||||
"""
|
||||
pc = RTCPeerConnection()
|
||||
|
||||
# 添加连接状态变化监听
|
||||
@pc.on("connectionstatechange")
|
||||
async def on_connectionstatechange():
|
||||
print(f"连接状态: {pc.connectionState}")
|
||||
|
||||
# 添加ICE连接状态变化监听
|
||||
@pc.on("iceconnectionstatechange")
|
||||
async def on_iceconnectionstatechange():
|
||||
print(f"ICE连接状态: {pc.iceConnectionState}")
|
||||
|
||||
# 添加视频接收器
|
||||
pc.addTransceiver("video", direction="recvonly")
|
||||
|
||||
# 处理接收到的视频轨道
|
||||
@pc.on("track")
|
||||
def on_track(track):
|
||||
print(f"接收到轨道: {track.kind}")
|
||||
if track.kind == "video":
|
||||
print(f"轨道ID: {track.id}")
|
||||
print(f"轨道就绪状态: {track.readyState}")
|
||||
# 创建异步任务来处理视频帧
|
||||
asyncio.ensure_future(handle_video_track(track))
|
||||
|
||||
async def handle_video_track(track):
|
||||
"""处理视频轨道,接收并打印每一帧"""
|
||||
frame_count = 0
|
||||
print("开始处理视频轨道...")
|
||||
|
||||
while True:
|
||||
try:
|
||||
# 尝试接收帧
|
||||
frame = await track.recv()
|
||||
frame_count += 1
|
||||
print(f"收到原始帧 (第{frame_count}帧)")
|
||||
|
||||
# 打印帧的基本信息
|
||||
if hasattr(frame, 'width') and hasattr(frame, 'height'):
|
||||
print(f" 尺寸: {frame.width}x{frame.height}")
|
||||
if hasattr(frame, 'time_base'):
|
||||
print(f" 时间基准: {frame.time_base}")
|
||||
if hasattr(frame, 'pts'):
|
||||
print(f" 显示时间戳: {frame.pts}")
|
||||
|
||||
has_violation, violations, confidences = OCRViolationDetector.detect(frame)
|
||||
|
||||
# 输出检测结果
|
||||
if has_violation:
|
||||
detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
|
||||
for word, conf in zip(violations, confidences):
|
||||
detector.logger.info(f"- {word} (置信度: {conf:.4f})")
|
||||
else:
|
||||
detector.logger.info("图片中未检测到违禁词")
|
||||
except Exception as e:
|
||||
print(f"接收帧时出错: {e}")
|
||||
# 等待一段时间后重试
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
# 创建offer
|
||||
offer = await pc.createOffer()
|
||||
await pc.setLocalDescription(offer)
|
||||
|
||||
print(f"本地SDP信息:\n{offer.sdp}")
|
||||
|
||||
# 通过HTTP POST发送offer到WHEP端点
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
whep_url,
|
||||
data=offer.sdp,
|
||||
headers={"Content-Type": "application/sdp"}
|
||||
) as response:
|
||||
if response.status != 201:
|
||||
print(f"WHEP服务器返回错误: {response.status}")
|
||||
print(f"响应内容: {await response.text()}")
|
||||
raise Exception(f"WHEP服务器返回错误: {response.status}")
|
||||
|
||||
# 获取answer SDP
|
||||
answer_sdp = await response.text()
|
||||
|
||||
# 创建RTCSessionDescription对象
|
||||
answer = RTCSessionDescription(sdp=answer_sdp, type="answer")
|
||||
|
||||
print(f"收到远程SDP:\n{answer_sdp}")
|
||||
|
||||
# 设置远程描述
|
||||
await pc.setRemoteDescription(answer)
|
||||
|
||||
print("连接已建立,开始接收视频流...")
|
||||
|
||||
# 保持连接,直到用户中断
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
# 检查连接状态
|
||||
print(f"当前连接状态: {pc.connectionState}")
|
||||
except KeyboardInterrupt:
|
||||
print("用户中断,关闭连接...")
|
||||
finally:
|
||||
await pc.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 替换为你的WHEP端点URL
|
||||
WHEP_URL = "http://192.168.110.25:1985/rtc/v1/whep/?app=live&stream=473b95a47e338301cbd96809ea7ac416"
|
||||
|
||||
# 运行拉流任务
|
||||
asyncio.run(whep_pull_video_stream(WHEP_URL))
|
112
core/rtmp.py
@ -1,112 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import cv2
|
||||
import time
|
||||
from ocr.model_violation_detector import MultiModelViolationDetector
|
||||
|
||||
|
||||
# 配置文件相对路径(根据实际目录结构调整)
|
||||
YOLO_MODEL_PATH = "../ocr/models/best.pt" # 关键修正:从core目录向上一级找ocr文件夹
|
||||
FORBIDDEN_WORDS_PATH = "../ocr/forbidden_words.txt"
|
||||
OCR_CONFIG_PATH = "../ocr/config/1.yaml"
|
||||
KNOWN_FACES_DIR = "../ocr/known_faces"
|
||||
|
||||
# 创建检测器实例
|
||||
detector = MultiModelViolationDetector(
|
||||
forbidden_words_path=FORBIDDEN_WORDS_PATH,
|
||||
ocr_config_path=OCR_CONFIG_PATH,
|
||||
yolo_model_path=YOLO_MODEL_PATH,
|
||||
known_faces_dir=KNOWN_FACES_DIR,
|
||||
ocr_confidence_threshold=0.5
|
||||
)
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("rtmp_video_puller")
|
||||
|
||||
|
||||
async def rtmp_pull_video_stream(rtmp_url):
|
||||
"""
|
||||
通过RTMP从指定URL拉取视频流并进行违规检测
|
||||
"""
|
||||
cap = None # 初始化视频捕获对象
|
||||
try:
|
||||
# 异步打开RTMP流
|
||||
cap = await asyncio.to_thread(
|
||||
cv2.VideoCapture,
|
||||
rtmp_url,
|
||||
cv2.CAP_FFMPEG # 指定FFmpeg后端确保RTMP兼容性
|
||||
)
|
||||
|
||||
# 检查RTMP流是否成功打开
|
||||
is_opened = await asyncio.to_thread(cap.isOpened)
|
||||
if not is_opened:
|
||||
raise Exception(f"RTMP流打开失败: {rtmp_url}(请检查URL有效性和FFmpeg环境)")
|
||||
|
||||
# 获取RTMP流基础信息
|
||||
width = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_WIDTH)
|
||||
height = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_HEIGHT)
|
||||
fps = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FPS)
|
||||
|
||||
# 处理异常情况
|
||||
fps = fps if fps > 0 else 30.0
|
||||
width, height = int(width), int(height)
|
||||
|
||||
# 打印流初始化成功信息
|
||||
print(f"RTMP流状态: 已成功连接")
|
||||
print(f"流基础信息: 分辨率 {width}x{height} | 配置帧率 {fps:.2f} FPS")
|
||||
print("开始接收视频帧...(按 Ctrl+C 中断)")
|
||||
|
||||
# 初始化帧统计参数
|
||||
frame_count = 0
|
||||
start_time = time.time()
|
||||
|
||||
# 循环读取视频帧
|
||||
while True:
|
||||
ret, frame = await asyncio.to_thread(cap.read)
|
||||
|
||||
if not ret:
|
||||
print(f"RTMP流状态: 帧读取失败(可能流已中断或结束)")
|
||||
break
|
||||
|
||||
frame_count += 1
|
||||
|
||||
# 打印当前帧信息
|
||||
print(f"收到帧 (第{frame_count}帧)")
|
||||
print(f" 帧尺寸: {width}x{height}")
|
||||
print(f" 配置帧率: {fps:.2f} FPS")
|
||||
|
||||
if frame is not None:
|
||||
has_violation, violation_type, details = detector.detect_violations(frame)
|
||||
if has_violation:
|
||||
print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
|
||||
else:
|
||||
print("未检测到任何违规内容")
|
||||
else:
|
||||
print(f"无法读取测试图像")
|
||||
|
||||
# 每100帧统计一次实际接收帧率
|
||||
if frame_count % 100 == 0:
|
||||
elapsed_time = time.time() - start_time
|
||||
actual_fps = frame_count / elapsed_time
|
||||
print(f"---- 帧统计: 累计{frame_count}帧 | 实际平均帧率 {actual_fps:.2f} FPS ----")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n用户操作: 已通过 Ctrl+C 中断程序")
|
||||
except Exception as e:
|
||||
logger.error(f"RTMP流处理异常: {str(e)}", exc_info=True)
|
||||
print(f"错误信息: {str(e)}")
|
||||
finally:
|
||||
if cap is not None:
|
||||
await asyncio.to_thread(cap.release)
|
||||
print(f"\n资源释放: RTMP流已关闭")
|
||||
print(f"最终统计: 共接收 {frame_count if 'frame_count' in locals() else 0} 帧")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
RTMP_URL = "rtmp://192.168.110.25:1935/live/473b95a47e338301cbd96809ea7ac416"
|
||||
|
||||
try:
|
||||
asyncio.run(rtmp_pull_video_stream(RTMP_URL))
|
||||
except Exception as e:
|
||||
print(f"程序启动失败: {str(e)}")
|
55
core/yolo.py
Normal file
@ -0,0 +1,55 @@
|
||||
import os
|
||||
|
||||
import cv2
|
||||
from ultralytics import YOLO
|
||||
|
||||
# 全局变量
|
||||
_yolo_model = None
|
||||
|
||||
|
||||
model_path = os.path.join(os.path.dirname(__file__), "models", "best.pt")
|
||||
|
||||
|
||||
def load_model():
|
||||
"""加载YOLO目标检测模型"""
|
||||
global _yolo_model
|
||||
|
||||
try:
|
||||
_yolo_model = YOLO(model_path)
|
||||
except Exception as e:
|
||||
print(f"YOLO model load failed: {e}")
|
||||
return False
|
||||
|
||||
return True if _yolo_model else False
|
||||
|
||||
|
||||
def detect(frame, conf_threshold=0.2):
|
||||
"""YOLO目标检测,返回(是否识别到, 结果字符串)"""
|
||||
global _yolo_model
|
||||
|
||||
if not _yolo_model or frame is None:
|
||||
return (False, "未初始化或无效帧")
|
||||
|
||||
try:
|
||||
results = _yolo_model(frame, conf=conf_threshold)
|
||||
# 检查是否有检测结果
|
||||
has_results = len(results[0].boxes) > 0 if results else False
|
||||
|
||||
if not has_results:
|
||||
return (False, "未检测到目标")
|
||||
|
||||
# 构建结果字符串
|
||||
result_parts = []
|
||||
for box in results[0].boxes:
|
||||
cls = int(box.cls[0])
|
||||
conf = float(box.conf[0])
|
||||
bbox = [float(x) for x in box.xyxy[0]]
|
||||
class_name = _yolo_model.names[cls] if hasattr(_yolo_model, 'names') else f"类别{cls}"
|
||||
result_parts.append(f"{class_name} (置信度: {conf:.2f}, 边界框: {bbox})")
|
||||
|
||||
result_str = "; ".join(result_parts)
|
||||
return (has_results, result_str)
|
||||
|
||||
except Exception as e:
|
||||
print(f"YOLO detect error: {e}")
|
||||
return (False, f"检测错误: {str(e)}")
|
23
main.py
@ -1,12 +1,19 @@
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from PIL import Image # 正确导入
|
||||
import numpy as np
|
||||
|
||||
import uvicorn
|
||||
from PIL import Image
|
||||
from fastapi import FastAPI
|
||||
from core.all import load_model,detect
|
||||
from ds.config import SERVER_CONFIG
|
||||
from middle.error_handler import global_exception_handler
|
||||
from service.user_service import router as user_router
|
||||
from service.sensitive_service import router as sensitive_router
|
||||
from service.face_service import router as face_router
|
||||
from service.device_service import router as device_router
|
||||
from ws.ws import ws_router, lifespan
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 初始化 FastAPI 应用、指定生命周期管理
|
||||
# ------------------------------
|
||||
@ -22,6 +29,8 @@ app = FastAPI(
|
||||
# ------------------------------
|
||||
app.include_router(user_router)
|
||||
app.include_router(device_router)
|
||||
app.include_router(face_router)
|
||||
app.include_router(sensitive_router)
|
||||
app.include_router(ws_router)
|
||||
|
||||
# ------------------------------
|
||||
@ -33,11 +42,19 @@ app.add_exception_handler(Exception, global_exception_handler)
|
||||
# 启动服务
|
||||
# ------------------------------
|
||||
if __name__ == "__main__":
|
||||
# -------------------------- 配置调整 --------------------------
|
||||
# 模型配置路径(建议改为环境变量)
|
||||
YOLO_MODEL_PATH = r"/core/models\best.pt"
|
||||
OCR_CONFIG_PATH = r"/core/config\config.yaml"
|
||||
|
||||
# 初始化项目(默认端口设为8000,避免初始化失败时port未定义)
|
||||
port = int(SERVER_CONFIG.get("port", 8000))
|
||||
|
||||
# 启动 UVicorn 服务
|
||||
uvicorn.run(
|
||||
app="main:app",
|
||||
host="0.0.0.0",
|
||||
port=port,
|
||||
reload=True,
|
||||
workers=8,
|
||||
ws="websockets"
|
||||
)
|
||||
|
@ -8,7 +8,8 @@ from passlib.context import CryptContext
|
||||
|
||||
from ds.config import JWT_CONFIG
|
||||
from ds.db import db
|
||||
from service.user_service import UserResponse
|
||||
|
||||
# 移除这里的 from service.user_service import UserResponse 导入
|
||||
|
||||
# ------------------------------
|
||||
# 密码加密配置
|
||||
@ -25,6 +26,7 @@ ACCESS_TOKEN_EXPIRE_MINUTES = int(JWT_CONFIG["access_token_expire_minutes"])
|
||||
# OAuth2 依赖(从请求头获取 Token、格式:Bearer <token>)
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login")
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 密码工具函数
|
||||
# ------------------------------
|
||||
@ -32,10 +34,12 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""验证明文密码与加密密码是否匹配"""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""对明文密码进行 bcrypt 加密"""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# JWT 工具函数
|
||||
# ------------------------------
|
||||
@ -53,11 +57,15 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
|
||||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 认证依赖(获取当前登录用户)
|
||||
# ------------------------------
|
||||
def get_current_user(token: str = Depends(oauth2_scheme)) -> UserResponse:
|
||||
def get_current_user(token: str = Depends(oauth2_scheme)): # 移除返回类型注解
|
||||
"""从 Token 中解析用户信息、验证通过后返回当前用户"""
|
||||
# 延迟导入,打破循环依赖
|
||||
from schema.user_schema import UserResponse # 在这里导入
|
||||
|
||||
# 认证失败异常
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@ -89,7 +97,7 @@ def get_current_user(token: str = Depends(oauth2_scheme)) -> UserResponse:
|
||||
raise credentials_exception # 用户不存在
|
||||
|
||||
# 转换为 UserResponse 模型(自动校验字段)
|
||||
return UserResponse(** user)
|
||||
return UserResponse(**user)
|
||||
except Exception as e:
|
||||
raise credentials_exception from e
|
||||
finally:
|
||||
|
@ -1,139 +0,0 @@
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import insightface
|
||||
from insightface.app import FaceAnalysis
|
||||
|
||||
|
||||
class FaceRecognizer:
|
||||
"""
|
||||
封装InsightFace人脸识别功能,支持从文件夹加载已知人脸。
|
||||
"""
|
||||
|
||||
def __init__(self, known_faces_dir: str):
|
||||
self.known_faces_dir = known_faces_dir
|
||||
self.app = self._initialize_insightface()
|
||||
self.known_faces_embeddings = {}
|
||||
self.known_faces_names = []
|
||||
self._load_known_faces()
|
||||
|
||||
def _initialize_insightface(self):
|
||||
"""初始化InsightFace FaceAnalysis应用"""
|
||||
print("初始化InsightFace引擎...")
|
||||
try:
|
||||
app = FaceAnalysis(name='buffalo_l', root='~/.insightface')
|
||||
app.prepare(ctx_id=0, det_size=(640, 640))
|
||||
print("InsightFace引擎初始化完成")
|
||||
return app
|
||||
except Exception as e:
|
||||
print(f"InsightFace初始化失败: {e}")
|
||||
print("请检查依赖是否安装及模型是否可访问")
|
||||
return None
|
||||
|
||||
def _load_known_faces(self):
|
||||
"""加载已知人脸特征"""
|
||||
if not os.path.exists(self.known_faces_dir):
|
||||
print(f"已知人脸目录不存在,已创建: {self.known_faces_dir}")
|
||||
os.makedirs(self.known_faces_dir, exist_ok=True)
|
||||
return
|
||||
|
||||
print(f"从目录加载人脸特征: {self.known_faces_dir}")
|
||||
for person_name in os.listdir(self.known_faces_dir):
|
||||
person_dir = os.path.join(self.known_faces_dir, person_name)
|
||||
if os.path.isdir(person_dir):
|
||||
print(f"处理人物: {person_name}")
|
||||
embeddings = []
|
||||
for filename in os.listdir(person_dir):
|
||||
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
|
||||
image_path = os.path.join(person_dir, filename)
|
||||
try:
|
||||
img = cv2.imread(image_path)
|
||||
if img is None:
|
||||
print(f"无法读取图片: {image_path},已跳过")
|
||||
continue
|
||||
|
||||
faces = self.app.get(img)
|
||||
if faces:
|
||||
embeddings.append(faces[0].embedding)
|
||||
print(f"提取特征成功: {filename}")
|
||||
else:
|
||||
print(f"未检测到人脸: {filename},已跳过")
|
||||
except Exception as e:
|
||||
print(f"处理图片出错 {image_path}: {e}")
|
||||
|
||||
if embeddings:
|
||||
self.known_faces_embeddings[person_name] = np.array(embeddings).mean(axis=0)
|
||||
self.known_faces_names.append(person_name)
|
||||
print(f"人物 {person_name} 加载完成,共 {len(embeddings)} 张照片")
|
||||
else:
|
||||
print(f"人物 {person_name} 无有效特征,已跳过")
|
||||
print(f"人脸加载完成,共 {len(self.known_faces_names)} 人")
|
||||
|
||||
def recognize(self, frame, threshold=0.4):
|
||||
"""识别人脸并返回结果"""
|
||||
if not self.app or not self.known_faces_names:
|
||||
return False, None, None
|
||||
|
||||
faces = self.app.get(frame)
|
||||
if not faces:
|
||||
return False, None, None
|
||||
|
||||
for face in faces:
|
||||
for known_name in self.known_faces_names:
|
||||
known_embedding = self.known_faces_embeddings[known_name]
|
||||
|
||||
embedding1 = face.embedding.astype(np.float32)
|
||||
embedding2 = known_embedding.astype(np.float32)
|
||||
|
||||
dot_product = np.dot(embedding1, embedding2)
|
||||
norm_embedding1 = np.linalg.norm(embedding1)
|
||||
norm_embedding2 = np.linalg.norm(embedding2)
|
||||
|
||||
similarity = 0.0 if (norm_embedding1 == 0 or norm_embedding2 == 0) else (
|
||||
dot_product / (norm_embedding1 * norm_embedding2)
|
||||
)
|
||||
|
||||
if similarity >= threshold:
|
||||
print(f"检测到已知人物: {known_name} (相似度: {similarity:.4f})")
|
||||
return True, known_name, similarity
|
||||
|
||||
return False, None, None
|
||||
|
||||
def test_single_image(self, image_path: str, threshold=0.4):
|
||||
"""测试单张图片识别"""
|
||||
if not os.path.exists(image_path):
|
||||
print(f"图片不存在: {image_path}")
|
||||
return False, None, None
|
||||
|
||||
frame = cv2.imread(image_path)
|
||||
if frame is None:
|
||||
print(f"无法读取图片: {image_path}")
|
||||
return False, None, None
|
||||
|
||||
result, name, similarity = self.recognize(frame, threshold)
|
||||
|
||||
if result:
|
||||
print(f"识别结果: {name} (相似度: {similarity:.4f})")
|
||||
|
||||
faces = self.app.get(frame)
|
||||
for face in faces:
|
||||
bbox = face.bbox.astype(int)
|
||||
cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2)
|
||||
text = f"{name}: {similarity:.2f}"
|
||||
cv2.putText(frame, text, (bbox[0], bbox[1] - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
|
||||
|
||||
cv2.imshow('识别结果', frame)
|
||||
print("按任意键关闭窗口...")
|
||||
cv2.waitKey(0)
|
||||
cv2.destroyAllWindows()
|
||||
else:
|
||||
print("未识别到已知人脸")
|
||||
|
||||
return result, name, similarity
|
||||
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# recognizer = FaceRecognizer(known_faces_dir="known_faces")
|
||||
# test_image_path = r"F:\OCR\RapidOCR-main\known_faces\B\14sino-qiu02-master1050.jpg"
|
||||
# recognizer.test_single_image(test_image_path, threshold=0.4)
|
@ -1,156 +0,0 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import insightface
|
||||
from insightface.app import FaceAnalysis
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class BinaryFaceFeatureHandler:
|
||||
"""
|
||||
专门处理图片二进制数据的特征提取器,支持分批次接收二进制数据并累积计算平均特征
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.app = self._init_insightface()
|
||||
self.feature_list = [] # 存储所有图片二进制数据提取的特征
|
||||
|
||||
def _init_insightface(self):
|
||||
"""初始化InsightFace引擎"""
|
||||
try:
|
||||
print("正在初始化InsightFace引擎...")
|
||||
app = FaceAnalysis(name='buffalo_l', root='~/.insightface')
|
||||
app.prepare(ctx_id=0, det_size=(640, 640))
|
||||
print("InsightFace引擎初始化完成")
|
||||
return app
|
||||
except Exception as e:
|
||||
print(f"InsightFace初始化失败: {e}")
|
||||
return None
|
||||
|
||||
def add_binary_data(self, binary_data):
|
||||
"""
|
||||
接收单张图片的二进制数据,提取特征并保存
|
||||
|
||||
参数:
|
||||
binary_data: 图片的二进制数据(bytes类型)
|
||||
|
||||
返回:
|
||||
成功提取特征时返回 (True, 特征值numpy数组)
|
||||
失败时返回 (False, None)
|
||||
"""
|
||||
if not self.app:
|
||||
print("引擎未初始化,无法处理")
|
||||
return False, None
|
||||
|
||||
try:
|
||||
# 直接处理二进制数据:转换为图像格式
|
||||
img = Image.open(BytesIO(binary_data))
|
||||
frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
||||
|
||||
# 提取特征
|
||||
faces = self.app.get(frame)
|
||||
if faces:
|
||||
# 获取当前提取的特征值
|
||||
current_feature = faces[0].embedding
|
||||
# 添加到特征列表
|
||||
self.feature_list.append(current_feature)
|
||||
print(f"已累计 {len(self.feature_list)} 个特征")
|
||||
# 返回成功标志和当前特征值
|
||||
return True,current_feature
|
||||
else:
|
||||
print("二进制数据中未检测到人脸")
|
||||
return False, None
|
||||
except Exception as e:
|
||||
print(f"处理二进制数据出错: {e}")
|
||||
return False, None
|
||||
|
||||
def get_average_feature(self, features):
|
||||
"""
|
||||
计算多个特征向量的平均值
|
||||
|
||||
参数:
|
||||
features: 特征值列表,每个元素可以是字符串格式或numpy数组
|
||||
例如: [feature1, feature2, ...]
|
||||
返回:
|
||||
单一平均特征向量的numpy数组,若无可计算数据则返回None
|
||||
"""
|
||||
try:
|
||||
# 验证输入是否为列表且不为空
|
||||
if not isinstance(features, list) or len(features) == 0:
|
||||
print("输入必须是包含至少一个特征值的列表")
|
||||
return None
|
||||
|
||||
# 处理每个特征值
|
||||
processed_features = []
|
||||
for i, embedding in enumerate(features):
|
||||
try:
|
||||
if isinstance(embedding, str):
|
||||
# 处理包含括号和逗号的字符串格式
|
||||
embedding_str = embedding.replace('[', '').replace(']', '').replace(',', ' ').strip()
|
||||
embedding_list = [float(num) for num in embedding_str.split() if num.strip()]
|
||||
embedding_np = np.array(embedding_list, dtype=np.float32)
|
||||
else:
|
||||
embedding_np = np.array(embedding, dtype=np.float32)
|
||||
|
||||
# 验证特征值格式
|
||||
if len(embedding_np.shape) == 1:
|
||||
processed_features.append(embedding_np)
|
||||
print(f"已添加第 {i + 1} 个特征值用于计算平均值")
|
||||
else:
|
||||
print(f"跳过第 {i + 1} 个特征值,不是一维数组")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理第 {i + 1} 个特征值时出错: {e}")
|
||||
|
||||
# 确保有有效的特征值
|
||||
if not processed_features:
|
||||
print("没有有效的特征值用于计算平均值")
|
||||
return None
|
||||
|
||||
# 检查所有特征向量维度是否相同
|
||||
dims = {feat.shape[0] for feat in processed_features}
|
||||
if len(dims) > 1:
|
||||
print(f"特征值维度不一致,无法计算平均值。检测到的维度: {dims}")
|
||||
return None
|
||||
|
||||
# 计算平均值
|
||||
avg_feature = np.mean(processed_features, axis=0)
|
||||
print(f"成功计算 {len(processed_features)} 个特征值的平均特征向量,维度: {avg_feature.shape[0]}")
|
||||
|
||||
return avg_feature
|
||||
|
||||
except Exception as e:
|
||||
print(f"计算平均特征值时出错: {e}")
|
||||
return None
|
||||
|
||||
# def clear(self):
|
||||
# """清空已存储的特征数据"""
|
||||
# self.feature_list = []
|
||||
# print("已清空所有特征数据")
|
||||
|
||||
|
||||
# # 使用示例
|
||||
# if __name__ == "__main__":
|
||||
# handler = BinaryFaceFeatureHandler()
|
||||
#
|
||||
# # 模拟接收图片二进制数据
|
||||
# try:
|
||||
# # 第一次接收
|
||||
# with open(r"D:\Git\bin\video\ocr\known_faces\B\102-f.jpg_1140x855.jpg", "rb") as f:
|
||||
# bin_data1 = f.read()
|
||||
# success, feature1 = handler.add_binary_data(bin_data1)
|
||||
# if success:
|
||||
# print(f"第一次提取的特征值前5个: {feature1[:5]}")
|
||||
#
|
||||
# # 第二次接收
|
||||
# with open(r"D:\Git\bin\video\ocr\known_faces\B\104-1.jpg", "rb") as f:
|
||||
# bin_data2 = f.read()
|
||||
# success, feature2 = handler.add_binary_data(bin_data2)
|
||||
# if success:
|
||||
# print(f"第二次提取的特征值前5个: {feature2[:5]}")
|
||||
#
|
||||
# # 计算平均值
|
||||
# avg_feature = handler.get_average_feature()
|
||||
#
|
||||
# except Exception as e:
|
||||
# print(f"处理过程出错: {e}")
|
@ -1,18 +0,0 @@
|
||||
打倒习近平
|
||||
打到习近平
|
||||
打倒毛泽东
|
||||
打到毛泽东
|
||||
打到主席
|
||||
打倒主席
|
||||
打到共产主义
|
||||
打倒共产主义
|
||||
打到共产党
|
||||
打倒共产党
|
||||
胖猫
|
||||
法轮功
|
||||
法轮大法
|
||||
法轮大法好
|
||||
法轮功大法好
|
||||
法轮
|
||||
李洪志
|
||||
习近平
|
Before Width: | Height: | Size: 195 KiB |
Before Width: | Height: | Size: 208 KiB |
Before Width: | Height: | Size: 657 KiB |
Before Width: | Height: | Size: 53 KiB |
Before Width: | Height: | Size: 8.1 KiB |
Before Width: | Height: | Size: 14 KiB |
Before Width: | Height: | Size: 58 KiB |
Before Width: | Height: | Size: 4.9 KiB |
Before Width: | Height: | Size: 34 KiB |
Before Width: | Height: | Size: 155 KiB |
Before Width: | Height: | Size: 386 KiB |
Before Width: | Height: | Size: 1.4 MiB |
Before Width: | Height: | Size: 62 KiB |
@ -1,49 +0,0 @@
|
||||
#日志文件
|
||||
import logging
|
||||
import sys
|
||||
|
||||
def setup_logger():
|
||||
"""
|
||||
配置一个全局日志记录器,支持输出到控制台和文件。
|
||||
"""
|
||||
# 创建一个日志记录器
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger = logging.getLogger("ViolationDetectorLogger")
|
||||
logger.setLevel(logging.DEBUG) # 设置最低级别为DEBUG
|
||||
|
||||
# 如果已经有处理器了,就不要重复添加,防止日志重复打印
|
||||
if logger.hasHandlers():
|
||||
return logger
|
||||
|
||||
# --- 控制台处理器 ---
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
# 对于控制台,我们只显示INFO及以上级别的信息
|
||||
console_handler.setLevel(logging.INFO)
|
||||
console_formatter = logging.Formatter(
|
||||
'%(asctime)s - %(levelname)s - [%(module)s:%(lineno)d] - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
console_handler.setFormatter(console_formatter)
|
||||
|
||||
# --- 文件处理器 ---
|
||||
file_handler = logging.FileHandler("violation_detector.log", mode='a', encoding='utf-8')
|
||||
# 对于文件,我们记录所有DEBUG及以上级别的信息
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
file_formatter = logging.Formatter(
|
||||
'%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
file_handler.setFormatter(file_formatter)
|
||||
|
||||
# 将处理器添加到日志记录器
|
||||
logger.addHandler(console_handler)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
return logger
|
||||
|
||||
# 创建并导出logger实例
|
||||
logger = setup_logger()
|
@ -1,136 +0,0 @@
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from .ocr_violation_detector import OCRViolationDetector
|
||||
from .yolo_violation_detector import ViolationDetector as YoloViolationDetector
|
||||
from .face_recognizer import FaceRecognizer
|
||||
|
||||
class MultiModelViolationDetector:
|
||||
"""
|
||||
多模型违规检测封装类,串行调用OCR、人脸识别和YOLO模型,任一模型检测到违规即返回结果
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
forbidden_words_path: str,
|
||||
ocr_config_path: str,
|
||||
yolo_model_path: str,
|
||||
known_faces_dir: str,
|
||||
ocr_confidence_threshold: float = 0.5):
|
||||
"""
|
||||
初始化所有检测模型
|
||||
"""
|
||||
# 初始化OCR检测器
|
||||
self.ocr_detector = OCRViolationDetector(
|
||||
forbidden_words_path=forbidden_words_path,
|
||||
ocr_config_path=ocr_config_path,
|
||||
ocr_confidence_threshold=ocr_confidence_threshold
|
||||
)
|
||||
|
||||
# 初始化人脸识别器
|
||||
self.face_recognizer = FaceRecognizer(
|
||||
known_faces_dir=known_faces_dir
|
||||
)
|
||||
|
||||
# 初始化YOLO检测器
|
||||
self.yolo_detector = YoloViolationDetector(
|
||||
model_path=yolo_model_path
|
||||
)
|
||||
|
||||
print("多模型违规检测器初始化完成")
|
||||
|
||||
def detect_violations(self, frame):
|
||||
"""
|
||||
串行调用三个检测模型(OCR → 人脸识别 → YOLO),任一检测到违规即返回结果
|
||||
"""
|
||||
# 1. 首先进行OCR违禁词检测
|
||||
try:
|
||||
ocr_has_violation, ocr_words, ocr_confs = self.ocr_detector.detect(frame)
|
||||
if ocr_has_violation:
|
||||
details = {
|
||||
"words": ocr_words,
|
||||
"confidences": ocr_confs
|
||||
}
|
||||
print(f"警告: OCR检测到违禁内容: {details}")
|
||||
return (True, "ocr", details)
|
||||
except Exception as e:
|
||||
print(f"错误: OCR检测出错: {str(e)}")
|
||||
|
||||
# 2. 接着进行人脸识别检测
|
||||
try:
|
||||
face_has_violation, face_name, face_similarity = self.face_recognizer.recognize(frame)
|
||||
if face_has_violation:
|
||||
details = {
|
||||
"name": face_name,
|
||||
"similarity": face_similarity
|
||||
}
|
||||
print(f"警告: 人脸识别到违规人员: {details}")
|
||||
return (True, "face", details)
|
||||
except Exception as e:
|
||||
print(f"错误: 人脸识别出错: {str(e)}")
|
||||
|
||||
# 3. 最后进行YOLO目标检测
|
||||
try:
|
||||
yolo_results = self.yolo_detector.detect(frame)
|
||||
if len(yolo_results.boxes) > 0:
|
||||
details = {
|
||||
"classes": yolo_results.names,
|
||||
"boxes": yolo_results.boxes.xyxy.tolist(),
|
||||
"confidences": yolo_results.boxes.conf.tolist(),
|
||||
"class_ids": yolo_results.boxes.cls.tolist()
|
||||
}
|
||||
print(f"警告: YOLO检测到违规目标: {details}")
|
||||
return (True, "yolo", details)
|
||||
except Exception as e:
|
||||
print(f"错误: YOLO检测出错: {str(e)}")
|
||||
|
||||
# 所有检测均未发现违规
|
||||
return (False, None, None)
|
||||
|
||||
|
||||
def load_config(config_path: str) -> dict:
|
||||
"""加载YAML配置文件"""
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
return yaml.safe_load(f)
|
||||
except FileNotFoundError:
|
||||
print(f"错误: 配置文件未找到: {config_path}")
|
||||
raise
|
||||
except yaml.YAMLError as e:
|
||||
print(f"错误: 配置文件格式错误: {config_path}, 错误: {str(e)}")
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"错误: 加载配置文件出错: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
# 使用示例
|
||||
# if __name__ == "__main__":
|
||||
# # 加载配置文件
|
||||
# config = load_config("config.yaml") # 配置文件路径,可根据实际情况修改
|
||||
#
|
||||
# # 初始化多模型检测器
|
||||
# detector = MultiModelViolationDetector(
|
||||
# forbidden_words_path=config["forbidden_words_path"],
|
||||
# ocr_config_path=config["ocr_config_path"],
|
||||
# yolo_model_path=config["yolo_model_path"],
|
||||
# known_faces_dir=config["known_faces_dir"],
|
||||
# ocr_confidence_threshold=config.get("ocr_confidence_threshold", 0.5)
|
||||
# )
|
||||
#
|
||||
# # 读取测试图像(可替换为视频帧读取逻辑)
|
||||
# test_image_path = config.get("test_image_path") # 从配置文件获取测试图片路径
|
||||
# if test_image_path:
|
||||
# frame = cv2.imread(test_image_path)
|
||||
#
|
||||
# if frame is not None:
|
||||
# has_violation, violation_type, details = detector.detect_violations(frame)
|
||||
# if has_violation:
|
||||
# print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
|
||||
# else:
|
||||
# print("未检测到任何违规内容")
|
||||
# else:
|
||||
# print(f"无法读取测试图像: {test_image_path}")
|
||||
# else:
|
||||
# print("配置文件中未指定测试图像路径")
|
@ -1,178 +0,0 @@
|
||||
import os
|
||||
import cv2
|
||||
from rapidocr import RapidOCR
|
||||
|
||||
|
||||
class OCRViolationDetector:
|
||||
"""
|
||||
封装RapidOCR引擎,用于检测图像帧中的违禁词。
|
||||
核心功能:加载违禁词、初始化OCR引擎、单帧图像违禁词检测
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
forbidden_words_path: str,
|
||||
ocr_config_path: str,
|
||||
ocr_confidence_threshold: float = 0.5):
|
||||
"""
|
||||
初始化OCR引擎和违禁词列表。
|
||||
|
||||
Args:
|
||||
forbidden_words_path (str): 违禁词列表 .txt 文件的路径。
|
||||
ocr_config_path (str): OCR配置文件(如1.yaml)的路径。
|
||||
ocr_confidence_threshold (float): OCR识别结果的置信度阈值(0~1)。
|
||||
"""
|
||||
# 加载违禁词
|
||||
self.forbidden_words = self._load_forbidden_words(forbidden_words_path)
|
||||
|
||||
# 初始化RapidOCR引擎
|
||||
self.ocr_engine = self._initialize_ocr(ocr_config_path)
|
||||
|
||||
# 校验核心依赖是否就绪
|
||||
self._check_dependencies()
|
||||
|
||||
# 设置置信度阈值(限制在0~1范围)
|
||||
self.OCR_CONFIDENCE_THRESHOLD = max(0.0, min(ocr_confidence_threshold, 1.0))
|
||||
print(f"OCR置信度阈值已设置(范围0~1): {self.OCR_CONFIDENCE_THRESHOLD:.4f}")
|
||||
|
||||
def _load_forbidden_words(self, path: str) -> set:
|
||||
"""
|
||||
从TXT文件加载违禁词(去重、过滤空行,支持UTF-8编码)
|
||||
"""
|
||||
forbidden_words = set()
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(path):
|
||||
print(f"错误:违禁词文件不存在: {path}")
|
||||
return forbidden_words
|
||||
|
||||
# 读取文件并处理内容
|
||||
try:
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
forbidden_words = {
|
||||
line.strip() for line in f
|
||||
if line.strip() # 跳过空行或纯空格行
|
||||
}
|
||||
print(f"成功加载违禁词: {len(forbidden_words)} 个(已去重)")
|
||||
except UnicodeDecodeError:
|
||||
print(f"错误:违禁词文件编码错误(需UTF-8): {path}")
|
||||
except PermissionError:
|
||||
print(f"错误:无权限读取违禁词文件: {path}")
|
||||
except Exception as e:
|
||||
print(f"错误:加载违禁词失败: {str(e)}")
|
||||
|
||||
return forbidden_words
|
||||
|
||||
def _initialize_ocr(self, config_path: str) -> RapidOCR | None:
|
||||
"""
|
||||
初始化RapidOCR引擎(校验配置文件、捕获初始化异常)
|
||||
"""
|
||||
print("开始初始化RapidOCR引擎...")
|
||||
|
||||
# 检查配置文件是否存在
|
||||
if not os.path.exists(config_path):
|
||||
print(f"错误:OCR配置文件不存在: {config_path}")
|
||||
return None
|
||||
|
||||
# 初始化OCR引擎
|
||||
try:
|
||||
ocr_engine = RapidOCR(config_path=config_path)
|
||||
print("RapidOCR引擎初始化成功")
|
||||
return ocr_engine
|
||||
except ImportError:
|
||||
print("错误:RapidOCR依赖未安装(需执行:pip install rapidocr-onnxruntime)")
|
||||
except Exception as e:
|
||||
print(f"错误:RapidOCR初始化失败: {str(e)}")
|
||||
|
||||
return None
|
||||
|
||||
def _check_dependencies(self) -> None:
|
||||
"""校验OCR引擎和违禁词列表是否就绪"""
|
||||
if not self.ocr_engine:
|
||||
print("警告:⚠️ OCR引擎未就绪,违禁词检测功能将禁用")
|
||||
if not self.forbidden_words:
|
||||
print("警告:⚠️ 违禁词列表为空,违禁词检测功能将禁用")
|
||||
|
||||
def detect(self, frame) -> tuple[bool, list, list]:
|
||||
"""
|
||||
对单帧图像进行OCR违禁词检测(核心方法)
|
||||
|
||||
Args:
|
||||
frame: 输入图像帧(NumPy数组,BGR格式,cv2读取的图像)。
|
||||
|
||||
Returns:
|
||||
tuple[bool, list, list]:
|
||||
- 第一个元素:是否检测到违禁词(True/False);
|
||||
- 第二个元素:检测到的违禁词列表(空列表表示无违禁词);
|
||||
- 第三个元素:对应违禁词的置信度列表(与违禁词列表一一对应)。
|
||||
"""
|
||||
# 初始化返回结果
|
||||
has_violation = False
|
||||
violation_words = []
|
||||
violation_confs = []
|
||||
|
||||
# 前置校验
|
||||
if frame is None or frame.size == 0:
|
||||
print("警告:输入图像帧为空或无效,跳过OCR检测")
|
||||
return has_violation, violation_words, violation_confs
|
||||
if not self.ocr_engine or not self.forbidden_words:
|
||||
print("OCR引擎未就绪或违禁词为空,跳过OCR检测")
|
||||
return has_violation, violation_words, violation_confs
|
||||
|
||||
try:
|
||||
# 执行OCR识别
|
||||
print("开始执行OCR识别...")
|
||||
ocr_result = self.ocr_engine(frame)
|
||||
print(f"RapidOCR原始结果: {ocr_result}")
|
||||
|
||||
# 校验OCR结果是否有效
|
||||
if ocr_result is None:
|
||||
print("OCR识别未返回任何结果(图像无文本或识别失败)")
|
||||
return has_violation, violation_words, violation_confs
|
||||
|
||||
# 检查txts和scores是否存在且不为None
|
||||
if not hasattr(ocr_result, 'txts') or ocr_result.txts is None:
|
||||
print("警告:OCR结果中txts为None或不存在")
|
||||
return has_violation, violation_words, violation_confs
|
||||
|
||||
if not hasattr(ocr_result, 'scores') or ocr_result.scores is None:
|
||||
print("警告:OCR结果中scores为None或不存在")
|
||||
return has_violation, violation_words, violation_confs
|
||||
|
||||
# 转为列表并去None
|
||||
if not isinstance(ocr_result.txts, (list, tuple)):
|
||||
print(f"警告:OCR txts不是可迭代类型,实际类型: {type(ocr_result.txts)}")
|
||||
texts = []
|
||||
else:
|
||||
texts = [txt.strip() for txt in ocr_result.txts if txt and isinstance(txt, str)]
|
||||
|
||||
if not isinstance(ocr_result.scores, (list, tuple)):
|
||||
print(f"警告:OCR scores不是可迭代类型,实际类型: {type(ocr_result.scores)}")
|
||||
confidences = []
|
||||
else:
|
||||
confidences = [conf for conf in ocr_result.scores if conf and isinstance(conf, (int, float))]
|
||||
|
||||
# 校验文本和置信度列表长度是否一致
|
||||
if len(texts) != len(confidences):
|
||||
print(f"警告:OCR文本与置信度数量不匹配(文本{len(texts)}个,置信度{len(confidences)}个),跳过检测")
|
||||
return has_violation, violation_words, violation_confs
|
||||
if len(texts) == 0:
|
||||
print("OCR未识别到任何有效文本")
|
||||
return has_violation, violation_words, violation_confs
|
||||
|
||||
# 遍历识别结果,筛选违禁词
|
||||
print(f"开始筛选违禁词(阈值{self.OCR_CONFIDENCE_THRESHOLD:.4f})")
|
||||
for text, conf in zip(texts, confidences):
|
||||
if conf < self.OCR_CONFIDENCE_THRESHOLD:
|
||||
print(f"文本 '{text}' 置信度{conf:.4f} < 阈值,跳过")
|
||||
continue
|
||||
matched_words = [word for word in self.forbidden_words if word in text]
|
||||
if matched_words:
|
||||
has_violation = True
|
||||
violation_words.extend(matched_words)
|
||||
violation_confs.extend([conf] * len(matched_words))
|
||||
print(f"警告:检测到违禁词: {matched_words}(来源文本: '{text}',置信度: {conf:.4f})")
|
||||
|
||||
except Exception as e:
|
||||
print(f"错误:OCR检测过程异常: {str(e)}")
|
||||
|
||||
return has_violation, violation_words, violation_confs
|
@ -1,47 +0,0 @@
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
|
||||
class ViolationDetector:
|
||||
"""
|
||||
用于加载YOLOv8 .pt模型并进行违规内容检测的类。
|
||||
"""
|
||||
def __init__(self, model_path):
|
||||
"""
|
||||
初始化检测器。
|
||||
|
||||
Args:
|
||||
model_path (str): YOLO .pt模型的路径。
|
||||
"""
|
||||
print(f"正在从 '{model_path}' 加载YOLO模型...")
|
||||
self.model = YOLO(model_path)
|
||||
print("YOLO模型加载成功。")
|
||||
|
||||
def detect(self, frame):
|
||||
"""
|
||||
对单帧图像进行目标检测。
|
||||
|
||||
Args:
|
||||
frame: 输入的图像帧 (NumPy数组, BGR格式)。
|
||||
|
||||
Returns:
|
||||
ultralytics.engine.results.Results: YOLO的检测结果对象。
|
||||
"""
|
||||
# conf可以根据您的模型效果进行调整
|
||||
# --- 为了测试,我们暂时将置信度调低,例如 0.2 ---
|
||||
results = self.model(frame, conf=0.2)
|
||||
return results[0]
|
||||
|
||||
def draw_boxes(self, frame, result):
|
||||
"""
|
||||
在图像帧上绘制检测框。
|
||||
|
||||
Args:
|
||||
frame: 原始图像帧。
|
||||
result: YOLO的检测结果对象。
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: 绘制了检测框的图像帧。
|
||||
"""
|
||||
# 使用YOLO自带的plot功能,方便快捷
|
||||
annotated_frame = result.plot()
|
||||
return annotated_frame
|
164
rtc/rtc.py
@ -1,164 +0,0 @@
|
||||
import queue
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import threading
|
||||
import time
|
||||
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration
|
||||
from aiortc.mediastreams import MediaStreamTrack
|
||||
|
||||
# 创建一个长度为1的队列,用于生产者和消费者之间的通信
|
||||
frame_queue = queue.Queue(maxsize=1)
|
||||
|
||||
|
||||
class VideoTrack(MediaStreamTrack):
|
||||
"""自定义视频轨道类,继承自MediaStreamTrack"""
|
||||
kind = "video"
|
||||
|
||||
def __init__(self, max_frames=100):
|
||||
super().__init__()
|
||||
self.frames = queue.Queue(maxsize=max_frames)
|
||||
|
||||
async def recv(self):
|
||||
return await super().recv()
|
||||
|
||||
|
||||
def webrtc_producer(webrtc_url):
|
||||
"""
|
||||
生产者方法:从WEBRTC读取视频帧并放入队列
|
||||
仅当队列空时才放入新帧,否则丢弃
|
||||
"""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# 创建RTCPeerConnection对象,不使用ICE服务器
|
||||
pc = RTCPeerConnection(RTCConfiguration(iceServers=[]))
|
||||
video_track = VideoTrack()
|
||||
pc.addTrack(video_track)
|
||||
|
||||
@pc.on("track")
|
||||
async def on_track(track):
|
||||
if track.kind == "video":
|
||||
print("接收到视频轨道,开始接收视频帧")
|
||||
while True:
|
||||
# 从轨道接收视频帧
|
||||
frame = await track.recv()
|
||||
# 转换为BGR24格式的NumPy数组
|
||||
frame_bgr24 = frame.to_ndarray(format='bgr24')
|
||||
|
||||
# 检查队列是否为空,为空则加入,否则丢弃
|
||||
if frame_queue.empty():
|
||||
try:
|
||||
frame_queue.put_nowait(frame_bgr24)
|
||||
print("帧已放入队列")
|
||||
except queue.Full:
|
||||
print("队列已满,丢弃帧")
|
||||
else:
|
||||
print("队列非空,丢弃帧")
|
||||
|
||||
async def main():
|
||||
# 创建并发送SDP Offer
|
||||
offer = await pc.createOffer()
|
||||
print("已创建本地SDP Offer")
|
||||
await pc.setLocalDescription(offer)
|
||||
|
||||
# 发送Offer到服务器并接收Answer
|
||||
async with aiohttp.ClientSession() as session:
|
||||
print(f"开始向服务器 {webrtc_url} 发送SDP Offer")
|
||||
async with session.post(
|
||||
webrtc_url,
|
||||
data=offer.sdp.encode(),
|
||||
headers={
|
||||
"Content-Type": "application/sdp",
|
||||
"Content-Length": str(len(offer.sdp))
|
||||
},
|
||||
ssl=False
|
||||
) as response:
|
||||
print("已接收到服务器的响应")
|
||||
answer_sdp = await response.text()
|
||||
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer'))
|
||||
|
||||
# 保持连接
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(0.1)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
print("关闭RTCPeerConnection")
|
||||
await pc.close()
|
||||
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
def frame_consumer(ip):
|
||||
"""
|
||||
消费者方法:从队列中读取帧并处理
|
||||
每次处理后休眠200ms模拟延迟
|
||||
"""
|
||||
print("消费者启动,开始等待帧...")
|
||||
try:
|
||||
while True:
|
||||
# 阻塞等待队列中的帧
|
||||
frame = frame_queue.get()
|
||||
print(f"消费帧,大小: {frame.shape}")
|
||||
|
||||
has_violation, violations, confidences = OCRViolationDetector.detect(frame)
|
||||
|
||||
|
||||
# 输出检测结果
|
||||
if has_violation:
|
||||
detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
|
||||
for word, conf in zip(violations, confidences):
|
||||
detector.logger.info(f"- {word} (置信度: {conf:.4f})")
|
||||
else:
|
||||
detector.logger.info("图片中未检测到违禁词")
|
||||
|
||||
|
||||
# 标记任务完成
|
||||
frame_queue.task_done()
|
||||
except KeyboardInterrupt:
|
||||
print("消费者退出")
|
||||
|
||||
|
||||
def start_webrtc_stream(ip, webrtc_url):
|
||||
"""
|
||||
启动WebRTC视频流处理的主方法
|
||||
参数: webrtc_url - WebRTC服务器地址
|
||||
"""
|
||||
print(f"开始连接到WebRTC服务器: {webrtc_url}")
|
||||
|
||||
# 启动生产者线程
|
||||
producer_thread = threading.Thread(
|
||||
target=webrtc_producer,
|
||||
args=(webrtc_url,),
|
||||
daemon=True,
|
||||
name="webrtc-producer"
|
||||
)
|
||||
|
||||
# 启动消费者线程
|
||||
consumer_thread = threading.Thread(
|
||||
target=frame_consumer(ip),
|
||||
daemon=True,
|
||||
name="frame-consumer"
|
||||
)
|
||||
|
||||
producer_thread.start()
|
||||
consumer_thread.start()
|
||||
print("生产者和消费者线程已启动")
|
||||
|
||||
try:
|
||||
# 保持主线程运行
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
print("程序正在退出...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 示例用法
|
||||
# 实际使用时替换为真实的WebRTC服务器地址
|
||||
webrtc_server_url = "http://192.168.110.65:1985/rtc/v1/whep/?app=live&stream=677a4845aa48cb8526c811ad56fc5e60"
|
||||
start_webrtc_stream(webrtc_server_url)
|
101
rtmp/rtmp.py
@ -1,101 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import cv2
|
||||
import time
|
||||
|
||||
# 配置日志(与WHEP代码保持一致的日志风格)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("rtmp_video_puller")
|
||||
|
||||
|
||||
async def rtmp_pull_video_stream(rtmp_url):
|
||||
"""
|
||||
通过RTMP从指定URL拉取视频流并在收到每一帧时打印消息
|
||||
功能与WHEP拉流函数对齐:流状态反馈、帧信息打印、帧率统计、异常处理
|
||||
|
||||
Args:
|
||||
rtmp_url: RTMP流的URL地址(如 rtmp://xxx/live/stream_key)
|
||||
"""
|
||||
cap = None # 初始化视频捕获对象
|
||||
try:
|
||||
# 1. 异步打开RTMP流(指定FFmpeg后端确保RTMP兼容性,同步操作通过to_thread避免阻塞事件循环)
|
||||
cap = await asyncio.to_thread(
|
||||
cv2.VideoCapture,
|
||||
rtmp_url,
|
||||
cv2.CAP_FFMPEG # 必须指定FFmpeg后端,RTMP协议依赖该后端解析
|
||||
)
|
||||
|
||||
# 2. 检查RTMP流是否成功打开
|
||||
is_opened = await asyncio.to_thread(cap.isOpened)
|
||||
if not is_opened:
|
||||
raise Exception(f"RTMP流打开失败: {rtmp_url}(请检查URL有效性和FFmpeg环境)")
|
||||
|
||||
# 3. 异步获取RTMP流基础信息(分辨率、帧率)
|
||||
width = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_WIDTH)
|
||||
height = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_HEIGHT)
|
||||
fps = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FPS)
|
||||
|
||||
# 处理异常情况:部分RTMP流未返回帧率时默认30FPS
|
||||
fps = fps if fps > 0 else 30.0
|
||||
# 分辨率转为整数(视频尺寸必然是整数)
|
||||
width, height = int(width), int(height)
|
||||
|
||||
# 打印流初始化成功信息(与WHEP连接成功信息风格一致)
|
||||
print(f"RTMP流状态: 已成功连接")
|
||||
print(f"流基础信息: 分辨率 {width}x{height} | 配置帧率 {fps:.2f} FPS")
|
||||
print("开始接收视频帧...(按 Ctrl+C 中断)")
|
||||
|
||||
# 4. 初始化帧统计参数
|
||||
frame_count = 0 # 总接收帧数
|
||||
start_time = time.time() # 统计起始时间
|
||||
|
||||
# 5. 循环异步读取视频帧(核心逻辑)
|
||||
while True:
|
||||
# 异步读取一帧(cv2.read是同步操作,用to_thread适配异步环境)
|
||||
ret, frame = await asyncio.to_thread(cap.read)
|
||||
|
||||
# 检查帧是否读取成功(流中断/结束时ret为False)
|
||||
if not ret:
|
||||
print(f"RTMP流状态: 帧读取失败(可能流已中断或结束)")
|
||||
break
|
||||
|
||||
# 帧计数累加
|
||||
frame_count += 1
|
||||
|
||||
# 6. 打印当前帧基础信息(与WHEP帧信息打印风格对齐)
|
||||
print(f"收到帧 (第{frame_count}帧)")
|
||||
print(f" 帧尺寸: {width}x{height}")
|
||||
print(f" 配置帧率: {fps:.2f} FPS")
|
||||
|
||||
# 7. 每100帧统计一次实际接收帧率(补充性能监控,与原RTMP示例逻辑一致)
|
||||
if frame_count % 100 == 0:
|
||||
elapsed_time = time.time() - start_time
|
||||
actual_fps = frame_count / elapsed_time # 实际接收帧率(可能低于配置帧率)
|
||||
print(f"---- 帧统计: 累计{frame_count}帧 | 实际平均帧率 {actual_fps:.2f} FPS ----")
|
||||
|
||||
# (可选)帧数据处理入口:如需处理帧(如推流、分析),可在此处添加逻辑
|
||||
# 示例:yield frame (若需生成器模式,可调整函数为异步生成器)
|
||||
|
||||
# 8. 异常处理(覆盖用户中断、通用错误)
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n用户操作: 已通过 Ctrl+C 中断程序")
|
||||
except Exception as e:
|
||||
# 日志记录详细错误(便于问题排查),同时打印用户可见信息
|
||||
logger.error(f"RTMP流处理异常: {str(e)}", exc_info=True)
|
||||
print(f"错误信息: {str(e)}")
|
||||
finally:
|
||||
# 9. 资源释放(无论成功/失败都确保释放,避免内存泄漏)
|
||||
if cap is not None:
|
||||
await asyncio.to_thread(cap.release)
|
||||
print(f"\n资源释放: RTMP流已关闭")
|
||||
print(f"最终统计: 共接收 {frame_count if 'frame_count' in locals() else 0} 帧")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
RTMP_URL = "rtmp://192.168.110.25:1935/live/473b95a47e338301cbd96809ea7ac416"
|
||||
|
||||
# 运行RTMP拉流任务(与WHEP一致的异步执行方式)
|
||||
try:
|
||||
asyncio.run(rtmp_pull_video_stream(RTMP_URL))
|
||||
except Exception as e:
|
||||
print(f"程序启动失败: {str(e)}")
|
@ -23,7 +23,7 @@ class FaceResponse(BaseModel):
|
||||
"""人脸记录响应模型(仍包含ID,由数据库生成后返回)"""
|
||||
id: int = Field(..., description="主键ID(数据库自增)")
|
||||
name: str = Field(None, description="名称")
|
||||
eigenvalue: str = Field(None, description="特征(暂为None)")
|
||||
eigenvalue: str | None = Field(None, description="特征(可为空)")
|
||||
created_at: datetime = Field(..., description="记录创建时间")
|
||||
updated_at: datetime = Field(..., description="记录更新时间")
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, Query, HTTPException
|
||||
from fastapi import APIRouter, Query, HTTPException,Request
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
from ds.db import db
|
||||
@ -108,7 +108,7 @@ def update_online_status_by_ip(client_ip: str, online_status: int) -> bool:
|
||||
# 原有接口保持不变
|
||||
# ------------------------------
|
||||
@router.post("/add", response_model=APIResponse, summary="创建设备信息")
|
||||
async def create_device(device_data: DeviceCreateRequest):
|
||||
async def create_device(device_data: DeviceCreateRequest, request: Request): # 注入Request对象
|
||||
# 原有代码保持不变
|
||||
conn = None
|
||||
cursor = None
|
||||
@ -125,11 +125,10 @@ async def create_device(device_data: DeviceCreateRequest):
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"设备IP {device_data.ip} 已存在,返回已有设备信息",
|
||||
data=DeviceResponse(**existing_device)
|
||||
data=DeviceResponse(** existing_device)
|
||||
)
|
||||
|
||||
from fastapi import Request
|
||||
request = Request(scope={"type": "http"})
|
||||
# 直接使用注入的request对象获取用户代理
|
||||
user_agent = request.headers.get("User-Agent", "").lower()
|
||||
|
||||
if user_agent == "default":
|
||||
@ -184,7 +183,6 @@ async def create_device(device_data: DeviceCreateRequest):
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
@router.get("/", response_model=APIResponse, summary="获取设备列表(支持筛选分页)")
|
||||
async def get_device_list(
|
||||
page: int = Query(1, ge=1, description="页码,默认第1页"),
|
||||
|
@ -6,15 +6,15 @@ from schema.face_schema import FaceCreateRequest, FaceUpdateRequest, FaceRespons
|
||||
from schema.response_schema import APIResponse
|
||||
from middle.auth_middleware import get_current_user
|
||||
from schema.user_schema import UserResponse
|
||||
from ocr.feature_extraction import BinaryFaceFeatureHandler
|
||||
|
||||
from util.face_util import add_binary_data,get_average_feature
|
||||
#初始化实例
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/faces",
|
||||
tags=["人脸管理"]
|
||||
)
|
||||
|
||||
# 创建 BinaryFaceFeatureHandler 的实例
|
||||
binary_face_feature_handler = BinaryFaceFeatureHandler()
|
||||
|
||||
|
||||
# ------------------------------
|
||||
@ -33,6 +33,8 @@ async def create_face(
|
||||
- ID 由数据库自动生成,无需前端传入
|
||||
- 暂不处理文件内容,eigenvalue 设为 None
|
||||
"""
|
||||
|
||||
# 调用你的方法
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
@ -45,14 +47,24 @@ async def create_face(
|
||||
# 把文件转为二进制数组
|
||||
file_content = await file.read()
|
||||
|
||||
# 调用人脸识别得到特征值
|
||||
# 计算特征值
|
||||
flag, eigenvalue = add_binary_data(file_content)
|
||||
|
||||
if flag == False:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="未检测到人脸"
|
||||
)
|
||||
|
||||
# 打印数组长度
|
||||
print(f"文件大小:{len(file_content)} 字节")
|
||||
|
||||
# 2. 插入数据库:无需传 ID(自增),只传 name 和 eigenvalue(None)
|
||||
insert_query = """
|
||||
INSERT INTO face (name, eigenvalue)
|
||||
VALUES (%s, %s)
|
||||
"""
|
||||
cursor.execute(insert_query, (face_create.name, None))
|
||||
cursor.execute(insert_query, (face_create.name, str(eigenvalue)))
|
||||
conn.commit()
|
||||
|
||||
# 3. 获取数据库自动生成的 ID(关键:用 LAST_INSERT_ID() 查刚插入的记录)
|
||||
@ -60,19 +72,45 @@ async def create_face(
|
||||
cursor.execute(select_new_query)
|
||||
created_face = cursor.fetchone()
|
||||
|
||||
if not created_face:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="创建人脸记录成功,但无法获取新创建的记录"
|
||||
)
|
||||
|
||||
return APIResponse(
|
||||
code=201,
|
||||
message=f"人脸记录创建成功(ID:{created_face['id']},文件名:{file.filename})",
|
||||
data=FaceResponse(**created_face)
|
||||
data=FaceResponse(** created_face)
|
||||
)
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"创建人脸记录失败:{str(e)}") from e
|
||||
# 改为使用HTTPException
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"创建人脸记录失败:{str(e)}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
# 捕获其他可能的异常
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"服务器错误:{str(e)}"
|
||||
) from e
|
||||
finally:
|
||||
await file.close() # 关闭文件流
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
# 调用人脸识别得到特征值(这里可以添加你的人脸识别逻辑)
|
||||
flag, eigenvalue = add_binary_data(file_content)
|
||||
if flag == False:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="未检测到人脸"
|
||||
)
|
||||
|
||||
# 将 eigenvalue 转为 str
|
||||
eigenvalue = str(eigenvalue)
|
||||
|
||||
# ------------------------------
|
||||
# 2. 获取单个人脸记录(不变,用自增ID查询)
|
||||
@ -104,18 +142,21 @@ async def get_face(
|
||||
data=FaceResponse(**face)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise Exception(f"查询人脸记录失败:{str(e)}") from e
|
||||
# 改为使用HTTPException
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"查询人脸记录失败:{str(e)}"
|
||||
) from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 后续 3.获取所有、4.更新、5.删除 接口(不变,仅用数据库自增的 ID 操作,无需修改)
|
||||
# 后续 3.获取所有、4.更新、5.删除 接口(修复异常处理)
|
||||
# ------------------------------
|
||||
# 3. 获取所有人脸记录(不变)
|
||||
# ------------------------------
|
||||
@router.get("", response_model=APIResponse, summary="获取所有人脸记录")
|
||||
async def get_all_faces(
|
||||
current_user: UserResponse = Depends(get_current_user)
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
@ -130,10 +171,13 @@ async def get_all_faces(
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="所有人脸记录查询成功",
|
||||
data=[FaceResponse(**face) for face in faces]
|
||||
data=[FaceResponse(** face) for face in faces]
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise Exception(f"查询所有人脸记录失败:{str(e)}") from e
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"查询所有人脸记录失败:{str(e)}"
|
||||
) from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
@ -194,7 +238,10 @@ async def update_face(
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"更新人脸记录失败:{str(e)}") from e
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"更新人脸记录失败:{str(e)}"
|
||||
) from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
@ -234,7 +281,10 @@ async def delete_face(
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"删除人脸记录失败:{str(e)}") from e
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"删除人脸记录失败:{str(e)}"
|
||||
) from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
@ -249,38 +299,43 @@ def get_all_face_name_with_eigenvalue() -> dict:
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
# 1. 建立数据库连接并获取游标(dictionary=True使结果以字典形式返回)
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 只查询需要的字段,提高效率
|
||||
# 2. 执行SQL查询:只获取name非空的记录,减少数据传输
|
||||
query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL"
|
||||
cursor.execute(query)
|
||||
faces = cursor.fetchall()
|
||||
faces = cursor.fetchall() # 返回结果:列表套字典,如 [{"name":"张三","eigenvalue":...}, ...]
|
||||
|
||||
# 先收集所有名称对应的特征值列表(处理重复名称)
|
||||
# 3. 收集同一名称对应的所有特征值(处理名称重复场景)
|
||||
name_to_eigenvalues = {}
|
||||
for face in faces:
|
||||
name = face["name"]
|
||||
eigenvalue = face["eigenvalue"]
|
||||
# 若名称已存在,追加特征值;否则新建列表存储
|
||||
if name in name_to_eigenvalues:
|
||||
name_to_eigenvalues[name].append(eigenvalue)
|
||||
else:
|
||||
name_to_eigenvalues[name] = [eigenvalue]
|
||||
|
||||
# 构建最终字典:重复名称取平均特征值,唯一名称直接取特征值
|
||||
# 4. 构建最终字典:重复名称取平均,唯一名称直接取特征值
|
||||
face_dict = {}
|
||||
for name, eigenvalues in name_to_eigenvalues.items():
|
||||
print("调用的特征值是:" + eigenvalues)
|
||||
|
||||
# 处理特征值:多个则求平均,单个则直接使用
|
||||
if len(eigenvalues) > 1:
|
||||
# 调用平均特征值计算方法
|
||||
face_dict[name] = binary_face_feature_handler.get_average_feature(eigenvalues)
|
||||
# 调用外部方法计算平均特征值(需确保binary_face_feature_handler已正确导入)
|
||||
face_dict[name] = get_average_feature(eigenvalues)
|
||||
else:
|
||||
# 取列表中唯一的特征值(避免value为列表类型)
|
||||
face_dict[name] = eigenvalues[0]
|
||||
|
||||
return face_dict
|
||||
|
||||
except MySQLError as e:
|
||||
# 捕获数据库异常,添加上下文信息后重新抛出(便于定位问题)
|
||||
raise Exception(f"获取人脸名称与特征值失败:{str(e)}") from e
|
||||
finally:
|
||||
# 确保资源释放
|
||||
# 5. 无论是否异常,均释放数据库连接和游标(避免资源泄漏)
|
||||
db.close_connection(conn, cursor)
|
145
util/face_util.py
Normal file
@ -0,0 +1,145 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import insightface
|
||||
from insightface.app import FaceAnalysis
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
# 全局变量存储InsightFace引擎和特征列表
|
||||
_insightface_app = None
|
||||
_feature_list = []
|
||||
|
||||
|
||||
def init_insightface():
|
||||
"""初始化InsightFace引擎"""
|
||||
global _insightface_app
|
||||
try:
|
||||
print("正在初始化InsightFace引擎...")
|
||||
app = FaceAnalysis(name='buffalo_l', root='~/.insightface')
|
||||
app.prepare(ctx_id=0, det_size=(640, 640))
|
||||
print("InsightFace引擎初始化完成")
|
||||
_insightface_app = app
|
||||
return app
|
||||
except Exception as e:
|
||||
print(f"InsightFace初始化失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def add_binary_data(binary_data):
|
||||
"""
|
||||
接收单张图片的二进制数据,提取特征并保存
|
||||
|
||||
参数:
|
||||
binary_data: 图片的二进制数据(bytes类型)
|
||||
|
||||
返回:
|
||||
成功提取特征时返回 (True, 特征值numpy数组)
|
||||
失败时返回 (False, None)
|
||||
"""
|
||||
global _insightface_app, _feature_list
|
||||
|
||||
if not _insightface_app:
|
||||
print("引擎未初始化,无法处理")
|
||||
return False, None
|
||||
|
||||
try:
|
||||
# 直接处理二进制数据:转换为图像格式
|
||||
img = Image.open(BytesIO(binary_data))
|
||||
frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
||||
|
||||
# 提取特征
|
||||
faces = _insightface_app.get(frame)
|
||||
if faces:
|
||||
# 获取当前提取的特征值
|
||||
current_feature = faces[0].embedding
|
||||
# 添加到特征列表
|
||||
_feature_list.append(current_feature)
|
||||
print(f"已累计 {len(_feature_list)} 个特征")
|
||||
# 返回成功标志和当前特征值
|
||||
return True, current_feature
|
||||
else:
|
||||
print("二进制数据中未检测到人脸")
|
||||
return False, None
|
||||
except Exception as e:
|
||||
print(f"处理二进制数据出错: {e}")
|
||||
return False, None
|
||||
|
||||
|
||||
def get_average_feature(features=None):
|
||||
"""
|
||||
计算多个特征向量的平均值
|
||||
|
||||
参数:
|
||||
features: 可选,特征值列表。如果未提供,则使用全局存储的_feature_list
|
||||
每个元素可以是字符串格式或numpy数组
|
||||
|
||||
返回:
|
||||
单一平均特征向量的numpy数组,若无可计算数据则返回None
|
||||
"""
|
||||
global _feature_list
|
||||
|
||||
# 如果未提供features参数,则使用全局特征列表
|
||||
if features is None:
|
||||
features = _feature_list
|
||||
|
||||
try:
|
||||
# 验证输入是否为列表且不为空
|
||||
if not isinstance(features, list) or len(features) == 0:
|
||||
print("输入必须是包含至少一个特征值的列表")
|
||||
return None
|
||||
|
||||
# 处理每个特征值
|
||||
processed_features = []
|
||||
for i, embedding in enumerate(features):
|
||||
try:
|
||||
if isinstance(embedding, str):
|
||||
# 处理包含括号和逗号的字符串格式
|
||||
embedding_str = embedding.replace('[', '').replace(']', '').replace(',', ' ').strip()
|
||||
embedding_list = [float(num) for num in embedding_str.split() if num.strip()]
|
||||
embedding_np = np.array(embedding_list, dtype=np.float32)
|
||||
else:
|
||||
embedding_np = np.array(embedding, dtype=np.float32)
|
||||
|
||||
# 验证特征值格式
|
||||
if len(embedding_np.shape) == 1:
|
||||
processed_features.append(embedding_np)
|
||||
print(f"已添加第 {i + 1} 个特征值用于计算平均值")
|
||||
else:
|
||||
print(f"跳过第 {i + 1} 个特征值,不是一维数组")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理第 {i + 1} 个特征值时出错: {e}")
|
||||
|
||||
# 确保有有效的特征值
|
||||
if not processed_features:
|
||||
print("没有有效的特征值用于计算平均值")
|
||||
return None
|
||||
|
||||
# 检查所有特征向量维度是否相同
|
||||
dims = {feat.shape[0] for feat in processed_features}
|
||||
if len(dims) > 1:
|
||||
print(f"特征值维度不一致,无法计算平均值。检测到的维度: {dims}")
|
||||
return None
|
||||
|
||||
# 计算平均值
|
||||
avg_feature = np.mean(processed_features, axis=0)
|
||||
print(f"成功计算 {len(processed_features)} 个特征值的平均特征向量,维度: {avg_feature.shape[0]}")
|
||||
|
||||
return avg_feature
|
||||
|
||||
except Exception as e:
|
||||
print(f"计算平均特征值时出错: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def clear_features():
|
||||
"""清空已存储的特征数据"""
|
||||
global _feature_list
|
||||
_feature_list = []
|
||||
print("已清空所有特征数据")
|
||||
|
||||
|
||||
def get_feature_list():
|
||||
"""获取当前存储的特征列表"""
|
||||
global _feature_list
|
||||
return _feature_list.copy() # 返回副本防止外部直接修改
|
482
ws.html
@ -1,482 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>WebSocket 测试工具</title>
|
||||
<style>
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
font-family: 'Arial', 'Microsoft YaHei', sans-serif;
|
||||
}
|
||||
|
||||
body {
|
||||
max-width: 1200px;
|
||||
margin: 20px auto;
|
||||
padding: 0 20px;
|
||||
background-color: #f5f7fa;
|
||||
}
|
||||
|
||||
.container {
|
||||
background: white;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
|
||||
padding: 25px;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
h1 {
|
||||
color: #2c3e50;
|
||||
margin-bottom: 20px;
|
||||
font-size: 24px;
|
||||
border-bottom: 2px solid #3498db;
|
||||
padding-bottom: 10px;
|
||||
}
|
||||
|
||||
.status-bar {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 15px;
|
||||
margin-bottom: 20px;
|
||||
padding: 12px 15px;
|
||||
background-color: #f8f9fa;
|
||||
border-radius: 6px;
|
||||
}
|
||||
|
||||
.status-label {
|
||||
font-weight: bold;
|
||||
color: #495057;
|
||||
}
|
||||
|
||||
.status-value {
|
||||
padding: 4px 10px;
|
||||
border-radius: 4px;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.status-connected {
|
||||
background-color: #d4edda;
|
||||
color: #155724;
|
||||
}
|
||||
|
||||
.status-disconnected {
|
||||
background-color: #f8d7da;
|
||||
color: #721c24;
|
||||
}
|
||||
|
||||
.status-connecting {
|
||||
background-color: #fff3cd;
|
||||
color: #856404;
|
||||
}
|
||||
|
||||
.btn {
|
||||
padding: 8px 16px;
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
transition: background-color 0.2s;
|
||||
}
|
||||
|
||||
.btn-primary {
|
||||
background-color: #3498db;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-primary:hover {
|
||||
background-color: #2980b9;
|
||||
}
|
||||
|
||||
.btn-danger {
|
||||
background-color: #e74c3c;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-danger:hover {
|
||||
background-color: #c0392b;
|
||||
}
|
||||
|
||||
.btn-success {
|
||||
background-color: #2ecc71;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-success:hover {
|
||||
background-color: #27ae60;
|
||||
}
|
||||
|
||||
.control-group {
|
||||
display: flex;
|
||||
gap: 15px;
|
||||
margin-bottom: 20px;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.input-group {
|
||||
display: flex;
|
||||
gap: 10px;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.input-group label {
|
||||
color: #495057;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.input-group input, .input-group select {
|
||||
padding: 8px 12px;
|
||||
border: 1px solid #ced4da;
|
||||
border-radius: 4px;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
.message-area {
|
||||
margin-top: 20px;
|
||||
}
|
||||
|
||||
.message-input {
|
||||
width: 100%;
|
||||
height: 100px;
|
||||
padding: 12px;
|
||||
border: 1px solid #ced4da;
|
||||
border-radius: 6px;
|
||||
resize: none;
|
||||
font-size: 14px;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
.log-area {
|
||||
width: 100%;
|
||||
height: 300px;
|
||||
padding: 15px;
|
||||
border: 1px solid #ced4da;
|
||||
border-radius: 6px;
|
||||
background-color: #f8f9fa;
|
||||
overflow-y: auto;
|
||||
font-size: 14px;
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
.log-item {
|
||||
margin-bottom: 8px;
|
||||
padding-bottom: 8px;
|
||||
border-bottom: 1px dashed #e9ecef;
|
||||
}
|
||||
|
||||
.log-time {
|
||||
color: #6c757d;
|
||||
font-size: 12px;
|
||||
margin-right: 10px;
|
||||
}
|
||||
|
||||
.log-send {
|
||||
color: #2980b9;
|
||||
}
|
||||
|
||||
.log-receive {
|
||||
color: #27ae60;
|
||||
}
|
||||
|
||||
.log-status {
|
||||
color: #856404;
|
||||
}
|
||||
|
||||
.log-error {
|
||||
color: #e74c3c;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>WebSocket 测试工具</h1>
|
||||
|
||||
<!-- 连接状态区 -->
|
||||
<div class="status-bar">
|
||||
<div class="status-label">连接状态:</div>
|
||||
<div id="connectionStatus" class="status-value status-disconnected">未连接</div>
|
||||
<div class="status-label">服务地址:</div>
|
||||
<div id="wsUrl" class="status-value">ws://192.168.110.25:8000/ws</div>
|
||||
<div class="status-label">连接时间:</div>
|
||||
<div id="connectTime" class="status-value">-</div>
|
||||
</div>
|
||||
|
||||
<!-- 控制按钮区 -->
|
||||
<div class="control-group">
|
||||
<button id="connectBtn" class="btn btn-primary">建立连接</button>
|
||||
<button id="disconnectBtn" class="btn btn-danger" disabled>断开连接</button>
|
||||
|
||||
<!-- 心跳控制 -->
|
||||
<div class="input-group">
|
||||
<label>自动心跳:</label>
|
||||
<select id="autoHeartbeat">
|
||||
<option value="on">开启</option>
|
||||
<option value="off">关闭</option>
|
||||
</select>
|
||||
<label>间隔(秒):</label>
|
||||
<input type="number" id="heartbeatInterval" value="30" min="10" max="120" style="width: 80px;">
|
||||
<button id="sendHeartbeatBtn" class="btn btn-success">手动发送心跳</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 自定义消息发送区 -->
|
||||
<div class="message-area">
|
||||
<h3>发送自定义消息</h3>
|
||||
<textarea id="messageInput" class="message-input"
|
||||
placeholder='示例:{"type":"test","content":"Hello WebSocket"}'>{"type":"test","content":"Hello WebSocket"}</textarea>
|
||||
<button id="sendMessageBtn" class="btn btn-primary" disabled>发送消息</button>
|
||||
</div>
|
||||
|
||||
<!-- 日志显示区 -->
|
||||
<div class="message-area">
|
||||
<h3>消息日志</h3>
|
||||
<div id="logContainer" class="log-area">
|
||||
<div class="log-item"><span class="log-time">[加载完成]</span> 请点击「建立连接」开始测试</div>
|
||||
</div>
|
||||
<button id="clearLogBtn" class="btn btn-primary" style="margin-top: 10px;">清空日志</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
// 全局变量
|
||||
let ws = null;
|
||||
let heartbeatTimer = null;
|
||||
const wsUrl = "ws://192.168.110.25:8000/ws";
|
||||
|
||||
// DOM 元素
|
||||
const connectionStatus = document.getElementById('connectionStatus');
|
||||
const connectTime = document.getElementById('connectTime');
|
||||
const connectBtn = document.getElementById('connectBtn');
|
||||
const disconnectBtn = document.getElementById('disconnectBtn');
|
||||
const sendMessageBtn = document.getElementById('sendMessageBtn');
|
||||
const sendHeartbeatBtn = document.getElementById('sendHeartbeatBtn');
|
||||
const autoHeartbeat = document.getElementById('autoHeartbeat');
|
||||
const heartbeatInterval = document.getElementById('heartbeatInterval');
|
||||
const messageInput = document.getElementById('messageInput');
|
||||
const logContainer = document.getElementById('logContainer');
|
||||
const clearLogBtn = document.getElementById('clearLogBtn');
|
||||
|
||||
// 工具函数:添加日志
|
||||
function addLog(content, type = 'status') {
|
||||
const now = new Date().toLocaleString('zh-CN', {
|
||||
year: 'numeric', month: '2-digit', day: '2-digit',
|
||||
hour: '2-digit', minute: '2-digit', second: '2-digit'
|
||||
});
|
||||
const logItem = document.createElement('div');
|
||||
logItem.className = 'log-item';
|
||||
|
||||
let logClass = '';
|
||||
switch (type) {
|
||||
case 'send':
|
||||
logClass = 'log-send';
|
||||
break;
|
||||
case 'receive':
|
||||
logClass = 'log-receive';
|
||||
break;
|
||||
case 'error':
|
||||
logClass = 'log-error';
|
||||
break;
|
||||
default:
|
||||
logClass = 'log-status';
|
||||
}
|
||||
|
||||
logItem.innerHTML = `<span class="log-time">[${now}]</span> <span class="${logClass}">${content}</span>`;
|
||||
logContainer.appendChild(logItem);
|
||||
// 滚动到最新日志
|
||||
logContainer.scrollTop = logContainer.scrollHeight;
|
||||
}
|
||||
|
||||
// 工具函数:格式化JSON(便于日志显示)
|
||||
function formatJson(jsonStr) {
|
||||
try {
|
||||
const obj = JSON.parse(jsonStr);
|
||||
return JSON.stringify(obj, null, 2);
|
||||
} catch (e) {
|
||||
return jsonStr; // 非JSON格式直接返回
|
||||
}
|
||||
}
|
||||
|
||||
// 建立WebSocket连接
|
||||
function connectWebSocket() {
|
||||
if (ws) {
|
||||
addLog('已存在连接,无需重复建立', 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
ws = new WebSocket(wsUrl);
|
||||
|
||||
// 连接成功
|
||||
ws.onopen = function () {
|
||||
connectionStatus.className = 'status-value status-connected';
|
||||
connectionStatus.textContent = '已连接';
|
||||
const now = new Date().toLocaleString('zh-CN');
|
||||
connectTime.textContent = now;
|
||||
addLog(`连接成功!服务地址:${wsUrl}`, 'status');
|
||||
|
||||
// 更新按钮状态
|
||||
connectBtn.disabled = true;
|
||||
disconnectBtn.disabled = false;
|
||||
sendMessageBtn.disabled = false;
|
||||
|
||||
// 开启自动心跳(默认开启)
|
||||
if (autoHeartbeat.value === 'on') {
|
||||
startAutoHeartbeat();
|
||||
}
|
||||
};
|
||||
|
||||
// 接收消息
|
||||
ws.onmessage = function (event) {
|
||||
const message = event.data;
|
||||
addLog(`收到消息:\n${formatJson(message)}`, 'receive');
|
||||
};
|
||||
|
||||
// 连接关闭
|
||||
ws.onclose = function (event) {
|
||||
connectionStatus.className = 'status-value status-disconnected';
|
||||
connectionStatus.textContent = '已断开';
|
||||
addLog(`连接断开!代码:${event.code},原因:${event.reason || '未知'}`, 'status');
|
||||
|
||||
// 清除自动心跳
|
||||
stopAutoHeartbeat();
|
||||
|
||||
// 更新按钮状态
|
||||
connectBtn.disabled = false;
|
||||
disconnectBtn.disabled = true;
|
||||
sendMessageBtn.disabled = true;
|
||||
|
||||
// 重置WebSocket对象
|
||||
ws = null;
|
||||
};
|
||||
|
||||
// 连接错误
|
||||
ws.onerror = function (error) {
|
||||
addLog(`连接错误:${error.message || '未知错误'}`, 'error');
|
||||
};
|
||||
|
||||
} catch (e) {
|
||||
addLog(`建立连接失败:${e.message}`, 'error');
|
||||
ws = null;
|
||||
}
|
||||
}
|
||||
|
||||
// 断开WebSocket连接
|
||||
function disconnectWebSocket() {
|
||||
if (!ws) {
|
||||
addLog('当前无连接,无需断开', 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
ws.close(1000, '手动断开连接');
|
||||
}
|
||||
|
||||
// 发送心跳消息(符合约定格式:{"timestamp":xxxxx, "type":"heartbeat"})
|
||||
function sendHeartbeat() {
|
||||
if (!ws || ws.readyState !== WebSocket.OPEN) {
|
||||
addLog('发送心跳失败:当前无有效连接', 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
const heartbeatMsg = {
|
||||
timestamp: Date.now(), // 当前毫秒时间戳
|
||||
type: "heartbeat"
|
||||
};
|
||||
const msgStr = JSON.stringify(heartbeatMsg);
|
||||
|
||||
ws.send(msgStr);
|
||||
addLog(`发送心跳:\n${formatJson(msgStr)}`, 'send');
|
||||
}
|
||||
|
||||
// 开启自动心跳
|
||||
function startAutoHeartbeat() {
|
||||
// 先停止已有定时器
|
||||
stopAutoHeartbeat();
|
||||
|
||||
const interval = parseInt(heartbeatInterval.value) * 1000;
|
||||
if (isNaN(interval) || interval < 10000) {
|
||||
addLog('自动心跳间隔无效,已重置为30秒', 'error');
|
||||
heartbeatInterval.value = 30;
|
||||
return startAutoHeartbeat();
|
||||
}
|
||||
|
||||
addLog(`开启自动心跳,间隔:${heartbeatInterval.value}秒`, 'status');
|
||||
heartbeatTimer = setInterval(sendHeartbeat, interval);
|
||||
}
|
||||
|
||||
// 停止自动心跳
|
||||
function stopAutoHeartbeat() {
|
||||
if (heartbeatTimer) {
|
||||
clearInterval(heartbeatTimer);
|
||||
heartbeatTimer = null;
|
||||
addLog('已停止自动心跳', 'status');
|
||||
}
|
||||
}
|
||||
|
||||
// 发送自定义消息
|
||||
function sendCustomMessage() {
|
||||
if (!ws || ws.readyState !== WebSocket.OPEN) {
|
||||
addLog('发送消息失败:当前无有效连接', 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
const msgStr = messageInput.value.trim();
|
||||
if (!msgStr) {
|
||||
addLog('发送消息失败:消息内容不能为空', 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// 验证JSON格式(可选,仅提示不强制)
|
||||
JSON.parse(msgStr);
|
||||
ws.send(msgStr);
|
||||
addLog(`发送自定义消息:\n${formatJson(msgStr)}`, 'send');
|
||||
} catch (e) {
|
||||
addLog(`JSON格式错误:${e.message},仍尝试发送原始内容`, 'error');
|
||||
ws.send(msgStr);
|
||||
addLog(`发送自定义消息(非JSON):\n${msgStr}`, 'send');
|
||||
}
|
||||
}
|
||||
|
||||
// 绑定按钮事件
|
||||
connectBtn.addEventListener('click', connectWebSocket);
|
||||
disconnectBtn.addEventListener('click', disconnectWebSocket);
|
||||
sendMessageBtn.addEventListener('click', sendCustomMessage);
|
||||
sendHeartbeatBtn.addEventListener('click', sendHeartbeat);
|
||||
clearLogBtn.addEventListener('click', () => {
|
||||
logContainer.innerHTML = '<div class="log-item"><span class="log-time">[日志已清空]</span> 请继续操作...</div>';
|
||||
});
|
||||
|
||||
// 自动心跳开关变更事件
|
||||
autoHeartbeat.addEventListener('change', function () {
|
||||
if (ws && ws.readyState === WebSocket.OPEN) {
|
||||
if (this.value === 'on') {
|
||||
startAutoHeartbeat();
|
||||
} else {
|
||||
stopAutoHeartbeat();
|
||||
}
|
||||
} else {
|
||||
addLog('需先建立有效连接才能控制自动心跳', 'error');
|
||||
// 重置选择
|
||||
this.value = 'off';
|
||||
}
|
||||
});
|
||||
|
||||
// 心跳间隔变更事件(实时生效)
|
||||
heartbeatInterval.addEventListener('change', function () {
|
||||
if (autoHeartbeat.value === 'on' && ws && ws.readyState === WebSocket.OPEN) {
|
||||
startAutoHeartbeat();
|
||||
}
|
||||
});
|
||||
|
||||
// 快捷键支持(Ctrl+Enter发送消息)
|
||||
messageInput.addEventListener('keydown', function (e) {
|
||||
if (e.ctrlKey && e.key === 'Enter') {
|
||||
sendCustomMessage();
|
||||
e.preventDefault();
|
||||
}
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
356
ws/ws.py
@ -4,314 +4,300 @@ import json
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict, Optional, AsyncGenerator
|
||||
from concurrent.futures import ThreadPoolExecutor # 新增:显式线程池
|
||||
|
||||
from service.device_service import update_online_status_by_ip, increment_alarm_count_by_ip
|
||||
from service.device_action_service import add_device_action
|
||||
from schema.device_action_schema import DeviceActionCreate
|
||||
from core.all import detect
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI
|
||||
from queue import Queue # 线程安全队列,无需额外Lock
|
||||
from core.all import load_model
|
||||
|
||||
from ocr.model_violation_detector import MultiModelViolationDetector
|
||||
|
||||
# -------------------------- 配置调整 --------------------------
|
||||
# 模型路径(建议改为环境变量)
|
||||
YOLO_MODEL_PATH = r"D:\Git\bin\video\ocr\models\best.pt"
|
||||
OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml"
|
||||
|
||||
# 核心优化:模型池大小(决定最大并发任务数,显存占用=大小×单模型显存)
|
||||
MODEL_POOL_SIZE = 5 # 示例:设为5,支持5个任务并行,显存会明显上升
|
||||
THREAD_POOL_SIZE = MODEL_POOL_SIZE * 2 # 线程池大小≥模型池,避免线程瓶颈
|
||||
|
||||
# 其他配置
|
||||
HEARTBEAT_INTERVAL = 30 # 心跳间隔(秒)
|
||||
# 配置常量
|
||||
HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒)
|
||||
HEARTBEAT_TIMEOUT = 600 # 客户端超时阈值(秒)
|
||||
WS_ENDPOINT = "/ws" # WebSocket端点
|
||||
FRAME_QUEUE_SIZE = 5 # 增大帧队列,允许缓存更多帧(避免丢帧)
|
||||
WS_ENDPOINT = "/ws" # WebSocket端点路径
|
||||
FRAME_QUEUE_SIZE = 1 # 帧队列大小限制
|
||||
|
||||
# -------------------------- 工具函数 --------------------------
|
||||
|
||||
# 工具函数:获取格式化时间字符串(统一时间戳格式)
|
||||
def get_current_time_str() -> str:
|
||||
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
def get_current_time_file_str() -> str:
|
||||
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
|
||||
# -------------------------- 模型池重构(核心修改1) --------------------------
|
||||
class ModelPool:
|
||||
def __init__(self, pool_size: int = MODEL_POOL_SIZE):
|
||||
self.pool = Queue(maxsize=pool_size)
|
||||
# 移除冗余Lock:Queue.get()/put()本身线程安全
|
||||
self._init_models(pool_size)
|
||||
print(f"[{get_current_time_str()}] 模型池初始化完成(共{pool_size}个实例,显存已预分配)")
|
||||
|
||||
def _init_models(self, pool_size: int):
|
||||
"""预加载所有模型实例(初始化时显存会一次性上升)"""
|
||||
for i in range(pool_size):
|
||||
try:
|
||||
detector = MultiModelViolationDetector(
|
||||
ocr_config_path=OCR_CONFIG_PATH,
|
||||
yolo_model_path=YOLO_MODEL_PATH,
|
||||
ocr_confidence_threshold=0.5
|
||||
)
|
||||
self.pool.put(detector)
|
||||
print(f"[{get_current_time_str()}] 模型实例{i+1}/{pool_size}加载完成")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"模型实例{i+1}加载失败:{str(e)}")
|
||||
|
||||
def get_model(self) -> MultiModelViolationDetector:
|
||||
"""获取模型(阻塞直到有空闲实例,确保并发安全)"""
|
||||
return self.pool.get()
|
||||
|
||||
def return_model(self, detector: MultiModelViolationDetector):
|
||||
"""归还模型(立即释放资源供其他任务使用)"""
|
||||
self.pool.put(detector)
|
||||
|
||||
# -------------------------- 全局资源初始化 --------------------------
|
||||
model_pool = ModelPool(pool_size=MODEL_POOL_SIZE) # 初始化模型池(预占显存)
|
||||
thread_pool = ThreadPoolExecutor( # 显式创建线程池(核心修改2)
|
||||
max_workers=THREAD_POOL_SIZE,
|
||||
thread_name_prefix="ModelWorker-" # 线程命名,便于调试
|
||||
)
|
||||
|
||||
# -------------------------- 客户端连接封装(核心修改3) --------------------------
|
||||
# 客户端连接封装
|
||||
class ClientConnection:
|
||||
def __init__(self, websocket: WebSocket, client_ip: str):
|
||||
self.websocket = websocket
|
||||
self.client_ip = client_ip
|
||||
self.last_heartbeat = datetime.datetime.now()
|
||||
self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE) # 增大队列
|
||||
self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE)
|
||||
self.consumer_task: Optional[asyncio.Task] = None
|
||||
# 移除“客户端独占模型”:不再持有detector属性
|
||||
|
||||
def update_heartbeat(self):
|
||||
"""更新心跳时间(客户端发送心跳时调用)"""
|
||||
self.last_heartbeat = datetime.datetime.now()
|
||||
|
||||
def is_alive(self) -> bool:
|
||||
"""判断客户端是否存活(心跳超时检查)"""
|
||||
timeout = (datetime.datetime.now() - self.last_heartbeat).total_seconds()
|
||||
return timeout < HEARTBEAT_TIMEOUT
|
||||
|
||||
def start_consumer(self):
|
||||
"""启动帧消费任务(每个客户端一个独立任务)"""
|
||||
"""启动帧消费任务"""
|
||||
self.consumer_task = asyncio.create_task(self.consume_frames())
|
||||
return self.consumer_task
|
||||
|
||||
async def send_frame_permit(self):
|
||||
"""发送帧许可信号(允许客户端继续发帧)"""
|
||||
"""
|
||||
发送「帧发送许可信号」
|
||||
通知客户端可发送下一帧图像
|
||||
"""
|
||||
try:
|
||||
await self.websocket.send_json({
|
||||
frame_permit_msg = {
|
||||
"type": "frame",
|
||||
"timestamp": get_current_time_str(),
|
||||
"client_ip": self.client_ip
|
||||
})
|
||||
}
|
||||
await self.websocket.send_json(frame_permit_msg)
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:已发送帧发送许可信号(取帧后立即通知)")
|
||||
except Exception as e:
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧许可发送失败 - {str(e)}")
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧许可信号发送失败 - {str(e)}")
|
||||
|
||||
async def consume_frames(self) -> None:
|
||||
"""消费帧队列(并发核心:每帧临时借模型处理)"""
|
||||
"""消费队列中的帧并处理(核心调整:取帧后立即发许可,再处理帧)"""
|
||||
try:
|
||||
while True:
|
||||
# 1. 从队列取帧(无帧时阻塞)
|
||||
# 1. 从队列取出帧(阻塞直到有帧可用)
|
||||
frame_data = await self.frame_queue.get()
|
||||
# 2. 立即发送下一帧许可(让客户端持续发帧,积累并发任务)
|
||||
await self.send_frame_permit()
|
||||
|
||||
# -------------------------- 核心修改:取出帧后立即发送下一帧许可 --------------------------
|
||||
await self.send_frame_permit() # 取帧即通知客户端发下一帧,无需等处理完成
|
||||
# -----------------------------------------------------------------------------------------
|
||||
|
||||
try:
|
||||
# 3. 并行处理帧(核心:任务级借模型)
|
||||
# 2. 处理取出的帧(即使处理慢,客户端也已收到许可,可提前准备下一帧)
|
||||
await self.process_frame(frame_data)
|
||||
finally:
|
||||
self.frame_queue.task_done() # 标记帧处理完成
|
||||
# 3. 标记帧任务完成(无论处理成功/失败,都需清理队列)
|
||||
self.frame_queue.task_done()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧消费任务已取消")
|
||||
except Exception as e:
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:消费逻辑错误 - {str(e)}")
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧消费逻辑错误 - {str(e)}")
|
||||
|
||||
async def process_frame(self, frame_data: bytes) -> None:
|
||||
"""处理单帧(核心修改4:任务级借还模型)"""
|
||||
# 1. 临时借用模型(阻塞直到有空闲实例,显存随借用数上升)
|
||||
detector = model_pool.get_model()
|
||||
try:
|
||||
# 2. 二进制转OpenCV图像
|
||||
"""处理单帧图像数据(检测违规后发送危险通知 + 调用违规次数加一方法)"""
|
||||
# 二进制数据转OpenCV图像
|
||||
nparr = np.frombuffer(frame_data, np.uint8)
|
||||
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:图像解析失败")
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:无法解析图像数据")
|
||||
return
|
||||
|
||||
# 3. 保存图像(可选)
|
||||
# 确保图像保存目录存在
|
||||
os.makedirs('images', exist_ok=True)
|
||||
|
||||
# 保存图像(按IP+时间戳命名,避免冲突)
|
||||
filename = f"images/{self.client_ip.replace('.', '_')}_{get_current_time_file_str()}.jpg"
|
||||
try:
|
||||
cv2.imwrite(filename, img)
|
||||
|
||||
# 4. 显式线程池执行AI检测(真正并发,无线程瓶颈)
|
||||
loop = asyncio.get_running_loop()
|
||||
has_violation, violation_type, details = await loop.run_in_executor(
|
||||
thread_pool, # 用自定义线程池,避免默认线程不足
|
||||
detector.detect_violations, # 临时借用的模型
|
||||
img # 输入图像
|
||||
)
|
||||
|
||||
# 5. 违规处理(与原逻辑一致)
|
||||
print(f"[{get_current_time_str()}] 图像已保存至:{filename}")
|
||||
has_violation, data, type = detect(img)
|
||||
print(has_violation)
|
||||
print(type)
|
||||
print(data)
|
||||
if has_violation:
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:违规 - {violation_type}")
|
||||
# 违规次数更新(用线程池避免阻塞事件循环)
|
||||
await loop.run_in_executor(thread_pool, increment_alarm_count_by_ip, self.client_ip)
|
||||
# 发送危险通知
|
||||
await self.websocket.send_json({
|
||||
print(
|
||||
f"[{get_current_time_str()}] 客户端{self.client_ip}:检测到违规 - 类型: {type}, 详情: {data}")
|
||||
|
||||
# 调用违规次数加一方法
|
||||
try:
|
||||
await asyncio.to_thread(increment_alarm_count_by_ip, self.client_ip)
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:违规次数已+1")
|
||||
except Exception as e:
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:违规次数更新失败 - {str(e)}")
|
||||
|
||||
# 发送「危险通知」
|
||||
danger_msg = {
|
||||
"type": "danger",
|
||||
"timestamp": get_current_time_str(),
|
||||
"client_ip": self.client_ip,
|
||||
"violation_type": violation_type,
|
||||
"details": details
|
||||
})
|
||||
"client_ip": self.client_ip
|
||||
}
|
||||
await self.websocket.send_json(danger_msg)
|
||||
else:
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:无违规")
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:未检测到违规")
|
||||
except Exception as e:
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧处理错误 - {str(e)}")
|
||||
finally:
|
||||
# 6. 无论成功/失败,强制归还模型(核心:释放资源供其他任务使用)
|
||||
model_pool.return_model(detector)
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:模型已归还(可复用)")
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:图像处理错误 - {str(e)}")
|
||||
|
||||
# -------------------------- 全局状态与心跳 --------------------------
|
||||
|
||||
# 全局状态管理
|
||||
connected_clients: Dict[str, ClientConnection] = {}
|
||||
client_lock = asyncio.Lock() # 保护客户端字典的异步锁
|
||||
heartbeat_task: Optional[asyncio.Task] = None
|
||||
|
||||
|
||||
# 心跳检查(定时清理超时客户端 + 调用离线状态更新方法)
|
||||
async def heartbeat_checker():
|
||||
"""心跳检查(移除模型归还逻辑,因模型已任务级归还)"""
|
||||
while True:
|
||||
current_time = get_current_time_str()
|
||||
async with client_lock:
|
||||
# 筛选超时客户端
|
||||
timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()]
|
||||
|
||||
if timeout_ips:
|
||||
print(f"[{current_time}] 心跳检查:{len(timeout_ips)}个客户端超时(IP:{timeout_ips})")
|
||||
for ip in timeout_ips:
|
||||
async with client_lock:
|
||||
conn = connected_clients.get(ip)
|
||||
if not conn:
|
||||
continue
|
||||
# 取消消费任务+关闭连接
|
||||
try:
|
||||
conn = connected_clients[ip]
|
||||
if conn.consumer_task and not conn.consumer_task.done():
|
||||
conn.consumer_task.cancel()
|
||||
await conn.websocket.close(code=1008, reason="心跳超时")
|
||||
# 标记离线(用线程池)
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(thread_pool, update_online_status_by_ip, ip, 0)
|
||||
await loop.run_in_executor(
|
||||
thread_pool, add_device_action, DeviceActionCreate(client_ip=ip, action=0)
|
||||
)
|
||||
connected_clients.pop(ip)
|
||||
print(f"[{current_time}] 客户端{ip}:超时离线(资源已清理)")
|
||||
|
||||
# 打印在线状态
|
||||
async with client_lock:
|
||||
# 超时设为离线并记录
|
||||
try:
|
||||
await asyncio.to_thread(update_online_status_by_ip, ip, 0)
|
||||
action_data = DeviceActionCreate(client_ip=ip, action=0)
|
||||
await asyncio.to_thread(add_device_action, action_data)
|
||||
print(f"[{current_time}] 客户端{ip}:已标记为离线并记录操作")
|
||||
except Exception as e:
|
||||
print(f"[{current_time}] 客户端{ip}:离线状态更新失败 - {str(e)}")
|
||||
finally:
|
||||
connected_clients.pop(ip, None)
|
||||
else:
|
||||
print(f"[{current_time}] 心跳检查:{len(connected_clients)}个客户端在线")
|
||||
|
||||
await asyncio.sleep(HEARTBEAT_INTERVAL)
|
||||
|
||||
# -------------------------- 应用生命周期(核心修改5:管理线程池) --------------------------
|
||||
|
||||
# 应用生命周期管理
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global heartbeat_task
|
||||
# 启动心跳任务
|
||||
heartbeat_task = asyncio.create_task(heartbeat_checker())
|
||||
print(f"[{get_current_time_str()}] 心跳任务启动(ID:{id(heartbeat_task)})")
|
||||
print(f"[{get_current_time_str()}] 线程池启动(最大线程数:{THREAD_POOL_SIZE})")
|
||||
yield # 应用运行期间
|
||||
# 清理资源
|
||||
print(f"[{get_current_time_str()}] 全局心跳检查任务启动(任务ID:{id(heartbeat_task)})")
|
||||
yield
|
||||
if heartbeat_task and not heartbeat_task.done():
|
||||
heartbeat_task.cancel()
|
||||
try:
|
||||
await heartbeat_task
|
||||
print(f"[{get_current_time_str()}] 心跳任务已关闭")
|
||||
# 关闭线程池(等待所有任务完成)
|
||||
thread_pool.shutdown(wait=True)
|
||||
print(f"[{get_current_time_str()}] 线程池已关闭")
|
||||
print(f"[{get_current_time_str()}] 全局心跳检查任务已取消")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# -------------------------- WebSocket路由 --------------------------
|
||||
|
||||
# 消息处理工具函数
|
||||
async def send_heartbeat_ack(conn: ClientConnection):
|
||||
try:
|
||||
heartbeat_ack_msg = {
|
||||
"type": "heart",
|
||||
"timestamp": get_current_time_str(),
|
||||
"client_ip": conn.client_ip
|
||||
}
|
||||
await conn.websocket.send_json(heartbeat_ack_msg)
|
||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:已发送心跳确认")
|
||||
return True
|
||||
except Exception as e:
|
||||
connected_clients.pop(conn.client_ip, None)
|
||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:心跳确认发送失败 - {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def handle_text_msg(conn: ClientConnection, text: str):
|
||||
try:
|
||||
msg = json.loads(text)
|
||||
if msg.get("type") == "heart":
|
||||
conn.update_heartbeat()
|
||||
await send_heartbeat_ack(conn)
|
||||
else:
|
||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:未知文本消息类型({msg.get('type')})")
|
||||
except json.JSONDecodeError:
|
||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:无效JSON文本消息")
|
||||
|
||||
|
||||
async def handle_binary_msg(conn: ClientConnection, data: bytes):
|
||||
try:
|
||||
conn.frame_queue.put_nowait(data)
|
||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:图像数据({len(data)}字节)已加入队列")
|
||||
except asyncio.QueueFull:
|
||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:帧队列已满,丢弃当前图像数据")
|
||||
|
||||
|
||||
# WebSocket路由配置
|
||||
ws_router = APIRouter()
|
||||
|
||||
|
||||
@ws_router.websocket(WS_ENDPOINT)
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
# 加载模型
|
||||
load_model()
|
||||
await websocket.accept()
|
||||
client_ip = websocket.client.host if websocket.client else "unknown_ip"
|
||||
current_time = get_current_time_str()
|
||||
print(f"[{current_time}] 客户端{client_ip}:连接建立")
|
||||
print(f"[{current_time}] 客户端{client_ip}:WebSocket连接已建立")
|
||||
|
||||
new_conn = None
|
||||
is_online_updated = False
|
||||
|
||||
try:
|
||||
# 处理重复连接(关闭旧连接)
|
||||
async with client_lock:
|
||||
# 处理重复连接
|
||||
if client_ip in connected_clients:
|
||||
old_conn = connected_clients[client_ip]
|
||||
if old_conn.consumer_task and not old_conn.consumer_task.done():
|
||||
old_conn.consumer_task.cancel()
|
||||
await old_conn.websocket.close(code=1008, reason="新连接抢占")
|
||||
await old_conn.websocket.close(code=1008, reason="同一IP新连接建立")
|
||||
connected_clients.pop(client_ip)
|
||||
print(f"[{current_time}] 客户端{client_ip}:旧连接已关闭")
|
||||
print(f"[{current_time}] 客户端{client_ip}:已关闭旧连接")
|
||||
|
||||
# 创建新连接+启动消费任务
|
||||
# 注册新连接
|
||||
new_conn = ClientConnection(websocket, client_ip)
|
||||
connected_clients[client_ip] = new_conn
|
||||
new_conn.start_consumer()
|
||||
# 初始发送帧许可(让客户端立即发帧)
|
||||
# 初始许可:连接建立后立即发一次,让客户端知道可发第一帧(后续靠取帧后自动发)
|
||||
await new_conn.send_frame_permit()
|
||||
|
||||
# 标记客户端在线
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(thread_pool, update_online_status_by_ip, client_ip, 1)
|
||||
await loop.run_in_executor(
|
||||
thread_pool, add_device_action, DeviceActionCreate(client_ip=client_ip, action=1)
|
||||
)
|
||||
# 标记上线并记录
|
||||
try:
|
||||
await asyncio.to_thread(update_online_status_by_ip, client_ip, 1)
|
||||
action_data = DeviceActionCreate(client_ip=client_ip, action=1)
|
||||
await asyncio.to_thread(add_device_action, action_data)
|
||||
print(f"[{current_time}] 客户端{client_ip}:已标记为在线并记录操作")
|
||||
is_online_updated = True
|
||||
async with client_lock:
|
||||
connected_clients[client_ip] = new_conn
|
||||
print(f"[{current_time}] 客户端{client_ip}:注册成功(在线数:{len(connected_clients)})")
|
||||
except Exception as e:
|
||||
print(f"[{current_time}] 客户端{client_ip}:上线状态更新失败 - {str(e)}")
|
||||
|
||||
# 消息循环(接收文本/二进制帧)
|
||||
print(f"[{current_time}] 客户端{client_ip}:新连接注册成功,在线数:{len(connected_clients)}")
|
||||
|
||||
# 消息循环
|
||||
while True:
|
||||
data = await websocket.receive()
|
||||
if "text" in data:
|
||||
# 处理文本消息(如心跳)
|
||||
try:
|
||||
msg = json.loads(data["text"])
|
||||
if msg.get("type") == "heart":
|
||||
new_conn.update_heartbeat()
|
||||
# 回复心跳确认
|
||||
await websocket.send_json({
|
||||
"type": "heart",
|
||||
"timestamp": get_current_time_str(),
|
||||
"client_ip": client_ip
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}:无效JSON")
|
||||
await handle_text_msg(new_conn, data["text"])
|
||||
elif "bytes" in data:
|
||||
# 处理二进制帧(图像)
|
||||
try:
|
||||
await new_conn.frame_queue.put(data["bytes"])
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}:帧已入队(队列大小:{new_conn.frame_queue.qsize()})")
|
||||
except asyncio.QueueFull:
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}:帧队列满(丢弃当前帧)")
|
||||
await handle_binary_msg(new_conn, data["bytes"])
|
||||
|
||||
except WebSocketDisconnect as e:
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}:主动断开(代码:{e.code})")
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}:主动断开连接(代码:{e.code})")
|
||||
except Exception as e:
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}:连接异常 - {str(e)[:50]}")
|
||||
finally:
|
||||
# 清理资源(无需归还模型,已在process_frame中归还)
|
||||
if new_conn and client_ip in connected_clients:
|
||||
async with client_lock:
|
||||
conn = connected_clients.get(client_ip)
|
||||
if conn:
|
||||
# 清理资源并标记离线
|
||||
if client_ip in connected_clients:
|
||||
conn = connected_clients[client_ip]
|
||||
if conn.consumer_task and not conn.consumer_task.done():
|
||||
conn.consumer_task.cancel()
|
||||
# 标记离线(仅当在线状态已更新时)
|
||||
|
||||
# 主动/异常断开时标记离线
|
||||
if is_online_updated:
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(thread_pool, update_online_status_by_ip, client_ip, 0)
|
||||
await loop.run_in_executor(
|
||||
thread_pool, add_device_action, DeviceActionCreate(client_ip=client_ip, action=0)
|
||||
)
|
||||
connected_clients.pop(client_ip)
|
||||
async with client_lock:
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}:资源清理完成(在线数:{len(connected_clients)})")
|
||||
try:
|
||||
await asyncio.to_thread(update_online_status_by_ip, client_ip, 0)
|
||||
action_data = DeviceActionCreate(client_ip=client_ip, action=0)
|
||||
await asyncio.to_thread(add_device_action, action_data)
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}:断开后已标记为离线")
|
||||
except Exception as e:
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}:断开后离线更新失败 - {str(e)}")
|
||||
|
||||
connected_clients.pop(client_ip, None)
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}:资源已清理,在线数:{len(connected_clients)}")
|
||||
|