2025-09-02 18:51:50 +08:00
|
|
|
|
from datetime import timedelta
|
2025-09-12 14:05:09 +08:00
|
|
|
|
from typing import Optional
|
2025-09-02 18:51:50 +08:00
|
|
|
|
|
2025-09-12 14:05:09 +08:00
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
2025-09-02 18:51:50 +08:00
|
|
|
|
from mysql.connector import Error as MySQLError
|
|
|
|
|
|
|
|
|
|
from ds.db import db
|
|
|
|
|
from schema.user_schema import UserRegisterRequest, UserLoginRequest, UserResponse
|
|
|
|
|
from schema.response_schema import APIResponse
|
|
|
|
|
from middle.auth_middleware import (
|
|
|
|
|
get_password_hash,
|
|
|
|
|
verify_password,
|
|
|
|
|
create_access_token,
|
|
|
|
|
ACCESS_TOKEN_EXPIRE_MINUTES,
|
2025-09-12 14:05:09 +08:00
|
|
|
|
get_current_user # 仅保留登录用户校验,移除is_admin导入
|
2025-09-02 18:51:50 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 创建用户接口路由(前缀 /users、标签用于 Swagger 分类)
|
|
|
|
|
router = APIRouter(
|
|
|
|
|
prefix="/users",
|
|
|
|
|
tags=["用户管理"]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ------------------------------
|
|
|
|
|
# 1. 用户注册接口
|
|
|
|
|
# ------------------------------
|
|
|
|
|
@router.post("/register", response_model=APIResponse, summary="用户注册")
|
|
|
|
|
async def user_register(request: UserRegisterRequest):
|
|
|
|
|
"""
|
2025-09-12 14:05:09 +08:00
|
|
|
|
用户注册:
|
2025-09-02 18:51:50 +08:00
|
|
|
|
- 校验用户名是否已存在
|
|
|
|
|
- 加密密码后插入数据库
|
|
|
|
|
- 返回注册成功信息
|
|
|
|
|
"""
|
|
|
|
|
conn = None
|
|
|
|
|
cursor = None
|
|
|
|
|
try:
|
|
|
|
|
conn = db.get_connection()
|
|
|
|
|
cursor = conn.cursor(dictionary=True)
|
|
|
|
|
|
|
|
|
|
# 1. 检查用户名是否已存在(唯一索引)
|
|
|
|
|
check_query = "SELECT username FROM users WHERE username = %s"
|
|
|
|
|
cursor.execute(check_query, (request.username,))
|
|
|
|
|
existing_user = cursor.fetchone()
|
|
|
|
|
if existing_user:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail=f"用户名 '{request.username}' 已存在、请更换其他用户名"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 2. 加密密码
|
|
|
|
|
hashed_password = get_password_hash(request.password)
|
|
|
|
|
|
|
|
|
|
# 3. 插入新用户到数据库
|
|
|
|
|
insert_query = """
|
|
|
|
|
INSERT INTO users (username, password)
|
|
|
|
|
VALUES (%s, %s)
|
|
|
|
|
"""
|
|
|
|
|
cursor.execute(insert_query, (request.username, hashed_password))
|
|
|
|
|
conn.commit() # 提交事务
|
|
|
|
|
|
|
|
|
|
# 4. 返回注册成功响应
|
|
|
|
|
return APIResponse(
|
|
|
|
|
code=201, # 201 表示资源创建成功
|
|
|
|
|
message=f"用户 '{request.username}' 注册成功",
|
|
|
|
|
data=None
|
|
|
|
|
)
|
|
|
|
|
except MySQLError as e:
|
|
|
|
|
conn.rollback() # 数据库错误时回滚事务
|
2025-09-08 17:34:23 +08:00
|
|
|
|
raise Exception(f"注册失败: {str(e)}") from e
|
2025-09-02 18:51:50 +08:00
|
|
|
|
finally:
|
|
|
|
|
db.close_connection(conn, cursor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ------------------------------
|
|
|
|
|
# 2. 用户登录接口
|
|
|
|
|
# ------------------------------
|
|
|
|
|
@router.post("/login", response_model=APIResponse, summary="用户登录(获取 Token)")
|
|
|
|
|
async def user_login(request: UserLoginRequest):
|
|
|
|
|
"""
|
2025-09-12 14:05:09 +08:00
|
|
|
|
用户登录:
|
2025-09-02 18:51:50 +08:00
|
|
|
|
- 校验用户名是否存在
|
|
|
|
|
- 校验密码是否正确
|
|
|
|
|
- 生成 JWT Token 并返回
|
|
|
|
|
"""
|
|
|
|
|
conn = None
|
|
|
|
|
cursor = None
|
|
|
|
|
try:
|
|
|
|
|
conn = db.get_connection()
|
|
|
|
|
cursor = conn.cursor(dictionary=True)
|
|
|
|
|
|
2025-09-08 17:34:23 +08:00
|
|
|
|
# 修复: SQL查询添加 created_at 和 updated_at 字段
|
2025-09-02 18:51:50 +08:00
|
|
|
|
query = """
|
|
|
|
|
SELECT id, username, password, created_at, updated_at
|
|
|
|
|
FROM users
|
|
|
|
|
WHERE username = %s
|
|
|
|
|
"""
|
|
|
|
|
cursor.execute(query, (request.username,))
|
|
|
|
|
user = cursor.fetchone()
|
|
|
|
|
|
|
|
|
|
# 2. 校验用户名和密码
|
|
|
|
|
if not user or not verify_password(request.password, user["password"]):
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=401,
|
|
|
|
|
detail="用户名或密码错误",
|
|
|
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 3. 生成 Token(过期时间从配置读取)
|
|
|
|
|
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
|
|
|
access_token = create_access_token(
|
|
|
|
|
data={"sub": user["username"]},
|
|
|
|
|
expires_delta=access_token_expires
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 4. 返回 Token 和用户基本信息
|
|
|
|
|
return APIResponse(
|
|
|
|
|
code=200,
|
|
|
|
|
message="登录成功",
|
|
|
|
|
data={
|
|
|
|
|
"access_token": access_token,
|
|
|
|
|
"token_type": "bearer",
|
|
|
|
|
"user": UserResponse(
|
|
|
|
|
id=user["id"],
|
|
|
|
|
username=user["username"],
|
|
|
|
|
created_at=user.get("created_at"),
|
|
|
|
|
updated_at=user.get("updated_at")
|
|
|
|
|
)
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
except MySQLError as e:
|
2025-09-08 17:34:23 +08:00
|
|
|
|
raise Exception(f"登录失败: {str(e)}") from e
|
2025-09-02 18:51:50 +08:00
|
|
|
|
finally:
|
|
|
|
|
db.close_connection(conn, cursor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ------------------------------
|
|
|
|
|
# 3. 获取当前登录用户信息(需认证)
|
|
|
|
|
# ------------------------------
|
|
|
|
|
@router.get("/me", response_model=APIResponse, summary="获取当前用户信息")
|
|
|
|
|
async def get_current_user_info(
|
|
|
|
|
current_user: UserResponse = Depends(get_current_user) # 依赖认证中间件
|
|
|
|
|
):
|
|
|
|
|
"""
|
2025-09-12 14:05:09 +08:00
|
|
|
|
获取当前登录用户信息:
|
2025-09-08 17:34:23 +08:00
|
|
|
|
- 需在请求头携带 Token(格式: Bearer <token>)
|
2025-09-02 18:51:50 +08:00
|
|
|
|
- 认证通过后返回用户信息
|
|
|
|
|
"""
|
|
|
|
|
return APIResponse(
|
|
|
|
|
code=200,
|
|
|
|
|
message="获取用户信息成功",
|
|
|
|
|
data=current_user
|
|
|
|
|
)
|
|
|
|
|
|
2025-09-12 14:05:09 +08:00
|
|
|
|
|
|
|
|
|
# ------------------------------
|
|
|
|
|
# 4. 获取用户列表(仅需登录权限)
|
|
|
|
|
# ------------------------------
|
|
|
|
|
@router.get("/list", response_model=APIResponse, summary="获取用户列表")
|
|
|
|
|
async def get_user_list(
|
|
|
|
|
page: int = Query(1, ge=1, description="页码,从1开始"),
|
|
|
|
|
page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100之间"),
|
|
|
|
|
username: Optional[str] = Query(None, description="用户名模糊搜索"),
|
|
|
|
|
current_user: UserResponse = Depends(get_current_user) # 仅需登录即可访问(移除管理员校验)
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
获取用户列表:
|
|
|
|
|
- 需登录权限(请求头携带 Token: Bearer <token>)
|
|
|
|
|
- 支持分页查询(page=页码,page_size=每页条数)
|
|
|
|
|
- 支持用户名模糊搜索(如输入"test"可匹配"test123"、"admin_test"等)
|
|
|
|
|
- 仅返回用户ID、用户名、创建时间、更新时间(不包含密码等敏感信息)
|
|
|
|
|
"""
|
|
|
|
|
conn = None
|
|
|
|
|
cursor = None
|
|
|
|
|
try:
|
|
|
|
|
conn = db.get_connection()
|
|
|
|
|
cursor = conn.cursor(dictionary=True)
|
|
|
|
|
|
|
|
|
|
# 计算分页偏移量(page从1开始,偏移量=(页码-1)*每页条数)
|
|
|
|
|
offset = (page - 1) * page_size
|
|
|
|
|
|
|
|
|
|
# 基础查询(仅查非敏感字段)
|
|
|
|
|
base_query = """
|
|
|
|
|
SELECT id, username, created_at, updated_at
|
|
|
|
|
FROM users
|
|
|
|
|
"""
|
|
|
|
|
# 总条数查询(用于分页计算)
|
|
|
|
|
count_query = "SELECT COUNT(*) as total FROM users"
|
|
|
|
|
|
|
|
|
|
# 条件拼接(支持用户名模糊搜索)
|
|
|
|
|
conditions = []
|
|
|
|
|
params = []
|
|
|
|
|
if username:
|
|
|
|
|
conditions.append("username LIKE %s")
|
|
|
|
|
params.append(f"%{username}%") # 模糊匹配:%表示任意字符
|
|
|
|
|
|
|
|
|
|
# 构建最终查询语句
|
|
|
|
|
if conditions:
|
|
|
|
|
where_clause = " WHERE " + " AND ".join(conditions)
|
|
|
|
|
final_query = f"{base_query}{where_clause} LIMIT %s OFFSET %s"
|
|
|
|
|
final_count_query = f"{count_query}{where_clause}"
|
|
|
|
|
params.extend([page_size, offset]) # 追加分页参数
|
|
|
|
|
else:
|
|
|
|
|
final_query = f"{base_query} LIMIT %s OFFSET %s"
|
|
|
|
|
final_count_query = count_query
|
|
|
|
|
params = [page_size, offset]
|
|
|
|
|
|
|
|
|
|
# 1. 查询用户列表数据
|
|
|
|
|
cursor.execute(final_query, params)
|
|
|
|
|
users = cursor.fetchall()
|
|
|
|
|
|
|
|
|
|
# 2. 查询总条数(用于计算总页数)
|
|
|
|
|
count_params = [f"%{username}%"] if username else []
|
|
|
|
|
cursor.execute(final_count_query, count_params)
|
|
|
|
|
total = cursor.fetchone()["total"]
|
|
|
|
|
|
|
|
|
|
# 3. 转换为UserResponse模型(确保字段匹配)
|
|
|
|
|
user_list = [
|
|
|
|
|
UserResponse(
|
|
|
|
|
id=user["id"],
|
|
|
|
|
username=user["username"],
|
|
|
|
|
created_at=user["created_at"],
|
|
|
|
|
updated_at=user["updated_at"]
|
|
|
|
|
)
|
|
|
|
|
for user in users
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# 4. 计算总页数(向上取整,如11条数据每页10条=2页)
|
|
|
|
|
total_pages = (total + page_size - 1) // page_size
|
|
|
|
|
|
|
|
|
|
# 返回结果(包含列表和分页信息)
|
|
|
|
|
return APIResponse(
|
|
|
|
|
code=200,
|
|
|
|
|
message="获取用户列表成功",
|
|
|
|
|
data={
|
|
|
|
|
"users": user_list,
|
|
|
|
|
"pagination": {
|
|
|
|
|
"page": page, # 当前页码
|
|
|
|
|
"page_size": page_size, # 每页条数
|
|
|
|
|
"total": total, # 总数据量
|
|
|
|
|
"total_pages": total_pages # 总页数
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
except MySQLError as e:
|
|
|
|
|
raise Exception(f"获取用户列表失败: {str(e)}") from e
|
|
|
|
|
finally:
|
|
|
|
|
# 无论成功失败,都关闭数据库连接
|
|
|
|
|
db.close_connection(conn, cursor)
|