from datetime import datetime, timedelta, timezone from typing import Optional from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer from jose import JWTError, jwt from passlib.context import CryptContext from ds.config import JWT_CONFIG from ds.db import db from service.user_service import UserResponse # ------------------------------ # 密码加密配置 # ------------------------------ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") # ------------------------------ # JWT 配置 # ------------------------------ SECRET_KEY = JWT_CONFIG["secret_key"] ALGORITHM = JWT_CONFIG["algorithm"] ACCESS_TOKEN_EXPIRE_MINUTES = int(JWT_CONFIG["access_token_expire_minutes"]) # OAuth2 依赖(从请求头获取 Token、格式:Bearer ) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login") # ------------------------------ # 密码工具函数 # ------------------------------ def verify_password(plain_password: str, hashed_password: str) -> bool: """验证明文密码与加密密码是否匹配""" return pwd_context.verify(plain_password, hashed_password) def get_password_hash(password: str) -> str: """对明文密码进行 bcrypt 加密""" return pwd_context.hash(password) # ------------------------------ # JWT 工具函数 # ------------------------------ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: """生成 JWT Token""" to_encode = data.copy() # 设置过期时间 if expires_delta: expire = datetime.now(timezone.utc) + expires_delta else: expire = datetime.now(timezone.utc) + timedelta(minutes=15) # 添加过期时间到 Token 数据 to_encode.update({"exp": expire}) # 生成 Token encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt # ------------------------------ # 认证依赖(获取当前登录用户) # ------------------------------ def get_current_user(token: str = Depends(oauth2_scheme)) -> UserResponse: """从 Token 中解析用户信息、验证通过后返回当前用户""" # 认证失败异常 credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token 无效或已过期", headers={"WWW-Authenticate": "Bearer"}, ) try: # 解码 Token payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) # 获取 Token 中的用户名 username: str = payload.get("sub") if username is None: raise credentials_exception except JWTError: raise credentials_exception # 从数据库查询用户(验证用户是否存在) conn = None cursor = None try: conn = db.get_connection() cursor = conn.cursor(dictionary=True) # 返回字典格式结果 query = "SELECT id, username, created_at, updated_at FROM users WHERE username = %s" cursor.execute(query, (username,)) user = cursor.fetchone() if user is None: raise credentials_exception # 用户不存在 # 转换为 UserResponse 模型(自动校验字段) return UserResponse(** user) except Exception as e: raise credentials_exception from e finally: db.close_connection(conn, cursor)