40 lines
		
	
	
		
			1.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			40 lines
		
	
	
		
			1.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|  | import json | |||
|  | from datetime import datetime | |||
|  | from functools import wraps | |||
|  | from typing import Any | |||
|  | 
 | |||
|  | from encryption.encryption import aes_encrypt | |||
|  | from schema.response_schema import APIResponse | |||
|  | from pydantic import BaseModel | |||
|  | 
 | |||
|  | 
 | |||
|  | def encrypt_response(field: str = "data"): | |||
|  |     """接口返回值加密装饰器:正确序列化自定义对象为JSON""" | |||
|  | 
 | |||
|  |     def decorator(func): | |||
|  |         @wraps(func) | |||
|  |         async def wrapper(*args, **kwargs): | |||
|  |             original_response: APIResponse = await func(*args, **kwargs) | |||
|  |             field_value = getattr(original_response, field) | |||
|  | 
 | |||
|  |             if not field_value: | |||
|  |                 return original_response | |||
|  | 
 | |||
|  |             # 自定义JSON序列化函数:处理Pydantic模型和datetime | |||
|  |             def json_default(obj: Any) -> Any: | |||
|  |                 if isinstance(obj, BaseModel): | |||
|  |                     return obj.model_dump()  | |||
|  |                 if isinstance(obj, datetime): | |||
|  |                     return obj.isoformat() | |||
|  |                 return str(obj)  | |||
|  | 
 | |||
|  |             # 使用自定义序列化函数、确保生成标准JSON | |||
|  |             field_value_json = json.dumps(field_value, default=json_default) | |||
|  |             encrypted_data = aes_encrypt(field_value_json) | |||
|  |             setattr(original_response, field, encrypted_data) | |||
|  | 
 | |||
|  |             return original_response | |||
|  | 
 | |||
|  |         return wrapper | |||
|  | 
 | |||
|  |     return decorator |