from datetime import timedelta from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query from mysql.connector import Error as MySQLError from ds.db import db from schema.user_schema import UserRegisterRequest, UserLoginRequest, UserResponse from schema.response_schema import APIResponse from middle.auth_middleware import ( get_password_hash, verify_password, create_access_token, ACCESS_TOKEN_EXPIRE_MINUTES, get_current_user # 仅保留登录用户校验,移除is_admin导入 ) # 创建用户接口路由(前缀 /users、标签用于 Swagger 分类) router = APIRouter( prefix="/users", tags=["用户管理"] ) # ------------------------------ # 1. 用户注册接口 # ------------------------------ @router.post("/register", response_model=APIResponse, summary="用户注册") async def user_register(request: UserRegisterRequest): """ 用户注册: - 校验用户名是否已存在 - 加密密码后插入数据库 - 返回注册成功信息 """ conn = None cursor = None try: conn = db.get_connection() cursor = conn.cursor(dictionary=True) # 1. 检查用户名是否已存在(唯一索引) check_query = "SELECT username FROM users WHERE username = %s" cursor.execute(check_query, (request.username,)) existing_user = cursor.fetchone() if existing_user: raise HTTPException( status_code=400, detail=f"用户名 '{request.username}' 已存在、请更换其他用户名" ) # 2. 加密密码 hashed_password = get_password_hash(request.password) # 3. 插入新用户到数据库 insert_query = """ INSERT INTO users (username, password) VALUES (%s, %s) """ cursor.execute(insert_query, (request.username, hashed_password)) conn.commit() # 提交事务 # 4. 返回注册成功响应 return APIResponse( code=201, # 201 表示资源创建成功 message=f"用户 '{request.username}' 注册成功", data=None ) except MySQLError as e: conn.rollback() # 数据库错误时回滚事务 raise Exception(f"注册失败: {str(e)}") from e finally: db.close_connection(conn, cursor) # ------------------------------ # 2. 用户登录接口 # ------------------------------ @router.post("/login", response_model=APIResponse, summary="用户登录(获取 Token)") async def user_login(request: UserLoginRequest): """ 用户登录: - 校验用户名是否存在 - 校验密码是否正确 - 生成 JWT Token 并返回 """ conn = None cursor = None try: conn = db.get_connection() cursor = conn.cursor(dictionary=True) # 修复: SQL查询添加 created_at 和 updated_at 字段 query = """ SELECT id, username, password, created_at, updated_at FROM users WHERE username = %s """ cursor.execute(query, (request.username,)) user = cursor.fetchone() # 2. 校验用户名和密码 if not user or not verify_password(request.password, user["password"]): raise HTTPException( status_code=401, detail="用户名或密码错误", headers={"WWW-Authenticate": "Bearer"}, ) # 3. 生成 Token(过期时间从配置读取) access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) access_token = create_access_token( data={"sub": user["username"]}, expires_delta=access_token_expires ) # 4. 返回 Token 和用户基本信息 return APIResponse( code=200, message="登录成功", data={ "access_token": access_token, "token_type": "bearer", "user": UserResponse( id=user["id"], username=user["username"], created_at=user.get("created_at"), updated_at=user.get("updated_at") ) } ) except MySQLError as e: raise Exception(f"登录失败: {str(e)}") from e finally: db.close_connection(conn, cursor) # ------------------------------ # 3. 获取当前登录用户信息(需认证) # ------------------------------ @router.get("/me", response_model=APIResponse, summary="获取当前用户信息") async def get_current_user_info( current_user: UserResponse = Depends(get_current_user) # 依赖认证中间件 ): """ 获取当前登录用户信息: - 需在请求头携带 Token(格式: Bearer ) - 认证通过后返回用户信息 """ return APIResponse( code=200, message="获取用户信息成功", data=current_user ) # ------------------------------ # 4. 获取用户列表(仅需登录权限) # ------------------------------ @router.get("/list", response_model=APIResponse, summary="获取用户列表") async def get_user_list( page: int = Query(1, ge=1, description="页码,从1开始"), page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100之间"), username: Optional[str] = Query(None, description="用户名模糊搜索"), current_user: UserResponse = Depends(get_current_user) # 仅需登录即可访问(移除管理员校验) ): """ 获取用户列表: - 需登录权限(请求头携带 Token: Bearer ) - 支持分页查询(page=页码,page_size=每页条数) - 支持用户名模糊搜索(如输入"test"可匹配"test123"、"admin_test"等) - 仅返回用户ID、用户名、创建时间、更新时间(不包含密码等敏感信息) """ conn = None cursor = None try: conn = db.get_connection() cursor = conn.cursor(dictionary=True) # 计算分页偏移量(page从1开始,偏移量=(页码-1)*每页条数) offset = (page - 1) * page_size # 基础查询(仅查非敏感字段) base_query = """ SELECT id, username, created_at, updated_at FROM users """ # 总条数查询(用于分页计算) count_query = "SELECT COUNT(*) as total FROM users" # 条件拼接(支持用户名模糊搜索) conditions = [] params = [] if username: conditions.append("username LIKE %s") params.append(f"%{username}%") # 模糊匹配:%表示任意字符 # 构建最终查询语句 if conditions: where_clause = " WHERE " + " AND ".join(conditions) final_query = f"{base_query}{where_clause} LIMIT %s OFFSET %s" final_count_query = f"{count_query}{where_clause}" params.extend([page_size, offset]) # 追加分页参数 else: final_query = f"{base_query} LIMIT %s OFFSET %s" final_count_query = count_query params = [page_size, offset] # 1. 查询用户列表数据 cursor.execute(final_query, params) users = cursor.fetchall() # 2. 查询总条数(用于计算总页数) count_params = [f"%{username}%"] if username else [] cursor.execute(final_count_query, count_params) total = cursor.fetchone()["total"] # 3. 转换为UserResponse模型(确保字段匹配) user_list = [ UserResponse( id=user["id"], username=user["username"], created_at=user["created_at"], updated_at=user["updated_at"] ) for user in users ] # 4. 计算总页数(向上取整,如11条数据每页10条=2页) total_pages = (total + page_size - 1) // page_size # 返回结果(包含列表和分页信息) return APIResponse( code=200, message="获取用户列表成功", data={ "users": user_list, "pagination": { "page": page, # 当前页码 "page_size": page_size, # 每页条数 "total": total, # 总数据量 "total_pages": total_pages # 总页数 } } ) except MySQLError as e: raise Exception(f"获取用户列表失败: {str(e)}") from e finally: # 无论成功失败,都关闭数据库连接 db.close_connection(conn, cursor)