103 lines
		
	
	
		
			3.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			103 lines
		
	
	
		
			3.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|  | from datetime import datetime, timedelta, timezone | |||
|  | from typing import Optional | |||
|  | 
 | |||
|  | from fastapi import Depends, HTTPException, status | |||
|  | from fastapi.security import OAuth2PasswordBearer | |||
|  | from jose import JWTError, jwt | |||
|  | from passlib.context import CryptContext | |||
|  | 
 | |||
|  | from ds.config import JWT_CONFIG | |||
|  | from ds.db import db | |||
|  | 
 | |||
|  | # ------------------------------ | |||
|  | # 密码加密配置 | |||
|  | # ------------------------------ | |||
|  | pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") | |||
|  | 
 | |||
|  | # ------------------------------ | |||
|  | # JWT 配置 | |||
|  | # ------------------------------ | |||
|  | SECRET_KEY = JWT_CONFIG["secret_key"] | |||
|  | ALGORITHM = JWT_CONFIG["algorithm"] | |||
|  | ACCESS_TOKEN_EXPIRE_MINUTES = int(JWT_CONFIG["access_token_expire_minutes"]) | |||
|  | 
 | |||
|  | # OAuth2 依赖(从请求头获取 Token、格式: Bearer <token>) | |||
|  | oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login") | |||
|  | 
 | |||
|  | 
 | |||
|  | # ------------------------------ | |||
|  | # 密码工具函数 | |||
|  | # ------------------------------ | |||
|  | def verify_password(plain_password: str, hashed_password: str) -> bool: | |||
|  |     """验证明文密码与加密密码是否匹配""" | |||
|  |     return pwd_context.verify(plain_password, hashed_password) | |||
|  | 
 | |||
|  | 
 | |||
|  | def get_password_hash(password: str) -> str: | |||
|  |     """对明文密码进行 bcrypt 加密""" | |||
|  |     return pwd_context.hash(password) | |||
|  | 
 | |||
|  | 
 | |||
|  | # ------------------------------ | |||
|  | # JWT 工具函数 | |||
|  | # ------------------------------ | |||
|  | def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: | |||
|  |     """生成 JWT Token""" | |||
|  |     to_encode = data.copy() | |||
|  |     # 设置过期时间 | |||
|  |     if expires_delta: | |||
|  |         expire = datetime.now(timezone.utc) + expires_delta | |||
|  |     else: | |||
|  |         expire = datetime.now(timezone.utc) + timedelta(minutes=15) | |||
|  |     # 添加过期时间到 Token 数据 | |||
|  |     to_encode.update({"exp": expire}) | |||
|  |     # 生成 Token | |||
|  |     encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) | |||
|  |     return encoded_jwt | |||
|  | 
 | |||
|  | 
 | |||
|  | # ------------------------------ | |||
|  | # 认证依赖(获取当前登录用户) | |||
|  | # ------------------------------ | |||
|  | def get_current_user(token: str = Depends(oauth2_scheme)):  # 移除返回类型注解 | |||
|  |     """从 Token 中解析用户信息、验证通过后返回当前用户""" | |||
|  |     # 延迟导入、打破循环依赖 | |||
|  |     from schema.user_schema import UserResponse  # 在这里导入 | |||
|  | 
 | |||
|  |     # 认证失败异常 | |||
|  |     credentials_exception = HTTPException( | |||
|  |         status_code=status.HTTP_401_UNAUTHORIZED, | |||
|  |         detail="Token 无效或已过期", | |||
|  |         headers={"WWW-Authenticate": "Bearer"}, | |||
|  |     ) | |||
|  | 
 | |||
|  |     try: | |||
|  |         # 解码 Token | |||
|  |         payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | |||
|  |         # 获取 Token 中的用户名 | |||
|  |         username: str = payload.get("sub") | |||
|  |         if username is None: | |||
|  |             raise credentials_exception | |||
|  |     except JWTError: | |||
|  |         raise credentials_exception | |||
|  | 
 | |||
|  |     # 从数据库查询用户(验证用户是否存在) | |||
|  |     conn = None | |||
|  |     cursor = None | |||
|  |     try: | |||
|  |         conn = db.get_connection() | |||
|  |         cursor = conn.cursor(dictionary=True) | |||
|  |         query = "SELECT id, username, created_at, updated_at FROM users WHERE username = %s" | |||
|  |         cursor.execute(query, (username,)) | |||
|  |         user = cursor.fetchone() | |||
|  | 
 | |||
|  |         if user is None: | |||
|  |             raise credentials_exception | |||
|  | 
 | |||
|  |         # 转换为 UserResponse 模型(自动校验字段) | |||
|  |         return UserResponse(**user) | |||
|  |     except Exception as e: | |||
|  |         raise credentials_exception from e | |||
|  |     finally: | |||
|  |         db.close_connection(conn, cursor) |