yolo模型识别不到
This commit is contained in:
10
core/all.py
10
core/all.py
@ -67,7 +67,7 @@ def detect(client_ip, frame):
|
||||
if full_save_path:
|
||||
cv2.imwrite(full_save_path, frame)
|
||||
print(f"✅ yolo违规图片已保存:{display_path}") # 日志也修正
|
||||
save_db(model_type="yolo", client_ip=client_ip, result=str(full_save_path))
|
||||
save_db(model_type="yolo", client_ip=client_ip, result=str(display_path))
|
||||
return (True, yolo_result, "yolo")
|
||||
|
||||
# 2. 人脸检测
|
||||
@ -77,17 +77,19 @@ def detect(client_ip, frame):
|
||||
if full_save_path:
|
||||
cv2.imwrite(full_save_path, frame)
|
||||
print(f"✅ face违规图片已保存:{display_path}") # 日志也修正
|
||||
save_db(model_type="face", client_ip=client_ip, result=str(full_save_path))
|
||||
save_db(model_type="face", client_ip=client_ip, result=str(display_path))
|
||||
return (True, face_result, "face")
|
||||
|
||||
# 3. OCR检测
|
||||
ocr_flag, ocr_result = ocrDetect(frame)
|
||||
if ocr_flag:
|
||||
full_save_path, display_path = get_image_save_path(model_type="ocr", client_ip=client_ip) # 这里改了
|
||||
full_save_path, display_path = get_image_save_path(model_type="ocr", client_ip=client_ip)
|
||||
print(f"✅ ocr违规图片已保存:{display_path}")
|
||||
# 这里改了
|
||||
if full_save_path:
|
||||
cv2.imwrite(full_save_path, frame)
|
||||
print(f"✅ ocr违规图片已保存:{display_path}") # 日志也修正
|
||||
save_db(model_type="ocr", client_ip=client_ip, result=str(full_save_path))
|
||||
save_db(model_type="ocr", client_ip=client_ip, result=str(display_path))
|
||||
return (True, ocr_result, "ocr")
|
||||
|
||||
# 4. 无违规内容(不保存图片)
|
||||
|
||||
@ -1,30 +1,29 @@
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
from service.device_service import get_unique_client_ips
|
||||
|
||||
|
||||
def create_directory_structure():
|
||||
"""创建项目所需的目录结构,为所有客户端IP预创建基础目录"""
|
||||
try:
|
||||
# 1. 创建根目录下的resource文件夹(存在则跳过,不覆盖子内容)
|
||||
resource_dir = Path("resource")
|
||||
resource_dir.mkdir(exist_ok=True)
|
||||
# print(f"确保resource目录存在: {resource_dir.absolute()}")
|
||||
|
||||
# 2. 在resource下创建dect文件夹
|
||||
dect_dir = resource_dir / "dect"
|
||||
dect_dir.mkdir(exist_ok=True)
|
||||
# print(f"确保dect目录存在: {dect_dir.absolute()}")
|
||||
|
||||
# 3. 在dect下创建三个模型文件夹
|
||||
model_dirs = ["ocr", "face", "yolo"]
|
||||
for model in model_dirs:
|
||||
model_dir = dect_dir / model
|
||||
model_dir.mkdir(exist_ok=True)
|
||||
# print(f"确保{model}模型目录存在: {model_dir.absolute()}")
|
||||
|
||||
# 4. 调用外部方法获取所有客户端IP地址
|
||||
try:
|
||||
# 调用外部ip_read()方法获取所有客户端IP地址列表
|
||||
all_ip_addresses = get_unique_client_ips()
|
||||
|
||||
# 确保返回的是列表类型
|
||||
@ -58,7 +57,6 @@ def create_directory_structure():
|
||||
|
||||
# 递归创建目录(存在则跳过,不覆盖)
|
||||
month_dir.mkdir(parents=True, exist_ok=True)
|
||||
# print(f"为客户端IP {ip} 创建/确保目录存在: {month_dir.absolute()}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理客户端IP和日期目录时发生错误: {str(e)}")
|
||||
@ -67,52 +65,68 @@ def create_directory_structure():
|
||||
print(f"创建基础目录结构时发生错误: {str(e)}")
|
||||
|
||||
|
||||
def get_image_save_path(model_type: str, client_ip: str) -> tuple:
|
||||
def get_image_save_path(model_type: str, client_ip: str) -> Tuple[str, str]:
|
||||
"""
|
||||
获取图片保存的「完整路径」和「显示用短路径」
|
||||
获取图片保存的「本地完整路径」和「带路由前缀的显示路径」
|
||||
|
||||
参数:
|
||||
model_type: 模型类型,应为"ocr"、"face"或"yolo"
|
||||
client_ip: 检测到违禁的客户端IP地址(原始格式,如192.168.1.101)
|
||||
|
||||
返回:
|
||||
元组 (完整保存路径, 显示用短路径);若出错则返回 ("", "")
|
||||
元组 (本地完整保存路径, 带/api/file/前缀的显示路径);若出错则返回 ("", "")
|
||||
"""
|
||||
try:
|
||||
# 验证模型类型有效性
|
||||
valid_models = ["ocr", "face", "yolo"]
|
||||
if model_type not in valid_models:
|
||||
raise ValueError(f"无效的模型类型: {model_type},必须是{valid_models}之一")
|
||||
|
||||
# 1. 验证客户端IP有效性(检查是否在已知IP列表中)
|
||||
all_ip_addresses = get_unique_client_ips()
|
||||
if not isinstance(all_ip_addresses, list):
|
||||
all_ip_addresses = [all_ip_addresses]
|
||||
valid_ips = [ip.strip() for ip in all_ip_addresses if ip.strip()]
|
||||
|
||||
if client_ip.strip() not in valid_ips:
|
||||
client_ip_stripped = client_ip.strip()
|
||||
if client_ip_stripped not in valid_ips:
|
||||
raise ValueError(f"客户端IP {client_ip} 不在已知IP列表中,无法保存文件")
|
||||
|
||||
# 2. 处理IP地址(与目录创建逻辑一致,将.替换为_)
|
||||
safe_ip = client_ip.strip().replace(".", "_")
|
||||
# 2. 处理IP地址(将.替换为_,避免路径问题)
|
||||
safe_ip = client_ip_stripped.replace(".", "_")
|
||||
|
||||
# 3. 获取当前日期和毫秒级时间戳(确保文件名唯一)
|
||||
now = datetime.datetime.now()
|
||||
current_year = str(now.year)
|
||||
current_month = str(now.month)
|
||||
current_day = str(now.day)
|
||||
# 时间戳格式:年月日时分秒毫秒(如20250910143050123)
|
||||
timestamp = now.strftime("%Y%m%d%H%M%S%f")[:-3]
|
||||
current_month = str(now.month).zfill(2) # 确保月份为两位数
|
||||
current_day = str(now.day).zfill(2) # 确保日期为两位数
|
||||
timestamp = now.strftime("%Y%m%d%H%M%S%f")[:-3] # 取毫秒级时间戳
|
||||
|
||||
# 4. 定义基础目录(用于生成相对路径)
|
||||
base_dir = Path("resource") / "dect" # 显示路径会去掉这个前缀
|
||||
base_dir = Path("resource") / "dect"
|
||||
# 构建日级目录(完整路径:resource/dect/{model}/{safe_ip}/{年}/{月}/{日})
|
||||
day_dir = base_dir / model_type / safe_ip / current_year / current_month / current_day
|
||||
day_dir.mkdir(parents=True, exist_ok=True) # 确保日目录存在
|
||||
day_dir.mkdir(parents=True, exist_ok=True) # 确保目录存在
|
||||
|
||||
# 5. 构建唯一文件名
|
||||
image_filename = f"dect_{model_type}_{safe_ip}_{current_year}{current_month}{current_day}_{timestamp}.jpg"
|
||||
|
||||
# 6. 生成完整路径(用于实际保存图片)和显示路径(用于打印)
|
||||
full_path = day_dir / image_filename # 完整路径:resource/dect/.../xxx.jpg
|
||||
display_path = full_path.relative_to(base_dir) # 短路径:{model}/.../xxx.jpg(去掉resource/dect)
|
||||
# 6. 生成「本地完整路径」(使用系统路径,但在字符串表示时统一为正斜杠)
|
||||
local_full_path = day_dir / image_filename
|
||||
# 转换为字符串并统一使用正斜杠
|
||||
local_full_path_str = str(local_full_path).replace("\\", "/")
|
||||
|
||||
return str(full_path), str(display_path)
|
||||
# 7. 生成带路由前缀的显示路径(核心修改部分)
|
||||
# 获取项目根目录(base_dir是resource/dect,向上两级即为项目根目录)
|
||||
project_root = base_dir.parents[1]
|
||||
# 计算相对于项目根目录的路径(包含resource/dect层级)
|
||||
relative_path = local_full_path.relative_to(project_root)
|
||||
# 转换为字符串并统一使用正斜杠
|
||||
relative_path_str = str(relative_path).replace("\\", "/")
|
||||
# 拼接路由前缀
|
||||
routed_display_path = f"/api/file/{relative_path_str}"
|
||||
|
||||
return local_full_path_str, routed_display_path
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取图片保存路径时发生错误: {str(e)}")
|
||||
|
||||
85
main.py
85
main.py
@ -1,10 +1,8 @@
|
||||
import uvicorn
|
||||
import threading
|
||||
import time
|
||||
import os
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from service.file_service import app as flask_app
|
||||
|
||||
# 原有业务导入
|
||||
from core.all import load_model
|
||||
@ -15,32 +13,12 @@ from service.sensitive_service import router as sensitive_router
|
||||
from service.face_service import router as face_router
|
||||
from service.device_service import router as device_router
|
||||
from service.model_service import router as model_router
|
||||
from service.file_service import router as file_router
|
||||
from service.device_danger_service import router as device_danger_router
|
||||
from ws.ws import ws_router, lifespan
|
||||
from core.establish import create_directory_structure
|
||||
|
||||
|
||||
# Flask 服务启动函数(不变)
|
||||
def start_flask_service():
|
||||
try:
|
||||
print(f"\n[Flask 服务] 准备启动,端口:5000")
|
||||
print(f"[Flask 服务] 访问示例:http://服务器IP:5000/resource/dect/ocr/xxx.jpg\n")
|
||||
|
||||
BASE_IMAGE_DIR = os.path.abspath(os.path.join("resource", "dect"))
|
||||
if not os.path.exists(BASE_IMAGE_DIR):
|
||||
print(f"[Flask 服务] 图片根目录不存在,创建:{BASE_IMAGE_DIR}")
|
||||
os.makedirs(BASE_IMAGE_DIR, exist_ok=True)
|
||||
|
||||
flask_app.run(
|
||||
host="0.0.0.0",
|
||||
port=5000,
|
||||
debug=False,
|
||||
use_reloader=False
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[Flask 服务] 启动失败:{str(e)}")
|
||||
|
||||
|
||||
# 初始化 FastAPI 应用(新增 CORS 配置)
|
||||
# 初始化 FastAPI 应用
|
||||
app = FastAPI(
|
||||
title="内容安全审核后台",
|
||||
description="含图片访问服务和动态模型管理",
|
||||
@ -48,38 +26,33 @@ app = FastAPI(
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# ------------------------------
|
||||
# 新增:完整 CORS 配置(解决跨域问题)
|
||||
# ------------------------------
|
||||
# 1. 允许的前端域名(根据实际情况修改!本地开发通常是 http://localhost:8080 等)
|
||||
ALLOWED_ORIGINS = [
|
||||
# "http://localhost:8080", # 前端本地开发地址(必改,填实际前端地址)
|
||||
# "http://127.0.0.1:8080",
|
||||
# "http://服务器IP:8080", # 部署后前端地址(如适用)
|
||||
"*" #表示允许所有域名(开发环境可用,生产环境不推荐)
|
||||
"*"
|
||||
]
|
||||
|
||||
# 2. 配置 CORS 中间件
|
||||
# 配置 CORS 中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS, # 允许的前端域名
|
||||
allow_credentials=True, # 允许携带 Cookie(如需登录态则必开)
|
||||
allow_methods=["*"], # 允许所有 HTTP 方法(包括 PUT/DELETE)
|
||||
allow_headers=["*"], # 允许所有请求头(包括 Content-Type)
|
||||
allow_credentials=True, # 允许携带 Cookie
|
||||
allow_methods=["*"], # 允许所有 HTTP 方法
|
||||
allow_headers=["*"], # 允许所有请求头
|
||||
)
|
||||
|
||||
# 注册路由(不变)
|
||||
# 注册路由
|
||||
app.include_router(user_router)
|
||||
app.include_router(device_router)
|
||||
app.include_router(face_router)
|
||||
app.include_router(sensitive_router)
|
||||
app.include_router(model_router) # 模型管理路由
|
||||
app.include_router(model_router)
|
||||
app.include_router(file_router)
|
||||
app.include_router(device_danger_router)
|
||||
app.include_router(ws_router)
|
||||
|
||||
# 注册全局异常处理器(不变)
|
||||
# 注册全局异常处理器
|
||||
app.add_exception_handler(Exception, global_exception_handler)
|
||||
|
||||
# 主服务启动入口(不变)
|
||||
# 主服务启动入口
|
||||
if __name__ == "__main__":
|
||||
create_directory_structure()
|
||||
print(f"[初始化] 目录结构创建完成")
|
||||
@ -89,11 +62,11 @@ if __name__ == "__main__":
|
||||
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
|
||||
print(f"[初始化] 模型保存目录:{MODEL_SAVE_DIR}")
|
||||
|
||||
# # 模型路径配置
|
||||
# YOLO_MODEL_PATH = os.path.join("core", "models", "best.pt")
|
||||
# OCR_CONFIG_PATH = os.path.join("core", "config", "config.yaml")
|
||||
# print(f"[初始化] 默认YOLO模型路径:{YOLO_MODEL_PATH}")
|
||||
# print(f"[初始化] OCR 配置路径:{OCR_CONFIG_PATH}")
|
||||
# 确保图片目录存在(原Flask服务负责的目录)
|
||||
BASE_IMAGE_DIR = os.path.abspath(os.path.join("resource", "dect"))
|
||||
if not os.path.exists(BASE_IMAGE_DIR):
|
||||
print(f"[初始化] 图片根目录不存在,创建:{BASE_IMAGE_DIR}")
|
||||
os.makedirs(BASE_IMAGE_DIR, exist_ok=True)
|
||||
|
||||
# 加载检测模型
|
||||
try:
|
||||
@ -105,23 +78,7 @@ if __name__ == "__main__":
|
||||
except Exception as e:
|
||||
print(f"[初始化] 模型加载警告:{str(e)}(服务仍启动)")
|
||||
|
||||
|
||||
|
||||
# 2. 启动 Flask 服务(子线程)
|
||||
flask_thread = threading.Thread(
|
||||
target=start_flask_service,
|
||||
daemon=True
|
||||
)
|
||||
flask_thread.start()
|
||||
|
||||
# 等待 Flask 初始化
|
||||
time.sleep(1)
|
||||
if flask_thread.is_alive():
|
||||
print(f"[Flask 服务] 启动成功(运行中)")
|
||||
else:
|
||||
print(f"[Flask 服务] 启动失败!图片访问不可用")
|
||||
|
||||
# 3. 启动 FastAPI 主服务
|
||||
# 启动 FastAPI 主服务(仅使用8000端口)
|
||||
port = int(SERVER_CONFIG.get("port", 8000))
|
||||
print(f"\n[FastAPI 服务] 准备启动,端口:{port}")
|
||||
print(f"[FastAPI 服务] 接口文档:http://服务器IP:{port}/docs\n")
|
||||
@ -133,4 +90,4 @@ if __name__ == "__main__":
|
||||
workers=1,
|
||||
ws="websockets",
|
||||
reload=False
|
||||
)
|
||||
)
|
||||
|
||||
33
schema/device_danger_schema.py
Normal file
33
schema/device_danger_schema.py
Normal file
@ -0,0 +1,33 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 请求模型
|
||||
# ------------------------------
|
||||
class DeviceDangerCreateRequest(BaseModel):
|
||||
"""设备危险记录创建请求模型"""
|
||||
client_ip: str = Field(..., max_length=100, description="设备IP地址(必须与devices表中IP对应)")
|
||||
type: str = Field(..., max_length=50, description="危险类型(如:病毒检测、端口异常、权限泄露等)")
|
||||
result: str = Field(..., description="危险检测结果/处理结果(如:检测到木马病毒,已隔离;端口22异常开放,已关闭)")
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 响应模型
|
||||
# ------------------------------
|
||||
class DeviceDangerResponse(BaseModel):
|
||||
"""单条设备危险记录响应模型(与device_danger表字段对齐,updated_at允许为null)"""
|
||||
id: int = Field(..., description="危险记录主键ID")
|
||||
client_ip: str = Field(..., max_length=100, description="设备IP地址")
|
||||
type: str = Field(..., max_length=50, description="危险类型")
|
||||
result: str = Field(..., description="危险检测结果/处理结果")
|
||||
created_at: datetime = Field(..., description="记录创建时间(危险发生/检测时间)")
|
||||
updated_at: Optional[datetime] = Field(None, description="记录更新时间(数据库中该字段当前为null)")
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class DeviceDangerListResponse(BaseModel):
|
||||
"""设备危险记录列表响应模型(带分页)"""
|
||||
total: int = Field(..., description="危险记录总数")
|
||||
dangers: List[DeviceDangerResponse] = Field(..., description="设备危险记录列表")
|
||||
@ -28,7 +28,6 @@ class DeviceResponse(BaseModel):
|
||||
params: Optional[str] = Field(None, description="扩展参数(JSON字符串)")
|
||||
created_at: datetime = Field(..., description="记录创建时间")
|
||||
updated_at: datetime = Field(..., description="记录更新时间")
|
||||
|
||||
model_config = {"from_attributes": True} # 支持从数据库结果直接转换
|
||||
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ from schema.response_schema import APIResponse
|
||||
|
||||
# 路由配置
|
||||
router = APIRouter(
|
||||
prefix="/device/actions",
|
||||
prefix="/api/device/actions",
|
||||
tags=["设备操作记录"]
|
||||
)
|
||||
|
||||
|
||||
267
service/device_danger_service.py
Normal file
267
service/device_danger_service.py
Normal file
@ -0,0 +1,267 @@
|
||||
import json
|
||||
from datetime import date
|
||||
|
||||
from fastapi import APIRouter, Query, HTTPException, Path
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
from ds.db import db
|
||||
from encryption.encrypt_decorator import encrypt_response
|
||||
from schema.device_danger_schema import (
|
||||
DeviceDangerCreateRequest, DeviceDangerResponse, DeviceDangerListResponse
|
||||
)
|
||||
from schema.response_schema import APIResponse
|
||||
|
||||
# 路由初始化(前缀与设备管理相关,标签区分功能)
|
||||
router = APIRouter(
|
||||
prefix="/api/devices/dangers",
|
||||
tags=["设备管理-危险记录"]
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 内部工具方法 - 检查设备是否存在(复用设备表逻辑)
|
||||
# ------------------------------
|
||||
def check_device_exist(client_ip: str) -> bool:
|
||||
"""
|
||||
检查指定IP的设备是否在devices表中存在
|
||||
|
||||
:param client_ip: 设备IP地址
|
||||
:return: 存在返回True,不存在返回False
|
||||
"""
|
||||
if not client_ip:
|
||||
raise ValueError("设备IP不能为空")
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,))
|
||||
return cursor.fetchone() is not None
|
||||
except MySQLError as e:
|
||||
raise Exception(f"检查设备存在性失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 内部工具方法 - 创建设备危险记录(核心插入逻辑)
|
||||
# ------------------------------
|
||||
def create_danger_record(danger_data: DeviceDangerCreateRequest) -> DeviceDangerResponse:
|
||||
"""
|
||||
内部工具方法:向device_danger表插入新的危险记录
|
||||
|
||||
:param danger_data: 危险记录创建请求数据
|
||||
:return: 创建成功的危险记录模型对象
|
||||
"""
|
||||
# 先检查设备是否存在
|
||||
if not check_device_exist(danger_data.client_ip):
|
||||
raise ValueError(f"IP为 {danger_data.client_ip} 的设备不存在,无法创建危险记录")
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 插入危险记录(id自增,时间自动填充)
|
||||
insert_query = """
|
||||
INSERT INTO device_danger
|
||||
(client_ip, type, result, created_at, updated_at)
|
||||
VALUES (%s, %s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
"""
|
||||
cursor.execute(insert_query, (
|
||||
danger_data.client_ip,
|
||||
danger_data.type,
|
||||
danger_data.result
|
||||
))
|
||||
conn.commit()
|
||||
|
||||
# 获取刚创建的记录(用自增ID查询)
|
||||
danger_id = cursor.lastrowid
|
||||
cursor.execute("SELECT * FROM device_danger WHERE id = %s", (danger_id,))
|
||||
new_danger = cursor.fetchone()
|
||||
|
||||
return DeviceDangerResponse(**new_danger)
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"插入危险记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 接口1:创建设备危险记录
|
||||
# ------------------------------
|
||||
@router.post("/add", response_model=APIResponse, summary="创建设备危险记录")
|
||||
@encrypt_response()
|
||||
async def add_device_danger(danger_data: DeviceDangerCreateRequest):
|
||||
try:
|
||||
# 调用内部方法创建记录
|
||||
new_danger = create_danger_record(danger_data)
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"设备[{danger_data.client_ip}]危险记录创建成功",
|
||||
data=new_danger
|
||||
)
|
||||
except ValueError as e:
|
||||
# 设备不存在等业务异常
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
except Exception as e:
|
||||
# 数据库异常等系统错误
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 接口2:获取危险记录列表(支持多条件筛选+分页)
|
||||
# ------------------------------
|
||||
@router.get("/", response_model=APIResponse, summary="获取设备危险记录列表(多条件筛选)")
|
||||
@encrypt_response()
|
||||
async def get_danger_list(
|
||||
page: int = Query(1, ge=1, description="页码,默认第1页"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100之间"),
|
||||
client_ip: str = Query(None, max_length=100, description="按设备IP筛选"),
|
||||
danger_type: str = Query(None, max_length=50, alias="type", description="按危险类型筛选"),
|
||||
start_date: date = Query(None, description="按创建时间筛选(开始日期,格式YYYY-MM-DD)"),
|
||||
end_date: date = Query(None, description="按创建时间筛选(结束日期,格式YYYY-MM-DD)")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 构建筛选条件
|
||||
where_clause = []
|
||||
params = []
|
||||
|
||||
if client_ip:
|
||||
where_clause.append("client_ip = %s")
|
||||
params.append(client_ip)
|
||||
if danger_type:
|
||||
where_clause.append("type = %s")
|
||||
params.append(danger_type)
|
||||
if start_date:
|
||||
where_clause.append("DATE(created_at) >= %s")
|
||||
params.append(start_date.strftime("%Y-%m-%d"))
|
||||
if end_date:
|
||||
where_clause.append("DATE(created_at) <= %s")
|
||||
params.append(end_date.strftime("%Y-%m-%d"))
|
||||
|
||||
# 1. 统计符合条件的总记录数
|
||||
count_query = "SELECT COUNT(*) AS total FROM device_danger"
|
||||
if where_clause:
|
||||
count_query += " WHERE " + " AND ".join(where_clause)
|
||||
cursor.execute(count_query, params)
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 2. 分页查询记录(按创建时间倒序,最新的在前)
|
||||
offset = (page - 1) * page_size
|
||||
list_query = "SELECT * FROM device_danger"
|
||||
if where_clause:
|
||||
list_query += " WHERE " + " AND ".join(where_clause)
|
||||
list_query += " ORDER BY created_at DESC LIMIT %s OFFSET %s"
|
||||
params.extend([page_size, offset]) # 追加分页参数
|
||||
|
||||
cursor.execute(list_query, params)
|
||||
danger_list = cursor.fetchall()
|
||||
|
||||
# 转换为响应模型
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="获取危险记录列表成功",
|
||||
data=DeviceDangerListResponse(
|
||||
total=total,
|
||||
dangers=[DeviceDangerResponse(**item) for item in danger_list]
|
||||
)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"查询危险记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 接口3:获取单个设备的所有危险记录
|
||||
# ------------------------------
|
||||
@router.get("/device/{client_ip}", response_model=APIResponse, summary="获取单个设备的所有危险记录")
|
||||
# @encrypt_response()
|
||||
async def get_device_dangers(
|
||||
client_ip: str = Path(..., max_length=100, description="设备IP地址"),
|
||||
page: int = Query(1, ge=1, description="页码,默认第1页"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100之间")
|
||||
):
|
||||
# 先检查设备是否存在
|
||||
if not check_device_exist(client_ip):
|
||||
raise HTTPException(status_code=404, detail=f"IP为 {client_ip} 的设备不存在")
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 1. 统计该设备的危险记录总数
|
||||
count_query = "SELECT COUNT(*) AS total FROM device_danger WHERE client_ip = %s"
|
||||
cursor.execute(count_query, (client_ip,))
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 2. 分页查询该设备的危险记录
|
||||
offset = (page - 1) * page_size
|
||||
list_query = """
|
||||
SELECT * FROM device_danger
|
||||
WHERE client_ip = %s
|
||||
ORDER BY created_at DESC
|
||||
LIMIT %s OFFSET %s
|
||||
"""
|
||||
cursor.execute(list_query, (client_ip, page_size, offset))
|
||||
danger_list = cursor.fetchall()
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"获取设备[{client_ip}]危险记录成功(共{total}条)",
|
||||
data=DeviceDangerListResponse(
|
||||
total=total,
|
||||
dangers=[DeviceDangerResponse(**item) for item in danger_list]
|
||||
)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"查询设备[{client_ip}]危险记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 接口4:根据ID获取单个危险记录详情
|
||||
# ------------------------------
|
||||
@router.get("/{danger_id}", response_model=APIResponse, summary="根据ID获取单个危险记录详情")
|
||||
@encrypt_response()
|
||||
async def get_danger_detail(
|
||||
danger_id: int = Path(..., ge=1, description="危险记录ID")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 查询单个危险记录
|
||||
query = "SELECT * FROM device_danger WHERE id = %s"
|
||||
cursor.execute(query, (danger_id,))
|
||||
danger = cursor.fetchone()
|
||||
|
||||
if not danger:
|
||||
raise HTTPException(status_code=404, detail=f"ID为 {danger_id} 的危险记录不存在")
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="获取危险记录详情成功",
|
||||
data=DeviceDangerResponse(**danger)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"查询危险记录详情失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
@ -13,7 +13,7 @@ from schema.device_schema import (
|
||||
from schema.response_schema import APIResponse
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/devices",
|
||||
prefix="/api/devices",
|
||||
tags=["设备管理"]
|
||||
)
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ from schema.response_schema import APIResponse
|
||||
from util.face_util import add_binary_data, get_average_feature
|
||||
from util.file_util import save_face_to_up_images
|
||||
|
||||
router = APIRouter(prefix="/faces", tags=["人脸管理"])
|
||||
router = APIRouter(prefix="/api/faces", tags=["人脸管理"])
|
||||
|
||||
|
||||
# ------------------------------
|
||||
|
||||
@ -1,276 +1,174 @@
|
||||
from flask import Flask, send_from_directory, abort, request
|
||||
from fastapi import FastAPI, HTTPException, Request, Depends, APIRouter
|
||||
from fastapi.responses import FileResponse
|
||||
import os
|
||||
import logging
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from flask_cors import CORS
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from typing import Annotated
|
||||
|
||||
# 配置日志(保持原有格式)
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 初始化 Flask 应用(供 main.py 导入)
|
||||
app = Flask(__name__)
|
||||
|
||||
# ------------------------------
|
||||
# 核心修改:与 FastAPI 对齐的跨域配置
|
||||
# ------------------------------
|
||||
# 1. 允许的前端域名(根据实际环境修改,生产环境删除 "*")
|
||||
ALLOWED_ORIGINS = [
|
||||
# "http://localhost:8080", # 本地前端开发地址
|
||||
# "http://127.0.0.1:8080",
|
||||
# "http://服务器IP:8080", # 部署后前端地址
|
||||
"*"
|
||||
]
|
||||
|
||||
# 2. 配置 CORS(与 FastAPI 规则完全对齐)
|
||||
CORS(
|
||||
app,
|
||||
resources={
|
||||
r"/*": {
|
||||
"origins": ALLOWED_ORIGINS,
|
||||
"allow_credentials": True,
|
||||
"methods": ["*"],
|
||||
"allow_headers": ["*"],
|
||||
}
|
||||
},
|
||||
router = APIRouter(
|
||||
prefix="/api/file",
|
||||
tags=["文件管理"]
|
||||
)
|
||||
|
||||
# ------------------------------
|
||||
# 核心路径配置(关键修改:修正 PROJECT_ROOT 计算)
|
||||
# 原问题:file_service.py 在 service 文件夹内,需向上跳一级到项目根目录
|
||||
# 4. 路径配置
|
||||
# ------------------------------
|
||||
CURRENT_FILE_PATH = Path(__file__).resolve() # 当前文件路径:service/file_service.py
|
||||
PROJECT_ROOT = CURRENT_FILE_PATH.parent.parent # 项目根目录(service 文件夹的父目录)
|
||||
# 资源目录(现在正确指向项目根目录下的文件夹)
|
||||
BASE_IMAGE_DIR_DECT = str((PROJECT_ROOT / "resource" / "dect").resolve()) # 根目录/resource/dect
|
||||
BASE_IMAGE_DIR_UP_IMAGES = str((PROJECT_ROOT / "up_images").resolve()) # 根目录/up_images
|
||||
BASE_MODEL_DIR = str((PROJECT_ROOT / "resource" / "models").resolve()) # 根目录/resource/models
|
||||
CURRENT_FILE_PATH = Path(__file__).resolve()
|
||||
PROJECT_ROOT = CURRENT_FILE_PATH.parent.parent # 项目根目录
|
||||
|
||||
# 资源目录定义
|
||||
BASE_IMAGE_DIR_DECT = str((PROJECT_ROOT / "resource" / "dect").resolve()) # 检测图片目录
|
||||
BASE_IMAGE_DIR_UP_IMAGES = str((PROJECT_ROOT / "up_images").resolve()) # 人脸图片目录
|
||||
BASE_MODEL_DIR = str((PROJECT_ROOT / "resource" / "models").resolve()) # 模型文件目录
|
||||
|
||||
# 确保资源目录存在
|
||||
for dir_path in [BASE_IMAGE_DIR_DECT, BASE_IMAGE_DIR_UP_IMAGES, BASE_MODEL_DIR]:
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
print(f"[创建目录] {dir_path}")
|
||||
|
||||
# ------------------------------
|
||||
# 安全检查装饰器(不变)
|
||||
# 5. 安全依赖项(替代Flask装饰器)
|
||||
# ------------------------------
|
||||
def safe_path_check(root_dir: str):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
resource_path = kwargs.get('resource_path', '').strip()
|
||||
# 统一路径分隔符(兼容 Windows \ 和 Linux /)
|
||||
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
|
||||
# 拼接完整路径(防止路径遍历)
|
||||
full_file_path = os.path.abspath(os.path.join(root_dir, resource_path))
|
||||
logger.debug(
|
||||
f"[Flask 安全检查] 请求路径:{resource_path} | 完整路径:{full_file_path} | 根目录:{root_dir}"
|
||||
)
|
||||
|
||||
# 1. 禁止路径遍历(确保请求文件在根目录内)
|
||||
if not full_file_path.startswith(root_dir):
|
||||
logger.warning(
|
||||
f"[Flask 安全拦截] 非法路径遍历!IP:{request.remote_addr} | 请求路径:{resource_path}"
|
||||
)
|
||||
abort(403)
|
||||
|
||||
# 2. 检查文件存在且为有效文件(非目录)
|
||||
if not os.path.exists(full_file_path) or not os.path.isfile(full_file_path):
|
||||
logger.warning(
|
||||
f"[Flask 资源错误] 文件不存在/非文件!IP:{request.remote_addr} | 路径:{full_file_path}"
|
||||
)
|
||||
abort(404)
|
||||
|
||||
# 3. 限制文件大小(模型200MB,图片10MB)
|
||||
max_size = 200 * 1024 * 1024 if "models" in root_dir else 10 * 1024 * 1024
|
||||
if os.path.getsize(full_file_path) > max_size:
|
||||
logger.warning(
|
||||
f"[Flask 大小超限] 文件超过{max_size//1024//1024}MB!IP:{request.remote_addr} | 路径:{full_file_path}"
|
||||
)
|
||||
abort(413)
|
||||
|
||||
# 安全检查通过,传递根目录给视图函数
|
||||
return func(*args, **kwargs, root_dir=root_dir)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
# ------------------------------
|
||||
# 1. 模型下载接口(/model/download/*)
|
||||
# ------------------------------
|
||||
@app.route('/model/download/<path:resource_path>')
|
||||
@safe_path_check(root_dir=BASE_MODEL_DIR)
|
||||
def download_model(resource_path, root_dir):
|
||||
try:
|
||||
"""
|
||||
安全路径校验依赖项:
|
||||
1. 禁止路径遍历(确保请求文件在根目录内)
|
||||
2. 校验文件存在且为有效文件(非目录)
|
||||
3. 限制文件大小(模型200MB,图片10MB)
|
||||
"""
|
||||
async def dependency(request: Request, resource_path: str):
|
||||
# 统一路径分隔符
|
||||
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
|
||||
dir_path, file_name = os.path.split(resource_path)
|
||||
full_dir = os.path.abspath(os.path.join(root_dir, dir_path))
|
||||
# 拼接完整路径
|
||||
full_file_path = os.path.abspath(os.path.join(root_dir, resource_path))
|
||||
|
||||
# 仅允许 .pt 格式(YOLO 模型)
|
||||
if not file_name.lower().endswith('.pt'):
|
||||
logger.warning(
|
||||
f"[Flask 格式错误] 非 .pt 模型文件!IP:{request.remote_addr} | 文件名:{file_name}"
|
||||
)
|
||||
abort(415)
|
||||
# 校验1:禁止路径遍历
|
||||
if not full_file_path.startswith(root_dir):
|
||||
print(f"[安全检查] 禁止路径遍历!IP:{request.client.host} | 请求路径:{resource_path}")
|
||||
raise HTTPException(status_code=403, detail="非法路径访问")
|
||||
|
||||
logger.info(
|
||||
f"[Flask 模型下载] 成功请求!IP:{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}"
|
||||
# 校验2:文件存在且为有效文件
|
||||
if not os.path.exists(full_file_path) or not os.path.isfile(full_file_path):
|
||||
print(f"[资源错误] 文件不存在/非文件!IP:{request.client.host} | 路径:{full_file_path}")
|
||||
raise HTTPException(status_code=404, detail="文件不存在")
|
||||
|
||||
# 校验3:文件大小限制
|
||||
max_size = 200 * 1024 * 1024 if "models" in root_dir else 10 * 1024 * 1024
|
||||
if os.path.getsize(full_file_path) > max_size:
|
||||
print(f"[大小超限] 文件超过{max_size//1024//1024}MB!IP:{request.client.host} | 路径:{full_file_path}")
|
||||
raise HTTPException(status_code=413, detail=f"文件大小超过限制({max_size//1024//1024}MB)")
|
||||
|
||||
return full_file_path
|
||||
return dependency
|
||||
|
||||
# ------------------------------
|
||||
# 6. 核心接口
|
||||
# ------------------------------
|
||||
@router.get("/model/download/{resource_path:path}", summary="模型下载接口")
|
||||
async def download_model(
|
||||
resource_path: str,
|
||||
full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_MODEL_DIR))],
|
||||
request: Request
|
||||
):
|
||||
"""模型下载接口(仅允许 .pt 格式,强制浏览器下载)"""
|
||||
try:
|
||||
dir_path, file_name = os.path.split(full_file_path)
|
||||
|
||||
# 额外校验:仅允许 YOLO 模型格式(.pt)
|
||||
if not file_name.lower().endswith(".pt"):
|
||||
print(f"[格式错误] 非 .pt 模型文件!IP:{request.client.host} | 文件名:{file_name}")
|
||||
raise HTTPException(status_code=415, detail="仅支持 .pt 格式的模型文件")
|
||||
|
||||
print(f"[模型下载] 尝试下载!IP:{request.client.host} | 文件:{file_name} | 目录:{dir_path}")
|
||||
|
||||
# 强制下载
|
||||
return FileResponse(
|
||||
full_file_path,
|
||||
filename=file_name,
|
||||
media_type="application/octet-stream"
|
||||
)
|
||||
|
||||
# 强制浏览器下载(而非预览)
|
||||
return send_from_directory(
|
||||
full_dir,
|
||||
file_name,
|
||||
as_attachment=True,
|
||||
mimetype="application/octet-stream"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Flask 模型下载异常] IP:{request.remote_addr} | 错误:{str(e)}"
|
||||
)
|
||||
abort(500)
|
||||
print(f"[下载异常] IP:{request.client.host} | 错误:{str(e)}")
|
||||
raise HTTPException(status_code=500, detail="服务器内部错误")
|
||||
|
||||
# ------------------------------
|
||||
# 2. 人脸图片访问接口(/up_images/*)
|
||||
# ------------------------------
|
||||
@app.route('/up_images/<path:resource_path>')
|
||||
@safe_path_check(root_dir=BASE_IMAGE_DIR_UP_IMAGES)
|
||||
def get_face_image(resource_path, root_dir):
|
||||
|
||||
@router.get("/up_images/{resource_path:path}", summary="人脸图片访问接口")
|
||||
async def get_face_image(
|
||||
resource_path: str,
|
||||
full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_IMAGE_DIR_UP_IMAGES))],
|
||||
request: Request
|
||||
):
|
||||
"""人脸图片访问接口(允许浏览器预览,仅支持常见图片格式)"""
|
||||
try:
|
||||
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
|
||||
dir_path, file_name = os.path.split(resource_path)
|
||||
full_dir = os.path.abspath(os.path.join(root_dir, dir_path))
|
||||
dir_path, file_name = os.path.split(full_file_path)
|
||||
|
||||
# 仅允许常见图片格式
|
||||
allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp')
|
||||
# 图片格式校验
|
||||
allowed_ext = (".png", ".jpg", ".jpeg", ".gif", ".bmp")
|
||||
if not file_name.lower().endswith(allowed_ext):
|
||||
logger.warning(
|
||||
f"[Flask 格式错误] 非图片文件!IP:{request.remote_addr} | 文件名:{file_name}"
|
||||
)
|
||||
abort(415)
|
||||
print(f"[格式错误] 非图片文件!IP:{request.client.host} | 文件名:{file_name}")
|
||||
raise HTTPException(status_code=415, detail="仅支持常见图片格式")
|
||||
|
||||
logger.info(
|
||||
f"[Flask 人脸图片] 成功请求!IP:{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}"
|
||||
)
|
||||
|
||||
# 允许浏览器预览图片
|
||||
return send_from_directory(full_dir, file_name, as_attachment=False)
|
||||
print(f"[人脸图片] 尝试访问!IP:{request.client.host} | 文件:{file_name} | 目录:{dir_path}")
|
||||
|
||||
return FileResponse(full_file_path)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Flask 人脸图片异常] IP:{request.remote_addr} | 错误:{str(e)}"
|
||||
)
|
||||
abort(500)
|
||||
print(f"[人脸图片异常] IP:{request.client.host} | 错误:{str(e)}")
|
||||
|
||||
# ------------------------------
|
||||
# 3. 检测图片访问接口(/resource/dect/*)
|
||||
# ------------------------------
|
||||
@app.route('/resource/dect/<path:resource_path>')
|
||||
@safe_path_check(root_dir=BASE_IMAGE_DIR_DECT)
|
||||
def get_dect_image(resource_path, root_dir):
|
||||
|
||||
@router.get("/resource/dect/{resource_path:path}", summary="检测图片访问接口")
|
||||
async def get_dect_image(
|
||||
resource_path: str,
|
||||
full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_IMAGE_DIR_DECT))],
|
||||
request: Request
|
||||
):
|
||||
"""检测图片访问接口(允许浏览器预览,仅支持常见图片格式)"""
|
||||
try:
|
||||
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
|
||||
dir_path, file_name = os.path.split(resource_path)
|
||||
full_dir = os.path.abspath(os.path.join(root_dir, dir_path))
|
||||
dir_path, file_name = os.path.split(full_file_path)
|
||||
|
||||
# 仅允许常见图片格式
|
||||
allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp')
|
||||
# 图片格式校验
|
||||
allowed_ext = (".png", ".jpg", ".jpeg", ".gif", ".bmp")
|
||||
if not file_name.lower().endswith(allowed_ext):
|
||||
logger.warning(
|
||||
f"[Flask 格式错误] 非图片文件!IP:{request.remote_addr} | 文件名:{file_name}"
|
||||
)
|
||||
abort(415)
|
||||
print(f"[格式错误] 非图片文件!IP:{request.client.host} | 文件名:{file_name}")
|
||||
raise HTTPException(status_code=415, detail="仅支持常见图片格式")
|
||||
|
||||
logger.info(
|
||||
f"[Flask 检测图片] 成功请求!IP:{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}"
|
||||
)
|
||||
|
||||
return send_from_directory(full_dir, file_name, as_attachment=False)
|
||||
print(f"[检测图片] 尝试访问!IP:{request.client.host} | 文件:{file_name} | 目录:{dir_path}")
|
||||
|
||||
return FileResponse(full_file_path)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Flask 检测图片异常] IP:{request.remote_addr} | 错误:{str(e)}"
|
||||
)
|
||||
abort(500)
|
||||
print(f"[检测图片异常] IP:{request.client.host} | 错误:{str(e)}")
|
||||
raise HTTPException(status_code=500, detail="服务器内部错误")
|
||||
|
||||
# ------------------------------
|
||||
# 4. 兼容旧图片接口(/images/* → 映射到 /resource/dect/*)
|
||||
# ------------------------------
|
||||
@app.route('/images/<path:resource_path>')
|
||||
@safe_path_check(root_dir=BASE_IMAGE_DIR_DECT)
|
||||
def get_compatible_image(resource_path, root_dir):
|
||||
|
||||
@router.get("/images/{resource_path:path}", summary="兼容旧接口")
|
||||
async def get_compatible_image(
|
||||
resource_path: str,
|
||||
full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_IMAGE_DIR_DECT))],
|
||||
request: Request
|
||||
):
|
||||
"""兼容旧接口(/images/* → 映射到 /resource/dect/*,保留历史兼容性)"""
|
||||
try:
|
||||
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
|
||||
dir_path, file_name = os.path.split(resource_path)
|
||||
full_dir = os.path.abspath(os.path.join(root_dir, dir_path))
|
||||
dir_path, file_name = os.path.split(full_file_path)
|
||||
|
||||
allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp')
|
||||
# 图片格式校验
|
||||
allowed_ext = (".png", ".jpg", ".jpeg", ".gif", ".bmp")
|
||||
if not file_name.lower().endswith(allowed_ext):
|
||||
logger.warning(
|
||||
f"[Flask 格式错误] 非图片文件!IP:{request.remote_addr} | 文件名:{file_name}"
|
||||
)
|
||||
abort(415)
|
||||
|
||||
logger.info(
|
||||
f"[Flask 兼容图片] 成功请求!IP:{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}"
|
||||
)
|
||||
|
||||
return send_from_directory(full_dir, file_name, as_attachment=False)
|
||||
print(f"[格式错误] 非图片文件!IP:{request.client.host} | 文件名:{file_name}")
|
||||
raise HTTPException(status_code=415, detail="仅支持常见图片格式")
|
||||
print(f"[兼容图片] 尝试访问!IP:{request.client.host} | 文件:{file_name} | 目录:{dir_path}")
|
||||
|
||||
return FileResponse(full_file_path)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Flask 兼容图片异常] IP:{request.remote_addr} | 错误:{str(e)}"
|
||||
)
|
||||
abort(500)
|
||||
|
||||
# ------------------------------
|
||||
# 全局错误处理器(不变)
|
||||
# ------------------------------
|
||||
@app.errorhandler(403)
|
||||
def forbidden_error(error):
|
||||
return "❌ 禁止访问:路径非法(可能存在路径遍历)或无权限", 403
|
||||
|
||||
@app.errorhandler(404)
|
||||
def not_found_error(error):
|
||||
return "❌ 资源不存在:请检查URL路径(IP、目录、文件名)是否正确", 404
|
||||
|
||||
@app.errorhandler(413)
|
||||
def too_large_error(error):
|
||||
return "❌ 文件过大:图片最大10MB,模型最大200MB", 413
|
||||
|
||||
@app.errorhandler(415)
|
||||
def unsupported_type_error(error):
|
||||
return "❌ 不支持的文件类型:图片支持png/jpg/jpeg/gif/bmp,模型仅支持pt", 415
|
||||
|
||||
@app.errorhandler(500)
|
||||
def server_error(error):
|
||||
return "❌ 服务器内部错误:请联系管理员查看后台日志", 500
|
||||
|
||||
# ------------------------------
|
||||
# Flask 独立启动入口(供测试,实际由 main.py 子线程启动)
|
||||
# ------------------------------
|
||||
if __name__ == '__main__':
|
||||
# 确保所有资源目录存在
|
||||
required_dirs = [
|
||||
(BASE_IMAGE_DIR_DECT, "检测图片目录"),
|
||||
(BASE_IMAGE_DIR_UP_IMAGES, "人脸图片目录"),
|
||||
(BASE_MODEL_DIR, "模型文件目录")
|
||||
]
|
||||
for dir_path, dir_desc in required_dirs:
|
||||
if not os.path.exists(dir_path):
|
||||
logger.info(f"[Flask 初始化] {dir_desc}不存在,创建:{dir_path}")
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
|
||||
# 启动提示
|
||||
logger.info("\n[Flask 服务启动成功!] 支持的接口:")
|
||||
logger.info(f"1. 模型下载 → http://服务器IP:5000/model/download/resource/models/xxx.pt")
|
||||
logger.info(f"2. 人脸图片 → http://服务器IP:5000/up_images/xxx.jpg")
|
||||
logger.info(f"3. 检测图片 → http://服务器IP:5000/resource/dect/xxx.jpg 或 http://服务器IP:5000/images/xxx.jpg\n")
|
||||
|
||||
# 启动服务(禁用 debug 和自动重载)
|
||||
app.run(
|
||||
host="0.0.0.0",
|
||||
port=5000,
|
||||
debug=False,
|
||||
use_reloader=False
|
||||
)
|
||||
print(f"[兼容图片异常] IP:{request.client.host} | 错误:{str(e)}")
|
||||
raise HTTPException(status_code=500, detail="服务器内部错误")
|
||||
|
||||
@ -38,7 +38,7 @@ _yolo_model = None
|
||||
_current_model_version = None # 模型版本标识
|
||||
_current_conf_threshold = 0.8 # 默认置信度初始值
|
||||
|
||||
router = APIRouter(prefix="/models", tags=["模型管理"])
|
||||
router = APIRouter(prefix="/api/models", tags=["模型管理"])
|
||||
|
||||
|
||||
# 服务重启核心工具函数(保持不变)
|
||||
|
||||
@ -16,7 +16,7 @@ from schema.user_schema import UserResponse
|
||||
|
||||
# 创建敏感信息接口路由(前缀 /sensitives、标签用于 Swagger 分类)
|
||||
router = APIRouter(
|
||||
prefix="/sensitives",
|
||||
prefix="/api/sensitives",
|
||||
tags=["敏感信息管理"]
|
||||
)
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ from middle.auth_middleware import (
|
||||
|
||||
# 创建用户接口路由(前缀 /users、标签用于 Swagger 分类)
|
||||
router = APIRouter(
|
||||
prefix="/users",
|
||||
prefix="/api/users",
|
||||
tags=["用户管理"]
|
||||
)
|
||||
|
||||
|
||||
@ -12,7 +12,8 @@ def save_face_to_up_images(
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
保存人脸图片到 `/up_images/用户IP/人脸名字/` 路径
|
||||
确保db_path以up_images开头,且统一使用正斜杠
|
||||
确保db_path以 /api/file/up_images 开头,且统一使用正斜杠
|
||||
本地不创建/api/file/文件夹,仅URL访问时使用该前缀路由
|
||||
|
||||
参数:
|
||||
client_ip: 客户端IP(原始格式,如192.168.1.101)
|
||||
@ -21,10 +22,10 @@ def save_face_to_up_images(
|
||||
image_format: 图片格式(默认jpg)
|
||||
|
||||
返回:
|
||||
字典:success(是否成功)、db_path(存数据库的相对路径)、local_abs_path(本地绝对路径)、msg(提示)
|
||||
字典:success(是否成功)、db_path(存数据库的路径,带/api/file/前缀)、local_abs_path(本地绝对路径)、msg(提示)
|
||||
"""
|
||||
try:
|
||||
# 1. 基础参数校验
|
||||
# 1. 基础参数校验(不变)
|
||||
if not client_ip.strip():
|
||||
return {"success": False, "db_path": "", "local_abs_path": "", "msg": "客户端IP不能为空"}
|
||||
if not image_bytes:
|
||||
@ -32,53 +33,54 @@ def save_face_to_up_images(
|
||||
if image_format.lower() not in ["jpg", "jpeg", "png"]:
|
||||
return {"success": False, "db_path": "", "local_abs_path": "", "msg": "仅支持jpg/jpeg/png格式"}
|
||||
|
||||
# 2. 处理特殊字符(避免路径错误)
|
||||
# 2. 处理特殊字符(避免路径错误)(不变)
|
||||
safe_ip = client_ip.strip().replace(".", "_") # IP中的.替换为_
|
||||
safe_face_name = face_name.strip() if (face_name and face_name.strip()) else "未命名"
|
||||
safe_face_name = "".join([c for c in safe_face_name if c not in r'\/:*?"<>|']) # 过滤非法字符
|
||||
|
||||
# 3. 构建根目录(强制转为绝对路径,避免相对路径混淆)
|
||||
root_dir = Path("up_images").resolve() # 转为绝对路径(如D:/Git/bin/video/up_images)
|
||||
root_dir = Path("up_images").resolve()
|
||||
if not root_dir.exists():
|
||||
root_dir.mkdir(parents=True, exist_ok=True)
|
||||
print(f"[FileUtil] 已创建up_images根目录:{root_dir}")
|
||||
|
||||
# 4. 构建文件层级路径(确保在root_dir子目录下)
|
||||
# 4. 构建文件层级路径(确保在root_dir子目录下)(不变)
|
||||
ip_dir = root_dir / safe_ip
|
||||
face_name_dir = ip_dir / safe_face_name
|
||||
face_name_dir.mkdir(parents=True, exist_ok=True) # 自动创建目录
|
||||
print(f"[FileUtil] 图片存储目录:{face_name_dir}")
|
||||
face_name_dir.mkdir(parents=True, exist_ok=True)
|
||||
print(f"[FileUtil] 图片存储目录(本地):{face_name_dir}")
|
||||
|
||||
# 5. 生成唯一文件名(毫秒级时间戳)
|
||||
# 5. 生成唯一文件名(毫秒级时间戳)(不变)
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3]
|
||||
|
||||
image_filename = f"face_{safe_ip}_{timestamp}.{image_format.lower()}"
|
||||
local_abs_path = face_name_dir / image_filename
|
||||
|
||||
# 6. 计算路径(确保所有路径都是绝对路径且在root_dir下)
|
||||
local_abs_path = face_name_dir / image_filename # 绝对路径
|
||||
|
||||
# 验证路径是否在root_dir下(防止路径穿越攻击)
|
||||
if not local_abs_path.resolve().is_relative_to(root_dir.resolve()):
|
||||
raise Exception(f"图片路径不在up_images根目录下(安全校验失败):{local_abs_path}")
|
||||
|
||||
# 数据库存储路径:强制包含up_images前缀,统一使用正斜杠
|
||||
relative_path = local_abs_path.relative_to(root_dir.parent) # 相对于root_dir的父目录
|
||||
db_path = str(relative_path).replace("\\", "/") # 此时会包含up_images部分
|
||||
# 数据库存储路径:核心修改——在原有relative_path前添加 /api/file/ 前缀
|
||||
relative_path = local_abs_path.relative_to(root_dir.parent)
|
||||
|
||||
# 7. 写入图片文件
|
||||
relative_path_str = str(relative_path).replace("\\", "/")
|
||||
# 2. 再拼接/api/file/前缀
|
||||
db_path = f"/api/file/{relative_path_str}"
|
||||
|
||||
# 7. 写入图片文件(不变)
|
||||
with open(local_abs_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
print(f"[FileUtil] 图片保存成功:")
|
||||
print(f" 数据库路径:{db_path}")
|
||||
print(f" 本地绝对路径:{local_abs_path}")
|
||||
print(f" 数据库路径(带/api/file/前缀):{db_path}")
|
||||
print(f" 本地绝对路径(无/api/file/):{local_abs_path}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"db_path": db_path, # 格式为 up_images/192_168_110_31/小龙/xxx.jpg
|
||||
"local_abs_path": str(local_abs_path), # 本地绝对路径(完整路径)
|
||||
"db_path": db_path,
|
||||
"local_abs_path": str(local_abs_path),
|
||||
"msg": "图片保存成功"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"图片保存失败:{str(e)}"
|
||||
print(f"[FileUtil] 错误:{error_msg}")
|
||||
return {"success": False, "db_path": "", "local_abs_path": "", "msg": error_msg}
|
||||
return {"success": False, "db_path": "", "local_abs_path": "", "msg": error_msg}
|
||||
129
ws/ws.py
129
ws/ws.py
@ -17,17 +17,16 @@ from service.device_action_service import add_device_action
|
||||
from schema.device_action_schema import DeviceActionCreate
|
||||
from core.all import detect, load_model
|
||||
|
||||
# -------------------------- 1. AES 加密解密工具(固定密钥)--------------------------
|
||||
# -------------------------- 1. AES 加密工具(仅用于服务器向客户端发送消息)--------------------------
|
||||
AES_SECRET_KEY = b"jr1vA6tfWMHOYi6UXw67UuO6fdak2rMa" # 约定密钥(32字节)
|
||||
AES_BLOCK_SIZE = 16 # AES固定块大小
|
||||
|
||||
|
||||
def aes_encrypt(plaintext: str) -> dict:
|
||||
"""AES-CBC加密:返回{密文, IV, 算法标识}(均Base64编码)"""
|
||||
"""AES-CBC加密:返回{密文, IV, 算法标识}(均Base64编码)- 仅用于服务器发消息"""
|
||||
try:
|
||||
iv = os.urandom(AES_BLOCK_SIZE) # 随机IV(16字节)
|
||||
cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv)
|
||||
# 明文填充+加密+Base64编码
|
||||
padded_plaintext = pad(plaintext.encode("utf-8"), AES_BLOCK_SIZE)
|
||||
ciphertext = base64.b64encode(cipher.encrypt(padded_plaintext)).decode("utf-8")
|
||||
iv_base64 = base64.b64encode(iv).decode("utf-8")
|
||||
@ -40,20 +39,6 @@ def aes_encrypt(plaintext: str) -> dict:
|
||||
raise Exception(f"AES加密失败: {str(e)}") from e
|
||||
|
||||
|
||||
def aes_decrypt(encrypted_dict: dict) -> str:
|
||||
"""AES-CBC解密:输入加密字典,返回原始文本"""
|
||||
try:
|
||||
# Base64解码密文和IV
|
||||
ciphertext = base64.b64decode(encrypted_dict["ciphertext"])
|
||||
iv = base64.b64decode(encrypted_dict["iv"])
|
||||
# 解密+去除填充
|
||||
cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv)
|
||||
decrypted = unpad(cipher.decrypt(ciphertext), AES_BLOCK_SIZE).decode("utf-8")
|
||||
return decrypted
|
||||
except Exception as e:
|
||||
raise Exception(f"AES解密失败: {str(e)}") from e
|
||||
|
||||
|
||||
# -------------------------- 2. 配置常量(保持原有)--------------------------
|
||||
HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒)
|
||||
HEARTBEAT_TIMEOUT = 600 # 客户端超时阈值(秒)
|
||||
@ -72,7 +57,7 @@ def get_current_time_file_str() -> str:
|
||||
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
|
||||
|
||||
# -------------------------- 4. 客户端连接封装(新增消息加密)--------------------------
|
||||
# -------------------------- 4. 客户端连接封装(服务器发消息仍加密,接收消息改明文)--------------------------
|
||||
class ClientConnection:
|
||||
def __init__(self, websocket: WebSocket, client_ip: str):
|
||||
self.websocket = websocket
|
||||
@ -96,28 +81,25 @@ class ClientConnection:
|
||||
return self.consumer_task
|
||||
|
||||
async def send_frame_permit(self):
|
||||
"""发送加密的帧许可信号"""
|
||||
"""发送加密的帧许可信号(服务器→客户端:加密)"""
|
||||
try:
|
||||
# 1. 构建原始消息
|
||||
frame_permit_msg = {
|
||||
"type": "frame",
|
||||
"timestamp": get_current_time_str(),
|
||||
"client_ip": self.client_ip
|
||||
}
|
||||
# 2. AES加密消息
|
||||
encrypted_msg = aes_encrypt(json.dumps(frame_permit_msg))
|
||||
# 3. 发送加密消息
|
||||
encrypted_msg = aes_encrypt(json.dumps(frame_permit_msg)) # 保持加密
|
||||
await self.websocket.send_json(encrypted_msg)
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 已发送加密帧许可")
|
||||
except Exception as e:
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧许可加密/发送失败 - {str(e)}")
|
||||
|
||||
async def consume_frames(self) -> None:
|
||||
"""消费队列中的帧并处理"""
|
||||
"""消费队列中的明文图像帧并处理"""
|
||||
try:
|
||||
while True:
|
||||
frame_data = await self.frame_queue.get()
|
||||
await self.send_frame_permit() # 发送下一帧许可
|
||||
await self.send_frame_permit() # 回复仍加密
|
||||
try:
|
||||
await self.process_frame(frame_data)
|
||||
finally:
|
||||
@ -128,23 +110,22 @@ class ClientConnection:
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费错误 - {str(e)}")
|
||||
|
||||
async def process_frame(self, frame_data: bytes) -> None:
|
||||
"""处理单帧图像(含加密危险通知)"""
|
||||
# 二进制转OpenCV图像
|
||||
"""处理明文图像帧(危险通知仍加密发送)"""
|
||||
# 二进制转OpenCV图像(客户端发的是明文二进制,直接解析)
|
||||
nparr = np.frombuffer(frame_data, np.uint8)
|
||||
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 无法解析图像")
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 无法解析明文图像")
|
||||
return
|
||||
|
||||
try:
|
||||
# 调用检测函数(client_ip + img 双参数)
|
||||
has_violation, data, detector_type = await asyncio.to_thread(
|
||||
detect, self.client_ip, img
|
||||
)
|
||||
print(
|
||||
f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测结果 - 违规: {has_violation}, 类型: {detector_type}")
|
||||
|
||||
# 处理违规逻辑(发送加密危险通知)
|
||||
# 违规通知:服务器→客户端,仍加密
|
||||
if has_violation:
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - {data}")
|
||||
# 违规次数+1
|
||||
@ -154,19 +135,17 @@ class ClientConnection:
|
||||
except Exception as e:
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数更新失败 - {str(e)}")
|
||||
|
||||
# 1. 构建原始危险通知
|
||||
# 构建危险通知并加密发送
|
||||
danger_msg = {
|
||||
"type": "danger",
|
||||
"timestamp": get_current_time_str(),
|
||||
"client_ip": self.client_ip,
|
||||
"detail": data
|
||||
}
|
||||
# 2. AES加密通知
|
||||
encrypted_danger_msg = aes_encrypt(json.dumps(danger_msg))
|
||||
# 3. 发送加密通知
|
||||
encrypted_danger_msg = aes_encrypt(json.dumps(danger_msg)) # 保持加密
|
||||
await self.websocket.send_json(encrypted_danger_msg)
|
||||
except Exception as e:
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 图像处理错误 - {str(e)}")
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 明文图像处理错误 - {str(e)}")
|
||||
|
||||
|
||||
# -------------------------- 5. 全局状态与心跳管理(保持原有)--------------------------
|
||||
@ -178,7 +157,6 @@ async def heartbeat_checker():
|
||||
"""全局心跳检查任务"""
|
||||
while True:
|
||||
current_time = get_current_time_str()
|
||||
# 筛选超时客户端
|
||||
timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()]
|
||||
|
||||
if timeout_ips:
|
||||
@ -186,11 +164,9 @@ async def heartbeat_checker():
|
||||
for ip in timeout_ips:
|
||||
try:
|
||||
conn = connected_clients[ip]
|
||||
# 取消消费任务+关闭连接
|
||||
if conn.consumer_task and not conn.consumer_task.done():
|
||||
conn.consumer_task.cancel()
|
||||
await conn.websocket.close(code=1008, reason="心跳超时")
|
||||
# 标记离线
|
||||
await asyncio.to_thread(update_online_status_by_ip, ip, 0)
|
||||
action_data = DeviceActionCreate(client_ip=ip, action=0)
|
||||
await asyncio.to_thread(add_device_action, action_data)
|
||||
@ -205,19 +181,16 @@ async def heartbeat_checker():
|
||||
await asyncio.sleep(HEARTBEAT_INTERVAL)
|
||||
|
||||
|
||||
# -------------------------- 6. 消息处理工具(新增消息解密)--------------------------
|
||||
# -------------------------- 6. 客户端明文消息处理(关键修改:删除解密逻辑)--------------------------
|
||||
async def send_heartbeat_ack(conn: ClientConnection):
|
||||
"""发送加密的心跳确认"""
|
||||
"""发送加密的心跳确认(服务器→客户端:加密)"""
|
||||
try:
|
||||
# 1. 构建原始心跳确认
|
||||
heartbeat_ack_msg = {
|
||||
"type": "heart",
|
||||
"timestamp": get_current_time_str(),
|
||||
"client_ip": conn.client_ip
|
||||
}
|
||||
# 2. AES加密
|
||||
encrypted_msg = aes_encrypt(json.dumps(heartbeat_ack_msg))
|
||||
# 3. 发送
|
||||
encrypted_msg = aes_encrypt(json.dumps(heartbeat_ack_msg)) # 保持加密
|
||||
await conn.websocket.send_json(encrypted_msg)
|
||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 已发送加密心跳确认")
|
||||
return True
|
||||
@ -228,44 +201,22 @@ async def send_heartbeat_ack(conn: ClientConnection):
|
||||
|
||||
|
||||
async def handle_text_msg(conn: ClientConnection, text: str):
|
||||
"""处理加密的文本消息(如心跳)"""
|
||||
"""处理客户端明文文本消息(如心跳)- 关键修改:无需解密,直接解析JSON"""
|
||||
try:
|
||||
# 1. 解析加密字典
|
||||
encrypted_dict = json.loads(text)
|
||||
# 2. AES解密
|
||||
decrypted_text = aes_decrypt(encrypted_dict)
|
||||
# 3. 解析业务消息
|
||||
msg = json.loads(decrypted_text)
|
||||
|
||||
# 客户端发的是明文JSON,直接解析(删除原解密步骤)
|
||||
msg = json.loads(text)
|
||||
if msg.get("type") == "heart":
|
||||
conn.update_heartbeat()
|
||||
await send_heartbeat_ack(conn)
|
||||
await send_heartbeat_ack(conn) # 服务器回复仍加密
|
||||
else:
|
||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 未知文本类型({msg.get('type')})")
|
||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 未知明文文本类型({msg.get('type')})")
|
||||
except json.JSONDecodeError:
|
||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 无效JSON格式")
|
||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 无效JSON格式(明文文本)")
|
||||
except Exception as e:
|
||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 文本消息解密失败 - {str(e)}")
|
||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 明文文本消息处理失败 - {str(e)}")
|
||||
|
||||
|
||||
async def handle_binary_msg(conn: ClientConnection, data: str):
|
||||
"""处理加密的图像消息(客户端需先转Base64+加密)"""
|
||||
try:
|
||||
# 1. 解密得到Base64编码的图像
|
||||
encrypted_dict = json.loads(data)
|
||||
decrypted_base64 = aes_decrypt(encrypted_dict)
|
||||
# 2. Base64解码为二进制图像
|
||||
frame_data = base64.b64decode(decrypted_base64)
|
||||
# 3. 加入帧队列
|
||||
conn.frame_queue.put_nowait(frame_data)
|
||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 解密后图像({len(frame_data)}字节)入队")
|
||||
except asyncio.QueueFull:
|
||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 帧队列已满,丢弃数据")
|
||||
except Exception as e:
|
||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 图像消息解密失败 - {str(e)}")
|
||||
|
||||
|
||||
# -------------------------- 7. WebSocket路由与生命周期(保持原有结构)--------------------------
|
||||
# -------------------------- 7. WebSocket路由与生命周期(关键修改:处理明文二进制图像)--------------------------
|
||||
ws_router = APIRouter()
|
||||
|
||||
|
||||
@ -276,7 +227,6 @@ async def lifespan(app: FastAPI):
|
||||
heartbeat_task = asyncio.create_task(heartbeat_checker())
|
||||
print(f"[{get_current_time_str()}] 心跳检查任务启动(ID: {id(heartbeat_task)})")
|
||||
yield
|
||||
# 关闭时清理
|
||||
if heartbeat_task and not heartbeat_task.done():
|
||||
heartbeat_task.cancel()
|
||||
await heartbeat_task
|
||||
@ -285,8 +235,8 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
@ws_router.websocket(WS_ENDPOINT)
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
"""WebSocket连接处理入口"""
|
||||
load_model() # 加载检测模型(仅一次)
|
||||
"""WebSocket连接处理入口 - 关键修改:接收客户端明文二进制图像"""
|
||||
load_model() # 加载检测模型(建议移到全局,避免重复加载)
|
||||
await websocket.accept()
|
||||
client_ip = websocket.client.host if websocket.client else "unknown_ip"
|
||||
current_time = get_current_time_str()
|
||||
@ -306,8 +256,8 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
# 注册新连接
|
||||
new_conn = ClientConnection(websocket, client_ip)
|
||||
connected_clients[client_ip] = new_conn
|
||||
new_conn.start_consumer() # 启动帧消费
|
||||
await new_conn.send_frame_permit() # 发送首次帧许可
|
||||
new_conn.start_consumer()
|
||||
await new_conn.send_frame_permit() # 首次许可仍加密
|
||||
|
||||
# 标记客户端上线
|
||||
try:
|
||||
@ -321,28 +271,33 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
|
||||
print(f"[{current_time}] 客户端{client_ip}: 连接注册成功,在线数: {len(connected_clients)}")
|
||||
|
||||
# 消息循环(接收客户端消息)
|
||||
# 消息循环:接收客户端明文消息(关键修改)
|
||||
while True:
|
||||
data = await websocket.receive()
|
||||
if "text" in data:
|
||||
# 处理加密文本消息(心跳、客户端指令)
|
||||
# 处理客户端明文文本(如心跳:{"type":"heart",...})
|
||||
await handle_text_msg(new_conn, data["text"])
|
||||
elif "bytes" in data:
|
||||
# 兼容客户端发送二进制:先转Base64再处理
|
||||
base64_data = base64.b64encode(data["bytes"]).decode("utf-8")
|
||||
await handle_binary_msg(new_conn, base64_data)
|
||||
# 处理客户端明文二进制图像(直接入队,无需解密)
|
||||
frame_data = data["bytes"]
|
||||
try:
|
||||
new_conn.frame_queue.put_nowait(frame_data)
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}: 明文图像({len(frame_data)}字节)入队")
|
||||
except asyncio.QueueFull:
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}: 帧队列已满,丢弃数据")
|
||||
except Exception as e:
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}: 明文图像处理失败 - {str(e)}")
|
||||
|
||||
except WebSocketDisconnect as e:
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}: 主动断开连接(代码: {e.code})")
|
||||
except Exception as e:
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}: 连接异常 - {str(e)[:50]}")
|
||||
finally:
|
||||
# 清理资源(断开后处理)
|
||||
# 清理资源
|
||||
if client_ip in connected_clients:
|
||||
conn = connected_clients[client_ip]
|
||||
if conn.consumer_task and not conn.consumer_task.done():
|
||||
conn.consumer_task.cancel()
|
||||
# 仅上线成功的客户端,才标记离线
|
||||
if is_online_updated:
|
||||
try:
|
||||
await asyncio.to_thread(update_online_status_by_ip, client_ip, 0)
|
||||
@ -352,4 +307,4 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
except Exception as e:
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}: 离线更新失败 - {str(e)}")
|
||||
connected_clients.pop(client_ip, None)
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}: 资源已清理,在线数: {len(connected_clients)}")
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}: 资源已清理,在线数: {len(connected_clients)}")
|
||||
Reference in New Issue
Block a user