| 
									
										
										
										
											2025-09-02 18:51:50 +08:00
										 |  |  |  | from datetime import timedelta | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | from typing import Optional | 
					
						
							| 
									
										
										
										
											2025-09-02 18:51:50 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | from fastapi import APIRouter, Depends, HTTPException, Query | 
					
						
							| 
									
										
										
										
											2025-09-02 18:51:50 +08:00
										 |  |  |  | from mysql.connector import Error as MySQLError | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | from ds.db import db | 
					
						
							| 
									
										
										
										
											2025-09-15 18:08:54 +08:00
										 |  |  |  | from encryption.encrypt_decorator import encrypt_response | 
					
						
							| 
									
										
										
										
											2025-09-02 18:51:50 +08:00
										 |  |  |  | from schema.user_schema import UserRegisterRequest, UserLoginRequest, UserResponse | 
					
						
							|  |  |  |  | from schema.response_schema import APIResponse | 
					
						
							|  |  |  |  | from middle.auth_middleware import ( | 
					
						
							|  |  |  |  |     get_password_hash, | 
					
						
							|  |  |  |  |     verify_password, | 
					
						
							|  |  |  |  |     create_access_token, | 
					
						
							|  |  |  |  |     ACCESS_TOKEN_EXPIRE_MINUTES, | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  |     get_current_user  # 仅保留登录用户校验,移除is_admin导入 | 
					
						
							| 
									
										
										
										
											2025-09-02 18:51:50 +08:00
										 |  |  |  | ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | # 创建用户接口路由(前缀 /users、标签用于 Swagger 分类) | 
					
						
							|  |  |  |  | router = APIRouter( | 
					
						
							|  |  |  |  |     prefix="/users", | 
					
						
							|  |  |  |  |     tags=["用户管理"] | 
					
						
							|  |  |  |  | ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | # ------------------------------ | 
					
						
							|  |  |  |  | # 1. 用户注册接口 | 
					
						
							|  |  |  |  | # ------------------------------ | 
					
						
							|  |  |  |  | @router.post("/register", response_model=APIResponse, summary="用户注册") | 
					
						
							|  |  |  |  | async def user_register(request: UserRegisterRequest): | 
					
						
							|  |  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  |     用户注册: | 
					
						
							| 
									
										
										
										
											2025-09-02 18:51:50 +08:00
										 |  |  |  |     - 校验用户名是否已存在 | 
					
						
							|  |  |  |  |     - 加密密码后插入数据库 | 
					
						
							|  |  |  |  |     - 返回注册成功信息 | 
					
						
							|  |  |  |  |     """
 | 
					
						
							|  |  |  |  |     conn = None | 
					
						
							|  |  |  |  |     cursor = None | 
					
						
							|  |  |  |  |     try: | 
					
						
							|  |  |  |  |         conn = db.get_connection() | 
					
						
							|  |  |  |  |         cursor = conn.cursor(dictionary=True) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 1. 检查用户名是否已存在(唯一索引) | 
					
						
							|  |  |  |  |         check_query = "SELECT username FROM users WHERE username = %s" | 
					
						
							|  |  |  |  |         cursor.execute(check_query, (request.username,)) | 
					
						
							|  |  |  |  |         existing_user = cursor.fetchone() | 
					
						
							|  |  |  |  |         if existing_user: | 
					
						
							|  |  |  |  |             raise HTTPException( | 
					
						
							|  |  |  |  |                 status_code=400, | 
					
						
							|  |  |  |  |                 detail=f"用户名 '{request.username}' 已存在、请更换其他用户名" | 
					
						
							|  |  |  |  |             ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 2. 加密密码 | 
					
						
							|  |  |  |  |         hashed_password = get_password_hash(request.password) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 3. 插入新用户到数据库 | 
					
						
							|  |  |  |  |         insert_query = """
 | 
					
						
							|  |  |  |  |             INSERT INTO users (username, password) | 
					
						
							|  |  |  |  |             VALUES (%s, %s) | 
					
						
							|  |  |  |  |         """
 | 
					
						
							|  |  |  |  |         cursor.execute(insert_query, (request.username, hashed_password)) | 
					
						
							|  |  |  |  |         conn.commit()  # 提交事务 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 4. 返回注册成功响应 | 
					
						
							|  |  |  |  |         return APIResponse( | 
					
						
							|  |  |  |  |             code=201,  # 201 表示资源创建成功 | 
					
						
							|  |  |  |  |             message=f"用户 '{request.username}' 注册成功", | 
					
						
							|  |  |  |  |             data=None | 
					
						
							|  |  |  |  |         ) | 
					
						
							|  |  |  |  |     except MySQLError as e: | 
					
						
							|  |  |  |  |         conn.rollback()  # 数据库错误时回滚事务 | 
					
						
							| 
									
										
										
										
											2025-09-08 17:34:23 +08:00
										 |  |  |  |         raise Exception(f"注册失败: {str(e)}") from e | 
					
						
							| 
									
										
										
										
											2025-09-02 18:51:50 +08:00
										 |  |  |  |     finally: | 
					
						
							|  |  |  |  |         db.close_connection(conn, cursor) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | # ------------------------------ | 
					
						
							|  |  |  |  | # 2. 用户登录接口 | 
					
						
							|  |  |  |  | # ------------------------------ | 
					
						
							|  |  |  |  | @router.post("/login", response_model=APIResponse, summary="用户登录(获取 Token)") | 
					
						
							|  |  |  |  | async def user_login(request: UserLoginRequest): | 
					
						
							|  |  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  |     用户登录: | 
					
						
							| 
									
										
										
										
											2025-09-02 18:51:50 +08:00
										 |  |  |  |     - 校验用户名是否存在 | 
					
						
							|  |  |  |  |     - 校验密码是否正确 | 
					
						
							|  |  |  |  |     - 生成 JWT Token 并返回 | 
					
						
							|  |  |  |  |     """
 | 
					
						
							|  |  |  |  |     conn = None | 
					
						
							|  |  |  |  |     cursor = None | 
					
						
							|  |  |  |  |     try: | 
					
						
							|  |  |  |  |         conn = db.get_connection() | 
					
						
							|  |  |  |  |         cursor = conn.cursor(dictionary=True) | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-08 17:34:23 +08:00
										 |  |  |  |         # 修复: SQL查询添加 created_at 和 updated_at 字段 | 
					
						
							| 
									
										
										
										
											2025-09-02 18:51:50 +08:00
										 |  |  |  |         query = """
 | 
					
						
							|  |  |  |  |             SELECT id, username, password, created_at, updated_at  | 
					
						
							|  |  |  |  |             FROM users  | 
					
						
							|  |  |  |  |             WHERE username = %s | 
					
						
							|  |  |  |  |         """
 | 
					
						
							|  |  |  |  |         cursor.execute(query, (request.username,)) | 
					
						
							|  |  |  |  |         user = cursor.fetchone() | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 2. 校验用户名和密码 | 
					
						
							|  |  |  |  |         if not user or not verify_password(request.password, user["password"]): | 
					
						
							|  |  |  |  |             raise HTTPException( | 
					
						
							|  |  |  |  |                 status_code=401, | 
					
						
							|  |  |  |  |                 detail="用户名或密码错误", | 
					
						
							|  |  |  |  |                 headers={"WWW-Authenticate": "Bearer"}, | 
					
						
							|  |  |  |  |             ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 3. 生成 Token(过期时间从配置读取) | 
					
						
							|  |  |  |  |         access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) | 
					
						
							|  |  |  |  |         access_token = create_access_token( | 
					
						
							|  |  |  |  |             data={"sub": user["username"]}, | 
					
						
							|  |  |  |  |             expires_delta=access_token_expires | 
					
						
							|  |  |  |  |         ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 4. 返回 Token 和用户基本信息 | 
					
						
							|  |  |  |  |         return APIResponse( | 
					
						
							|  |  |  |  |             code=200, | 
					
						
							|  |  |  |  |             message="登录成功", | 
					
						
							|  |  |  |  |             data={ | 
					
						
							|  |  |  |  |                 "access_token": access_token, | 
					
						
							|  |  |  |  |                 "token_type": "bearer", | 
					
						
							|  |  |  |  |                 "user": UserResponse( | 
					
						
							|  |  |  |  |                     id=user["id"], | 
					
						
							|  |  |  |  |                     username=user["username"], | 
					
						
							|  |  |  |  |                     created_at=user.get("created_at"), | 
					
						
							|  |  |  |  |                     updated_at=user.get("updated_at") | 
					
						
							|  |  |  |  |                 ) | 
					
						
							|  |  |  |  |             } | 
					
						
							|  |  |  |  |         ) | 
					
						
							|  |  |  |  |     except MySQLError as e: | 
					
						
							| 
									
										
										
										
											2025-09-08 17:34:23 +08:00
										 |  |  |  |         raise Exception(f"登录失败: {str(e)}") from e | 
					
						
							| 
									
										
										
										
											2025-09-02 18:51:50 +08:00
										 |  |  |  |     finally: | 
					
						
							|  |  |  |  |         db.close_connection(conn, cursor) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | # ------------------------------ | 
					
						
							|  |  |  |  | # 3. 获取当前登录用户信息(需认证) | 
					
						
							|  |  |  |  | # ------------------------------ | 
					
						
							|  |  |  |  | @router.get("/me", response_model=APIResponse, summary="获取当前用户信息") | 
					
						
							|  |  |  |  | async def get_current_user_info( | 
					
						
							|  |  |  |  |         current_user: UserResponse = Depends(get_current_user)  # 依赖认证中间件 | 
					
						
							|  |  |  |  | ): | 
					
						
							|  |  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  |     获取当前登录用户信息: | 
					
						
							| 
									
										
										
										
											2025-09-08 17:34:23 +08:00
										 |  |  |  |     - 需在请求头携带 Token(格式: Bearer <token>) | 
					
						
							| 
									
										
										
										
											2025-09-02 18:51:50 +08:00
										 |  |  |  |     - 认证通过后返回用户信息 | 
					
						
							|  |  |  |  |     """
 | 
					
						
							|  |  |  |  |     return APIResponse( | 
					
						
							|  |  |  |  |         code=200, | 
					
						
							|  |  |  |  |         message="获取用户信息成功", | 
					
						
							|  |  |  |  |         data=current_user | 
					
						
							|  |  |  |  |     ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | # ------------------------------ | 
					
						
							|  |  |  |  | # 4. 获取用户列表(仅需登录权限) | 
					
						
							|  |  |  |  | # ------------------------------ | 
					
						
							|  |  |  |  | @router.get("/list", response_model=APIResponse, summary="获取用户列表") | 
					
						
							|  |  |  |  | async def get_user_list( | 
					
						
							|  |  |  |  |         page: int = Query(1, ge=1, description="页码,从1开始"), | 
					
						
							|  |  |  |  |         page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100之间"), | 
					
						
							|  |  |  |  |         username: Optional[str] = Query(None, description="用户名模糊搜索"), | 
					
						
							|  |  |  |  |         current_user: UserResponse = Depends(get_current_user)  # 仅需登录即可访问(移除管理员校验) | 
					
						
							|  |  |  |  | ): | 
					
						
							|  |  |  |  |     """
 | 
					
						
							|  |  |  |  |     获取用户列表: | 
					
						
							|  |  |  |  |     - 需登录权限(请求头携带 Token: Bearer <token>) | 
					
						
							|  |  |  |  |     - 支持分页查询(page=页码,page_size=每页条数) | 
					
						
							|  |  |  |  |     - 支持用户名模糊搜索(如输入"test"可匹配"test123"、"admin_test"等) | 
					
						
							|  |  |  |  |     - 仅返回用户ID、用户名、创建时间、更新时间(不包含密码等敏感信息) | 
					
						
							|  |  |  |  |     """
 | 
					
						
							|  |  |  |  |     conn = None | 
					
						
							|  |  |  |  |     cursor = None | 
					
						
							|  |  |  |  |     try: | 
					
						
							|  |  |  |  |         conn = db.get_connection() | 
					
						
							|  |  |  |  |         cursor = conn.cursor(dictionary=True) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 计算分页偏移量(page从1开始,偏移量=(页码-1)*每页条数) | 
					
						
							|  |  |  |  |         offset = (page - 1) * page_size | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 基础查询(仅查非敏感字段) | 
					
						
							|  |  |  |  |         base_query = """
 | 
					
						
							|  |  |  |  |             SELECT id, username, created_at, updated_at  | 
					
						
							|  |  |  |  |             FROM users | 
					
						
							|  |  |  |  |         """
 | 
					
						
							|  |  |  |  |         # 总条数查询(用于分页计算) | 
					
						
							|  |  |  |  |         count_query = "SELECT COUNT(*) as total FROM users" | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 条件拼接(支持用户名模糊搜索) | 
					
						
							|  |  |  |  |         conditions = [] | 
					
						
							|  |  |  |  |         params = [] | 
					
						
							|  |  |  |  |         if username: | 
					
						
							|  |  |  |  |             conditions.append("username LIKE %s") | 
					
						
							|  |  |  |  |             params.append(f"%{username}%")  # 模糊匹配:%表示任意字符 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 构建最终查询语句 | 
					
						
							|  |  |  |  |         if conditions: | 
					
						
							|  |  |  |  |             where_clause = " WHERE " + " AND ".join(conditions) | 
					
						
							|  |  |  |  |             final_query = f"{base_query}{where_clause} LIMIT %s OFFSET %s" | 
					
						
							|  |  |  |  |             final_count_query = f"{count_query}{where_clause}" | 
					
						
							|  |  |  |  |             params.extend([page_size, offset])  # 追加分页参数 | 
					
						
							|  |  |  |  |         else: | 
					
						
							|  |  |  |  |             final_query = f"{base_query} LIMIT %s OFFSET %s" | 
					
						
							|  |  |  |  |             final_count_query = count_query | 
					
						
							|  |  |  |  |             params = [page_size, offset] | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 1. 查询用户列表数据 | 
					
						
							|  |  |  |  |         cursor.execute(final_query, params) | 
					
						
							|  |  |  |  |         users = cursor.fetchall() | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 2. 查询总条数(用于计算总页数) | 
					
						
							|  |  |  |  |         count_params = [f"%{username}%"] if username else [] | 
					
						
							|  |  |  |  |         cursor.execute(final_count_query, count_params) | 
					
						
							|  |  |  |  |         total = cursor.fetchone()["total"] | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 3. 转换为UserResponse模型(确保字段匹配) | 
					
						
							|  |  |  |  |         user_list = [ | 
					
						
							|  |  |  |  |             UserResponse( | 
					
						
							|  |  |  |  |                 id=user["id"], | 
					
						
							|  |  |  |  |                 username=user["username"], | 
					
						
							|  |  |  |  |                 created_at=user["created_at"], | 
					
						
							|  |  |  |  |                 updated_at=user["updated_at"] | 
					
						
							|  |  |  |  |             ) | 
					
						
							|  |  |  |  |             for user in users | 
					
						
							|  |  |  |  |         ] | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 4. 计算总页数(向上取整,如11条数据每页10条=2页) | 
					
						
							|  |  |  |  |         total_pages = (total + page_size - 1) // page_size | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 返回结果(包含列表和分页信息) | 
					
						
							|  |  |  |  |         return APIResponse( | 
					
						
							|  |  |  |  |             code=200, | 
					
						
							|  |  |  |  |             message="获取用户列表成功", | 
					
						
							|  |  |  |  |             data={ | 
					
						
							|  |  |  |  |                 "users": user_list, | 
					
						
							|  |  |  |  |                 "pagination": { | 
					
						
							|  |  |  |  |                     "page": page,          # 当前页码 | 
					
						
							|  |  |  |  |                     "page_size": page_size, # 每页条数 | 
					
						
							|  |  |  |  |                     "total": total,         # 总数据量 | 
					
						
							|  |  |  |  |                     "total_pages": total_pages  # 总页数 | 
					
						
							|  |  |  |  |                 } | 
					
						
							|  |  |  |  |             } | 
					
						
							|  |  |  |  |         ) | 
					
						
							|  |  |  |  |     except MySQLError as e: | 
					
						
							|  |  |  |  |         raise Exception(f"获取用户列表失败: {str(e)}") from e | 
					
						
							|  |  |  |  |     finally: | 
					
						
							|  |  |  |  |         # 无论成功失败,都关闭数据库连接 | 
					
						
							|  |  |  |  |         db.close_connection(conn, cursor) |