优化代码风格
This commit is contained in:
@ -1,39 +1,33 @@
|
||||
import json
|
||||
from functools import wraps
|
||||
from fastapi import HTTPException
|
||||
from schema.response_schema import APIResponse
|
||||
from encryption.encryption import AESCipher
|
||||
from utils.encrypt_utils import aes_encrypt
|
||||
|
||||
|
||||
def encrypt_response(func):
|
||||
"""
|
||||
返回值加密装饰器:
|
||||
- 仅对 APIResponse 的 data 字段加密(code/message 不加密,便于前端判断基础状态)
|
||||
- 若 data 为 None(如注册接口),不加密
|
||||
"""
|
||||
def encrypt_response(field: str = "data"):
|
||||
"""接口返回值加密装饰器:默认加密APIResponse的data字段"""
|
||||
|
||||
@wraps(func) # 保留原函数元信息(如 __name__、__doc__,避免 FastAPI 路由异常)
|
||||
async def wrapper(*args, **kwargs):
|
||||
try:
|
||||
# 1. 执行原接口函数,获取返回值(APIResponse 对象)
|
||||
response: APIResponse = await func(*args, **kwargs)
|
||||
def decorator(func):
|
||||
@wraps(func) # 保留原函数元信息(避免FastAPI路由异常)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# 执行原接口函数,获取原始响应(APIResponse对象)
|
||||
original_response: APIResponse = await func(*args, **kwargs)
|
||||
|
||||
# 2. 仅当 data 不为 None 时加密
|
||||
if response.data is not None:
|
||||
# 加密 data 字段(字典类型)
|
||||
encrypted_result = AESCipher.encrypt(response.data)
|
||||
# 替换原 data 为加密后的数据(包含密文和 IV)
|
||||
response.data = {
|
||||
"is_encrypted": True, # 标记是否加密,便于前端处理
|
||||
**encrypted_result
|
||||
}
|
||||
# 若需加密的字段为空,直接返回原响应(如注册接口data=None)
|
||||
if not getattr(original_response, field):
|
||||
return original_response
|
||||
|
||||
return response
|
||||
# 复杂数据转JSON字符串(支持datetime、字典、列表等)
|
||||
field_value = getattr(original_response, field)
|
||||
field_value_json = json.dumps(field_value, default=str) # 处理特殊类型
|
||||
|
||||
except Exception as e:
|
||||
# 加密过程异常时,返回 500 错误
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"返回值加密失败:{str(e)}"
|
||||
) from e
|
||||
# AES加密并替换原字段
|
||||
encrypted_data = aes_encrypt(field_value_json)
|
||||
setattr(original_response, field, encrypted_data)
|
||||
|
||||
return wrapper
|
||||
# 返回加密后的响应
|
||||
return original_response
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
@ -1,72 +1,56 @@
|
||||
import base64
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
import base64
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Util.Padding import pad, unpad
|
||||
from Crypto.Random import get_random_bytes
|
||||
from fastapi import HTTPException
|
||||
|
||||
# 加载环境变量(从 .env 文件读取密钥)
|
||||
load_dotenv()
|
||||
# 硬编码AES密钥(32字节,AES-256)
|
||||
AES_SECRET_KEY = b"jr1vA6tfWMHOYi6UXw67UuO6fdak2rMa"
|
||||
AES_BLOCK_SIZE = 16 # AES固定块大小
|
||||
|
||||
# 校验密钥长度(确保符合AES规范)
|
||||
valid_key_lengths = [16, 24, 32]
|
||||
if len(AES_SECRET_KEY) not in valid_key_lengths:
|
||||
raise ValueError(
|
||||
f"AES密钥长度必须为{valid_key_lengths}字节,当前为{len(AES_SECRET_KEY)}字节"
|
||||
)
|
||||
|
||||
|
||||
class AESCipher:
|
||||
"""AES-CBC 对称加密工具类"""
|
||||
# 从环境变量获取密钥(AES-256 需 32 字节密钥,AES-128 需 16 字节)
|
||||
SECRET_KEY = "jr1vA6tfWMHOYi6UXw67UuO6fdak2rMa".encode("utf-8")
|
||||
# AES 块大小固定为 16 字节
|
||||
BLOCK_SIZE = 16
|
||||
def aes_encrypt(plaintext: str) -> dict:
|
||||
"""AES-CBC模式加密(返回密文+IV,均为Base64编码)"""
|
||||
try:
|
||||
# 生成随机IV(16字节)
|
||||
iv = os.urandom(AES_BLOCK_SIZE)
|
||||
|
||||
@classmethod
|
||||
def _validate_key(cls):
|
||||
"""校验密钥长度(AES-256 需 32 字节,AES-128 需 16 字节)"""
|
||||
if len(cls.SECRET_KEY) not in (16, 32):
|
||||
raise ValueError("AES 密钥长度必须为 16 字节(AES-128)或 32 字节(AES-256)")
|
||||
# 创建加密器
|
||||
cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv)
|
||||
|
||||
@classmethod
|
||||
def encrypt(cls, data: dict) -> dict:
|
||||
"""
|
||||
加密函数:将字典类型的 data 加密
|
||||
返回:{encrypted_data: 加密后Base64字符串, iv: 16字节IV的Base64字符串}
|
||||
"""
|
||||
cls._validate_key()
|
||||
|
||||
# 1. 生成 16 字节随机 IV(每次加密都生成新 IV,无需保密但需和解密一致)
|
||||
iv = get_random_bytes(cls.BLOCK_SIZE)
|
||||
|
||||
# 2. 初始化 AES-CBC 加密器
|
||||
cipher = AES.new(cls.SECRET_KEY, AES.MODE_CBC, iv)
|
||||
|
||||
# 3. 数据序列化(字典转JSON字符串)→ 编码为字节 → 填充(PKCS7)
|
||||
data_str = str(data) # 若需更严谨,可使用 json.dumps(data, ensure_ascii=False)
|
||||
data_bytes = data_str.encode("utf-8")
|
||||
padded_data = pad(data_bytes, cls.BLOCK_SIZE, style="pkcs7")
|
||||
|
||||
# 4. 加密 → 转为 Base64 字符串(便于接口传输)
|
||||
encrypted_bytes = cipher.encrypt(padded_data)
|
||||
encrypted_data = base64.b64encode(encrypted_bytes).decode("utf-8")
|
||||
iv_b64 = base64.b64encode(iv).decode("utf-8")
|
||||
# 明文填充并加密
|
||||
padded_plaintext = pad(plaintext.encode("utf-8"), AES_BLOCK_SIZE)
|
||||
ciphertext = base64.b64encode(cipher.encrypt(padded_plaintext)).decode("utf-8")
|
||||
iv_base64 = base64.b64encode(iv).decode("utf-8")
|
||||
|
||||
return {
|
||||
"encrypted_data": encrypted_data,
|
||||
"iv": iv_b64 # IV 需随密文一起返回,供前端解密
|
||||
"ciphertext": ciphertext,
|
||||
"iv": iv_base64,
|
||||
"algorithm": "AES-CBC"
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"AES加密失败:{str(e)}") from e
|
||||
|
||||
@classmethod
|
||||
def decrypt(cls, encrypted_data: str, iv_b64: str) -> dict:
|
||||
"""
|
||||
解密函数:将加密后的 Base64 字符串解密为字典
|
||||
参数:encrypted_data(加密数据)、iv_b64(加密时的IV)
|
||||
"""
|
||||
cls._validate_key()
|
||||
|
||||
# 1. 解码 Base64(IV 和 密文)
|
||||
iv = base64.b64decode(iv_b64)
|
||||
encrypted_bytes = base64.b64decode(encrypted_data)
|
||||
def aes_decrypt(ciphertext: str, iv: str) -> str:
|
||||
"""AES-CBC模式解密"""
|
||||
try:
|
||||
# 解码Base64
|
||||
ciphertext_bytes = base64.b64decode(ciphertext)
|
||||
iv_bytes = base64.b64decode(iv)
|
||||
|
||||
# 2. 初始化 AES-CBC 解密器
|
||||
cipher = AES.new(cls.SECRET_KEY, AES.MODE_CBC, iv)
|
||||
# 创建解密器
|
||||
cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv_bytes)
|
||||
|
||||
# 3. 解密 → 去除填充 → 解码为字符串 → 转为字典(此处简化,实际可用 json.loads)
|
||||
decrypted_bytes = unpad(cipher.decrypt(encrypted_bytes), cls.BLOCK_SIZE, style="pkcs7")
|
||||
decrypted_str = decrypted_bytes.decode("utf-8")
|
||||
return eval(decrypted_str) # 生产环境建议用 json.loads,避免 eval 安全风险
|
||||
# 解密并去填充
|
||||
decrypted_bytes = unpad(cipher.decrypt(ciphertext_bytes), AES_BLOCK_SIZE)
|
||||
return decrypted_bytes.decode("utf-8")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"AES解密失败:{str(e)}") from e
|
||||
@ -3,6 +3,7 @@ from mysql.connector import Error as MySQLError
|
||||
from typing import Optional
|
||||
|
||||
from ds.db import db
|
||||
from encryption.encrypt_decorator import encrypt_response
|
||||
from schema.sensitive_schema import (
|
||||
SensitiveCreateRequest,
|
||||
SensitiveUpdateRequest,
|
||||
@ -120,6 +121,7 @@ async def get_sensitive(
|
||||
# 3. 获取敏感信息分页列表(重构:支持分页+关键词搜索)
|
||||
# ------------------------------
|
||||
@router.get("", response_model=APIResponse, summary="获取敏感信息分页列表(支持关键词搜索)")
|
||||
@encrypt_response()
|
||||
async def get_sensitive_list(
|
||||
page: int = Query(1, ge=1, description="页码(默认1,最小1)"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数(默认10,1-100)"),
|
||||
|
||||
@ -27,7 +27,6 @@ router = APIRouter(
|
||||
# 1. 用户注册接口
|
||||
# ------------------------------
|
||||
@router.post("/register", response_model=APIResponse, summary="用户注册")
|
||||
@encrypt_response
|
||||
async def user_register(request: UserRegisterRequest):
|
||||
"""
|
||||
用户注册:
|
||||
|
||||
Reference in New Issue
Block a user