from fastapi import FastAPI, HTTPException, Request, Depends, APIRouter from fastapi.responses import FileResponse import os import logging from functools import wraps from pathlib import Path from fastapi.middleware.cors import CORSMiddleware from typing import Annotated router = APIRouter( prefix="/api/file", tags=["文件管理"] ) # ------------------------------ # 4. 路径配置 # ------------------------------ CURRENT_FILE_PATH = Path(__file__).resolve() PROJECT_ROOT = CURRENT_FILE_PATH.parent.parent # 项目根目录 # 资源目录定义 BASE_IMAGE_DIR_DECT = str((PROJECT_ROOT / "resource" / "dect").resolve()) # 检测图片目录 BASE_IMAGE_DIR_UP_IMAGES = str((PROJECT_ROOT / "up_images").resolve()) # 人脸图片目录 BASE_MODEL_DIR = str((PROJECT_ROOT / "resource" / "models").resolve()) # 模型文件目录 # 确保资源目录存在 for dir_path in [BASE_IMAGE_DIR_DECT, BASE_IMAGE_DIR_UP_IMAGES, BASE_MODEL_DIR]: if not os.path.exists(dir_path): os.makedirs(dir_path, exist_ok=True) print(f"[创建目录] {dir_path}") # ------------------------------ # 5. 安全依赖项(替代Flask装饰器) # ------------------------------ def safe_path_check(root_dir: str): """ 安全路径校验依赖项: 1. 禁止路径遍历(确保请求文件在根目录内) 2. 校验文件存在且为有效文件(非目录) 3. 限制文件大小(模型200MB,图片10MB) """ async def dependency(request: Request, resource_path: str): # 统一路径分隔符 resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep) # 拼接完整路径 full_file_path = os.path.abspath(os.path.join(root_dir, resource_path)) # 校验1:禁止路径遍历 if not full_file_path.startswith(root_dir): print(f"[安全检查] 禁止路径遍历!IP:{request.client.host} | 请求路径:{resource_path}") raise HTTPException(status_code=403, detail="非法路径访问") # 校验2:文件存在且为有效文件 if not os.path.exists(full_file_path) or not os.path.isfile(full_file_path): print(f"[资源错误] 文件不存在/非文件!IP:{request.client.host} | 路径:{full_file_path}") raise HTTPException(status_code=404, detail="文件不存在") # 校验3:文件大小限制 max_size = 200 * 1024 * 1024 if "models" in root_dir else 10 * 1024 * 1024 if os.path.getsize(full_file_path) > max_size: print(f"[大小超限] 文件超过{max_size//1024//1024}MB!IP:{request.client.host} | 路径:{full_file_path}") raise HTTPException(status_code=413, detail=f"文件大小超过限制({max_size//1024//1024}MB)") return full_file_path return dependency # ------------------------------ # 6. 核心接口 # ------------------------------ @router.get("/model/download/{resource_path:path}", summary="模型下载接口") async def download_model( resource_path: str, full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_MODEL_DIR))], request: Request ): """模型下载接口(仅允许 .pt 格式,强制浏览器下载)""" try: dir_path, file_name = os.path.split(full_file_path) # 额外校验:仅允许 YOLO 模型格式(.pt) if not file_name.lower().endswith(".pt"): print(f"[格式错误] 非 .pt 模型文件!IP:{request.client.host} | 文件名:{file_name}") raise HTTPException(status_code=415, detail="仅支持 .pt 格式的模型文件") print(f"[模型下载] 尝试下载!IP:{request.client.host} | 文件:{file_name} | 目录:{dir_path}") # 强制下载 return FileResponse( full_file_path, filename=file_name, media_type="application/octet-stream" ) except HTTPException: raise except Exception as e: print(f"[下载异常] IP:{request.client.host} | 错误:{str(e)}") raise HTTPException(status_code=500, detail="服务器内部错误") @router.get("/up_images/{resource_path:path}", summary="人脸图片访问接口") async def get_face_image( resource_path: str, full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_IMAGE_DIR_UP_IMAGES))], request: Request ): """人脸图片访问接口(允许浏览器预览,仅支持常见图片格式)""" try: dir_path, file_name = os.path.split(full_file_path) # 图片格式校验 allowed_ext = (".png", ".jpg", ".jpeg", ".gif", ".bmp") if not file_name.lower().endswith(allowed_ext): print(f"[格式错误] 非图片文件!IP:{request.client.host} | 文件名:{file_name}") raise HTTPException(status_code=415, detail="仅支持常见图片格式") print(f"[人脸图片] 尝试访问!IP:{request.client.host} | 文件:{file_name} | 目录:{dir_path}") return FileResponse(full_file_path) except HTTPException: raise except Exception as e: print(f"[人脸图片异常] IP:{request.client.host} | 错误:{str(e)}") @router.get("/resource/dect/{resource_path:path}", summary="检测图片访问接口") async def get_dect_image( resource_path: str, full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_IMAGE_DIR_DECT))], request: Request ): """检测图片访问接口(允许浏览器预览,仅支持常见图片格式)""" try: dir_path, file_name = os.path.split(full_file_path) # 图片格式校验 allowed_ext = (".png", ".jpg", ".jpeg", ".gif", ".bmp") if not file_name.lower().endswith(allowed_ext): print(f"[格式错误] 非图片文件!IP:{request.client.host} | 文件名:{file_name}") raise HTTPException(status_code=415, detail="仅支持常见图片格式") print(f"[检测图片] 尝试访问!IP:{request.client.host} | 文件:{file_name} | 目录:{dir_path}") return FileResponse(full_file_path) except HTTPException: raise except Exception as e: print(f"[检测图片异常] IP:{request.client.host} | 错误:{str(e)}") raise HTTPException(status_code=500, detail="服务器内部错误") @router.get("/images/{resource_path:path}", summary="兼容旧接口") async def get_compatible_image( resource_path: str, full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_IMAGE_DIR_DECT))], request: Request ): """兼容旧接口(/images/* → 映射到 /resource/dect/*,保留历史兼容性)""" try: dir_path, file_name = os.path.split(full_file_path) # 图片格式校验 allowed_ext = (".png", ".jpg", ".jpeg", ".gif", ".bmp") if not file_name.lower().endswith(allowed_ext): print(f"[格式错误] 非图片文件!IP:{request.client.host} | 文件名:{file_name}") raise HTTPException(status_code=415, detail="仅支持常见图片格式") print(f"[兼容图片] 尝试访问!IP:{request.client.host} | 文件:{file_name} | 目录:{dir_path}") return FileResponse(full_file_path) except HTTPException: raise except Exception as e: print(f"[兼容图片异常] IP:{request.client.host} | 错误:{str(e)}") raise HTTPException(status_code=500, detail="服务器内部错误")