Files
video/middle/auth_middleware.py
2025-09-04 22:59:27 +08:00

104 lines
3.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from datetime import datetime, timedelta, timezone
from typing import Optional
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from passlib.context import CryptContext
from ds.config import JWT_CONFIG
from ds.db import db
# 移除这里的 from service.user_service import UserResponse 导入
# ------------------------------
# 密码加密配置
# ------------------------------
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# ------------------------------
# JWT 配置
# ------------------------------
SECRET_KEY = JWT_CONFIG["secret_key"]
ALGORITHM = JWT_CONFIG["algorithm"]
ACCESS_TOKEN_EXPIRE_MINUTES = int(JWT_CONFIG["access_token_expire_minutes"])
# OAuth2 依赖(从请求头获取 Token、格式Bearer <token>
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login")
# ------------------------------
# 密码工具函数
# ------------------------------
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""验证明文密码与加密密码是否匹配"""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""对明文密码进行 bcrypt 加密"""
return pwd_context.hash(password)
# ------------------------------
# JWT 工具函数
# ------------------------------
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""生成 JWT Token"""
to_encode = data.copy()
# 设置过期时间
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=15)
# 添加过期时间到 Token 数据
to_encode.update({"exp": expire})
# 生成 Token
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
# ------------------------------
# 认证依赖(获取当前登录用户)
# ------------------------------
def get_current_user(token: str = Depends(oauth2_scheme)): # 移除返回类型注解
"""从 Token 中解析用户信息、验证通过后返回当前用户"""
# 延迟导入,打破循环依赖
from schema.user_schema import UserResponse # 在这里导入
# 认证失败异常
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token 无效或已过期",
headers={"WWW-Authenticate": "Bearer"},
)
try:
# 解码 Token
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
# 获取 Token 中的用户名
username: str = payload.get("sub")
if username is None:
raise credentials_exception
except JWTError:
raise credentials_exception
# 从数据库查询用户(验证用户是否存在)
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True) # 返回字典格式结果
query = "SELECT id, username, created_at, updated_at FROM users WHERE username = %s"
cursor.execute(query, (username,))
user = cursor.fetchone()
if user is None:
raise credentials_exception # 用户不存在
# 转换为 UserResponse 模型(自动校验字段)
return UserResponse(**user)
except Exception as e:
raise credentials_exception from e
finally:
db.close_connection(conn, cursor)