247 lines
		
	
	
		
			8.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			247 lines
		
	
	
		
			8.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from datetime import timedelta
 | ||
| from typing import Optional
 | ||
| 
 | ||
| from fastapi import APIRouter, Depends, HTTPException, Query
 | ||
| from mysql.connector import Error as MySQLError
 | ||
| 
 | ||
| from ds.db import db
 | ||
| from encryption.encrypt_decorator import encrypt_response
 | ||
| from middle.auth_middleware import (
 | ||
|     get_password_hash,
 | ||
|     verify_password,
 | ||
|     create_access_token,
 | ||
|     ACCESS_TOKEN_EXPIRE_MINUTES,
 | ||
|     get_current_user
 | ||
| )
 | ||
| from schema.response_schema import APIResponse
 | ||
| from schema.user_schema import UserRegisterRequest, UserLoginRequest, UserResponse
 | ||
| 
 | ||
| router = APIRouter(
 | ||
|     prefix="/api/users",
 | ||
|     tags=["用户管理"]
 | ||
| )
 | ||
| 
 | ||
| 
 | ||
| # 用户注册接口
 | ||
| @router.post("/register", response_model=APIResponse, summary="用户注册")
 | ||
| @encrypt_response()
 | ||
| async def user_register(request: UserRegisterRequest):
 | ||
|     """
 | ||
|     用户注册:
 | ||
|     - 校验用户名是否已存在
 | ||
|     - 加密密码后插入数据库
 | ||
|     - 返回注册成功信息
 | ||
|     """
 | ||
|     conn = None
 | ||
|     cursor = None
 | ||
|     try:
 | ||
|         conn = db.get_connection()
 | ||
|         cursor = conn.cursor(dictionary=True)
 | ||
| 
 | ||
|         # 1. 检查用户名是否已存在(唯一索引)
 | ||
|         check_query = "SELECT username FROM users WHERE username = %s"
 | ||
|         cursor.execute(check_query, (request.username,))
 | ||
|         existing_user = cursor.fetchone()
 | ||
|         if existing_user:
 | ||
|             raise HTTPException(
 | ||
|                 status_code=400,
 | ||
|                 detail=f"用户名 '{request.username}' 已存在、请更换其他用户名"
 | ||
|             )
 | ||
| 
 | ||
|         # 2. 加密密码
 | ||
|         hashed_password = get_password_hash(request.password)
 | ||
| 
 | ||
|         # 3. 插入新用户到数据库
 | ||
|         insert_query = """
 | ||
|             INSERT INTO users (username, password)
 | ||
|             VALUES (%s, %s)
 | ||
|         """
 | ||
|         cursor.execute(insert_query, (request.username, hashed_password))
 | ||
|         conn.commit()  # 提交事务
 | ||
| 
 | ||
|         # 4. 返回注册成功响应
 | ||
|         return APIResponse(
 | ||
|             code=200,  # 200 表示资源创建成功
 | ||
|             message=f"用户 '{request.username}' 注册成功",
 | ||
|             data=None
 | ||
|         )
 | ||
|     except MySQLError as e:
 | ||
|         conn.rollback()  # 数据库错误时回滚事务
 | ||
|         raise Exception(f"注册失败: {str(e)}") from e
 | ||
|     finally:
 | ||
|         db.close_connection(conn, cursor)
 | ||
| 
 | ||
| 
 | ||
| # 用户登录接口
 | ||
| @router.post("/login", response_model=APIResponse, summary="用户登录(获取 Token)")
 | ||
| @encrypt_response()
 | ||
| async def user_login(request: UserLoginRequest):
 | ||
|     """
 | ||
|     用户登录:
 | ||
|     - 校验用户名是否存在
 | ||
|     - 校验密码是否正确
 | ||
|     - 生成 JWT Token 并返回
 | ||
|     """
 | ||
|     conn = None
 | ||
|     cursor = None
 | ||
|     try:
 | ||
|         conn = db.get_connection()
 | ||
|         cursor = conn.cursor(dictionary=True)
 | ||
| 
 | ||
|         # 修复: SQL查询添加 created_at 和 updated_at 字段
 | ||
|         query = """
 | ||
|             SELECT id, username, password, created_at, updated_at 
 | ||
|             FROM users 
 | ||
|             WHERE username = %s
 | ||
|         """
 | ||
|         cursor.execute(query, (request.username,))
 | ||
|         user = cursor.fetchone()
 | ||
| 
 | ||
|         # 2. 校验用户名和密码
 | ||
|         if not user or not verify_password(request.password, user["password"]):
 | ||
|             raise HTTPException(
 | ||
|                 status_code=401,
 | ||
|                 detail="用户名或密码错误",
 | ||
|                 headers={"WWW-Authenticate": "Bearer"},
 | ||
|             )
 | ||
| 
 | ||
|         # 3. 生成 Token(过期时间从配置读取)
 | ||
|         access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
 | ||
|         access_token = create_access_token(
 | ||
|             data={"sub": user["username"]},
 | ||
|             expires_delta=access_token_expires
 | ||
|         )
 | ||
| 
 | ||
|         # 4. 返回 Token 和用户基本信息
 | ||
|         return APIResponse(
 | ||
|             code=200,
 | ||
|             message="登录成功",
 | ||
|             data={
 | ||
|                 "access_token": access_token,
 | ||
|                 "token_type": "bearer",
 | ||
|                 "user": UserResponse(
 | ||
|                     id=user["id"],
 | ||
|                     username=user["username"],
 | ||
|                     created_at=user.get("created_at"),
 | ||
|                     updated_at=user.get("updated_at")
 | ||
|                 )
 | ||
|             }
 | ||
|         )
 | ||
|     except MySQLError as e:
 | ||
|         raise Exception(f"登录失败: {str(e)}") from e
 | ||
|     finally:
 | ||
|         db.close_connection(conn, cursor)
 | ||
| 
 | ||
| 
 | ||
| # 获取当前登录用户信息(需认证)
 | ||
| @router.get("/me", response_model=APIResponse, summary="获取当前用户信息")
 | ||
| @encrypt_response()
 | ||
| async def get_current_user_info(
 | ||
|         current_user: UserResponse = Depends(get_current_user)  # 依赖认证中间件
 | ||
| ):
 | ||
|     """
 | ||
|     获取当前登录用户信息:
 | ||
|     - 需在请求头携带 Token(格式: Bearer <token>)
 | ||
|     - 认证通过后返回用户信息
 | ||
|     """
 | ||
|     return APIResponse(
 | ||
|         code=200,
 | ||
|         message="获取用户信息成功",
 | ||
|         data=current_user
 | ||
|     )
 | ||
| 
 | ||
| 
 | ||
| # 获取用户列表(仅需登录权限)
 | ||
| @router.get("/list", response_model=APIResponse, summary="获取用户列表")
 | ||
| @encrypt_response()
 | ||
| async def get_user_list(
 | ||
|         page: int = Query(1, ge=1, description="页码、从1开始"),
 | ||
|         page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间"),
 | ||
|         username: Optional[str] = Query(None, description="用户名模糊搜索"),
 | ||
|         current_user: UserResponse = Depends(get_current_user)  # 仅需登录即可访问(移除管理员校验)
 | ||
| ):
 | ||
|     """
 | ||
|     获取用户列表:
 | ||
|     - 需登录权限(请求头携带 Token: Bearer <token>)
 | ||
|     - 支持分页查询(page=页码、page_size=每页条数)
 | ||
|     - 支持用户名模糊搜索(如输入"test"可匹配"test123"、"admin_test"等)
 | ||
|     - 仅返回用户ID、用户名、创建时间、更新时间(不包含密码等敏感信息)
 | ||
|     """
 | ||
|     conn = None
 | ||
|     cursor = None
 | ||
|     try:
 | ||
|         conn = db.get_connection()
 | ||
|         cursor = conn.cursor(dictionary=True)
 | ||
| 
 | ||
|         # 计算分页偏移量(page从1开始、偏移量=(页码-1)*每页条数)
 | ||
|         offset = (page - 1) * page_size
 | ||
| 
 | ||
|         # 基础查询(仅查非敏感字段)
 | ||
|         base_query = """
 | ||
|             SELECT id, username, created_at, updated_at 
 | ||
|             FROM users
 | ||
|         """
 | ||
|         # 总条数查询(用于分页计算)
 | ||
|         count_query = "SELECT COUNT(*) as total FROM users"
 | ||
| 
 | ||
|         # 条件拼接(支持用户名模糊搜索)
 | ||
|         conditions = []
 | ||
|         params = []
 | ||
|         if username:
 | ||
|             conditions.append("username LIKE %s")
 | ||
|             params.append(f"%{username}%")  # 模糊匹配:%表示任意字符
 | ||
| 
 | ||
|         # 构建最终查询语句
 | ||
|         if conditions:
 | ||
|             where_clause = " WHERE " + " AND ".join(conditions)
 | ||
|             final_query = f"{base_query}{where_clause} LIMIT %s OFFSET %s"
 | ||
|             final_count_query = f"{count_query}{where_clause}"
 | ||
|             params.extend([page_size, offset])  # 追加分页参数
 | ||
|         else:
 | ||
|             final_query = f"{base_query} LIMIT %s OFFSET %s"
 | ||
|             final_count_query = count_query
 | ||
|             params = [page_size, offset]
 | ||
| 
 | ||
|         # 1. 查询用户列表数据
 | ||
|         cursor.execute(final_query, params)
 | ||
|         users = cursor.fetchall()
 | ||
| 
 | ||
|         # 2. 查询总条数(用于计算总页数)
 | ||
|         count_params = [f"%{username}%"] if username else []
 | ||
|         cursor.execute(final_count_query, count_params)
 | ||
|         total = cursor.fetchone()["total"]
 | ||
| 
 | ||
|         # 3. 转换为UserResponse模型(确保字段匹配)
 | ||
|         user_list = [
 | ||
|             UserResponse(
 | ||
|                 id=user["id"],
 | ||
|                 username=user["username"],
 | ||
|                 created_at=user["created_at"],
 | ||
|                 updated_at=user["updated_at"]
 | ||
|             )
 | ||
|             for user in users
 | ||
|         ]
 | ||
| 
 | ||
|         # 4. 计算总页数(向上取整、如11条数据每页10条=2页)
 | ||
|         total_pages = (total + page_size - 1) // page_size
 | ||
| 
 | ||
|         # 返回结果(包含列表和分页信息)
 | ||
|         return APIResponse(
 | ||
|             code=200,
 | ||
|             message="获取用户列表成功",
 | ||
|             data={
 | ||
|                 "users": user_list,
 | ||
|                 "pagination": {
 | ||
|                     "page": page,  # 当前页码
 | ||
|                     "page_size": page_size,  # 每页条数
 | ||
|                     "total": total,  # 总数据量
 | ||
|                     "total_pages": total_pages  # 总页数
 | ||
|                 }
 | ||
|             }
 | ||
|         )
 | ||
|     except MySQLError as e:
 | ||
|         raise Exception(f"获取用户列表失败: {str(e)}") from e
 | ||
|     finally:
 | ||
|         # 无论成功失败、都关闭数据库连接
 | ||
|         db.close_connection(conn, cursor)
 |