yolo模型识别不到
This commit is contained in:
@ -12,7 +12,7 @@ from schema.response_schema import APIResponse
|
||||
|
||||
# 路由配置
|
||||
router = APIRouter(
|
||||
prefix="/device/actions",
|
||||
prefix="/api/device/actions",
|
||||
tags=["设备操作记录"]
|
||||
)
|
||||
|
||||
|
||||
267
service/device_danger_service.py
Normal file
267
service/device_danger_service.py
Normal file
@ -0,0 +1,267 @@
|
||||
import json
|
||||
from datetime import date
|
||||
|
||||
from fastapi import APIRouter, Query, HTTPException, Path
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
from ds.db import db
|
||||
from encryption.encrypt_decorator import encrypt_response
|
||||
from schema.device_danger_schema import (
|
||||
DeviceDangerCreateRequest, DeviceDangerResponse, DeviceDangerListResponse
|
||||
)
|
||||
from schema.response_schema import APIResponse
|
||||
|
||||
# 路由初始化(前缀与设备管理相关,标签区分功能)
|
||||
router = APIRouter(
|
||||
prefix="/api/devices/dangers",
|
||||
tags=["设备管理-危险记录"]
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 内部工具方法 - 检查设备是否存在(复用设备表逻辑)
|
||||
# ------------------------------
|
||||
def check_device_exist(client_ip: str) -> bool:
|
||||
"""
|
||||
检查指定IP的设备是否在devices表中存在
|
||||
|
||||
:param client_ip: 设备IP地址
|
||||
:return: 存在返回True,不存在返回False
|
||||
"""
|
||||
if not client_ip:
|
||||
raise ValueError("设备IP不能为空")
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,))
|
||||
return cursor.fetchone() is not None
|
||||
except MySQLError as e:
|
||||
raise Exception(f"检查设备存在性失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 内部工具方法 - 创建设备危险记录(核心插入逻辑)
|
||||
# ------------------------------
|
||||
def create_danger_record(danger_data: DeviceDangerCreateRequest) -> DeviceDangerResponse:
|
||||
"""
|
||||
内部工具方法:向device_danger表插入新的危险记录
|
||||
|
||||
:param danger_data: 危险记录创建请求数据
|
||||
:return: 创建成功的危险记录模型对象
|
||||
"""
|
||||
# 先检查设备是否存在
|
||||
if not check_device_exist(danger_data.client_ip):
|
||||
raise ValueError(f"IP为 {danger_data.client_ip} 的设备不存在,无法创建危险记录")
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 插入危险记录(id自增,时间自动填充)
|
||||
insert_query = """
|
||||
INSERT INTO device_danger
|
||||
(client_ip, type, result, created_at, updated_at)
|
||||
VALUES (%s, %s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
"""
|
||||
cursor.execute(insert_query, (
|
||||
danger_data.client_ip,
|
||||
danger_data.type,
|
||||
danger_data.result
|
||||
))
|
||||
conn.commit()
|
||||
|
||||
# 获取刚创建的记录(用自增ID查询)
|
||||
danger_id = cursor.lastrowid
|
||||
cursor.execute("SELECT * FROM device_danger WHERE id = %s", (danger_id,))
|
||||
new_danger = cursor.fetchone()
|
||||
|
||||
return DeviceDangerResponse(**new_danger)
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"插入危险记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 接口1:创建设备危险记录
|
||||
# ------------------------------
|
||||
@router.post("/add", response_model=APIResponse, summary="创建设备危险记录")
|
||||
@encrypt_response()
|
||||
async def add_device_danger(danger_data: DeviceDangerCreateRequest):
|
||||
try:
|
||||
# 调用内部方法创建记录
|
||||
new_danger = create_danger_record(danger_data)
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"设备[{danger_data.client_ip}]危险记录创建成功",
|
||||
data=new_danger
|
||||
)
|
||||
except ValueError as e:
|
||||
# 设备不存在等业务异常
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
except Exception as e:
|
||||
# 数据库异常等系统错误
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 接口2:获取危险记录列表(支持多条件筛选+分页)
|
||||
# ------------------------------
|
||||
@router.get("/", response_model=APIResponse, summary="获取设备危险记录列表(多条件筛选)")
|
||||
@encrypt_response()
|
||||
async def get_danger_list(
|
||||
page: int = Query(1, ge=1, description="页码,默认第1页"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100之间"),
|
||||
client_ip: str = Query(None, max_length=100, description="按设备IP筛选"),
|
||||
danger_type: str = Query(None, max_length=50, alias="type", description="按危险类型筛选"),
|
||||
start_date: date = Query(None, description="按创建时间筛选(开始日期,格式YYYY-MM-DD)"),
|
||||
end_date: date = Query(None, description="按创建时间筛选(结束日期,格式YYYY-MM-DD)")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 构建筛选条件
|
||||
where_clause = []
|
||||
params = []
|
||||
|
||||
if client_ip:
|
||||
where_clause.append("client_ip = %s")
|
||||
params.append(client_ip)
|
||||
if danger_type:
|
||||
where_clause.append("type = %s")
|
||||
params.append(danger_type)
|
||||
if start_date:
|
||||
where_clause.append("DATE(created_at) >= %s")
|
||||
params.append(start_date.strftime("%Y-%m-%d"))
|
||||
if end_date:
|
||||
where_clause.append("DATE(created_at) <= %s")
|
||||
params.append(end_date.strftime("%Y-%m-%d"))
|
||||
|
||||
# 1. 统计符合条件的总记录数
|
||||
count_query = "SELECT COUNT(*) AS total FROM device_danger"
|
||||
if where_clause:
|
||||
count_query += " WHERE " + " AND ".join(where_clause)
|
||||
cursor.execute(count_query, params)
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 2. 分页查询记录(按创建时间倒序,最新的在前)
|
||||
offset = (page - 1) * page_size
|
||||
list_query = "SELECT * FROM device_danger"
|
||||
if where_clause:
|
||||
list_query += " WHERE " + " AND ".join(where_clause)
|
||||
list_query += " ORDER BY created_at DESC LIMIT %s OFFSET %s"
|
||||
params.extend([page_size, offset]) # 追加分页参数
|
||||
|
||||
cursor.execute(list_query, params)
|
||||
danger_list = cursor.fetchall()
|
||||
|
||||
# 转换为响应模型
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="获取危险记录列表成功",
|
||||
data=DeviceDangerListResponse(
|
||||
total=total,
|
||||
dangers=[DeviceDangerResponse(**item) for item in danger_list]
|
||||
)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"查询危险记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 接口3:获取单个设备的所有危险记录
|
||||
# ------------------------------
|
||||
@router.get("/device/{client_ip}", response_model=APIResponse, summary="获取单个设备的所有危险记录")
|
||||
# @encrypt_response()
|
||||
async def get_device_dangers(
|
||||
client_ip: str = Path(..., max_length=100, description="设备IP地址"),
|
||||
page: int = Query(1, ge=1, description="页码,默认第1页"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100之间")
|
||||
):
|
||||
# 先检查设备是否存在
|
||||
if not check_device_exist(client_ip):
|
||||
raise HTTPException(status_code=404, detail=f"IP为 {client_ip} 的设备不存在")
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 1. 统计该设备的危险记录总数
|
||||
count_query = "SELECT COUNT(*) AS total FROM device_danger WHERE client_ip = %s"
|
||||
cursor.execute(count_query, (client_ip,))
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 2. 分页查询该设备的危险记录
|
||||
offset = (page - 1) * page_size
|
||||
list_query = """
|
||||
SELECT * FROM device_danger
|
||||
WHERE client_ip = %s
|
||||
ORDER BY created_at DESC
|
||||
LIMIT %s OFFSET %s
|
||||
"""
|
||||
cursor.execute(list_query, (client_ip, page_size, offset))
|
||||
danger_list = cursor.fetchall()
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"获取设备[{client_ip}]危险记录成功(共{total}条)",
|
||||
data=DeviceDangerListResponse(
|
||||
total=total,
|
||||
dangers=[DeviceDangerResponse(**item) for item in danger_list]
|
||||
)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"查询设备[{client_ip}]危险记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 接口4:根据ID获取单个危险记录详情
|
||||
# ------------------------------
|
||||
@router.get("/{danger_id}", response_model=APIResponse, summary="根据ID获取单个危险记录详情")
|
||||
@encrypt_response()
|
||||
async def get_danger_detail(
|
||||
danger_id: int = Path(..., ge=1, description="危险记录ID")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 查询单个危险记录
|
||||
query = "SELECT * FROM device_danger WHERE id = %s"
|
||||
cursor.execute(query, (danger_id,))
|
||||
danger = cursor.fetchone()
|
||||
|
||||
if not danger:
|
||||
raise HTTPException(status_code=404, detail=f"ID为 {danger_id} 的危险记录不存在")
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="获取危险记录详情成功",
|
||||
data=DeviceDangerResponse(**danger)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"查询危险记录详情失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
@ -13,7 +13,7 @@ from schema.device_schema import (
|
||||
from schema.response_schema import APIResponse
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/devices",
|
||||
prefix="/api/devices",
|
||||
tags=["设备管理"]
|
||||
)
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ from schema.response_schema import APIResponse
|
||||
from util.face_util import add_binary_data, get_average_feature
|
||||
from util.file_util import save_face_to_up_images
|
||||
|
||||
router = APIRouter(prefix="/faces", tags=["人脸管理"])
|
||||
router = APIRouter(prefix="/api/faces", tags=["人脸管理"])
|
||||
|
||||
|
||||
# ------------------------------
|
||||
|
||||
@ -1,276 +1,174 @@
|
||||
from flask import Flask, send_from_directory, abort, request
|
||||
from fastapi import FastAPI, HTTPException, Request, Depends, APIRouter
|
||||
from fastapi.responses import FileResponse
|
||||
import os
|
||||
import logging
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from flask_cors import CORS
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from typing import Annotated
|
||||
|
||||
# 配置日志(保持原有格式)
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 初始化 Flask 应用(供 main.py 导入)
|
||||
app = Flask(__name__)
|
||||
|
||||
# ------------------------------
|
||||
# 核心修改:与 FastAPI 对齐的跨域配置
|
||||
# ------------------------------
|
||||
# 1. 允许的前端域名(根据实际环境修改,生产环境删除 "*")
|
||||
ALLOWED_ORIGINS = [
|
||||
# "http://localhost:8080", # 本地前端开发地址
|
||||
# "http://127.0.0.1:8080",
|
||||
# "http://服务器IP:8080", # 部署后前端地址
|
||||
"*"
|
||||
]
|
||||
|
||||
# 2. 配置 CORS(与 FastAPI 规则完全对齐)
|
||||
CORS(
|
||||
app,
|
||||
resources={
|
||||
r"/*": {
|
||||
"origins": ALLOWED_ORIGINS,
|
||||
"allow_credentials": True,
|
||||
"methods": ["*"],
|
||||
"allow_headers": ["*"],
|
||||
}
|
||||
},
|
||||
router = APIRouter(
|
||||
prefix="/api/file",
|
||||
tags=["文件管理"]
|
||||
)
|
||||
|
||||
# ------------------------------
|
||||
# 核心路径配置(关键修改:修正 PROJECT_ROOT 计算)
|
||||
# 原问题:file_service.py 在 service 文件夹内,需向上跳一级到项目根目录
|
||||
# 4. 路径配置
|
||||
# ------------------------------
|
||||
CURRENT_FILE_PATH = Path(__file__).resolve() # 当前文件路径:service/file_service.py
|
||||
PROJECT_ROOT = CURRENT_FILE_PATH.parent.parent # 项目根目录(service 文件夹的父目录)
|
||||
# 资源目录(现在正确指向项目根目录下的文件夹)
|
||||
BASE_IMAGE_DIR_DECT = str((PROJECT_ROOT / "resource" / "dect").resolve()) # 根目录/resource/dect
|
||||
BASE_IMAGE_DIR_UP_IMAGES = str((PROJECT_ROOT / "up_images").resolve()) # 根目录/up_images
|
||||
BASE_MODEL_DIR = str((PROJECT_ROOT / "resource" / "models").resolve()) # 根目录/resource/models
|
||||
CURRENT_FILE_PATH = Path(__file__).resolve()
|
||||
PROJECT_ROOT = CURRENT_FILE_PATH.parent.parent # 项目根目录
|
||||
|
||||
# 资源目录定义
|
||||
BASE_IMAGE_DIR_DECT = str((PROJECT_ROOT / "resource" / "dect").resolve()) # 检测图片目录
|
||||
BASE_IMAGE_DIR_UP_IMAGES = str((PROJECT_ROOT / "up_images").resolve()) # 人脸图片目录
|
||||
BASE_MODEL_DIR = str((PROJECT_ROOT / "resource" / "models").resolve()) # 模型文件目录
|
||||
|
||||
# 确保资源目录存在
|
||||
for dir_path in [BASE_IMAGE_DIR_DECT, BASE_IMAGE_DIR_UP_IMAGES, BASE_MODEL_DIR]:
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
print(f"[创建目录] {dir_path}")
|
||||
|
||||
# ------------------------------
|
||||
# 安全检查装饰器(不变)
|
||||
# 5. 安全依赖项(替代Flask装饰器)
|
||||
# ------------------------------
|
||||
def safe_path_check(root_dir: str):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
resource_path = kwargs.get('resource_path', '').strip()
|
||||
# 统一路径分隔符(兼容 Windows \ 和 Linux /)
|
||||
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
|
||||
# 拼接完整路径(防止路径遍历)
|
||||
full_file_path = os.path.abspath(os.path.join(root_dir, resource_path))
|
||||
logger.debug(
|
||||
f"[Flask 安全检查] 请求路径:{resource_path} | 完整路径:{full_file_path} | 根目录:{root_dir}"
|
||||
)
|
||||
|
||||
# 1. 禁止路径遍历(确保请求文件在根目录内)
|
||||
if not full_file_path.startswith(root_dir):
|
||||
logger.warning(
|
||||
f"[Flask 安全拦截] 非法路径遍历!IP:{request.remote_addr} | 请求路径:{resource_path}"
|
||||
)
|
||||
abort(403)
|
||||
|
||||
# 2. 检查文件存在且为有效文件(非目录)
|
||||
if not os.path.exists(full_file_path) or not os.path.isfile(full_file_path):
|
||||
logger.warning(
|
||||
f"[Flask 资源错误] 文件不存在/非文件!IP:{request.remote_addr} | 路径:{full_file_path}"
|
||||
)
|
||||
abort(404)
|
||||
|
||||
# 3. 限制文件大小(模型200MB,图片10MB)
|
||||
max_size = 200 * 1024 * 1024 if "models" in root_dir else 10 * 1024 * 1024
|
||||
if os.path.getsize(full_file_path) > max_size:
|
||||
logger.warning(
|
||||
f"[Flask 大小超限] 文件超过{max_size//1024//1024}MB!IP:{request.remote_addr} | 路径:{full_file_path}"
|
||||
)
|
||||
abort(413)
|
||||
|
||||
# 安全检查通过,传递根目录给视图函数
|
||||
return func(*args, **kwargs, root_dir=root_dir)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
# ------------------------------
|
||||
# 1. 模型下载接口(/model/download/*)
|
||||
# ------------------------------
|
||||
@app.route('/model/download/<path:resource_path>')
|
||||
@safe_path_check(root_dir=BASE_MODEL_DIR)
|
||||
def download_model(resource_path, root_dir):
|
||||
try:
|
||||
"""
|
||||
安全路径校验依赖项:
|
||||
1. 禁止路径遍历(确保请求文件在根目录内)
|
||||
2. 校验文件存在且为有效文件(非目录)
|
||||
3. 限制文件大小(模型200MB,图片10MB)
|
||||
"""
|
||||
async def dependency(request: Request, resource_path: str):
|
||||
# 统一路径分隔符
|
||||
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
|
||||
dir_path, file_name = os.path.split(resource_path)
|
||||
full_dir = os.path.abspath(os.path.join(root_dir, dir_path))
|
||||
# 拼接完整路径
|
||||
full_file_path = os.path.abspath(os.path.join(root_dir, resource_path))
|
||||
|
||||
# 仅允许 .pt 格式(YOLO 模型)
|
||||
if not file_name.lower().endswith('.pt'):
|
||||
logger.warning(
|
||||
f"[Flask 格式错误] 非 .pt 模型文件!IP:{request.remote_addr} | 文件名:{file_name}"
|
||||
)
|
||||
abort(415)
|
||||
# 校验1:禁止路径遍历
|
||||
if not full_file_path.startswith(root_dir):
|
||||
print(f"[安全检查] 禁止路径遍历!IP:{request.client.host} | 请求路径:{resource_path}")
|
||||
raise HTTPException(status_code=403, detail="非法路径访问")
|
||||
|
||||
logger.info(
|
||||
f"[Flask 模型下载] 成功请求!IP:{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}"
|
||||
# 校验2:文件存在且为有效文件
|
||||
if not os.path.exists(full_file_path) or not os.path.isfile(full_file_path):
|
||||
print(f"[资源错误] 文件不存在/非文件!IP:{request.client.host} | 路径:{full_file_path}")
|
||||
raise HTTPException(status_code=404, detail="文件不存在")
|
||||
|
||||
# 校验3:文件大小限制
|
||||
max_size = 200 * 1024 * 1024 if "models" in root_dir else 10 * 1024 * 1024
|
||||
if os.path.getsize(full_file_path) > max_size:
|
||||
print(f"[大小超限] 文件超过{max_size//1024//1024}MB!IP:{request.client.host} | 路径:{full_file_path}")
|
||||
raise HTTPException(status_code=413, detail=f"文件大小超过限制({max_size//1024//1024}MB)")
|
||||
|
||||
return full_file_path
|
||||
return dependency
|
||||
|
||||
# ------------------------------
|
||||
# 6. 核心接口
|
||||
# ------------------------------
|
||||
@router.get("/model/download/{resource_path:path}", summary="模型下载接口")
|
||||
async def download_model(
|
||||
resource_path: str,
|
||||
full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_MODEL_DIR))],
|
||||
request: Request
|
||||
):
|
||||
"""模型下载接口(仅允许 .pt 格式,强制浏览器下载)"""
|
||||
try:
|
||||
dir_path, file_name = os.path.split(full_file_path)
|
||||
|
||||
# 额外校验:仅允许 YOLO 模型格式(.pt)
|
||||
if not file_name.lower().endswith(".pt"):
|
||||
print(f"[格式错误] 非 .pt 模型文件!IP:{request.client.host} | 文件名:{file_name}")
|
||||
raise HTTPException(status_code=415, detail="仅支持 .pt 格式的模型文件")
|
||||
|
||||
print(f"[模型下载] 尝试下载!IP:{request.client.host} | 文件:{file_name} | 目录:{dir_path}")
|
||||
|
||||
# 强制下载
|
||||
return FileResponse(
|
||||
full_file_path,
|
||||
filename=file_name,
|
||||
media_type="application/octet-stream"
|
||||
)
|
||||
|
||||
# 强制浏览器下载(而非预览)
|
||||
return send_from_directory(
|
||||
full_dir,
|
||||
file_name,
|
||||
as_attachment=True,
|
||||
mimetype="application/octet-stream"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Flask 模型下载异常] IP:{request.remote_addr} | 错误:{str(e)}"
|
||||
)
|
||||
abort(500)
|
||||
print(f"[下载异常] IP:{request.client.host} | 错误:{str(e)}")
|
||||
raise HTTPException(status_code=500, detail="服务器内部错误")
|
||||
|
||||
# ------------------------------
|
||||
# 2. 人脸图片访问接口(/up_images/*)
|
||||
# ------------------------------
|
||||
@app.route('/up_images/<path:resource_path>')
|
||||
@safe_path_check(root_dir=BASE_IMAGE_DIR_UP_IMAGES)
|
||||
def get_face_image(resource_path, root_dir):
|
||||
|
||||
@router.get("/up_images/{resource_path:path}", summary="人脸图片访问接口")
|
||||
async def get_face_image(
|
||||
resource_path: str,
|
||||
full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_IMAGE_DIR_UP_IMAGES))],
|
||||
request: Request
|
||||
):
|
||||
"""人脸图片访问接口(允许浏览器预览,仅支持常见图片格式)"""
|
||||
try:
|
||||
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
|
||||
dir_path, file_name = os.path.split(resource_path)
|
||||
full_dir = os.path.abspath(os.path.join(root_dir, dir_path))
|
||||
dir_path, file_name = os.path.split(full_file_path)
|
||||
|
||||
# 仅允许常见图片格式
|
||||
allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp')
|
||||
# 图片格式校验
|
||||
allowed_ext = (".png", ".jpg", ".jpeg", ".gif", ".bmp")
|
||||
if not file_name.lower().endswith(allowed_ext):
|
||||
logger.warning(
|
||||
f"[Flask 格式错误] 非图片文件!IP:{request.remote_addr} | 文件名:{file_name}"
|
||||
)
|
||||
abort(415)
|
||||
print(f"[格式错误] 非图片文件!IP:{request.client.host} | 文件名:{file_name}")
|
||||
raise HTTPException(status_code=415, detail="仅支持常见图片格式")
|
||||
|
||||
logger.info(
|
||||
f"[Flask 人脸图片] 成功请求!IP:{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}"
|
||||
)
|
||||
|
||||
# 允许浏览器预览图片
|
||||
return send_from_directory(full_dir, file_name, as_attachment=False)
|
||||
print(f"[人脸图片] 尝试访问!IP:{request.client.host} | 文件:{file_name} | 目录:{dir_path}")
|
||||
|
||||
return FileResponse(full_file_path)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Flask 人脸图片异常] IP:{request.remote_addr} | 错误:{str(e)}"
|
||||
)
|
||||
abort(500)
|
||||
print(f"[人脸图片异常] IP:{request.client.host} | 错误:{str(e)}")
|
||||
|
||||
# ------------------------------
|
||||
# 3. 检测图片访问接口(/resource/dect/*)
|
||||
# ------------------------------
|
||||
@app.route('/resource/dect/<path:resource_path>')
|
||||
@safe_path_check(root_dir=BASE_IMAGE_DIR_DECT)
|
||||
def get_dect_image(resource_path, root_dir):
|
||||
|
||||
@router.get("/resource/dect/{resource_path:path}", summary="检测图片访问接口")
|
||||
async def get_dect_image(
|
||||
resource_path: str,
|
||||
full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_IMAGE_DIR_DECT))],
|
||||
request: Request
|
||||
):
|
||||
"""检测图片访问接口(允许浏览器预览,仅支持常见图片格式)"""
|
||||
try:
|
||||
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
|
||||
dir_path, file_name = os.path.split(resource_path)
|
||||
full_dir = os.path.abspath(os.path.join(root_dir, dir_path))
|
||||
dir_path, file_name = os.path.split(full_file_path)
|
||||
|
||||
# 仅允许常见图片格式
|
||||
allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp')
|
||||
# 图片格式校验
|
||||
allowed_ext = (".png", ".jpg", ".jpeg", ".gif", ".bmp")
|
||||
if not file_name.lower().endswith(allowed_ext):
|
||||
logger.warning(
|
||||
f"[Flask 格式错误] 非图片文件!IP:{request.remote_addr} | 文件名:{file_name}"
|
||||
)
|
||||
abort(415)
|
||||
print(f"[格式错误] 非图片文件!IP:{request.client.host} | 文件名:{file_name}")
|
||||
raise HTTPException(status_code=415, detail="仅支持常见图片格式")
|
||||
|
||||
logger.info(
|
||||
f"[Flask 检测图片] 成功请求!IP:{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}"
|
||||
)
|
||||
|
||||
return send_from_directory(full_dir, file_name, as_attachment=False)
|
||||
print(f"[检测图片] 尝试访问!IP:{request.client.host} | 文件:{file_name} | 目录:{dir_path}")
|
||||
|
||||
return FileResponse(full_file_path)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Flask 检测图片异常] IP:{request.remote_addr} | 错误:{str(e)}"
|
||||
)
|
||||
abort(500)
|
||||
print(f"[检测图片异常] IP:{request.client.host} | 错误:{str(e)}")
|
||||
raise HTTPException(status_code=500, detail="服务器内部错误")
|
||||
|
||||
# ------------------------------
|
||||
# 4. 兼容旧图片接口(/images/* → 映射到 /resource/dect/*)
|
||||
# ------------------------------
|
||||
@app.route('/images/<path:resource_path>')
|
||||
@safe_path_check(root_dir=BASE_IMAGE_DIR_DECT)
|
||||
def get_compatible_image(resource_path, root_dir):
|
||||
|
||||
@router.get("/images/{resource_path:path}", summary="兼容旧接口")
|
||||
async def get_compatible_image(
|
||||
resource_path: str,
|
||||
full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_IMAGE_DIR_DECT))],
|
||||
request: Request
|
||||
):
|
||||
"""兼容旧接口(/images/* → 映射到 /resource/dect/*,保留历史兼容性)"""
|
||||
try:
|
||||
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
|
||||
dir_path, file_name = os.path.split(resource_path)
|
||||
full_dir = os.path.abspath(os.path.join(root_dir, dir_path))
|
||||
dir_path, file_name = os.path.split(full_file_path)
|
||||
|
||||
allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp')
|
||||
# 图片格式校验
|
||||
allowed_ext = (".png", ".jpg", ".jpeg", ".gif", ".bmp")
|
||||
if not file_name.lower().endswith(allowed_ext):
|
||||
logger.warning(
|
||||
f"[Flask 格式错误] 非图片文件!IP:{request.remote_addr} | 文件名:{file_name}"
|
||||
)
|
||||
abort(415)
|
||||
|
||||
logger.info(
|
||||
f"[Flask 兼容图片] 成功请求!IP:{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}"
|
||||
)
|
||||
|
||||
return send_from_directory(full_dir, file_name, as_attachment=False)
|
||||
print(f"[格式错误] 非图片文件!IP:{request.client.host} | 文件名:{file_name}")
|
||||
raise HTTPException(status_code=415, detail="仅支持常见图片格式")
|
||||
print(f"[兼容图片] 尝试访问!IP:{request.client.host} | 文件:{file_name} | 目录:{dir_path}")
|
||||
|
||||
return FileResponse(full_file_path)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Flask 兼容图片异常] IP:{request.remote_addr} | 错误:{str(e)}"
|
||||
)
|
||||
abort(500)
|
||||
|
||||
# ------------------------------
|
||||
# 全局错误处理器(不变)
|
||||
# ------------------------------
|
||||
@app.errorhandler(403)
|
||||
def forbidden_error(error):
|
||||
return "❌ 禁止访问:路径非法(可能存在路径遍历)或无权限", 403
|
||||
|
||||
@app.errorhandler(404)
|
||||
def not_found_error(error):
|
||||
return "❌ 资源不存在:请检查URL路径(IP、目录、文件名)是否正确", 404
|
||||
|
||||
@app.errorhandler(413)
|
||||
def too_large_error(error):
|
||||
return "❌ 文件过大:图片最大10MB,模型最大200MB", 413
|
||||
|
||||
@app.errorhandler(415)
|
||||
def unsupported_type_error(error):
|
||||
return "❌ 不支持的文件类型:图片支持png/jpg/jpeg/gif/bmp,模型仅支持pt", 415
|
||||
|
||||
@app.errorhandler(500)
|
||||
def server_error(error):
|
||||
return "❌ 服务器内部错误:请联系管理员查看后台日志", 500
|
||||
|
||||
# ------------------------------
|
||||
# Flask 独立启动入口(供测试,实际由 main.py 子线程启动)
|
||||
# ------------------------------
|
||||
if __name__ == '__main__':
|
||||
# 确保所有资源目录存在
|
||||
required_dirs = [
|
||||
(BASE_IMAGE_DIR_DECT, "检测图片目录"),
|
||||
(BASE_IMAGE_DIR_UP_IMAGES, "人脸图片目录"),
|
||||
(BASE_MODEL_DIR, "模型文件目录")
|
||||
]
|
||||
for dir_path, dir_desc in required_dirs:
|
||||
if not os.path.exists(dir_path):
|
||||
logger.info(f"[Flask 初始化] {dir_desc}不存在,创建:{dir_path}")
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
|
||||
# 启动提示
|
||||
logger.info("\n[Flask 服务启动成功!] 支持的接口:")
|
||||
logger.info(f"1. 模型下载 → http://服务器IP:5000/model/download/resource/models/xxx.pt")
|
||||
logger.info(f"2. 人脸图片 → http://服务器IP:5000/up_images/xxx.jpg")
|
||||
logger.info(f"3. 检测图片 → http://服务器IP:5000/resource/dect/xxx.jpg 或 http://服务器IP:5000/images/xxx.jpg\n")
|
||||
|
||||
# 启动服务(禁用 debug 和自动重载)
|
||||
app.run(
|
||||
host="0.0.0.0",
|
||||
port=5000,
|
||||
debug=False,
|
||||
use_reloader=False
|
||||
)
|
||||
print(f"[兼容图片异常] IP:{request.client.host} | 错误:{str(e)}")
|
||||
raise HTTPException(status_code=500, detail="服务器内部错误")
|
||||
|
||||
@ -38,7 +38,7 @@ _yolo_model = None
|
||||
_current_model_version = None # 模型版本标识
|
||||
_current_conf_threshold = 0.8 # 默认置信度初始值
|
||||
|
||||
router = APIRouter(prefix="/models", tags=["模型管理"])
|
||||
router = APIRouter(prefix="/api/models", tags=["模型管理"])
|
||||
|
||||
|
||||
# 服务重启核心工具函数(保持不变)
|
||||
|
||||
@ -16,7 +16,7 @@ from schema.user_schema import UserResponse
|
||||
|
||||
# 创建敏感信息接口路由(前缀 /sensitives、标签用于 Swagger 分类)
|
||||
router = APIRouter(
|
||||
prefix="/sensitives",
|
||||
prefix="/api/sensitives",
|
||||
tags=["敏感信息管理"]
|
||||
)
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ from middle.auth_middleware import (
|
||||
|
||||
# 创建用户接口路由(前缀 /users、标签用于 Swagger 分类)
|
||||
router = APIRouter(
|
||||
prefix="/users",
|
||||
prefix="/api/users",
|
||||
tags=["用户管理"]
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user