Files
video/service/file_service.py
2025-09-16 20:17:48 +08:00

175 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import FastAPI, HTTPException, Request, Depends, APIRouter
from fastapi.responses import FileResponse
import os
import logging
from functools import wraps
from pathlib import Path
from fastapi.middleware.cors import CORSMiddleware
from typing import Annotated
router = APIRouter(
prefix="/api/file",
tags=["文件管理"]
)
# ------------------------------
# 4. 路径配置
# ------------------------------
CURRENT_FILE_PATH = Path(__file__).resolve()
PROJECT_ROOT = CURRENT_FILE_PATH.parent.parent # 项目根目录
# 资源目录定义
BASE_IMAGE_DIR_DECT = str((PROJECT_ROOT / "resource" / "dect").resolve()) # 检测图片目录
BASE_IMAGE_DIR_UP_IMAGES = str((PROJECT_ROOT / "up_images").resolve()) # 人脸图片目录
BASE_MODEL_DIR = str((PROJECT_ROOT / "resource" / "models").resolve()) # 模型文件目录
# 确保资源目录存在
for dir_path in [BASE_IMAGE_DIR_DECT, BASE_IMAGE_DIR_UP_IMAGES, BASE_MODEL_DIR]:
if not os.path.exists(dir_path):
os.makedirs(dir_path, exist_ok=True)
print(f"[创建目录] {dir_path}")
# ------------------------------
# 5. 安全依赖项替代Flask装饰器
# ------------------------------
def safe_path_check(root_dir: str):
"""
安全路径校验依赖项:
1. 禁止路径遍历(确保请求文件在根目录内)
2. 校验文件存在且为有效文件(非目录)
3. 限制文件大小模型200MB图片10MB
"""
async def dependency(request: Request, resource_path: str):
# 统一路径分隔符
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
# 拼接完整路径
full_file_path = os.path.abspath(os.path.join(root_dir, resource_path))
# 校验1禁止路径遍历
if not full_file_path.startswith(root_dir):
print(f"[安全检查] 禁止路径遍历IP{request.client.host} | 请求路径:{resource_path}")
raise HTTPException(status_code=403, detail="非法路径访问")
# 校验2文件存在且为有效文件
if not os.path.exists(full_file_path) or not os.path.isfile(full_file_path):
print(f"[资源错误] 文件不存在/非文件IP{request.client.host} | 路径:{full_file_path}")
raise HTTPException(status_code=404, detail="文件不存在")
# 校验3文件大小限制
max_size = 200 * 1024 * 1024 if "models" in root_dir else 10 * 1024 * 1024
if os.path.getsize(full_file_path) > max_size:
print(f"[大小超限] 文件超过{max_size//1024//1024}MBIP{request.client.host} | 路径:{full_file_path}")
raise HTTPException(status_code=413, detail=f"文件大小超过限制({max_size//1024//1024}MB)")
return full_file_path
return dependency
# ------------------------------
# 6. 核心接口
# ------------------------------
@router.get("/model/download/{resource_path:path}", summary="模型下载接口")
async def download_model(
resource_path: str,
full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_MODEL_DIR))],
request: Request
):
"""模型下载接口(仅允许 .pt 格式,强制浏览器下载)"""
try:
dir_path, file_name = os.path.split(full_file_path)
# 额外校验:仅允许 YOLO 模型格式(.pt
if not file_name.lower().endswith(".pt"):
print(f"[格式错误] 非 .pt 模型文件IP{request.client.host} | 文件名:{file_name}")
raise HTTPException(status_code=415, detail="仅支持 .pt 格式的模型文件")
print(f"[模型下载] 尝试下载IP{request.client.host} | 文件:{file_name} | 目录:{dir_path}")
# 强制下载
return FileResponse(
full_file_path,
filename=file_name,
media_type="application/octet-stream"
)
except HTTPException:
raise
except Exception as e:
print(f"[下载异常] IP{request.client.host} | 错误:{str(e)}")
raise HTTPException(status_code=500, detail="服务器内部错误")
@router.get("/up_images/{resource_path:path}", summary="人脸图片访问接口")
async def get_face_image(
resource_path: str,
full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_IMAGE_DIR_UP_IMAGES))],
request: Request
):
"""人脸图片访问接口(允许浏览器预览,仅支持常见图片格式)"""
try:
dir_path, file_name = os.path.split(full_file_path)
# 图片格式校验
allowed_ext = (".png", ".jpg", ".jpeg", ".gif", ".bmp")
if not file_name.lower().endswith(allowed_ext):
print(f"[格式错误] 非图片文件IP{request.client.host} | 文件名:{file_name}")
raise HTTPException(status_code=415, detail="仅支持常见图片格式")
print(f"[人脸图片] 尝试访问IP{request.client.host} | 文件:{file_name} | 目录:{dir_path}")
return FileResponse(full_file_path)
except HTTPException:
raise
except Exception as e:
print(f"[人脸图片异常] IP{request.client.host} | 错误:{str(e)}")
@router.get("/resource/dect/{resource_path:path}", summary="检测图片访问接口")
async def get_dect_image(
resource_path: str,
full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_IMAGE_DIR_DECT))],
request: Request
):
"""检测图片访问接口(允许浏览器预览,仅支持常见图片格式)"""
try:
dir_path, file_name = os.path.split(full_file_path)
# 图片格式校验
allowed_ext = (".png", ".jpg", ".jpeg", ".gif", ".bmp")
if not file_name.lower().endswith(allowed_ext):
print(f"[格式错误] 非图片文件IP{request.client.host} | 文件名:{file_name}")
raise HTTPException(status_code=415, detail="仅支持常见图片格式")
print(f"[检测图片] 尝试访问IP{request.client.host} | 文件:{file_name} | 目录:{dir_path}")
return FileResponse(full_file_path)
except HTTPException:
raise
except Exception as e:
print(f"[检测图片异常] IP{request.client.host} | 错误:{str(e)}")
raise HTTPException(status_code=500, detail="服务器内部错误")
@router.get("/images/{resource_path:path}", summary="兼容旧接口")
async def get_compatible_image(
resource_path: str,
full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_IMAGE_DIR_DECT))],
request: Request
):
"""兼容旧接口(/images/* → 映射到 /resource/dect/*,保留历史兼容性)"""
try:
dir_path, file_name = os.path.split(full_file_path)
# 图片格式校验
allowed_ext = (".png", ".jpg", ".jpeg", ".gif", ".bmp")
if not file_name.lower().endswith(allowed_ext):
print(f"[格式错误] 非图片文件IP{request.client.host} | 文件名:{file_name}")
raise HTTPException(status_code=415, detail="仅支持常见图片格式")
print(f"[兼容图片] 尝试访问IP{request.client.host} | 文件:{file_name} | 目录:{dir_path}")
return FileResponse(full_file_path)
except HTTPException:
raise
except Exception as e:
print(f"[兼容图片异常] IP{request.client.host} | 错误:{str(e)}")
raise HTTPException(status_code=500, detail="服务器内部错误")