From ced84e49bcbbca46dd02950c8bbace68ea549210 Mon Sep 17 00:00:00 2001 From: ZZX9599 <536509593@qq.com> Date: Mon, 15 Sep 2025 18:08:54 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BB=A3=E7=A0=81=E9=A3=8E?= =?UTF-8?q?=E6=A0=BC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- encryption/encrypt_decorator.py | 39 ++++++++++++++++++ encryption/encryption.py | 72 +++++++++++++++++++++++++++++++++ main.py | 3 +- service/user_service.py | 2 + 4 files changed, 114 insertions(+), 2 deletions(-) create mode 100644 encryption/encrypt_decorator.py create mode 100644 encryption/encryption.py diff --git a/encryption/encrypt_decorator.py b/encryption/encrypt_decorator.py new file mode 100644 index 0000000..da541a4 --- /dev/null +++ b/encryption/encrypt_decorator.py @@ -0,0 +1,39 @@ +from functools import wraps +from fastapi import HTTPException +from schema.response_schema import APIResponse +from encryption.encryption import AESCipher + + +def encrypt_response(func): + """ + 返回值加密装饰器: + - 仅对 APIResponse 的 data 字段加密(code/message 不加密,便于前端判断基础状态) + - 若 data 为 None(如注册接口),不加密 + """ + + @wraps(func) # 保留原函数元信息(如 __name__、__doc__,避免 FastAPI 路由异常) + async def wrapper(*args, **kwargs): + try: + # 1. 执行原接口函数,获取返回值(APIResponse 对象) + response: APIResponse = await func(*args, **kwargs) + + # 2. 仅当 data 不为 None 时加密 + if response.data is not None: + # 加密 data 字段(字典类型) + encrypted_result = AESCipher.encrypt(response.data) + # 替换原 data 为加密后的数据(包含密文和 IV) + response.data = { + "is_encrypted": True, # 标记是否加密,便于前端处理 + **encrypted_result + } + + return response + + except Exception as e: + # 加密过程异常时,返回 500 错误 + raise HTTPException( + status_code=500, + detail=f"返回值加密失败:{str(e)}" + ) from e + + return wrapper \ No newline at end of file diff --git a/encryption/encryption.py b/encryption/encryption.py new file mode 100644 index 0000000..2cc39d8 --- /dev/null +++ b/encryption/encryption.py @@ -0,0 +1,72 @@ +import base64 +import os +from dotenv import load_dotenv +from Crypto.Cipher import AES +from Crypto.Util.Padding import pad, unpad +from Crypto.Random import get_random_bytes + +# 加载环境变量(从 .env 文件读取密钥) +load_dotenv() + + +class AESCipher: + """AES-CBC 对称加密工具类""" + # 从环境变量获取密钥(AES-256 需 32 字节密钥,AES-128 需 16 字节) + SECRET_KEY = "jr1vA6tfWMHOYi6UXw67UuO6fdak2rMa".encode("utf-8") + # AES 块大小固定为 16 字节 + BLOCK_SIZE = 16 + + @classmethod + def _validate_key(cls): + """校验密钥长度(AES-256 需 32 字节,AES-128 需 16 字节)""" + if len(cls.SECRET_KEY) not in (16, 32): + raise ValueError("AES 密钥长度必须为 16 字节(AES-128)或 32 字节(AES-256)") + + @classmethod + def encrypt(cls, data: dict) -> dict: + """ + 加密函数:将字典类型的 data 加密 + 返回:{encrypted_data: 加密后Base64字符串, iv: 16字节IV的Base64字符串} + """ + cls._validate_key() + + # 1. 生成 16 字节随机 IV(每次加密都生成新 IV,无需保密但需和解密一致) + iv = get_random_bytes(cls.BLOCK_SIZE) + + # 2. 初始化 AES-CBC 加密器 + cipher = AES.new(cls.SECRET_KEY, AES.MODE_CBC, iv) + + # 3. 数据序列化(字典转JSON字符串)→ 编码为字节 → 填充(PKCS7) + data_str = str(data) # 若需更严谨,可使用 json.dumps(data, ensure_ascii=False) + data_bytes = data_str.encode("utf-8") + padded_data = pad(data_bytes, cls.BLOCK_SIZE, style="pkcs7") + + # 4. 加密 → 转为 Base64 字符串(便于接口传输) + encrypted_bytes = cipher.encrypt(padded_data) + encrypted_data = base64.b64encode(encrypted_bytes).decode("utf-8") + iv_b64 = base64.b64encode(iv).decode("utf-8") + + return { + "encrypted_data": encrypted_data, + "iv": iv_b64 # IV 需随密文一起返回,供前端解密 + } + + @classmethod + def decrypt(cls, encrypted_data: str, iv_b64: str) -> dict: + """ + 解密函数:将加密后的 Base64 字符串解密为字典 + 参数:encrypted_data(加密数据)、iv_b64(加密时的IV) + """ + cls._validate_key() + + # 1. 解码 Base64(IV 和 密文) + iv = base64.b64decode(iv_b64) + encrypted_bytes = base64.b64decode(encrypted_data) + + # 2. 初始化 AES-CBC 解密器 + cipher = AES.new(cls.SECRET_KEY, AES.MODE_CBC, iv) + + # 3. 解密 → 去除填充 → 解码为字符串 → 转为字典(此处简化,实际可用 json.loads) + decrypted_bytes = unpad(cipher.decrypt(encrypted_bytes), cls.BLOCK_SIZE, style="pkcs7") + decrypted_str = decrypted_bytes.decode("utf-8") + return eval(decrypted_str) # 生产环境建议用 json.loads,避免 eval 安全风险 diff --git a/main.py b/main.py index 6e7d53d..df05b4f 100644 --- a/main.py +++ b/main.py @@ -14,7 +14,7 @@ from service.user_service import router as user_router from service.sensitive_service import router as sensitive_router from service.face_service import router as face_router from service.device_service import router as device_router -from service.model_service import router as model_router # 模型管理路由 +from service.model_service import router as model_router from ws.ws import ws_router, lifespan from core.establish import create_directory_structure @@ -81,7 +81,6 @@ app.add_exception_handler(Exception, global_exception_handler) # 主服务启动入口(不变) if __name__ == "__main__": - # 1. 初始化资源 create_directory_structure() print(f"[初始化] 目录结构创建完成") diff --git a/service/user_service.py b/service/user_service.py index f96dcdb..b9537ad 100644 --- a/service/user_service.py +++ b/service/user_service.py @@ -5,6 +5,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query from mysql.connector import Error as MySQLError from ds.db import db +from encryption.encrypt_decorator import encrypt_response from schema.user_schema import UserRegisterRequest, UserLoginRequest, UserResponse from schema.response_schema import APIResponse from middle.auth_middleware import ( @@ -26,6 +27,7 @@ router = APIRouter( # 1. 用户注册接口 # ------------------------------ @router.post("/register", response_model=APIResponse, summary="用户注册") +@encrypt_response async def user_register(request: UserRegisterRequest): """ 用户注册: