| 
									
										
										
										
											2025-09-02 18:51:50 +08:00
										 |  |  |  | from datetime import datetime, timedelta, timezone | 
					
						
							|  |  |  |  | from typing import Optional | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | from fastapi import Depends, HTTPException, status | 
					
						
							|  |  |  |  | from fastapi.security import OAuth2PasswordBearer | 
					
						
							|  |  |  |  | from jose import JWTError, jwt | 
					
						
							|  |  |  |  | from passlib.context import CryptContext | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | from ds.config import JWT_CONFIG | 
					
						
							|  |  |  |  | from ds.db import db | 
					
						
							| 
									
										
										
										
											2025-09-04 22:59:27 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-02 18:51:50 +08:00
										 |  |  |  | # ------------------------------ | 
					
						
							|  |  |  |  | # 密码加密配置 | 
					
						
							|  |  |  |  | # ------------------------------ | 
					
						
							|  |  |  |  | pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | # ------------------------------ | 
					
						
							|  |  |  |  | # JWT 配置 | 
					
						
							|  |  |  |  | # ------------------------------ | 
					
						
							|  |  |  |  | SECRET_KEY = JWT_CONFIG["secret_key"] | 
					
						
							|  |  |  |  | ALGORITHM = JWT_CONFIG["algorithm"] | 
					
						
							|  |  |  |  | ACCESS_TOKEN_EXPIRE_MINUTES = int(JWT_CONFIG["access_token_expire_minutes"]) | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-08 17:34:23 +08:00
										 |  |  |  | # OAuth2 依赖(从请求头获取 Token、格式: Bearer <token>) | 
					
						
							| 
									
										
										
										
											2025-09-02 18:51:50 +08:00
										 |  |  |  | oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login") | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-04 22:59:27 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-02 18:51:50 +08:00
										 |  |  |  | # ------------------------------ | 
					
						
							|  |  |  |  | # 密码工具函数 | 
					
						
							|  |  |  |  | # ------------------------------ | 
					
						
							|  |  |  |  | def verify_password(plain_password: str, hashed_password: str) -> bool: | 
					
						
							|  |  |  |  |     """验证明文密码与加密密码是否匹配""" | 
					
						
							|  |  |  |  |     return pwd_context.verify(plain_password, hashed_password) | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-04 22:59:27 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-02 18:51:50 +08:00
										 |  |  |  | def get_password_hash(password: str) -> str: | 
					
						
							|  |  |  |  |     """对明文密码进行 bcrypt 加密""" | 
					
						
							|  |  |  |  |     return pwd_context.hash(password) | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-04 22:59:27 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-02 18:51:50 +08:00
										 |  |  |  | # ------------------------------ | 
					
						
							|  |  |  |  | # JWT 工具函数 | 
					
						
							|  |  |  |  | # ------------------------------ | 
					
						
							|  |  |  |  | def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: | 
					
						
							|  |  |  |  |     """生成 JWT Token""" | 
					
						
							|  |  |  |  |     to_encode = data.copy() | 
					
						
							|  |  |  |  |     # 设置过期时间 | 
					
						
							|  |  |  |  |     if expires_delta: | 
					
						
							|  |  |  |  |         expire = datetime.now(timezone.utc) + expires_delta | 
					
						
							|  |  |  |  |     else: | 
					
						
							|  |  |  |  |         expire = datetime.now(timezone.utc) + timedelta(minutes=15) | 
					
						
							|  |  |  |  |     # 添加过期时间到 Token 数据 | 
					
						
							|  |  |  |  |     to_encode.update({"exp": expire}) | 
					
						
							|  |  |  |  |     # 生成 Token | 
					
						
							|  |  |  |  |     encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) | 
					
						
							|  |  |  |  |     return encoded_jwt | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-04 22:59:27 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-02 18:51:50 +08:00
										 |  |  |  | # ------------------------------ | 
					
						
							|  |  |  |  | # 认证依赖(获取当前登录用户) | 
					
						
							|  |  |  |  | # ------------------------------ | 
					
						
							| 
									
										
										
										
											2025-09-04 22:59:27 +08:00
										 |  |  |  | def get_current_user(token: str = Depends(oauth2_scheme)):  # 移除返回类型注解 | 
					
						
							| 
									
										
										
										
											2025-09-02 18:51:50 +08:00
										 |  |  |  |     """从 Token 中解析用户信息、验证通过后返回当前用户""" | 
					
						
							| 
									
										
										
										
											2025-09-08 17:34:23 +08:00
										 |  |  |  |     # 延迟导入、打破循环依赖 | 
					
						
							| 
									
										
										
										
											2025-09-04 22:59:27 +08:00
										 |  |  |  |     from schema.user_schema import UserResponse  # 在这里导入 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-02 18:51:50 +08:00
										 |  |  |  |     # 认证失败异常 | 
					
						
							|  |  |  |  |     credentials_exception = HTTPException( | 
					
						
							|  |  |  |  |         status_code=status.HTTP_401_UNAUTHORIZED, | 
					
						
							|  |  |  |  |         detail="Token 无效或已过期", | 
					
						
							|  |  |  |  |         headers={"WWW-Authenticate": "Bearer"}, | 
					
						
							|  |  |  |  |     ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     try: | 
					
						
							|  |  |  |  |         # 解码 Token | 
					
						
							|  |  |  |  |         payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | 
					
						
							|  |  |  |  |         # 获取 Token 中的用户名 | 
					
						
							|  |  |  |  |         username: str = payload.get("sub") | 
					
						
							|  |  |  |  |         if username is None: | 
					
						
							|  |  |  |  |             raise credentials_exception | 
					
						
							|  |  |  |  |     except JWTError: | 
					
						
							|  |  |  |  |         raise credentials_exception | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     # 从数据库查询用户(验证用户是否存在) | 
					
						
							|  |  |  |  |     conn = None | 
					
						
							|  |  |  |  |     cursor = None | 
					
						
							|  |  |  |  |     try: | 
					
						
							|  |  |  |  |         conn = db.get_connection() | 
					
						
							|  |  |  |  |         cursor = conn.cursor(dictionary=True)  # 返回字典格式结果 | 
					
						
							|  |  |  |  |         query = "SELECT id, username, created_at, updated_at FROM users WHERE username = %s" | 
					
						
							|  |  |  |  |         cursor.execute(query, (username,)) | 
					
						
							|  |  |  |  |         user = cursor.fetchone() | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         if user is None: | 
					
						
							|  |  |  |  |             raise credentials_exception  # 用户不存在 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 转换为 UserResponse 模型(自动校验字段) | 
					
						
							| 
									
										
										
										
											2025-09-04 22:59:27 +08:00
										 |  |  |  |         return UserResponse(**user) | 
					
						
							| 
									
										
										
										
											2025-09-02 18:51:50 +08:00
										 |  |  |  |     except Exception as e: | 
					
						
							|  |  |  |  |         raise credentials_exception from e | 
					
						
							|  |  |  |  |     finally: | 
					
						
							| 
									
										
										
										
											2025-09-08 17:34:23 +08:00
										 |  |  |  |         db.close_connection(conn, cursor) |