44 lines
1.6 KiB
Python
44 lines
1.6 KiB
Python
import datetime
|
||
import json
|
||
from functools import wraps
|
||
from typing import Any
|
||
|
||
from encryption.encryption import aes_encrypt
|
||
from schema.response_schema import APIResponse
|
||
# 假设SensitiveResponse等是Pydantic模型,需导入BaseModel
|
||
from pydantic import BaseModel
|
||
|
||
|
||
def encrypt_response(field: str = "data"):
|
||
"""接口返回值加密装饰器:正确序列化自定义对象为JSON"""
|
||
|
||
def decorator(func):
|
||
@wraps(func)
|
||
async def wrapper(*args, **kwargs):
|
||
original_response: APIResponse = await func(*args, **kwargs)
|
||
field_value = getattr(original_response, field)
|
||
|
||
if not field_value:
|
||
return original_response
|
||
|
||
# 自定义JSON序列化函数:处理Pydantic模型和datetime
|
||
def json_default(obj: Any) -> Any:
|
||
# 处理Pydantic模型(转换为字典)
|
||
if isinstance(obj, BaseModel):
|
||
return obj.model_dump() # Pydantic v2用model_dump(),v1用dict()
|
||
# 处理datetime(转换为ISO格式字符串)
|
||
if isinstance(obj, datetime):
|
||
return obj.isoformat()
|
||
# 其他无法序列化的类型,可根据需要扩展
|
||
return str(obj) # 作为最后兜底
|
||
|
||
# 使用自定义序列化函数,确保生成标准JSON
|
||
field_value_json = json.dumps(field_value, default=json_default)
|
||
encrypted_data = aes_encrypt(field_value_json)
|
||
setattr(original_response, field, encrypted_data)
|
||
|
||
return original_response
|
||
|
||
return wrapper
|
||
|
||
return decorator |