103 lines
		
	
	
		
			3.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			103 lines
		
	
	
		
			3.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from datetime import datetime, timedelta, timezone
 | ||
| from typing import Optional
 | ||
| 
 | ||
| from fastapi import Depends, HTTPException, status
 | ||
| from fastapi.security import OAuth2PasswordBearer
 | ||
| from jose import JWTError, jwt
 | ||
| from passlib.context import CryptContext
 | ||
| 
 | ||
| from ds.config import JWT_CONFIG
 | ||
| from ds.db import db
 | ||
| 
 | ||
| # ------------------------------
 | ||
| # 密码加密配置
 | ||
| # ------------------------------
 | ||
| pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
 | ||
| 
 | ||
| # ------------------------------
 | ||
| # JWT 配置
 | ||
| # ------------------------------
 | ||
| SECRET_KEY = JWT_CONFIG["secret_key"]
 | ||
| ALGORITHM = JWT_CONFIG["algorithm"]
 | ||
| ACCESS_TOKEN_EXPIRE_MINUTES = int(JWT_CONFIG["access_token_expire_minutes"])
 | ||
| 
 | ||
| # OAuth2 依赖(从请求头获取 Token、格式: Bearer <token>)
 | ||
| oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login")
 | ||
| 
 | ||
| 
 | ||
| # ------------------------------
 | ||
| # 密码工具函数
 | ||
| # ------------------------------
 | ||
| def verify_password(plain_password: str, hashed_password: str) -> bool:
 | ||
|     """验证明文密码与加密密码是否匹配"""
 | ||
|     return pwd_context.verify(plain_password, hashed_password)
 | ||
| 
 | ||
| 
 | ||
| def get_password_hash(password: str) -> str:
 | ||
|     """对明文密码进行 bcrypt 加密"""
 | ||
|     return pwd_context.hash(password)
 | ||
| 
 | ||
| 
 | ||
| # ------------------------------
 | ||
| # JWT 工具函数
 | ||
| # ------------------------------
 | ||
| def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
 | ||
|     """生成 JWT Token"""
 | ||
|     to_encode = data.copy()
 | ||
|     # 设置过期时间
 | ||
|     if expires_delta:
 | ||
|         expire = datetime.now(timezone.utc) + expires_delta
 | ||
|     else:
 | ||
|         expire = datetime.now(timezone.utc) + timedelta(minutes=15)
 | ||
|     # 添加过期时间到 Token 数据
 | ||
|     to_encode.update({"exp": expire})
 | ||
|     # 生成 Token
 | ||
|     encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
 | ||
|     return encoded_jwt
 | ||
| 
 | ||
| 
 | ||
| # ------------------------------
 | ||
| # 认证依赖(获取当前登录用户)
 | ||
| # ------------------------------
 | ||
| def get_current_user(token: str = Depends(oauth2_scheme)):  # 移除返回类型注解
 | ||
|     """从 Token 中解析用户信息、验证通过后返回当前用户"""
 | ||
|     # 延迟导入、打破循环依赖
 | ||
|     from schema.user_schema import UserResponse  # 在这里导入
 | ||
| 
 | ||
|     # 认证失败异常
 | ||
|     credentials_exception = HTTPException(
 | ||
|         status_code=status.HTTP_401_UNAUTHORIZED,
 | ||
|         detail="Token 无效或已过期",
 | ||
|         headers={"WWW-Authenticate": "Bearer"},
 | ||
|     )
 | ||
| 
 | ||
|     try:
 | ||
|         # 解码 Token
 | ||
|         payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
 | ||
|         # 获取 Token 中的用户名
 | ||
|         username: str = payload.get("sub")
 | ||
|         if username is None:
 | ||
|             raise credentials_exception
 | ||
|     except JWTError:
 | ||
|         raise credentials_exception
 | ||
| 
 | ||
|     # 从数据库查询用户(验证用户是否存在)
 | ||
|     conn = None
 | ||
|     cursor = None
 | ||
|     try:
 | ||
|         conn = db.get_connection()
 | ||
|         cursor = conn.cursor(dictionary=True)
 | ||
|         query = "SELECT id, username, created_at, updated_at FROM users WHERE username = %s"
 | ||
|         cursor.execute(query, (username,))
 | ||
|         user = cursor.fetchone()
 | ||
| 
 | ||
|         if user is None:
 | ||
|             raise credentials_exception
 | ||
| 
 | ||
|         # 转换为 UserResponse 模型(自动校验字段)
 | ||
|         return UserResponse(**user)
 | ||
|     except Exception as e:
 | ||
|         raise credentials_exception from e
 | ||
|     finally:
 | ||
|         db.close_connection(conn, cursor)
 |