最新可用
2
.idea/Video.iml
generated
@ -2,7 +2,7 @@
|
|||||||
<module type="PYTHON_MODULE" version="4">
|
<module type="PYTHON_MODULE" version="4">
|
||||||
<component name="NewModuleRootManager">
|
<component name="NewModuleRootManager">
|
||||||
<content url="file://$MODULE_DIR$" />
|
<content url="file://$MODULE_DIR$" />
|
||||||
<orderEntry type="jdk" jdkName="video" jdkType="Python SDK" />
|
<orderEntry type="jdk" jdkName="Python 3.10" jdkType="Python SDK" />
|
||||||
<orderEntry type="sourceFolder" forTests="false" />
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
</component>
|
</component>
|
||||||
</module>
|
</module>
|
2
.idea/misc.xml
generated
@ -3,5 +3,5 @@
|
|||||||
<component name="Black">
|
<component name="Black">
|
||||||
<option name="sdkName" value="video" />
|
<option name="sdkName" value="video" />
|
||||||
</component>
|
</component>
|
||||||
<component name="ProjectRootManager" version="2" project-jdk-name="video" project-jdk-type="Python SDK" />
|
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10" project-jdk-type="Python SDK" />
|
||||||
</project>
|
</project>
|
@ -15,5 +15,5 @@ algorithm = HS256
|
|||||||
access_token_expire_minutes = 30
|
access_token_expire_minutes = 30
|
||||||
|
|
||||||
[live]
|
[live]
|
||||||
rtmp_url = rtmp://192.168.110.65:1935/live/
|
rtmp_url = rtmp://192.168.110.25:1935/live/
|
||||||
webrtc_url = http://192.168.110.65:1985/rtc/v1/whep/?app=live&stream=
|
webrtc_url = http://192.168.110.25:1985/rtc/v1/whep/?app=live&stream=
|
||||||
|
45
core/all.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
from core.ocr import load_model as ocrLoadModel, detect as ocrDetect
|
||||||
|
from core.face import load_model as faceLoadModel, detect as faceDetect
|
||||||
|
from core.yolo import load_model as yoloLoadModel, detect as yoloDetect
|
||||||
|
|
||||||
|
# 添加一个标记变量,用于监控load_model是否已被调用
|
||||||
|
_model_loaded = False
|
||||||
|
|
||||||
|
|
||||||
|
def load_model():
|
||||||
|
global _model_loaded
|
||||||
|
|
||||||
|
# 如果已经调用过,直接忽略
|
||||||
|
if _model_loaded:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 首次调用时加载模型
|
||||||
|
ocrLoadModel()
|
||||||
|
faceLoadModel()
|
||||||
|
yoloLoadModel()
|
||||||
|
|
||||||
|
# 标记为已调用
|
||||||
|
_model_loaded = True
|
||||||
|
|
||||||
|
|
||||||
|
def detect(frame):
|
||||||
|
# 先进行YOLO检测
|
||||||
|
yolo_flag, yolo_result = yoloDetect(frame)
|
||||||
|
print("YOLO检测结果:", yolo_result)
|
||||||
|
if yolo_flag:
|
||||||
|
return (True, yolo_result, "yolo")
|
||||||
|
|
||||||
|
# YOLO未检测到,进行人脸检测
|
||||||
|
face_flag, face_result = faceDetect(frame)
|
||||||
|
print("人脸检测结果:", face_result)
|
||||||
|
if face_flag:
|
||||||
|
return (True, face_result, "face")
|
||||||
|
|
||||||
|
# 人脸未检测到,进行OCR检测
|
||||||
|
ocr_flag, ocr_result = ocrDetect(frame)
|
||||||
|
print("OCR检测结果:", ocr_result)
|
||||||
|
if ocr_flag:
|
||||||
|
return (True, ocr_result, "ocr")
|
||||||
|
|
||||||
|
# 所有检测都未检测到
|
||||||
|
return (False, "未检测到任何内容", "none")
|
113
core/face.py
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
from PIL import Image # 确保正确导入Image类
|
||||||
|
from insightface.app import FaceAnalysis
|
||||||
|
# 导入获取人脸信息的服务
|
||||||
|
from service.face_service import get_all_face_name_with_eigenvalue
|
||||||
|
|
||||||
|
# 全局变量
|
||||||
|
_face_app = None
|
||||||
|
_known_faces_embeddings = {} # 存储姓名到特征值的映射
|
||||||
|
_known_faces_names = [] # 存储所有已知姓名
|
||||||
|
|
||||||
|
|
||||||
|
def load_model():
|
||||||
|
"""加载人脸识别模型及已知人脸特征库"""
|
||||||
|
global _face_app, _known_faces_embeddings, _known_faces_names
|
||||||
|
|
||||||
|
# 初始化InsightFace模型
|
||||||
|
try:
|
||||||
|
_face_app = FaceAnalysis(name='buffalo_l', root=os.path.expanduser('~/.insightface'))
|
||||||
|
_face_app.prepare(ctx_id=0, det_size=(640, 640))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Face model load failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 从服务获取所有人脸姓名和特征值
|
||||||
|
try:
|
||||||
|
face_data = get_all_face_name_with_eigenvalue()
|
||||||
|
|
||||||
|
# 处理获取到的人脸数据
|
||||||
|
for person_name, eigenvalue_data in face_data.items():
|
||||||
|
# 处理特征值数据 - 兼容数组和字符串两种格式
|
||||||
|
if isinstance(eigenvalue_data, np.ndarray):
|
||||||
|
# 如果已经是numpy数组,直接使用
|
||||||
|
eigenvalue = eigenvalue_data.astype(np.float32)
|
||||||
|
elif isinstance(eigenvalue_data, str):
|
||||||
|
# 清理字符串:移除方括号、换行符和多余空格
|
||||||
|
cleaned = eigenvalue_data.replace('[', '').replace(']', '').replace('\n', '').strip()
|
||||||
|
# 按空格或逗号分割(处理可能的不同分隔符)
|
||||||
|
values = [v for v in cleaned.split() if v]
|
||||||
|
# 转换为数组
|
||||||
|
eigenvalue = np.array(list(map(float, values)), dtype=np.float32)
|
||||||
|
else:
|
||||||
|
# 不支持的类型
|
||||||
|
print(f"Unsupported eigenvalue type for {person_name}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 归一化处理
|
||||||
|
norm = np.linalg.norm(eigenvalue)
|
||||||
|
if norm != 0:
|
||||||
|
eigenvalue = eigenvalue / norm
|
||||||
|
|
||||||
|
_known_faces_embeddings[person_name] = eigenvalue
|
||||||
|
_known_faces_names.append(person_name)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading face data from service: {e}")
|
||||||
|
|
||||||
|
return True if _face_app else False
|
||||||
|
|
||||||
|
|
||||||
|
def detect(frame, threshold=0.4):
|
||||||
|
"""检测并识别人脸,返回结果元组(是否匹配到已知人脸, 结果字符串)"""
|
||||||
|
global _face_app, _known_faces_embeddings, _known_faces_names
|
||||||
|
|
||||||
|
if not _face_app or not _known_faces_names or frame is None:
|
||||||
|
return (False, "未初始化或无效帧")
|
||||||
|
|
||||||
|
try:
|
||||||
|
faces = _face_app.get(frame)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Face detect error: {e}")
|
||||||
|
return (False, f"检测错误: {str(e)}")
|
||||||
|
|
||||||
|
result_parts = []
|
||||||
|
has_matched = False # 新增标记:是否有匹配到的已知人脸
|
||||||
|
|
||||||
|
for face in faces:
|
||||||
|
# 特征归一化
|
||||||
|
embedding = face.embedding.astype(np.float32)
|
||||||
|
norm = np.linalg.norm(embedding)
|
||||||
|
if norm == 0:
|
||||||
|
continue
|
||||||
|
embedding = embedding / norm
|
||||||
|
|
||||||
|
# 对比已知人脸
|
||||||
|
max_sim, best_name = -1.0, "Unknown"
|
||||||
|
for name in _known_faces_names:
|
||||||
|
known_emb = _known_faces_embeddings[name]
|
||||||
|
sim = np.dot(embedding, known_emb)
|
||||||
|
if sim > max_sim:
|
||||||
|
max_sim = sim
|
||||||
|
best_name = name
|
||||||
|
|
||||||
|
# 判断匹配结果
|
||||||
|
is_match = max_sim >= threshold
|
||||||
|
if is_match:
|
||||||
|
has_matched = True # 只要有一个匹配成功,就标记为True
|
||||||
|
|
||||||
|
bbox = face.bbox
|
||||||
|
result_parts.append(
|
||||||
|
f"{'匹配' if is_match else '不匹配'}: {best_name} (相似度: {max_sim:.2f}, 边界框: {bbox})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建结果字符串
|
||||||
|
if not result_parts:
|
||||||
|
result_str = "未检测到人脸"
|
||||||
|
else:
|
||||||
|
result_str = "; ".join(result_parts)
|
||||||
|
|
||||||
|
# 第一个返回值改为:是否匹配到已知人脸
|
||||||
|
return (has_matched, result_str)
|
BIN
core/models/best.pt
Normal file
76
core/ocr.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
from rapidocr import RapidOCR
|
||||||
|
from service.sensitive_service import get_all_sensitive_words
|
||||||
|
|
||||||
|
# 全局变量
|
||||||
|
_ocr_engine = None
|
||||||
|
_forbidden_words = set()
|
||||||
|
_conf_threshold = 0.5
|
||||||
|
|
||||||
|
ocr_config_path = os.path.join(os.path.dirname(__file__), "config", "config.yaml")
|
||||||
|
|
||||||
|
|
||||||
|
def load_model():
|
||||||
|
"""加载OCR引擎及违禁词列表"""
|
||||||
|
global _ocr_engine, _forbidden_words, _conf_threshold
|
||||||
|
|
||||||
|
# 加载违禁词
|
||||||
|
try:
|
||||||
|
_forbidden_words = get_all_sensitive_words()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Forbidden words load error: {e}")
|
||||||
|
|
||||||
|
# 初始化OCR引擎
|
||||||
|
if not os.path.exists(ocr_config_path):
|
||||||
|
print(f"OCR config not found: {ocr_config_path}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
_ocr_engine = RapidOCR(config_path=ocr_config_path)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"OCR model load failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True if _ocr_engine else False
|
||||||
|
|
||||||
|
|
||||||
|
def detect(frame):
|
||||||
|
"""OCR检测并筛选违禁词,返回(是否检测到违禁词, 结果字符串)"""
|
||||||
|
if not _ocr_engine or not _forbidden_words or frame is None or frame.size == 0:
|
||||||
|
return (False, "未初始化或无效帧")
|
||||||
|
|
||||||
|
try:
|
||||||
|
ocr_res = _ocr_engine(frame)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"OCR detect error: {e}")
|
||||||
|
return (False, f"检测错误: {str(e)}")
|
||||||
|
|
||||||
|
if not ocr_res or not hasattr(ocr_res, 'txts') or not hasattr(ocr_res, 'scores'):
|
||||||
|
return (False, "无OCR结果")
|
||||||
|
|
||||||
|
# 处理OCR结果
|
||||||
|
texts = [t.strip() for t in ocr_res.txts if t and isinstance(t, str)]
|
||||||
|
confs = [c for c in ocr_res.scores if c and isinstance(c, (int, float))]
|
||||||
|
if len(texts) != len(confs):
|
||||||
|
return (False, "OCR结果格式异常")
|
||||||
|
|
||||||
|
# 筛选违禁词
|
||||||
|
vio_info = []
|
||||||
|
for txt, conf in zip(texts, confs):
|
||||||
|
if conf < _conf_threshold:
|
||||||
|
continue
|
||||||
|
matched = [w for w in _forbidden_words if w in txt]
|
||||||
|
if matched:
|
||||||
|
vio_info.append(f"文本: '{txt}' 包含违禁词: {', '.join(matched)} (置信度: {conf:.2f})")
|
||||||
|
|
||||||
|
# 构建结果字符串
|
||||||
|
has_text = len(texts) > 0
|
||||||
|
has_violation = len(vio_info) > 0
|
||||||
|
|
||||||
|
if not has_text:
|
||||||
|
return (False, "未识别到文本")
|
||||||
|
elif has_violation:
|
||||||
|
return (True, "; ".join(vio_info))
|
||||||
|
else:
|
||||||
|
return (False, "未检测到违禁词")
|
137
core/rtc.py
@ -1,137 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
from aiortc import RTCPeerConnection, RTCSessionDescription
|
|
||||||
import aiohttp
|
|
||||||
from ocr.ocr_violation_detector import OCRViolationDetector
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
# 创建检测器实例
|
|
||||||
detector = OCRViolationDetector(
|
|
||||||
forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt",
|
|
||||||
ocr_confidence_threshold=0.7,
|
|
||||||
log_level=logging.INFO,
|
|
||||||
log_file="ocr_detection.log"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 配置日志
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger("whep_video_puller")
|
|
||||||
|
|
||||||
|
|
||||||
async def whep_pull_video_stream(ip,whep_url):
|
|
||||||
"""
|
|
||||||
通过WHEP从指定URL拉取视频流并在收到每一帧时打印消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
whep_url: WHEP端点的URL
|
|
||||||
"""
|
|
||||||
pc = RTCPeerConnection()
|
|
||||||
|
|
||||||
# 添加连接状态变化监听
|
|
||||||
@pc.on("connectionstatechange")
|
|
||||||
async def on_connectionstatechange():
|
|
||||||
print(f"连接状态: {pc.connectionState}")
|
|
||||||
|
|
||||||
# 添加ICE连接状态变化监听
|
|
||||||
@pc.on("iceconnectionstatechange")
|
|
||||||
async def on_iceconnectionstatechange():
|
|
||||||
print(f"ICE连接状态: {pc.iceConnectionState}")
|
|
||||||
|
|
||||||
# 添加视频接收器
|
|
||||||
pc.addTransceiver("video", direction="recvonly")
|
|
||||||
|
|
||||||
# 处理接收到的视频轨道
|
|
||||||
@pc.on("track")
|
|
||||||
def on_track(track):
|
|
||||||
print(f"接收到轨道: {track.kind}")
|
|
||||||
if track.kind == "video":
|
|
||||||
print(f"轨道ID: {track.id}")
|
|
||||||
print(f"轨道就绪状态: {track.readyState}")
|
|
||||||
# 创建异步任务来处理视频帧
|
|
||||||
asyncio.ensure_future(handle_video_track(track))
|
|
||||||
|
|
||||||
async def handle_video_track(track):
|
|
||||||
"""处理视频轨道,接收并打印每一帧"""
|
|
||||||
frame_count = 0
|
|
||||||
print("开始处理视频轨道...")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
# 尝试接收帧
|
|
||||||
frame = await track.recv()
|
|
||||||
frame_count += 1
|
|
||||||
print(f"收到原始帧 (第{frame_count}帧)")
|
|
||||||
|
|
||||||
# 打印帧的基本信息
|
|
||||||
if hasattr(frame, 'width') and hasattr(frame, 'height'):
|
|
||||||
print(f" 尺寸: {frame.width}x{frame.height}")
|
|
||||||
if hasattr(frame, 'time_base'):
|
|
||||||
print(f" 时间基准: {frame.time_base}")
|
|
||||||
if hasattr(frame, 'pts'):
|
|
||||||
print(f" 显示时间戳: {frame.pts}")
|
|
||||||
|
|
||||||
has_violation, violations, confidences = OCRViolationDetector.detect(frame)
|
|
||||||
|
|
||||||
# 输出检测结果
|
|
||||||
if has_violation:
|
|
||||||
detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
|
|
||||||
for word, conf in zip(violations, confidences):
|
|
||||||
detector.logger.info(f"- {word} (置信度: {conf:.4f})")
|
|
||||||
else:
|
|
||||||
detector.logger.info("图片中未检测到违禁词")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"接收帧时出错: {e}")
|
|
||||||
# 等待一段时间后重试
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 创建offer
|
|
||||||
offer = await pc.createOffer()
|
|
||||||
await pc.setLocalDescription(offer)
|
|
||||||
|
|
||||||
print(f"本地SDP信息:\n{offer.sdp}")
|
|
||||||
|
|
||||||
# 通过HTTP POST发送offer到WHEP端点
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(
|
|
||||||
whep_url,
|
|
||||||
data=offer.sdp,
|
|
||||||
headers={"Content-Type": "application/sdp"}
|
|
||||||
) as response:
|
|
||||||
if response.status != 201:
|
|
||||||
print(f"WHEP服务器返回错误: {response.status}")
|
|
||||||
print(f"响应内容: {await response.text()}")
|
|
||||||
raise Exception(f"WHEP服务器返回错误: {response.status}")
|
|
||||||
|
|
||||||
# 获取answer SDP
|
|
||||||
answer_sdp = await response.text()
|
|
||||||
|
|
||||||
# 创建RTCSessionDescription对象
|
|
||||||
answer = RTCSessionDescription(sdp=answer_sdp, type="answer")
|
|
||||||
|
|
||||||
print(f"收到远程SDP:\n{answer_sdp}")
|
|
||||||
|
|
||||||
# 设置远程描述
|
|
||||||
await pc.setRemoteDescription(answer)
|
|
||||||
|
|
||||||
print("连接已建立,开始接收视频流...")
|
|
||||||
|
|
||||||
# 保持连接,直到用户中断
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
# 检查连接状态
|
|
||||||
print(f"当前连接状态: {pc.connectionState}")
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("用户中断,关闭连接...")
|
|
||||||
finally:
|
|
||||||
await pc.close()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 替换为你的WHEP端点URL
|
|
||||||
WHEP_URL = "http://192.168.110.25:1985/rtc/v1/whep/?app=live&stream=473b95a47e338301cbd96809ea7ac416"
|
|
||||||
|
|
||||||
# 运行拉流任务
|
|
||||||
asyncio.run(whep_pull_video_stream(WHEP_URL))
|
|
112
core/rtmp.py
@ -1,112 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import cv2
|
|
||||||
import time
|
|
||||||
from ocr.model_violation_detector import MultiModelViolationDetector
|
|
||||||
|
|
||||||
|
|
||||||
# 配置文件相对路径(根据实际目录结构调整)
|
|
||||||
YOLO_MODEL_PATH = "../ocr/models/best.pt" # 关键修正:从core目录向上一级找ocr文件夹
|
|
||||||
FORBIDDEN_WORDS_PATH = "../ocr/forbidden_words.txt"
|
|
||||||
OCR_CONFIG_PATH = "../ocr/config/1.yaml"
|
|
||||||
KNOWN_FACES_DIR = "../ocr/known_faces"
|
|
||||||
|
|
||||||
# 创建检测器实例
|
|
||||||
detector = MultiModelViolationDetector(
|
|
||||||
forbidden_words_path=FORBIDDEN_WORDS_PATH,
|
|
||||||
ocr_config_path=OCR_CONFIG_PATH,
|
|
||||||
yolo_model_path=YOLO_MODEL_PATH,
|
|
||||||
known_faces_dir=KNOWN_FACES_DIR,
|
|
||||||
ocr_confidence_threshold=0.5
|
|
||||||
)
|
|
||||||
|
|
||||||
# 配置日志
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger("rtmp_video_puller")
|
|
||||||
|
|
||||||
|
|
||||||
async def rtmp_pull_video_stream(rtmp_url):
|
|
||||||
"""
|
|
||||||
通过RTMP从指定URL拉取视频流并进行违规检测
|
|
||||||
"""
|
|
||||||
cap = None # 初始化视频捕获对象
|
|
||||||
try:
|
|
||||||
# 异步打开RTMP流
|
|
||||||
cap = await asyncio.to_thread(
|
|
||||||
cv2.VideoCapture,
|
|
||||||
rtmp_url,
|
|
||||||
cv2.CAP_FFMPEG # 指定FFmpeg后端确保RTMP兼容性
|
|
||||||
)
|
|
||||||
|
|
||||||
# 检查RTMP流是否成功打开
|
|
||||||
is_opened = await asyncio.to_thread(cap.isOpened)
|
|
||||||
if not is_opened:
|
|
||||||
raise Exception(f"RTMP流打开失败: {rtmp_url}(请检查URL有效性和FFmpeg环境)")
|
|
||||||
|
|
||||||
# 获取RTMP流基础信息
|
|
||||||
width = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_WIDTH)
|
|
||||||
height = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_HEIGHT)
|
|
||||||
fps = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FPS)
|
|
||||||
|
|
||||||
# 处理异常情况
|
|
||||||
fps = fps if fps > 0 else 30.0
|
|
||||||
width, height = int(width), int(height)
|
|
||||||
|
|
||||||
# 打印流初始化成功信息
|
|
||||||
print(f"RTMP流状态: 已成功连接")
|
|
||||||
print(f"流基础信息: 分辨率 {width}x{height} | 配置帧率 {fps:.2f} FPS")
|
|
||||||
print("开始接收视频帧...(按 Ctrl+C 中断)")
|
|
||||||
|
|
||||||
# 初始化帧统计参数
|
|
||||||
frame_count = 0
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
# 循环读取视频帧
|
|
||||||
while True:
|
|
||||||
ret, frame = await asyncio.to_thread(cap.read)
|
|
||||||
|
|
||||||
if not ret:
|
|
||||||
print(f"RTMP流状态: 帧读取失败(可能流已中断或结束)")
|
|
||||||
break
|
|
||||||
|
|
||||||
frame_count += 1
|
|
||||||
|
|
||||||
# 打印当前帧信息
|
|
||||||
print(f"收到帧 (第{frame_count}帧)")
|
|
||||||
print(f" 帧尺寸: {width}x{height}")
|
|
||||||
print(f" 配置帧率: {fps:.2f} FPS")
|
|
||||||
|
|
||||||
if frame is not None:
|
|
||||||
has_violation, violation_type, details = detector.detect_violations(frame)
|
|
||||||
if has_violation:
|
|
||||||
print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
|
|
||||||
else:
|
|
||||||
print("未检测到任何违规内容")
|
|
||||||
else:
|
|
||||||
print(f"无法读取测试图像")
|
|
||||||
|
|
||||||
# 每100帧统计一次实际接收帧率
|
|
||||||
if frame_count % 100 == 0:
|
|
||||||
elapsed_time = time.time() - start_time
|
|
||||||
actual_fps = frame_count / elapsed_time
|
|
||||||
print(f"---- 帧统计: 累计{frame_count}帧 | 实际平均帧率 {actual_fps:.2f} FPS ----")
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print(f"\n用户操作: 已通过 Ctrl+C 中断程序")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"RTMP流处理异常: {str(e)}", exc_info=True)
|
|
||||||
print(f"错误信息: {str(e)}")
|
|
||||||
finally:
|
|
||||||
if cap is not None:
|
|
||||||
await asyncio.to_thread(cap.release)
|
|
||||||
print(f"\n资源释放: RTMP流已关闭")
|
|
||||||
print(f"最终统计: 共接收 {frame_count if 'frame_count' in locals() else 0} 帧")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
RTMP_URL = "rtmp://192.168.110.25:1935/live/473b95a47e338301cbd96809ea7ac416"
|
|
||||||
|
|
||||||
try:
|
|
||||||
asyncio.run(rtmp_pull_video_stream(RTMP_URL))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"程序启动失败: {str(e)}")
|
|
55
core/yolo.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
from ultralytics import YOLO
|
||||||
|
|
||||||
|
# 全局变量
|
||||||
|
_yolo_model = None
|
||||||
|
|
||||||
|
|
||||||
|
model_path = os.path.join(os.path.dirname(__file__), "models", "best.pt")
|
||||||
|
|
||||||
|
|
||||||
|
def load_model():
|
||||||
|
"""加载YOLO目标检测模型"""
|
||||||
|
global _yolo_model
|
||||||
|
|
||||||
|
try:
|
||||||
|
_yolo_model = YOLO(model_path)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"YOLO model load failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True if _yolo_model else False
|
||||||
|
|
||||||
|
|
||||||
|
def detect(frame, conf_threshold=0.2):
|
||||||
|
"""YOLO目标检测,返回(是否识别到, 结果字符串)"""
|
||||||
|
global _yolo_model
|
||||||
|
|
||||||
|
if not _yolo_model or frame is None:
|
||||||
|
return (False, "未初始化或无效帧")
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = _yolo_model(frame, conf=conf_threshold)
|
||||||
|
# 检查是否有检测结果
|
||||||
|
has_results = len(results[0].boxes) > 0 if results else False
|
||||||
|
|
||||||
|
if not has_results:
|
||||||
|
return (False, "未检测到目标")
|
||||||
|
|
||||||
|
# 构建结果字符串
|
||||||
|
result_parts = []
|
||||||
|
for box in results[0].boxes:
|
||||||
|
cls = int(box.cls[0])
|
||||||
|
conf = float(box.conf[0])
|
||||||
|
bbox = [float(x) for x in box.xyxy[0]]
|
||||||
|
class_name = _yolo_model.names[cls] if hasattr(_yolo_model, 'names') else f"类别{cls}"
|
||||||
|
result_parts.append(f"{class_name} (置信度: {conf:.2f}, 边界框: {bbox})")
|
||||||
|
|
||||||
|
result_str = "; ".join(result_parts)
|
||||||
|
return (has_results, result_str)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"YOLO detect error: {e}")
|
||||||
|
return (False, f"检测错误: {str(e)}")
|
23
main.py
@ -1,12 +1,19 @@
|
|||||||
import uvicorn
|
from PIL import Image # 正确导入
|
||||||
from fastapi import FastAPI
|
import numpy as np
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
from PIL import Image
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from core.all import load_model,detect
|
||||||
from ds.config import SERVER_CONFIG
|
from ds.config import SERVER_CONFIG
|
||||||
from middle.error_handler import global_exception_handler
|
from middle.error_handler import global_exception_handler
|
||||||
from service.user_service import router as user_router
|
from service.user_service import router as user_router
|
||||||
|
from service.sensitive_service import router as sensitive_router
|
||||||
|
from service.face_service import router as face_router
|
||||||
from service.device_service import router as device_router
|
from service.device_service import router as device_router
|
||||||
from ws.ws import ws_router, lifespan
|
from ws.ws import ws_router, lifespan
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# 初始化 FastAPI 应用、指定生命周期管理
|
# 初始化 FastAPI 应用、指定生命周期管理
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
@ -22,6 +29,8 @@ app = FastAPI(
|
|||||||
# ------------------------------
|
# ------------------------------
|
||||||
app.include_router(user_router)
|
app.include_router(user_router)
|
||||||
app.include_router(device_router)
|
app.include_router(device_router)
|
||||||
|
app.include_router(face_router)
|
||||||
|
app.include_router(sensitive_router)
|
||||||
app.include_router(ws_router)
|
app.include_router(ws_router)
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
@ -33,11 +42,19 @@ app.add_exception_handler(Exception, global_exception_handler)
|
|||||||
# 启动服务
|
# 启动服务
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
# -------------------------- 配置调整 --------------------------
|
||||||
|
# 模型配置路径(建议改为环境变量)
|
||||||
|
YOLO_MODEL_PATH = r"/core/models\best.pt"
|
||||||
|
OCR_CONFIG_PATH = r"/core/config\config.yaml"
|
||||||
|
|
||||||
|
# 初始化项目(默认端口设为8000,避免初始化失败时port未定义)
|
||||||
port = int(SERVER_CONFIG.get("port", 8000))
|
port = int(SERVER_CONFIG.get("port", 8000))
|
||||||
|
|
||||||
|
# 启动 UVicorn 服务
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
app="main:app",
|
app="main:app",
|
||||||
host="0.0.0.0",
|
host="0.0.0.0",
|
||||||
port=port,
|
port=port,
|
||||||
reload=True,
|
workers=8,
|
||||||
ws="websockets"
|
ws="websockets"
|
||||||
)
|
)
|
||||||
|
@ -8,7 +8,8 @@ from passlib.context import CryptContext
|
|||||||
|
|
||||||
from ds.config import JWT_CONFIG
|
from ds.config import JWT_CONFIG
|
||||||
from ds.db import db
|
from ds.db import db
|
||||||
from service.user_service import UserResponse
|
|
||||||
|
# 移除这里的 from service.user_service import UserResponse 导入
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# 密码加密配置
|
# 密码加密配置
|
||||||
@ -25,6 +26,7 @@ ACCESS_TOKEN_EXPIRE_MINUTES = int(JWT_CONFIG["access_token_expire_minutes"])
|
|||||||
# OAuth2 依赖(从请求头获取 Token、格式:Bearer <token>)
|
# OAuth2 依赖(从请求头获取 Token、格式:Bearer <token>)
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login")
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# 密码工具函数
|
# 密码工具函数
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
@ -32,10 +34,12 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|||||||
"""验证明文密码与加密密码是否匹配"""
|
"""验证明文密码与加密密码是否匹配"""
|
||||||
return pwd_context.verify(plain_password, hashed_password)
|
return pwd_context.verify(plain_password, hashed_password)
|
||||||
|
|
||||||
|
|
||||||
def get_password_hash(password: str) -> str:
|
def get_password_hash(password: str) -> str:
|
||||||
"""对明文密码进行 bcrypt 加密"""
|
"""对明文密码进行 bcrypt 加密"""
|
||||||
return pwd_context.hash(password)
|
return pwd_context.hash(password)
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# JWT 工具函数
|
# JWT 工具函数
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
@ -53,11 +57,15 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
|
|||||||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||||
return encoded_jwt
|
return encoded_jwt
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# 认证依赖(获取当前登录用户)
|
# 认证依赖(获取当前登录用户)
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
def get_current_user(token: str = Depends(oauth2_scheme)) -> UserResponse:
|
def get_current_user(token: str = Depends(oauth2_scheme)): # 移除返回类型注解
|
||||||
"""从 Token 中解析用户信息、验证通过后返回当前用户"""
|
"""从 Token 中解析用户信息、验证通过后返回当前用户"""
|
||||||
|
# 延迟导入,打破循环依赖
|
||||||
|
from schema.user_schema import UserResponse # 在这里导入
|
||||||
|
|
||||||
# 认证失败异常
|
# 认证失败异常
|
||||||
credentials_exception = HTTPException(
|
credentials_exception = HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
@ -89,7 +97,7 @@ def get_current_user(token: str = Depends(oauth2_scheme)) -> UserResponse:
|
|||||||
raise credentials_exception # 用户不存在
|
raise credentials_exception # 用户不存在
|
||||||
|
|
||||||
# 转换为 UserResponse 模型(自动校验字段)
|
# 转换为 UserResponse 模型(自动校验字段)
|
||||||
return UserResponse(** user)
|
return UserResponse(**user)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise credentials_exception from e
|
raise credentials_exception from e
|
||||||
finally:
|
finally:
|
||||||
|
@ -1,139 +0,0 @@
|
|||||||
import os
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import insightface
|
|
||||||
from insightface.app import FaceAnalysis
|
|
||||||
|
|
||||||
|
|
||||||
class FaceRecognizer:
|
|
||||||
"""
|
|
||||||
封装InsightFace人脸识别功能,支持从文件夹加载已知人脸。
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, known_faces_dir: str):
|
|
||||||
self.known_faces_dir = known_faces_dir
|
|
||||||
self.app = self._initialize_insightface()
|
|
||||||
self.known_faces_embeddings = {}
|
|
||||||
self.known_faces_names = []
|
|
||||||
self._load_known_faces()
|
|
||||||
|
|
||||||
def _initialize_insightface(self):
|
|
||||||
"""初始化InsightFace FaceAnalysis应用"""
|
|
||||||
print("初始化InsightFace引擎...")
|
|
||||||
try:
|
|
||||||
app = FaceAnalysis(name='buffalo_l', root='~/.insightface')
|
|
||||||
app.prepare(ctx_id=0, det_size=(640, 640))
|
|
||||||
print("InsightFace引擎初始化完成")
|
|
||||||
return app
|
|
||||||
except Exception as e:
|
|
||||||
print(f"InsightFace初始化失败: {e}")
|
|
||||||
print("请检查依赖是否安装及模型是否可访问")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _load_known_faces(self):
|
|
||||||
"""加载已知人脸特征"""
|
|
||||||
if not os.path.exists(self.known_faces_dir):
|
|
||||||
print(f"已知人脸目录不存在,已创建: {self.known_faces_dir}")
|
|
||||||
os.makedirs(self.known_faces_dir, exist_ok=True)
|
|
||||||
return
|
|
||||||
|
|
||||||
print(f"从目录加载人脸特征: {self.known_faces_dir}")
|
|
||||||
for person_name in os.listdir(self.known_faces_dir):
|
|
||||||
person_dir = os.path.join(self.known_faces_dir, person_name)
|
|
||||||
if os.path.isdir(person_dir):
|
|
||||||
print(f"处理人物: {person_name}")
|
|
||||||
embeddings = []
|
|
||||||
for filename in os.listdir(person_dir):
|
|
||||||
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
|
|
||||||
image_path = os.path.join(person_dir, filename)
|
|
||||||
try:
|
|
||||||
img = cv2.imread(image_path)
|
|
||||||
if img is None:
|
|
||||||
print(f"无法读取图片: {image_path},已跳过")
|
|
||||||
continue
|
|
||||||
|
|
||||||
faces = self.app.get(img)
|
|
||||||
if faces:
|
|
||||||
embeddings.append(faces[0].embedding)
|
|
||||||
print(f"提取特征成功: {filename}")
|
|
||||||
else:
|
|
||||||
print(f"未检测到人脸: {filename},已跳过")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"处理图片出错 {image_path}: {e}")
|
|
||||||
|
|
||||||
if embeddings:
|
|
||||||
self.known_faces_embeddings[person_name] = np.array(embeddings).mean(axis=0)
|
|
||||||
self.known_faces_names.append(person_name)
|
|
||||||
print(f"人物 {person_name} 加载完成,共 {len(embeddings)} 张照片")
|
|
||||||
else:
|
|
||||||
print(f"人物 {person_name} 无有效特征,已跳过")
|
|
||||||
print(f"人脸加载完成,共 {len(self.known_faces_names)} 人")
|
|
||||||
|
|
||||||
def recognize(self, frame, threshold=0.4):
|
|
||||||
"""识别人脸并返回结果"""
|
|
||||||
if not self.app or not self.known_faces_names:
|
|
||||||
return False, None, None
|
|
||||||
|
|
||||||
faces = self.app.get(frame)
|
|
||||||
if not faces:
|
|
||||||
return False, None, None
|
|
||||||
|
|
||||||
for face in faces:
|
|
||||||
for known_name in self.known_faces_names:
|
|
||||||
known_embedding = self.known_faces_embeddings[known_name]
|
|
||||||
|
|
||||||
embedding1 = face.embedding.astype(np.float32)
|
|
||||||
embedding2 = known_embedding.astype(np.float32)
|
|
||||||
|
|
||||||
dot_product = np.dot(embedding1, embedding2)
|
|
||||||
norm_embedding1 = np.linalg.norm(embedding1)
|
|
||||||
norm_embedding2 = np.linalg.norm(embedding2)
|
|
||||||
|
|
||||||
similarity = 0.0 if (norm_embedding1 == 0 or norm_embedding2 == 0) else (
|
|
||||||
dot_product / (norm_embedding1 * norm_embedding2)
|
|
||||||
)
|
|
||||||
|
|
||||||
if similarity >= threshold:
|
|
||||||
print(f"检测到已知人物: {known_name} (相似度: {similarity:.4f})")
|
|
||||||
return True, known_name, similarity
|
|
||||||
|
|
||||||
return False, None, None
|
|
||||||
|
|
||||||
def test_single_image(self, image_path: str, threshold=0.4):
|
|
||||||
"""测试单张图片识别"""
|
|
||||||
if not os.path.exists(image_path):
|
|
||||||
print(f"图片不存在: {image_path}")
|
|
||||||
return False, None, None
|
|
||||||
|
|
||||||
frame = cv2.imread(image_path)
|
|
||||||
if frame is None:
|
|
||||||
print(f"无法读取图片: {image_path}")
|
|
||||||
return False, None, None
|
|
||||||
|
|
||||||
result, name, similarity = self.recognize(frame, threshold)
|
|
||||||
|
|
||||||
if result:
|
|
||||||
print(f"识别结果: {name} (相似度: {similarity:.4f})")
|
|
||||||
|
|
||||||
faces = self.app.get(frame)
|
|
||||||
for face in faces:
|
|
||||||
bbox = face.bbox.astype(int)
|
|
||||||
cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2)
|
|
||||||
text = f"{name}: {similarity:.2f}"
|
|
||||||
cv2.putText(frame, text, (bbox[0], bbox[1] - 10),
|
|
||||||
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
|
|
||||||
|
|
||||||
cv2.imshow('识别结果', frame)
|
|
||||||
print("按任意键关闭窗口...")
|
|
||||||
cv2.waitKey(0)
|
|
||||||
cv2.destroyAllWindows()
|
|
||||||
else:
|
|
||||||
print("未识别到已知人脸")
|
|
||||||
|
|
||||||
return result, name, similarity
|
|
||||||
|
|
||||||
#
|
|
||||||
# if __name__ == "__main__":
|
|
||||||
# recognizer = FaceRecognizer(known_faces_dir="known_faces")
|
|
||||||
# test_image_path = r"F:\OCR\RapidOCR-main\known_faces\B\14sino-qiu02-master1050.jpg"
|
|
||||||
# recognizer.test_single_image(test_image_path, threshold=0.4)
|
|
@ -1,156 +0,0 @@
|
|||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import insightface
|
|
||||||
from insightface.app import FaceAnalysis
|
|
||||||
from io import BytesIO
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
class BinaryFaceFeatureHandler:
|
|
||||||
"""
|
|
||||||
专门处理图片二进制数据的特征提取器,支持分批次接收二进制数据并累积计算平均特征
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.app = self._init_insightface()
|
|
||||||
self.feature_list = [] # 存储所有图片二进制数据提取的特征
|
|
||||||
|
|
||||||
def _init_insightface(self):
|
|
||||||
"""初始化InsightFace引擎"""
|
|
||||||
try:
|
|
||||||
print("正在初始化InsightFace引擎...")
|
|
||||||
app = FaceAnalysis(name='buffalo_l', root='~/.insightface')
|
|
||||||
app.prepare(ctx_id=0, det_size=(640, 640))
|
|
||||||
print("InsightFace引擎初始化完成")
|
|
||||||
return app
|
|
||||||
except Exception as e:
|
|
||||||
print(f"InsightFace初始化失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def add_binary_data(self, binary_data):
|
|
||||||
"""
|
|
||||||
接收单张图片的二进制数据,提取特征并保存
|
|
||||||
|
|
||||||
参数:
|
|
||||||
binary_data: 图片的二进制数据(bytes类型)
|
|
||||||
|
|
||||||
返回:
|
|
||||||
成功提取特征时返回 (True, 特征值numpy数组)
|
|
||||||
失败时返回 (False, None)
|
|
||||||
"""
|
|
||||||
if not self.app:
|
|
||||||
print("引擎未初始化,无法处理")
|
|
||||||
return False, None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 直接处理二进制数据:转换为图像格式
|
|
||||||
img = Image.open(BytesIO(binary_data))
|
|
||||||
frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
|
||||||
|
|
||||||
# 提取特征
|
|
||||||
faces = self.app.get(frame)
|
|
||||||
if faces:
|
|
||||||
# 获取当前提取的特征值
|
|
||||||
current_feature = faces[0].embedding
|
|
||||||
# 添加到特征列表
|
|
||||||
self.feature_list.append(current_feature)
|
|
||||||
print(f"已累计 {len(self.feature_list)} 个特征")
|
|
||||||
# 返回成功标志和当前特征值
|
|
||||||
return True,current_feature
|
|
||||||
else:
|
|
||||||
print("二进制数据中未检测到人脸")
|
|
||||||
return False, None
|
|
||||||
except Exception as e:
|
|
||||||
print(f"处理二进制数据出错: {e}")
|
|
||||||
return False, None
|
|
||||||
|
|
||||||
def get_average_feature(self, features):
|
|
||||||
"""
|
|
||||||
计算多个特征向量的平均值
|
|
||||||
|
|
||||||
参数:
|
|
||||||
features: 特征值列表,每个元素可以是字符串格式或numpy数组
|
|
||||||
例如: [feature1, feature2, ...]
|
|
||||||
返回:
|
|
||||||
单一平均特征向量的numpy数组,若无可计算数据则返回None
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 验证输入是否为列表且不为空
|
|
||||||
if not isinstance(features, list) or len(features) == 0:
|
|
||||||
print("输入必须是包含至少一个特征值的列表")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 处理每个特征值
|
|
||||||
processed_features = []
|
|
||||||
for i, embedding in enumerate(features):
|
|
||||||
try:
|
|
||||||
if isinstance(embedding, str):
|
|
||||||
# 处理包含括号和逗号的字符串格式
|
|
||||||
embedding_str = embedding.replace('[', '').replace(']', '').replace(',', ' ').strip()
|
|
||||||
embedding_list = [float(num) for num in embedding_str.split() if num.strip()]
|
|
||||||
embedding_np = np.array(embedding_list, dtype=np.float32)
|
|
||||||
else:
|
|
||||||
embedding_np = np.array(embedding, dtype=np.float32)
|
|
||||||
|
|
||||||
# 验证特征值格式
|
|
||||||
if len(embedding_np.shape) == 1:
|
|
||||||
processed_features.append(embedding_np)
|
|
||||||
print(f"已添加第 {i + 1} 个特征值用于计算平均值")
|
|
||||||
else:
|
|
||||||
print(f"跳过第 {i + 1} 个特征值,不是一维数组")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"处理第 {i + 1} 个特征值时出错: {e}")
|
|
||||||
|
|
||||||
# 确保有有效的特征值
|
|
||||||
if not processed_features:
|
|
||||||
print("没有有效的特征值用于计算平均值")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 检查所有特征向量维度是否相同
|
|
||||||
dims = {feat.shape[0] for feat in processed_features}
|
|
||||||
if len(dims) > 1:
|
|
||||||
print(f"特征值维度不一致,无法计算平均值。检测到的维度: {dims}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 计算平均值
|
|
||||||
avg_feature = np.mean(processed_features, axis=0)
|
|
||||||
print(f"成功计算 {len(processed_features)} 个特征值的平均特征向量,维度: {avg_feature.shape[0]}")
|
|
||||||
|
|
||||||
return avg_feature
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"计算平均特征值时出错: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# def clear(self):
|
|
||||||
# """清空已存储的特征数据"""
|
|
||||||
# self.feature_list = []
|
|
||||||
# print("已清空所有特征数据")
|
|
||||||
|
|
||||||
|
|
||||||
# # 使用示例
|
|
||||||
# if __name__ == "__main__":
|
|
||||||
# handler = BinaryFaceFeatureHandler()
|
|
||||||
#
|
|
||||||
# # 模拟接收图片二进制数据
|
|
||||||
# try:
|
|
||||||
# # 第一次接收
|
|
||||||
# with open(r"D:\Git\bin\video\ocr\known_faces\B\102-f.jpg_1140x855.jpg", "rb") as f:
|
|
||||||
# bin_data1 = f.read()
|
|
||||||
# success, feature1 = handler.add_binary_data(bin_data1)
|
|
||||||
# if success:
|
|
||||||
# print(f"第一次提取的特征值前5个: {feature1[:5]}")
|
|
||||||
#
|
|
||||||
# # 第二次接收
|
|
||||||
# with open(r"D:\Git\bin\video\ocr\known_faces\B\104-1.jpg", "rb") as f:
|
|
||||||
# bin_data2 = f.read()
|
|
||||||
# success, feature2 = handler.add_binary_data(bin_data2)
|
|
||||||
# if success:
|
|
||||||
# print(f"第二次提取的特征值前5个: {feature2[:5]}")
|
|
||||||
#
|
|
||||||
# # 计算平均值
|
|
||||||
# avg_feature = handler.get_average_feature()
|
|
||||||
#
|
|
||||||
# except Exception as e:
|
|
||||||
# print(f"处理过程出错: {e}")
|
|
@ -1,18 +0,0 @@
|
|||||||
打倒习近平
|
|
||||||
打到习近平
|
|
||||||
打倒毛泽东
|
|
||||||
打到毛泽东
|
|
||||||
打到主席
|
|
||||||
打倒主席
|
|
||||||
打到共产主义
|
|
||||||
打倒共产主义
|
|
||||||
打到共产党
|
|
||||||
打倒共产党
|
|
||||||
胖猫
|
|
||||||
法轮功
|
|
||||||
法轮大法
|
|
||||||
法轮大法好
|
|
||||||
法轮功大法好
|
|
||||||
法轮
|
|
||||||
李洪志
|
|
||||||
习近平
|
|
Before Width: | Height: | Size: 195 KiB |
Before Width: | Height: | Size: 208 KiB |
Before Width: | Height: | Size: 657 KiB |
Before Width: | Height: | Size: 53 KiB |
Before Width: | Height: | Size: 8.1 KiB |
Before Width: | Height: | Size: 14 KiB |
Before Width: | Height: | Size: 58 KiB |
Before Width: | Height: | Size: 4.9 KiB |
Before Width: | Height: | Size: 34 KiB |
Before Width: | Height: | Size: 155 KiB |
Before Width: | Height: | Size: 386 KiB |
Before Width: | Height: | Size: 1.4 MiB |
Before Width: | Height: | Size: 62 KiB |
@ -1,49 +0,0 @@
|
|||||||
#日志文件
|
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
|
|
||||||
def setup_logger():
|
|
||||||
"""
|
|
||||||
配置一个全局日志记录器,支持输出到控制台和文件。
|
|
||||||
"""
|
|
||||||
# 创建一个日志记录器
|
|
||||||
|
|
||||||
# 配置日志
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
logger = logging.getLogger("ViolationDetectorLogger")
|
|
||||||
logger.setLevel(logging.DEBUG) # 设置最低级别为DEBUG
|
|
||||||
|
|
||||||
# 如果已经有处理器了,就不要重复添加,防止日志重复打印
|
|
||||||
if logger.hasHandlers():
|
|
||||||
return logger
|
|
||||||
|
|
||||||
# --- 控制台处理器 ---
|
|
||||||
console_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
# 对于控制台,我们只显示INFO及以上级别的信息
|
|
||||||
console_handler.setLevel(logging.INFO)
|
|
||||||
console_formatter = logging.Formatter(
|
|
||||||
'%(asctime)s - %(levelname)s - [%(module)s:%(lineno)d] - %(message)s',
|
|
||||||
datefmt='%Y-%m-%d %H:%M:%S'
|
|
||||||
)
|
|
||||||
console_handler.setFormatter(console_formatter)
|
|
||||||
|
|
||||||
# --- 文件处理器 ---
|
|
||||||
file_handler = logging.FileHandler("violation_detector.log", mode='a', encoding='utf-8')
|
|
||||||
# 对于文件,我们记录所有DEBUG及以上级别的信息
|
|
||||||
file_handler.setLevel(logging.DEBUG)
|
|
||||||
file_formatter = logging.Formatter(
|
|
||||||
'%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
|
|
||||||
datefmt='%Y-%m-%d %H:%M:%S'
|
|
||||||
)
|
|
||||||
file_handler.setFormatter(file_formatter)
|
|
||||||
|
|
||||||
# 将处理器添加到日志记录器
|
|
||||||
logger.addHandler(console_handler)
|
|
||||||
logger.addHandler(file_handler)
|
|
||||||
|
|
||||||
return logger
|
|
||||||
|
|
||||||
# 创建并导出logger实例
|
|
||||||
logger = setup_logger()
|
|
@ -1,136 +0,0 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import yaml
|
|
||||||
from pathlib import Path
|
|
||||||
from .ocr_violation_detector import OCRViolationDetector
|
|
||||||
from .yolo_violation_detector import ViolationDetector as YoloViolationDetector
|
|
||||||
from .face_recognizer import FaceRecognizer
|
|
||||||
|
|
||||||
class MultiModelViolationDetector:
|
|
||||||
"""
|
|
||||||
多模型违规检测封装类,串行调用OCR、人脸识别和YOLO模型,任一模型检测到违规即返回结果
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
forbidden_words_path: str,
|
|
||||||
ocr_config_path: str,
|
|
||||||
yolo_model_path: str,
|
|
||||||
known_faces_dir: str,
|
|
||||||
ocr_confidence_threshold: float = 0.5):
|
|
||||||
"""
|
|
||||||
初始化所有检测模型
|
|
||||||
"""
|
|
||||||
# 初始化OCR检测器
|
|
||||||
self.ocr_detector = OCRViolationDetector(
|
|
||||||
forbidden_words_path=forbidden_words_path,
|
|
||||||
ocr_config_path=ocr_config_path,
|
|
||||||
ocr_confidence_threshold=ocr_confidence_threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
# 初始化人脸识别器
|
|
||||||
self.face_recognizer = FaceRecognizer(
|
|
||||||
known_faces_dir=known_faces_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
# 初始化YOLO检测器
|
|
||||||
self.yolo_detector = YoloViolationDetector(
|
|
||||||
model_path=yolo_model_path
|
|
||||||
)
|
|
||||||
|
|
||||||
print("多模型违规检测器初始化完成")
|
|
||||||
|
|
||||||
def detect_violations(self, frame):
|
|
||||||
"""
|
|
||||||
串行调用三个检测模型(OCR → 人脸识别 → YOLO),任一检测到违规即返回结果
|
|
||||||
"""
|
|
||||||
# 1. 首先进行OCR违禁词检测
|
|
||||||
try:
|
|
||||||
ocr_has_violation, ocr_words, ocr_confs = self.ocr_detector.detect(frame)
|
|
||||||
if ocr_has_violation:
|
|
||||||
details = {
|
|
||||||
"words": ocr_words,
|
|
||||||
"confidences": ocr_confs
|
|
||||||
}
|
|
||||||
print(f"警告: OCR检测到违禁内容: {details}")
|
|
||||||
return (True, "ocr", details)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"错误: OCR检测出错: {str(e)}")
|
|
||||||
|
|
||||||
# 2. 接着进行人脸识别检测
|
|
||||||
try:
|
|
||||||
face_has_violation, face_name, face_similarity = self.face_recognizer.recognize(frame)
|
|
||||||
if face_has_violation:
|
|
||||||
details = {
|
|
||||||
"name": face_name,
|
|
||||||
"similarity": face_similarity
|
|
||||||
}
|
|
||||||
print(f"警告: 人脸识别到违规人员: {details}")
|
|
||||||
return (True, "face", details)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"错误: 人脸识别出错: {str(e)}")
|
|
||||||
|
|
||||||
# 3. 最后进行YOLO目标检测
|
|
||||||
try:
|
|
||||||
yolo_results = self.yolo_detector.detect(frame)
|
|
||||||
if len(yolo_results.boxes) > 0:
|
|
||||||
details = {
|
|
||||||
"classes": yolo_results.names,
|
|
||||||
"boxes": yolo_results.boxes.xyxy.tolist(),
|
|
||||||
"confidences": yolo_results.boxes.conf.tolist(),
|
|
||||||
"class_ids": yolo_results.boxes.cls.tolist()
|
|
||||||
}
|
|
||||||
print(f"警告: YOLO检测到违规目标: {details}")
|
|
||||||
return (True, "yolo", details)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"错误: YOLO检测出错: {str(e)}")
|
|
||||||
|
|
||||||
# 所有检测均未发现违规
|
|
||||||
return (False, None, None)
|
|
||||||
|
|
||||||
|
|
||||||
def load_config(config_path: str) -> dict:
|
|
||||||
"""加载YAML配置文件"""
|
|
||||||
try:
|
|
||||||
with open(config_path, 'r', encoding='utf-8') as f:
|
|
||||||
return yaml.safe_load(f)
|
|
||||||
except FileNotFoundError:
|
|
||||||
print(f"错误: 配置文件未找到: {config_path}")
|
|
||||||
raise
|
|
||||||
except yaml.YAMLError as e:
|
|
||||||
print(f"错误: 配置文件格式错误: {config_path}, 错误: {str(e)}")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
print(f"错误: 加载配置文件出错: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
# 使用示例
|
|
||||||
# if __name__ == "__main__":
|
|
||||||
# # 加载配置文件
|
|
||||||
# config = load_config("config.yaml") # 配置文件路径,可根据实际情况修改
|
|
||||||
#
|
|
||||||
# # 初始化多模型检测器
|
|
||||||
# detector = MultiModelViolationDetector(
|
|
||||||
# forbidden_words_path=config["forbidden_words_path"],
|
|
||||||
# ocr_config_path=config["ocr_config_path"],
|
|
||||||
# yolo_model_path=config["yolo_model_path"],
|
|
||||||
# known_faces_dir=config["known_faces_dir"],
|
|
||||||
# ocr_confidence_threshold=config.get("ocr_confidence_threshold", 0.5)
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# # 读取测试图像(可替换为视频帧读取逻辑)
|
|
||||||
# test_image_path = config.get("test_image_path") # 从配置文件获取测试图片路径
|
|
||||||
# if test_image_path:
|
|
||||||
# frame = cv2.imread(test_image_path)
|
|
||||||
#
|
|
||||||
# if frame is not None:
|
|
||||||
# has_violation, violation_type, details = detector.detect_violations(frame)
|
|
||||||
# if has_violation:
|
|
||||||
# print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
|
|
||||||
# else:
|
|
||||||
# print("未检测到任何违规内容")
|
|
||||||
# else:
|
|
||||||
# print(f"无法读取测试图像: {test_image_path}")
|
|
||||||
# else:
|
|
||||||
# print("配置文件中未指定测试图像路径")
|
|
@ -1,178 +0,0 @@
|
|||||||
import os
|
|
||||||
import cv2
|
|
||||||
from rapidocr import RapidOCR
|
|
||||||
|
|
||||||
|
|
||||||
class OCRViolationDetector:
|
|
||||||
"""
|
|
||||||
封装RapidOCR引擎,用于检测图像帧中的违禁词。
|
|
||||||
核心功能:加载违禁词、初始化OCR引擎、单帧图像违禁词检测
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
forbidden_words_path: str,
|
|
||||||
ocr_config_path: str,
|
|
||||||
ocr_confidence_threshold: float = 0.5):
|
|
||||||
"""
|
|
||||||
初始化OCR引擎和违禁词列表。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
forbidden_words_path (str): 违禁词列表 .txt 文件的路径。
|
|
||||||
ocr_config_path (str): OCR配置文件(如1.yaml)的路径。
|
|
||||||
ocr_confidence_threshold (float): OCR识别结果的置信度阈值(0~1)。
|
|
||||||
"""
|
|
||||||
# 加载违禁词
|
|
||||||
self.forbidden_words = self._load_forbidden_words(forbidden_words_path)
|
|
||||||
|
|
||||||
# 初始化RapidOCR引擎
|
|
||||||
self.ocr_engine = self._initialize_ocr(ocr_config_path)
|
|
||||||
|
|
||||||
# 校验核心依赖是否就绪
|
|
||||||
self._check_dependencies()
|
|
||||||
|
|
||||||
# 设置置信度阈值(限制在0~1范围)
|
|
||||||
self.OCR_CONFIDENCE_THRESHOLD = max(0.0, min(ocr_confidence_threshold, 1.0))
|
|
||||||
print(f"OCR置信度阈值已设置(范围0~1): {self.OCR_CONFIDENCE_THRESHOLD:.4f}")
|
|
||||||
|
|
||||||
def _load_forbidden_words(self, path: str) -> set:
|
|
||||||
"""
|
|
||||||
从TXT文件加载违禁词(去重、过滤空行,支持UTF-8编码)
|
|
||||||
"""
|
|
||||||
forbidden_words = set()
|
|
||||||
|
|
||||||
# 检查文件是否存在
|
|
||||||
if not os.path.exists(path):
|
|
||||||
print(f"错误:违禁词文件不存在: {path}")
|
|
||||||
return forbidden_words
|
|
||||||
|
|
||||||
# 读取文件并处理内容
|
|
||||||
try:
|
|
||||||
with open(path, 'r', encoding='utf-8') as f:
|
|
||||||
forbidden_words = {
|
|
||||||
line.strip() for line in f
|
|
||||||
if line.strip() # 跳过空行或纯空格行
|
|
||||||
}
|
|
||||||
print(f"成功加载违禁词: {len(forbidden_words)} 个(已去重)")
|
|
||||||
except UnicodeDecodeError:
|
|
||||||
print(f"错误:违禁词文件编码错误(需UTF-8): {path}")
|
|
||||||
except PermissionError:
|
|
||||||
print(f"错误:无权限读取违禁词文件: {path}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"错误:加载违禁词失败: {str(e)}")
|
|
||||||
|
|
||||||
return forbidden_words
|
|
||||||
|
|
||||||
def _initialize_ocr(self, config_path: str) -> RapidOCR | None:
|
|
||||||
"""
|
|
||||||
初始化RapidOCR引擎(校验配置文件、捕获初始化异常)
|
|
||||||
"""
|
|
||||||
print("开始初始化RapidOCR引擎...")
|
|
||||||
|
|
||||||
# 检查配置文件是否存在
|
|
||||||
if not os.path.exists(config_path):
|
|
||||||
print(f"错误:OCR配置文件不存在: {config_path}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 初始化OCR引擎
|
|
||||||
try:
|
|
||||||
ocr_engine = RapidOCR(config_path=config_path)
|
|
||||||
print("RapidOCR引擎初始化成功")
|
|
||||||
return ocr_engine
|
|
||||||
except ImportError:
|
|
||||||
print("错误:RapidOCR依赖未安装(需执行:pip install rapidocr-onnxruntime)")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"错误:RapidOCR初始化失败: {str(e)}")
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _check_dependencies(self) -> None:
|
|
||||||
"""校验OCR引擎和违禁词列表是否就绪"""
|
|
||||||
if not self.ocr_engine:
|
|
||||||
print("警告:⚠️ OCR引擎未就绪,违禁词检测功能将禁用")
|
|
||||||
if not self.forbidden_words:
|
|
||||||
print("警告:⚠️ 违禁词列表为空,违禁词检测功能将禁用")
|
|
||||||
|
|
||||||
def detect(self, frame) -> tuple[bool, list, list]:
|
|
||||||
"""
|
|
||||||
对单帧图像进行OCR违禁词检测(核心方法)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
frame: 输入图像帧(NumPy数组,BGR格式,cv2读取的图像)。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[bool, list, list]:
|
|
||||||
- 第一个元素:是否检测到违禁词(True/False);
|
|
||||||
- 第二个元素:检测到的违禁词列表(空列表表示无违禁词);
|
|
||||||
- 第三个元素:对应违禁词的置信度列表(与违禁词列表一一对应)。
|
|
||||||
"""
|
|
||||||
# 初始化返回结果
|
|
||||||
has_violation = False
|
|
||||||
violation_words = []
|
|
||||||
violation_confs = []
|
|
||||||
|
|
||||||
# 前置校验
|
|
||||||
if frame is None or frame.size == 0:
|
|
||||||
print("警告:输入图像帧为空或无效,跳过OCR检测")
|
|
||||||
return has_violation, violation_words, violation_confs
|
|
||||||
if not self.ocr_engine or not self.forbidden_words:
|
|
||||||
print("OCR引擎未就绪或违禁词为空,跳过OCR检测")
|
|
||||||
return has_violation, violation_words, violation_confs
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 执行OCR识别
|
|
||||||
print("开始执行OCR识别...")
|
|
||||||
ocr_result = self.ocr_engine(frame)
|
|
||||||
print(f"RapidOCR原始结果: {ocr_result}")
|
|
||||||
|
|
||||||
# 校验OCR结果是否有效
|
|
||||||
if ocr_result is None:
|
|
||||||
print("OCR识别未返回任何结果(图像无文本或识别失败)")
|
|
||||||
return has_violation, violation_words, violation_confs
|
|
||||||
|
|
||||||
# 检查txts和scores是否存在且不为None
|
|
||||||
if not hasattr(ocr_result, 'txts') or ocr_result.txts is None:
|
|
||||||
print("警告:OCR结果中txts为None或不存在")
|
|
||||||
return has_violation, violation_words, violation_confs
|
|
||||||
|
|
||||||
if not hasattr(ocr_result, 'scores') or ocr_result.scores is None:
|
|
||||||
print("警告:OCR结果中scores为None或不存在")
|
|
||||||
return has_violation, violation_words, violation_confs
|
|
||||||
|
|
||||||
# 转为列表并去None
|
|
||||||
if not isinstance(ocr_result.txts, (list, tuple)):
|
|
||||||
print(f"警告:OCR txts不是可迭代类型,实际类型: {type(ocr_result.txts)}")
|
|
||||||
texts = []
|
|
||||||
else:
|
|
||||||
texts = [txt.strip() for txt in ocr_result.txts if txt and isinstance(txt, str)]
|
|
||||||
|
|
||||||
if not isinstance(ocr_result.scores, (list, tuple)):
|
|
||||||
print(f"警告:OCR scores不是可迭代类型,实际类型: {type(ocr_result.scores)}")
|
|
||||||
confidences = []
|
|
||||||
else:
|
|
||||||
confidences = [conf for conf in ocr_result.scores if conf and isinstance(conf, (int, float))]
|
|
||||||
|
|
||||||
# 校验文本和置信度列表长度是否一致
|
|
||||||
if len(texts) != len(confidences):
|
|
||||||
print(f"警告:OCR文本与置信度数量不匹配(文本{len(texts)}个,置信度{len(confidences)}个),跳过检测")
|
|
||||||
return has_violation, violation_words, violation_confs
|
|
||||||
if len(texts) == 0:
|
|
||||||
print("OCR未识别到任何有效文本")
|
|
||||||
return has_violation, violation_words, violation_confs
|
|
||||||
|
|
||||||
# 遍历识别结果,筛选违禁词
|
|
||||||
print(f"开始筛选违禁词(阈值{self.OCR_CONFIDENCE_THRESHOLD:.4f})")
|
|
||||||
for text, conf in zip(texts, confidences):
|
|
||||||
if conf < self.OCR_CONFIDENCE_THRESHOLD:
|
|
||||||
print(f"文本 '{text}' 置信度{conf:.4f} < 阈值,跳过")
|
|
||||||
continue
|
|
||||||
matched_words = [word for word in self.forbidden_words if word in text]
|
|
||||||
if matched_words:
|
|
||||||
has_violation = True
|
|
||||||
violation_words.extend(matched_words)
|
|
||||||
violation_confs.extend([conf] * len(matched_words))
|
|
||||||
print(f"警告:检测到违禁词: {matched_words}(来源文本: '{text}',置信度: {conf:.4f})")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"错误:OCR检测过程异常: {str(e)}")
|
|
||||||
|
|
||||||
return has_violation, violation_words, violation_confs
|
|
@ -1,47 +0,0 @@
|
|||||||
from ultralytics import YOLO
|
|
||||||
import cv2
|
|
||||||
|
|
||||||
class ViolationDetector:
|
|
||||||
"""
|
|
||||||
用于加载YOLOv8 .pt模型并进行违规内容检测的类。
|
|
||||||
"""
|
|
||||||
def __init__(self, model_path):
|
|
||||||
"""
|
|
||||||
初始化检测器。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path (str): YOLO .pt模型的路径。
|
|
||||||
"""
|
|
||||||
print(f"正在从 '{model_path}' 加载YOLO模型...")
|
|
||||||
self.model = YOLO(model_path)
|
|
||||||
print("YOLO模型加载成功。")
|
|
||||||
|
|
||||||
def detect(self, frame):
|
|
||||||
"""
|
|
||||||
对单帧图像进行目标检测。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
frame: 输入的图像帧 (NumPy数组, BGR格式)。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ultralytics.engine.results.Results: YOLO的检测结果对象。
|
|
||||||
"""
|
|
||||||
# conf可以根据您的模型效果进行调整
|
|
||||||
# --- 为了测试,我们暂时将置信度调低,例如 0.2 ---
|
|
||||||
results = self.model(frame, conf=0.2)
|
|
||||||
return results[0]
|
|
||||||
|
|
||||||
def draw_boxes(self, frame, result):
|
|
||||||
"""
|
|
||||||
在图像帧上绘制检测框。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
frame: 原始图像帧。
|
|
||||||
result: YOLO的检测结果对象。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
numpy.ndarray: 绘制了检测框的图像帧。
|
|
||||||
"""
|
|
||||||
# 使用YOLO自带的plot功能,方便快捷
|
|
||||||
annotated_frame = result.plot()
|
|
||||||
return annotated_frame
|
|
164
rtc/rtc.py
@ -1,164 +0,0 @@
|
|||||||
import queue
|
|
||||||
import asyncio
|
|
||||||
import aiohttp
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration
|
|
||||||
from aiortc.mediastreams import MediaStreamTrack
|
|
||||||
|
|
||||||
# 创建一个长度为1的队列,用于生产者和消费者之间的通信
|
|
||||||
frame_queue = queue.Queue(maxsize=1)
|
|
||||||
|
|
||||||
|
|
||||||
class VideoTrack(MediaStreamTrack):
|
|
||||||
"""自定义视频轨道类,继承自MediaStreamTrack"""
|
|
||||||
kind = "video"
|
|
||||||
|
|
||||||
def __init__(self, max_frames=100):
|
|
||||||
super().__init__()
|
|
||||||
self.frames = queue.Queue(maxsize=max_frames)
|
|
||||||
|
|
||||||
async def recv(self):
|
|
||||||
return await super().recv()
|
|
||||||
|
|
||||||
|
|
||||||
def webrtc_producer(webrtc_url):
|
|
||||||
"""
|
|
||||||
生产者方法:从WEBRTC读取视频帧并放入队列
|
|
||||||
仅当队列空时才放入新帧,否则丢弃
|
|
||||||
"""
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
# 创建RTCPeerConnection对象,不使用ICE服务器
|
|
||||||
pc = RTCPeerConnection(RTCConfiguration(iceServers=[]))
|
|
||||||
video_track = VideoTrack()
|
|
||||||
pc.addTrack(video_track)
|
|
||||||
|
|
||||||
@pc.on("track")
|
|
||||||
async def on_track(track):
|
|
||||||
if track.kind == "video":
|
|
||||||
print("接收到视频轨道,开始接收视频帧")
|
|
||||||
while True:
|
|
||||||
# 从轨道接收视频帧
|
|
||||||
frame = await track.recv()
|
|
||||||
# 转换为BGR24格式的NumPy数组
|
|
||||||
frame_bgr24 = frame.to_ndarray(format='bgr24')
|
|
||||||
|
|
||||||
# 检查队列是否为空,为空则加入,否则丢弃
|
|
||||||
if frame_queue.empty():
|
|
||||||
try:
|
|
||||||
frame_queue.put_nowait(frame_bgr24)
|
|
||||||
print("帧已放入队列")
|
|
||||||
except queue.Full:
|
|
||||||
print("队列已满,丢弃帧")
|
|
||||||
else:
|
|
||||||
print("队列非空,丢弃帧")
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
# 创建并发送SDP Offer
|
|
||||||
offer = await pc.createOffer()
|
|
||||||
print("已创建本地SDP Offer")
|
|
||||||
await pc.setLocalDescription(offer)
|
|
||||||
|
|
||||||
# 发送Offer到服务器并接收Answer
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
print(f"开始向服务器 {webrtc_url} 发送SDP Offer")
|
|
||||||
async with session.post(
|
|
||||||
webrtc_url,
|
|
||||||
data=offer.sdp.encode(),
|
|
||||||
headers={
|
|
||||||
"Content-Type": "application/sdp",
|
|
||||||
"Content-Length": str(len(offer.sdp))
|
|
||||||
},
|
|
||||||
ssl=False
|
|
||||||
) as response:
|
|
||||||
print("已接收到服务器的响应")
|
|
||||||
answer_sdp = await response.text()
|
|
||||||
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer'))
|
|
||||||
|
|
||||||
# 保持连接
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
pass
|
|
||||||
finally:
|
|
||||||
print("关闭RTCPeerConnection")
|
|
||||||
await pc.close()
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(main())
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
def frame_consumer(ip):
|
|
||||||
"""
|
|
||||||
消费者方法:从队列中读取帧并处理
|
|
||||||
每次处理后休眠200ms模拟延迟
|
|
||||||
"""
|
|
||||||
print("消费者启动,开始等待帧...")
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
# 阻塞等待队列中的帧
|
|
||||||
frame = frame_queue.get()
|
|
||||||
print(f"消费帧,大小: {frame.shape}")
|
|
||||||
|
|
||||||
has_violation, violations, confidences = OCRViolationDetector.detect(frame)
|
|
||||||
|
|
||||||
|
|
||||||
# 输出检测结果
|
|
||||||
if has_violation:
|
|
||||||
detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
|
|
||||||
for word, conf in zip(violations, confidences):
|
|
||||||
detector.logger.info(f"- {word} (置信度: {conf:.4f})")
|
|
||||||
else:
|
|
||||||
detector.logger.info("图片中未检测到违禁词")
|
|
||||||
|
|
||||||
|
|
||||||
# 标记任务完成
|
|
||||||
frame_queue.task_done()
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("消费者退出")
|
|
||||||
|
|
||||||
|
|
||||||
def start_webrtc_stream(ip, webrtc_url):
|
|
||||||
"""
|
|
||||||
启动WebRTC视频流处理的主方法
|
|
||||||
参数: webrtc_url - WebRTC服务器地址
|
|
||||||
"""
|
|
||||||
print(f"开始连接到WebRTC服务器: {webrtc_url}")
|
|
||||||
|
|
||||||
# 启动生产者线程
|
|
||||||
producer_thread = threading.Thread(
|
|
||||||
target=webrtc_producer,
|
|
||||||
args=(webrtc_url,),
|
|
||||||
daemon=True,
|
|
||||||
name="webrtc-producer"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 启动消费者线程
|
|
||||||
consumer_thread = threading.Thread(
|
|
||||||
target=frame_consumer(ip),
|
|
||||||
daemon=True,
|
|
||||||
name="frame-consumer"
|
|
||||||
)
|
|
||||||
|
|
||||||
producer_thread.start()
|
|
||||||
consumer_thread.start()
|
|
||||||
print("生产者和消费者线程已启动")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 保持主线程运行
|
|
||||||
while True:
|
|
||||||
time.sleep(1)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("程序正在退出...")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 示例用法
|
|
||||||
# 实际使用时替换为真实的WebRTC服务器地址
|
|
||||||
webrtc_server_url = "http://192.168.110.65:1985/rtc/v1/whep/?app=live&stream=677a4845aa48cb8526c811ad56fc5e60"
|
|
||||||
start_webrtc_stream(webrtc_server_url)
|
|
101
rtmp/rtmp.py
@ -1,101 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import cv2
|
|
||||||
import time
|
|
||||||
|
|
||||||
# 配置日志(与WHEP代码保持一致的日志风格)
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger("rtmp_video_puller")
|
|
||||||
|
|
||||||
|
|
||||||
async def rtmp_pull_video_stream(rtmp_url):
|
|
||||||
"""
|
|
||||||
通过RTMP从指定URL拉取视频流并在收到每一帧时打印消息
|
|
||||||
功能与WHEP拉流函数对齐:流状态反馈、帧信息打印、帧率统计、异常处理
|
|
||||||
|
|
||||||
Args:
|
|
||||||
rtmp_url: RTMP流的URL地址(如 rtmp://xxx/live/stream_key)
|
|
||||||
"""
|
|
||||||
cap = None # 初始化视频捕获对象
|
|
||||||
try:
|
|
||||||
# 1. 异步打开RTMP流(指定FFmpeg后端确保RTMP兼容性,同步操作通过to_thread避免阻塞事件循环)
|
|
||||||
cap = await asyncio.to_thread(
|
|
||||||
cv2.VideoCapture,
|
|
||||||
rtmp_url,
|
|
||||||
cv2.CAP_FFMPEG # 必须指定FFmpeg后端,RTMP协议依赖该后端解析
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. 检查RTMP流是否成功打开
|
|
||||||
is_opened = await asyncio.to_thread(cap.isOpened)
|
|
||||||
if not is_opened:
|
|
||||||
raise Exception(f"RTMP流打开失败: {rtmp_url}(请检查URL有效性和FFmpeg环境)")
|
|
||||||
|
|
||||||
# 3. 异步获取RTMP流基础信息(分辨率、帧率)
|
|
||||||
width = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_WIDTH)
|
|
||||||
height = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_HEIGHT)
|
|
||||||
fps = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FPS)
|
|
||||||
|
|
||||||
# 处理异常情况:部分RTMP流未返回帧率时默认30FPS
|
|
||||||
fps = fps if fps > 0 else 30.0
|
|
||||||
# 分辨率转为整数(视频尺寸必然是整数)
|
|
||||||
width, height = int(width), int(height)
|
|
||||||
|
|
||||||
# 打印流初始化成功信息(与WHEP连接成功信息风格一致)
|
|
||||||
print(f"RTMP流状态: 已成功连接")
|
|
||||||
print(f"流基础信息: 分辨率 {width}x{height} | 配置帧率 {fps:.2f} FPS")
|
|
||||||
print("开始接收视频帧...(按 Ctrl+C 中断)")
|
|
||||||
|
|
||||||
# 4. 初始化帧统计参数
|
|
||||||
frame_count = 0 # 总接收帧数
|
|
||||||
start_time = time.time() # 统计起始时间
|
|
||||||
|
|
||||||
# 5. 循环异步读取视频帧(核心逻辑)
|
|
||||||
while True:
|
|
||||||
# 异步读取一帧(cv2.read是同步操作,用to_thread适配异步环境)
|
|
||||||
ret, frame = await asyncio.to_thread(cap.read)
|
|
||||||
|
|
||||||
# 检查帧是否读取成功(流中断/结束时ret为False)
|
|
||||||
if not ret:
|
|
||||||
print(f"RTMP流状态: 帧读取失败(可能流已中断或结束)")
|
|
||||||
break
|
|
||||||
|
|
||||||
# 帧计数累加
|
|
||||||
frame_count += 1
|
|
||||||
|
|
||||||
# 6. 打印当前帧基础信息(与WHEP帧信息打印风格对齐)
|
|
||||||
print(f"收到帧 (第{frame_count}帧)")
|
|
||||||
print(f" 帧尺寸: {width}x{height}")
|
|
||||||
print(f" 配置帧率: {fps:.2f} FPS")
|
|
||||||
|
|
||||||
# 7. 每100帧统计一次实际接收帧率(补充性能监控,与原RTMP示例逻辑一致)
|
|
||||||
if frame_count % 100 == 0:
|
|
||||||
elapsed_time = time.time() - start_time
|
|
||||||
actual_fps = frame_count / elapsed_time # 实际接收帧率(可能低于配置帧率)
|
|
||||||
print(f"---- 帧统计: 累计{frame_count}帧 | 实际平均帧率 {actual_fps:.2f} FPS ----")
|
|
||||||
|
|
||||||
# (可选)帧数据处理入口:如需处理帧(如推流、分析),可在此处添加逻辑
|
|
||||||
# 示例:yield frame (若需生成器模式,可调整函数为异步生成器)
|
|
||||||
|
|
||||||
# 8. 异常处理(覆盖用户中断、通用错误)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print(f"\n用户操作: 已通过 Ctrl+C 中断程序")
|
|
||||||
except Exception as e:
|
|
||||||
# 日志记录详细错误(便于问题排查),同时打印用户可见信息
|
|
||||||
logger.error(f"RTMP流处理异常: {str(e)}", exc_info=True)
|
|
||||||
print(f"错误信息: {str(e)}")
|
|
||||||
finally:
|
|
||||||
# 9. 资源释放(无论成功/失败都确保释放,避免内存泄漏)
|
|
||||||
if cap is not None:
|
|
||||||
await asyncio.to_thread(cap.release)
|
|
||||||
print(f"\n资源释放: RTMP流已关闭")
|
|
||||||
print(f"最终统计: 共接收 {frame_count if 'frame_count' in locals() else 0} 帧")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
RTMP_URL = "rtmp://192.168.110.25:1935/live/473b95a47e338301cbd96809ea7ac416"
|
|
||||||
|
|
||||||
# 运行RTMP拉流任务(与WHEP一致的异步执行方式)
|
|
||||||
try:
|
|
||||||
asyncio.run(rtmp_pull_video_stream(RTMP_URL))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"程序启动失败: {str(e)}")
|
|
@ -23,7 +23,7 @@ class FaceResponse(BaseModel):
|
|||||||
"""人脸记录响应模型(仍包含ID,由数据库生成后返回)"""
|
"""人脸记录响应模型(仍包含ID,由数据库生成后返回)"""
|
||||||
id: int = Field(..., description="主键ID(数据库自增)")
|
id: int = Field(..., description="主键ID(数据库自增)")
|
||||||
name: str = Field(None, description="名称")
|
name: str = Field(None, description="名称")
|
||||||
eigenvalue: str = Field(None, description="特征(暂为None)")
|
eigenvalue: str | None = Field(None, description="特征(可为空)")
|
||||||
created_at: datetime = Field(..., description="记录创建时间")
|
created_at: datetime = Field(..., description="记录创建时间")
|
||||||
updated_at: datetime = Field(..., description="记录更新时间")
|
updated_at: datetime = Field(..., description="记录更新时间")
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
from fastapi import APIRouter, Query, HTTPException
|
from fastapi import APIRouter, Query, HTTPException,Request
|
||||||
from mysql.connector import Error as MySQLError
|
from mysql.connector import Error as MySQLError
|
||||||
|
|
||||||
from ds.db import db
|
from ds.db import db
|
||||||
@ -108,7 +108,7 @@ def update_online_status_by_ip(client_ip: str, online_status: int) -> bool:
|
|||||||
# 原有接口保持不变
|
# 原有接口保持不变
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
@router.post("/add", response_model=APIResponse, summary="创建设备信息")
|
@router.post("/add", response_model=APIResponse, summary="创建设备信息")
|
||||||
async def create_device(device_data: DeviceCreateRequest):
|
async def create_device(device_data: DeviceCreateRequest, request: Request): # 注入Request对象
|
||||||
# 原有代码保持不变
|
# 原有代码保持不变
|
||||||
conn = None
|
conn = None
|
||||||
cursor = None
|
cursor = None
|
||||||
@ -125,11 +125,10 @@ async def create_device(device_data: DeviceCreateRequest):
|
|||||||
return APIResponse(
|
return APIResponse(
|
||||||
code=200,
|
code=200,
|
||||||
message=f"设备IP {device_data.ip} 已存在,返回已有设备信息",
|
message=f"设备IP {device_data.ip} 已存在,返回已有设备信息",
|
||||||
data=DeviceResponse(**existing_device)
|
data=DeviceResponse(** existing_device)
|
||||||
)
|
)
|
||||||
|
|
||||||
from fastapi import Request
|
# 直接使用注入的request对象获取用户代理
|
||||||
request = Request(scope={"type": "http"})
|
|
||||||
user_agent = request.headers.get("User-Agent", "").lower()
|
user_agent = request.headers.get("User-Agent", "").lower()
|
||||||
|
|
||||||
if user_agent == "default":
|
if user_agent == "default":
|
||||||
@ -184,7 +183,6 @@ async def create_device(device_data: DeviceCreateRequest):
|
|||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=APIResponse, summary="获取设备列表(支持筛选分页)")
|
@router.get("/", response_model=APIResponse, summary="获取设备列表(支持筛选分页)")
|
||||||
async def get_device_list(
|
async def get_device_list(
|
||||||
page: int = Query(1, ge=1, description="页码,默认第1页"),
|
page: int = Query(1, ge=1, description="页码,默认第1页"),
|
||||||
|
@ -6,15 +6,15 @@ from schema.face_schema import FaceCreateRequest, FaceUpdateRequest, FaceRespons
|
|||||||
from schema.response_schema import APIResponse
|
from schema.response_schema import APIResponse
|
||||||
from middle.auth_middleware import get_current_user
|
from middle.auth_middleware import get_current_user
|
||||||
from schema.user_schema import UserResponse
|
from schema.user_schema import UserResponse
|
||||||
from ocr.feature_extraction import BinaryFaceFeatureHandler
|
|
||||||
|
from util.face_util import add_binary_data,get_average_feature
|
||||||
|
#初始化实例
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
prefix="/faces",
|
prefix="/faces",
|
||||||
tags=["人脸管理"]
|
tags=["人脸管理"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建 BinaryFaceFeatureHandler 的实例
|
|
||||||
binary_face_feature_handler = BinaryFaceFeatureHandler()
|
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
@ -33,6 +33,8 @@ async def create_face(
|
|||||||
- ID 由数据库自动生成,无需前端传入
|
- ID 由数据库自动生成,无需前端传入
|
||||||
- 暂不处理文件内容,eigenvalue 设为 None
|
- 暂不处理文件内容,eigenvalue 设为 None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# 调用你的方法
|
||||||
conn = None
|
conn = None
|
||||||
cursor = None
|
cursor = None
|
||||||
try:
|
try:
|
||||||
@ -45,14 +47,24 @@ async def create_face(
|
|||||||
# 把文件转为二进制数组
|
# 把文件转为二进制数组
|
||||||
file_content = await file.read()
|
file_content = await file.read()
|
||||||
|
|
||||||
# 调用人脸识别得到特征值
|
# 计算特征值
|
||||||
|
flag, eigenvalue = add_binary_data(file_content)
|
||||||
|
|
||||||
|
if flag == False:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="未检测到人脸"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 打印数组长度
|
||||||
|
print(f"文件大小:{len(file_content)} 字节")
|
||||||
|
|
||||||
# 2. 插入数据库:无需传 ID(自增),只传 name 和 eigenvalue(None)
|
# 2. 插入数据库:无需传 ID(自增),只传 name 和 eigenvalue(None)
|
||||||
insert_query = """
|
insert_query = """
|
||||||
INSERT INTO face (name, eigenvalue)
|
INSERT INTO face (name, eigenvalue)
|
||||||
VALUES (%s, %s)
|
VALUES (%s, %s)
|
||||||
"""
|
"""
|
||||||
cursor.execute(insert_query, (face_create.name, None))
|
cursor.execute(insert_query, (face_create.name, str(eigenvalue)))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
# 3. 获取数据库自动生成的 ID(关键:用 LAST_INSERT_ID() 查刚插入的记录)
|
# 3. 获取数据库自动生成的 ID(关键:用 LAST_INSERT_ID() 查刚插入的记录)
|
||||||
@ -60,19 +72,45 @@ async def create_face(
|
|||||||
cursor.execute(select_new_query)
|
cursor.execute(select_new_query)
|
||||||
created_face = cursor.fetchone()
|
created_face = cursor.fetchone()
|
||||||
|
|
||||||
|
if not created_face:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="创建人脸记录成功,但无法获取新创建的记录"
|
||||||
|
)
|
||||||
|
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
code=201,
|
code=201,
|
||||||
message=f"人脸记录创建成功(ID:{created_face['id']},文件名:{file.filename})",
|
message=f"人脸记录创建成功(ID:{created_face['id']},文件名:{file.filename})",
|
||||||
data=FaceResponse(**created_face)
|
data=FaceResponse(** created_face)
|
||||||
)
|
)
|
||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
if conn:
|
if conn:
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
raise Exception(f"创建人脸记录失败:{str(e)}") from e
|
# 改为使用HTTPException
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"创建人脸记录失败:{str(e)}"
|
||||||
|
) from e
|
||||||
|
except Exception as e:
|
||||||
|
# 捕获其他可能的异常
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"服务器错误:{str(e)}"
|
||||||
|
) from e
|
||||||
finally:
|
finally:
|
||||||
await file.close() # 关闭文件流
|
await file.close() # 关闭文件流
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
# 调用人脸识别得到特征值(这里可以添加你的人脸识别逻辑)
|
||||||
|
flag, eigenvalue = add_binary_data(file_content)
|
||||||
|
if flag == False:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="未检测到人脸"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 将 eigenvalue 转为 str
|
||||||
|
eigenvalue = str(eigenvalue)
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# 2. 获取单个人脸记录(不变,用自增ID查询)
|
# 2. 获取单个人脸记录(不变,用自增ID查询)
|
||||||
@ -104,18 +142,21 @@ async def get_face(
|
|||||||
data=FaceResponse(**face)
|
data=FaceResponse(**face)
|
||||||
)
|
)
|
||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
raise Exception(f"查询人脸记录失败:{str(e)}") from e
|
# 改为使用HTTPException
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"查询人脸记录失败:{str(e)}"
|
||||||
|
) from e
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
# 后续 3.获取所有、4.更新、5.删除 接口(不变,仅用数据库自增的 ID 操作,无需修改)
|
# 后续 3.获取所有、4.更新、5.删除 接口(修复异常处理)
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# 3. 获取所有人脸记录(不变)
|
# 3. 获取所有人脸记录(不变)
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
@router.get("", response_model=APIResponse, summary="获取所有人脸记录")
|
@router.get("", response_model=APIResponse, summary="获取所有人脸记录")
|
||||||
async def get_all_faces(
|
async def get_all_faces(
|
||||||
current_user: UserResponse = Depends(get_current_user)
|
|
||||||
):
|
):
|
||||||
conn = None
|
conn = None
|
||||||
cursor = None
|
cursor = None
|
||||||
@ -130,10 +171,13 @@ async def get_all_faces(
|
|||||||
return APIResponse(
|
return APIResponse(
|
||||||
code=200,
|
code=200,
|
||||||
message="所有人脸记录查询成功",
|
message="所有人脸记录查询成功",
|
||||||
data=[FaceResponse(**face) for face in faces]
|
data=[FaceResponse(** face) for face in faces]
|
||||||
)
|
)
|
||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
raise Exception(f"查询所有人脸记录失败:{str(e)}") from e
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"查询所有人脸记录失败:{str(e)}"
|
||||||
|
) from e
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
@ -194,7 +238,10 @@ async def update_face(
|
|||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
if conn:
|
if conn:
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
raise Exception(f"更新人脸记录失败:{str(e)}") from e
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"更新人脸记录失败:{str(e)}"
|
||||||
|
) from e
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
@ -234,7 +281,10 @@ async def delete_face(
|
|||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
if conn:
|
if conn:
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
raise Exception(f"删除人脸记录失败:{str(e)}") from e
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"删除人脸记录失败:{str(e)}"
|
||||||
|
) from e
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
@ -249,38 +299,43 @@ def get_all_face_name_with_eigenvalue() -> dict:
|
|||||||
conn = None
|
conn = None
|
||||||
cursor = None
|
cursor = None
|
||||||
try:
|
try:
|
||||||
|
# 1. 建立数据库连接并获取游标(dictionary=True使结果以字典形式返回)
|
||||||
conn = db.get_connection()
|
conn = db.get_connection()
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
# 只查询需要的字段,提高效率
|
# 2. 执行SQL查询:只获取name非空的记录,减少数据传输
|
||||||
query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL"
|
query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL"
|
||||||
cursor.execute(query)
|
cursor.execute(query)
|
||||||
faces = cursor.fetchall()
|
faces = cursor.fetchall() # 返回结果:列表套字典,如 [{"name":"张三","eigenvalue":...}, ...]
|
||||||
|
|
||||||
# 先收集所有名称对应的特征值列表(处理重复名称)
|
# 3. 收集同一名称对应的所有特征值(处理名称重复场景)
|
||||||
name_to_eigenvalues = {}
|
name_to_eigenvalues = {}
|
||||||
for face in faces:
|
for face in faces:
|
||||||
name = face["name"]
|
name = face["name"]
|
||||||
eigenvalue = face["eigenvalue"]
|
eigenvalue = face["eigenvalue"]
|
||||||
|
# 若名称已存在,追加特征值;否则新建列表存储
|
||||||
if name in name_to_eigenvalues:
|
if name in name_to_eigenvalues:
|
||||||
name_to_eigenvalues[name].append(eigenvalue)
|
name_to_eigenvalues[name].append(eigenvalue)
|
||||||
else:
|
else:
|
||||||
name_to_eigenvalues[name] = [eigenvalue]
|
name_to_eigenvalues[name] = [eigenvalue]
|
||||||
|
|
||||||
# 构建最终字典:重复名称取平均特征值,唯一名称直接取特征值
|
# 4. 构建最终字典:重复名称取平均,唯一名称直接取特征值
|
||||||
face_dict = {}
|
face_dict = {}
|
||||||
for name, eigenvalues in name_to_eigenvalues.items():
|
for name, eigenvalues in name_to_eigenvalues.items():
|
||||||
print("调用的特征值是:" + eigenvalues)
|
|
||||||
|
# 处理特征值:多个则求平均,单个则直接使用
|
||||||
if len(eigenvalues) > 1:
|
if len(eigenvalues) > 1:
|
||||||
# 调用平均特征值计算方法
|
# 调用外部方法计算平均特征值(需确保binary_face_feature_handler已正确导入)
|
||||||
face_dict[name] = binary_face_feature_handler.get_average_feature(eigenvalues)
|
face_dict[name] = get_average_feature(eigenvalues)
|
||||||
else:
|
else:
|
||||||
|
# 取列表中唯一的特征值(避免value为列表类型)
|
||||||
face_dict[name] = eigenvalues[0]
|
face_dict[name] = eigenvalues[0]
|
||||||
|
|
||||||
return face_dict
|
return face_dict
|
||||||
|
|
||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
|
# 捕获数据库异常,添加上下文信息后重新抛出(便于定位问题)
|
||||||
raise Exception(f"获取人脸名称与特征值失败:{str(e)}") from e
|
raise Exception(f"获取人脸名称与特征值失败:{str(e)}") from e
|
||||||
finally:
|
finally:
|
||||||
# 确保资源释放
|
# 5. 无论是否异常,均释放数据库连接和游标(避免资源泄漏)
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
145
util/face_util.py
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import insightface
|
||||||
|
from insightface.app import FaceAnalysis
|
||||||
|
from io import BytesIO
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
# 全局变量存储InsightFace引擎和特征列表
|
||||||
|
_insightface_app = None
|
||||||
|
_feature_list = []
|
||||||
|
|
||||||
|
|
||||||
|
def init_insightface():
|
||||||
|
"""初始化InsightFace引擎"""
|
||||||
|
global _insightface_app
|
||||||
|
try:
|
||||||
|
print("正在初始化InsightFace引擎...")
|
||||||
|
app = FaceAnalysis(name='buffalo_l', root='~/.insightface')
|
||||||
|
app.prepare(ctx_id=0, det_size=(640, 640))
|
||||||
|
print("InsightFace引擎初始化完成")
|
||||||
|
_insightface_app = app
|
||||||
|
return app
|
||||||
|
except Exception as e:
|
||||||
|
print(f"InsightFace初始化失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def add_binary_data(binary_data):
|
||||||
|
"""
|
||||||
|
接收单张图片的二进制数据,提取特征并保存
|
||||||
|
|
||||||
|
参数:
|
||||||
|
binary_data: 图片的二进制数据(bytes类型)
|
||||||
|
|
||||||
|
返回:
|
||||||
|
成功提取特征时返回 (True, 特征值numpy数组)
|
||||||
|
失败时返回 (False, None)
|
||||||
|
"""
|
||||||
|
global _insightface_app, _feature_list
|
||||||
|
|
||||||
|
if not _insightface_app:
|
||||||
|
print("引擎未初始化,无法处理")
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 直接处理二进制数据:转换为图像格式
|
||||||
|
img = Image.open(BytesIO(binary_data))
|
||||||
|
frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
|
# 提取特征
|
||||||
|
faces = _insightface_app.get(frame)
|
||||||
|
if faces:
|
||||||
|
# 获取当前提取的特征值
|
||||||
|
current_feature = faces[0].embedding
|
||||||
|
# 添加到特征列表
|
||||||
|
_feature_list.append(current_feature)
|
||||||
|
print(f"已累计 {len(_feature_list)} 个特征")
|
||||||
|
# 返回成功标志和当前特征值
|
||||||
|
return True, current_feature
|
||||||
|
else:
|
||||||
|
print("二进制数据中未检测到人脸")
|
||||||
|
return False, None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理二进制数据出错: {e}")
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
|
||||||
|
def get_average_feature(features=None):
|
||||||
|
"""
|
||||||
|
计算多个特征向量的平均值
|
||||||
|
|
||||||
|
参数:
|
||||||
|
features: 可选,特征值列表。如果未提供,则使用全局存储的_feature_list
|
||||||
|
每个元素可以是字符串格式或numpy数组
|
||||||
|
|
||||||
|
返回:
|
||||||
|
单一平均特征向量的numpy数组,若无可计算数据则返回None
|
||||||
|
"""
|
||||||
|
global _feature_list
|
||||||
|
|
||||||
|
# 如果未提供features参数,则使用全局特征列表
|
||||||
|
if features is None:
|
||||||
|
features = _feature_list
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 验证输入是否为列表且不为空
|
||||||
|
if not isinstance(features, list) or len(features) == 0:
|
||||||
|
print("输入必须是包含至少一个特征值的列表")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 处理每个特征值
|
||||||
|
processed_features = []
|
||||||
|
for i, embedding in enumerate(features):
|
||||||
|
try:
|
||||||
|
if isinstance(embedding, str):
|
||||||
|
# 处理包含括号和逗号的字符串格式
|
||||||
|
embedding_str = embedding.replace('[', '').replace(']', '').replace(',', ' ').strip()
|
||||||
|
embedding_list = [float(num) for num in embedding_str.split() if num.strip()]
|
||||||
|
embedding_np = np.array(embedding_list, dtype=np.float32)
|
||||||
|
else:
|
||||||
|
embedding_np = np.array(embedding, dtype=np.float32)
|
||||||
|
|
||||||
|
# 验证特征值格式
|
||||||
|
if len(embedding_np.shape) == 1:
|
||||||
|
processed_features.append(embedding_np)
|
||||||
|
print(f"已添加第 {i + 1} 个特征值用于计算平均值")
|
||||||
|
else:
|
||||||
|
print(f"跳过第 {i + 1} 个特征值,不是一维数组")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理第 {i + 1} 个特征值时出错: {e}")
|
||||||
|
|
||||||
|
# 确保有有效的特征值
|
||||||
|
if not processed_features:
|
||||||
|
print("没有有效的特征值用于计算平均值")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 检查所有特征向量维度是否相同
|
||||||
|
dims = {feat.shape[0] for feat in processed_features}
|
||||||
|
if len(dims) > 1:
|
||||||
|
print(f"特征值维度不一致,无法计算平均值。检测到的维度: {dims}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 计算平均值
|
||||||
|
avg_feature = np.mean(processed_features, axis=0)
|
||||||
|
print(f"成功计算 {len(processed_features)} 个特征值的平均特征向量,维度: {avg_feature.shape[0]}")
|
||||||
|
|
||||||
|
return avg_feature
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"计算平均特征值时出错: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def clear_features():
|
||||||
|
"""清空已存储的特征数据"""
|
||||||
|
global _feature_list
|
||||||
|
_feature_list = []
|
||||||
|
print("已清空所有特征数据")
|
||||||
|
|
||||||
|
|
||||||
|
def get_feature_list():
|
||||||
|
"""获取当前存储的特征列表"""
|
||||||
|
global _feature_list
|
||||||
|
return _feature_list.copy() # 返回副本防止外部直接修改
|
482
ws.html
@ -1,482 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html lang="zh-CN">
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
||||||
<title>WebSocket 测试工具</title>
|
|
||||||
<style>
|
|
||||||
* {
|
|
||||||
box-sizing: border-box;
|
|
||||||
margin: 0;
|
|
||||||
padding: 0;
|
|
||||||
font-family: 'Arial', 'Microsoft YaHei', sans-serif;
|
|
||||||
}
|
|
||||||
|
|
||||||
body {
|
|
||||||
max-width: 1200px;
|
|
||||||
margin: 20px auto;
|
|
||||||
padding: 0 20px;
|
|
||||||
background-color: #f5f7fa;
|
|
||||||
}
|
|
||||||
|
|
||||||
.container {
|
|
||||||
background: white;
|
|
||||||
border-radius: 8px;
|
|
||||||
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
|
|
||||||
padding: 25px;
|
|
||||||
margin-bottom: 20px;
|
|
||||||
}
|
|
||||||
|
|
||||||
h1 {
|
|
||||||
color: #2c3e50;
|
|
||||||
margin-bottom: 20px;
|
|
||||||
font-size: 24px;
|
|
||||||
border-bottom: 2px solid #3498db;
|
|
||||||
padding-bottom: 10px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.status-bar {
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
gap: 15px;
|
|
||||||
margin-bottom: 20px;
|
|
||||||
padding: 12px 15px;
|
|
||||||
background-color: #f8f9fa;
|
|
||||||
border-radius: 6px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.status-label {
|
|
||||||
font-weight: bold;
|
|
||||||
color: #495057;
|
|
||||||
}
|
|
||||||
|
|
||||||
.status-value {
|
|
||||||
padding: 4px 10px;
|
|
||||||
border-radius: 4px;
|
|
||||||
font-weight: bold;
|
|
||||||
}
|
|
||||||
|
|
||||||
.status-connected {
|
|
||||||
background-color: #d4edda;
|
|
||||||
color: #155724;
|
|
||||||
}
|
|
||||||
|
|
||||||
.status-disconnected {
|
|
||||||
background-color: #f8d7da;
|
|
||||||
color: #721c24;
|
|
||||||
}
|
|
||||||
|
|
||||||
.status-connecting {
|
|
||||||
background-color: #fff3cd;
|
|
||||||
color: #856404;
|
|
||||||
}
|
|
||||||
|
|
||||||
.btn {
|
|
||||||
padding: 8px 16px;
|
|
||||||
border: none;
|
|
||||||
border-radius: 4px;
|
|
||||||
cursor: pointer;
|
|
||||||
font-size: 14px;
|
|
||||||
font-weight: 500;
|
|
||||||
transition: background-color 0.2s;
|
|
||||||
}
|
|
||||||
|
|
||||||
.btn-primary {
|
|
||||||
background-color: #3498db;
|
|
||||||
color: white;
|
|
||||||
}
|
|
||||||
|
|
||||||
.btn-primary:hover {
|
|
||||||
background-color: #2980b9;
|
|
||||||
}
|
|
||||||
|
|
||||||
.btn-danger {
|
|
||||||
background-color: #e74c3c;
|
|
||||||
color: white;
|
|
||||||
}
|
|
||||||
|
|
||||||
.btn-danger:hover {
|
|
||||||
background-color: #c0392b;
|
|
||||||
}
|
|
||||||
|
|
||||||
.btn-success {
|
|
||||||
background-color: #2ecc71;
|
|
||||||
color: white;
|
|
||||||
}
|
|
||||||
|
|
||||||
.btn-success:hover {
|
|
||||||
background-color: #27ae60;
|
|
||||||
}
|
|
||||||
|
|
||||||
.control-group {
|
|
||||||
display: flex;
|
|
||||||
gap: 15px;
|
|
||||||
margin-bottom: 20px;
|
|
||||||
align-items: center;
|
|
||||||
}
|
|
||||||
|
|
||||||
.input-group {
|
|
||||||
display: flex;
|
|
||||||
gap: 10px;
|
|
||||||
align-items: center;
|
|
||||||
}
|
|
||||||
|
|
||||||
.input-group label {
|
|
||||||
color: #495057;
|
|
||||||
font-weight: 500;
|
|
||||||
}
|
|
||||||
|
|
||||||
.input-group input, .input-group select {
|
|
||||||
padding: 8px 12px;
|
|
||||||
border: 1px solid #ced4da;
|
|
||||||
border-radius: 4px;
|
|
||||||
font-size: 14px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.message-area {
|
|
||||||
margin-top: 20px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.message-input {
|
|
||||||
width: 100%;
|
|
||||||
height: 100px;
|
|
||||||
padding: 12px;
|
|
||||||
border: 1px solid #ced4da;
|
|
||||||
border-radius: 6px;
|
|
||||||
resize: none;
|
|
||||||
font-size: 14px;
|
|
||||||
margin-bottom: 10px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.log-area {
|
|
||||||
width: 100%;
|
|
||||||
height: 300px;
|
|
||||||
padding: 15px;
|
|
||||||
border: 1px solid #ced4da;
|
|
||||||
border-radius: 6px;
|
|
||||||
background-color: #f8f9fa;
|
|
||||||
overflow-y: auto;
|
|
||||||
font-size: 14px;
|
|
||||||
line-height: 1.6;
|
|
||||||
}
|
|
||||||
|
|
||||||
.log-item {
|
|
||||||
margin-bottom: 8px;
|
|
||||||
padding-bottom: 8px;
|
|
||||||
border-bottom: 1px dashed #e9ecef;
|
|
||||||
}
|
|
||||||
|
|
||||||
.log-time {
|
|
||||||
color: #6c757d;
|
|
||||||
font-size: 12px;
|
|
||||||
margin-right: 10px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.log-send {
|
|
||||||
color: #2980b9;
|
|
||||||
}
|
|
||||||
|
|
||||||
.log-receive {
|
|
||||||
color: #27ae60;
|
|
||||||
}
|
|
||||||
|
|
||||||
.log-status {
|
|
||||||
color: #856404;
|
|
||||||
}
|
|
||||||
|
|
||||||
.log-error {
|
|
||||||
color: #e74c3c;
|
|
||||||
}
|
|
||||||
</style>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<div class="container">
|
|
||||||
<h1>WebSocket 测试工具</h1>
|
|
||||||
|
|
||||||
<!-- 连接状态区 -->
|
|
||||||
<div class="status-bar">
|
|
||||||
<div class="status-label">连接状态:</div>
|
|
||||||
<div id="connectionStatus" class="status-value status-disconnected">未连接</div>
|
|
||||||
<div class="status-label">服务地址:</div>
|
|
||||||
<div id="wsUrl" class="status-value">ws://192.168.110.25:8000/ws</div>
|
|
||||||
<div class="status-label">连接时间:</div>
|
|
||||||
<div id="connectTime" class="status-value">-</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- 控制按钮区 -->
|
|
||||||
<div class="control-group">
|
|
||||||
<button id="connectBtn" class="btn btn-primary">建立连接</button>
|
|
||||||
<button id="disconnectBtn" class="btn btn-danger" disabled>断开连接</button>
|
|
||||||
|
|
||||||
<!-- 心跳控制 -->
|
|
||||||
<div class="input-group">
|
|
||||||
<label>自动心跳:</label>
|
|
||||||
<select id="autoHeartbeat">
|
|
||||||
<option value="on">开启</option>
|
|
||||||
<option value="off">关闭</option>
|
|
||||||
</select>
|
|
||||||
<label>间隔(秒):</label>
|
|
||||||
<input type="number" id="heartbeatInterval" value="30" min="10" max="120" style="width: 80px;">
|
|
||||||
<button id="sendHeartbeatBtn" class="btn btn-success">手动发送心跳</button>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- 自定义消息发送区 -->
|
|
||||||
<div class="message-area">
|
|
||||||
<h3>发送自定义消息</h3>
|
|
||||||
<textarea id="messageInput" class="message-input"
|
|
||||||
placeholder='示例:{"type":"test","content":"Hello WebSocket"}'>{"type":"test","content":"Hello WebSocket"}</textarea>
|
|
||||||
<button id="sendMessageBtn" class="btn btn-primary" disabled>发送消息</button>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- 日志显示区 -->
|
|
||||||
<div class="message-area">
|
|
||||||
<h3>消息日志</h3>
|
|
||||||
<div id="logContainer" class="log-area">
|
|
||||||
<div class="log-item"><span class="log-time">[加载完成]</span> 请点击「建立连接」开始测试</div>
|
|
||||||
</div>
|
|
||||||
<button id="clearLogBtn" class="btn btn-primary" style="margin-top: 10px;">清空日志</button>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<script>
|
|
||||||
// 全局变量
|
|
||||||
let ws = null;
|
|
||||||
let heartbeatTimer = null;
|
|
||||||
const wsUrl = "ws://192.168.110.25:8000/ws";
|
|
||||||
|
|
||||||
// DOM 元素
|
|
||||||
const connectionStatus = document.getElementById('connectionStatus');
|
|
||||||
const connectTime = document.getElementById('connectTime');
|
|
||||||
const connectBtn = document.getElementById('connectBtn');
|
|
||||||
const disconnectBtn = document.getElementById('disconnectBtn');
|
|
||||||
const sendMessageBtn = document.getElementById('sendMessageBtn');
|
|
||||||
const sendHeartbeatBtn = document.getElementById('sendHeartbeatBtn');
|
|
||||||
const autoHeartbeat = document.getElementById('autoHeartbeat');
|
|
||||||
const heartbeatInterval = document.getElementById('heartbeatInterval');
|
|
||||||
const messageInput = document.getElementById('messageInput');
|
|
||||||
const logContainer = document.getElementById('logContainer');
|
|
||||||
const clearLogBtn = document.getElementById('clearLogBtn');
|
|
||||||
|
|
||||||
// 工具函数:添加日志
|
|
||||||
function addLog(content, type = 'status') {
|
|
||||||
const now = new Date().toLocaleString('zh-CN', {
|
|
||||||
year: 'numeric', month: '2-digit', day: '2-digit',
|
|
||||||
hour: '2-digit', minute: '2-digit', second: '2-digit'
|
|
||||||
});
|
|
||||||
const logItem = document.createElement('div');
|
|
||||||
logItem.className = 'log-item';
|
|
||||||
|
|
||||||
let logClass = '';
|
|
||||||
switch (type) {
|
|
||||||
case 'send':
|
|
||||||
logClass = 'log-send';
|
|
||||||
break;
|
|
||||||
case 'receive':
|
|
||||||
logClass = 'log-receive';
|
|
||||||
break;
|
|
||||||
case 'error':
|
|
||||||
logClass = 'log-error';
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
logClass = 'log-status';
|
|
||||||
}
|
|
||||||
|
|
||||||
logItem.innerHTML = `<span class="log-time">[${now}]</span> <span class="${logClass}">${content}</span>`;
|
|
||||||
logContainer.appendChild(logItem);
|
|
||||||
// 滚动到最新日志
|
|
||||||
logContainer.scrollTop = logContainer.scrollHeight;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 工具函数:格式化JSON(便于日志显示)
|
|
||||||
function formatJson(jsonStr) {
|
|
||||||
try {
|
|
||||||
const obj = JSON.parse(jsonStr);
|
|
||||||
return JSON.stringify(obj, null, 2);
|
|
||||||
} catch (e) {
|
|
||||||
return jsonStr; // 非JSON格式直接返回
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 建立WebSocket连接
|
|
||||||
function connectWebSocket() {
|
|
||||||
if (ws) {
|
|
||||||
addLog('已存在连接,无需重复建立', 'error');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
ws = new WebSocket(wsUrl);
|
|
||||||
|
|
||||||
// 连接成功
|
|
||||||
ws.onopen = function () {
|
|
||||||
connectionStatus.className = 'status-value status-connected';
|
|
||||||
connectionStatus.textContent = '已连接';
|
|
||||||
const now = new Date().toLocaleString('zh-CN');
|
|
||||||
connectTime.textContent = now;
|
|
||||||
addLog(`连接成功!服务地址:${wsUrl}`, 'status');
|
|
||||||
|
|
||||||
// 更新按钮状态
|
|
||||||
connectBtn.disabled = true;
|
|
||||||
disconnectBtn.disabled = false;
|
|
||||||
sendMessageBtn.disabled = false;
|
|
||||||
|
|
||||||
// 开启自动心跳(默认开启)
|
|
||||||
if (autoHeartbeat.value === 'on') {
|
|
||||||
startAutoHeartbeat();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// 接收消息
|
|
||||||
ws.onmessage = function (event) {
|
|
||||||
const message = event.data;
|
|
||||||
addLog(`收到消息:\n${formatJson(message)}`, 'receive');
|
|
||||||
};
|
|
||||||
|
|
||||||
// 连接关闭
|
|
||||||
ws.onclose = function (event) {
|
|
||||||
connectionStatus.className = 'status-value status-disconnected';
|
|
||||||
connectionStatus.textContent = '已断开';
|
|
||||||
addLog(`连接断开!代码:${event.code},原因:${event.reason || '未知'}`, 'status');
|
|
||||||
|
|
||||||
// 清除自动心跳
|
|
||||||
stopAutoHeartbeat();
|
|
||||||
|
|
||||||
// 更新按钮状态
|
|
||||||
connectBtn.disabled = false;
|
|
||||||
disconnectBtn.disabled = true;
|
|
||||||
sendMessageBtn.disabled = true;
|
|
||||||
|
|
||||||
// 重置WebSocket对象
|
|
||||||
ws = null;
|
|
||||||
};
|
|
||||||
|
|
||||||
// 连接错误
|
|
||||||
ws.onerror = function (error) {
|
|
||||||
addLog(`连接错误:${error.message || '未知错误'}`, 'error');
|
|
||||||
};
|
|
||||||
|
|
||||||
} catch (e) {
|
|
||||||
addLog(`建立连接失败:${e.message}`, 'error');
|
|
||||||
ws = null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 断开WebSocket连接
|
|
||||||
function disconnectWebSocket() {
|
|
||||||
if (!ws) {
|
|
||||||
addLog('当前无连接,无需断开', 'error');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
ws.close(1000, '手动断开连接');
|
|
||||||
}
|
|
||||||
|
|
||||||
// 发送心跳消息(符合约定格式:{"timestamp":xxxxx, "type":"heartbeat"})
|
|
||||||
function sendHeartbeat() {
|
|
||||||
if (!ws || ws.readyState !== WebSocket.OPEN) {
|
|
||||||
addLog('发送心跳失败:当前无有效连接', 'error');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const heartbeatMsg = {
|
|
||||||
timestamp: Date.now(), // 当前毫秒时间戳
|
|
||||||
type: "heartbeat"
|
|
||||||
};
|
|
||||||
const msgStr = JSON.stringify(heartbeatMsg);
|
|
||||||
|
|
||||||
ws.send(msgStr);
|
|
||||||
addLog(`发送心跳:\n${formatJson(msgStr)}`, 'send');
|
|
||||||
}
|
|
||||||
|
|
||||||
// 开启自动心跳
|
|
||||||
function startAutoHeartbeat() {
|
|
||||||
// 先停止已有定时器
|
|
||||||
stopAutoHeartbeat();
|
|
||||||
|
|
||||||
const interval = parseInt(heartbeatInterval.value) * 1000;
|
|
||||||
if (isNaN(interval) || interval < 10000) {
|
|
||||||
addLog('自动心跳间隔无效,已重置为30秒', 'error');
|
|
||||||
heartbeatInterval.value = 30;
|
|
||||||
return startAutoHeartbeat();
|
|
||||||
}
|
|
||||||
|
|
||||||
addLog(`开启自动心跳,间隔:${heartbeatInterval.value}秒`, 'status');
|
|
||||||
heartbeatTimer = setInterval(sendHeartbeat, interval);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 停止自动心跳
|
|
||||||
function stopAutoHeartbeat() {
|
|
||||||
if (heartbeatTimer) {
|
|
||||||
clearInterval(heartbeatTimer);
|
|
||||||
heartbeatTimer = null;
|
|
||||||
addLog('已停止自动心跳', 'status');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 发送自定义消息
|
|
||||||
function sendCustomMessage() {
|
|
||||||
if (!ws || ws.readyState !== WebSocket.OPEN) {
|
|
||||||
addLog('发送消息失败:当前无有效连接', 'error');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const msgStr = messageInput.value.trim();
|
|
||||||
if (!msgStr) {
|
|
||||||
addLog('发送消息失败:消息内容不能为空', 'error');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
// 验证JSON格式(可选,仅提示不强制)
|
|
||||||
JSON.parse(msgStr);
|
|
||||||
ws.send(msgStr);
|
|
||||||
addLog(`发送自定义消息:\n${formatJson(msgStr)}`, 'send');
|
|
||||||
} catch (e) {
|
|
||||||
addLog(`JSON格式错误:${e.message},仍尝试发送原始内容`, 'error');
|
|
||||||
ws.send(msgStr);
|
|
||||||
addLog(`发送自定义消息(非JSON):\n${msgStr}`, 'send');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 绑定按钮事件
|
|
||||||
connectBtn.addEventListener('click', connectWebSocket);
|
|
||||||
disconnectBtn.addEventListener('click', disconnectWebSocket);
|
|
||||||
sendMessageBtn.addEventListener('click', sendCustomMessage);
|
|
||||||
sendHeartbeatBtn.addEventListener('click', sendHeartbeat);
|
|
||||||
clearLogBtn.addEventListener('click', () => {
|
|
||||||
logContainer.innerHTML = '<div class="log-item"><span class="log-time">[日志已清空]</span> 请继续操作...</div>';
|
|
||||||
});
|
|
||||||
|
|
||||||
// 自动心跳开关变更事件
|
|
||||||
autoHeartbeat.addEventListener('change', function () {
|
|
||||||
if (ws && ws.readyState === WebSocket.OPEN) {
|
|
||||||
if (this.value === 'on') {
|
|
||||||
startAutoHeartbeat();
|
|
||||||
} else {
|
|
||||||
stopAutoHeartbeat();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
addLog('需先建立有效连接才能控制自动心跳', 'error');
|
|
||||||
// 重置选择
|
|
||||||
this.value = 'off';
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// 心跳间隔变更事件(实时生效)
|
|
||||||
heartbeatInterval.addEventListener('change', function () {
|
|
||||||
if (autoHeartbeat.value === 'on' && ws && ws.readyState === WebSocket.OPEN) {
|
|
||||||
startAutoHeartbeat();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// 快捷键支持(Ctrl+Enter发送消息)
|
|
||||||
messageInput.addEventListener('keydown', function (e) {
|
|
||||||
if (e.ctrlKey && e.key === 'Enter') {
|
|
||||||
sendCustomMessage();
|
|
||||||
e.preventDefault();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
</script>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
356
ws/ws.py
@ -4,314 +4,300 @@ import json
|
|||||||
import os
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Dict, Optional, AsyncGenerator
|
from typing import Dict, Optional, AsyncGenerator
|
||||||
from concurrent.futures import ThreadPoolExecutor # 新增:显式线程池
|
|
||||||
|
|
||||||
from service.device_service import update_online_status_by_ip, increment_alarm_count_by_ip
|
from service.device_service import update_online_status_by_ip, increment_alarm_count_by_ip
|
||||||
from service.device_action_service import add_device_action
|
from service.device_action_service import add_device_action
|
||||||
from schema.device_action_schema import DeviceActionCreate
|
from schema.device_action_schema import DeviceActionCreate
|
||||||
|
from core.all import detect
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI
|
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI
|
||||||
from queue import Queue # 线程安全队列,无需额外Lock
|
from core.all import load_model
|
||||||
|
|
||||||
from ocr.model_violation_detector import MultiModelViolationDetector
|
# 配置常量
|
||||||
|
HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒)
|
||||||
# -------------------------- 配置调整 --------------------------
|
|
||||||
# 模型路径(建议改为环境变量)
|
|
||||||
YOLO_MODEL_PATH = r"D:\Git\bin\video\ocr\models\best.pt"
|
|
||||||
OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml"
|
|
||||||
|
|
||||||
# 核心优化:模型池大小(决定最大并发任务数,显存占用=大小×单模型显存)
|
|
||||||
MODEL_POOL_SIZE = 5 # 示例:设为5,支持5个任务并行,显存会明显上升
|
|
||||||
THREAD_POOL_SIZE = MODEL_POOL_SIZE * 2 # 线程池大小≥模型池,避免线程瓶颈
|
|
||||||
|
|
||||||
# 其他配置
|
|
||||||
HEARTBEAT_INTERVAL = 30 # 心跳间隔(秒)
|
|
||||||
HEARTBEAT_TIMEOUT = 600 # 客户端超时阈值(秒)
|
HEARTBEAT_TIMEOUT = 600 # 客户端超时阈值(秒)
|
||||||
WS_ENDPOINT = "/ws" # WebSocket端点
|
WS_ENDPOINT = "/ws" # WebSocket端点路径
|
||||||
FRAME_QUEUE_SIZE = 5 # 增大帧队列,允许缓存更多帧(避免丢帧)
|
FRAME_QUEUE_SIZE = 1 # 帧队列大小限制
|
||||||
|
|
||||||
# -------------------------- 工具函数 --------------------------
|
|
||||||
|
# 工具函数:获取格式化时间字符串(统一时间戳格式)
|
||||||
def get_current_time_str() -> str:
|
def get_current_time_str() -> str:
|
||||||
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
|
|
||||||
def get_current_time_file_str() -> str:
|
def get_current_time_file_str() -> str:
|
||||||
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||||
|
|
||||||
# -------------------------- 模型池重构(核心修改1) --------------------------
|
|
||||||
class ModelPool:
|
|
||||||
def __init__(self, pool_size: int = MODEL_POOL_SIZE):
|
|
||||||
self.pool = Queue(maxsize=pool_size)
|
|
||||||
# 移除冗余Lock:Queue.get()/put()本身线程安全
|
|
||||||
self._init_models(pool_size)
|
|
||||||
print(f"[{get_current_time_str()}] 模型池初始化完成(共{pool_size}个实例,显存已预分配)")
|
|
||||||
|
|
||||||
def _init_models(self, pool_size: int):
|
# 客户端连接封装
|
||||||
"""预加载所有模型实例(初始化时显存会一次性上升)"""
|
|
||||||
for i in range(pool_size):
|
|
||||||
try:
|
|
||||||
detector = MultiModelViolationDetector(
|
|
||||||
ocr_config_path=OCR_CONFIG_PATH,
|
|
||||||
yolo_model_path=YOLO_MODEL_PATH,
|
|
||||||
ocr_confidence_threshold=0.5
|
|
||||||
)
|
|
||||||
self.pool.put(detector)
|
|
||||||
print(f"[{get_current_time_str()}] 模型实例{i+1}/{pool_size}加载完成")
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"模型实例{i+1}加载失败:{str(e)}")
|
|
||||||
|
|
||||||
def get_model(self) -> MultiModelViolationDetector:
|
|
||||||
"""获取模型(阻塞直到有空闲实例,确保并发安全)"""
|
|
||||||
return self.pool.get()
|
|
||||||
|
|
||||||
def return_model(self, detector: MultiModelViolationDetector):
|
|
||||||
"""归还模型(立即释放资源供其他任务使用)"""
|
|
||||||
self.pool.put(detector)
|
|
||||||
|
|
||||||
# -------------------------- 全局资源初始化 --------------------------
|
|
||||||
model_pool = ModelPool(pool_size=MODEL_POOL_SIZE) # 初始化模型池(预占显存)
|
|
||||||
thread_pool = ThreadPoolExecutor( # 显式创建线程池(核心修改2)
|
|
||||||
max_workers=THREAD_POOL_SIZE,
|
|
||||||
thread_name_prefix="ModelWorker-" # 线程命名,便于调试
|
|
||||||
)
|
|
||||||
|
|
||||||
# -------------------------- 客户端连接封装(核心修改3) --------------------------
|
|
||||||
class ClientConnection:
|
class ClientConnection:
|
||||||
def __init__(self, websocket: WebSocket, client_ip: str):
|
def __init__(self, websocket: WebSocket, client_ip: str):
|
||||||
self.websocket = websocket
|
self.websocket = websocket
|
||||||
self.client_ip = client_ip
|
self.client_ip = client_ip
|
||||||
self.last_heartbeat = datetime.datetime.now()
|
self.last_heartbeat = datetime.datetime.now()
|
||||||
self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE) # 增大队列
|
self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE)
|
||||||
self.consumer_task: Optional[asyncio.Task] = None
|
self.consumer_task: Optional[asyncio.Task] = None
|
||||||
# 移除“客户端独占模型”:不再持有detector属性
|
|
||||||
|
|
||||||
def update_heartbeat(self):
|
def update_heartbeat(self):
|
||||||
|
"""更新心跳时间(客户端发送心跳时调用)"""
|
||||||
self.last_heartbeat = datetime.datetime.now()
|
self.last_heartbeat = datetime.datetime.now()
|
||||||
|
|
||||||
def is_alive(self) -> bool:
|
def is_alive(self) -> bool:
|
||||||
|
"""判断客户端是否存活(心跳超时检查)"""
|
||||||
timeout = (datetime.datetime.now() - self.last_heartbeat).total_seconds()
|
timeout = (datetime.datetime.now() - self.last_heartbeat).total_seconds()
|
||||||
return timeout < HEARTBEAT_TIMEOUT
|
return timeout < HEARTBEAT_TIMEOUT
|
||||||
|
|
||||||
def start_consumer(self):
|
def start_consumer(self):
|
||||||
"""启动帧消费任务(每个客户端一个独立任务)"""
|
"""启动帧消费任务"""
|
||||||
self.consumer_task = asyncio.create_task(self.consume_frames())
|
self.consumer_task = asyncio.create_task(self.consume_frames())
|
||||||
return self.consumer_task
|
return self.consumer_task
|
||||||
|
|
||||||
async def send_frame_permit(self):
|
async def send_frame_permit(self):
|
||||||
"""发送帧许可信号(允许客户端继续发帧)"""
|
"""
|
||||||
|
发送「帧发送许可信号」
|
||||||
|
通知客户端可发送下一帧图像
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
await self.websocket.send_json({
|
frame_permit_msg = {
|
||||||
"type": "frame",
|
"type": "frame",
|
||||||
"timestamp": get_current_time_str(),
|
"timestamp": get_current_time_str(),
|
||||||
"client_ip": self.client_ip
|
"client_ip": self.client_ip
|
||||||
})
|
}
|
||||||
|
await self.websocket.send_json(frame_permit_msg)
|
||||||
|
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:已发送帧发送许可信号(取帧后立即通知)")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧许可发送失败 - {str(e)}")
|
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧许可信号发送失败 - {str(e)}")
|
||||||
|
|
||||||
async def consume_frames(self) -> None:
|
async def consume_frames(self) -> None:
|
||||||
"""消费帧队列(并发核心:每帧临时借模型处理)"""
|
"""消费队列中的帧并处理(核心调整:取帧后立即发许可,再处理帧)"""
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
# 1. 从队列取帧(无帧时阻塞)
|
# 1. 从队列取出帧(阻塞直到有帧可用)
|
||||||
frame_data = await self.frame_queue.get()
|
frame_data = await self.frame_queue.get()
|
||||||
# 2. 立即发送下一帧许可(让客户端持续发帧,积累并发任务)
|
|
||||||
await self.send_frame_permit()
|
# -------------------------- 核心修改:取出帧后立即发送下一帧许可 --------------------------
|
||||||
|
await self.send_frame_permit() # 取帧即通知客户端发下一帧,无需等处理完成
|
||||||
|
# -----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 3. 并行处理帧(核心:任务级借模型)
|
# 2. 处理取出的帧(即使处理慢,客户端也已收到许可,可提前准备下一帧)
|
||||||
await self.process_frame(frame_data)
|
await self.process_frame(frame_data)
|
||||||
finally:
|
finally:
|
||||||
self.frame_queue.task_done() # 标记帧处理完成
|
# 3. 标记帧任务完成(无论处理成功/失败,都需清理队列)
|
||||||
|
self.frame_queue.task_done()
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧消费任务已取消")
|
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧消费任务已取消")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:消费逻辑错误 - {str(e)}")
|
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧消费逻辑错误 - {str(e)}")
|
||||||
|
|
||||||
async def process_frame(self, frame_data: bytes) -> None:
|
async def process_frame(self, frame_data: bytes) -> None:
|
||||||
"""处理单帧(核心修改4:任务级借还模型)"""
|
"""处理单帧图像数据(检测违规后发送危险通知 + 调用违规次数加一方法)"""
|
||||||
# 1. 临时借用模型(阻塞直到有空闲实例,显存随借用数上升)
|
# 二进制数据转OpenCV图像
|
||||||
detector = model_pool.get_model()
|
|
||||||
try:
|
|
||||||
# 2. 二进制转OpenCV图像
|
|
||||||
nparr = np.frombuffer(frame_data, np.uint8)
|
nparr = np.frombuffer(frame_data, np.uint8)
|
||||||
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||||
if img is None:
|
if img is None:
|
||||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:图像解析失败")
|
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:无法解析图像数据")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 3. 保存图像(可选)
|
# 确保图像保存目录存在
|
||||||
os.makedirs('images', exist_ok=True)
|
os.makedirs('images', exist_ok=True)
|
||||||
|
|
||||||
|
# 保存图像(按IP+时间戳命名,避免冲突)
|
||||||
filename = f"images/{self.client_ip.replace('.', '_')}_{get_current_time_file_str()}.jpg"
|
filename = f"images/{self.client_ip.replace('.', '_')}_{get_current_time_file_str()}.jpg"
|
||||||
|
try:
|
||||||
cv2.imwrite(filename, img)
|
cv2.imwrite(filename, img)
|
||||||
|
print(f"[{get_current_time_str()}] 图像已保存至:{filename}")
|
||||||
# 4. 显式线程池执行AI检测(真正并发,无线程瓶颈)
|
has_violation, data, type = detect(img)
|
||||||
loop = asyncio.get_running_loop()
|
print(has_violation)
|
||||||
has_violation, violation_type, details = await loop.run_in_executor(
|
print(type)
|
||||||
thread_pool, # 用自定义线程池,避免默认线程不足
|
print(data)
|
||||||
detector.detect_violations, # 临时借用的模型
|
|
||||||
img # 输入图像
|
|
||||||
)
|
|
||||||
|
|
||||||
# 5. 违规处理(与原逻辑一致)
|
|
||||||
if has_violation:
|
if has_violation:
|
||||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:违规 - {violation_type}")
|
print(
|
||||||
# 违规次数更新(用线程池避免阻塞事件循环)
|
f"[{get_current_time_str()}] 客户端{self.client_ip}:检测到违规 - 类型: {type}, 详情: {data}")
|
||||||
await loop.run_in_executor(thread_pool, increment_alarm_count_by_ip, self.client_ip)
|
|
||||||
# 发送危险通知
|
# 调用违规次数加一方法
|
||||||
await self.websocket.send_json({
|
try:
|
||||||
|
await asyncio.to_thread(increment_alarm_count_by_ip, self.client_ip)
|
||||||
|
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:违规次数已+1")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:违规次数更新失败 - {str(e)}")
|
||||||
|
|
||||||
|
# 发送「危险通知」
|
||||||
|
danger_msg = {
|
||||||
"type": "danger",
|
"type": "danger",
|
||||||
"timestamp": get_current_time_str(),
|
"timestamp": get_current_time_str(),
|
||||||
"client_ip": self.client_ip,
|
"client_ip": self.client_ip
|
||||||
"violation_type": violation_type,
|
}
|
||||||
"details": details
|
await self.websocket.send_json(danger_msg)
|
||||||
})
|
|
||||||
else:
|
else:
|
||||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:无违规")
|
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:未检测到违规")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧处理错误 - {str(e)}")
|
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:图像处理错误 - {str(e)}")
|
||||||
finally:
|
|
||||||
# 6. 无论成功/失败,强制归还模型(核心:释放资源供其他任务使用)
|
|
||||||
model_pool.return_model(detector)
|
|
||||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:模型已归还(可复用)")
|
|
||||||
|
|
||||||
# -------------------------- 全局状态与心跳 --------------------------
|
|
||||||
|
# 全局状态管理
|
||||||
connected_clients: Dict[str, ClientConnection] = {}
|
connected_clients: Dict[str, ClientConnection] = {}
|
||||||
client_lock = asyncio.Lock() # 保护客户端字典的异步锁
|
|
||||||
heartbeat_task: Optional[asyncio.Task] = None
|
heartbeat_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
|
||||||
|
# 心跳检查(定时清理超时客户端 + 调用离线状态更新方法)
|
||||||
async def heartbeat_checker():
|
async def heartbeat_checker():
|
||||||
"""心跳检查(移除模型归还逻辑,因模型已任务级归还)"""
|
|
||||||
while True:
|
while True:
|
||||||
current_time = get_current_time_str()
|
current_time = get_current_time_str()
|
||||||
async with client_lock:
|
|
||||||
# 筛选超时客户端
|
|
||||||
timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()]
|
timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()]
|
||||||
|
|
||||||
|
if timeout_ips:
|
||||||
|
print(f"[{current_time}] 心跳检查:{len(timeout_ips)}个客户端超时(IP:{timeout_ips})")
|
||||||
for ip in timeout_ips:
|
for ip in timeout_ips:
|
||||||
async with client_lock:
|
try:
|
||||||
conn = connected_clients.get(ip)
|
conn = connected_clients[ip]
|
||||||
if not conn:
|
|
||||||
continue
|
|
||||||
# 取消消费任务+关闭连接
|
|
||||||
if conn.consumer_task and not conn.consumer_task.done():
|
if conn.consumer_task and not conn.consumer_task.done():
|
||||||
conn.consumer_task.cancel()
|
conn.consumer_task.cancel()
|
||||||
await conn.websocket.close(code=1008, reason="心跳超时")
|
await conn.websocket.close(code=1008, reason="心跳超时")
|
||||||
# 标记离线(用线程池)
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
await loop.run_in_executor(thread_pool, update_online_status_by_ip, ip, 0)
|
|
||||||
await loop.run_in_executor(
|
|
||||||
thread_pool, add_device_action, DeviceActionCreate(client_ip=ip, action=0)
|
|
||||||
)
|
|
||||||
connected_clients.pop(ip)
|
|
||||||
print(f"[{current_time}] 客户端{ip}:超时离线(资源已清理)")
|
|
||||||
|
|
||||||
# 打印在线状态
|
# 超时设为离线并记录
|
||||||
async with client_lock:
|
try:
|
||||||
|
await asyncio.to_thread(update_online_status_by_ip, ip, 0)
|
||||||
|
action_data = DeviceActionCreate(client_ip=ip, action=0)
|
||||||
|
await asyncio.to_thread(add_device_action, action_data)
|
||||||
|
print(f"[{current_time}] 客户端{ip}:已标记为离线并记录操作")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[{current_time}] 客户端{ip}:离线状态更新失败 - {str(e)}")
|
||||||
|
finally:
|
||||||
|
connected_clients.pop(ip, None)
|
||||||
|
else:
|
||||||
print(f"[{current_time}] 心跳检查:{len(connected_clients)}个客户端在线")
|
print(f"[{current_time}] 心跳检查:{len(connected_clients)}个客户端在线")
|
||||||
|
|
||||||
await asyncio.sleep(HEARTBEAT_INTERVAL)
|
await asyncio.sleep(HEARTBEAT_INTERVAL)
|
||||||
|
|
||||||
# -------------------------- 应用生命周期(核心修改5:管理线程池) --------------------------
|
|
||||||
|
# 应用生命周期管理
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
global heartbeat_task
|
global heartbeat_task
|
||||||
# 启动心跳任务
|
|
||||||
heartbeat_task = asyncio.create_task(heartbeat_checker())
|
heartbeat_task = asyncio.create_task(heartbeat_checker())
|
||||||
print(f"[{get_current_time_str()}] 心跳任务启动(ID:{id(heartbeat_task)})")
|
print(f"[{get_current_time_str()}] 全局心跳检查任务启动(任务ID:{id(heartbeat_task)})")
|
||||||
print(f"[{get_current_time_str()}] 线程池启动(最大线程数:{THREAD_POOL_SIZE})")
|
yield
|
||||||
yield # 应用运行期间
|
|
||||||
# 清理资源
|
|
||||||
if heartbeat_task and not heartbeat_task.done():
|
if heartbeat_task and not heartbeat_task.done():
|
||||||
heartbeat_task.cancel()
|
heartbeat_task.cancel()
|
||||||
|
try:
|
||||||
await heartbeat_task
|
await heartbeat_task
|
||||||
print(f"[{get_current_time_str()}] 心跳任务已关闭")
|
print(f"[{get_current_time_str()}] 全局心跳检查任务已取消")
|
||||||
# 关闭线程池(等待所有任务完成)
|
except asyncio.CancelledError:
|
||||||
thread_pool.shutdown(wait=True)
|
pass
|
||||||
print(f"[{get_current_time_str()}] 线程池已关闭")
|
|
||||||
|
|
||||||
# -------------------------- WebSocket路由 --------------------------
|
|
||||||
|
# 消息处理工具函数
|
||||||
|
async def send_heartbeat_ack(conn: ClientConnection):
|
||||||
|
try:
|
||||||
|
heartbeat_ack_msg = {
|
||||||
|
"type": "heart",
|
||||||
|
"timestamp": get_current_time_str(),
|
||||||
|
"client_ip": conn.client_ip
|
||||||
|
}
|
||||||
|
await conn.websocket.send_json(heartbeat_ack_msg)
|
||||||
|
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:已发送心跳确认")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
connected_clients.pop(conn.client_ip, None)
|
||||||
|
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:心跳确认发送失败 - {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_text_msg(conn: ClientConnection, text: str):
|
||||||
|
try:
|
||||||
|
msg = json.loads(text)
|
||||||
|
if msg.get("type") == "heart":
|
||||||
|
conn.update_heartbeat()
|
||||||
|
await send_heartbeat_ack(conn)
|
||||||
|
else:
|
||||||
|
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:未知文本消息类型({msg.get('type')})")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:无效JSON文本消息")
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_binary_msg(conn: ClientConnection, data: bytes):
|
||||||
|
try:
|
||||||
|
conn.frame_queue.put_nowait(data)
|
||||||
|
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:图像数据({len(data)}字节)已加入队列")
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:帧队列已满,丢弃当前图像数据")
|
||||||
|
|
||||||
|
|
||||||
|
# WebSocket路由配置
|
||||||
ws_router = APIRouter()
|
ws_router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@ws_router.websocket(WS_ENDPOINT)
|
@ws_router.websocket(WS_ENDPOINT)
|
||||||
async def websocket_endpoint(websocket: WebSocket):
|
async def websocket_endpoint(websocket: WebSocket):
|
||||||
|
# 加载模型
|
||||||
|
load_model()
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
client_ip = websocket.client.host if websocket.client else "unknown_ip"
|
client_ip = websocket.client.host if websocket.client else "unknown_ip"
|
||||||
current_time = get_current_time_str()
|
current_time = get_current_time_str()
|
||||||
print(f"[{current_time}] 客户端{client_ip}:连接建立")
|
print(f"[{current_time}] 客户端{client_ip}:WebSocket连接已建立")
|
||||||
|
|
||||||
new_conn = None
|
|
||||||
is_online_updated = False
|
is_online_updated = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 处理重复连接(关闭旧连接)
|
# 处理重复连接
|
||||||
async with client_lock:
|
|
||||||
if client_ip in connected_clients:
|
if client_ip in connected_clients:
|
||||||
old_conn = connected_clients[client_ip]
|
old_conn = connected_clients[client_ip]
|
||||||
if old_conn.consumer_task and not old_conn.consumer_task.done():
|
if old_conn.consumer_task and not old_conn.consumer_task.done():
|
||||||
old_conn.consumer_task.cancel()
|
old_conn.consumer_task.cancel()
|
||||||
await old_conn.websocket.close(code=1008, reason="新连接抢占")
|
await old_conn.websocket.close(code=1008, reason="同一IP新连接建立")
|
||||||
connected_clients.pop(client_ip)
|
connected_clients.pop(client_ip)
|
||||||
print(f"[{current_time}] 客户端{client_ip}:旧连接已关闭")
|
print(f"[{current_time}] 客户端{client_ip}:已关闭旧连接")
|
||||||
|
|
||||||
# 创建新连接+启动消费任务
|
# 注册新连接
|
||||||
new_conn = ClientConnection(websocket, client_ip)
|
new_conn = ClientConnection(websocket, client_ip)
|
||||||
|
connected_clients[client_ip] = new_conn
|
||||||
new_conn.start_consumer()
|
new_conn.start_consumer()
|
||||||
# 初始发送帧许可(让客户端立即发帧)
|
# 初始许可:连接建立后立即发一次,让客户端知道可发第一帧(后续靠取帧后自动发)
|
||||||
await new_conn.send_frame_permit()
|
await new_conn.send_frame_permit()
|
||||||
|
|
||||||
# 标记客户端在线
|
# 标记上线并记录
|
||||||
loop = asyncio.get_running_loop()
|
try:
|
||||||
await loop.run_in_executor(thread_pool, update_online_status_by_ip, client_ip, 1)
|
await asyncio.to_thread(update_online_status_by_ip, client_ip, 1)
|
||||||
await loop.run_in_executor(
|
action_data = DeviceActionCreate(client_ip=client_ip, action=1)
|
||||||
thread_pool, add_device_action, DeviceActionCreate(client_ip=client_ip, action=1)
|
await asyncio.to_thread(add_device_action, action_data)
|
||||||
)
|
print(f"[{current_time}] 客户端{client_ip}:已标记为在线并记录操作")
|
||||||
is_online_updated = True
|
is_online_updated = True
|
||||||
async with client_lock:
|
except Exception as e:
|
||||||
connected_clients[client_ip] = new_conn
|
print(f"[{current_time}] 客户端{client_ip}:上线状态更新失败 - {str(e)}")
|
||||||
print(f"[{current_time}] 客户端{client_ip}:注册成功(在线数:{len(connected_clients)})")
|
|
||||||
|
|
||||||
# 消息循环(接收文本/二进制帧)
|
print(f"[{current_time}] 客户端{client_ip}:新连接注册成功,在线数:{len(connected_clients)}")
|
||||||
|
|
||||||
|
# 消息循环
|
||||||
while True:
|
while True:
|
||||||
data = await websocket.receive()
|
data = await websocket.receive()
|
||||||
if "text" in data:
|
if "text" in data:
|
||||||
# 处理文本消息(如心跳)
|
await handle_text_msg(new_conn, data["text"])
|
||||||
try:
|
|
||||||
msg = json.loads(data["text"])
|
|
||||||
if msg.get("type") == "heart":
|
|
||||||
new_conn.update_heartbeat()
|
|
||||||
# 回复心跳确认
|
|
||||||
await websocket.send_json({
|
|
||||||
"type": "heart",
|
|
||||||
"timestamp": get_current_time_str(),
|
|
||||||
"client_ip": client_ip
|
|
||||||
})
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
print(f"[{get_current_time_str()}] 客户端{client_ip}:无效JSON")
|
|
||||||
elif "bytes" in data:
|
elif "bytes" in data:
|
||||||
# 处理二进制帧(图像)
|
await handle_binary_msg(new_conn, data["bytes"])
|
||||||
try:
|
|
||||||
await new_conn.frame_queue.put(data["bytes"])
|
|
||||||
print(f"[{get_current_time_str()}] 客户端{client_ip}:帧已入队(队列大小:{new_conn.frame_queue.qsize()})")
|
|
||||||
except asyncio.QueueFull:
|
|
||||||
print(f"[{get_current_time_str()}] 客户端{client_ip}:帧队列满(丢弃当前帧)")
|
|
||||||
|
|
||||||
except WebSocketDisconnect as e:
|
except WebSocketDisconnect as e:
|
||||||
print(f"[{get_current_time_str()}] 客户端{client_ip}:主动断开(代码:{e.code})")
|
print(f"[{get_current_time_str()}] 客户端{client_ip}:主动断开连接(代码:{e.code})")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[{get_current_time_str()}] 客户端{client_ip}:连接异常 - {str(e)[:50]}")
|
print(f"[{get_current_time_str()}] 客户端{client_ip}:连接异常 - {str(e)[:50]}")
|
||||||
finally:
|
finally:
|
||||||
# 清理资源(无需归还模型,已在process_frame中归还)
|
# 清理资源并标记离线
|
||||||
if new_conn and client_ip in connected_clients:
|
if client_ip in connected_clients:
|
||||||
async with client_lock:
|
conn = connected_clients[client_ip]
|
||||||
conn = connected_clients.get(client_ip)
|
|
||||||
if conn:
|
|
||||||
if conn.consumer_task and not conn.consumer_task.done():
|
if conn.consumer_task and not conn.consumer_task.done():
|
||||||
conn.consumer_task.cancel()
|
conn.consumer_task.cancel()
|
||||||
# 标记离线(仅当在线状态已更新时)
|
|
||||||
|
# 主动/异常断开时标记离线
|
||||||
if is_online_updated:
|
if is_online_updated:
|
||||||
loop = asyncio.get_running_loop()
|
try:
|
||||||
await loop.run_in_executor(thread_pool, update_online_status_by_ip, client_ip, 0)
|
await asyncio.to_thread(update_online_status_by_ip, client_ip, 0)
|
||||||
await loop.run_in_executor(
|
action_data = DeviceActionCreate(client_ip=client_ip, action=0)
|
||||||
thread_pool, add_device_action, DeviceActionCreate(client_ip=client_ip, action=0)
|
await asyncio.to_thread(add_device_action, action_data)
|
||||||
)
|
print(f"[{get_current_time_str()}] 客户端{client_ip}:断开后已标记为离线")
|
||||||
connected_clients.pop(client_ip)
|
except Exception as e:
|
||||||
async with client_lock:
|
print(f"[{get_current_time_str()}] 客户端{client_ip}:断开后离线更新失败 - {str(e)}")
|
||||||
print(f"[{get_current_time_str()}] 客户端{client_ip}:资源清理完成(在线数:{len(connected_clients)})")
|
|
||||||
|
connected_clients.pop(client_ip, None)
|
||||||
|
print(f"[{get_current_time_str()}] 客户端{client_ip}:资源已清理,在线数:{len(connected_clients)}")
|
||||||
|