From de6d1b957a7db101e48b210754ddf1197fc31b16 Mon Sep 17 00:00:00 2001 From: ninghongbin <2409766686@qq.com> Date: Tue, 16 Sep 2025 20:17:48 +0800 Subject: [PATCH] =?UTF-8?q?yolo=E6=A8=A1=E5=9E=8B=E8=AF=86=E5=88=AB?= =?UTF-8?q?=E4=B8=8D=E5=88=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- core/all.py | 10 +- core/establish.py | 56 +++-- main.py | 85 ++----- schema/device_danger_schema.py | 33 +++ schema/device_schema.py | 1 - service/device_action_service.py | 2 +- service/device_danger_service.py | 267 ++++++++++++++++++++++ service/device_service.py | 2 +- service/face_service.py | 2 +- service/file_service.py | 370 +++++++++++-------------------- service/model_service.py | 2 +- service/sensitive_service.py | 2 +- service/user_service.py | 2 +- util/file_util.py | 46 ++-- ws/ws.py | 129 ++++------- 15 files changed, 568 insertions(+), 441 deletions(-) create mode 100644 schema/device_danger_schema.py create mode 100644 service/device_danger_service.py diff --git a/core/all.py b/core/all.py index 2e36f91..5a243c5 100644 --- a/core/all.py +++ b/core/all.py @@ -67,7 +67,7 @@ def detect(client_ip, frame): if full_save_path: cv2.imwrite(full_save_path, frame) print(f"✅ yolo违规图片已保存:{display_path}") # 日志也修正 - save_db(model_type="yolo", client_ip=client_ip, result=str(full_save_path)) + save_db(model_type="yolo", client_ip=client_ip, result=str(display_path)) return (True, yolo_result, "yolo") # 2. 人脸检测 @@ -77,17 +77,19 @@ def detect(client_ip, frame): if full_save_path: cv2.imwrite(full_save_path, frame) print(f"✅ face违规图片已保存:{display_path}") # 日志也修正 - save_db(model_type="face", client_ip=client_ip, result=str(full_save_path)) + save_db(model_type="face", client_ip=client_ip, result=str(display_path)) return (True, face_result, "face") # 3. OCR检测 ocr_flag, ocr_result = ocrDetect(frame) if ocr_flag: - full_save_path, display_path = get_image_save_path(model_type="ocr", client_ip=client_ip) # 这里改了 + full_save_path, display_path = get_image_save_path(model_type="ocr", client_ip=client_ip) + print(f"✅ ocr违规图片已保存:{display_path}") + # 这里改了 if full_save_path: cv2.imwrite(full_save_path, frame) print(f"✅ ocr违规图片已保存:{display_path}") # 日志也修正 - save_db(model_type="ocr", client_ip=client_ip, result=str(full_save_path)) + save_db(model_type="ocr", client_ip=client_ip, result=str(display_path)) return (True, ocr_result, "ocr") # 4. 无违规内容(不保存图片) diff --git a/core/establish.py b/core/establish.py index 2092319..aca1d5d 100644 --- a/core/establish.py +++ b/core/establish.py @@ -1,30 +1,29 @@ import datetime from pathlib import Path +from typing import List, Tuple from service.device_service import get_unique_client_ips + + def create_directory_structure(): """创建项目所需的目录结构,为所有客户端IP预创建基础目录""" try: # 1. 创建根目录下的resource文件夹(存在则跳过,不覆盖子内容) resource_dir = Path("resource") resource_dir.mkdir(exist_ok=True) - # print(f"确保resource目录存在: {resource_dir.absolute()}") # 2. 在resource下创建dect文件夹 dect_dir = resource_dir / "dect" dect_dir.mkdir(exist_ok=True) - # print(f"确保dect目录存在: {dect_dir.absolute()}") # 3. 在dect下创建三个模型文件夹 model_dirs = ["ocr", "face", "yolo"] for model in model_dirs: model_dir = dect_dir / model model_dir.mkdir(exist_ok=True) - # print(f"确保{model}模型目录存在: {model_dir.absolute()}") # 4. 调用外部方法获取所有客户端IP地址 try: - # 调用外部ip_read()方法获取所有客户端IP地址列表 all_ip_addresses = get_unique_client_ips() # 确保返回的是列表类型 @@ -58,7 +57,6 @@ def create_directory_structure(): # 递归创建目录(存在则跳过,不覆盖) month_dir.mkdir(parents=True, exist_ok=True) - # print(f"为客户端IP {ip} 创建/确保目录存在: {month_dir.absolute()}") except Exception as e: print(f"处理客户端IP和日期目录时发生错误: {str(e)}") @@ -67,52 +65,68 @@ def create_directory_structure(): print(f"创建基础目录结构时发生错误: {str(e)}") -def get_image_save_path(model_type: str, client_ip: str) -> tuple: +def get_image_save_path(model_type: str, client_ip: str) -> Tuple[str, str]: """ - 获取图片保存的「完整路径」和「显示用短路径」 + 获取图片保存的「本地完整路径」和「带路由前缀的显示路径」 参数: model_type: 模型类型,应为"ocr"、"face"或"yolo" client_ip: 检测到违禁的客户端IP地址(原始格式,如192.168.1.101) 返回: - 元组 (完整保存路径, 显示用短路径);若出错则返回 ("", "") + 元组 (本地完整保存路径, 带/api/file/前缀的显示路径);若出错则返回 ("", "") """ try: + # 验证模型类型有效性 + valid_models = ["ocr", "face", "yolo"] + if model_type not in valid_models: + raise ValueError(f"无效的模型类型: {model_type},必须是{valid_models}之一") + # 1. 验证客户端IP有效性(检查是否在已知IP列表中) all_ip_addresses = get_unique_client_ips() if not isinstance(all_ip_addresses, list): all_ip_addresses = [all_ip_addresses] valid_ips = [ip.strip() for ip in all_ip_addresses if ip.strip()] - if client_ip.strip() not in valid_ips: + client_ip_stripped = client_ip.strip() + if client_ip_stripped not in valid_ips: raise ValueError(f"客户端IP {client_ip} 不在已知IP列表中,无法保存文件") - # 2. 处理IP地址(与目录创建逻辑一致,将.替换为_) - safe_ip = client_ip.strip().replace(".", "_") + # 2. 处理IP地址(将.替换为_,避免路径问题) + safe_ip = client_ip_stripped.replace(".", "_") # 3. 获取当前日期和毫秒级时间戳(确保文件名唯一) now = datetime.datetime.now() current_year = str(now.year) - current_month = str(now.month) - current_day = str(now.day) - # 时间戳格式:年月日时分秒毫秒(如20250910143050123) - timestamp = now.strftime("%Y%m%d%H%M%S%f")[:-3] + current_month = str(now.month).zfill(2) # 确保月份为两位数 + current_day = str(now.day).zfill(2) # 确保日期为两位数 + timestamp = now.strftime("%Y%m%d%H%M%S%f")[:-3] # 取毫秒级时间戳 # 4. 定义基础目录(用于生成相对路径) - base_dir = Path("resource") / "dect" # 显示路径会去掉这个前缀 + base_dir = Path("resource") / "dect" # 构建日级目录(完整路径:resource/dect/{model}/{safe_ip}/{年}/{月}/{日}) day_dir = base_dir / model_type / safe_ip / current_year / current_month / current_day - day_dir.mkdir(parents=True, exist_ok=True) # 确保日目录存在 + day_dir.mkdir(parents=True, exist_ok=True) # 确保目录存在 # 5. 构建唯一文件名 image_filename = f"dect_{model_type}_{safe_ip}_{current_year}{current_month}{current_day}_{timestamp}.jpg" - # 6. 生成完整路径(用于实际保存图片)和显示路径(用于打印) - full_path = day_dir / image_filename # 完整路径:resource/dect/.../xxx.jpg - display_path = full_path.relative_to(base_dir) # 短路径:{model}/.../xxx.jpg(去掉resource/dect) + # 6. 生成「本地完整路径」(使用系统路径,但在字符串表示时统一为正斜杠) + local_full_path = day_dir / image_filename + # 转换为字符串并统一使用正斜杠 + local_full_path_str = str(local_full_path).replace("\\", "/") - return str(full_path), str(display_path) + # 7. 生成带路由前缀的显示路径(核心修改部分) + # 获取项目根目录(base_dir是resource/dect,向上两级即为项目根目录) + project_root = base_dir.parents[1] + # 计算相对于项目根目录的路径(包含resource/dect层级) + relative_path = local_full_path.relative_to(project_root) + # 转换为字符串并统一使用正斜杠 + relative_path_str = str(relative_path).replace("\\", "/") + # 拼接路由前缀 + routed_display_path = f"/api/file/{relative_path_str}" + + return local_full_path_str, routed_display_path except Exception as e: print(f"获取图片保存路径时发生错误: {str(e)}") diff --git a/main.py b/main.py index df05b4f..166159e 100644 --- a/main.py +++ b/main.py @@ -1,10 +1,8 @@ import uvicorn -import threading import time import os from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from service.file_service import app as flask_app # 原有业务导入 from core.all import load_model @@ -15,32 +13,12 @@ from service.sensitive_service import router as sensitive_router from service.face_service import router as face_router from service.device_service import router as device_router from service.model_service import router as model_router +from service.file_service import router as file_router +from service.device_danger_service import router as device_danger_router from ws.ws import ws_router, lifespan from core.establish import create_directory_structure - -# Flask 服务启动函数(不变) -def start_flask_service(): - try: - print(f"\n[Flask 服务] 准备启动,端口:5000") - print(f"[Flask 服务] 访问示例:http://服务器IP:5000/resource/dect/ocr/xxx.jpg\n") - - BASE_IMAGE_DIR = os.path.abspath(os.path.join("resource", "dect")) - if not os.path.exists(BASE_IMAGE_DIR): - print(f"[Flask 服务] 图片根目录不存在,创建:{BASE_IMAGE_DIR}") - os.makedirs(BASE_IMAGE_DIR, exist_ok=True) - - flask_app.run( - host="0.0.0.0", - port=5000, - debug=False, - use_reloader=False - ) - except Exception as e: - print(f"[Flask 服务] 启动失败:{str(e)}") - - -# 初始化 FastAPI 应用(新增 CORS 配置) +# 初始化 FastAPI 应用 app = FastAPI( title="内容安全审核后台", description="含图片访问服务和动态模型管理", @@ -48,38 +26,33 @@ app = FastAPI( lifespan=lifespan ) -# ------------------------------ -# 新增:完整 CORS 配置(解决跨域问题) -# ------------------------------ -# 1. 允许的前端域名(根据实际情况修改!本地开发通常是 http://localhost:8080 等) ALLOWED_ORIGINS = [ - # "http://localhost:8080", # 前端本地开发地址(必改,填实际前端地址) - # "http://127.0.0.1:8080", - # "http://服务器IP:8080", # 部署后前端地址(如适用) - "*" #表示允许所有域名(开发环境可用,生产环境不推荐) + "*" ] -# 2. 配置 CORS 中间件 +# 配置 CORS 中间件 app.add_middleware( CORSMiddleware, allow_origins=ALLOWED_ORIGINS, # 允许的前端域名 - allow_credentials=True, # 允许携带 Cookie(如需登录态则必开) - allow_methods=["*"], # 允许所有 HTTP 方法(包括 PUT/DELETE) - allow_headers=["*"], # 允许所有请求头(包括 Content-Type) + allow_credentials=True, # 允许携带 Cookie + allow_methods=["*"], # 允许所有 HTTP 方法 + allow_headers=["*"], # 允许所有请求头 ) -# 注册路由(不变) +# 注册路由 app.include_router(user_router) app.include_router(device_router) app.include_router(face_router) app.include_router(sensitive_router) -app.include_router(model_router) # 模型管理路由 +app.include_router(model_router) +app.include_router(file_router) +app.include_router(device_danger_router) app.include_router(ws_router) -# 注册全局异常处理器(不变) +# 注册全局异常处理器 app.add_exception_handler(Exception, global_exception_handler) -# 主服务启动入口(不变) +# 主服务启动入口 if __name__ == "__main__": create_directory_structure() print(f"[初始化] 目录结构创建完成") @@ -89,11 +62,11 @@ if __name__ == "__main__": os.makedirs(MODEL_SAVE_DIR, exist_ok=True) print(f"[初始化] 模型保存目录:{MODEL_SAVE_DIR}") - # # 模型路径配置 - # YOLO_MODEL_PATH = os.path.join("core", "models", "best.pt") - # OCR_CONFIG_PATH = os.path.join("core", "config", "config.yaml") - # print(f"[初始化] 默认YOLO模型路径:{YOLO_MODEL_PATH}") - # print(f"[初始化] OCR 配置路径:{OCR_CONFIG_PATH}") + # 确保图片目录存在(原Flask服务负责的目录) + BASE_IMAGE_DIR = os.path.abspath(os.path.join("resource", "dect")) + if not os.path.exists(BASE_IMAGE_DIR): + print(f"[初始化] 图片根目录不存在,创建:{BASE_IMAGE_DIR}") + os.makedirs(BASE_IMAGE_DIR, exist_ok=True) # 加载检测模型 try: @@ -105,23 +78,7 @@ if __name__ == "__main__": except Exception as e: print(f"[初始化] 模型加载警告:{str(e)}(服务仍启动)") - - - # 2. 启动 Flask 服务(子线程) - flask_thread = threading.Thread( - target=start_flask_service, - daemon=True - ) - flask_thread.start() - - # 等待 Flask 初始化 - time.sleep(1) - if flask_thread.is_alive(): - print(f"[Flask 服务] 启动成功(运行中)") - else: - print(f"[Flask 服务] 启动失败!图片访问不可用") - - # 3. 启动 FastAPI 主服务 + # 启动 FastAPI 主服务(仅使用8000端口) port = int(SERVER_CONFIG.get("port", 8000)) print(f"\n[FastAPI 服务] 准备启动,端口:{port}") print(f"[FastAPI 服务] 接口文档:http://服务器IP:{port}/docs\n") @@ -133,4 +90,4 @@ if __name__ == "__main__": workers=1, ws="websockets", reload=False - ) \ No newline at end of file + ) diff --git a/schema/device_danger_schema.py b/schema/device_danger_schema.py new file mode 100644 index 0000000..5731c29 --- /dev/null +++ b/schema/device_danger_schema.py @@ -0,0 +1,33 @@ +from datetime import datetime +from typing import Optional, List +from pydantic import BaseModel, Field + + +# ------------------------------ +# 请求模型 +# ------------------------------ +class DeviceDangerCreateRequest(BaseModel): + """设备危险记录创建请求模型""" + client_ip: str = Field(..., max_length=100, description="设备IP地址(必须与devices表中IP对应)") + type: str = Field(..., max_length=50, description="危险类型(如:病毒检测、端口异常、权限泄露等)") + result: str = Field(..., description="危险检测结果/处理结果(如:检测到木马病毒,已隔离;端口22异常开放,已关闭)") + + +# ------------------------------ +# 响应模型 +# ------------------------------ +class DeviceDangerResponse(BaseModel): + """单条设备危险记录响应模型(与device_danger表字段对齐,updated_at允许为null)""" + id: int = Field(..., description="危险记录主键ID") + client_ip: str = Field(..., max_length=100, description="设备IP地址") + type: str = Field(..., max_length=50, description="危险类型") + result: str = Field(..., description="危险检测结果/处理结果") + created_at: datetime = Field(..., description="记录创建时间(危险发生/检测时间)") + updated_at: Optional[datetime] = Field(None, description="记录更新时间(数据库中该字段当前为null)") + model_config = {"from_attributes": True} + + +class DeviceDangerListResponse(BaseModel): + """设备危险记录列表响应模型(带分页)""" + total: int = Field(..., description="危险记录总数") + dangers: List[DeviceDangerResponse] = Field(..., description="设备危险记录列表") \ No newline at end of file diff --git a/schema/device_schema.py b/schema/device_schema.py index f3af2ab..e2e9114 100644 --- a/schema/device_schema.py +++ b/schema/device_schema.py @@ -28,7 +28,6 @@ class DeviceResponse(BaseModel): params: Optional[str] = Field(None, description="扩展参数(JSON字符串)") created_at: datetime = Field(..., description="记录创建时间") updated_at: datetime = Field(..., description="记录更新时间") - model_config = {"from_attributes": True} # 支持从数据库结果直接转换 diff --git a/service/device_action_service.py b/service/device_action_service.py index e93a36d..6cd4b74 100644 --- a/service/device_action_service.py +++ b/service/device_action_service.py @@ -12,7 +12,7 @@ from schema.response_schema import APIResponse # 路由配置 router = APIRouter( - prefix="/device/actions", + prefix="/api/device/actions", tags=["设备操作记录"] ) diff --git a/service/device_danger_service.py b/service/device_danger_service.py new file mode 100644 index 0000000..ef7bb56 --- /dev/null +++ b/service/device_danger_service.py @@ -0,0 +1,267 @@ +import json +from datetime import date + +from fastapi import APIRouter, Query, HTTPException, Path +from mysql.connector import Error as MySQLError + +from ds.db import db +from encryption.encrypt_decorator import encrypt_response +from schema.device_danger_schema import ( + DeviceDangerCreateRequest, DeviceDangerResponse, DeviceDangerListResponse +) +from schema.response_schema import APIResponse + +# 路由初始化(前缀与设备管理相关,标签区分功能) +router = APIRouter( + prefix="/api/devices/dangers", + tags=["设备管理-危险记录"] +) + + +# ------------------------------ +# 内部工具方法 - 检查设备是否存在(复用设备表逻辑) +# ------------------------------ +def check_device_exist(client_ip: str) -> bool: + """ + 检查指定IP的设备是否在devices表中存在 + + :param client_ip: 设备IP地址 + :return: 存在返回True,不存在返回False + """ + if not client_ip: + raise ValueError("设备IP不能为空") + + conn = None + cursor = None + try: + conn = db.get_connection() + cursor = conn.cursor(dictionary=True) + + cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,)) + return cursor.fetchone() is not None + except MySQLError as e: + raise Exception(f"检查设备存在性失败: {str(e)}") from e + finally: + db.close_connection(conn, cursor) + + +# ------------------------------ +# 内部工具方法 - 创建设备危险记录(核心插入逻辑) +# ------------------------------ +def create_danger_record(danger_data: DeviceDangerCreateRequest) -> DeviceDangerResponse: + """ + 内部工具方法:向device_danger表插入新的危险记录 + + :param danger_data: 危险记录创建请求数据 + :return: 创建成功的危险记录模型对象 + """ + # 先检查设备是否存在 + if not check_device_exist(danger_data.client_ip): + raise ValueError(f"IP为 {danger_data.client_ip} 的设备不存在,无法创建危险记录") + + conn = None + cursor = None + try: + conn = db.get_connection() + cursor = conn.cursor(dictionary=True) + + # 插入危险记录(id自增,时间自动填充) + insert_query = """ + INSERT INTO device_danger + (client_ip, type, result, created_at, updated_at) + VALUES (%s, %s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) + """ + cursor.execute(insert_query, ( + danger_data.client_ip, + danger_data.type, + danger_data.result + )) + conn.commit() + + # 获取刚创建的记录(用自增ID查询) + danger_id = cursor.lastrowid + cursor.execute("SELECT * FROM device_danger WHERE id = %s", (danger_id,)) + new_danger = cursor.fetchone() + + return DeviceDangerResponse(**new_danger) + except MySQLError as e: + if conn: + conn.rollback() + raise Exception(f"插入危险记录失败: {str(e)}") from e + finally: + db.close_connection(conn, cursor) + + +# ------------------------------ +# 接口1:创建设备危险记录 +# ------------------------------ +@router.post("/add", response_model=APIResponse, summary="创建设备危险记录") +@encrypt_response() +async def add_device_danger(danger_data: DeviceDangerCreateRequest): + try: + # 调用内部方法创建记录 + new_danger = create_danger_record(danger_data) + + return APIResponse( + code=200, + message=f"设备[{danger_data.client_ip}]危险记录创建成功", + data=new_danger + ) + except ValueError as e: + # 设备不存在等业务异常 + raise HTTPException(status_code=400, detail=str(e)) from e + except Exception as e: + # 数据库异常等系统错误 + raise HTTPException(status_code=500, detail=str(e)) from e + + +# ------------------------------ +# 接口2:获取危险记录列表(支持多条件筛选+分页) +# ------------------------------ +@router.get("/", response_model=APIResponse, summary="获取设备危险记录列表(多条件筛选)") +@encrypt_response() +async def get_danger_list( + page: int = Query(1, ge=1, description="页码,默认第1页"), + page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100之间"), + client_ip: str = Query(None, max_length=100, description="按设备IP筛选"), + danger_type: str = Query(None, max_length=50, alias="type", description="按危险类型筛选"), + start_date: date = Query(None, description="按创建时间筛选(开始日期,格式YYYY-MM-DD)"), + end_date: date = Query(None, description="按创建时间筛选(结束日期,格式YYYY-MM-DD)") +): + conn = None + cursor = None + try: + conn = db.get_connection() + cursor = conn.cursor(dictionary=True) + + # 构建筛选条件 + where_clause = [] + params = [] + + if client_ip: + where_clause.append("client_ip = %s") + params.append(client_ip) + if danger_type: + where_clause.append("type = %s") + params.append(danger_type) + if start_date: + where_clause.append("DATE(created_at) >= %s") + params.append(start_date.strftime("%Y-%m-%d")) + if end_date: + where_clause.append("DATE(created_at) <= %s") + params.append(end_date.strftime("%Y-%m-%d")) + + # 1. 统计符合条件的总记录数 + count_query = "SELECT COUNT(*) AS total FROM device_danger" + if where_clause: + count_query += " WHERE " + " AND ".join(where_clause) + cursor.execute(count_query, params) + total = cursor.fetchone()["total"] + + # 2. 分页查询记录(按创建时间倒序,最新的在前) + offset = (page - 1) * page_size + list_query = "SELECT * FROM device_danger" + if where_clause: + list_query += " WHERE " + " AND ".join(where_clause) + list_query += " ORDER BY created_at DESC LIMIT %s OFFSET %s" + params.extend([page_size, offset]) # 追加分页参数 + + cursor.execute(list_query, params) + danger_list = cursor.fetchall() + + # 转换为响应模型 + return APIResponse( + code=200, + message="获取危险记录列表成功", + data=DeviceDangerListResponse( + total=total, + dangers=[DeviceDangerResponse(**item) for item in danger_list] + ) + ) + except MySQLError as e: + raise HTTPException(status_code=500, detail=f"查询危险记录失败: {str(e)}") from e + finally: + db.close_connection(conn, cursor) + + +# ------------------------------ +# 接口3:获取单个设备的所有危险记录 +# ------------------------------ +@router.get("/device/{client_ip}", response_model=APIResponse, summary="获取单个设备的所有危险记录") +# @encrypt_response() +async def get_device_dangers( + client_ip: str = Path(..., max_length=100, description="设备IP地址"), + page: int = Query(1, ge=1, description="页码,默认第1页"), + page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100之间") +): + # 先检查设备是否存在 + if not check_device_exist(client_ip): + raise HTTPException(status_code=404, detail=f"IP为 {client_ip} 的设备不存在") + + conn = None + cursor = None + try: + conn = db.get_connection() + cursor = conn.cursor(dictionary=True) + + # 1. 统计该设备的危险记录总数 + count_query = "SELECT COUNT(*) AS total FROM device_danger WHERE client_ip = %s" + cursor.execute(count_query, (client_ip,)) + total = cursor.fetchone()["total"] + + # 2. 分页查询该设备的危险记录 + offset = (page - 1) * page_size + list_query = """ + SELECT * FROM device_danger + WHERE client_ip = %s + ORDER BY created_at DESC + LIMIT %s OFFSET %s + """ + cursor.execute(list_query, (client_ip, page_size, offset)) + danger_list = cursor.fetchall() + + return APIResponse( + code=200, + message=f"获取设备[{client_ip}]危险记录成功(共{total}条)", + data=DeviceDangerListResponse( + total=total, + dangers=[DeviceDangerResponse(**item) for item in danger_list] + ) + ) + except MySQLError as e: + raise HTTPException(status_code=500, detail=f"查询设备[{client_ip}]危险记录失败: {str(e)}") from e + finally: + db.close_connection(conn, cursor) + + +# ------------------------------ +# 接口4:根据ID获取单个危险记录详情 +# ------------------------------ +@router.get("/{danger_id}", response_model=APIResponse, summary="根据ID获取单个危险记录详情") +@encrypt_response() +async def get_danger_detail( + danger_id: int = Path(..., ge=1, description="危险记录ID") +): + conn = None + cursor = None + try: + conn = db.get_connection() + cursor = conn.cursor(dictionary=True) + + # 查询单个危险记录 + query = "SELECT * FROM device_danger WHERE id = %s" + cursor.execute(query, (danger_id,)) + danger = cursor.fetchone() + + if not danger: + raise HTTPException(status_code=404, detail=f"ID为 {danger_id} 的危险记录不存在") + + return APIResponse( + code=200, + message="获取危险记录详情成功", + data=DeviceDangerResponse(**danger) + ) + except MySQLError as e: + raise HTTPException(status_code=500, detail=f"查询危险记录详情失败: {str(e)}") from e + finally: + db.close_connection(conn, cursor) \ No newline at end of file diff --git a/service/device_service.py b/service/device_service.py index 39e2ad8..c772133 100644 --- a/service/device_service.py +++ b/service/device_service.py @@ -13,7 +13,7 @@ from schema.device_schema import ( from schema.response_schema import APIResponse router = APIRouter( - prefix="/devices", + prefix="/api/devices", tags=["设备管理"] ) diff --git a/service/face_service.py b/service/face_service.py index 5bfd80b..409c22b 100644 --- a/service/face_service.py +++ b/service/face_service.py @@ -17,7 +17,7 @@ from schema.response_schema import APIResponse from util.face_util import add_binary_data, get_average_feature from util.file_util import save_face_to_up_images -router = APIRouter(prefix="/faces", tags=["人脸管理"]) +router = APIRouter(prefix="/api/faces", tags=["人脸管理"]) # ------------------------------ diff --git a/service/file_service.py b/service/file_service.py index f8a8f0f..232cbc7 100644 --- a/service/file_service.py +++ b/service/file_service.py @@ -1,276 +1,174 @@ -from flask import Flask, send_from_directory, abort, request +from fastapi import FastAPI, HTTPException, Request, Depends, APIRouter +from fastapi.responses import FileResponse import os import logging from functools import wraps from pathlib import Path -from flask_cors import CORS +from fastapi.middleware.cors import CORSMiddleware +from typing import Annotated -# 配置日志(保持原有格式) -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) -# 初始化 Flask 应用(供 main.py 导入) -app = Flask(__name__) - -# ------------------------------ -# 核心修改:与 FastAPI 对齐的跨域配置 -# ------------------------------ -# 1. 允许的前端域名(根据实际环境修改,生产环境删除 "*") -ALLOWED_ORIGINS = [ - # "http://localhost:8080", # 本地前端开发地址 - # "http://127.0.0.1:8080", - # "http://服务器IP:8080", # 部署后前端地址 - "*" -] - -# 2. 配置 CORS(与 FastAPI 规则完全对齐) -CORS( - app, - resources={ - r"/*": { - "origins": ALLOWED_ORIGINS, - "allow_credentials": True, - "methods": ["*"], - "allow_headers": ["*"], - } - }, +router = APIRouter( + prefix="/api/file", + tags=["文件管理"] ) # ------------------------------ -# 核心路径配置(关键修改:修正 PROJECT_ROOT 计算) -# 原问题:file_service.py 在 service 文件夹内,需向上跳一级到项目根目录 +# 4. 路径配置 # ------------------------------ -CURRENT_FILE_PATH = Path(__file__).resolve() # 当前文件路径:service/file_service.py -PROJECT_ROOT = CURRENT_FILE_PATH.parent.parent # 项目根目录(service 文件夹的父目录) -# 资源目录(现在正确指向项目根目录下的文件夹) -BASE_IMAGE_DIR_DECT = str((PROJECT_ROOT / "resource" / "dect").resolve()) # 根目录/resource/dect -BASE_IMAGE_DIR_UP_IMAGES = str((PROJECT_ROOT / "up_images").resolve()) # 根目录/up_images -BASE_MODEL_DIR = str((PROJECT_ROOT / "resource" / "models").resolve()) # 根目录/resource/models +CURRENT_FILE_PATH = Path(__file__).resolve() +PROJECT_ROOT = CURRENT_FILE_PATH.parent.parent # 项目根目录 +# 资源目录定义 +BASE_IMAGE_DIR_DECT = str((PROJECT_ROOT / "resource" / "dect").resolve()) # 检测图片目录 +BASE_IMAGE_DIR_UP_IMAGES = str((PROJECT_ROOT / "up_images").resolve()) # 人脸图片目录 +BASE_MODEL_DIR = str((PROJECT_ROOT / "resource" / "models").resolve()) # 模型文件目录 + +# 确保资源目录存在 +for dir_path in [BASE_IMAGE_DIR_DECT, BASE_IMAGE_DIR_UP_IMAGES, BASE_MODEL_DIR]: + if not os.path.exists(dir_path): + os.makedirs(dir_path, exist_ok=True) + print(f"[创建目录] {dir_path}") # ------------------------------ -# 安全检查装饰器(不变) +# 5. 安全依赖项(替代Flask装饰器) # ------------------------------ def safe_path_check(root_dir: str): - def decorator(func): - @wraps(func) - def wrapper(*args, **kwargs): - resource_path = kwargs.get('resource_path', '').strip() - # 统一路径分隔符(兼容 Windows \ 和 Linux /) - resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep) - # 拼接完整路径(防止路径遍历) - full_file_path = os.path.abspath(os.path.join(root_dir, resource_path)) - logger.debug( - f"[Flask 安全检查] 请求路径:{resource_path} | 完整路径:{full_file_path} | 根目录:{root_dir}" - ) - - # 1. 禁止路径遍历(确保请求文件在根目录内) - if not full_file_path.startswith(root_dir): - logger.warning( - f"[Flask 安全拦截] 非法路径遍历!IP:{request.remote_addr} | 请求路径:{resource_path}" - ) - abort(403) - - # 2. 检查文件存在且为有效文件(非目录) - if not os.path.exists(full_file_path) or not os.path.isfile(full_file_path): - logger.warning( - f"[Flask 资源错误] 文件不存在/非文件!IP:{request.remote_addr} | 路径:{full_file_path}" - ) - abort(404) - - # 3. 限制文件大小(模型200MB,图片10MB) - max_size = 200 * 1024 * 1024 if "models" in root_dir else 10 * 1024 * 1024 - if os.path.getsize(full_file_path) > max_size: - logger.warning( - f"[Flask 大小超限] 文件超过{max_size//1024//1024}MB!IP:{request.remote_addr} | 路径:{full_file_path}" - ) - abort(413) - - # 安全检查通过,传递根目录给视图函数 - return func(*args, **kwargs, root_dir=root_dir) - return wrapper - return decorator - -# ------------------------------ -# 1. 模型下载接口(/model/download/*) -# ------------------------------ -@app.route('/model/download/') -@safe_path_check(root_dir=BASE_MODEL_DIR) -def download_model(resource_path, root_dir): - try: + """ + 安全路径校验依赖项: + 1. 禁止路径遍历(确保请求文件在根目录内) + 2. 校验文件存在且为有效文件(非目录) + 3. 限制文件大小(模型200MB,图片10MB) + """ + async def dependency(request: Request, resource_path: str): + # 统一路径分隔符 resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep) - dir_path, file_name = os.path.split(resource_path) - full_dir = os.path.abspath(os.path.join(root_dir, dir_path)) + # 拼接完整路径 + full_file_path = os.path.abspath(os.path.join(root_dir, resource_path)) - # 仅允许 .pt 格式(YOLO 模型) - if not file_name.lower().endswith('.pt'): - logger.warning( - f"[Flask 格式错误] 非 .pt 模型文件!IP:{request.remote_addr} | 文件名:{file_name}" - ) - abort(415) + # 校验1:禁止路径遍历 + if not full_file_path.startswith(root_dir): + print(f"[安全检查] 禁止路径遍历!IP:{request.client.host} | 请求路径:{resource_path}") + raise HTTPException(status_code=403, detail="非法路径访问") - logger.info( - f"[Flask 模型下载] 成功请求!IP:{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}" + # 校验2:文件存在且为有效文件 + if not os.path.exists(full_file_path) or not os.path.isfile(full_file_path): + print(f"[资源错误] 文件不存在/非文件!IP:{request.client.host} | 路径:{full_file_path}") + raise HTTPException(status_code=404, detail="文件不存在") + + # 校验3:文件大小限制 + max_size = 200 * 1024 * 1024 if "models" in root_dir else 10 * 1024 * 1024 + if os.path.getsize(full_file_path) > max_size: + print(f"[大小超限] 文件超过{max_size//1024//1024}MB!IP:{request.client.host} | 路径:{full_file_path}") + raise HTTPException(status_code=413, detail=f"文件大小超过限制({max_size//1024//1024}MB)") + + return full_file_path + return dependency + +# ------------------------------ +# 6. 核心接口 +# ------------------------------ +@router.get("/model/download/{resource_path:path}", summary="模型下载接口") +async def download_model( + resource_path: str, + full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_MODEL_DIR))], + request: Request +): + """模型下载接口(仅允许 .pt 格式,强制浏览器下载)""" + try: + dir_path, file_name = os.path.split(full_file_path) + + # 额外校验:仅允许 YOLO 模型格式(.pt) + if not file_name.lower().endswith(".pt"): + print(f"[格式错误] 非 .pt 模型文件!IP:{request.client.host} | 文件名:{file_name}") + raise HTTPException(status_code=415, detail="仅支持 .pt 格式的模型文件") + + print(f"[模型下载] 尝试下载!IP:{request.client.host} | 文件:{file_name} | 目录:{dir_path}") + + # 强制下载 + return FileResponse( + full_file_path, + filename=file_name, + media_type="application/octet-stream" ) - - # 强制浏览器下载(而非预览) - return send_from_directory( - full_dir, - file_name, - as_attachment=True, - mimetype="application/octet-stream" - ) - + except HTTPException: + raise except Exception as e: - logger.error( - f"[Flask 模型下载异常] IP:{request.remote_addr} | 错误:{str(e)}" - ) - abort(500) + print(f"[下载异常] IP:{request.client.host} | 错误:{str(e)}") + raise HTTPException(status_code=500, detail="服务器内部错误") -# ------------------------------ -# 2. 人脸图片访问接口(/up_images/*) -# ------------------------------ -@app.route('/up_images/') -@safe_path_check(root_dir=BASE_IMAGE_DIR_UP_IMAGES) -def get_face_image(resource_path, root_dir): + +@router.get("/up_images/{resource_path:path}", summary="人脸图片访问接口") +async def get_face_image( + resource_path: str, + full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_IMAGE_DIR_UP_IMAGES))], + request: Request +): + """人脸图片访问接口(允许浏览器预览,仅支持常见图片格式)""" try: - resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep) - dir_path, file_name = os.path.split(resource_path) - full_dir = os.path.abspath(os.path.join(root_dir, dir_path)) + dir_path, file_name = os.path.split(full_file_path) - # 仅允许常见图片格式 - allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp') + # 图片格式校验 + allowed_ext = (".png", ".jpg", ".jpeg", ".gif", ".bmp") if not file_name.lower().endswith(allowed_ext): - logger.warning( - f"[Flask 格式错误] 非图片文件!IP:{request.remote_addr} | 文件名:{file_name}" - ) - abort(415) + print(f"[格式错误] 非图片文件!IP:{request.client.host} | 文件名:{file_name}") + raise HTTPException(status_code=415, detail="仅支持常见图片格式") - logger.info( - f"[Flask 人脸图片] 成功请求!IP:{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}" - ) - - # 允许浏览器预览图片 - return send_from_directory(full_dir, file_name, as_attachment=False) + print(f"[人脸图片] 尝试访问!IP:{request.client.host} | 文件:{file_name} | 目录:{dir_path}") + return FileResponse(full_file_path) + except HTTPException: + raise except Exception as e: - logger.error( - f"[Flask 人脸图片异常] IP:{request.remote_addr} | 错误:{str(e)}" - ) - abort(500) + print(f"[人脸图片异常] IP:{request.client.host} | 错误:{str(e)}") -# ------------------------------ -# 3. 检测图片访问接口(/resource/dect/*) -# ------------------------------ -@app.route('/resource/dect/') -@safe_path_check(root_dir=BASE_IMAGE_DIR_DECT) -def get_dect_image(resource_path, root_dir): + +@router.get("/resource/dect/{resource_path:path}", summary="检测图片访问接口") +async def get_dect_image( + resource_path: str, + full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_IMAGE_DIR_DECT))], + request: Request +): + """检测图片访问接口(允许浏览器预览,仅支持常见图片格式)""" try: - resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep) - dir_path, file_name = os.path.split(resource_path) - full_dir = os.path.abspath(os.path.join(root_dir, dir_path)) + dir_path, file_name = os.path.split(full_file_path) - # 仅允许常见图片格式 - allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp') + # 图片格式校验 + allowed_ext = (".png", ".jpg", ".jpeg", ".gif", ".bmp") if not file_name.lower().endswith(allowed_ext): - logger.warning( - f"[Flask 格式错误] 非图片文件!IP:{request.remote_addr} | 文件名:{file_name}" - ) - abort(415) + print(f"[格式错误] 非图片文件!IP:{request.client.host} | 文件名:{file_name}") + raise HTTPException(status_code=415, detail="仅支持常见图片格式") - logger.info( - f"[Flask 检测图片] 成功请求!IP:{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}" - ) - - return send_from_directory(full_dir, file_name, as_attachment=False) + print(f"[检测图片] 尝试访问!IP:{request.client.host} | 文件:{file_name} | 目录:{dir_path}") + return FileResponse(full_file_path) + except HTTPException: + raise except Exception as e: - logger.error( - f"[Flask 检测图片异常] IP:{request.remote_addr} | 错误:{str(e)}" - ) - abort(500) + print(f"[检测图片异常] IP:{request.client.host} | 错误:{str(e)}") + raise HTTPException(status_code=500, detail="服务器内部错误") -# ------------------------------ -# 4. 兼容旧图片接口(/images/* → 映射到 /resource/dect/*) -# ------------------------------ -@app.route('/images/') -@safe_path_check(root_dir=BASE_IMAGE_DIR_DECT) -def get_compatible_image(resource_path, root_dir): + +@router.get("/images/{resource_path:path}", summary="兼容旧接口") +async def get_compatible_image( + resource_path: str, + full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_IMAGE_DIR_DECT))], + request: Request +): + """兼容旧接口(/images/* → 映射到 /resource/dect/*,保留历史兼容性)""" try: - resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep) - dir_path, file_name = os.path.split(resource_path) - full_dir = os.path.abspath(os.path.join(root_dir, dir_path)) + dir_path, file_name = os.path.split(full_file_path) - allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp') + # 图片格式校验 + allowed_ext = (".png", ".jpg", ".jpeg", ".gif", ".bmp") if not file_name.lower().endswith(allowed_ext): - logger.warning( - f"[Flask 格式错误] 非图片文件!IP:{request.remote_addr} | 文件名:{file_name}" - ) - abort(415) - - logger.info( - f"[Flask 兼容图片] 成功请求!IP:{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}" - ) - - return send_from_directory(full_dir, file_name, as_attachment=False) + print(f"[格式错误] 非图片文件!IP:{request.client.host} | 文件名:{file_name}") + raise HTTPException(status_code=415, detail="仅支持常见图片格式") + print(f"[兼容图片] 尝试访问!IP:{request.client.host} | 文件:{file_name} | 目录:{dir_path}") + return FileResponse(full_file_path) + except HTTPException: + raise except Exception as e: - logger.error( - f"[Flask 兼容图片异常] IP:{request.remote_addr} | 错误:{str(e)}" - ) - abort(500) - -# ------------------------------ -# 全局错误处理器(不变) -# ------------------------------ -@app.errorhandler(403) -def forbidden_error(error): - return "❌ 禁止访问:路径非法(可能存在路径遍历)或无权限", 403 - -@app.errorhandler(404) -def not_found_error(error): - return "❌ 资源不存在:请检查URL路径(IP、目录、文件名)是否正确", 404 - -@app.errorhandler(413) -def too_large_error(error): - return "❌ 文件过大:图片最大10MB,模型最大200MB", 413 - -@app.errorhandler(415) -def unsupported_type_error(error): - return "❌ 不支持的文件类型:图片支持png/jpg/jpeg/gif/bmp,模型仅支持pt", 415 - -@app.errorhandler(500) -def server_error(error): - return "❌ 服务器内部错误:请联系管理员查看后台日志", 500 - -# ------------------------------ -# Flask 独立启动入口(供测试,实际由 main.py 子线程启动) -# ------------------------------ -if __name__ == '__main__': - # 确保所有资源目录存在 - required_dirs = [ - (BASE_IMAGE_DIR_DECT, "检测图片目录"), - (BASE_IMAGE_DIR_UP_IMAGES, "人脸图片目录"), - (BASE_MODEL_DIR, "模型文件目录") - ] - for dir_path, dir_desc in required_dirs: - if not os.path.exists(dir_path): - logger.info(f"[Flask 初始化] {dir_desc}不存在,创建:{dir_path}") - os.makedirs(dir_path, exist_ok=True) - - # 启动提示 - logger.info("\n[Flask 服务启动成功!] 支持的接口:") - logger.info(f"1. 模型下载 → http://服务器IP:5000/model/download/resource/models/xxx.pt") - logger.info(f"2. 人脸图片 → http://服务器IP:5000/up_images/xxx.jpg") - logger.info(f"3. 检测图片 → http://服务器IP:5000/resource/dect/xxx.jpg 或 http://服务器IP:5000/images/xxx.jpg\n") - - # 启动服务(禁用 debug 和自动重载) - app.run( - host="0.0.0.0", - port=5000, - debug=False, - use_reloader=False - ) \ No newline at end of file + print(f"[兼容图片异常] IP:{request.client.host} | 错误:{str(e)}") + raise HTTPException(status_code=500, detail="服务器内部错误") diff --git a/service/model_service.py b/service/model_service.py index 80c9edf..4a0f068 100644 --- a/service/model_service.py +++ b/service/model_service.py @@ -38,7 +38,7 @@ _yolo_model = None _current_model_version = None # 模型版本标识 _current_conf_threshold = 0.8 # 默认置信度初始值 -router = APIRouter(prefix="/models", tags=["模型管理"]) +router = APIRouter(prefix="/api/models", tags=["模型管理"]) # 服务重启核心工具函数(保持不变) diff --git a/service/sensitive_service.py b/service/sensitive_service.py index 955bc62..b253ce5 100644 --- a/service/sensitive_service.py +++ b/service/sensitive_service.py @@ -16,7 +16,7 @@ from schema.user_schema import UserResponse # 创建敏感信息接口路由(前缀 /sensitives、标签用于 Swagger 分类) router = APIRouter( - prefix="/sensitives", + prefix="/api/sensitives", tags=["敏感信息管理"] ) diff --git a/service/user_service.py b/service/user_service.py index c30d5b8..53a6128 100644 --- a/service/user_service.py +++ b/service/user_service.py @@ -18,7 +18,7 @@ from middle.auth_middleware import ( # 创建用户接口路由(前缀 /users、标签用于 Swagger 分类) router = APIRouter( - prefix="/users", + prefix="/api/users", tags=["用户管理"] ) diff --git a/util/file_util.py b/util/file_util.py index 2e9be83..4dbf4b1 100644 --- a/util/file_util.py +++ b/util/file_util.py @@ -12,7 +12,8 @@ def save_face_to_up_images( ) -> Dict[str, str]: """ 保存人脸图片到 `/up_images/用户IP/人脸名字/` 路径 - 确保db_path以up_images开头,且统一使用正斜杠 + 确保db_path以 /api/file/up_images 开头,且统一使用正斜杠 + 本地不创建/api/file/文件夹,仅URL访问时使用该前缀路由 参数: client_ip: 客户端IP(原始格式,如192.168.1.101) @@ -21,10 +22,10 @@ def save_face_to_up_images( image_format: 图片格式(默认jpg) 返回: - 字典:success(是否成功)、db_path(存数据库的相对路径)、local_abs_path(本地绝对路径)、msg(提示) + 字典:success(是否成功)、db_path(存数据库的路径,带/api/file/前缀)、local_abs_path(本地绝对路径)、msg(提示) """ try: - # 1. 基础参数校验 + # 1. 基础参数校验(不变) if not client_ip.strip(): return {"success": False, "db_path": "", "local_abs_path": "", "msg": "客户端IP不能为空"} if not image_bytes: @@ -32,53 +33,54 @@ def save_face_to_up_images( if image_format.lower() not in ["jpg", "jpeg", "png"]: return {"success": False, "db_path": "", "local_abs_path": "", "msg": "仅支持jpg/jpeg/png格式"} - # 2. 处理特殊字符(避免路径错误) + # 2. 处理特殊字符(避免路径错误)(不变) safe_ip = client_ip.strip().replace(".", "_") # IP中的.替换为_ safe_face_name = face_name.strip() if (face_name and face_name.strip()) else "未命名" safe_face_name = "".join([c for c in safe_face_name if c not in r'\/:*?"<>|']) # 过滤非法字符 # 3. 构建根目录(强制转为绝对路径,避免相对路径混淆) - root_dir = Path("up_images").resolve() # 转为绝对路径(如D:/Git/bin/video/up_images) + root_dir = Path("up_images").resolve() if not root_dir.exists(): root_dir.mkdir(parents=True, exist_ok=True) print(f"[FileUtil] 已创建up_images根目录:{root_dir}") - # 4. 构建文件层级路径(确保在root_dir子目录下) + # 4. 构建文件层级路径(确保在root_dir子目录下)(不变) ip_dir = root_dir / safe_ip face_name_dir = ip_dir / safe_face_name - face_name_dir.mkdir(parents=True, exist_ok=True) # 自动创建目录 - print(f"[FileUtil] 图片存储目录:{face_name_dir}") + face_name_dir.mkdir(parents=True, exist_ok=True) + print(f"[FileUtil] 图片存储目录(本地):{face_name_dir}") - # 5. 生成唯一文件名(毫秒级时间戳) + # 5. 生成唯一文件名(毫秒级时间戳)(不变) timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3] + image_filename = f"face_{safe_ip}_{timestamp}.{image_format.lower()}" + local_abs_path = face_name_dir / image_filename - # 6. 计算路径(确保所有路径都是绝对路径且在root_dir下) - local_abs_path = face_name_dir / image_filename # 绝对路径 - - # 验证路径是否在root_dir下(防止路径穿越攻击) if not local_abs_path.resolve().is_relative_to(root_dir.resolve()): raise Exception(f"图片路径不在up_images根目录下(安全校验失败):{local_abs_path}") - # 数据库存储路径:强制包含up_images前缀,统一使用正斜杠 - relative_path = local_abs_path.relative_to(root_dir.parent) # 相对于root_dir的父目录 - db_path = str(relative_path).replace("\\", "/") # 此时会包含up_images部分 + # 数据库存储路径:核心修改——在原有relative_path前添加 /api/file/ 前缀 + relative_path = local_abs_path.relative_to(root_dir.parent) - # 7. 写入图片文件 + relative_path_str = str(relative_path).replace("\\", "/") + # 2. 再拼接/api/file/前缀 + db_path = f"/api/file/{relative_path_str}" + + # 7. 写入图片文件(不变) with open(local_abs_path, "wb") as f: f.write(image_bytes) print(f"[FileUtil] 图片保存成功:") - print(f" 数据库路径:{db_path}") - print(f" 本地绝对路径:{local_abs_path}") + print(f" 数据库路径(带/api/file/前缀):{db_path}") + print(f" 本地绝对路径(无/api/file/):{local_abs_path}") return { "success": True, - "db_path": db_path, # 格式为 up_images/192_168_110_31/小龙/xxx.jpg - "local_abs_path": str(local_abs_path), # 本地绝对路径(完整路径) + "db_path": db_path, + "local_abs_path": str(local_abs_path), "msg": "图片保存成功" } except Exception as e: error_msg = f"图片保存失败:{str(e)}" print(f"[FileUtil] 错误:{error_msg}") - return {"success": False, "db_path": "", "local_abs_path": "", "msg": error_msg} + return {"success": False, "db_path": "", "local_abs_path": "", "msg": error_msg} \ No newline at end of file diff --git a/ws/ws.py b/ws/ws.py index c1b0665..49a33e3 100644 --- a/ws/ws.py +++ b/ws/ws.py @@ -17,17 +17,16 @@ from service.device_action_service import add_device_action from schema.device_action_schema import DeviceActionCreate from core.all import detect, load_model -# -------------------------- 1. AES 加密解密工具(固定密钥)-------------------------- +# -------------------------- 1. AES 加密工具(仅用于服务器向客户端发送消息)-------------------------- AES_SECRET_KEY = b"jr1vA6tfWMHOYi6UXw67UuO6fdak2rMa" # 约定密钥(32字节) AES_BLOCK_SIZE = 16 # AES固定块大小 def aes_encrypt(plaintext: str) -> dict: - """AES-CBC加密:返回{密文, IV, 算法标识}(均Base64编码)""" + """AES-CBC加密:返回{密文, IV, 算法标识}(均Base64编码)- 仅用于服务器发消息""" try: iv = os.urandom(AES_BLOCK_SIZE) # 随机IV(16字节) cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv) - # 明文填充+加密+Base64编码 padded_plaintext = pad(plaintext.encode("utf-8"), AES_BLOCK_SIZE) ciphertext = base64.b64encode(cipher.encrypt(padded_plaintext)).decode("utf-8") iv_base64 = base64.b64encode(iv).decode("utf-8") @@ -40,20 +39,6 @@ def aes_encrypt(plaintext: str) -> dict: raise Exception(f"AES加密失败: {str(e)}") from e -def aes_decrypt(encrypted_dict: dict) -> str: - """AES-CBC解密:输入加密字典,返回原始文本""" - try: - # Base64解码密文和IV - ciphertext = base64.b64decode(encrypted_dict["ciphertext"]) - iv = base64.b64decode(encrypted_dict["iv"]) - # 解密+去除填充 - cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv) - decrypted = unpad(cipher.decrypt(ciphertext), AES_BLOCK_SIZE).decode("utf-8") - return decrypted - except Exception as e: - raise Exception(f"AES解密失败: {str(e)}") from e - - # -------------------------- 2. 配置常量(保持原有)-------------------------- HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒) HEARTBEAT_TIMEOUT = 600 # 客户端超时阈值(秒) @@ -72,7 +57,7 @@ def get_current_time_file_str() -> str: return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") -# -------------------------- 4. 客户端连接封装(新增消息加密)-------------------------- +# -------------------------- 4. 客户端连接封装(服务器发消息仍加密,接收消息改明文)-------------------------- class ClientConnection: def __init__(self, websocket: WebSocket, client_ip: str): self.websocket = websocket @@ -96,28 +81,25 @@ class ClientConnection: return self.consumer_task async def send_frame_permit(self): - """发送加密的帧许可信号""" + """发送加密的帧许可信号(服务器→客户端:加密)""" try: - # 1. 构建原始消息 frame_permit_msg = { "type": "frame", "timestamp": get_current_time_str(), "client_ip": self.client_ip } - # 2. AES加密消息 - encrypted_msg = aes_encrypt(json.dumps(frame_permit_msg)) - # 3. 发送加密消息 + encrypted_msg = aes_encrypt(json.dumps(frame_permit_msg)) # 保持加密 await self.websocket.send_json(encrypted_msg) print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 已发送加密帧许可") except Exception as e: print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧许可加密/发送失败 - {str(e)}") async def consume_frames(self) -> None: - """消费队列中的帧并处理""" + """消费队列中的明文图像帧并处理""" try: while True: frame_data = await self.frame_queue.get() - await self.send_frame_permit() # 发送下一帧许可 + await self.send_frame_permit() # 回复仍加密 try: await self.process_frame(frame_data) finally: @@ -128,23 +110,22 @@ class ClientConnection: print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费错误 - {str(e)}") async def process_frame(self, frame_data: bytes) -> None: - """处理单帧图像(含加密危险通知)""" - # 二进制转OpenCV图像 + """处理明文图像帧(危险通知仍加密发送)""" + # 二进制转OpenCV图像(客户端发的是明文二进制,直接解析) nparr = np.frombuffer(frame_data, np.uint8) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) if img is None: - print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 无法解析图像") + print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 无法解析明文图像") return try: - # 调用检测函数(client_ip + img 双参数) has_violation, data, detector_type = await asyncio.to_thread( detect, self.client_ip, img ) print( f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测结果 - 违规: {has_violation}, 类型: {detector_type}") - # 处理违规逻辑(发送加密危险通知) + # 违规通知:服务器→客户端,仍加密 if has_violation: print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - {data}") # 违规次数+1 @@ -154,19 +135,17 @@ class ClientConnection: except Exception as e: print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数更新失败 - {str(e)}") - # 1. 构建原始危险通知 + # 构建危险通知并加密发送 danger_msg = { "type": "danger", "timestamp": get_current_time_str(), "client_ip": self.client_ip, "detail": data } - # 2. AES加密通知 - encrypted_danger_msg = aes_encrypt(json.dumps(danger_msg)) - # 3. 发送加密通知 + encrypted_danger_msg = aes_encrypt(json.dumps(danger_msg)) # 保持加密 await self.websocket.send_json(encrypted_danger_msg) except Exception as e: - print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 图像处理错误 - {str(e)}") + print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 明文图像处理错误 - {str(e)}") # -------------------------- 5. 全局状态与心跳管理(保持原有)-------------------------- @@ -178,7 +157,6 @@ async def heartbeat_checker(): """全局心跳检查任务""" while True: current_time = get_current_time_str() - # 筛选超时客户端 timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()] if timeout_ips: @@ -186,11 +164,9 @@ async def heartbeat_checker(): for ip in timeout_ips: try: conn = connected_clients[ip] - # 取消消费任务+关闭连接 if conn.consumer_task and not conn.consumer_task.done(): conn.consumer_task.cancel() await conn.websocket.close(code=1008, reason="心跳超时") - # 标记离线 await asyncio.to_thread(update_online_status_by_ip, ip, 0) action_data = DeviceActionCreate(client_ip=ip, action=0) await asyncio.to_thread(add_device_action, action_data) @@ -205,19 +181,16 @@ async def heartbeat_checker(): await asyncio.sleep(HEARTBEAT_INTERVAL) -# -------------------------- 6. 消息处理工具(新增消息解密)-------------------------- +# -------------------------- 6. 客户端明文消息处理(关键修改:删除解密逻辑)-------------------------- async def send_heartbeat_ack(conn: ClientConnection): - """发送加密的心跳确认""" + """发送加密的心跳确认(服务器→客户端:加密)""" try: - # 1. 构建原始心跳确认 heartbeat_ack_msg = { "type": "heart", "timestamp": get_current_time_str(), "client_ip": conn.client_ip } - # 2. AES加密 - encrypted_msg = aes_encrypt(json.dumps(heartbeat_ack_msg)) - # 3. 发送 + encrypted_msg = aes_encrypt(json.dumps(heartbeat_ack_msg)) # 保持加密 await conn.websocket.send_json(encrypted_msg) print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 已发送加密心跳确认") return True @@ -228,44 +201,22 @@ async def send_heartbeat_ack(conn: ClientConnection): async def handle_text_msg(conn: ClientConnection, text: str): - """处理加密的文本消息(如心跳)""" + """处理客户端明文文本消息(如心跳)- 关键修改:无需解密,直接解析JSON""" try: - # 1. 解析加密字典 - encrypted_dict = json.loads(text) - # 2. AES解密 - decrypted_text = aes_decrypt(encrypted_dict) - # 3. 解析业务消息 - msg = json.loads(decrypted_text) - + # 客户端发的是明文JSON,直接解析(删除原解密步骤) + msg = json.loads(text) if msg.get("type") == "heart": conn.update_heartbeat() - await send_heartbeat_ack(conn) + await send_heartbeat_ack(conn) # 服务器回复仍加密 else: - print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 未知文本类型({msg.get('type')})") + print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 未知明文文本类型({msg.get('type')})") except json.JSONDecodeError: - print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 无效JSON格式") + print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 无效JSON格式(明文文本)") except Exception as e: - print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 文本消息解密失败 - {str(e)}") + print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 明文文本消息处理失败 - {str(e)}") -async def handle_binary_msg(conn: ClientConnection, data: str): - """处理加密的图像消息(客户端需先转Base64+加密)""" - try: - # 1. 解密得到Base64编码的图像 - encrypted_dict = json.loads(data) - decrypted_base64 = aes_decrypt(encrypted_dict) - # 2. Base64解码为二进制图像 - frame_data = base64.b64decode(decrypted_base64) - # 3. 加入帧队列 - conn.frame_queue.put_nowait(frame_data) - print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 解密后图像({len(frame_data)}字节)入队") - except asyncio.QueueFull: - print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 帧队列已满,丢弃数据") - except Exception as e: - print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 图像消息解密失败 - {str(e)}") - - -# -------------------------- 7. WebSocket路由与生命周期(保持原有结构)-------------------------- +# -------------------------- 7. WebSocket路由与生命周期(关键修改:处理明文二进制图像)-------------------------- ws_router = APIRouter() @@ -276,7 +227,6 @@ async def lifespan(app: FastAPI): heartbeat_task = asyncio.create_task(heartbeat_checker()) print(f"[{get_current_time_str()}] 心跳检查任务启动(ID: {id(heartbeat_task)})") yield - # 关闭时清理 if heartbeat_task and not heartbeat_task.done(): heartbeat_task.cancel() await heartbeat_task @@ -285,8 +235,8 @@ async def lifespan(app: FastAPI): @ws_router.websocket(WS_ENDPOINT) async def websocket_endpoint(websocket: WebSocket): - """WebSocket连接处理入口""" - load_model() # 加载检测模型(仅一次) + """WebSocket连接处理入口 - 关键修改:接收客户端明文二进制图像""" + load_model() # 加载检测模型(建议移到全局,避免重复加载) await websocket.accept() client_ip = websocket.client.host if websocket.client else "unknown_ip" current_time = get_current_time_str() @@ -306,8 +256,8 @@ async def websocket_endpoint(websocket: WebSocket): # 注册新连接 new_conn = ClientConnection(websocket, client_ip) connected_clients[client_ip] = new_conn - new_conn.start_consumer() # 启动帧消费 - await new_conn.send_frame_permit() # 发送首次帧许可 + new_conn.start_consumer() + await new_conn.send_frame_permit() # 首次许可仍加密 # 标记客户端上线 try: @@ -321,28 +271,33 @@ async def websocket_endpoint(websocket: WebSocket): print(f"[{current_time}] 客户端{client_ip}: 连接注册成功,在线数: {len(connected_clients)}") - # 消息循环(接收客户端消息) + # 消息循环:接收客户端明文消息(关键修改) while True: data = await websocket.receive() if "text" in data: - # 处理加密文本消息(心跳、客户端指令) + # 处理客户端明文文本(如心跳:{"type":"heart",...}) await handle_text_msg(new_conn, data["text"]) elif "bytes" in data: - # 兼容客户端发送二进制:先转Base64再处理 - base64_data = base64.b64encode(data["bytes"]).decode("utf-8") - await handle_binary_msg(new_conn, base64_data) + # 处理客户端明文二进制图像(直接入队,无需解密) + frame_data = data["bytes"] + try: + new_conn.frame_queue.put_nowait(frame_data) + print(f"[{get_current_time_str()}] 客户端{client_ip}: 明文图像({len(frame_data)}字节)入队") + except asyncio.QueueFull: + print(f"[{get_current_time_str()}] 客户端{client_ip}: 帧队列已满,丢弃数据") + except Exception as e: + print(f"[{get_current_time_str()}] 客户端{client_ip}: 明文图像处理失败 - {str(e)}") except WebSocketDisconnect as e: print(f"[{get_current_time_str()}] 客户端{client_ip}: 主动断开连接(代码: {e.code})") except Exception as e: print(f"[{get_current_time_str()}] 客户端{client_ip}: 连接异常 - {str(e)[:50]}") finally: - # 清理资源(断开后处理) + # 清理资源 if client_ip in connected_clients: conn = connected_clients[client_ip] if conn.consumer_task and not conn.consumer_task.done(): conn.consumer_task.cancel() - # 仅上线成功的客户端,才标记离线 if is_online_updated: try: await asyncio.to_thread(update_online_status_by_ip, client_ip, 0) @@ -352,4 +307,4 @@ async def websocket_endpoint(websocket: WebSocket): except Exception as e: print(f"[{get_current_time_str()}] 客户端{client_ip}: 离线更新失败 - {str(e)}") connected_clients.pop(client_ip, None) - print(f"[{get_current_time_str()}] 客户端{client_ip}: 资源已清理,在线数: {len(connected_clients)}") + print(f"[{get_current_time_str()}] 客户端{client_ip}: 资源已清理,在线数: {len(connected_clients)}") \ No newline at end of file