Compare commits
17 Commits
834de70547
...
master
Author | SHA1 | Date | |
---|---|---|---|
435b2a0e6c | |||
ae177ca14a | |||
d3c4820b73 | |||
532a9e75e9 | |||
0fe49bf829 | |||
2571da3c2d | |||
1dd832e18d | |||
8ceb92c572 | |||
9b3d20511a | |||
30bf7c9fcb | |||
ec6dbfde90 | |||
3ed73bd9eb | |||
08f8a0e44e | |||
b5d870a19c | |||
ea82a33a8f | |||
bae7785a97 | |||
49d2c71fdd |
2
.idea/Video.iml
generated
@ -2,7 +2,7 @@
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="video" jdkType="Python SDK" />
|
||||
<orderEntry type="jdk" jdkName="Python 3.10" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
2
.idea/misc.xml
generated
@ -3,5 +3,5 @@
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="video" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="video" project-jdk-type="Python SDK" />
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10" project-jdk-type="Python SDK" />
|
||||
</project>
|
@ -13,7 +13,3 @@ charset = utf8mb4
|
||||
secret_key = 6tsieyd87wefdw2wgeduwte23rfcsd
|
||||
algorithm = HS256
|
||||
access_token_expire_minutes = 30
|
||||
|
||||
[live]
|
||||
rtmp_url = rtmp://192.168.110.65:1935/live/
|
||||
webrtc_url = http://192.168.110.65:1985/rtc/v1/whep/?app=live&stream=
|
||||
|
106
core/all.py
Normal file
@ -0,0 +1,106 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL.Image import Image
|
||||
|
||||
from core.establish import get_image_save_path
|
||||
from core.ocr import load_model as ocrLoadModel, detect as ocrDetect
|
||||
from core.face import load_model as faceLoadModel, detect as faceDetect
|
||||
from core.yolo import load_model as yoloLoadModel, detect as yoloDetect
|
||||
# 导入保存路径函数(根据实际文件位置调整导入路径)
|
||||
import numpy as np
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from ds.db import db
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
# 模型加载状态标记(避免重复加载)
|
||||
|
||||
|
||||
_model_loaded = False
|
||||
|
||||
|
||||
def load_model():
|
||||
"""加载所有检测模型(仅首次调用时执行)"""
|
||||
global _model_loaded
|
||||
if _model_loaded:
|
||||
print("模型已加载,无需重复执行")
|
||||
return
|
||||
|
||||
# 依次加载OCR、人脸、YOLO模型
|
||||
ocrLoadModel()
|
||||
faceLoadModel()
|
||||
yoloLoadModel()
|
||||
|
||||
_model_loaded = True
|
||||
print("所有检测模型加载完成")
|
||||
|
||||
|
||||
def save_db(model_type, client_ip, result):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
# 连接数据库
|
||||
conn = db.get_connection()
|
||||
# 往表插入数据
|
||||
cursor = conn.cursor(dictionary=True) # 返回字典格式结果
|
||||
insert_query = """
|
||||
INSERT INTO device_danger (client_ip, type, result)
|
||||
VALUES (%s, %s, %s)
|
||||
"""
|
||||
cursor.execute(insert_query, (client_ip, model_type, result))
|
||||
conn.commit()
|
||||
except MySQLError as e:
|
||||
raise Exception(f"获取设备列表失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
|
||||
def detect(client_ip, frame):
|
||||
"""
|
||||
执行模型检测,检测到违规时按指定格式保存图片
|
||||
参数:
|
||||
frame: 待检测的图像帧(OpenCV格式,numpy.ndarray类型)
|
||||
返回:
|
||||
(检测结果布尔值, 检测详情, 检测模型类型)
|
||||
"""
|
||||
# 1. YOLO检测(优先级1)
|
||||
yolo_flag, yolo_result = yoloDetect(frame)
|
||||
print(f"YOLO检测结果:{yolo_result}")
|
||||
if yolo_flag:
|
||||
full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip)
|
||||
if full_save_path: # 只判断完整路径是否有效(用于保存)
|
||||
cv2.imwrite(full_save_path, frame)
|
||||
# 打印时使用「显示用短路径」,符合需求格式
|
||||
print(f"✅ YOLO违规图片已保存:{display_path}")
|
||||
save_db(model_type="yolo", client_ip=client_ip, result=str(full_save_path))
|
||||
return (True, yolo_result, "yolo")
|
||||
#
|
||||
# # 2. 人脸检测(优先级2)
|
||||
face_flag, face_result = faceDetect(frame)
|
||||
print(f"人脸检测结果:{face_result}")
|
||||
if face_flag:
|
||||
full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip)
|
||||
if full_save_path: # 只判断完整路径是否有效(用于保存)
|
||||
cv2.imwrite(full_save_path, frame)
|
||||
# 打印时使用「显示用短路径」,符合需求格式
|
||||
print(f"✅ face违规图片已保存:{display_path}")
|
||||
save_db(model_type="face", client_ip=client_ip, result=str(full_save_path))
|
||||
return (True, face_result, "face")
|
||||
|
||||
# 3. OCR检测(优先级3)
|
||||
ocr_flag, ocr_result = ocrDetect(frame)
|
||||
print(f"OCR检测结果:{ocr_result}")
|
||||
if ocr_flag:
|
||||
# 解构元组,保存用完整路径,打印用短路径
|
||||
full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip)
|
||||
if full_save_path: # 只判断完整路径是否有效(用于保存)
|
||||
cv2.imwrite(full_save_path, frame)
|
||||
# 打印时使用「显示用短路径」,符合需求格式
|
||||
print(f"✅ ocr违规图片已保存:{display_path}")
|
||||
save_db(model_type="ocr", client_ip=client_ip, result=str(full_save_path))
|
||||
return (True, ocr_result, "ocr")
|
||||
# 4. 无违规内容(不保存图片)
|
||||
print(f"❌ 未检测到任何违规内容,不保存图片")
|
||||
return (False, "未检测到任何内容", "none")
|
120
core/establish.py
Normal file
@ -0,0 +1,120 @@
|
||||
import os
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from service.device_service import get_unique_client_ips
|
||||
def create_directory_structure():
|
||||
"""创建项目所需的目录结构,为所有客户端IP预创建基础目录"""
|
||||
try:
|
||||
# 1. 创建根目录下的resource文件夹(存在则跳过,不覆盖子内容)
|
||||
resource_dir = Path("resource")
|
||||
resource_dir.mkdir(exist_ok=True)
|
||||
print(f"确保resource目录存在: {resource_dir.absolute()}")
|
||||
|
||||
# 2. 在resource下创建dect文件夹
|
||||
dect_dir = resource_dir / "dect"
|
||||
dect_dir.mkdir(exist_ok=True)
|
||||
print(f"确保dect目录存在: {dect_dir.absolute()}")
|
||||
|
||||
# 3. 在dect下创建三个模型文件夹
|
||||
model_dirs = ["ocr", "face", "yolo"]
|
||||
for model in model_dirs:
|
||||
model_dir = dect_dir / model
|
||||
model_dir.mkdir(exist_ok=True)
|
||||
print(f"确保{model}模型目录存在: {model_dir.absolute()}")
|
||||
|
||||
# 4. 调用外部方法获取所有客户端IP地址
|
||||
try:
|
||||
# 调用外部ip_read()方法获取所有客户端IP地址列表
|
||||
all_ip_addresses = get_unique_client_ips()
|
||||
|
||||
# 确保返回的是列表类型
|
||||
if not isinstance(all_ip_addresses, list):
|
||||
all_ip_addresses = [all_ip_addresses]
|
||||
|
||||
# 过滤有效IP(去除空字符串和空格)
|
||||
valid_ips = [ip.strip() for ip in all_ip_addresses if ip.strip()]
|
||||
|
||||
if not valid_ips:
|
||||
print("警告: 未获取到有效的客户端IP地址")
|
||||
return
|
||||
|
||||
print(f"获取到的所有客户端IP地址: {valid_ips}")
|
||||
|
||||
# 5. 获取当前日期(年、月)
|
||||
now = datetime.datetime.now()
|
||||
current_year = str(now.year)
|
||||
current_month = str(now.month)
|
||||
|
||||
# 6. 为每个客户端IP在每个模型文件夹下创建年->月的基础目录结构
|
||||
for ip in valid_ips:
|
||||
# 处理IP地址中的特殊字符(将.替换为_,避免路径问题)
|
||||
safe_ip = ip.replace(".", "_")
|
||||
|
||||
for model in model_dirs:
|
||||
# 构建路径: resource/dect/{model}/{safe_ip}/{year}/{month}
|
||||
ip_dir = dect_dir / model / safe_ip
|
||||
year_dir = ip_dir / current_year
|
||||
month_dir = year_dir / current_month
|
||||
|
||||
# 递归创建目录(存在则跳过,不覆盖)
|
||||
month_dir.mkdir(parents=True, exist_ok=True)
|
||||
print(f"为客户端IP {ip} 创建/确保目录存在: {month_dir.absolute()}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理客户端IP和日期目录时发生错误: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"创建基础目录结构时发生错误: {str(e)}")
|
||||
|
||||
|
||||
def get_image_save_path(model_type: str, client_ip: str) -> tuple:
|
||||
"""
|
||||
获取图片保存的「完整路径」和「显示用短路径」
|
||||
|
||||
参数:
|
||||
model_type: 模型类型,应为"ocr"、"face"或"yolo"
|
||||
client_ip: 检测到违禁的客户端IP地址(原始格式,如192.168.1.101)
|
||||
|
||||
返回:
|
||||
元组 (完整保存路径, 显示用短路径);若出错则返回 ("", "")
|
||||
"""
|
||||
try:
|
||||
# 1. 验证客户端IP有效性(检查是否在已知IP列表中)
|
||||
all_ip_addresses = get_unique_client_ips()
|
||||
if not isinstance(all_ip_addresses, list):
|
||||
all_ip_addresses = [all_ip_addresses]
|
||||
valid_ips = [ip.strip() for ip in all_ip_addresses if ip.strip()]
|
||||
|
||||
if client_ip.strip() not in valid_ips:
|
||||
raise ValueError(f"客户端IP {client_ip} 不在已知IP列表中,无法保存文件")
|
||||
|
||||
# 2. 处理IP地址(与目录创建逻辑一致,将.替换为_)
|
||||
safe_ip = client_ip.strip().replace(".", "_")
|
||||
|
||||
# 3. 获取当前日期和毫秒级时间戳(确保文件名唯一)
|
||||
now = datetime.datetime.now()
|
||||
current_year = str(now.year)
|
||||
current_month = str(now.month)
|
||||
current_day = str(now.day)
|
||||
# 时间戳格式:年月日时分秒毫秒(如20250910143050123)
|
||||
timestamp = now.strftime("%Y%m%d%H%M%S%f")[:-3]
|
||||
|
||||
# 4. 定义基础目录(用于生成相对路径)
|
||||
base_dir = Path("resource") / "dect" # 显示路径会去掉这个前缀
|
||||
# 构建日级目录(完整路径:resource/dect/{model}/{safe_ip}/{年}/{月}/{日})
|
||||
day_dir = base_dir / model_type / safe_ip / current_year / current_month / current_day
|
||||
day_dir.mkdir(parents=True, exist_ok=True) # 确保日目录存在
|
||||
|
||||
# 5. 构建唯一文件名
|
||||
image_filename = f"dect_{model_type}_{safe_ip}_{current_year}{current_month}{current_day}_{timestamp}.jpg"
|
||||
|
||||
# 6. 生成完整路径(用于实际保存图片)和显示路径(用于打印)
|
||||
full_path = day_dir / image_filename # 完整路径:resource/dect/.../xxx.jpg
|
||||
display_path = full_path.relative_to(base_dir) # 短路径:{model}/.../xxx.jpg(去掉resource/dect)
|
||||
|
||||
return str(full_path), str(display_path)
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取图片保存路径时发生错误: {str(e)}")
|
||||
return "", ""
|
326
core/face.py
Normal file
@ -0,0 +1,326 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
import gc
|
||||
import time
|
||||
import threading
|
||||
from PIL import Image
|
||||
from insightface.app import FaceAnalysis
|
||||
# 假设service.face_service中get_all_face_name_with_eigenvalue可获取人脸数据
|
||||
from service.face_service import get_all_face_name_with_eigenvalue
|
||||
|
||||
# GPU状态检查支持
|
||||
try:
|
||||
import pynvml
|
||||
|
||||
pynvml.nvmlInit()
|
||||
_nvml_available = True
|
||||
except ImportError:
|
||||
print("警告: pynvml库未安装,无法检测GPU状态,默认尝试使用GPU")
|
||||
_nvml_available = False
|
||||
|
||||
# 全局人脸引擎与特征库
|
||||
_face_app = None
|
||||
_known_faces_embeddings = {} # 姓名 -> 归一化特征值的映射
|
||||
_known_faces_names = [] # 已知人脸姓名列表
|
||||
|
||||
# GPU使用状态标记
|
||||
_using_gpu = False # 是否使用GPU
|
||||
_used_gpu_id = -1 # 使用的GPU ID(-1表示CPU)
|
||||
|
||||
# 资源管理变量
|
||||
_ref_count = 0 # 引擎引用计数(记录当前使用次数)
|
||||
_last_used_time = 0 # 最后一次使用引擎的时间
|
||||
_lock = threading.Lock() # 线程安全锁
|
||||
_release_timeout = 8 # 闲置超时时间(秒)
|
||||
_is_releasing = False # 资源释放中标记
|
||||
_monitor_thread_running = False # 监控线程运行标记
|
||||
|
||||
# 调试计数器
|
||||
_debug_counter = {
|
||||
"engine_created": 0, # 引擎创建次数
|
||||
"engine_released": 0, # 引擎释放次数
|
||||
"detection_calls": 0 # 检测函数调用次数
|
||||
}
|
||||
|
||||
|
||||
def check_gpu_availability(gpu_id, memory_threshold=0.7):
|
||||
"""检查指定GPU的内存使用率是否低于阈值(判定为“可用”)"""
|
||||
if not _nvml_available:
|
||||
return True
|
||||
try:
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
|
||||
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
memory_usage = mem_info.used / mem_info.total
|
||||
return memory_usage < memory_threshold
|
||||
except Exception as e:
|
||||
print(f"检查GPU {gpu_id} 状态失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def select_best_gpu(preferred_gpus=[0, 1]):
|
||||
"""按优先级选择可用GPU,优先0号;均不可用则返回-1(CPU)"""
|
||||
for gpu_id in preferred_gpus:
|
||||
try:
|
||||
# 验证GPU是否存在
|
||||
if _nvml_available:
|
||||
pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
|
||||
# 验证GPU内存是否充足
|
||||
if check_gpu_availability(gpu_id):
|
||||
print(f"GPU {gpu_id} 可用,将使用该GPU")
|
||||
return gpu_id
|
||||
else:
|
||||
if gpu_id == 0:
|
||||
print("GPU 0 内存使用率过高,尝试其他GPU")
|
||||
except Exception as e:
|
||||
print(f"GPU {gpu_id} 不可用或访问失败: {e}")
|
||||
print("所有指定GPU均不可用,将使用CPU计算")
|
||||
return -1
|
||||
|
||||
|
||||
def _release_engine_resources():
|
||||
"""释放人脸引擎的所有资源(模型、特征库、GPU缓存等)"""
|
||||
global _face_app, _is_releasing, _known_faces_embeddings, _known_faces_names
|
||||
if not _face_app or _is_releasing:
|
||||
return
|
||||
|
||||
try:
|
||||
_is_releasing = True
|
||||
print("开始释放人脸引擎资源...")
|
||||
|
||||
# 释放InsightFace模型资源
|
||||
if hasattr(_face_app, "model"):
|
||||
_face_app.model = None # 显式置空模型引用
|
||||
_face_app = None # 释放引擎实例
|
||||
|
||||
# 清空人脸特征库
|
||||
_known_faces_embeddings.clear()
|
||||
_known_faces_names.clear()
|
||||
|
||||
_debug_counter["engine_released"] += 1
|
||||
print(f"人脸引擎已释放,调试统计: {_debug_counter}")
|
||||
|
||||
# 强制垃圾回收
|
||||
gc.collect()
|
||||
|
||||
# 清理各深度学习框架的GPU缓存
|
||||
# Torch 缓存清理
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
print("Torch GPU缓存已清理")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# TensorFlow 缓存清理
|
||||
try:
|
||||
import tensorflow as tf
|
||||
tf.keras.backend.clear_session()
|
||||
print("TensorFlow会话已清理")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# MXNet 缓存清理(InsightFace底层常用MXNet)
|
||||
try:
|
||||
import mxnet as mx
|
||||
mx.nd.waitall() # 等待所有计算完成并释放资源
|
||||
print("MXNet资源已等待释放")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
print(f"释放资源过程中出错: {e}")
|
||||
finally:
|
||||
_is_releasing = False
|
||||
|
||||
|
||||
def _resource_monitor_thread():
|
||||
"""后台监控线程:检测引擎闲置超时,触发资源释放"""
|
||||
global _ref_count, _last_used_time, _face_app, _monitor_thread_running
|
||||
_monitor_thread_running = True
|
||||
while _monitor_thread_running:
|
||||
time.sleep(2) # 缩短检查间隔,加快闲置检测响应
|
||||
with _lock:
|
||||
# 当“引擎存在 + 无引用 + 未在释放中”时,检查闲置时间
|
||||
if _face_app and _ref_count == 0 and not _is_releasing:
|
||||
idle_time = time.time() - _last_used_time
|
||||
if idle_time > _release_timeout:
|
||||
print(f"引擎闲置超时({idle_time:.1f}s > {_release_timeout}s),释放资源")
|
||||
_release_engine_resources()
|
||||
|
||||
|
||||
def load_model(prefer_gpu=True, preferred_gpus=[0, 1]):
|
||||
"""加载人脸识别引擎及已知人脸特征库(默认优先用0号GPU)"""
|
||||
global _face_app, _known_faces_embeddings, _known_faces_names, _using_gpu, _used_gpu_id
|
||||
|
||||
# 启动后台监控线程(确保仅启动一次)
|
||||
if not _monitor_thread_running:
|
||||
threading.Thread(
|
||||
target=_resource_monitor_thread,
|
||||
daemon=True,
|
||||
name="FaceEngineMonitor"
|
||||
).start()
|
||||
print("人脸引擎监控线程已启动")
|
||||
|
||||
# 若正在释放资源,等待释放完成
|
||||
while _is_releasing:
|
||||
time.sleep(0.1)
|
||||
|
||||
# 若引擎已初始化,直接返回
|
||||
if _face_app:
|
||||
return True
|
||||
|
||||
# 初始化InsightFace引擎
|
||||
try:
|
||||
print("正在初始化InsightFace人脸识别引擎...")
|
||||
_face_app = FaceAnalysis(name="buffalo_l", root=os.path.expanduser("~/.insightface"))
|
||||
|
||||
# 选择GPU(优先用0号)
|
||||
ctx_id = 0
|
||||
if prefer_gpu:
|
||||
ctx_id = select_best_gpu(preferred_gpus)
|
||||
_using_gpu = ctx_id != -1
|
||||
_used_gpu_id = ctx_id if _using_gpu else -1
|
||||
|
||||
if _using_gpu:
|
||||
print(f"引擎初始化成功,将使用GPU {ctx_id} 计算")
|
||||
else:
|
||||
print("引擎初始化成功,将使用CPU计算")
|
||||
|
||||
# 准备模型(加载到指定设备)
|
||||
_face_app.prepare(ctx_id=ctx_id, det_size=(640, 640))
|
||||
print("InsightFace引擎初始化完成")
|
||||
_debug_counter["engine_created"] += 1
|
||||
print(f"引擎调试统计: {_debug_counter}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"引擎初始化失败: {e}")
|
||||
return False
|
||||
|
||||
# 从服务加载已知人脸的姓名和特征值
|
||||
try:
|
||||
face_data = get_all_face_name_with_eigenvalue()
|
||||
for person_name, eigenvalue_data in face_data.items():
|
||||
# 兼容“numpy数组”和“字符串”格式的特征值
|
||||
if isinstance(eigenvalue_data, np.ndarray):
|
||||
eigenvalue = eigenvalue_data.astype(np.float32)
|
||||
elif isinstance(eigenvalue_data, str):
|
||||
# 清理字符串中的括号、换行等干扰符
|
||||
cleaned = eigenvalue_data.replace("[", "").replace("]", "").replace("\n", "").strip()
|
||||
# 分割并转换为浮点数数组
|
||||
values = [v for v in cleaned.split() if v] # 兼容空格/逗号分隔
|
||||
eigenvalue = np.array(list(map(float, values)), dtype=np.float32)
|
||||
else:
|
||||
print(f"不支持的特征值类型({type(eigenvalue_data)}),跳过 {person_name}")
|
||||
continue
|
||||
|
||||
# 特征值归一化(保证后续相似度计算的一致性)
|
||||
norm = np.linalg.norm(eigenvalue)
|
||||
if norm != 0:
|
||||
eigenvalue = eigenvalue / norm
|
||||
|
||||
_known_faces_embeddings[person_name] = eigenvalue
|
||||
_known_faces_names.append(person_name)
|
||||
|
||||
print(f"成功加载 {len(_known_faces_names)} 个人脸的特征库")
|
||||
|
||||
except Exception as e:
|
||||
print(f"加载人脸特征库失败: {e}")
|
||||
|
||||
return _face_app is not None
|
||||
|
||||
|
||||
def detect(frame, similarity_threshold=0.4):
|
||||
"""
|
||||
检测并识别人脸
|
||||
返回:(是否匹配到已知人脸, 结果描述字符串)
|
||||
"""
|
||||
global _face_app, _known_faces_embeddings, _known_faces_names, _ref_count, _last_used_time
|
||||
|
||||
# 校验输入帧有效性
|
||||
if frame is None or frame.size == 0:
|
||||
return (False, "无效的输入帧数据")
|
||||
|
||||
# 加锁并更新引用计数、最后使用时间
|
||||
engine = None
|
||||
with _lock:
|
||||
_ref_count += 1
|
||||
_last_used_time = time.time()
|
||||
_debug_counter["detection_calls"] += 1
|
||||
|
||||
# 若引擎未初始化且未在释放中,尝试初始化
|
||||
if not _face_app and not _is_releasing:
|
||||
if not load_model(prefer_gpu=True):
|
||||
# 初始化失败,恢复引用计数
|
||||
with _lock:
|
||||
_ref_count = max(0, _ref_count - 1)
|
||||
return (False, "人脸引擎初始化失败")
|
||||
|
||||
engine = _face_app # 获取引擎引用
|
||||
|
||||
# 校验引擎可用性
|
||||
if not engine or len(_known_faces_names) == 0:
|
||||
with _lock:
|
||||
_ref_count = max(0, _ref_count - 1)
|
||||
return (False, "人脸引擎不可用或特征库为空")
|
||||
|
||||
try:
|
||||
# GPU计算时,确保帧数据是连续内存(避免CUDA错误)
|
||||
if _using_gpu and engine is not None and not frame.flags.contiguous:
|
||||
frame = np.ascontiguousarray(frame)
|
||||
|
||||
# 执行人脸检测与特征提取
|
||||
faces = engine.get(frame)
|
||||
except Exception as e:
|
||||
print(f"人脸检测过程出错: {e}")
|
||||
# 出错时尝试重新初始化引擎(可能是GPU状态变化导致)
|
||||
print("尝试重新初始化人脸引擎...")
|
||||
with _lock:
|
||||
_ref_count = max(0, _ref_count - 1)
|
||||
load_model(prefer_gpu=True)
|
||||
return (False, f"检测错误: {str(e)}")
|
||||
|
||||
result_parts = []
|
||||
has_matched_known_face = False # 是否有任意人脸匹配到已知库
|
||||
|
||||
for face in faces:
|
||||
# 归一化当前检测到的人脸特征
|
||||
face_embedding = face.embedding.astype(np.float32)
|
||||
norm = np.linalg.norm(face_embedding)
|
||||
if norm == 0:
|
||||
continue
|
||||
face_embedding = face_embedding / norm
|
||||
|
||||
# 与已知人脸特征逐一比对
|
||||
max_similarity, best_match_name = -1.0, "Unknown"
|
||||
for name in _known_faces_names:
|
||||
known_emb = _known_faces_embeddings[name]
|
||||
similarity = np.dot(face_embedding, known_emb) # 余弦相似度
|
||||
if similarity > max_similarity:
|
||||
max_similarity = similarity
|
||||
best_match_name = name
|
||||
|
||||
# 判断是否匹配成功
|
||||
is_matched = max_similarity >= similarity_threshold
|
||||
if is_matched:
|
||||
has_matched_known_face = True
|
||||
|
||||
# 记录该人脸的检测结果
|
||||
bbox = face.bbox # 人脸边界框
|
||||
result_parts.append(
|
||||
f"{'匹配' if is_matched else '未匹配'}: {best_match_name} "
|
||||
f"(相似度: {max_similarity:.2f}, 边界框: {bbox.astype(int).tolist()})"
|
||||
)
|
||||
|
||||
# 构建最终结果字符串
|
||||
result_str = "未检测到人脸" if not result_parts else "; ".join(result_parts)
|
||||
|
||||
# 释放引用计数(线程安全)
|
||||
with _lock:
|
||||
_ref_count = max(0, _ref_count - 1)
|
||||
# 若仍有引用,更新最后使用时间;若引用为0,也立即标记(加快闲置检测)
|
||||
_last_used_time = time.time()
|
||||
|
||||
return (has_matched_known_face, result_str)
|
BIN
core/models/best.pt
Normal file
253
core/ocr.py
Normal file
@ -0,0 +1,253 @@
|
||||
import os
|
||||
import cv2
|
||||
import gc
|
||||
import time
|
||||
import threading
|
||||
import numpy as np
|
||||
from paddleocr import PaddleOCR
|
||||
from service.sensitive_service import get_all_sensitive_words
|
||||
|
||||
# 解决NumPy 1.20+版本中np.int已移除的兼容性问题
|
||||
try:
|
||||
if not hasattr(np, 'int'):
|
||||
np.int = int
|
||||
except Exception as e:
|
||||
print(f"处理NumPy兼容性时出错: {e}")
|
||||
|
||||
# 全局变量
|
||||
_ocr_engine = None
|
||||
_forbidden_words = set()
|
||||
_conf_threshold = 0.5
|
||||
|
||||
# 资源管理变量
|
||||
_ref_count = 0
|
||||
_last_used_time = 0
|
||||
_lock = threading.Lock()
|
||||
_release_timeout = 5 # 30秒无使用则释放
|
||||
_is_releasing = False # 标记是否正在释放
|
||||
|
||||
# 并行处理配置
|
||||
_max_workers = 4 # 并行处理的线程数
|
||||
|
||||
# 调试用计数器
|
||||
_debug_counter = {
|
||||
"created": 0,
|
||||
"released": 0,
|
||||
"detected": 0
|
||||
}
|
||||
|
||||
|
||||
def _release_engine():
|
||||
"""释放OCR引擎资源"""
|
||||
global _ocr_engine, _is_releasing
|
||||
if not _ocr_engine or _is_releasing:
|
||||
return
|
||||
|
||||
try:
|
||||
_is_releasing = True
|
||||
_ocr_engine = None
|
||||
_debug_counter["released"] += 1
|
||||
print(f"OCR engine released. Stats: {_debug_counter}")
|
||||
|
||||
# 清理GPU缓存
|
||||
gc.collect()
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
import paddle
|
||||
if paddle.is_compiled_with_cuda():
|
||||
paddle.device.cuda.empty_cache()
|
||||
except ImportError:
|
||||
pass
|
||||
finally:
|
||||
_is_releasing = False
|
||||
|
||||
|
||||
def _monitor_thread():
|
||||
"""监控线程,优化检查逻辑"""
|
||||
global _ref_count, _last_used_time, _ocr_engine
|
||||
while True:
|
||||
time.sleep(5) # 每5秒检查一次
|
||||
with _lock:
|
||||
if _ocr_engine and _ref_count == 0 and not _is_releasing:
|
||||
elapsed = time.time() - _last_used_time
|
||||
if elapsed > _release_timeout:
|
||||
print(f"Idle timeout ({elapsed:.1f}s > {_release_timeout}s), releasing engine")
|
||||
_release_engine()
|
||||
|
||||
|
||||
def load_model():
|
||||
"""加载违禁词列表和初始化监控线程"""
|
||||
global _forbidden_words
|
||||
|
||||
# 确保监控线程只启动一次
|
||||
if not any(t.name == "OCRMonitor" for t in threading.enumerate()):
|
||||
threading.Thread(target=_monitor_thread, daemon=True, name="OCRMonitor").start()
|
||||
print("OCR monitor thread started")
|
||||
|
||||
# 加载违禁词
|
||||
try:
|
||||
_forbidden_words = get_all_sensitive_words()
|
||||
print(f"Loaded {len(_forbidden_words)} forbidden words")
|
||||
except Exception as e:
|
||||
print(f"Forbidden words load error: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def detect(frame):
|
||||
"""OCR检测,支持并行处理"""
|
||||
global _ocr_engine, _forbidden_words, _conf_threshold, _ref_count, _last_used_time, _max_workers
|
||||
|
||||
# 验证前置条件
|
||||
if not _forbidden_words:
|
||||
return (False, "违禁词未初始化")
|
||||
if frame is None or frame.size == 0:
|
||||
return (False, "无效帧数据")
|
||||
|
||||
# 增加引用计数并获取引擎实例
|
||||
engine = None
|
||||
with _lock:
|
||||
_ref_count += 1
|
||||
_last_used_time = time.time()
|
||||
_debug_counter["detected"] += 1
|
||||
|
||||
# 初始化引擎(如果未初始化且不在释放中)
|
||||
if not _ocr_engine and not _is_releasing:
|
||||
try:
|
||||
# 初始化PaddleOCR,设置并行处理参数
|
||||
_ocr_engine = PaddleOCR(
|
||||
use_angle_cls=True,
|
||||
lang="ch",
|
||||
show_log=False,
|
||||
use_gpu=True,
|
||||
max_text_length=1024,
|
||||
threads=_max_workers
|
||||
)
|
||||
_debug_counter["created"] += 1
|
||||
print(f"PaddleOCR engine initialized with {_max_workers} workers. Stats: {_debug_counter}")
|
||||
except Exception as e:
|
||||
print(f"OCR model load failed: {e}")
|
||||
_ref_count -= 1
|
||||
return (False, f"引擎初始化失败: {str(e)}")
|
||||
|
||||
engine = _ocr_engine
|
||||
|
||||
# 检查引擎是否可用
|
||||
if not engine:
|
||||
with _lock:
|
||||
_ref_count -= 1
|
||||
return (False, "OCR引擎不可用")
|
||||
|
||||
try:
|
||||
# 执行OCR检测
|
||||
ocr_res = engine.ocr(frame, cls=True)
|
||||
|
||||
# 验证OCR结果格式
|
||||
if not ocr_res or not isinstance(ocr_res, list):
|
||||
return (False, "无OCR结果")
|
||||
|
||||
# 处理OCR结果 - 兼容多种格式
|
||||
texts = []
|
||||
confs = []
|
||||
for line in ocr_res:
|
||||
if line is None:
|
||||
continue
|
||||
|
||||
# 处理line可能是列表或直接是文本信息的情况
|
||||
if isinstance(line, list):
|
||||
items_to_process = line
|
||||
else:
|
||||
items_to_process = [line]
|
||||
|
||||
for item in items_to_process:
|
||||
# 精确识别并忽略图片坐标位置信息 [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
|
||||
if isinstance(item, list) and len(item) == 4: # 四边形有4个顶点
|
||||
is_coordinate = True
|
||||
for point in item:
|
||||
# 每个顶点应该是包含2个数字的列表
|
||||
if not (isinstance(point, list) and len(point) == 2 and
|
||||
all(isinstance(coord, (int, float)) for coord in point)):
|
||||
is_coordinate = False
|
||||
break
|
||||
if is_coordinate:
|
||||
continue # 是坐标信息,直接忽略
|
||||
|
||||
# 跳过纯数字列表(其他可能的坐标形式)
|
||||
if isinstance(item, list) and all(isinstance(x, (int, float)) for x in item):
|
||||
continue
|
||||
|
||||
# 处理元组形式的文本和置信度 (text, confidence)
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
text, conf = item
|
||||
if isinstance(text, str) and isinstance(conf, (int, float)):
|
||||
texts.append(text.strip())
|
||||
confs.append(float(conf))
|
||||
continue
|
||||
|
||||
# 处理列表形式的[坐标信息, (text, confidence)]
|
||||
if isinstance(item, list) and len(item) >= 2:
|
||||
# 尝试从列表中提取文本和置信度
|
||||
text_data = item[1]
|
||||
if isinstance(text_data, tuple) and len(text_data) == 2:
|
||||
text, conf = text_data
|
||||
if isinstance(text, str) and isinstance(conf, (int, float)):
|
||||
texts.append(text.strip())
|
||||
confs.append(float(conf))
|
||||
continue
|
||||
elif isinstance(text_data, str):
|
||||
# 只有文本没有置信度的情况
|
||||
texts.append(text_data.strip())
|
||||
confs.append(1.0) # 默认最高置信度
|
||||
continue
|
||||
|
||||
# 无法识别的格式,记录日志
|
||||
print(f"无法解析的OCR结果格式: {item}")
|
||||
|
||||
if len(texts) != len(confs):
|
||||
return (False, "OCR结果格式异常")
|
||||
|
||||
# 筛选违禁词
|
||||
vio_info = []
|
||||
for txt, conf in zip(texts, confs):
|
||||
if conf < _conf_threshold:
|
||||
continue
|
||||
matched = [w for w in _forbidden_words if w in txt]
|
||||
if matched:
|
||||
vio_info.append(f"文本: '{txt}' 包含违禁词: {', '.join(matched)} (置信度: {conf:.2f})")
|
||||
|
||||
# 构建结果
|
||||
has_text = len(texts) > 0
|
||||
has_violation = len(vio_info) > 0
|
||||
|
||||
if not has_text:
|
||||
return (False, "未识别到文本")
|
||||
elif has_violation:
|
||||
return (True, "; ".join(vio_info))
|
||||
else:
|
||||
return (False, "未检测到违禁词")
|
||||
|
||||
except Exception as e:
|
||||
print(f"OCR detect error: {e}")
|
||||
return (False, f"检测错误: {str(e)}")
|
||||
|
||||
finally:
|
||||
# 减少引用计数,确保线程安全
|
||||
with _lock:
|
||||
_ref_count = max(0, _ref_count - 1)
|
||||
if _ref_count > 0:
|
||||
_last_used_time = time.time()
|
||||
|
||||
|
||||
def batch_detect(frames):
|
||||
"""批量检测接口,充分利用并行能力"""
|
||||
results = []
|
||||
for frame in frames:
|
||||
results.append(detect(frame))
|
||||
return results
|
137
core/rtc.py
@ -1,137 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from aiortc import RTCPeerConnection, RTCSessionDescription
|
||||
import aiohttp
|
||||
from ocr.ocr_violation_detector import OCRViolationDetector
|
||||
|
||||
import logging
|
||||
|
||||
# 创建检测器实例
|
||||
detector = OCRViolationDetector(
|
||||
forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt",
|
||||
ocr_confidence_threshold=0.7,
|
||||
log_level=logging.INFO,
|
||||
log_file="ocr_detection.log"
|
||||
)
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("whep_video_puller")
|
||||
|
||||
|
||||
async def whep_pull_video_stream(ip,whep_url):
|
||||
"""
|
||||
通过WHEP从指定URL拉取视频流并在收到每一帧时打印消息
|
||||
|
||||
Args:
|
||||
whep_url: WHEP端点的URL
|
||||
"""
|
||||
pc = RTCPeerConnection()
|
||||
|
||||
# 添加连接状态变化监听
|
||||
@pc.on("connectionstatechange")
|
||||
async def on_connectionstatechange():
|
||||
print(f"连接状态: {pc.connectionState}")
|
||||
|
||||
# 添加ICE连接状态变化监听
|
||||
@pc.on("iceconnectionstatechange")
|
||||
async def on_iceconnectionstatechange():
|
||||
print(f"ICE连接状态: {pc.iceConnectionState}")
|
||||
|
||||
# 添加视频接收器
|
||||
pc.addTransceiver("video", direction="recvonly")
|
||||
|
||||
# 处理接收到的视频轨道
|
||||
@pc.on("track")
|
||||
def on_track(track):
|
||||
print(f"接收到轨道: {track.kind}")
|
||||
if track.kind == "video":
|
||||
print(f"轨道ID: {track.id}")
|
||||
print(f"轨道就绪状态: {track.readyState}")
|
||||
# 创建异步任务来处理视频帧
|
||||
asyncio.ensure_future(handle_video_track(track))
|
||||
|
||||
async def handle_video_track(track):
|
||||
"""处理视频轨道,接收并打印每一帧"""
|
||||
frame_count = 0
|
||||
print("开始处理视频轨道...")
|
||||
|
||||
while True:
|
||||
try:
|
||||
# 尝试接收帧
|
||||
frame = await track.recv()
|
||||
frame_count += 1
|
||||
print(f"收到原始帧 (第{frame_count}帧)")
|
||||
|
||||
# 打印帧的基本信息
|
||||
if hasattr(frame, 'width') and hasattr(frame, 'height'):
|
||||
print(f" 尺寸: {frame.width}x{frame.height}")
|
||||
if hasattr(frame, 'time_base'):
|
||||
print(f" 时间基准: {frame.time_base}")
|
||||
if hasattr(frame, 'pts'):
|
||||
print(f" 显示时间戳: {frame.pts}")
|
||||
|
||||
has_violation, violations, confidences = OCRViolationDetector.detect(frame)
|
||||
|
||||
# 输出检测结果
|
||||
if has_violation:
|
||||
detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
|
||||
for word, conf in zip(violations, confidences):
|
||||
detector.logger.info(f"- {word} (置信度: {conf:.4f})")
|
||||
else:
|
||||
detector.logger.info("图片中未检测到违禁词")
|
||||
except Exception as e:
|
||||
print(f"接收帧时出错: {e}")
|
||||
# 等待一段时间后重试
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
# 创建offer
|
||||
offer = await pc.createOffer()
|
||||
await pc.setLocalDescription(offer)
|
||||
|
||||
print(f"本地SDP信息:\n{offer.sdp}")
|
||||
|
||||
# 通过HTTP POST发送offer到WHEP端点
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
whep_url,
|
||||
data=offer.sdp,
|
||||
headers={"Content-Type": "application/sdp"}
|
||||
) as response:
|
||||
if response.status != 201:
|
||||
print(f"WHEP服务器返回错误: {response.status}")
|
||||
print(f"响应内容: {await response.text()}")
|
||||
raise Exception(f"WHEP服务器返回错误: {response.status}")
|
||||
|
||||
# 获取answer SDP
|
||||
answer_sdp = await response.text()
|
||||
|
||||
# 创建RTCSessionDescription对象
|
||||
answer = RTCSessionDescription(sdp=answer_sdp, type="answer")
|
||||
|
||||
print(f"收到远程SDP:\n{answer_sdp}")
|
||||
|
||||
# 设置远程描述
|
||||
await pc.setRemoteDescription(answer)
|
||||
|
||||
print("连接已建立,开始接收视频流...")
|
||||
|
||||
# 保持连接,直到用户中断
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
# 检查连接状态
|
||||
print(f"当前连接状态: {pc.connectionState}")
|
||||
except KeyboardInterrupt:
|
||||
print("用户中断,关闭连接...")
|
||||
finally:
|
||||
await pc.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 替换为你的WHEP端点URL
|
||||
WHEP_URL = "http://192.168.110.25:1985/rtc/v1/whep/?app=live&stream=473b95a47e338301cbd96809ea7ac416"
|
||||
|
||||
# 运行拉流任务
|
||||
asyncio.run(whep_pull_video_stream(WHEP_URL))
|
112
core/rtmp.py
@ -1,112 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import cv2
|
||||
import time
|
||||
from ocr.model_violation_detector import MultiModelViolationDetector
|
||||
|
||||
|
||||
# 配置文件相对路径(根据实际目录结构调整)
|
||||
YOLO_MODEL_PATH = "../ocr/models/best.pt" # 关键修正:从core目录向上一级找ocr文件夹
|
||||
FORBIDDEN_WORDS_PATH = "../ocr/forbidden_words.txt"
|
||||
OCR_CONFIG_PATH = "../ocr/config/1.yaml"
|
||||
KNOWN_FACES_DIR = "../ocr/known_faces"
|
||||
|
||||
# 创建检测器实例
|
||||
detector = MultiModelViolationDetector(
|
||||
forbidden_words_path=FORBIDDEN_WORDS_PATH,
|
||||
ocr_config_path=OCR_CONFIG_PATH,
|
||||
yolo_model_path=YOLO_MODEL_PATH,
|
||||
known_faces_dir=KNOWN_FACES_DIR,
|
||||
ocr_confidence_threshold=0.5
|
||||
)
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("rtmp_video_puller")
|
||||
|
||||
|
||||
async def rtmp_pull_video_stream(rtmp_url):
|
||||
"""
|
||||
通过RTMP从指定URL拉取视频流并进行违规检测
|
||||
"""
|
||||
cap = None # 初始化视频捕获对象
|
||||
try:
|
||||
# 异步打开RTMP流
|
||||
cap = await asyncio.to_thread(
|
||||
cv2.VideoCapture,
|
||||
rtmp_url,
|
||||
cv2.CAP_FFMPEG # 指定FFmpeg后端确保RTMP兼容性
|
||||
)
|
||||
|
||||
# 检查RTMP流是否成功打开
|
||||
is_opened = await asyncio.to_thread(cap.isOpened)
|
||||
if not is_opened:
|
||||
raise Exception(f"RTMP流打开失败: {rtmp_url}(请检查URL有效性和FFmpeg环境)")
|
||||
|
||||
# 获取RTMP流基础信息
|
||||
width = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_WIDTH)
|
||||
height = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_HEIGHT)
|
||||
fps = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FPS)
|
||||
|
||||
# 处理异常情况
|
||||
fps = fps if fps > 0 else 30.0
|
||||
width, height = int(width), int(height)
|
||||
|
||||
# 打印流初始化成功信息
|
||||
print(f"RTMP流状态: 已成功连接")
|
||||
print(f"流基础信息: 分辨率 {width}x{height} | 配置帧率 {fps:.2f} FPS")
|
||||
print("开始接收视频帧...(按 Ctrl+C 中断)")
|
||||
|
||||
# 初始化帧统计参数
|
||||
frame_count = 0
|
||||
start_time = time.time()
|
||||
|
||||
# 循环读取视频帧
|
||||
while True:
|
||||
ret, frame = await asyncio.to_thread(cap.read)
|
||||
|
||||
if not ret:
|
||||
print(f"RTMP流状态: 帧读取失败(可能流已中断或结束)")
|
||||
break
|
||||
|
||||
frame_count += 1
|
||||
|
||||
# 打印当前帧信息
|
||||
print(f"收到帧 (第{frame_count}帧)")
|
||||
print(f" 帧尺寸: {width}x{height}")
|
||||
print(f" 配置帧率: {fps:.2f} FPS")
|
||||
|
||||
if frame is not None:
|
||||
has_violation, violation_type, details = detector.detect_violations(frame)
|
||||
if has_violation:
|
||||
print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
|
||||
else:
|
||||
print("未检测到任何违规内容")
|
||||
else:
|
||||
print(f"无法读取测试图像")
|
||||
|
||||
# 每100帧统计一次实际接收帧率
|
||||
if frame_count % 100 == 0:
|
||||
elapsed_time = time.time() - start_time
|
||||
actual_fps = frame_count / elapsed_time
|
||||
print(f"---- 帧统计: 累计{frame_count}帧 | 实际平均帧率 {actual_fps:.2f} FPS ----")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n用户操作: 已通过 Ctrl+C 中断程序")
|
||||
except Exception as e:
|
||||
logger.error(f"RTMP流处理异常: {str(e)}", exc_info=True)
|
||||
print(f"错误信息: {str(e)}")
|
||||
finally:
|
||||
if cap is not None:
|
||||
await asyncio.to_thread(cap.release)
|
||||
print(f"\n资源释放: RTMP流已关闭")
|
||||
print(f"最终统计: 共接收 {frame_count if 'frame_count' in locals() else 0} 帧")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
RTMP_URL = "rtmp://192.168.110.25:1935/live/473b95a47e338301cbd96809ea7ac416"
|
||||
|
||||
try:
|
||||
asyncio.run(rtmp_pull_video_stream(RTMP_URL))
|
||||
except Exception as e:
|
||||
print(f"程序启动失败: {str(e)}")
|
54
core/yolo.py
Normal file
@ -0,0 +1,54 @@
|
||||
import os
|
||||
|
||||
from ultralytics import YOLO
|
||||
|
||||
# 全局变量
|
||||
_yolo_model = None
|
||||
|
||||
|
||||
model_path = os.path.join(os.path.dirname(__file__), "models", "best.pt")
|
||||
|
||||
|
||||
def load_model():
|
||||
"""加载YOLO目标检测模型"""
|
||||
global _yolo_model
|
||||
|
||||
try:
|
||||
_yolo_model = YOLO(model_path)
|
||||
except Exception as e:
|
||||
print(f"YOLO model load failed: {e}")
|
||||
return False
|
||||
|
||||
return True if _yolo_model else False
|
||||
|
||||
|
||||
def detect(frame, conf_threshold=0.2):
|
||||
"""YOLO目标检测、返回(是否识别到, 结果字符串)"""
|
||||
global _yolo_model
|
||||
|
||||
if not _yolo_model or frame is None:
|
||||
return (False, "未初始化或无效帧")
|
||||
|
||||
try:
|
||||
results = _yolo_model(frame, conf=conf_threshold)
|
||||
# 检查是否有检测结果
|
||||
has_results = len(results[0].boxes) > 0 if results else False
|
||||
|
||||
if not has_results:
|
||||
return (False, "未检测到目标")
|
||||
|
||||
# 构建结果字符串
|
||||
result_parts = []
|
||||
for box in results[0].boxes:
|
||||
cls = int(box.cls[0])
|
||||
conf = float(box.conf[0])
|
||||
bbox = [float(x) for x in box.xyxy[0]]
|
||||
class_name = _yolo_model.names[cls] if hasattr(_yolo_model, 'names') else f"类别{cls}"
|
||||
result_parts.append(f"{class_name} (置信度: {conf:.2f}, 边界框: {bbox})")
|
||||
|
||||
result_str = "; ".join(result_parts)
|
||||
return (has_results, result_str)
|
||||
|
||||
except Exception as e:
|
||||
print(f"YOLO detect error: {e}")
|
||||
return (False, f"检测错误: {str(e)}")
|
@ -14,4 +14,3 @@ config.read(config_path, encoding="utf-8")
|
||||
SERVER_CONFIG = config["server"]
|
||||
MYSQL_CONFIG = config["mysql"]
|
||||
JWT_CONFIG = config["jwt"]
|
||||
LIVE_CONFIG = config["live"]
|
||||
|
25
main.py
@ -1,11 +1,18 @@
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from PIL import Image # 正确导入
|
||||
import numpy as np
|
||||
|
||||
import uvicorn
|
||||
from PIL import Image
|
||||
from fastapi import FastAPI
|
||||
from core.all import load_model,detect
|
||||
from ds.config import SERVER_CONFIG
|
||||
from middle.error_handler import global_exception_handler
|
||||
from service.user_service import router as user_router
|
||||
from service.sensitive_service import router as sensitive_router
|
||||
from service.face_service import router as face_router
|
||||
from service.device_service import router as device_router
|
||||
from ws.ws import ws_router, lifespan
|
||||
from core.establish import create_directory_structure
|
||||
|
||||
# ------------------------------
|
||||
# 初始化 FastAPI 应用、指定生命周期管理
|
||||
@ -22,6 +29,8 @@ app = FastAPI(
|
||||
# ------------------------------
|
||||
app.include_router(user_router)
|
||||
app.include_router(device_router)
|
||||
app.include_router(face_router)
|
||||
app.include_router(sensitive_router)
|
||||
app.include_router(ws_router)
|
||||
|
||||
# ------------------------------
|
||||
@ -33,11 +42,21 @@ app.add_exception_handler(Exception, global_exception_handler)
|
||||
# 启动服务
|
||||
# ------------------------------
|
||||
if __name__ == "__main__":
|
||||
# -------------------------- 配置调整 --------------------------
|
||||
# 模型配置路径(建议改为环境变量)
|
||||
YOLO_MODEL_PATH = r"/core/models\best.pt"
|
||||
OCR_CONFIG_PATH = r"/core/config\config.yaml"
|
||||
|
||||
create_directory_structure()
|
||||
|
||||
# 初始化项目(默认端口设为8000、避免初始化失败时port未定义)
|
||||
port = int(SERVER_CONFIG.get("port", 8000))
|
||||
|
||||
# 启动 UVicorn 服务
|
||||
uvicorn.run(
|
||||
app="main:app",
|
||||
host="0.0.0.0",
|
||||
port=port,
|
||||
reload=True,
|
||||
workers=8,
|
||||
ws="websockets"
|
||||
)
|
||||
|
@ -8,7 +8,6 @@ from passlib.context import CryptContext
|
||||
|
||||
from ds.config import JWT_CONFIG
|
||||
from ds.db import db
|
||||
from service.user_service import UserResponse
|
||||
|
||||
# ------------------------------
|
||||
# 密码加密配置
|
||||
@ -22,9 +21,10 @@ SECRET_KEY = JWT_CONFIG["secret_key"]
|
||||
ALGORITHM = JWT_CONFIG["algorithm"]
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = int(JWT_CONFIG["access_token_expire_minutes"])
|
||||
|
||||
# OAuth2 依赖(从请求头获取 Token、格式:Bearer <token>)
|
||||
# OAuth2 依赖(从请求头获取 Token、格式: Bearer <token>)
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login")
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 密码工具函数
|
||||
# ------------------------------
|
||||
@ -32,10 +32,12 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""验证明文密码与加密密码是否匹配"""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""对明文密码进行 bcrypt 加密"""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# JWT 工具函数
|
||||
# ------------------------------
|
||||
@ -53,11 +55,15 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
|
||||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 认证依赖(获取当前登录用户)
|
||||
# ------------------------------
|
||||
def get_current_user(token: str = Depends(oauth2_scheme)) -> UserResponse:
|
||||
def get_current_user(token: str = Depends(oauth2_scheme)): # 移除返回类型注解
|
||||
"""从 Token 中解析用户信息、验证通过后返回当前用户"""
|
||||
# 延迟导入、打破循环依赖
|
||||
from schema.user_schema import UserResponse # 在这里导入
|
||||
|
||||
# 认证失败异常
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
@ -8,7 +8,7 @@ from schema.response_schema import APIResponse
|
||||
|
||||
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
"""全局异常处理器:所有未捕获的异常都会在这里统一处理"""
|
||||
"""全局异常处理器: 所有未捕获的异常都会在这里统一处理"""
|
||||
# 1. 请求参数验证错误(Pydantic 校验失败)
|
||||
if isinstance(exc, RequestValidationError):
|
||||
error_details = []
|
||||
@ -18,7 +18,7 @@ async def global_exception_handler(request: Request, exc: Exception):
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content=APIResponse(
|
||||
code=400,
|
||||
message=f"请求参数错误:{'; '.join(error_details)}",
|
||||
message=f"请求参数错误: {'; '.join(error_details)}",
|
||||
data=None
|
||||
).model_dump()
|
||||
)
|
||||
@ -52,7 +52,7 @@ async def global_exception_handler(request: Request, exc: Exception):
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content=APIResponse(
|
||||
code=500,
|
||||
message=f"数据库错误:{str(exc)}",
|
||||
message=f"数据库错误: {str(exc)}",
|
||||
data=None
|
||||
).model_dump()
|
||||
)
|
||||
@ -62,7 +62,7 @@ async def global_exception_handler(request: Request, exc: Exception):
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content=APIResponse(
|
||||
code=500,
|
||||
message=f"服务器内部错误:{str(exc)}",
|
||||
message=f"服务器内部错误: {str(exc)}",
|
||||
data=None
|
||||
).model_dump()
|
||||
)
|
||||
|
@ -1,139 +0,0 @@
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import insightface
|
||||
from insightface.app import FaceAnalysis
|
||||
|
||||
|
||||
class FaceRecognizer:
|
||||
"""
|
||||
封装InsightFace人脸识别功能,支持从文件夹加载已知人脸。
|
||||
"""
|
||||
|
||||
def __init__(self, known_faces_dir: str):
|
||||
self.known_faces_dir = known_faces_dir
|
||||
self.app = self._initialize_insightface()
|
||||
self.known_faces_embeddings = {}
|
||||
self.known_faces_names = []
|
||||
self._load_known_faces()
|
||||
|
||||
def _initialize_insightface(self):
|
||||
"""初始化InsightFace FaceAnalysis应用"""
|
||||
print("初始化InsightFace引擎...")
|
||||
try:
|
||||
app = FaceAnalysis(name='buffalo_l', root='~/.insightface')
|
||||
app.prepare(ctx_id=0, det_size=(640, 640))
|
||||
print("InsightFace引擎初始化完成")
|
||||
return app
|
||||
except Exception as e:
|
||||
print(f"InsightFace初始化失败: {e}")
|
||||
print("请检查依赖是否安装及模型是否可访问")
|
||||
return None
|
||||
|
||||
def _load_known_faces(self):
|
||||
"""加载已知人脸特征"""
|
||||
if not os.path.exists(self.known_faces_dir):
|
||||
print(f"已知人脸目录不存在,已创建: {self.known_faces_dir}")
|
||||
os.makedirs(self.known_faces_dir, exist_ok=True)
|
||||
return
|
||||
|
||||
print(f"从目录加载人脸特征: {self.known_faces_dir}")
|
||||
for person_name in os.listdir(self.known_faces_dir):
|
||||
person_dir = os.path.join(self.known_faces_dir, person_name)
|
||||
if os.path.isdir(person_dir):
|
||||
print(f"处理人物: {person_name}")
|
||||
embeddings = []
|
||||
for filename in os.listdir(person_dir):
|
||||
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
|
||||
image_path = os.path.join(person_dir, filename)
|
||||
try:
|
||||
img = cv2.imread(image_path)
|
||||
if img is None:
|
||||
print(f"无法读取图片: {image_path},已跳过")
|
||||
continue
|
||||
|
||||
faces = self.app.get(img)
|
||||
if faces:
|
||||
embeddings.append(faces[0].embedding)
|
||||
print(f"提取特征成功: {filename}")
|
||||
else:
|
||||
print(f"未检测到人脸: {filename},已跳过")
|
||||
except Exception as e:
|
||||
print(f"处理图片出错 {image_path}: {e}")
|
||||
|
||||
if embeddings:
|
||||
self.known_faces_embeddings[person_name] = np.array(embeddings).mean(axis=0)
|
||||
self.known_faces_names.append(person_name)
|
||||
print(f"人物 {person_name} 加载完成,共 {len(embeddings)} 张照片")
|
||||
else:
|
||||
print(f"人物 {person_name} 无有效特征,已跳过")
|
||||
print(f"人脸加载完成,共 {len(self.known_faces_names)} 人")
|
||||
|
||||
def recognize(self, frame, threshold=0.4):
|
||||
"""识别人脸并返回结果"""
|
||||
if not self.app or not self.known_faces_names:
|
||||
return False, None, None
|
||||
|
||||
faces = self.app.get(frame)
|
||||
if not faces:
|
||||
return False, None, None
|
||||
|
||||
for face in faces:
|
||||
for known_name in self.known_faces_names:
|
||||
known_embedding = self.known_faces_embeddings[known_name]
|
||||
|
||||
embedding1 = face.embedding.astype(np.float32)
|
||||
embedding2 = known_embedding.astype(np.float32)
|
||||
|
||||
dot_product = np.dot(embedding1, embedding2)
|
||||
norm_embedding1 = np.linalg.norm(embedding1)
|
||||
norm_embedding2 = np.linalg.norm(embedding2)
|
||||
|
||||
similarity = 0.0 if (norm_embedding1 == 0 or norm_embedding2 == 0) else (
|
||||
dot_product / (norm_embedding1 * norm_embedding2)
|
||||
)
|
||||
|
||||
if similarity >= threshold:
|
||||
print(f"检测到已知人物: {known_name} (相似度: {similarity:.4f})")
|
||||
return True, known_name, similarity
|
||||
|
||||
return False, None, None
|
||||
|
||||
def test_single_image(self, image_path: str, threshold=0.4):
|
||||
"""测试单张图片识别"""
|
||||
if not os.path.exists(image_path):
|
||||
print(f"图片不存在: {image_path}")
|
||||
return False, None, None
|
||||
|
||||
frame = cv2.imread(image_path)
|
||||
if frame is None:
|
||||
print(f"无法读取图片: {image_path}")
|
||||
return False, None, None
|
||||
|
||||
result, name, similarity = self.recognize(frame, threshold)
|
||||
|
||||
if result:
|
||||
print(f"识别结果: {name} (相似度: {similarity:.4f})")
|
||||
|
||||
faces = self.app.get(frame)
|
||||
for face in faces:
|
||||
bbox = face.bbox.astype(int)
|
||||
cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2)
|
||||
text = f"{name}: {similarity:.2f}"
|
||||
cv2.putText(frame, text, (bbox[0], bbox[1] - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
|
||||
|
||||
cv2.imshow('识别结果', frame)
|
||||
print("按任意键关闭窗口...")
|
||||
cv2.waitKey(0)
|
||||
cv2.destroyAllWindows()
|
||||
else:
|
||||
print("未识别到已知人脸")
|
||||
|
||||
return result, name, similarity
|
||||
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# recognizer = FaceRecognizer(known_faces_dir="known_faces")
|
||||
# test_image_path = r"F:\OCR\RapidOCR-main\known_faces\B\14sino-qiu02-master1050.jpg"
|
||||
# recognizer.test_single_image(test_image_path, threshold=0.4)
|
@ -1,18 +0,0 @@
|
||||
打倒习近平
|
||||
打到习近平
|
||||
打倒毛泽东
|
||||
打到毛泽东
|
||||
打到主席
|
||||
打倒主席
|
||||
打到共产主义
|
||||
打倒共产主义
|
||||
打到共产党
|
||||
打倒共产党
|
||||
胖猫
|
||||
法轮功
|
||||
法轮大法
|
||||
法轮大法好
|
||||
法轮功大法好
|
||||
法轮
|
||||
李洪志
|
||||
习近平
|
Before Width: | Height: | Size: 195 KiB |
Before Width: | Height: | Size: 208 KiB |
Before Width: | Height: | Size: 657 KiB |
Before Width: | Height: | Size: 53 KiB |
Before Width: | Height: | Size: 8.1 KiB |
Before Width: | Height: | Size: 14 KiB |
Before Width: | Height: | Size: 58 KiB |
Before Width: | Height: | Size: 4.9 KiB |
Before Width: | Height: | Size: 34 KiB |
Before Width: | Height: | Size: 155 KiB |
Before Width: | Height: | Size: 386 KiB |
Before Width: | Height: | Size: 1.4 MiB |
Before Width: | Height: | Size: 62 KiB |
@ -1,49 +0,0 @@
|
||||
#日志文件
|
||||
import logging
|
||||
import sys
|
||||
|
||||
def setup_logger():
|
||||
"""
|
||||
配置一个全局日志记录器,支持输出到控制台和文件。
|
||||
"""
|
||||
# 创建一个日志记录器
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger = logging.getLogger("ViolationDetectorLogger")
|
||||
logger.setLevel(logging.DEBUG) # 设置最低级别为DEBUG
|
||||
|
||||
# 如果已经有处理器了,就不要重复添加,防止日志重复打印
|
||||
if logger.hasHandlers():
|
||||
return logger
|
||||
|
||||
# --- 控制台处理器 ---
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
# 对于控制台,我们只显示INFO及以上级别的信息
|
||||
console_handler.setLevel(logging.INFO)
|
||||
console_formatter = logging.Formatter(
|
||||
'%(asctime)s - %(levelname)s - [%(module)s:%(lineno)d] - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
console_handler.setFormatter(console_formatter)
|
||||
|
||||
# --- 文件处理器 ---
|
||||
file_handler = logging.FileHandler("violation_detector.log", mode='a', encoding='utf-8')
|
||||
# 对于文件,我们记录所有DEBUG及以上级别的信息
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
file_formatter = logging.Formatter(
|
||||
'%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
file_handler.setFormatter(file_formatter)
|
||||
|
||||
# 将处理器添加到日志记录器
|
||||
logger.addHandler(console_handler)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
return logger
|
||||
|
||||
# 创建并导出logger实例
|
||||
logger = setup_logger()
|
@ -1,136 +0,0 @@
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from .ocr_violation_detector import OCRViolationDetector
|
||||
from .yolo_violation_detector import ViolationDetector as YoloViolationDetector
|
||||
from .face_recognizer import FaceRecognizer
|
||||
|
||||
class MultiModelViolationDetector:
|
||||
"""
|
||||
多模型违规检测封装类,串行调用OCR、人脸识别和YOLO模型,任一模型检测到违规即返回结果
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
forbidden_words_path: str,
|
||||
ocr_config_path: str,
|
||||
yolo_model_path: str,
|
||||
known_faces_dir: str,
|
||||
ocr_confidence_threshold: float = 0.5):
|
||||
"""
|
||||
初始化所有检测模型
|
||||
"""
|
||||
# 初始化OCR检测器
|
||||
self.ocr_detector = OCRViolationDetector(
|
||||
forbidden_words_path=forbidden_words_path,
|
||||
ocr_config_path=ocr_config_path,
|
||||
ocr_confidence_threshold=ocr_confidence_threshold
|
||||
)
|
||||
|
||||
# 初始化人脸识别器
|
||||
self.face_recognizer = FaceRecognizer(
|
||||
known_faces_dir=known_faces_dir
|
||||
)
|
||||
|
||||
# 初始化YOLO检测器
|
||||
self.yolo_detector = YoloViolationDetector(
|
||||
model_path=yolo_model_path
|
||||
)
|
||||
|
||||
print("多模型违规检测器初始化完成")
|
||||
|
||||
def detect_violations(self, frame):
|
||||
"""
|
||||
串行调用三个检测模型(OCR → 人脸识别 → YOLO),任一检测到违规即返回结果
|
||||
"""
|
||||
# 1. 首先进行OCR违禁词检测
|
||||
try:
|
||||
ocr_has_violation, ocr_words, ocr_confs = self.ocr_detector.detect(frame)
|
||||
if ocr_has_violation:
|
||||
details = {
|
||||
"words": ocr_words,
|
||||
"confidences": ocr_confs
|
||||
}
|
||||
print(f"警告: OCR检测到违禁内容: {details}")
|
||||
return (True, "ocr", details)
|
||||
except Exception as e:
|
||||
print(f"错误: OCR检测出错: {str(e)}")
|
||||
|
||||
# 2. 接着进行人脸识别检测
|
||||
try:
|
||||
face_has_violation, face_name, face_similarity = self.face_recognizer.recognize(frame)
|
||||
if face_has_violation:
|
||||
details = {
|
||||
"name": face_name,
|
||||
"similarity": face_similarity
|
||||
}
|
||||
print(f"警告: 人脸识别到违规人员: {details}")
|
||||
return (True, "face", details)
|
||||
except Exception as e:
|
||||
print(f"错误: 人脸识别出错: {str(e)}")
|
||||
|
||||
# 3. 最后进行YOLO目标检测
|
||||
try:
|
||||
yolo_results = self.yolo_detector.detect(frame)
|
||||
if len(yolo_results.boxes) > 0:
|
||||
details = {
|
||||
"classes": yolo_results.names,
|
||||
"boxes": yolo_results.boxes.xyxy.tolist(),
|
||||
"confidences": yolo_results.boxes.conf.tolist(),
|
||||
"class_ids": yolo_results.boxes.cls.tolist()
|
||||
}
|
||||
print(f"警告: YOLO检测到违规目标: {details}")
|
||||
return (True, "yolo", details)
|
||||
except Exception as e:
|
||||
print(f"错误: YOLO检测出错: {str(e)}")
|
||||
|
||||
# 所有检测均未发现违规
|
||||
return (False, None, None)
|
||||
|
||||
|
||||
def load_config(config_path: str) -> dict:
|
||||
"""加载YAML配置文件"""
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
return yaml.safe_load(f)
|
||||
except FileNotFoundError:
|
||||
print(f"错误: 配置文件未找到: {config_path}")
|
||||
raise
|
||||
except yaml.YAMLError as e:
|
||||
print(f"错误: 配置文件格式错误: {config_path}, 错误: {str(e)}")
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"错误: 加载配置文件出错: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
# 使用示例
|
||||
# if __name__ == "__main__":
|
||||
# # 加载配置文件
|
||||
# config = load_config("config.yaml") # 配置文件路径,可根据实际情况修改
|
||||
#
|
||||
# # 初始化多模型检测器
|
||||
# detector = MultiModelViolationDetector(
|
||||
# forbidden_words_path=config["forbidden_words_path"],
|
||||
# ocr_config_path=config["ocr_config_path"],
|
||||
# yolo_model_path=config["yolo_model_path"],
|
||||
# known_faces_dir=config["known_faces_dir"],
|
||||
# ocr_confidence_threshold=config.get("ocr_confidence_threshold", 0.5)
|
||||
# )
|
||||
#
|
||||
# # 读取测试图像(可替换为视频帧读取逻辑)
|
||||
# test_image_path = config.get("test_image_path") # 从配置文件获取测试图片路径
|
||||
# if test_image_path:
|
||||
# frame = cv2.imread(test_image_path)
|
||||
#
|
||||
# if frame is not None:
|
||||
# has_violation, violation_type, details = detector.detect_violations(frame)
|
||||
# if has_violation:
|
||||
# print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
|
||||
# else:
|
||||
# print("未检测到任何违规内容")
|
||||
# else:
|
||||
# print(f"无法读取测试图像: {test_image_path}")
|
||||
# else:
|
||||
# print("配置文件中未指定测试图像路径")
|
@ -1,178 +0,0 @@
|
||||
import os
|
||||
import cv2
|
||||
from rapidocr import RapidOCR
|
||||
|
||||
|
||||
class OCRViolationDetector:
|
||||
"""
|
||||
封装RapidOCR引擎,用于检测图像帧中的违禁词。
|
||||
核心功能:加载违禁词、初始化OCR引擎、单帧图像违禁词检测
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
forbidden_words_path: str,
|
||||
ocr_config_path: str,
|
||||
ocr_confidence_threshold: float = 0.5):
|
||||
"""
|
||||
初始化OCR引擎和违禁词列表。
|
||||
|
||||
Args:
|
||||
forbidden_words_path (str): 违禁词列表 .txt 文件的路径。
|
||||
ocr_config_path (str): OCR配置文件(如1.yaml)的路径。
|
||||
ocr_confidence_threshold (float): OCR识别结果的置信度阈值(0~1)。
|
||||
"""
|
||||
# 加载违禁词
|
||||
self.forbidden_words = self._load_forbidden_words(forbidden_words_path)
|
||||
|
||||
# 初始化RapidOCR引擎
|
||||
self.ocr_engine = self._initialize_ocr(ocr_config_path)
|
||||
|
||||
# 校验核心依赖是否就绪
|
||||
self._check_dependencies()
|
||||
|
||||
# 设置置信度阈值(限制在0~1范围)
|
||||
self.OCR_CONFIDENCE_THRESHOLD = max(0.0, min(ocr_confidence_threshold, 1.0))
|
||||
print(f"OCR置信度阈值已设置(范围0~1): {self.OCR_CONFIDENCE_THRESHOLD:.4f}")
|
||||
|
||||
def _load_forbidden_words(self, path: str) -> set:
|
||||
"""
|
||||
从TXT文件加载违禁词(去重、过滤空行,支持UTF-8编码)
|
||||
"""
|
||||
forbidden_words = set()
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(path):
|
||||
print(f"错误:违禁词文件不存在: {path}")
|
||||
return forbidden_words
|
||||
|
||||
# 读取文件并处理内容
|
||||
try:
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
forbidden_words = {
|
||||
line.strip() for line in f
|
||||
if line.strip() # 跳过空行或纯空格行
|
||||
}
|
||||
print(f"成功加载违禁词: {len(forbidden_words)} 个(已去重)")
|
||||
except UnicodeDecodeError:
|
||||
print(f"错误:违禁词文件编码错误(需UTF-8): {path}")
|
||||
except PermissionError:
|
||||
print(f"错误:无权限读取违禁词文件: {path}")
|
||||
except Exception as e:
|
||||
print(f"错误:加载违禁词失败: {str(e)}")
|
||||
|
||||
return forbidden_words
|
||||
|
||||
def _initialize_ocr(self, config_path: str) -> RapidOCR | None:
|
||||
"""
|
||||
初始化RapidOCR引擎(校验配置文件、捕获初始化异常)
|
||||
"""
|
||||
print("开始初始化RapidOCR引擎...")
|
||||
|
||||
# 检查配置文件是否存在
|
||||
if not os.path.exists(config_path):
|
||||
print(f"错误:OCR配置文件不存在: {config_path}")
|
||||
return None
|
||||
|
||||
# 初始化OCR引擎
|
||||
try:
|
||||
ocr_engine = RapidOCR(config_path=config_path)
|
||||
print("RapidOCR引擎初始化成功")
|
||||
return ocr_engine
|
||||
except ImportError:
|
||||
print("错误:RapidOCR依赖未安装(需执行:pip install rapidocr-onnxruntime)")
|
||||
except Exception as e:
|
||||
print(f"错误:RapidOCR初始化失败: {str(e)}")
|
||||
|
||||
return None
|
||||
|
||||
def _check_dependencies(self) -> None:
|
||||
"""校验OCR引擎和违禁词列表是否就绪"""
|
||||
if not self.ocr_engine:
|
||||
print("警告:⚠️ OCR引擎未就绪,违禁词检测功能将禁用")
|
||||
if not self.forbidden_words:
|
||||
print("警告:⚠️ 违禁词列表为空,违禁词检测功能将禁用")
|
||||
|
||||
def detect(self, frame) -> tuple[bool, list, list]:
|
||||
"""
|
||||
对单帧图像进行OCR违禁词检测(核心方法)
|
||||
|
||||
Args:
|
||||
frame: 输入图像帧(NumPy数组,BGR格式,cv2读取的图像)。
|
||||
|
||||
Returns:
|
||||
tuple[bool, list, list]:
|
||||
- 第一个元素:是否检测到违禁词(True/False);
|
||||
- 第二个元素:检测到的违禁词列表(空列表表示无违禁词);
|
||||
- 第三个元素:对应违禁词的置信度列表(与违禁词列表一一对应)。
|
||||
"""
|
||||
# 初始化返回结果
|
||||
has_violation = False
|
||||
violation_words = []
|
||||
violation_confs = []
|
||||
|
||||
# 前置校验
|
||||
if frame is None or frame.size == 0:
|
||||
print("警告:输入图像帧为空或无效,跳过OCR检测")
|
||||
return has_violation, violation_words, violation_confs
|
||||
if not self.ocr_engine or not self.forbidden_words:
|
||||
print("OCR引擎未就绪或违禁词为空,跳过OCR检测")
|
||||
return has_violation, violation_words, violation_confs
|
||||
|
||||
try:
|
||||
# 执行OCR识别
|
||||
print("开始执行OCR识别...")
|
||||
ocr_result = self.ocr_engine(frame)
|
||||
print(f"RapidOCR原始结果: {ocr_result}")
|
||||
|
||||
# 校验OCR结果是否有效
|
||||
if ocr_result is None:
|
||||
print("OCR识别未返回任何结果(图像无文本或识别失败)")
|
||||
return has_violation, violation_words, violation_confs
|
||||
|
||||
# 检查txts和scores是否存在且不为None
|
||||
if not hasattr(ocr_result, 'txts') or ocr_result.txts is None:
|
||||
print("警告:OCR结果中txts为None或不存在")
|
||||
return has_violation, violation_words, violation_confs
|
||||
|
||||
if not hasattr(ocr_result, 'scores') or ocr_result.scores is None:
|
||||
print("警告:OCR结果中scores为None或不存在")
|
||||
return has_violation, violation_words, violation_confs
|
||||
|
||||
# 转为列表并去None
|
||||
if not isinstance(ocr_result.txts, (list, tuple)):
|
||||
print(f"警告:OCR txts不是可迭代类型,实际类型: {type(ocr_result.txts)}")
|
||||
texts = []
|
||||
else:
|
||||
texts = [txt.strip() for txt in ocr_result.txts if txt and isinstance(txt, str)]
|
||||
|
||||
if not isinstance(ocr_result.scores, (list, tuple)):
|
||||
print(f"警告:OCR scores不是可迭代类型,实际类型: {type(ocr_result.scores)}")
|
||||
confidences = []
|
||||
else:
|
||||
confidences = [conf for conf in ocr_result.scores if conf and isinstance(conf, (int, float))]
|
||||
|
||||
# 校验文本和置信度列表长度是否一致
|
||||
if len(texts) != len(confidences):
|
||||
print(f"警告:OCR文本与置信度数量不匹配(文本{len(texts)}个,置信度{len(confidences)}个),跳过检测")
|
||||
return has_violation, violation_words, violation_confs
|
||||
if len(texts) == 0:
|
||||
print("OCR未识别到任何有效文本")
|
||||
return has_violation, violation_words, violation_confs
|
||||
|
||||
# 遍历识别结果,筛选违禁词
|
||||
print(f"开始筛选违禁词(阈值{self.OCR_CONFIDENCE_THRESHOLD:.4f})")
|
||||
for text, conf in zip(texts, confidences):
|
||||
if conf < self.OCR_CONFIDENCE_THRESHOLD:
|
||||
print(f"文本 '{text}' 置信度{conf:.4f} < 阈值,跳过")
|
||||
continue
|
||||
matched_words = [word for word in self.forbidden_words if word in text]
|
||||
if matched_words:
|
||||
has_violation = True
|
||||
violation_words.extend(matched_words)
|
||||
violation_confs.extend([conf] * len(matched_words))
|
||||
print(f"警告:检测到违禁词: {matched_words}(来源文本: '{text}',置信度: {conf:.4f})")
|
||||
|
||||
except Exception as e:
|
||||
print(f"错误:OCR检测过程异常: {str(e)}")
|
||||
|
||||
return has_violation, violation_words, violation_confs
|
@ -1,47 +0,0 @@
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
|
||||
class ViolationDetector:
|
||||
"""
|
||||
用于加载YOLOv8 .pt模型并进行违规内容检测的类。
|
||||
"""
|
||||
def __init__(self, model_path):
|
||||
"""
|
||||
初始化检测器。
|
||||
|
||||
Args:
|
||||
model_path (str): YOLO .pt模型的路径。
|
||||
"""
|
||||
print(f"正在从 '{model_path}' 加载YOLO模型...")
|
||||
self.model = YOLO(model_path)
|
||||
print("YOLO模型加载成功。")
|
||||
|
||||
def detect(self, frame):
|
||||
"""
|
||||
对单帧图像进行目标检测。
|
||||
|
||||
Args:
|
||||
frame: 输入的图像帧 (NumPy数组, BGR格式)。
|
||||
|
||||
Returns:
|
||||
ultralytics.engine.results.Results: YOLO的检测结果对象。
|
||||
"""
|
||||
# conf可以根据您的模型效果进行调整
|
||||
# --- 为了测试,我们暂时将置信度调低,例如 0.2 ---
|
||||
results = self.model(frame, conf=0.2)
|
||||
return results[0]
|
||||
|
||||
def draw_boxes(self, frame, result):
|
||||
"""
|
||||
在图像帧上绘制检测框。
|
||||
|
||||
Args:
|
||||
frame: 原始图像帧。
|
||||
result: YOLO的检测结果对象。
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: 绘制了检测框的图像帧。
|
||||
"""
|
||||
# 使用YOLO自带的plot功能,方便快捷
|
||||
annotated_frame = result.plot()
|
||||
return annotated_frame
|
164
rtc/rtc.py
@ -1,164 +0,0 @@
|
||||
import queue
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import threading
|
||||
import time
|
||||
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration
|
||||
from aiortc.mediastreams import MediaStreamTrack
|
||||
|
||||
# 创建一个长度为1的队列,用于生产者和消费者之间的通信
|
||||
frame_queue = queue.Queue(maxsize=1)
|
||||
|
||||
|
||||
class VideoTrack(MediaStreamTrack):
|
||||
"""自定义视频轨道类,继承自MediaStreamTrack"""
|
||||
kind = "video"
|
||||
|
||||
def __init__(self, max_frames=100):
|
||||
super().__init__()
|
||||
self.frames = queue.Queue(maxsize=max_frames)
|
||||
|
||||
async def recv(self):
|
||||
return await super().recv()
|
||||
|
||||
|
||||
def webrtc_producer(webrtc_url):
|
||||
"""
|
||||
生产者方法:从WEBRTC读取视频帧并放入队列
|
||||
仅当队列空时才放入新帧,否则丢弃
|
||||
"""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# 创建RTCPeerConnection对象,不使用ICE服务器
|
||||
pc = RTCPeerConnection(RTCConfiguration(iceServers=[]))
|
||||
video_track = VideoTrack()
|
||||
pc.addTrack(video_track)
|
||||
|
||||
@pc.on("track")
|
||||
async def on_track(track):
|
||||
if track.kind == "video":
|
||||
print("接收到视频轨道,开始接收视频帧")
|
||||
while True:
|
||||
# 从轨道接收视频帧
|
||||
frame = await track.recv()
|
||||
# 转换为BGR24格式的NumPy数组
|
||||
frame_bgr24 = frame.to_ndarray(format='bgr24')
|
||||
|
||||
# 检查队列是否为空,为空则加入,否则丢弃
|
||||
if frame_queue.empty():
|
||||
try:
|
||||
frame_queue.put_nowait(frame_bgr24)
|
||||
print("帧已放入队列")
|
||||
except queue.Full:
|
||||
print("队列已满,丢弃帧")
|
||||
else:
|
||||
print("队列非空,丢弃帧")
|
||||
|
||||
async def main():
|
||||
# 创建并发送SDP Offer
|
||||
offer = await pc.createOffer()
|
||||
print("已创建本地SDP Offer")
|
||||
await pc.setLocalDescription(offer)
|
||||
|
||||
# 发送Offer到服务器并接收Answer
|
||||
async with aiohttp.ClientSession() as session:
|
||||
print(f"开始向服务器 {webrtc_url} 发送SDP Offer")
|
||||
async with session.post(
|
||||
webrtc_url,
|
||||
data=offer.sdp.encode(),
|
||||
headers={
|
||||
"Content-Type": "application/sdp",
|
||||
"Content-Length": str(len(offer.sdp))
|
||||
},
|
||||
ssl=False
|
||||
) as response:
|
||||
print("已接收到服务器的响应")
|
||||
answer_sdp = await response.text()
|
||||
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer'))
|
||||
|
||||
# 保持连接
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(0.1)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
print("关闭RTCPeerConnection")
|
||||
await pc.close()
|
||||
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
def frame_consumer(ip):
|
||||
"""
|
||||
消费者方法:从队列中读取帧并处理
|
||||
每次处理后休眠200ms模拟延迟
|
||||
"""
|
||||
print("消费者启动,开始等待帧...")
|
||||
try:
|
||||
while True:
|
||||
# 阻塞等待队列中的帧
|
||||
frame = frame_queue.get()
|
||||
print(f"消费帧,大小: {frame.shape}")
|
||||
|
||||
has_violation, violations, confidences = OCRViolationDetector.detect(frame)
|
||||
|
||||
|
||||
# 输出检测结果
|
||||
if has_violation:
|
||||
detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
|
||||
for word, conf in zip(violations, confidences):
|
||||
detector.logger.info(f"- {word} (置信度: {conf:.4f})")
|
||||
else:
|
||||
detector.logger.info("图片中未检测到违禁词")
|
||||
|
||||
|
||||
# 标记任务完成
|
||||
frame_queue.task_done()
|
||||
except KeyboardInterrupt:
|
||||
print("消费者退出")
|
||||
|
||||
|
||||
def start_webrtc_stream(ip, webrtc_url):
|
||||
"""
|
||||
启动WebRTC视频流处理的主方法
|
||||
参数: webrtc_url - WebRTC服务器地址
|
||||
"""
|
||||
print(f"开始连接到WebRTC服务器: {webrtc_url}")
|
||||
|
||||
# 启动生产者线程
|
||||
producer_thread = threading.Thread(
|
||||
target=webrtc_producer,
|
||||
args=(webrtc_url,),
|
||||
daemon=True,
|
||||
name="webrtc-producer"
|
||||
)
|
||||
|
||||
# 启动消费者线程
|
||||
consumer_thread = threading.Thread(
|
||||
target=frame_consumer(ip),
|
||||
daemon=True,
|
||||
name="frame-consumer"
|
||||
)
|
||||
|
||||
producer_thread.start()
|
||||
consumer_thread.start()
|
||||
print("生产者和消费者线程已启动")
|
||||
|
||||
try:
|
||||
# 保持主线程运行
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
print("程序正在退出...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 示例用法
|
||||
# 实际使用时替换为真实的WebRTC服务器地址
|
||||
webrtc_server_url = "http://192.168.110.65:1985/rtc/v1/whep/?app=live&stream=677a4845aa48cb8526c811ad56fc5e60"
|
||||
start_webrtc_stream(webrtc_server_url)
|
101
rtmp/rtmp.py
@ -1,101 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import cv2
|
||||
import time
|
||||
|
||||
# 配置日志(与WHEP代码保持一致的日志风格)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("rtmp_video_puller")
|
||||
|
||||
|
||||
async def rtmp_pull_video_stream(rtmp_url):
|
||||
"""
|
||||
通过RTMP从指定URL拉取视频流并在收到每一帧时打印消息
|
||||
功能与WHEP拉流函数对齐:流状态反馈、帧信息打印、帧率统计、异常处理
|
||||
|
||||
Args:
|
||||
rtmp_url: RTMP流的URL地址(如 rtmp://xxx/live/stream_key)
|
||||
"""
|
||||
cap = None # 初始化视频捕获对象
|
||||
try:
|
||||
# 1. 异步打开RTMP流(指定FFmpeg后端确保RTMP兼容性,同步操作通过to_thread避免阻塞事件循环)
|
||||
cap = await asyncio.to_thread(
|
||||
cv2.VideoCapture,
|
||||
rtmp_url,
|
||||
cv2.CAP_FFMPEG # 必须指定FFmpeg后端,RTMP协议依赖该后端解析
|
||||
)
|
||||
|
||||
# 2. 检查RTMP流是否成功打开
|
||||
is_opened = await asyncio.to_thread(cap.isOpened)
|
||||
if not is_opened:
|
||||
raise Exception(f"RTMP流打开失败: {rtmp_url}(请检查URL有效性和FFmpeg环境)")
|
||||
|
||||
# 3. 异步获取RTMP流基础信息(分辨率、帧率)
|
||||
width = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_WIDTH)
|
||||
height = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_HEIGHT)
|
||||
fps = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FPS)
|
||||
|
||||
# 处理异常情况:部分RTMP流未返回帧率时默认30FPS
|
||||
fps = fps if fps > 0 else 30.0
|
||||
# 分辨率转为整数(视频尺寸必然是整数)
|
||||
width, height = int(width), int(height)
|
||||
|
||||
# 打印流初始化成功信息(与WHEP连接成功信息风格一致)
|
||||
print(f"RTMP流状态: 已成功连接")
|
||||
print(f"流基础信息: 分辨率 {width}x{height} | 配置帧率 {fps:.2f} FPS")
|
||||
print("开始接收视频帧...(按 Ctrl+C 中断)")
|
||||
|
||||
# 4. 初始化帧统计参数
|
||||
frame_count = 0 # 总接收帧数
|
||||
start_time = time.time() # 统计起始时间
|
||||
|
||||
# 5. 循环异步读取视频帧(核心逻辑)
|
||||
while True:
|
||||
# 异步读取一帧(cv2.read是同步操作,用to_thread适配异步环境)
|
||||
ret, frame = await asyncio.to_thread(cap.read)
|
||||
|
||||
# 检查帧是否读取成功(流中断/结束时ret为False)
|
||||
if not ret:
|
||||
print(f"RTMP流状态: 帧读取失败(可能流已中断或结束)")
|
||||
break
|
||||
|
||||
# 帧计数累加
|
||||
frame_count += 1
|
||||
|
||||
# 6. 打印当前帧基础信息(与WHEP帧信息打印风格对齐)
|
||||
print(f"收到帧 (第{frame_count}帧)")
|
||||
print(f" 帧尺寸: {width}x{height}")
|
||||
print(f" 配置帧率: {fps:.2f} FPS")
|
||||
|
||||
# 7. 每100帧统计一次实际接收帧率(补充性能监控,与原RTMP示例逻辑一致)
|
||||
if frame_count % 100 == 0:
|
||||
elapsed_time = time.time() - start_time
|
||||
actual_fps = frame_count / elapsed_time # 实际接收帧率(可能低于配置帧率)
|
||||
print(f"---- 帧统计: 累计{frame_count}帧 | 实际平均帧率 {actual_fps:.2f} FPS ----")
|
||||
|
||||
# (可选)帧数据处理入口:如需处理帧(如推流、分析),可在此处添加逻辑
|
||||
# 示例:yield frame (若需生成器模式,可调整函数为异步生成器)
|
||||
|
||||
# 8. 异常处理(覆盖用户中断、通用错误)
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n用户操作: 已通过 Ctrl+C 中断程序")
|
||||
except Exception as e:
|
||||
# 日志记录详细错误(便于问题排查),同时打印用户可见信息
|
||||
logger.error(f"RTMP流处理异常: {str(e)}", exc_info=True)
|
||||
print(f"错误信息: {str(e)}")
|
||||
finally:
|
||||
# 9. 资源释放(无论成功/失败都确保释放,避免内存泄漏)
|
||||
if cap is not None:
|
||||
await asyncio.to_thread(cap.release)
|
||||
print(f"\n资源释放: RTMP流已关闭")
|
||||
print(f"最终统计: 共接收 {frame_count if 'frame_count' in locals() else 0} 帧")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
RTMP_URL = "rtmp://192.168.110.25:1935/live/473b95a47e338301cbd96809ea7ac416"
|
||||
|
||||
# 运行RTMP拉流任务(与WHEP一致的异步执行方式)
|
||||
try:
|
||||
asyncio.run(rtmp_pull_video_stream(RTMP_URL))
|
||||
except Exception as e:
|
||||
print(f"程序启动失败: {str(e)}")
|
36
schema/device_action_schema.py
Normal file
@ -0,0 +1,36 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 请求模型
|
||||
# ------------------------------
|
||||
class DeviceActionCreate(BaseModel):
|
||||
"""设备操作记录创建模型(0=离线、1=上线)"""
|
||||
client_ip: str = Field(..., description="客户端IP")
|
||||
action: int = Field(..., ge=0, le=1, description="操作状态(0=离线、1=上线)")
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 响应模型(单条记录)
|
||||
# ------------------------------
|
||||
class DeviceActionResponse(BaseModel):
|
||||
"""设备操作记录响应模型(与自增表对齐)"""
|
||||
id: int = Field(..., description="自增主键ID")
|
||||
client_ip: Optional[str] = Field(None, description="客户端IP")
|
||||
action: Optional[int] = Field(None, description="操作状态(0=离线、1=上线)")
|
||||
created_at: datetime = Field(..., description="记录创建时间")
|
||||
updated_at: datetime = Field(..., description="记录更新时间")
|
||||
|
||||
# 支持从数据库结果直接转换
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 列表响应模型(仅含 total + device_actions)
|
||||
# ------------------------------
|
||||
class DeviceActionListResponse(BaseModel):
|
||||
"""设备操作记录列表(仅核心返回字段)"""
|
||||
total: int = Field(..., description="总记录数")
|
||||
device_actions: List[DeviceActionResponse] = Field(..., description="操作记录列表")
|
@ -1,4 +1,3 @@
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict
|
||||
|
||||
@ -6,42 +5,31 @@ from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 请求模型(前端传参校验)
|
||||
# 请求模型
|
||||
# ------------------------------
|
||||
class DeviceCreateRequest(BaseModel):
|
||||
"""设备流信息创建请求模型"""
|
||||
"""设备流信息创建请求模型(与数据库表字段对齐)"""
|
||||
ip: Optional[str] = Field(..., max_length=100, description="设备IP地址")
|
||||
hostname: Optional[str] = Field(None, max_length=100, description="设备别名")
|
||||
params: Optional[Dict] = Field(None, description="设备详细信息")
|
||||
|
||||
|
||||
def md5_encrypt(text: str) -> str:
|
||||
"""对字符串进行MD5加密"""
|
||||
if not text:
|
||||
return ""
|
||||
md5_hash = hashlib.md5()
|
||||
md5_hash.update(text.encode('utf-8'))
|
||||
return md5_hash.hexdigest()
|
||||
params: Optional[Dict] = Field(None, description="设备详细信息(JSON格式)")
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 响应模型(后端返回设备数据)
|
||||
# 响应模型(后端返回数据)- 严格对齐数据库表字段
|
||||
# ------------------------------
|
||||
class DeviceResponse(BaseModel):
|
||||
"""设备流信息响应模型(字段与表结构完全对齐)"""
|
||||
id: int = Field(..., description="设备ID")
|
||||
"""设备流信息响应模型(与数据库表字段完全一致)"""
|
||||
id: int = Field(..., description="设备主键ID")
|
||||
client_ip: Optional[str] = Field(None, max_length=100, description="设备IP地址")
|
||||
hostname: Optional[str] = Field(None, max_length=100, description="设备别名")
|
||||
rtmp_push_url: Optional[str] = Field(None, description="需要推送的RTMP地址")
|
||||
live_webrtc_url: Optional[str] = Field(None, description="直播的Webrtc地址")
|
||||
detection_webrtc_url: Optional[str] = Field(None, description="检测的Webrtc地址")
|
||||
device_online_status: int = Field(..., description="设备在线状态(1-在线、0-离线)")
|
||||
device_type: Optional[str] = Field(None, description="设备类型")
|
||||
alarm_count: int = Field(..., description="报警次数")
|
||||
params: Optional[str] = Field(None, description="设备详细信息")
|
||||
params: Optional[str] = Field(None, description="设备详细信息(JSON字符串)")
|
||||
created_at: datetime = Field(..., description="记录创建时间")
|
||||
updated_at: datetime = Field(..., description="记录更新时间")
|
||||
|
||||
# 支持从数据库查询结果转换
|
||||
# 支持从数据库查询结果直接转换
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
|
@ -6,8 +6,8 @@ from pydantic import BaseModel, Field
|
||||
# 请求模型(前端传参校验)
|
||||
# ------------------------------
|
||||
class FaceCreateRequest(BaseModel):
|
||||
"""创建人脸记录请求模型(无需ID,由数据库自增)"""
|
||||
name: str = Field(None, max_length=255, description="名称(可选,最长255字符)")
|
||||
"""创建人脸记录请求模型(无需ID、由数据库自增)"""
|
||||
name: str = Field(None, max_length=255, description="名称(可选、最长255字符)")
|
||||
|
||||
|
||||
class FaceUpdateRequest(BaseModel):
|
||||
@ -20,10 +20,10 @@ class FaceUpdateRequest(BaseModel):
|
||||
# 响应模型(后端返回数据)
|
||||
# ------------------------------
|
||||
class FaceResponse(BaseModel):
|
||||
"""人脸记录响应模型(仍包含ID,由数据库生成后返回)"""
|
||||
"""人脸记录响应模型(仍包含ID、由数据库生成后返回)"""
|
||||
id: int = Field(..., description="主键ID(数据库自增)")
|
||||
name: str = Field(None, description="名称")
|
||||
eigenvalue: str = Field(None, description="特征(暂为None)")
|
||||
eigenvalue: str | None = Field(None, description="特征(可为空)")
|
||||
created_at: datetime = Field(..., description="记录创建时间")
|
||||
updated_at: datetime = Field(..., description="记录更新时间")
|
||||
|
||||
|
@ -5,9 +5,9 @@ from pydantic import BaseModel, Field
|
||||
|
||||
class APIResponse(BaseModel):
|
||||
"""统一 API 响应模型(所有接口必返此格式)"""
|
||||
code: int = Field(..., description="状态码:200=成功、4xx=客户端错误、5xx=服务端错误")
|
||||
message: str = Field(..., description="响应信息:成功/错误描述")
|
||||
data: Optional[Any] = Field(None, description="响应数据:成功时返回、错误时为 None")
|
||||
code: int = Field(..., description="状态码: 200=成功、4xx=客户端错误、5xx=服务端错误")
|
||||
message: str = Field(..., description="响应信息: 成功/错误描述")
|
||||
data: Optional[Any] = Field(None, description="响应数据: 成功时返回、错误时为 None")
|
||||
|
||||
# Pydantic V2 配置(支持从 ORM 对象转换)
|
||||
model_config = {"from_attributes": True}
|
||||
|
@ -7,7 +7,7 @@ from pydantic import BaseModel, Field
|
||||
# ------------------------------
|
||||
class SensitiveCreateRequest(BaseModel):
|
||||
"""创建敏感信息记录请求模型"""
|
||||
# 移除了id字段,由数据库自动生成
|
||||
# 移除了id字段、由数据库自动生成
|
||||
name: str = Field(None, max_length=255, description="名称")
|
||||
|
||||
|
||||
|
158
service/device_action_service.py
Normal file
@ -0,0 +1,158 @@
|
||||
from fastapi import APIRouter, Query, Path
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
from ds.db import db
|
||||
from schema.device_action_schema import (
|
||||
DeviceActionCreate,
|
||||
DeviceActionResponse,
|
||||
DeviceActionListResponse
|
||||
)
|
||||
from schema.response_schema import APIResponse
|
||||
|
||||
# 路由配置
|
||||
router = APIRouter(
|
||||
prefix="/device/actions",
|
||||
tags=["设备操作记录"]
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 内部方法: 新增设备操作记录(适配id自增)
|
||||
# ------------------------------
|
||||
def add_device_action(action_data: DeviceActionCreate) -> DeviceActionResponse:
|
||||
"""
|
||||
新增设备操作记录(内部方法、非接口)
|
||||
:param action_data: 含client_ip和action(0/1)
|
||||
:return: 新增的完整记录
|
||||
"""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 插入SQL(id自增、依赖数据库自动生成)
|
||||
insert_query = """
|
||||
INSERT INTO device_action
|
||||
(client_ip, action, created_at, updated_at)
|
||||
VALUES (%s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
"""
|
||||
cursor.execute(insert_query, (
|
||||
action_data.client_ip,
|
||||
action_data.action
|
||||
))
|
||||
conn.commit()
|
||||
|
||||
# 获取新增记录(通过自增ID查询)
|
||||
new_id = cursor.lastrowid
|
||||
cursor.execute("SELECT * FROM device_action WHERE id = %s", (new_id,))
|
||||
new_action = cursor.fetchone()
|
||||
|
||||
return DeviceActionResponse(**new_action)
|
||||
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"新增记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 接口: 分页查询操作记录列表(仅返回 total + device_actions)
|
||||
# ------------------------------
|
||||
@router.get("/list", response_model=APIResponse, summary="分页查询设备操作记录")
|
||||
async def get_device_action_list(
|
||||
page: int = Query(1, ge=1, description="页码、默认1"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100"),
|
||||
client_ip: str = Query(None, description="按客户端IP筛选"),
|
||||
action: int = Query(None, ge=0, le=1, description="按状态筛选(0=离线、1=上线)")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 1. 构建筛选条件(参数化查询、避免注入)
|
||||
where_clause = []
|
||||
params = []
|
||||
if client_ip:
|
||||
where_clause.append("client_ip = %s")
|
||||
params.append(client_ip)
|
||||
if action is not None:
|
||||
where_clause.append("action = %s")
|
||||
params.append(action)
|
||||
|
||||
# 2. 查询总记录数(用于返回 total)
|
||||
count_sql = "SELECT COUNT(*) AS total FROM device_action"
|
||||
if where_clause:
|
||||
count_sql += " WHERE " + " AND ".join(where_clause)
|
||||
cursor.execute(count_sql, params)
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 3. 分页查询记录(按创建时间倒序、确保最新记录在前)
|
||||
offset = (page - 1) * page_size
|
||||
list_sql = "SELECT * FROM device_action"
|
||||
if where_clause:
|
||||
list_sql += " WHERE " + " AND ".join(where_clause)
|
||||
list_sql += " ORDER BY created_at DESC LIMIT %s OFFSET %s"
|
||||
params.extend([page_size, offset]) # 追加分页参数(page/page_size仅用于查询、不返回)
|
||||
|
||||
cursor.execute(list_sql, params)
|
||||
action_list = cursor.fetchall()
|
||||
|
||||
# 4. 仅返回 total + device_actions
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="查询成功",
|
||||
data=DeviceActionListResponse(
|
||||
total=total,
|
||||
device_actions=[DeviceActionResponse(**item) for item in action_list]
|
||||
)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
raise Exception(f"查询记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
@router.get("/{client_ip}", response_model=APIResponse, summary="根据IP查询设备操作记录")
|
||||
async def get_device_actions_by_ip(
|
||||
client_ip: str = Path(..., description="客户端IP地址")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 1. 查询总记录数
|
||||
count_sql = "SELECT COUNT(*) AS total FROM device_action WHERE client_ip = %s"
|
||||
cursor.execute(count_sql, (client_ip,))
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 2. 查询该IP的所有记录(按创建时间倒序)
|
||||
list_sql = """
|
||||
SELECT * FROM device_action
|
||||
WHERE client_ip = %s
|
||||
ORDER BY created_at DESC
|
||||
"""
|
||||
cursor.execute(list_sql, (client_ip,))
|
||||
action_list = cursor.fetchall()
|
||||
|
||||
# 3. 返回结果
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="查询成功",
|
||||
data=DeviceActionListResponse(
|
||||
total=total,
|
||||
device_actions=[DeviceActionResponse(**item) for item in action_list]
|
||||
)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
raise Exception(f"查询记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
@ -1,25 +1,11 @@
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
|
||||
from fastapi import HTTPException, Query, APIRouter, Depends, Request
|
||||
from fastapi import APIRouter, Query, HTTPException,Request
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
from ds.config import LIVE_CONFIG
|
||||
from ds.db import db
|
||||
from middle.auth_middleware import get_current_user
|
||||
# 注意:导入的Schema已更新字段
|
||||
from schema.device_schema import (
|
||||
DeviceCreateRequest,
|
||||
DeviceResponse,
|
||||
DeviceListResponse,
|
||||
md5_encrypt
|
||||
)
|
||||
from schema.device_schema import DeviceCreateRequest, DeviceResponse, DeviceListResponse
|
||||
from schema.response_schema import APIResponse
|
||||
from schema.user_schema import UserResponse
|
||||
|
||||
# 导入之前封装的WEBRTC处理函数
|
||||
from core.rtmp import rtmp_pull_video_stream
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/devices",
|
||||
@ -27,65 +13,127 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
# 在后台线程中运行WEBRTC处理
|
||||
def run_webrtc_processing(ip, webrtc_url):
|
||||
try:
|
||||
print(f"开始处理来自设备 {ip} 的WEBRTC流: {webrtc_url}")
|
||||
rtmp_pull_video_stream(webrtc_url)
|
||||
except Exception as e:
|
||||
print(f"WEBRTC处理出错: {str(e)}")
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 1. 创建设备信息
|
||||
# 内部工具方法 - 通过客户端IP增加设备报警次数
|
||||
# ------------------------------
|
||||
@router.post("/add", response_model=APIResponse, summary="创建设备信息")
|
||||
async def create_device(request: Request, device_data: DeviceCreateRequest):
|
||||
def increment_alarm_count_by_ip(client_ip: str) -> bool:
|
||||
"""
|
||||
通过客户端IP增加设备的报警次数(内部服务方法)
|
||||
|
||||
:param client_ip: 客户端IP地址
|
||||
:return: 操作是否成功
|
||||
"""
|
||||
if not client_ip:
|
||||
raise ValueError("客户端IP不能为空")
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 检查设备是否存在
|
||||
cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,))
|
||||
if not cursor.fetchone():
|
||||
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在")
|
||||
|
||||
# 报警次数加1、并更新时间戳
|
||||
update_query = """
|
||||
UPDATE devices
|
||||
SET alarm_count = alarm_count + 1,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE client_ip = %s
|
||||
"""
|
||||
cursor.execute(update_query, (client_ip,))
|
||||
conn.commit()
|
||||
|
||||
return True
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"更新报警次数失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 内部工具方法 - 通过客户端IP更新设备在线状态
|
||||
# ------------------------------
|
||||
def update_online_status_by_ip(client_ip: str, online_status: int) -> bool:
|
||||
"""
|
||||
通过客户端IP更新设备的在线状态(内部服务方法)
|
||||
|
||||
:param client_ip: 客户端IP地址
|
||||
:param online_status: 在线状态(1-在线、0-离线)
|
||||
:return: 操作是否成功
|
||||
"""
|
||||
if not client_ip:
|
||||
raise ValueError("客户端IP不能为空")
|
||||
|
||||
# 验证状态值有效性
|
||||
if online_status not in (0, 1):
|
||||
raise ValueError("在线状态必须是0(离线)或1(在线)")
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 检查设备是否存在
|
||||
cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,))
|
||||
if not cursor.fetchone():
|
||||
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在")
|
||||
|
||||
# 更新在线状态和时间戳
|
||||
update_query = """
|
||||
UPDATE devices
|
||||
SET device_online_status = %s,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE client_ip = %s
|
||||
"""
|
||||
cursor.execute(update_query, (online_status, client_ip))
|
||||
conn.commit()
|
||||
|
||||
return True
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"更新设备在线状态失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 原有接口保持不变
|
||||
# ------------------------------
|
||||
@router.post("/add", response_model=APIResponse, summary="创建设备信息")
|
||||
async def create_device(device_data: DeviceCreateRequest, request: Request): # 注入Request对象
|
||||
# 原有代码保持不变
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 检查client_ip是否已存在
|
||||
cursor.execute("SELECT * FROM devices WHERE client_ip = %s", (device_data.ip,))
|
||||
existing_device = cursor.fetchone()
|
||||
if existing_device:
|
||||
# 设备创建成功后,在后台线程启动WEBRTC流处理
|
||||
threading.Thread(
|
||||
target=run_webrtc_processing,
|
||||
# args=(device_data.ip, existing_device["live_webrtc_url"]),
|
||||
args=(device_data.ip, existing_device["rtmp_push_url"]),
|
||||
|
||||
daemon=True # 设为守护线程,主程序退出时自动结束
|
||||
).start()
|
||||
# IP已存在时返回该设备信息
|
||||
# 更新设备状态为在线
|
||||
update_online_status_by_ip(client_ip=device_data.ip, online_status=1)
|
||||
# 返回信息
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"客户端IP {device_data.ip} 已存在",
|
||||
message=f"设备IP {device_data.ip} 已存在、返回已有设备信息",
|
||||
data=DeviceResponse(** existing_device)
|
||||
)
|
||||
|
||||
# 获取RTMP URL和WEBRTC URL配置
|
||||
rtmp_url = str(LIVE_CONFIG.get("rtmp_url", ""))
|
||||
webrtc_url = str(LIVE_CONFIG.get("webrtc_url", ""))
|
||||
|
||||
# 将设备详细信息(params)转换为JSON字符串
|
||||
device_params_json = json.dumps(device_data.params) if device_data.params else None
|
||||
|
||||
# 对JSON字符串进行MD5加密
|
||||
device_md5 = md5_encrypt(device_params_json) if device_params_json else ""
|
||||
|
||||
# 解析User-Agent获取设备类型
|
||||
# 直接使用注入的request对象获取用户代理
|
||||
user_agent = request.headers.get("User-Agent", "").lower()
|
||||
|
||||
# 优先处理User-Agent为default的情况
|
||||
if user_agent == "default":
|
||||
# 检查params中是否存在os键
|
||||
if device_data.params and isinstance(device_data.params, dict) and "os" in device_data.params:
|
||||
device_type = device_data.params["os"]
|
||||
else:
|
||||
device_type = "unknown"
|
||||
device_type = device_data.params.get("os") if (
|
||||
device_data.params and isinstance(device_data.params, dict)) else "unknown"
|
||||
elif "windows" in user_agent:
|
||||
device_type = "windows"
|
||||
elif "android" in user_agent:
|
||||
@ -95,22 +143,16 @@ async def create_device(request: Request, device_data: DeviceCreateRequest):
|
||||
else:
|
||||
device_type = "unknown"
|
||||
|
||||
# 构建完整的WEBRTC URL
|
||||
full_webrtc_url = webrtc_url + device_md5
|
||||
device_params_json = json.dumps(device_data.params) if device_data.params else None
|
||||
|
||||
# SQL插入语句
|
||||
insert_query = """
|
||||
INSERT INTO devices
|
||||
(client_ip, hostname, rtmp_push_url, live_webrtc_url, detection_webrtc_url,
|
||||
device_online_status, device_type, alarm_count, params)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
(client_ip, hostname, device_online_status, device_type, alarm_count, params)
|
||||
VALUES (%s, %s, %s, %s, %s, %s)
|
||||
"""
|
||||
cursor.execute(insert_query, (
|
||||
device_data.ip,
|
||||
device_data.hostname,
|
||||
rtmp_url + device_md5,
|
||||
full_webrtc_url, # 存储完整的WEBRTC URL
|
||||
"",
|
||||
1,
|
||||
device_type,
|
||||
0,
|
||||
@ -118,28 +160,22 @@ async def create_device(request: Request, device_data: DeviceCreateRequest):
|
||||
))
|
||||
conn.commit()
|
||||
|
||||
# 获取刚创建的设备信息
|
||||
device_id = cursor.lastrowid
|
||||
cursor.execute("SELECT * FROM devices WHERE id = %s", (device_id,))
|
||||
device = cursor.fetchone()
|
||||
new_device = cursor.fetchone()
|
||||
|
||||
# 设备创建成功后,在后台线程启动WEBRTC流处理
|
||||
threading.Thread(
|
||||
target=run_webrtc_processing,
|
||||
args=(device_data.ip, full_webrtc_url),
|
||||
daemon=True # 设为守护线程,主程序退出时自动结束
|
||||
).start()
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="设备创建成功,已开始处理WEBRTC流",
|
||||
data=DeviceResponse(**device)
|
||||
message="设备创建成功",
|
||||
data=DeviceResponse(**new_device)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"创建设备失败:{str(e)}") from e
|
||||
raise Exception(f"创建设备失败: {str(e)}") from e
|
||||
except json.JSONDecodeError as e:
|
||||
raise Exception(f"设备信息JSON序列化失败:{str(e)}") from e
|
||||
raise Exception(f"设备详细信息JSON序列化失败: {str(e)}") from e
|
||||
except Exception as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
@ -147,140 +183,82 @@ async def create_device(request: Request, device_data: DeviceCreateRequest):
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 2. 获取设备列表
|
||||
# ------------------------------
|
||||
@router.get("/", response_model=APIResponse, summary="获取设备列表")
|
||||
@router.get("/", response_model=APIResponse, summary="获取设备列表(支持筛选分页)")
|
||||
async def get_device_list(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数"),
|
||||
device_type: str = Query(None, description="设备类型筛选"),
|
||||
online_status: int = Query(None, ge=0, le=1, description="在线状态筛选(1-在线、0-离线)")
|
||||
page: int = Query(1, ge=1, description="页码、默认第1页"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间"),
|
||||
device_type: str = Query(None, description="按设备类型筛选"),
|
||||
online_status: int = Query(None, ge=0, le=1, description="按在线状态筛选")
|
||||
):
|
||||
# 原有代码保持不变
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 构建查询条件
|
||||
where_clause = []
|
||||
params = []
|
||||
|
||||
if device_type:
|
||||
where_clause.append("device_type = %s")
|
||||
params.append(device_type)
|
||||
|
||||
if online_status is not None:
|
||||
where_clause.append("device_online_status = %s")
|
||||
params.append(online_status)
|
||||
|
||||
# 总条数查询
|
||||
count_query = "SELECT COUNT(*) as total FROM devices"
|
||||
count_query = "SELECT COUNT(*) AS total FROM devices"
|
||||
if where_clause:
|
||||
count_query += " WHERE " + " AND ".join(where_clause)
|
||||
|
||||
cursor.execute(count_query, params)
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 分页查询(SELECT * 会自动匹配表字段、响应模型已对齐)
|
||||
offset = (page - 1) * page_size
|
||||
query = "SELECT * FROM devices"
|
||||
list_query = "SELECT * FROM devices"
|
||||
if where_clause:
|
||||
query += " WHERE " + " AND ".join(where_clause)
|
||||
query += " ORDER BY id DESC LIMIT %s OFFSET %s"
|
||||
list_query += " WHERE " + " AND ".join(where_clause)
|
||||
list_query += " ORDER BY id DESC LIMIT %s OFFSET %s"
|
||||
params.extend([page_size, offset])
|
||||
|
||||
cursor.execute(query, params)
|
||||
devices = cursor.fetchall()
|
||||
|
||||
# 响应模型已更新为params字段、直接转换即可
|
||||
device_list = [DeviceResponse(**device) for device in devices]
|
||||
cursor.execute(list_query, params)
|
||||
device_list = cursor.fetchall()
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="获取设备列表成功",
|
||||
data=DeviceListResponse(total=total, devices=device_list)
|
||||
data=DeviceListResponse(
|
||||
total=total,
|
||||
devices=[DeviceResponse(**device) for device in device_list]
|
||||
)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
raise Exception(f"获取设备列表失败:{str(e)}") from e
|
||||
raise Exception(f"获取设备列表失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 3. 获取单个设备详情
|
||||
# ------------------------------
|
||||
@router.get("/{device_id}", response_model=APIResponse, summary="获取设备详情")
|
||||
async def get_device_detail(
|
||||
device_id: int,
|
||||
current_user: UserResponse = Depends(get_current_user)
|
||||
):
|
||||
def get_unique_client_ips() -> list[str]:
|
||||
"""
|
||||
获取所有去重的客户端IP列表
|
||||
|
||||
:return: 去重后的客户端IP字符串列表,如果没有数据则返回空列表
|
||||
"""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 查询设备信息(SELECT * 匹配表字段)
|
||||
query = "SELECT * FROM devices WHERE id = %s"
|
||||
cursor.execute(query, (device_id,))
|
||||
device = cursor.fetchone()
|
||||
# 查询去重的客户端IP
|
||||
query = "SELECT DISTINCT client_ip FROM devices WHERE client_ip IS NOT NULL"
|
||||
cursor.execute(query)
|
||||
|
||||
if not device:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"设备ID为 {device_id} 的设备不存在"
|
||||
)
|
||||
# 提取结果并转换为字符串列表
|
||||
results = cursor.fetchall()
|
||||
return [item['client_ip'] for item in results]
|
||||
|
||||
# 响应模型已更新为params字段
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="获取设备详情成功",
|
||||
data=DeviceResponse(**device)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise Exception(f"获取设备详情失败:{str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 4. 删除设备信息
|
||||
# ------------------------------
|
||||
@router.delete("/{device_id}", response_model=APIResponse, summary="删除设备信息")
|
||||
async def delete_device(
|
||||
device_id: int,
|
||||
current_user: UserResponse = Depends(get_current_user)
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 检查设备是否存在
|
||||
cursor.execute("SELECT id FROM devices WHERE id = %s", (device_id,))
|
||||
if not cursor.fetchone():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"设备ID为 {device_id} 的设备不存在"
|
||||
)
|
||||
|
||||
# 执行删除
|
||||
delete_query = "DELETE FROM devices WHERE id = %s"
|
||||
cursor.execute(delete_query, (device_id,))
|
||||
conn.commit()
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"设备ID为 {device_id} 的设备已成功删除",
|
||||
data=None
|
||||
)
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"删除设备失败:{str(e)}") from e
|
||||
raise Exception(f"获取客户端IP列表失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
@ -7,32 +7,38 @@ from schema.response_schema import APIResponse
|
||||
from middle.auth_middleware import get_current_user
|
||||
from schema.user_schema import UserResponse
|
||||
|
||||
from util.face_util import add_binary_data,get_average_feature
|
||||
#初始化实例
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/faces",
|
||||
tags=["人脸管理"]
|
||||
)
|
||||
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 1. 创建人脸记录(核心修正:ID 数据库自增,前端无需传)
|
||||
# 1. 创建人脸记录(核心修正: ID 数据库自增、前端无需传)
|
||||
# ------------------------------
|
||||
@router.post("", response_model=APIResponse, summary="创建人脸记录(传名称+文件,ID自增)")
|
||||
@router.post("", response_model=APIResponse, summary="创建人脸记录(传名称+文件、ID自增)")
|
||||
async def create_face(
|
||||
# 前端仅需传:name(可选,Form格式)、file(必传,文件)
|
||||
# 前端仅需传: name(可选、Form格式)、file(必传、文件)
|
||||
name: str = Form(None, max_length=255, description="名称(可选)"),
|
||||
file: UploadFile = File(..., description="人脸文件(必传,暂不处理内容)")
|
||||
file: UploadFile = File(..., description="人脸文件(必传、暂不处理内容)")
|
||||
):
|
||||
"""
|
||||
创建人脸记录:
|
||||
创建人脸记录:
|
||||
- 需登录认证
|
||||
- 前端传参:multipart/form-data 表单(name 可选,file 必传)
|
||||
- ID 由数据库自动生成,无需前端传入
|
||||
- 暂不处理文件内容,eigenvalue 设为 None
|
||||
- 前端传参: multipart/form-data 表单(name 可选、file 必传)
|
||||
- ID 由数据库自动生成、无需前端传入
|
||||
- 暂不处理文件内容、eigenvalue 设为 None
|
||||
"""
|
||||
|
||||
# 调用你的方法
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
# 1. 用模型校验 name(仅校验长度,无需ID)
|
||||
# 1. 用模型校验 name(仅校验长度、无需ID)
|
||||
face_create = FaceCreateRequest(name=name)
|
||||
|
||||
conn = db.get_connection()
|
||||
@ -41,42 +47,77 @@ async def create_face(
|
||||
# 把文件转为二进制数组
|
||||
file_content = await file.read()
|
||||
|
||||
# 调用人脸识别得到特征值
|
||||
# 计算特征值
|
||||
flag, eigenvalue = add_binary_data(file_content)
|
||||
|
||||
if flag == False:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="未检测到人脸"
|
||||
)
|
||||
|
||||
# 2. 插入数据库:无需传 ID(自增),只传 name 和 eigenvalue(None)
|
||||
# 打印数组长度
|
||||
print(f"文件大小: {len(file_content)} 字节")
|
||||
|
||||
# 2. 插入数据库: 无需传 ID(自增)、只传 name 和 eigenvalue(None)
|
||||
insert_query = """
|
||||
INSERT INTO face (name, eigenvalue)
|
||||
VALUES (%s, %s)
|
||||
"""
|
||||
cursor.execute(insert_query, (face_create.name, None))
|
||||
cursor.execute(insert_query, (face_create.name, str(eigenvalue)))
|
||||
conn.commit()
|
||||
|
||||
# 3. 获取数据库自动生成的 ID(关键:用 LAST_INSERT_ID() 查刚插入的记录)
|
||||
# 3. 获取数据库自动生成的 ID(关键: 用 LAST_INSERT_ID() 查刚插入的记录)
|
||||
select_new_query = "SELECT * FROM face WHERE id = LAST_INSERT_ID()"
|
||||
cursor.execute(select_new_query)
|
||||
created_face = cursor.fetchone()
|
||||
|
||||
if not created_face:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="创建人脸记录成功、但无法获取新创建的记录"
|
||||
)
|
||||
|
||||
return APIResponse(
|
||||
code=201,
|
||||
message=f"人脸记录创建成功(ID:{created_face['id']},文件名:{file.filename})",
|
||||
message=f"人脸记录创建成功(ID: {created_face['id']}、文件名: {file.filename})",
|
||||
data=FaceResponse(** created_face)
|
||||
)
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"创建人脸记录失败:{str(e)}") from e
|
||||
# 改为使用HTTPException
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"创建人脸记录失败: {str(e)}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
# 捕获其他可能的异常
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"服务器错误: {str(e)}"
|
||||
) from e
|
||||
finally:
|
||||
await file.close() # 关闭文件流
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
# 调用人脸识别得到特征值(这里可以添加你的人脸识别逻辑)
|
||||
flag, eigenvalue = add_binary_data(file_content)
|
||||
if flag == False:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="未检测到人脸"
|
||||
)
|
||||
|
||||
# 将 eigenvalue 转为 str
|
||||
eigenvalue = str(eigenvalue)
|
||||
|
||||
# ------------------------------
|
||||
# 2. 获取单个人脸记录(不变,用自增ID查询)
|
||||
# 2. 获取单个人脸记录(不变、用自增ID查询)
|
||||
# ------------------------------
|
||||
@router.get("/{face_id}", response_model=APIResponse, summary="获取单个人脸记录")
|
||||
async def get_face(
|
||||
face_id: int, # 这里的 ID 是数据库自增的,前端从创建响应中获取
|
||||
face_id: int, # 这里的 ID 是数据库自增的、前端从创建响应中获取
|
||||
current_user: UserResponse = Depends(get_current_user)
|
||||
):
|
||||
conn = None
|
||||
@ -101,18 +142,21 @@ async def get_face(
|
||||
data=FaceResponse(**face)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise Exception(f"查询人脸记录失败:{str(e)}") from e
|
||||
# 改为使用HTTPException
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"查询人脸记录失败: {str(e)}"
|
||||
) from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 后续 3.获取所有、4.更新、5.删除 接口(不变,仅用数据库自增的 ID 操作,无需修改)
|
||||
# 后续 3.获取所有、4.更新、5.删除 接口(修复异常处理)
|
||||
# ------------------------------
|
||||
# 3. 获取所有人脸记录(不变)
|
||||
# ------------------------------
|
||||
@router.get("", response_model=APIResponse, summary="获取所有人脸记录")
|
||||
async def get_all_faces(
|
||||
current_user: UserResponse = Depends(get_current_user)
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
@ -130,13 +174,16 @@ async def get_all_faces(
|
||||
data=[FaceResponse(** face) for face in faces]
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise Exception(f"查询所有人脸记录失败:{str(e)}") from e
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"查询所有人脸记录失败: {str(e)}"
|
||||
) from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 4. 更新人脸记录(不变,用自增ID更新)
|
||||
# 4. 更新人脸记录(不变、用自增ID更新)
|
||||
# ------------------------------
|
||||
@router.put("/{face_id}", response_model=APIResponse, summary="更新人脸记录")
|
||||
async def update_face(
|
||||
@ -191,13 +238,16 @@ async def update_face(
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"更新人脸记录失败:{str(e)}") from e
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"更新人脸记录失败: {str(e)}"
|
||||
) from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 5. 删除人脸记录(不变,用自增ID删除)
|
||||
# 5. 删除人脸记录(不变、用自增ID删除)
|
||||
# ------------------------------
|
||||
@router.delete("/{face_id}", response_model=APIResponse, summary="删除人脸记录")
|
||||
async def delete_face(
|
||||
@ -231,39 +281,61 @@ async def delete_face(
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"删除人脸记录失败:{str(e)}") from e
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"删除人脸记录失败: {str(e)}"
|
||||
) from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
def get_all_face_name_with_eigenvalue() -> dict:
|
||||
"""
|
||||
获取所有人脸的名称及其对应的特征值,组成字典返回
|
||||
获取所有人脸的名称及其对应的特征值、组成字典返回
|
||||
key: 人脸名称(name)
|
||||
value: 人脸特征值(eigenvalue)
|
||||
注:过滤掉name为None的记录,避免字典key为None的情况
|
||||
value: 人脸特征值(eigenvalue)、若名称重复则返回平均特征值
|
||||
注: 过滤掉name为None的记录、避免字典key为None的情况
|
||||
"""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
# 1. 建立数据库连接并获取游标(dictionary=True使结果以字典形式返回)
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 只查询需要的字段,提高效率
|
||||
# 2. 执行SQL查询: 只获取name非空的记录、减少数据传输
|
||||
query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL"
|
||||
cursor.execute(query)
|
||||
faces = cursor.fetchall()
|
||||
faces = cursor.fetchall() # 返回结果: 列表套字典、如 [{"name":"张三","eigenvalue":...}, ...]
|
||||
|
||||
# 构建name到eigenvalue的映射字典
|
||||
face_dict = {
|
||||
face["name"]: face["eigenvalue"]
|
||||
for face in faces
|
||||
}
|
||||
# 3. 收集同一名称对应的所有特征值(处理名称重复场景)
|
||||
name_to_eigenvalues = {}
|
||||
for face in faces:
|
||||
name = face["name"]
|
||||
eigenvalue = face["eigenvalue"]
|
||||
# 若名称已存在、追加特征值;否则新建列表存储
|
||||
if name in name_to_eigenvalues:
|
||||
name_to_eigenvalues[name].append(eigenvalue)
|
||||
else:
|
||||
name_to_eigenvalues[name] = [eigenvalue]
|
||||
|
||||
# 4. 构建最终字典: 重复名称取平均、唯一名称直接取特征值
|
||||
face_dict = {}
|
||||
for name, eigenvalues in name_to_eigenvalues.items():
|
||||
|
||||
# 处理特征值: 多个则求平均、单个则直接使用
|
||||
if len(eigenvalues) > 1:
|
||||
# 调用外部方法计算平均特征值(需确保binary_face_feature_handler已正确导入)
|
||||
face_dict[name] = get_average_feature(eigenvalues)
|
||||
else:
|
||||
# 取列表中唯一的特征值(避免value为列表类型)
|
||||
face_dict[name] = eigenvalues[0]
|
||||
|
||||
return face_dict
|
||||
|
||||
except MySQLError as e:
|
||||
raise Exception(f"获取人脸名称与特征值失败:{str(e)}") from e
|
||||
# 捕获数据库异常、添加上下文信息后重新抛出(便于定位问题)
|
||||
raise Exception(f"获取人脸名称与特征值失败: {str(e)}") from e
|
||||
finally:
|
||||
# 确保资源释放
|
||||
# 5. 无论是否异常、均释放数据库连接和游标(避免资源泄漏)
|
||||
db.close_connection(conn, cursor)
|
@ -21,7 +21,7 @@ router = APIRouter(
|
||||
async def create_sensitive(
|
||||
sensitive: SensitiveCreateRequest): # 添加了登录认证依赖
|
||||
"""
|
||||
创建敏感信息记录:
|
||||
创建敏感信息记录:
|
||||
- 需登录认证
|
||||
- 插入新的敏感信息记录到数据库(ID由数据库自动生成)
|
||||
- 返回创建成功信息
|
||||
@ -32,7 +32,7 @@ async def create_sensitive(
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 插入新敏感信息记录到数据库(不包含ID,由数据库自动生成)
|
||||
# 插入新敏感信息记录到数据库(不包含ID、由数据库自动生成)
|
||||
insert_query = """
|
||||
INSERT INTO sensitives (name)
|
||||
VALUES (%s)
|
||||
@ -56,7 +56,7 @@ async def create_sensitive(
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"创建敏感信息记录失败:{str(e)}") from e
|
||||
raise Exception(f"创建敏感信息记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
@ -71,7 +71,7 @@ async def get_sensitive(
|
||||
current_user: UserResponse = Depends(get_current_user) # 需登录认证
|
||||
):
|
||||
"""
|
||||
获取单个敏感信息记录:
|
||||
获取单个敏感信息记录:
|
||||
- 需登录认证
|
||||
- 根据ID查询敏感信息记录
|
||||
- 返回查询到的敏感信息
|
||||
@ -98,7 +98,7 @@ async def get_sensitive(
|
||||
data=SensitiveResponse(**sensitive)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise Exception(f"查询敏感信息记录失败:{str(e)}") from e
|
||||
raise Exception(f"查询敏感信息记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
@ -109,7 +109,7 @@ async def get_sensitive(
|
||||
@router.get("", response_model=APIResponse, summary="获取所有敏感信息记录")
|
||||
async def get_all_sensitives():
|
||||
"""
|
||||
获取所有敏感信息记录:
|
||||
获取所有敏感信息记录:
|
||||
- 需登录认证
|
||||
- 查询所有敏感信息记录(不需要分页)
|
||||
- 返回所有敏感信息列表
|
||||
@ -130,7 +130,7 @@ async def get_all_sensitives():
|
||||
data=[SensitiveResponse(**sensitive) for sensitive in sensitives]
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise Exception(f"查询所有敏感信息记录失败:{str(e)}") from e
|
||||
raise Exception(f"查询所有敏感信息记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
@ -145,7 +145,7 @@ async def update_sensitive(
|
||||
current_user: UserResponse = Depends(get_current_user) # 需登录认证
|
||||
):
|
||||
"""
|
||||
更新敏感信息记录:
|
||||
更新敏感信息记录:
|
||||
- 需登录认证
|
||||
- 根据ID更新敏感信息记录
|
||||
- 返回更新后的敏感信息
|
||||
@ -203,7 +203,7 @@ async def update_sensitive(
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"更新敏感信息记录失败:{str(e)}") from e
|
||||
raise Exception(f"更新敏感信息记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
@ -217,7 +217,7 @@ async def delete_sensitive(
|
||||
current_user: UserResponse = Depends(get_current_user) # 需登录认证
|
||||
):
|
||||
"""
|
||||
删除敏感信息记录:
|
||||
删除敏感信息记录:
|
||||
- 需登录认证
|
||||
- 根据ID删除敏感信息记录
|
||||
- 返回删除成功信息
|
||||
@ -251,14 +251,14 @@ async def delete_sensitive(
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"删除敏感信息记录失败:{str(e)}") from e
|
||||
raise Exception(f"删除敏感信息记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
def get_all_sensitive_words() -> list[str]:
|
||||
"""
|
||||
获取所有敏感词,返回字符串数组
|
||||
获取所有敏感词、返回字符串数组
|
||||
|
||||
返回:
|
||||
list[str]: 包含所有敏感词的数组
|
||||
@ -273,7 +273,7 @@ def get_all_sensitive_words() -> list[str]:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 执行查询,只获取敏感词字段
|
||||
# 执行查询、只获取敏感词字段
|
||||
query = "SELECT name FROM sensitives ORDER BY id"
|
||||
cursor.execute(query)
|
||||
sensitive_records = cursor.fetchall()
|
||||
@ -283,7 +283,7 @@ def get_all_sensitive_words() -> list[str]:
|
||||
|
||||
except MySQLError as e:
|
||||
# 数据库错误处理
|
||||
raise MySQLError(f"查询敏感词失败:{str(e)}") from e
|
||||
raise MySQLError(f"查询敏感词失败: {str(e)}") from e
|
||||
finally:
|
||||
# 确保资源正确释放
|
||||
db.close_connection(conn, cursor)
|
@ -27,7 +27,7 @@ router = APIRouter(
|
||||
@router.post("/register", response_model=APIResponse, summary="用户注册")
|
||||
async def user_register(request: UserRegisterRequest):
|
||||
"""
|
||||
用户注册:
|
||||
用户注册:
|
||||
- 校验用户名是否已存在
|
||||
- 加密密码后插入数据库
|
||||
- 返回注册成功信息
|
||||
@ -67,7 +67,7 @@ async def user_register(request: UserRegisterRequest):
|
||||
)
|
||||
except MySQLError as e:
|
||||
conn.rollback() # 数据库错误时回滚事务
|
||||
raise Exception(f"注册失败:{str(e)}") from e
|
||||
raise Exception(f"注册失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
@ -78,7 +78,7 @@ async def user_register(request: UserRegisterRequest):
|
||||
@router.post("/login", response_model=APIResponse, summary="用户登录(获取 Token)")
|
||||
async def user_login(request: UserLoginRequest):
|
||||
"""
|
||||
用户登录:
|
||||
用户登录:
|
||||
- 校验用户名是否存在
|
||||
- 校验密码是否正确
|
||||
- 生成 JWT Token 并返回
|
||||
@ -89,7 +89,7 @@ async def user_login(request: UserLoginRequest):
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 修复:SQL查询添加 created_at 和 updated_at 字段
|
||||
# 修复: SQL查询添加 created_at 和 updated_at 字段
|
||||
query = """
|
||||
SELECT id, username, password, created_at, updated_at
|
||||
FROM users
|
||||
@ -129,7 +129,7 @@ async def user_login(request: UserLoginRequest):
|
||||
}
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise Exception(f"登录失败:{str(e)}") from e
|
||||
raise Exception(f"登录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
@ -142,8 +142,8 @@ async def get_current_user_info(
|
||||
current_user: UserResponse = Depends(get_current_user) # 依赖认证中间件
|
||||
):
|
||||
"""
|
||||
获取当前登录用户信息:
|
||||
- 需在请求头携带 Token(格式:Bearer <token>)
|
||||
获取当前登录用户信息:
|
||||
- 需在请求头携带 Token(格式: Bearer <token>)
|
||||
- 认证通过后返回用户信息
|
||||
"""
|
||||
return APIResponse(
|
||||
|
145
util/face_util.py
Normal file
@ -0,0 +1,145 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import insightface
|
||||
from insightface.app import FaceAnalysis
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
# 全局变量存储InsightFace引擎和特征列表
|
||||
_insightface_app = None
|
||||
_feature_list = []
|
||||
|
||||
|
||||
def init_insightface():
|
||||
"""初始化InsightFace引擎"""
|
||||
global _insightface_app
|
||||
try:
|
||||
print("正在初始化InsightFace引擎...")
|
||||
app = FaceAnalysis(name='buffalo_l', root='~/.insightface')
|
||||
app.prepare(ctx_id=0, det_size=(640, 640))
|
||||
print("InsightFace引擎初始化完成")
|
||||
_insightface_app = app
|
||||
return app
|
||||
except Exception as e:
|
||||
print(f"InsightFace初始化失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def add_binary_data(binary_data):
|
||||
"""
|
||||
接收单张图片的二进制数据、提取特征并保存
|
||||
|
||||
参数:
|
||||
binary_data: 图片的二进制数据(bytes类型)
|
||||
|
||||
返回:
|
||||
成功提取特征时返回 (True, 特征值numpy数组)
|
||||
失败时返回 (False, None)
|
||||
"""
|
||||
global _insightface_app, _feature_list
|
||||
|
||||
if not _insightface_app:
|
||||
print("引擎未初始化、无法处理")
|
||||
return False, None
|
||||
|
||||
try:
|
||||
# 直接处理二进制数据: 转换为图像格式
|
||||
img = Image.open(BytesIO(binary_data))
|
||||
frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
||||
|
||||
# 提取特征
|
||||
faces = _insightface_app.get(frame)
|
||||
if faces:
|
||||
# 获取当前提取的特征值
|
||||
current_feature = faces[0].embedding
|
||||
# 添加到特征列表
|
||||
_feature_list.append(current_feature)
|
||||
print(f"已累计 {len(_feature_list)} 个特征")
|
||||
# 返回成功标志和当前特征值
|
||||
return True, current_feature
|
||||
else:
|
||||
print("二进制数据中未检测到人脸")
|
||||
return False, None
|
||||
except Exception as e:
|
||||
print(f"处理二进制数据出错: {e}")
|
||||
return False, None
|
||||
|
||||
|
||||
def get_average_feature(features=None):
|
||||
"""
|
||||
计算多个特征向量的平均值
|
||||
|
||||
参数:
|
||||
features: 可选、特征值列表。如果未提供、则使用全局存储的_feature_list
|
||||
每个元素可以是字符串格式或numpy数组
|
||||
|
||||
返回:
|
||||
单一平均特征向量的numpy数组、若无可计算数据则返回None
|
||||
"""
|
||||
global _feature_list
|
||||
|
||||
# 如果未提供features参数、则使用全局特征列表
|
||||
if features is None:
|
||||
features = _feature_list
|
||||
|
||||
try:
|
||||
# 验证输入是否为列表且不为空
|
||||
if not isinstance(features, list) or len(features) == 0:
|
||||
print("输入必须是包含至少一个特征值的列表")
|
||||
return None
|
||||
|
||||
# 处理每个特征值
|
||||
processed_features = []
|
||||
for i, embedding in enumerate(features):
|
||||
try:
|
||||
if isinstance(embedding, str):
|
||||
# 处理包含括号和逗号的字符串格式
|
||||
embedding_str = embedding.replace('[', '').replace(']', '').replace(',', ' ').strip()
|
||||
embedding_list = [float(num) for num in embedding_str.split() if num.strip()]
|
||||
embedding_np = np.array(embedding_list, dtype=np.float32)
|
||||
else:
|
||||
embedding_np = np.array(embedding, dtype=np.float32)
|
||||
|
||||
# 验证特征值格式
|
||||
if len(embedding_np.shape) == 1:
|
||||
processed_features.append(embedding_np)
|
||||
print(f"已添加第 {i + 1} 个特征值用于计算平均值")
|
||||
else:
|
||||
print(f"跳过第 {i + 1} 个特征值、不是一维数组")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理第 {i + 1} 个特征值时出错: {e}")
|
||||
|
||||
# 确保有有效的特征值
|
||||
if not processed_features:
|
||||
print("没有有效的特征值用于计算平均值")
|
||||
return None
|
||||
|
||||
# 检查所有特征向量维度是否相同
|
||||
dims = {feat.shape[0] for feat in processed_features}
|
||||
if len(dims) > 1:
|
||||
print(f"特征值维度不一致、无法计算平均值。检测到的维度: {dims}")
|
||||
return None
|
||||
|
||||
# 计算平均值
|
||||
avg_feature = np.mean(processed_features, axis=0)
|
||||
print(f"成功计算 {len(processed_features)} 个特征值的平均特征向量、维度: {avg_feature.shape[0]}")
|
||||
|
||||
return avg_feature
|
||||
|
||||
except Exception as e:
|
||||
print(f"计算平均特征值时出错: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def clear_features():
|
||||
"""清空已存储的特征数据"""
|
||||
global _feature_list
|
||||
_feature_list = []
|
||||
print("已清空所有特征数据")
|
||||
|
||||
|
||||
def get_feature_list():
|
||||
"""获取当前存储的特征列表"""
|
||||
global _feature_list
|
||||
return _feature_list.copy() # 返回副本防止外部直接修改
|
482
ws.html
@ -1,482 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>WebSocket 测试工具</title>
|
||||
<style>
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
font-family: 'Arial', 'Microsoft YaHei', sans-serif;
|
||||
}
|
||||
|
||||
body {
|
||||
max-width: 1200px;
|
||||
margin: 20px auto;
|
||||
padding: 0 20px;
|
||||
background-color: #f5f7fa;
|
||||
}
|
||||
|
||||
.container {
|
||||
background: white;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
|
||||
padding: 25px;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
h1 {
|
||||
color: #2c3e50;
|
||||
margin-bottom: 20px;
|
||||
font-size: 24px;
|
||||
border-bottom: 2px solid #3498db;
|
||||
padding-bottom: 10px;
|
||||
}
|
||||
|
||||
.status-bar {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 15px;
|
||||
margin-bottom: 20px;
|
||||
padding: 12px 15px;
|
||||
background-color: #f8f9fa;
|
||||
border-radius: 6px;
|
||||
}
|
||||
|
||||
.status-label {
|
||||
font-weight: bold;
|
||||
color: #495057;
|
||||
}
|
||||
|
||||
.status-value {
|
||||
padding: 4px 10px;
|
||||
border-radius: 4px;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.status-connected {
|
||||
background-color: #d4edda;
|
||||
color: #155724;
|
||||
}
|
||||
|
||||
.status-disconnected {
|
||||
background-color: #f8d7da;
|
||||
color: #721c24;
|
||||
}
|
||||
|
||||
.status-connecting {
|
||||
background-color: #fff3cd;
|
||||
color: #856404;
|
||||
}
|
||||
|
||||
.btn {
|
||||
padding: 8px 16px;
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
transition: background-color 0.2s;
|
||||
}
|
||||
|
||||
.btn-primary {
|
||||
background-color: #3498db;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-primary:hover {
|
||||
background-color: #2980b9;
|
||||
}
|
||||
|
||||
.btn-danger {
|
||||
background-color: #e74c3c;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-danger:hover {
|
||||
background-color: #c0392b;
|
||||
}
|
||||
|
||||
.btn-success {
|
||||
background-color: #2ecc71;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-success:hover {
|
||||
background-color: #27ae60;
|
||||
}
|
||||
|
||||
.control-group {
|
||||
display: flex;
|
||||
gap: 15px;
|
||||
margin-bottom: 20px;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.input-group {
|
||||
display: flex;
|
||||
gap: 10px;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.input-group label {
|
||||
color: #495057;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.input-group input, .input-group select {
|
||||
padding: 8px 12px;
|
||||
border: 1px solid #ced4da;
|
||||
border-radius: 4px;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
.message-area {
|
||||
margin-top: 20px;
|
||||
}
|
||||
|
||||
.message-input {
|
||||
width: 100%;
|
||||
height: 100px;
|
||||
padding: 12px;
|
||||
border: 1px solid #ced4da;
|
||||
border-radius: 6px;
|
||||
resize: none;
|
||||
font-size: 14px;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
.log-area {
|
||||
width: 100%;
|
||||
height: 300px;
|
||||
padding: 15px;
|
||||
border: 1px solid #ced4da;
|
||||
border-radius: 6px;
|
||||
background-color: #f8f9fa;
|
||||
overflow-y: auto;
|
||||
font-size: 14px;
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
.log-item {
|
||||
margin-bottom: 8px;
|
||||
padding-bottom: 8px;
|
||||
border-bottom: 1px dashed #e9ecef;
|
||||
}
|
||||
|
||||
.log-time {
|
||||
color: #6c757d;
|
||||
font-size: 12px;
|
||||
margin-right: 10px;
|
||||
}
|
||||
|
||||
.log-send {
|
||||
color: #2980b9;
|
||||
}
|
||||
|
||||
.log-receive {
|
||||
color: #27ae60;
|
||||
}
|
||||
|
||||
.log-status {
|
||||
color: #856404;
|
||||
}
|
||||
|
||||
.log-error {
|
||||
color: #e74c3c;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>WebSocket 测试工具</h1>
|
||||
|
||||
<!-- 连接状态区 -->
|
||||
<div class="status-bar">
|
||||
<div class="status-label">连接状态:</div>
|
||||
<div id="connectionStatus" class="status-value status-disconnected">未连接</div>
|
||||
<div class="status-label">服务地址:</div>
|
||||
<div id="wsUrl" class="status-value">ws://192.168.110.25:8000/ws</div>
|
||||
<div class="status-label">连接时间:</div>
|
||||
<div id="connectTime" class="status-value">-</div>
|
||||
</div>
|
||||
|
||||
<!-- 控制按钮区 -->
|
||||
<div class="control-group">
|
||||
<button id="connectBtn" class="btn btn-primary">建立连接</button>
|
||||
<button id="disconnectBtn" class="btn btn-danger" disabled>断开连接</button>
|
||||
|
||||
<!-- 心跳控制 -->
|
||||
<div class="input-group">
|
||||
<label>自动心跳:</label>
|
||||
<select id="autoHeartbeat">
|
||||
<option value="on">开启</option>
|
||||
<option value="off">关闭</option>
|
||||
</select>
|
||||
<label>间隔(秒):</label>
|
||||
<input type="number" id="heartbeatInterval" value="30" min="10" max="120" style="width: 80px;">
|
||||
<button id="sendHeartbeatBtn" class="btn btn-success">手动发送心跳</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 自定义消息发送区 -->
|
||||
<div class="message-area">
|
||||
<h3>发送自定义消息</h3>
|
||||
<textarea id="messageInput" class="message-input"
|
||||
placeholder='示例:{"type":"test","content":"Hello WebSocket"}'>{"type":"test","content":"Hello WebSocket"}</textarea>
|
||||
<button id="sendMessageBtn" class="btn btn-primary" disabled>发送消息</button>
|
||||
</div>
|
||||
|
||||
<!-- 日志显示区 -->
|
||||
<div class="message-area">
|
||||
<h3>消息日志</h3>
|
||||
<div id="logContainer" class="log-area">
|
||||
<div class="log-item"><span class="log-time">[加载完成]</span> 请点击「建立连接」开始测试</div>
|
||||
</div>
|
||||
<button id="clearLogBtn" class="btn btn-primary" style="margin-top: 10px;">清空日志</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
// 全局变量
|
||||
let ws = null;
|
||||
let heartbeatTimer = null;
|
||||
const wsUrl = "ws://192.168.110.25:8000/ws";
|
||||
|
||||
// DOM 元素
|
||||
const connectionStatus = document.getElementById('connectionStatus');
|
||||
const connectTime = document.getElementById('connectTime');
|
||||
const connectBtn = document.getElementById('connectBtn');
|
||||
const disconnectBtn = document.getElementById('disconnectBtn');
|
||||
const sendMessageBtn = document.getElementById('sendMessageBtn');
|
||||
const sendHeartbeatBtn = document.getElementById('sendHeartbeatBtn');
|
||||
const autoHeartbeat = document.getElementById('autoHeartbeat');
|
||||
const heartbeatInterval = document.getElementById('heartbeatInterval');
|
||||
const messageInput = document.getElementById('messageInput');
|
||||
const logContainer = document.getElementById('logContainer');
|
||||
const clearLogBtn = document.getElementById('clearLogBtn');
|
||||
|
||||
// 工具函数:添加日志
|
||||
function addLog(content, type = 'status') {
|
||||
const now = new Date().toLocaleString('zh-CN', {
|
||||
year: 'numeric', month: '2-digit', day: '2-digit',
|
||||
hour: '2-digit', minute: '2-digit', second: '2-digit'
|
||||
});
|
||||
const logItem = document.createElement('div');
|
||||
logItem.className = 'log-item';
|
||||
|
||||
let logClass = '';
|
||||
switch (type) {
|
||||
case 'send':
|
||||
logClass = 'log-send';
|
||||
break;
|
||||
case 'receive':
|
||||
logClass = 'log-receive';
|
||||
break;
|
||||
case 'error':
|
||||
logClass = 'log-error';
|
||||
break;
|
||||
default:
|
||||
logClass = 'log-status';
|
||||
}
|
||||
|
||||
logItem.innerHTML = `<span class="log-time">[${now}]</span> <span class="${logClass}">${content}</span>`;
|
||||
logContainer.appendChild(logItem);
|
||||
// 滚动到最新日志
|
||||
logContainer.scrollTop = logContainer.scrollHeight;
|
||||
}
|
||||
|
||||
// 工具函数:格式化JSON(便于日志显示)
|
||||
function formatJson(jsonStr) {
|
||||
try {
|
||||
const obj = JSON.parse(jsonStr);
|
||||
return JSON.stringify(obj, null, 2);
|
||||
} catch (e) {
|
||||
return jsonStr; // 非JSON格式直接返回
|
||||
}
|
||||
}
|
||||
|
||||
// 建立WebSocket连接
|
||||
function connectWebSocket() {
|
||||
if (ws) {
|
||||
addLog('已存在连接,无需重复建立', 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
ws = new WebSocket(wsUrl);
|
||||
|
||||
// 连接成功
|
||||
ws.onopen = function () {
|
||||
connectionStatus.className = 'status-value status-connected';
|
||||
connectionStatus.textContent = '已连接';
|
||||
const now = new Date().toLocaleString('zh-CN');
|
||||
connectTime.textContent = now;
|
||||
addLog(`连接成功!服务地址:${wsUrl}`, 'status');
|
||||
|
||||
// 更新按钮状态
|
||||
connectBtn.disabled = true;
|
||||
disconnectBtn.disabled = false;
|
||||
sendMessageBtn.disabled = false;
|
||||
|
||||
// 开启自动心跳(默认开启)
|
||||
if (autoHeartbeat.value === 'on') {
|
||||
startAutoHeartbeat();
|
||||
}
|
||||
};
|
||||
|
||||
// 接收消息
|
||||
ws.onmessage = function (event) {
|
||||
const message = event.data;
|
||||
addLog(`收到消息:\n${formatJson(message)}`, 'receive');
|
||||
};
|
||||
|
||||
// 连接关闭
|
||||
ws.onclose = function (event) {
|
||||
connectionStatus.className = 'status-value status-disconnected';
|
||||
connectionStatus.textContent = '已断开';
|
||||
addLog(`连接断开!代码:${event.code},原因:${event.reason || '未知'}`, 'status');
|
||||
|
||||
// 清除自动心跳
|
||||
stopAutoHeartbeat();
|
||||
|
||||
// 更新按钮状态
|
||||
connectBtn.disabled = false;
|
||||
disconnectBtn.disabled = true;
|
||||
sendMessageBtn.disabled = true;
|
||||
|
||||
// 重置WebSocket对象
|
||||
ws = null;
|
||||
};
|
||||
|
||||
// 连接错误
|
||||
ws.onerror = function (error) {
|
||||
addLog(`连接错误:${error.message || '未知错误'}`, 'error');
|
||||
};
|
||||
|
||||
} catch (e) {
|
||||
addLog(`建立连接失败:${e.message}`, 'error');
|
||||
ws = null;
|
||||
}
|
||||
}
|
||||
|
||||
// 断开WebSocket连接
|
||||
function disconnectWebSocket() {
|
||||
if (!ws) {
|
||||
addLog('当前无连接,无需断开', 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
ws.close(1000, '手动断开连接');
|
||||
}
|
||||
|
||||
// 发送心跳消息(符合约定格式:{"timestamp":xxxxx, "type":"heartbeat"})
|
||||
function sendHeartbeat() {
|
||||
if (!ws || ws.readyState !== WebSocket.OPEN) {
|
||||
addLog('发送心跳失败:当前无有效连接', 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
const heartbeatMsg = {
|
||||
timestamp: Date.now(), // 当前毫秒时间戳
|
||||
type: "heartbeat"
|
||||
};
|
||||
const msgStr = JSON.stringify(heartbeatMsg);
|
||||
|
||||
ws.send(msgStr);
|
||||
addLog(`发送心跳:\n${formatJson(msgStr)}`, 'send');
|
||||
}
|
||||
|
||||
// 开启自动心跳
|
||||
function startAutoHeartbeat() {
|
||||
// 先停止已有定时器
|
||||
stopAutoHeartbeat();
|
||||
|
||||
const interval = parseInt(heartbeatInterval.value) * 1000;
|
||||
if (isNaN(interval) || interval < 10000) {
|
||||
addLog('自动心跳间隔无效,已重置为30秒', 'error');
|
||||
heartbeatInterval.value = 30;
|
||||
return startAutoHeartbeat();
|
||||
}
|
||||
|
||||
addLog(`开启自动心跳,间隔:${heartbeatInterval.value}秒`, 'status');
|
||||
heartbeatTimer = setInterval(sendHeartbeat, interval);
|
||||
}
|
||||
|
||||
// 停止自动心跳
|
||||
function stopAutoHeartbeat() {
|
||||
if (heartbeatTimer) {
|
||||
clearInterval(heartbeatTimer);
|
||||
heartbeatTimer = null;
|
||||
addLog('已停止自动心跳', 'status');
|
||||
}
|
||||
}
|
||||
|
||||
// 发送自定义消息
|
||||
function sendCustomMessage() {
|
||||
if (!ws || ws.readyState !== WebSocket.OPEN) {
|
||||
addLog('发送消息失败:当前无有效连接', 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
const msgStr = messageInput.value.trim();
|
||||
if (!msgStr) {
|
||||
addLog('发送消息失败:消息内容不能为空', 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// 验证JSON格式(可选,仅提示不强制)
|
||||
JSON.parse(msgStr);
|
||||
ws.send(msgStr);
|
||||
addLog(`发送自定义消息:\n${formatJson(msgStr)}`, 'send');
|
||||
} catch (e) {
|
||||
addLog(`JSON格式错误:${e.message},仍尝试发送原始内容`, 'error');
|
||||
ws.send(msgStr);
|
||||
addLog(`发送自定义消息(非JSON):\n${msgStr}`, 'send');
|
||||
}
|
||||
}
|
||||
|
||||
// 绑定按钮事件
|
||||
connectBtn.addEventListener('click', connectWebSocket);
|
||||
disconnectBtn.addEventListener('click', disconnectWebSocket);
|
||||
sendMessageBtn.addEventListener('click', sendCustomMessage);
|
||||
sendHeartbeatBtn.addEventListener('click', sendHeartbeat);
|
||||
clearLogBtn.addEventListener('click', () => {
|
||||
logContainer.innerHTML = '<div class="log-item"><span class="log-time">[日志已清空]</span> 请继续操作...</div>';
|
||||
});
|
||||
|
||||
// 自动心跳开关变更事件
|
||||
autoHeartbeat.addEventListener('change', function () {
|
||||
if (ws && ws.readyState === WebSocket.OPEN) {
|
||||
if (this.value === 'on') {
|
||||
startAutoHeartbeat();
|
||||
} else {
|
||||
stopAutoHeartbeat();
|
||||
}
|
||||
} else {
|
||||
addLog('需先建立有效连接才能控制自动心跳', 'error');
|
||||
// 重置选择
|
||||
this.value = 'off';
|
||||
}
|
||||
});
|
||||
|
||||
// 心跳间隔变更事件(实时生效)
|
||||
heartbeatInterval.addEventListener('change', function () {
|
||||
if (autoHeartbeat.value === 'on' && ws && ws.readyState === WebSocket.OPEN) {
|
||||
startAutoHeartbeat();
|
||||
}
|
||||
});
|
||||
|
||||
// 快捷键支持(Ctrl+Enter发送消息)
|
||||
messageInput.addEventListener('keydown', function (e) {
|
||||
if (e.ctrlKey && e.key === 'Enter') {
|
||||
sendCustomMessage();
|
||||
e.preventDefault();
|
||||
}
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
338
ws/ws.py
@ -3,288 +3,296 @@ import datetime
|
||||
import json
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict, Optional, AsyncGenerator
|
||||
from typing import Dict, Optional
|
||||
from service.device_service import update_online_status_by_ip, increment_alarm_count_by_ip
|
||||
from service.device_action_service import add_device_action
|
||||
from schema.device_action_schema import DeviceActionCreate
|
||||
from core.all import detect, load_model
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI
|
||||
|
||||
from ocr.model_violation_detector import MultiModelViolationDetector
|
||||
|
||||
# 配置文件相对路径(根据实际目录结构调整)
|
||||
YOLO_MODEL_PATH = r"D:\Git\bin\video\ocr\models\best.pt"
|
||||
FORBIDDEN_WORDS_PATH = r"D:\Git\bin\video\ocr\forbidden_words.txt"
|
||||
OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml"
|
||||
KNOWN_FACES_DIR = r"D:\Git\bin\video\ocr\known_faces"
|
||||
|
||||
# 创建检测器实例
|
||||
detector = MultiModelViolationDetector(
|
||||
forbidden_words_path=FORBIDDEN_WORDS_PATH,
|
||||
ocr_config_path=OCR_CONFIG_PATH,
|
||||
yolo_model_path=YOLO_MODEL_PATH,
|
||||
known_faces_dir=KNOWN_FACES_DIR,
|
||||
ocr_confidence_threshold=0.5
|
||||
)
|
||||
|
||||
# -------------------------- 配置常量 --------------------------
|
||||
# 配置常量
|
||||
HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒)
|
||||
HEARTBEAT_TIMEOUT = 600 # 客户端超时阈值(秒)
|
||||
WS_ENDPOINT = "/ws" # WebSocket端点路径
|
||||
FRAME_QUEUE_SIZE = 1 # 帧队列大小限制(保持1,确保单帧处理)
|
||||
|
||||
# -------------------------- 核心数据结构与全局变量 --------------------------
|
||||
ws_router = APIRouter()
|
||||
FRAME_QUEUE_SIZE = 1 # 帧队列大小限制
|
||||
|
||||
|
||||
# 客户端连接封装(包含帧队列)
|
||||
# 工具函数: 获取格式化时间字符串
|
||||
def get_current_time_str() -> str:
|
||||
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
def get_current_time_file_str() -> str:
|
||||
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
|
||||
|
||||
# 客户端连接封装
|
||||
class ClientConnection:
|
||||
def __init__(self, websocket: WebSocket, client_ip: str):
|
||||
self.websocket = websocket
|
||||
self.client_ip = client_ip
|
||||
self.client_ip = client_ip # 已初始化客户端IP,用于传递给detect
|
||||
self.last_heartbeat = datetime.datetime.now()
|
||||
self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE) # 帧队列,长度为1
|
||||
self.consumer_task: Optional[asyncio.Task] = None # 消费者任务
|
||||
self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE)
|
||||
self.consumer_task: Optional[asyncio.Task] = None
|
||||
|
||||
# 更新心跳时间
|
||||
def update_heartbeat(self):
|
||||
"""更新心跳时间"""
|
||||
self.last_heartbeat = datetime.datetime.now()
|
||||
|
||||
# 检查是否存活(超时返回False)
|
||||
def is_alive(self) -> bool:
|
||||
timeout = (datetime.datetime.now() - self.last_heartbeat).total_seconds()
|
||||
return timeout < HEARTBEAT_TIMEOUT
|
||||
"""判断客户端是否存活"""
|
||||
timeout_seconds = (datetime.datetime.now() - self.last_heartbeat).total_seconds()
|
||||
return timeout_seconds < HEARTBEAT_TIMEOUT
|
||||
|
||||
# 启动帧消费任务
|
||||
def start_consumer(self):
|
||||
"""启动帧消费任务"""
|
||||
self.consumer_task = asyncio.create_task(self.consume_frames())
|
||||
return self.consumer_task
|
||||
|
||||
# ---------- 新增:发送“允许发送二进制帧”的信号给客户端 ----------
|
||||
async def send_allow_send_frame(self):
|
||||
"""向客户端发送JSON信号,通知其可发送下一帧二进制数据"""
|
||||
async def send_frame_permit(self):
|
||||
"""发送帧发送许可信号"""
|
||||
try:
|
||||
allow_msg = {
|
||||
"type": "allow_send_frame", # 信号类型,与客户端约定
|
||||
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"status": "ready", # 表示服务器已准备好接收下一帧
|
||||
"client_ip": self.client_ip # 可选:便于客户端确认自身身份
|
||||
frame_permit_msg = {
|
||||
"type": "frame",
|
||||
"timestamp": get_current_time_str(),
|
||||
"client_ip": self.client_ip
|
||||
}
|
||||
await self.websocket.send_json(allow_msg)
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}:已发送「允许发送帧」信号")
|
||||
await self.websocket.send_json(frame_permit_msg)
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 已发送帧发送许可信号")
|
||||
except Exception as e:
|
||||
# 发送失败大概率是客户端已断开,不影响主流程,仅日志记录
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}:发送「允许发送帧」信号失败 - {str(e)}")
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧许可信号发送失败 - {str(e)}")
|
||||
|
||||
# 帧消费协程
|
||||
async def consume_frames(self) -> None:
|
||||
"""从队列中获取帧并进行处理,处理完后通知客户端可发送下一帧"""
|
||||
"""消费队列中的帧并处理"""
|
||||
try:
|
||||
while True:
|
||||
# 从队列获取帧数据(队列空时会阻塞,等待客户端发送)
|
||||
# 取出帧并立即发送下一帧许可
|
||||
frame_data = await self.frame_queue.get()
|
||||
await self.send_frame_permit()
|
||||
|
||||
try:
|
||||
# 处理帧数据
|
||||
await self.process_frame(frame_data)
|
||||
finally:
|
||||
# 标记任务完成(队列计数-1,此时队列回到空状态)
|
||||
self.frame_queue.task_done()
|
||||
# ---------- 修改:处理完当前帧后,立即通知客户端可发送下一帧 ----------
|
||||
await self.send_allow_send_frame()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}:帧消费任务已取消")
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费任务已取消")
|
||||
except Exception as e:
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}:帧处理错误 - {str(e)}")
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费逻辑错误 - {str(e)}")
|
||||
|
||||
async def process_frame(self, frame_data: bytes) -> None:
|
||||
"""处理单帧图像数据(原有逻辑不变)"""
|
||||
# 将二进制数据转换为NumPy数组(uint8类型)
|
||||
"""处理单帧图像数据(核心修改:detect函数传入 client_ip + img 双参数)"""
|
||||
# 二进制转OpenCV图像
|
||||
nparr = np.frombuffer(frame_data, np.uint8)
|
||||
# 解码为图像,返回与cv2.imread相同的格式(BGR通道的ndarray)
|
||||
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
|
||||
# 确保images文件夹存在
|
||||
if not os.path.exists('images'):
|
||||
os.makedirs('images')
|
||||
|
||||
# 生成唯一的文件名,包含时间戳和客户端IP,避免文件名冲突
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
filename = f"images/{self.client_ip.replace('.', '_')}_{timestamp}.jpg"
|
||||
if img is None:
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 无法解析图像数据")
|
||||
return
|
||||
|
||||
try:
|
||||
# 保存图像到本地
|
||||
cv2.imwrite(filename, img)
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 图像已保存至:{filename}")
|
||||
# -------------------------- 核心修改:按要求传入参数(1.client_ip 2.img) --------------------------
|
||||
# detect函数参数顺序:第一个为client_ip,第二个为图像数据img
|
||||
# 保持返回值解包(是否违规, 结果数据, 检测器类型)不变
|
||||
has_violation, data, detector_type = await asyncio.to_thread(
|
||||
detect, # 调用检测函数
|
||||
self.client_ip, # 第一个参数:客户端IP(新增,按需求顺序)
|
||||
img # 第二个参数:图像数据(原参数,调整顺序)
|
||||
)
|
||||
# -------------------------------------------------------------------------------------
|
||||
|
||||
# 进行检测
|
||||
if img is not None:
|
||||
has_violation, violation_type, details = detector.detect_violations(img)
|
||||
# 打印检测结果(包含客户端IP,与传入参数对应)
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测结果 - "
|
||||
f"违规: {has_violation}, 类型: {detector_type}, 数据: {data}")
|
||||
|
||||
# 处理违规逻辑(逻辑不变,基于detect返回结果执行)
|
||||
if has_violation:
|
||||
print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
|
||||
# 发送检测结果回客户端(原有逻辑不变)
|
||||
await self.websocket.send_json({
|
||||
"type": "detection_result",
|
||||
"has_violation": has_violation,
|
||||
"violation_type": violation_type,
|
||||
"details": details,
|
||||
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
})
|
||||
else:
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}:未检测到任何违规内容")
|
||||
else:
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}:无法解析图像数据")
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - "
|
||||
f"类型: {detector_type}, 详情: {data}")
|
||||
|
||||
# 违规次数+1
|
||||
try:
|
||||
await asyncio.to_thread(increment_alarm_count_by_ip, self.client_ip)
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数已+1")
|
||||
except Exception as e:
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}:图像处理错误 - {str(e)}")
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数更新失败 - {str(e)}")
|
||||
|
||||
# 发送危险通知
|
||||
danger_msg = {
|
||||
"type": "danger",
|
||||
"timestamp": get_current_time_str(),
|
||||
"client_ip": self.client_ip,
|
||||
"detail": data
|
||||
}
|
||||
await self.websocket.send_json(danger_msg)
|
||||
else:
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 未检测到违规")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 图像处理错误 - {str(e)}")
|
||||
|
||||
|
||||
# 全局连接管理(IP -> 连接实例)
|
||||
# 全局状态管理
|
||||
connected_clients: Dict[str, ClientConnection] = {}
|
||||
# 心跳任务(全局引用,用于关闭时清理)
|
||||
heartbeat_task: Optional[asyncio.Task] = None
|
||||
|
||||
|
||||
# -------------------------- 心跳检查逻辑(原有逻辑不变) --------------------------
|
||||
# 心跳检查任务
|
||||
async def heartbeat_checker():
|
||||
while True:
|
||||
now = datetime.datetime.now()
|
||||
# 1. 筛选超时客户端(避免遍历中修改字典)
|
||||
current_time = get_current_time_str()
|
||||
timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()]
|
||||
|
||||
# 2. 处理超时连接(关闭+移除)
|
||||
if timeout_ips:
|
||||
print(f"[{now:%H:%M:%S}] 心跳检查:{len(timeout_ips)}个客户端超时({timeout_ips})")
|
||||
print(f"[{current_time}] 心跳检查: {len(timeout_ips)}个客户端超时(IP: {timeout_ips})")
|
||||
for ip in timeout_ips:
|
||||
try:
|
||||
# 取消消费者任务
|
||||
if connected_clients[ip].consumer_task and not connected_clients[ip].consumer_task.done():
|
||||
connected_clients[ip].consumer_task.cancel()
|
||||
await connected_clients[ip].websocket.close(code=1008, reason="心跳超时")
|
||||
conn = connected_clients[ip]
|
||||
if conn.consumer_task and not conn.consumer_task.done():
|
||||
conn.consumer_task.cancel()
|
||||
await conn.websocket.close(code=1008, reason="心跳超时")
|
||||
|
||||
# 标记离线
|
||||
try:
|
||||
await asyncio.to_thread(update_online_status_by_ip, ip, 0)
|
||||
action_data = DeviceActionCreate(client_ip=ip, action=0)
|
||||
await asyncio.to_thread(add_device_action, action_data)
|
||||
print(f"[{current_time}] 客户端{ip}: 已标记为离线并记录操作")
|
||||
except Exception as e:
|
||||
print(f"[{current_time}] 客户端{ip}: 离线状态更新失败 - {str(e)}")
|
||||
finally:
|
||||
connected_clients.pop(ip, None)
|
||||
else:
|
||||
print(f"[{now:%H:%M:%S}] 心跳检查:{len(connected_clients)}个客户端在线,无超时")
|
||||
print(f"[{current_time}] 心跳检查: {len(connected_clients)}个客户端在线")
|
||||
|
||||
# 3. 等待下一轮检查
|
||||
await asyncio.sleep(HEARTBEAT_INTERVAL)
|
||||
|
||||
|
||||
# -------------------------- 应用生命周期(原有逻辑不变) --------------------------
|
||||
# 应用生命周期管理
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global heartbeat_task
|
||||
# 启动心跳任务
|
||||
heartbeat_task = asyncio.create_task(heartbeat_checker())
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 心跳任务启动(ID:{id(heartbeat_task)})")
|
||||
print(f"[{get_current_time_str()}] 全局心跳检查任务启动(任务ID: {id(heartbeat_task)})")
|
||||
yield
|
||||
# 关闭时取消心跳任务
|
||||
if heartbeat_task and not heartbeat_task.done():
|
||||
heartbeat_task.cancel()
|
||||
try:
|
||||
await heartbeat_task
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 心跳任务已取消")
|
||||
print(f"[{get_current_time_str()}] 全局心跳检查任务已取消")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
# -------------------------- 消息处理(文本/心跳逻辑不变,二进制逻辑保留) --------------------------
|
||||
async def send_heartbeat_ack(client_ip: str):
|
||||
"""回复心跳确认(原有逻辑不变)"""
|
||||
if client_ip not in connected_clients:
|
||||
return False
|
||||
# 消息处理工具函数
|
||||
async def send_heartbeat_ack(conn: ClientConnection):
|
||||
try:
|
||||
ack = {
|
||||
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"type": "heartbeat"
|
||||
heartbeat_ack_msg = {
|
||||
"type": "heart",
|
||||
"timestamp": get_current_time_str(),
|
||||
"client_ip": conn.client_ip
|
||||
}
|
||||
await connected_clients[client_ip].websocket.send_json(ack)
|
||||
await conn.websocket.send_json(heartbeat_ack_msg)
|
||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 已发送心跳确认")
|
||||
return True
|
||||
except Exception:
|
||||
connected_clients.pop(client_ip, None)
|
||||
except Exception as e:
|
||||
connected_clients.pop(conn.client_ip, None)
|
||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 心跳确认发送失败 - {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def handle_text_msg(client_ip: str, text: str, conn: ClientConnection):
|
||||
"""处理文本消息(核心:心跳+JSON解析,原有逻辑不变)"""
|
||||
async def handle_text_msg(conn: ClientConnection, text: str):
|
||||
try:
|
||||
msg = json.loads(text)
|
||||
# 仅处理心跳类型消息
|
||||
if msg.get("type") == "heartbeat":
|
||||
if msg.get("type") == "heart":
|
||||
conn.update_heartbeat()
|
||||
await send_heartbeat_ack(client_ip)
|
||||
await send_heartbeat_ack(conn)
|
||||
else:
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:收到文本消息:{msg}")
|
||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 未知文本消息类型({msg.get('type')})")
|
||||
except json.JSONDecodeError:
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:无效JSON消息")
|
||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 无效JSON文本消息")
|
||||
|
||||
|
||||
async def handle_binary_msg(client_ip: str, data: bytes):
|
||||
"""处理二进制消息(原有逻辑不变,因客户端仅在收到允许信号后发送,队列不会满)"""
|
||||
if client_ip not in connected_clients:
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:连接不存在,丢弃{len(data)}字节数据")
|
||||
return
|
||||
|
||||
conn = connected_clients[client_ip]
|
||||
|
||||
# 检查队列是否已满(理论上不会触发,因客户端按信号发送)
|
||||
if conn.frame_queue.full():
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:队列已满,丢弃{len(data)}字节数据")
|
||||
return
|
||||
|
||||
# 队列未满,添加帧到队列
|
||||
async def handle_binary_msg(conn: ClientConnection, data: bytes):
|
||||
try:
|
||||
conn.frame_queue.put_nowait(data)
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:已接收{len(data)}字节二进制数据,加入队列")
|
||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 图像数据({len(data)}字节)已加入队列")
|
||||
except asyncio.QueueFull:
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:队列突然满了,丢弃{len(data)}字节数据")
|
||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 帧队列已满、丢弃当前图像数据")
|
||||
|
||||
|
||||
# WebSocket路由配置
|
||||
ws_router = APIRouter()
|
||||
|
||||
|
||||
# -------------------------- WebSocket核心端点(修改连接初始化逻辑) --------------------------
|
||||
@ws_router.websocket(WS_ENDPOINT)
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
# 接受连接 + 获取客户端IP
|
||||
load_model() # 加载检测模型(仅在连接建立时加载一次,避免重复加载)
|
||||
await websocket.accept()
|
||||
client_ip = websocket.client.host if websocket.client else "unknown"
|
||||
now = datetime.datetime.now()
|
||||
print(f"[{now:%H:%M:%S}] 客户端{client_ip}:连接成功")
|
||||
client_ip = websocket.client.host if websocket.client else "unknown_ip"
|
||||
current_time = get_current_time_str()
|
||||
print(f"[{current_time}] 客户端{client_ip}: WebSocket连接已建立")
|
||||
|
||||
is_online_updated = False
|
||||
|
||||
consumer_task = None
|
||||
try:
|
||||
# 处理重复连接(关闭旧连接)
|
||||
# 处理重复连接(同一IP断开旧连接)
|
||||
if client_ip in connected_clients:
|
||||
# 取消旧连接的消费者任务
|
||||
if connected_clients[client_ip].consumer_task and not connected_clients[client_ip].consumer_task.done():
|
||||
connected_clients[client_ip].consumer_task.cancel()
|
||||
await connected_clients[client_ip].websocket.close(code=1008, reason="同一IP新连接")
|
||||
old_conn = connected_clients[client_ip]
|
||||
if old_conn.consumer_task and not old_conn.consumer_task.done():
|
||||
old_conn.consumer_task.cancel()
|
||||
await old_conn.websocket.close(code=1008, reason="同一IP新连接建立")
|
||||
connected_clients.pop(client_ip)
|
||||
print(f"[{now:%H:%M:%S}] 客户端{client_ip}:关闭旧连接")
|
||||
print(f"[{current_time}] 客户端{client_ip}: 已关闭旧连接")
|
||||
|
||||
# 注册新连接
|
||||
# 注册新连接(绑定client_ip和WebSocket)
|
||||
new_conn = ClientConnection(websocket, client_ip)
|
||||
connected_clients[client_ip] = new_conn
|
||||
new_conn.start_consumer() # 启动帧消费任务
|
||||
await new_conn.send_frame_permit() # 发送首次帧许可
|
||||
|
||||
# 启动帧消费任务
|
||||
consumer_task = new_conn.start_consumer()
|
||||
# ---------- 修改:客户端刚连接时,队列空,立即发送「允许发送帧」信号 ----------
|
||||
await new_conn.send_allow_send_frame()
|
||||
print(f"[{now:%H:%M:%S}] 客户端{client_ip}:注册成功,已启动帧消费任务,当前在线{len(connected_clients)}个")
|
||||
# 标记客户端上线
|
||||
try:
|
||||
await asyncio.to_thread(update_online_status_by_ip, client_ip, 1)
|
||||
action_data = DeviceActionCreate(client_ip=client_ip, action=1)
|
||||
await asyncio.to_thread(add_device_action, action_data)
|
||||
print(f"[{current_time}] 客户端{client_ip}: 已标记为在线并记录操作")
|
||||
is_online_updated = True
|
||||
except Exception as e:
|
||||
print(f"[{current_time}] 客户端{client_ip}: 上线状态更新失败 - {str(e)}")
|
||||
|
||||
# 循环接收消息(原有逻辑不变)
|
||||
print(f"[{current_time}] 客户端{client_ip}: 新连接注册成功、在线数: {len(connected_clients)}")
|
||||
|
||||
# 消息循环(持续接收客户端消息)
|
||||
while True:
|
||||
data = await websocket.receive()
|
||||
if "text" in data:
|
||||
await handle_text_msg(client_ip, data["text"], new_conn)
|
||||
await handle_text_msg(new_conn, data["text"])
|
||||
elif "bytes" in data:
|
||||
await handle_binary_msg(client_ip, data["bytes"])
|
||||
await handle_binary_msg(new_conn, data["bytes"])
|
||||
|
||||
# 异常处理(断开/错误)
|
||||
except WebSocketDisconnect as e:
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:主动断开(代码:{e.code})")
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}: 主动断开连接(代码: {e.code})")
|
||||
except Exception as e:
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:连接异常({str(e)[:50]})")
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}: 连接异常 - {str(e)[:50]}")
|
||||
finally:
|
||||
# 清理连接和任务
|
||||
# 清理资源(断开后标记离线+删除连接)
|
||||
if client_ip in connected_clients:
|
||||
# 取消消费者任务
|
||||
if connected_clients[client_ip].consumer_task and not connected_clients[client_ip].consumer_task.done():
|
||||
connected_clients[client_ip].consumer_task.cancel()
|
||||
conn = connected_clients[client_ip]
|
||||
if conn.consumer_task and not conn.consumer_task.done():
|
||||
conn.consumer_task.cancel()
|
||||
|
||||
# 仅当上线状态更新成功时,才执行离线更新
|
||||
if is_online_updated:
|
||||
try:
|
||||
await asyncio.to_thread(update_online_status_by_ip, client_ip, 0)
|
||||
action_data = DeviceActionCreate(client_ip=client_ip, action=0)
|
||||
await asyncio.to_thread(add_device_action, action_data)
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}: 断开后已标记为离线")
|
||||
except Exception as e:
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}: 断开后离线更新失败 - {str(e)}")
|
||||
|
||||
connected_clients.pop(client_ip, None)
|
||||
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:连接已清理,当前在线{len(connected_clients)}个")
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}: 资源已清理、在线数: {len(connected_clients)}")
|