from fastapi import APIRouter, Depends, HTTPException, Query, File, UploadFile from mysql.connector import Error as MySQLError from typing import Optional from ds.db import db from encryption.encrypt_decorator import encrypt_response from schema.sensitive_schema import ( SensitiveCreateRequest, SensitiveResponse, SensitiveListResponse ) from schema.response_schema import APIResponse from middle.auth_middleware import get_current_user from schema.user_schema import UserResponse from service.ocr_service import set_forbidden_words from service.sensitive_service import get_all_sensitive_words router = APIRouter( prefix="/api/sensitives", tags=["敏感信息管理"] ) # 创建敏感信息记录 @router.post("", response_model=APIResponse, summary="创建敏感信息记录") @encrypt_response() async def create_sensitive( sensitive: SensitiveCreateRequest, current_user: UserResponse = Depends(get_current_user) ): conn = None cursor = None try: conn = db.get_connection() cursor = conn.cursor(dictionary=True) # 插入新敏感信息记录到数据库(不包含ID、由数据库自动生成) insert_query = """ INSERT INTO sensitives (name, created_at, updated_at) VALUES (%s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) """ cursor.execute(insert_query, (sensitive.name,)) conn.commit() # 获取刚插入记录的ID(使用LAST_INSERT_ID()函数) new_id = cursor.lastrowid # 查询刚创建的记录并返回 select_query = "SELECT * FROM sensitives WHERE id = %s" cursor.execute(select_query, (new_id,)) created_sensitive = cursor.fetchone() # 重新加载最新的敏感词 set_forbidden_words(get_all_sensitive_words()) return APIResponse( code=200, message="敏感信息记录创建成功", data=SensitiveResponse(**created_sensitive) ) except MySQLError as e: if conn: conn.rollback() raise HTTPException( status_code=500, detail=f"创建敏感信息记录失败: {str(e)}" ) from e finally: db.close_connection(conn, cursor) # 获取敏感信息分页列表 @router.get("", response_model=APIResponse, summary="获取敏感信息分页列表(支持关键词搜索)") @encrypt_response() async def get_sensitive_list( page: int = Query(1, ge=1, description="页码(默认1、最小1)"), page_size: int = Query(10, ge=1, le=100, description="每页条数(默认10、1-100)"), name: Optional[str] = Query(None, description="敏感词关键词搜索(模糊匹配)") ): conn = None cursor = None try: conn = db.get_connection() cursor = conn.cursor(dictionary=True) # 1. 构建查询条件(支持关键词搜索) where_clause = [] params = [] if name: where_clause.append("name LIKE %s") params.append(f"%{name}%") # 模糊匹配关键词 # 2. 查询总记录数(用于分页计算) count_sql = "SELECT COUNT(*) AS total FROM sensitives" if where_clause: count_sql += " WHERE " + " AND ".join(where_clause) cursor.execute(count_sql, params.copy()) # 复制参数列表、避免后续污染 total = cursor.fetchone()["total"] # 3. 计算分页偏移量 offset = (page - 1) * page_size # 4. 分页查询敏感词数据(按更新时间倒序、最新的在前) list_sql = "SELECT * FROM sensitives" if where_clause: list_sql += " WHERE " + " AND ".join(where_clause) # 排序+分页(LIMIT 条数 OFFSET 偏移量) list_sql += " ORDER BY updated_at DESC LIMIT %s OFFSET %s" # 补充分页参数(page_size和offset) params.extend([page_size, offset]) cursor.execute(list_sql, params) sensitive_list = cursor.fetchall() # 5. 构造分页响应数据 return APIResponse( code=200, message=f"敏感信息列表查询成功(共{total}条记录、当前第{page}页)", data=SensitiveListResponse( total=total, sensitives=[SensitiveResponse(**item) for item in sensitive_list] ) ) except MySQLError as e: raise HTTPException( status_code=500, detail=f"查询敏感信息列表失败: {str(e)}" ) from e finally: db.close_connection(conn, cursor) # 删除敏感信息记录 @router.delete("/{sensitive_id}", response_model=APIResponse, summary="删除敏感信息记录") @encrypt_response() async def delete_sensitive( sensitive_id: int, current_user: UserResponse = Depends(get_current_user) # 需登录认证 ): """ 删除敏感信息记录: - 需登录认证 - 根据ID删除敏感信息记录 - 返回删除成功信息 """ conn = None cursor = None try: conn = db.get_connection() cursor = conn.cursor(dictionary=True) # 1. 检查记录是否存在 check_query = "SELECT id FROM sensitives WHERE id = %s" cursor.execute(check_query, (sensitive_id,)) existing_sensitive = cursor.fetchone() if not existing_sensitive: raise HTTPException( status_code=404, detail=f"ID为 {sensitive_id} 的敏感信息记录不存在" ) # 2. 执行删除操作 delete_query = "DELETE FROM sensitives WHERE id = %s" cursor.execute(delete_query, (sensitive_id,)) conn.commit() # 重新加载最新的敏感词 set_forbidden_words(get_all_sensitive_words()) return APIResponse( code=200, message=f"ID为 {sensitive_id} 的敏感信息记录删除成功", data=None ) except MySQLError as e: if conn: conn.rollback() raise HTTPException( status_code=500, detail=f"删除敏感信息记录失败: {str(e)}" ) from e finally: db.close_connection(conn, cursor) # 批量导入敏感信息(从txt文件) @router.post("/batch-import", response_model=APIResponse, summary="批量导入敏感信息(从txt文件)") @encrypt_response() async def batch_import_sensitives( file: UploadFile = File(..., description="包含敏感词的txt文件,每行一个敏感词"), # current_user: UserResponse = Depends(get_current_user) # 添加认证依赖 ): """ 批量导入敏感信息: - 需登录认证 - 接收txt文件,文件中每行一个敏感词 - 批量插入到数据库中(仅插入不存在的敏感词) - 返回导入结果统计 """ # 检查文件类型 filename = file.filename or "" if not filename.lower().endswith(".txt"): raise HTTPException( status_code=400, detail=f"请上传txt格式的文件,当前文件格式: {filename.split('.')[-1] if '.' in filename else '未知'}" ) # 检查文件大小 file_size = await file.read(1) # 读取1字节获取文件信息 await file.seek(0) # 重置文件指针 if not file_size: # 文件为空 raise HTTPException( status_code=400, detail="上传的文件为空,请提供有效的敏感词文件" ) conn = None cursor = None try: # 读取文件内容 contents = await file.read() # 按行分割内容,处理不同操作系统的换行符 lines = contents.decode("utf-8", errors="replace").splitlines() # 过滤空行和仅含空白字符的行 sensitive_words = [line.strip() for line in lines if line.strip()] if not sensitive_words: return APIResponse( code=200, message="文件中没有有效的敏感词", data={"imported": 0, "total": 0} ) conn = db.get_connection() cursor = conn.cursor(dictionary=True) # 先查询数据库中已存在的敏感词 query = "SELECT name FROM sensitives WHERE name IN (%s)" # 处理参数,根据敏感词数量生成占位符 placeholders = ', '.join(['%s'] * len(sensitive_words)) cursor.execute(query % placeholders, sensitive_words) existing_words = {row['name'] for row in cursor.fetchall()} # 过滤掉已存在的敏感词 new_words = [word for word in sensitive_words if word not in existing_words] if not new_words: return APIResponse( code=200, message="所有敏感词均已存在于数据库中", data={ "total": len(sensitive_words), "imported": 0, "duplicates": len(sensitive_words), "message": f"共处理{len(sensitive_words)}个敏感词,全部已存在,未导入任何新敏感词" } ) # 批量插入新的敏感词 insert_query = """ INSERT INTO sensitives (name, created_at, updated_at) VALUES (%s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) """ # 准备参数列表 params = [(word,) for word in new_words] # 执行批量插入 cursor.executemany(insert_query, params) conn.commit() # 重新加载最新的敏感词 set_forbidden_words(get_all_sensitive_words()) return APIResponse( code=200, message=f"敏感词批量导入成功", data={ "total": len(sensitive_words), "imported": len(new_words), "duplicates": len(sensitive_words) - len(new_words), "message": f"共处理{len(sensitive_words)}个敏感词,成功导入{len(new_words)}个,{len(sensitive_words) - len(new_words)}个已存在" } ) except UnicodeDecodeError: raise HTTPException( status_code=400, detail="文件编码格式错误,请使用UTF-8编码的txt文件" ) except MySQLError as e: if conn: conn.rollback() raise HTTPException( status_code=500, detail=f"批量导入敏感词失败: {str(e)}" ) from e except Exception as e: raise HTTPException( status_code=500, detail=f"处理文件时发生错误: {str(e)}" ) from e finally: await file.close() db.close_connection(conn, cursor)