初始化

This commit is contained in:
ZZX9599
2025-09-02 18:51:50 +08:00
commit fe1b33a6e5
30 changed files with 1607 additions and 0 deletions

Binary file not shown.

Binary file not shown.

96
middle/auth_middleware.py Normal file
View File

@ -0,0 +1,96 @@
from datetime import datetime, timedelta, timezone
from typing import Optional
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from passlib.context import CryptContext
from ds.config import JWT_CONFIG
from ds.db import db
from service.user_service import UserResponse
# ------------------------------
# 密码加密配置
# ------------------------------
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# ------------------------------
# JWT 配置
# ------------------------------
SECRET_KEY = JWT_CONFIG["secret_key"]
ALGORITHM = JWT_CONFIG["algorithm"]
ACCESS_TOKEN_EXPIRE_MINUTES = int(JWT_CONFIG["access_token_expire_minutes"])
# OAuth2 依赖(从请求头获取 Token、格式Bearer <token>
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login")
# ------------------------------
# 密码工具函数
# ------------------------------
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""验证明文密码与加密密码是否匹配"""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""对明文密码进行 bcrypt 加密"""
return pwd_context.hash(password)
# ------------------------------
# JWT 工具函数
# ------------------------------
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""生成 JWT Token"""
to_encode = data.copy()
# 设置过期时间
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=15)
# 添加过期时间到 Token 数据
to_encode.update({"exp": expire})
# 生成 Token
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
# ------------------------------
# 认证依赖(获取当前登录用户)
# ------------------------------
def get_current_user(token: str = Depends(oauth2_scheme)) -> UserResponse:
"""从 Token 中解析用户信息、验证通过后返回当前用户"""
# 认证失败异常
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token 无效或已过期",
headers={"WWW-Authenticate": "Bearer"},
)
try:
# 解码 Token
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
# 获取 Token 中的用户名
username: str = payload.get("sub")
if username is None:
raise credentials_exception
except JWTError:
raise credentials_exception
# 从数据库查询用户(验证用户是否存在)
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True) # 返回字典格式结果
query = "SELECT id, username, created_at, updated_at FROM users WHERE username = %s"
cursor.execute(query, (username,))
user = cursor.fetchone()
if user is None:
raise credentials_exception # 用户不存在
# 转换为 UserResponse 模型(自动校验字段)
return UserResponse(** user)
except Exception as e:
raise credentials_exception from e
finally:
db.close_connection(conn, cursor)

68
middle/error_handler.py Normal file
View File

@ -0,0 +1,68 @@
from fastapi import Request, status
from fastapi.responses import JSONResponse
from fastapi.exceptions import HTTPException, RequestValidationError
from mysql.connector import Error as MySQLError
from jose import JWTError
from schema.response_schema import APIResponse
async def global_exception_handler(request: Request, exc: Exception):
"""全局异常处理器:所有未捕获的异常都会在这里统一处理"""
# 1. 请求参数验证错误Pydantic 校验失败)
if isinstance(exc, RequestValidationError):
error_details = []
for err in exc.errors():
error_details.append(f"{err['loc'][1]}: {err['msg']}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=APIResponse(
code=400,
message=f"请求参数错误:{'; '.join(error_details)}",
data=None
).model_dump()
)
# 2. HTTP 异常(主动抛出的业务错误、如 401/404
if isinstance(exc, HTTPException):
return JSONResponse(
status_code=exc.status_code,
content=APIResponse(
code=exc.status_code,
message=exc.detail,
data=None
).model_dump()
)
# 3. JWT 相关错误Token 无效/过期)
if isinstance(exc, JWTError):
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content=APIResponse(
code=401,
message="Token 无效或已过期",
data=None
).model_dump(),
headers={"WWW-Authenticate": "Bearer"},
)
# 4. MySQL 数据库错误
if isinstance(exc, MySQLError):
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=APIResponse(
code=500,
message=f"数据库错误:{str(exc)}",
data=None
).model_dump()
)
# 5. 其他未知错误(兜底处理)
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=APIResponse(
code=500,
message=f"服务器内部错误:{str(exc)}",
data=None
).model_dump()
)