307 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			307 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from fastapi import APIRouter, Depends, HTTPException, Query, File, UploadFile
 | ||
| from mysql.connector import Error as MySQLError
 | ||
| from typing import Optional
 | ||
| 
 | ||
| from ds.db import db
 | ||
| from encryption.encrypt_decorator import encrypt_response
 | ||
| from schema.sensitive_schema import (
 | ||
|     SensitiveCreateRequest,
 | ||
|     SensitiveResponse,
 | ||
|     SensitiveListResponse
 | ||
| )
 | ||
| from schema.response_schema import APIResponse
 | ||
| from middle.auth_middleware import get_current_user
 | ||
| from schema.user_schema import UserResponse
 | ||
| from service.ocr_service import  set_forbidden_words
 | ||
| from service.sensitive_service import get_all_sensitive_words
 | ||
| 
 | ||
| router = APIRouter(
 | ||
|     prefix="/api/sensitives",
 | ||
|     tags=["敏感信息管理"]
 | ||
| )
 | ||
| 
 | ||
| 
 | ||
| # 创建敏感信息记录
 | ||
| @router.post("", response_model=APIResponse, summary="创建敏感信息记录")
 | ||
| @encrypt_response()
 | ||
| async def create_sensitive(
 | ||
|         sensitive: SensitiveCreateRequest,
 | ||
|         current_user: UserResponse = Depends(get_current_user)
 | ||
| ):
 | ||
|     conn = None
 | ||
|     cursor = None
 | ||
|     try:
 | ||
|         conn = db.get_connection()
 | ||
|         cursor = conn.cursor(dictionary=True)
 | ||
| 
 | ||
|         # 插入新敏感信息记录到数据库(不包含ID、由数据库自动生成)
 | ||
|         insert_query = """
 | ||
|             INSERT INTO sensitives (name, created_at, updated_at)
 | ||
|             VALUES (%s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
 | ||
|         """
 | ||
|         cursor.execute(insert_query, (sensitive.name,))
 | ||
|         conn.commit()
 | ||
| 
 | ||
|         # 获取刚插入记录的ID(使用LAST_INSERT_ID()函数)
 | ||
|         new_id = cursor.lastrowid
 | ||
| 
 | ||
|         # 查询刚创建的记录并返回
 | ||
|         select_query = "SELECT * FROM sensitives WHERE id = %s"
 | ||
|         cursor.execute(select_query, (new_id,))
 | ||
|         created_sensitive = cursor.fetchone()
 | ||
| 
 | ||
|         # 重新加载最新的敏感词
 | ||
|         set_forbidden_words(get_all_sensitive_words())
 | ||
| 
 | ||
|         return APIResponse(
 | ||
|             code=200,
 | ||
|             message="敏感信息记录创建成功",
 | ||
|             data=SensitiveResponse(**created_sensitive)
 | ||
|         )
 | ||
|     except MySQLError as e:
 | ||
|         if conn:
 | ||
|             conn.rollback()
 | ||
|         raise HTTPException(
 | ||
|             status_code=500,
 | ||
|             detail=f"创建敏感信息记录失败: {str(e)}"
 | ||
|         ) from e
 | ||
|     finally:
 | ||
|         db.close_connection(conn, cursor)
 | ||
| 
 | ||
| 
 | ||
| # 获取敏感信息分页列表
 | ||
| @router.get("", response_model=APIResponse, summary="获取敏感信息分页列表(支持关键词搜索)")
 | ||
| @encrypt_response()
 | ||
| async def get_sensitive_list(
 | ||
|         page: int = Query(1, ge=1, description="页码(默认1、最小1)"),
 | ||
|         page_size: int = Query(10, ge=1, le=100, description="每页条数(默认10、1-100)"),
 | ||
|         name: Optional[str] = Query(None, description="敏感词关键词搜索(模糊匹配)")
 | ||
| ):
 | ||
|     conn = None
 | ||
|     cursor = None
 | ||
|     try:
 | ||
|         conn = db.get_connection()
 | ||
|         cursor = conn.cursor(dictionary=True)
 | ||
| 
 | ||
|         # 1. 构建查询条件(支持关键词搜索)
 | ||
|         where_clause = []
 | ||
|         params = []
 | ||
|         if name:
 | ||
|             where_clause.append("name LIKE %s")
 | ||
|             params.append(f"%{name}%")  # 模糊匹配关键词
 | ||
| 
 | ||
|         # 2. 查询总记录数(用于分页计算)
 | ||
|         count_sql = "SELECT COUNT(*) AS total FROM sensitives"
 | ||
|         if where_clause:
 | ||
|             count_sql += " WHERE " + " AND ".join(where_clause)
 | ||
|         cursor.execute(count_sql, params.copy())  # 复制参数列表、避免后续污染
 | ||
|         total = cursor.fetchone()["total"]
 | ||
| 
 | ||
|         # 3. 计算分页偏移量
 | ||
|         offset = (page - 1) * page_size
 | ||
| 
 | ||
|         # 4. 分页查询敏感词数据(按更新时间倒序、最新的在前)
 | ||
|         list_sql = "SELECT * FROM sensitives"
 | ||
|         if where_clause:
 | ||
|             list_sql += " WHERE " + " AND ".join(where_clause)
 | ||
|         # 排序+分页(LIMIT 条数 OFFSET 偏移量)
 | ||
|         list_sql += " ORDER BY updated_at DESC LIMIT %s OFFSET %s"
 | ||
|         # 补充分页参数(page_size和offset)
 | ||
|         params.extend([page_size, offset])
 | ||
| 
 | ||
|         cursor.execute(list_sql, params)
 | ||
|         sensitive_list = cursor.fetchall()
 | ||
| 
 | ||
|         # 5. 构造分页响应数据
 | ||
|         return APIResponse(
 | ||
|             code=200,
 | ||
|             message=f"敏感信息列表查询成功(共{total}条记录、当前第{page}页)",
 | ||
|             data=SensitiveListResponse(
 | ||
|                 total=total,
 | ||
|                 sensitives=[SensitiveResponse(**item) for item in sensitive_list]
 | ||
|             )
 | ||
|         )
 | ||
|     except MySQLError as e:
 | ||
|         raise HTTPException(
 | ||
|             status_code=500,
 | ||
|             detail=f"查询敏感信息列表失败: {str(e)}"
 | ||
|         ) from e
 | ||
|     finally:
 | ||
|         db.close_connection(conn, cursor)
 | ||
| 
 | ||
| 
 | ||
| # 删除敏感信息记录
 | ||
| @router.delete("/{sensitive_id}", response_model=APIResponse, summary="删除敏感信息记录")
 | ||
| @encrypt_response()
 | ||
| async def delete_sensitive(
 | ||
|         sensitive_id: int,
 | ||
|         current_user: UserResponse = Depends(get_current_user)  # 需登录认证
 | ||
| ):
 | ||
|     """
 | ||
|     删除敏感信息记录:
 | ||
|     - 需登录认证
 | ||
|     - 根据ID删除敏感信息记录
 | ||
|     - 返回删除成功信息
 | ||
|     """
 | ||
|     conn = None
 | ||
|     cursor = None
 | ||
|     try:
 | ||
|         conn = db.get_connection()
 | ||
|         cursor = conn.cursor(dictionary=True)
 | ||
| 
 | ||
|         # 1. 检查记录是否存在
 | ||
|         check_query = "SELECT id FROM sensitives WHERE id = %s"
 | ||
|         cursor.execute(check_query, (sensitive_id,))
 | ||
|         existing_sensitive = cursor.fetchone()
 | ||
|         if not existing_sensitive:
 | ||
|             raise HTTPException(
 | ||
|                 status_code=404,
 | ||
|                 detail=f"ID为 {sensitive_id} 的敏感信息记录不存在"
 | ||
|             )
 | ||
| 
 | ||
|         # 2. 执行删除操作
 | ||
|         delete_query = "DELETE FROM sensitives WHERE id = %s"
 | ||
|         cursor.execute(delete_query, (sensitive_id,))
 | ||
|         conn.commit()
 | ||
| 
 | ||
|         # 重新加载最新的敏感词
 | ||
|         set_forbidden_words(get_all_sensitive_words())
 | ||
| 
 | ||
|         return APIResponse(
 | ||
|             code=200,
 | ||
|             message=f"ID为 {sensitive_id} 的敏感信息记录删除成功",
 | ||
|             data=None
 | ||
|         )
 | ||
|     except MySQLError as e:
 | ||
|         if conn:
 | ||
|             conn.rollback()
 | ||
|         raise HTTPException(
 | ||
|             status_code=500,
 | ||
|             detail=f"删除敏感信息记录失败: {str(e)}"
 | ||
|         ) from e
 | ||
|     finally:
 | ||
|         db.close_connection(conn, cursor)
 | ||
| 
 | ||
| 
 | ||
| # 批量导入敏感信息(从txt文件)
 | ||
| @router.post("/batch-import", response_model=APIResponse, summary="批量导入敏感信息(从txt文件)")
 | ||
| @encrypt_response()
 | ||
| async def batch_import_sensitives(
 | ||
|         file: UploadFile = File(..., description="包含敏感词的txt文件,每行一个敏感词"),
 | ||
|         # current_user: UserResponse = Depends(get_current_user)  # 添加认证依赖
 | ||
| ):
 | ||
|     """
 | ||
|     批量导入敏感信息:
 | ||
|     - 需登录认证
 | ||
|     - 接收txt文件,文件中每行一个敏感词
 | ||
|     - 批量插入到数据库中(仅插入不存在的敏感词)
 | ||
|     - 返回导入结果统计
 | ||
|     """
 | ||
|     # 检查文件类型
 | ||
|     filename = file.filename or ""
 | ||
|     if not filename.lower().endswith(".txt"):
 | ||
|         raise HTTPException(
 | ||
|             status_code=400,
 | ||
|             detail=f"请上传txt格式的文件,当前文件格式: {filename.split('.')[-1] if '.' in filename else '未知'}"
 | ||
|         )
 | ||
| 
 | ||
|     # 检查文件大小
 | ||
|     file_size = await file.read(1)  # 读取1字节获取文件信息
 | ||
|     await file.seek(0)  # 重置文件指针
 | ||
|     if not file_size:  # 文件为空
 | ||
|         raise HTTPException(
 | ||
|             status_code=400,
 | ||
|             detail="上传的文件为空,请提供有效的敏感词文件"
 | ||
|         )
 | ||
| 
 | ||
|     conn = None
 | ||
|     cursor = None
 | ||
|     try:
 | ||
|         # 读取文件内容
 | ||
|         contents = await file.read()
 | ||
|         # 按行分割内容,处理不同操作系统的换行符
 | ||
|         lines = contents.decode("utf-8", errors="replace").splitlines()
 | ||
| 
 | ||
|         # 过滤空行和仅含空白字符的行
 | ||
|         sensitive_words = [line.strip() for line in lines if line.strip()]
 | ||
| 
 | ||
|         if not sensitive_words:
 | ||
|             return APIResponse(
 | ||
|                 code=200,
 | ||
|                 message="文件中没有有效的敏感词",
 | ||
|                 data={"imported": 0, "total": 0}
 | ||
|             )
 | ||
| 
 | ||
|         conn = db.get_connection()
 | ||
|         cursor = conn.cursor(dictionary=True)
 | ||
| 
 | ||
|         # 先查询数据库中已存在的敏感词
 | ||
|         query = "SELECT name FROM sensitives WHERE name IN (%s)"
 | ||
|         # 处理参数,根据敏感词数量生成占位符
 | ||
|         placeholders = ', '.join(['%s'] * len(sensitive_words))
 | ||
|         cursor.execute(query % placeholders, sensitive_words)
 | ||
|         existing_words = {row['name'] for row in cursor.fetchall()}
 | ||
| 
 | ||
|         # 过滤掉已存在的敏感词
 | ||
|         new_words = [word for word in sensitive_words if word not in existing_words]
 | ||
| 
 | ||
|         if not new_words:
 | ||
|             return APIResponse(
 | ||
|                 code=200,
 | ||
|                 message="所有敏感词均已存在于数据库中",
 | ||
|                 data={
 | ||
|                     "total": len(sensitive_words),
 | ||
|                     "imported": 0,
 | ||
|                     "duplicates": len(sensitive_words),
 | ||
|                     "message": f"共处理{len(sensitive_words)}个敏感词,全部已存在,未导入任何新敏感词"
 | ||
|                 }
 | ||
|             )
 | ||
| 
 | ||
|         # 批量插入新的敏感词
 | ||
|         insert_query = """
 | ||
|             INSERT INTO sensitives (name, created_at, updated_at)
 | ||
|             VALUES (%s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
 | ||
|         """
 | ||
| 
 | ||
|         # 准备参数列表
 | ||
|         params = [(word,) for word in new_words]
 | ||
| 
 | ||
|         # 执行批量插入
 | ||
|         cursor.executemany(insert_query, params)
 | ||
|         conn.commit()
 | ||
| 
 | ||
|         # 重新加载最新的敏感词
 | ||
|         set_forbidden_words(get_all_sensitive_words())
 | ||
| 
 | ||
|         return APIResponse(
 | ||
|             code=200,
 | ||
|             message=f"敏感词批量导入成功",
 | ||
|             data={
 | ||
|                 "total": len(sensitive_words),
 | ||
|                 "imported": len(new_words),
 | ||
|                 "duplicates": len(sensitive_words) - len(new_words),
 | ||
|                 "message": f"共处理{len(sensitive_words)}个敏感词,成功导入{len(new_words)}个,{len(sensitive_words) - len(new_words)}个已存在"
 | ||
|             }
 | ||
|         )
 | ||
| 
 | ||
|     except UnicodeDecodeError:
 | ||
|         raise HTTPException(
 | ||
|             status_code=400,
 | ||
|             detail="文件编码格式错误,请使用UTF-8编码的txt文件"
 | ||
|         )
 | ||
|     except MySQLError as e:
 | ||
|         if conn:
 | ||
|             conn.rollback()
 | ||
|         raise HTTPException(
 | ||
|             status_code=500,
 | ||
|             detail=f"批量导入敏感词失败: {str(e)}"
 | ||
|         ) from e
 | ||
|     except Exception as e:
 | ||
|         raise HTTPException(
 | ||
|             status_code=500,
 | ||
|             detail=f"处理文件时发生错误: {str(e)}"
 | ||
|         ) from e
 | ||
|     finally:
 | ||
|         await file.close()
 | ||
|         db.close_connection(conn, cursor)
 |