yolo模型识别不到

This commit is contained in:
2025-09-16 20:17:48 +08:00
parent 396505d8c2
commit de6d1b957a
15 changed files with 568 additions and 441 deletions

View File

@ -67,7 +67,7 @@ def detect(client_ip, frame):
if full_save_path: if full_save_path:
cv2.imwrite(full_save_path, frame) cv2.imwrite(full_save_path, frame)
print(f"✅ yolo违规图片已保存{display_path}") # 日志也修正 print(f"✅ yolo违规图片已保存{display_path}") # 日志也修正
save_db(model_type="yolo", client_ip=client_ip, result=str(full_save_path)) save_db(model_type="yolo", client_ip=client_ip, result=str(display_path))
return (True, yolo_result, "yolo") return (True, yolo_result, "yolo")
# 2. 人脸检测 # 2. 人脸检测
@ -77,17 +77,19 @@ def detect(client_ip, frame):
if full_save_path: if full_save_path:
cv2.imwrite(full_save_path, frame) cv2.imwrite(full_save_path, frame)
print(f"✅ face违规图片已保存{display_path}") # 日志也修正 print(f"✅ face违规图片已保存{display_path}") # 日志也修正
save_db(model_type="face", client_ip=client_ip, result=str(full_save_path)) save_db(model_type="face", client_ip=client_ip, result=str(display_path))
return (True, face_result, "face") return (True, face_result, "face")
# 3. OCR检测 # 3. OCR检测
ocr_flag, ocr_result = ocrDetect(frame) ocr_flag, ocr_result = ocrDetect(frame)
if ocr_flag: if ocr_flag:
full_save_path, display_path = get_image_save_path(model_type="ocr", client_ip=client_ip) # 这里改了 full_save_path, display_path = get_image_save_path(model_type="ocr", client_ip=client_ip)
print(f"✅ ocr违规图片已保存{display_path}")
# 这里改了
if full_save_path: if full_save_path:
cv2.imwrite(full_save_path, frame) cv2.imwrite(full_save_path, frame)
print(f"✅ ocr违规图片已保存{display_path}") # 日志也修正 print(f"✅ ocr违规图片已保存{display_path}") # 日志也修正
save_db(model_type="ocr", client_ip=client_ip, result=str(full_save_path)) save_db(model_type="ocr", client_ip=client_ip, result=str(display_path))
return (True, ocr_result, "ocr") return (True, ocr_result, "ocr")
# 4. 无违规内容(不保存图片) # 4. 无违规内容(不保存图片)

View File

@ -1,30 +1,29 @@
import datetime import datetime
from pathlib import Path from pathlib import Path
from typing import List, Tuple
from service.device_service import get_unique_client_ips from service.device_service import get_unique_client_ips
def create_directory_structure(): def create_directory_structure():
"""创建项目所需的目录结构为所有客户端IP预创建基础目录""" """创建项目所需的目录结构为所有客户端IP预创建基础目录"""
try: try:
# 1. 创建根目录下的resource文件夹存在则跳过不覆盖子内容 # 1. 创建根目录下的resource文件夹存在则跳过不覆盖子内容
resource_dir = Path("resource") resource_dir = Path("resource")
resource_dir.mkdir(exist_ok=True) resource_dir.mkdir(exist_ok=True)
# print(f"确保resource目录存在: {resource_dir.absolute()}")
# 2. 在resource下创建dect文件夹 # 2. 在resource下创建dect文件夹
dect_dir = resource_dir / "dect" dect_dir = resource_dir / "dect"
dect_dir.mkdir(exist_ok=True) dect_dir.mkdir(exist_ok=True)
# print(f"确保dect目录存在: {dect_dir.absolute()}")
# 3. 在dect下创建三个模型文件夹 # 3. 在dect下创建三个模型文件夹
model_dirs = ["ocr", "face", "yolo"] model_dirs = ["ocr", "face", "yolo"]
for model in model_dirs: for model in model_dirs:
model_dir = dect_dir / model model_dir = dect_dir / model
model_dir.mkdir(exist_ok=True) model_dir.mkdir(exist_ok=True)
# print(f"确保{model}模型目录存在: {model_dir.absolute()}")
# 4. 调用外部方法获取所有客户端IP地址 # 4. 调用外部方法获取所有客户端IP地址
try: try:
# 调用外部ip_read()方法获取所有客户端IP地址列表
all_ip_addresses = get_unique_client_ips() all_ip_addresses = get_unique_client_ips()
# 确保返回的是列表类型 # 确保返回的是列表类型
@ -58,7 +57,6 @@ def create_directory_structure():
# 递归创建目录(存在则跳过,不覆盖) # 递归创建目录(存在则跳过,不覆盖)
month_dir.mkdir(parents=True, exist_ok=True) month_dir.mkdir(parents=True, exist_ok=True)
# print(f"为客户端IP {ip} 创建/确保目录存在: {month_dir.absolute()}")
except Exception as e: except Exception as e:
print(f"处理客户端IP和日期目录时发生错误: {str(e)}") print(f"处理客户端IP和日期目录时发生错误: {str(e)}")
@ -67,52 +65,68 @@ def create_directory_structure():
print(f"创建基础目录结构时发生错误: {str(e)}") print(f"创建基础目录结构时发生错误: {str(e)}")
def get_image_save_path(model_type: str, client_ip: str) -> tuple: def get_image_save_path(model_type: str, client_ip: str) -> Tuple[str, str]:
""" """
获取图片保存的「完整路径」和「显示用短路径」 获取图片保存的「本地完整路径」和「带路由前缀的显示路径」
参数: 参数:
model_type: 模型类型,应为"ocr""face""yolo" model_type: 模型类型,应为"ocr""face""yolo"
client_ip: 检测到违禁的客户端IP地址原始格式如192.168.1.101 client_ip: 检测到违禁的客户端IP地址原始格式如192.168.1.101
返回: 返回:
元组 (完整保存路径, 显示用短路径);若出错则返回 ("", "") 元组 (本地完整保存路径, 带/api/file/前缀的显示路径);若出错则返回 ("", "")
""" """
try: try:
# 验证模型类型有效性
valid_models = ["ocr", "face", "yolo"]
if model_type not in valid_models:
raise ValueError(f"无效的模型类型: {model_type},必须是{valid_models}之一")
# 1. 验证客户端IP有效性检查是否在已知IP列表中 # 1. 验证客户端IP有效性检查是否在已知IP列表中
all_ip_addresses = get_unique_client_ips() all_ip_addresses = get_unique_client_ips()
if not isinstance(all_ip_addresses, list): if not isinstance(all_ip_addresses, list):
all_ip_addresses = [all_ip_addresses] all_ip_addresses = [all_ip_addresses]
valid_ips = [ip.strip() for ip in all_ip_addresses if ip.strip()] valid_ips = [ip.strip() for ip in all_ip_addresses if ip.strip()]
if client_ip.strip() not in valid_ips: client_ip_stripped = client_ip.strip()
if client_ip_stripped not in valid_ips:
raise ValueError(f"客户端IP {client_ip} 不在已知IP列表中无法保存文件") raise ValueError(f"客户端IP {client_ip} 不在已知IP列表中无法保存文件")
# 2. 处理IP地址与目录创建逻辑一致,将.替换为_ # 2. 处理IP地址将.替换为_,避免路径问题
safe_ip = client_ip.strip().replace(".", "_") safe_ip = client_ip_stripped.replace(".", "_")
# 3. 获取当前日期和毫秒级时间戳(确保文件名唯一) # 3. 获取当前日期和毫秒级时间戳(确保文件名唯一)
now = datetime.datetime.now() now = datetime.datetime.now()
current_year = str(now.year) current_year = str(now.year)
current_month = str(now.month) current_month = str(now.month).zfill(2) # 确保月份为两位数
current_day = str(now.day) current_day = str(now.day).zfill(2) # 确保日期为两位数
# 时间戳格式年月日时分秒毫秒如20250910143050123 timestamp = now.strftime("%Y%m%d%H%M%S%f")[:-3] # 取毫秒级时间戳
timestamp = now.strftime("%Y%m%d%H%M%S%f")[:-3]
# 4. 定义基础目录(用于生成相对路径) # 4. 定义基础目录(用于生成相对路径)
base_dir = Path("resource") / "dect" # 显示路径会去掉这个前缀 base_dir = Path("resource") / "dect"
# 构建日级目录完整路径resource/dect/{model}/{safe_ip}/{年}/{月}/{日} # 构建日级目录完整路径resource/dect/{model}/{safe_ip}/{年}/{月}/{日}
day_dir = base_dir / model_type / safe_ip / current_year / current_month / current_day day_dir = base_dir / model_type / safe_ip / current_year / current_month / current_day
day_dir.mkdir(parents=True, exist_ok=True) # 确保目录存在 day_dir.mkdir(parents=True, exist_ok=True) # 确保目录存在
# 5. 构建唯一文件名 # 5. 构建唯一文件名
image_filename = f"dect_{model_type}_{safe_ip}_{current_year}{current_month}{current_day}_{timestamp}.jpg" image_filename = f"dect_{model_type}_{safe_ip}_{current_year}{current_month}{current_day}_{timestamp}.jpg"
# 6. 生成完整路径(用于实际保存图片)和显示路径(用于打印 # 6. 生成「本地完整路径」(使用系统路径,但在字符串表示时统一为正斜杠
full_path = day_dir / image_filename # 完整路径resource/dect/.../xxx.jpg local_full_path = day_dir / image_filename
display_path = full_path.relative_to(base_dir) # 短路径:{model}/.../xxx.jpg去掉resource/dect # 转换为字符串并统一使用正斜杠
local_full_path_str = str(local_full_path).replace("\\", "/")
return str(full_path), str(display_path) # 7. 生成带路由前缀的显示路径(核心修改部分)
# 获取项目根目录base_dir是resource/dect向上两级即为项目根目录
project_root = base_dir.parents[1]
# 计算相对于项目根目录的路径包含resource/dect层级
relative_path = local_full_path.relative_to(project_root)
# 转换为字符串并统一使用正斜杠
relative_path_str = str(relative_path).replace("\\", "/")
# 拼接路由前缀
routed_display_path = f"/api/file/{relative_path_str}"
return local_full_path_str, routed_display_path
except Exception as e: except Exception as e:
print(f"获取图片保存路径时发生错误: {str(e)}") print(f"获取图片保存路径时发生错误: {str(e)}")

85
main.py
View File

@ -1,10 +1,8 @@
import uvicorn import uvicorn
import threading
import time import time
import os import os
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from service.file_service import app as flask_app
# 原有业务导入 # 原有业务导入
from core.all import load_model from core.all import load_model
@ -15,32 +13,12 @@ from service.sensitive_service import router as sensitive_router
from service.face_service import router as face_router from service.face_service import router as face_router
from service.device_service import router as device_router from service.device_service import router as device_router
from service.model_service import router as model_router from service.model_service import router as model_router
from service.file_service import router as file_router
from service.device_danger_service import router as device_danger_router
from ws.ws import ws_router, lifespan from ws.ws import ws_router, lifespan
from core.establish import create_directory_structure from core.establish import create_directory_structure
# 初始化 FastAPI 应用
# Flask 服务启动函数(不变)
def start_flask_service():
try:
print(f"\n[Flask 服务] 准备启动端口5000")
print(f"[Flask 服务] 访问示例http://服务器IP:5000/resource/dect/ocr/xxx.jpg\n")
BASE_IMAGE_DIR = os.path.abspath(os.path.join("resource", "dect"))
if not os.path.exists(BASE_IMAGE_DIR):
print(f"[Flask 服务] 图片根目录不存在,创建:{BASE_IMAGE_DIR}")
os.makedirs(BASE_IMAGE_DIR, exist_ok=True)
flask_app.run(
host="0.0.0.0",
port=5000,
debug=False,
use_reloader=False
)
except Exception as e:
print(f"[Flask 服务] 启动失败:{str(e)}")
# 初始化 FastAPI 应用(新增 CORS 配置)
app = FastAPI( app = FastAPI(
title="内容安全审核后台", title="内容安全审核后台",
description="含图片访问服务和动态模型管理", description="含图片访问服务和动态模型管理",
@ -48,38 +26,33 @@ app = FastAPI(
lifespan=lifespan lifespan=lifespan
) )
# ------------------------------
# 新增:完整 CORS 配置(解决跨域问题)
# ------------------------------
# 1. 允许的前端域名(根据实际情况修改!本地开发通常是 http://localhost:8080 等)
ALLOWED_ORIGINS = [ ALLOWED_ORIGINS = [
# "http://localhost:8080", # 前端本地开发地址(必改,填实际前端地址) "*"
# "http://127.0.0.1:8080",
# "http://服务器IP:8080", # 部署后前端地址(如适用)
"*" #表示允许所有域名(开发环境可用,生产环境不推荐)
] ]
# 2. 配置 CORS 中间件 # 配置 CORS 中间件
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=ALLOWED_ORIGINS, # 允许的前端域名 allow_origins=ALLOWED_ORIGINS, # 允许的前端域名
allow_credentials=True, # 允许携带 Cookie(如需登录态则必开) allow_credentials=True, # 允许携带 Cookie
allow_methods=["*"], # 允许所有 HTTP 方法(包括 PUT/DELETE allow_methods=["*"], # 允许所有 HTTP 方法
allow_headers=["*"], # 允许所有请求头(包括 Content-Type allow_headers=["*"], # 允许所有请求头
) )
# 注册路由(不变) # 注册路由
app.include_router(user_router) app.include_router(user_router)
app.include_router(device_router) app.include_router(device_router)
app.include_router(face_router) app.include_router(face_router)
app.include_router(sensitive_router) app.include_router(sensitive_router)
app.include_router(model_router) # 模型管理路由 app.include_router(model_router)
app.include_router(file_router)
app.include_router(device_danger_router)
app.include_router(ws_router) app.include_router(ws_router)
# 注册全局异常处理器(不变) # 注册全局异常处理器
app.add_exception_handler(Exception, global_exception_handler) app.add_exception_handler(Exception, global_exception_handler)
# 主服务启动入口(不变) # 主服务启动入口
if __name__ == "__main__": if __name__ == "__main__":
create_directory_structure() create_directory_structure()
print(f"[初始化] 目录结构创建完成") print(f"[初始化] 目录结构创建完成")
@ -89,11 +62,11 @@ if __name__ == "__main__":
os.makedirs(MODEL_SAVE_DIR, exist_ok=True) os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
print(f"[初始化] 模型保存目录:{MODEL_SAVE_DIR}") print(f"[初始化] 模型保存目录:{MODEL_SAVE_DIR}")
# # 模型路径配置 # 确保图片目录存在原Flask服务负责的目录
# YOLO_MODEL_PATH = os.path.join("core", "models", "best.pt") BASE_IMAGE_DIR = os.path.abspath(os.path.join("resource", "dect"))
# OCR_CONFIG_PATH = os.path.join("core", "config", "config.yaml") if not os.path.exists(BASE_IMAGE_DIR):
# print(f"[初始化] 默认YOLO模型路径{YOLO_MODEL_PATH}") print(f"[初始化] 图片根目录不存在,创建:{BASE_IMAGE_DIR}")
# print(f"[初始化] OCR 配置路径:{OCR_CONFIG_PATH}") os.makedirs(BASE_IMAGE_DIR, exist_ok=True)
# 加载检测模型 # 加载检测模型
try: try:
@ -105,23 +78,7 @@ if __name__ == "__main__":
except Exception as e: except Exception as e:
print(f"[初始化] 模型加载警告:{str(e)}(服务仍启动)") print(f"[初始化] 模型加载警告:{str(e)}(服务仍启动)")
# 启动 FastAPI 主服务仅使用8000端口
# 2. 启动 Flask 服务(子线程)
flask_thread = threading.Thread(
target=start_flask_service,
daemon=True
)
flask_thread.start()
# 等待 Flask 初始化
time.sleep(1)
if flask_thread.is_alive():
print(f"[Flask 服务] 启动成功(运行中)")
else:
print(f"[Flask 服务] 启动失败!图片访问不可用")
# 3. 启动 FastAPI 主服务
port = int(SERVER_CONFIG.get("port", 8000)) port = int(SERVER_CONFIG.get("port", 8000))
print(f"\n[FastAPI 服务] 准备启动,端口:{port}") print(f"\n[FastAPI 服务] 准备启动,端口:{port}")
print(f"[FastAPI 服务] 接口文档http://服务器IP:{port}/docs\n") print(f"[FastAPI 服务] 接口文档http://服务器IP:{port}/docs\n")
@ -133,4 +90,4 @@ if __name__ == "__main__":
workers=1, workers=1,
ws="websockets", ws="websockets",
reload=False reload=False
) )

View File

@ -0,0 +1,33 @@
from datetime import datetime
from typing import Optional, List
from pydantic import BaseModel, Field
# ------------------------------
# 请求模型
# ------------------------------
class DeviceDangerCreateRequest(BaseModel):
"""设备危险记录创建请求模型"""
client_ip: str = Field(..., max_length=100, description="设备IP地址必须与devices表中IP对应")
type: str = Field(..., max_length=50, description="危险类型(如:病毒检测、端口异常、权限泄露等)")
result: str = Field(..., description="危险检测结果/处理结果检测到木马病毒已隔离端口22异常开放已关闭")
# ------------------------------
# 响应模型
# ------------------------------
class DeviceDangerResponse(BaseModel):
"""单条设备危险记录响应模型与device_danger表字段对齐updated_at允许为null"""
id: int = Field(..., description="危险记录主键ID")
client_ip: str = Field(..., max_length=100, description="设备IP地址")
type: str = Field(..., max_length=50, description="危险类型")
result: str = Field(..., description="危险检测结果/处理结果")
created_at: datetime = Field(..., description="记录创建时间(危险发生/检测时间)")
updated_at: Optional[datetime] = Field(None, description="记录更新时间数据库中该字段当前为null")
model_config = {"from_attributes": True}
class DeviceDangerListResponse(BaseModel):
"""设备危险记录列表响应模型(带分页)"""
total: int = Field(..., description="危险记录总数")
dangers: List[DeviceDangerResponse] = Field(..., description="设备危险记录列表")

View File

@ -28,7 +28,6 @@ class DeviceResponse(BaseModel):
params: Optional[str] = Field(None, description="扩展参数JSON字符串") params: Optional[str] = Field(None, description="扩展参数JSON字符串")
created_at: datetime = Field(..., description="记录创建时间") created_at: datetime = Field(..., description="记录创建时间")
updated_at: datetime = Field(..., description="记录更新时间") updated_at: datetime = Field(..., description="记录更新时间")
model_config = {"from_attributes": True} # 支持从数据库结果直接转换 model_config = {"from_attributes": True} # 支持从数据库结果直接转换

View File

@ -12,7 +12,7 @@ from schema.response_schema import APIResponse
# 路由配置 # 路由配置
router = APIRouter( router = APIRouter(
prefix="/device/actions", prefix="/api/device/actions",
tags=["设备操作记录"] tags=["设备操作记录"]
) )

View File

@ -0,0 +1,267 @@
import json
from datetime import date
from fastapi import APIRouter, Query, HTTPException, Path
from mysql.connector import Error as MySQLError
from ds.db import db
from encryption.encrypt_decorator import encrypt_response
from schema.device_danger_schema import (
DeviceDangerCreateRequest, DeviceDangerResponse, DeviceDangerListResponse
)
from schema.response_schema import APIResponse
# 路由初始化(前缀与设备管理相关,标签区分功能)
router = APIRouter(
prefix="/api/devices/dangers",
tags=["设备管理-危险记录"]
)
# ------------------------------
# 内部工具方法 - 检查设备是否存在(复用设备表逻辑)
# ------------------------------
def check_device_exist(client_ip: str) -> bool:
"""
检查指定IP的设备是否在devices表中存在
:param client_ip: 设备IP地址
:return: 存在返回True不存在返回False
"""
if not client_ip:
raise ValueError("设备IP不能为空")
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,))
return cursor.fetchone() is not None
except MySQLError as e:
raise Exception(f"检查设备存在性失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 内部工具方法 - 创建设备危险记录(核心插入逻辑)
# ------------------------------
def create_danger_record(danger_data: DeviceDangerCreateRequest) -> DeviceDangerResponse:
"""
内部工具方法向device_danger表插入新的危险记录
:param danger_data: 危险记录创建请求数据
:return: 创建成功的危险记录模型对象
"""
# 先检查设备是否存在
if not check_device_exist(danger_data.client_ip):
raise ValueError(f"IP为 {danger_data.client_ip} 的设备不存在,无法创建危险记录")
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 插入危险记录id自增时间自动填充
insert_query = """
INSERT INTO device_danger
(client_ip, type, result, created_at, updated_at)
VALUES (%s, %s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
"""
cursor.execute(insert_query, (
danger_data.client_ip,
danger_data.type,
danger_data.result
))
conn.commit()
# 获取刚创建的记录用自增ID查询
danger_id = cursor.lastrowid
cursor.execute("SELECT * FROM device_danger WHERE id = %s", (danger_id,))
new_danger = cursor.fetchone()
return DeviceDangerResponse(**new_danger)
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"插入危险记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 接口1创建设备危险记录
# ------------------------------
@router.post("/add", response_model=APIResponse, summary="创建设备危险记录")
@encrypt_response()
async def add_device_danger(danger_data: DeviceDangerCreateRequest):
try:
# 调用内部方法创建记录
new_danger = create_danger_record(danger_data)
return APIResponse(
code=200,
message=f"设备[{danger_data.client_ip}]危险记录创建成功",
data=new_danger
)
except ValueError as e:
# 设备不存在等业务异常
raise HTTPException(status_code=400, detail=str(e)) from e
except Exception as e:
# 数据库异常等系统错误
raise HTTPException(status_code=500, detail=str(e)) from e
# ------------------------------
# 接口2获取危险记录列表支持多条件筛选+分页)
# ------------------------------
@router.get("/", response_model=APIResponse, summary="获取设备危险记录列表(多条件筛选)")
@encrypt_response()
async def get_danger_list(
page: int = Query(1, ge=1, description="页码默认第1页"),
page_size: int = Query(10, ge=1, le=100, description="每页条数1-100之间"),
client_ip: str = Query(None, max_length=100, description="按设备IP筛选"),
danger_type: str = Query(None, max_length=50, alias="type", description="按危险类型筛选"),
start_date: date = Query(None, description="按创建时间筛选开始日期格式YYYY-MM-DD"),
end_date: date = Query(None, description="按创建时间筛选结束日期格式YYYY-MM-DD")
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 构建筛选条件
where_clause = []
params = []
if client_ip:
where_clause.append("client_ip = %s")
params.append(client_ip)
if danger_type:
where_clause.append("type = %s")
params.append(danger_type)
if start_date:
where_clause.append("DATE(created_at) >= %s")
params.append(start_date.strftime("%Y-%m-%d"))
if end_date:
where_clause.append("DATE(created_at) <= %s")
params.append(end_date.strftime("%Y-%m-%d"))
# 1. 统计符合条件的总记录数
count_query = "SELECT COUNT(*) AS total FROM device_danger"
if where_clause:
count_query += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_query, params)
total = cursor.fetchone()["total"]
# 2. 分页查询记录(按创建时间倒序,最新的在前)
offset = (page - 1) * page_size
list_query = "SELECT * FROM device_danger"
if where_clause:
list_query += " WHERE " + " AND ".join(where_clause)
list_query += " ORDER BY created_at DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset]) # 追加分页参数
cursor.execute(list_query, params)
danger_list = cursor.fetchall()
# 转换为响应模型
return APIResponse(
code=200,
message="获取危险记录列表成功",
data=DeviceDangerListResponse(
total=total,
dangers=[DeviceDangerResponse(**item) for item in danger_list]
)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"查询危险记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 接口3获取单个设备的所有危险记录
# ------------------------------
@router.get("/device/{client_ip}", response_model=APIResponse, summary="获取单个设备的所有危险记录")
# @encrypt_response()
async def get_device_dangers(
client_ip: str = Path(..., max_length=100, description="设备IP地址"),
page: int = Query(1, ge=1, description="页码默认第1页"),
page_size: int = Query(10, ge=1, le=100, description="每页条数1-100之间")
):
# 先检查设备是否存在
if not check_device_exist(client_ip):
raise HTTPException(status_code=404, detail=f"IP为 {client_ip} 的设备不存在")
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 1. 统计该设备的危险记录总数
count_query = "SELECT COUNT(*) AS total FROM device_danger WHERE client_ip = %s"
cursor.execute(count_query, (client_ip,))
total = cursor.fetchone()["total"]
# 2. 分页查询该设备的危险记录
offset = (page - 1) * page_size
list_query = """
SELECT * FROM device_danger
WHERE client_ip = %s
ORDER BY created_at DESC
LIMIT %s OFFSET %s
"""
cursor.execute(list_query, (client_ip, page_size, offset))
danger_list = cursor.fetchall()
return APIResponse(
code=200,
message=f"获取设备[{client_ip}]危险记录成功(共{total}条)",
data=DeviceDangerListResponse(
total=total,
dangers=[DeviceDangerResponse(**item) for item in danger_list]
)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"查询设备[{client_ip}]危险记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 接口4根据ID获取单个危险记录详情
# ------------------------------
@router.get("/{danger_id}", response_model=APIResponse, summary="根据ID获取单个危险记录详情")
@encrypt_response()
async def get_danger_detail(
danger_id: int = Path(..., ge=1, description="危险记录ID")
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 查询单个危险记录
query = "SELECT * FROM device_danger WHERE id = %s"
cursor.execute(query, (danger_id,))
danger = cursor.fetchone()
if not danger:
raise HTTPException(status_code=404, detail=f"ID为 {danger_id} 的危险记录不存在")
return APIResponse(
code=200,
message="获取危险记录详情成功",
data=DeviceDangerResponse(**danger)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"查询危险记录详情失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)

View File

@ -13,7 +13,7 @@ from schema.device_schema import (
from schema.response_schema import APIResponse from schema.response_schema import APIResponse
router = APIRouter( router = APIRouter(
prefix="/devices", prefix="/api/devices",
tags=["设备管理"] tags=["设备管理"]
) )

View File

@ -17,7 +17,7 @@ from schema.response_schema import APIResponse
from util.face_util import add_binary_data, get_average_feature from util.face_util import add_binary_data, get_average_feature
from util.file_util import save_face_to_up_images from util.file_util import save_face_to_up_images
router = APIRouter(prefix="/faces", tags=["人脸管理"]) router = APIRouter(prefix="/api/faces", tags=["人脸管理"])
# ------------------------------ # ------------------------------

View File

@ -1,276 +1,174 @@
from flask import Flask, send_from_directory, abort, request from fastapi import FastAPI, HTTPException, Request, Depends, APIRouter
from fastapi.responses import FileResponse
import os import os
import logging import logging
from functools import wraps from functools import wraps
from pathlib import Path from pathlib import Path
from flask_cors import CORS from fastapi.middleware.cors import CORSMiddleware
from typing import Annotated
# 配置日志(保持原有格式)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# 初始化 Flask 应用(供 main.py 导入) router = APIRouter(
app = Flask(__name__) prefix="/api/file",
tags=["文件管理"]
# ------------------------------
# 核心修改:与 FastAPI 对齐的跨域配置
# ------------------------------
# 1. 允许的前端域名(根据实际环境修改,生产环境删除 "*"
ALLOWED_ORIGINS = [
# "http://localhost:8080", # 本地前端开发地址
# "http://127.0.0.1:8080",
# "http://服务器IP:8080", # 部署后前端地址
"*"
]
# 2. 配置 CORS与 FastAPI 规则完全对齐)
CORS(
app,
resources={
r"/*": {
"origins": ALLOWED_ORIGINS,
"allow_credentials": True,
"methods": ["*"],
"allow_headers": ["*"],
}
},
) )
# ------------------------------ # ------------------------------
# 核心路径配置(关键修改:修正 PROJECT_ROOT 计算) # 4. 路径配置
# 原问题file_service.py 在 service 文件夹内,需向上跳一级到项目根目录
# ------------------------------ # ------------------------------
CURRENT_FILE_PATH = Path(__file__).resolve() # 当前文件路径service/file_service.py CURRENT_FILE_PATH = Path(__file__).resolve()
PROJECT_ROOT = CURRENT_FILE_PATH.parent.parent # 项目根目录service 文件夹的父目录) PROJECT_ROOT = CURRENT_FILE_PATH.parent.parent # 项目根目录
# 资源目录(现在正确指向项目根目录下的文件夹)
BASE_IMAGE_DIR_DECT = str((PROJECT_ROOT / "resource" / "dect").resolve()) # 根目录/resource/dect
BASE_IMAGE_DIR_UP_IMAGES = str((PROJECT_ROOT / "up_images").resolve()) # 根目录/up_images
BASE_MODEL_DIR = str((PROJECT_ROOT / "resource" / "models").resolve()) # 根目录/resource/models
# 资源目录定义
BASE_IMAGE_DIR_DECT = str((PROJECT_ROOT / "resource" / "dect").resolve()) # 检测图片目录
BASE_IMAGE_DIR_UP_IMAGES = str((PROJECT_ROOT / "up_images").resolve()) # 人脸图片目录
BASE_MODEL_DIR = str((PROJECT_ROOT / "resource" / "models").resolve()) # 模型文件目录
# 确保资源目录存在
for dir_path in [BASE_IMAGE_DIR_DECT, BASE_IMAGE_DIR_UP_IMAGES, BASE_MODEL_DIR]:
if not os.path.exists(dir_path):
os.makedirs(dir_path, exist_ok=True)
print(f"[创建目录] {dir_path}")
# ------------------------------ # ------------------------------
# 安全检查装饰器(不变 # 5. 安全依赖项替代Flask装饰器
# ------------------------------ # ------------------------------
def safe_path_check(root_dir: str): def safe_path_check(root_dir: str):
def decorator(func): """
@wraps(func) 安全路径校验依赖项:
def wrapper(*args, **kwargs): 1. 禁止路径遍历(确保请求文件在根目录内)
resource_path = kwargs.get('resource_path', '').strip() 2. 校验文件存在且为有效文件(非目录)
# 统一路径分隔符(兼容 Windows \ 和 Linux / 3. 限制文件大小模型200MB图片10MB
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep) """
# 拼接完整路径(防止路径遍历) async def dependency(request: Request, resource_path: str):
full_file_path = os.path.abspath(os.path.join(root_dir, resource_path)) # 统一路径分隔符
logger.debug(
f"[Flask 安全检查] 请求路径:{resource_path} | 完整路径:{full_file_path} | 根目录:{root_dir}"
)
# 1. 禁止路径遍历(确保请求文件在根目录内)
if not full_file_path.startswith(root_dir):
logger.warning(
f"[Flask 安全拦截] 非法路径遍历IP{request.remote_addr} | 请求路径:{resource_path}"
)
abort(403)
# 2. 检查文件存在且为有效文件(非目录)
if not os.path.exists(full_file_path) or not os.path.isfile(full_file_path):
logger.warning(
f"[Flask 资源错误] 文件不存在/非文件IP{request.remote_addr} | 路径:{full_file_path}"
)
abort(404)
# 3. 限制文件大小模型200MB图片10MB
max_size = 200 * 1024 * 1024 if "models" in root_dir else 10 * 1024 * 1024
if os.path.getsize(full_file_path) > max_size:
logger.warning(
f"[Flask 大小超限] 文件超过{max_size//1024//1024}MBIP{request.remote_addr} | 路径:{full_file_path}"
)
abort(413)
# 安全检查通过,传递根目录给视图函数
return func(*args, **kwargs, root_dir=root_dir)
return wrapper
return decorator
# ------------------------------
# 1. 模型下载接口(/model/download/*
# ------------------------------
@app.route('/model/download/<path:resource_path>')
@safe_path_check(root_dir=BASE_MODEL_DIR)
def download_model(resource_path, root_dir):
try:
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep) resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
dir_path, file_name = os.path.split(resource_path) # 拼接完整路径
full_dir = os.path.abspath(os.path.join(root_dir, dir_path)) full_file_path = os.path.abspath(os.path.join(root_dir, resource_path))
# 仅允许 .pt 格式YOLO 模型) # 校验1禁止路径遍历
if not file_name.lower().endswith('.pt'): if not full_file_path.startswith(root_dir):
logger.warning( print(f"[安全检查] 禁止路径遍历IP{request.client.host} | 请求路径:{resource_path}")
f"[Flask 格式错误] 非 .pt 模型文件IP{request.remote_addr} | 文件名:{file_name}" raise HTTPException(status_code=403, detail="非法路径访问")
)
abort(415)
logger.info( # 校验2文件存在且为有效文件
f"[Flask 模型下载] 成功请求IP{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}" if not os.path.exists(full_file_path) or not os.path.isfile(full_file_path):
print(f"[资源错误] 文件不存在/非文件IP{request.client.host} | 路径:{full_file_path}")
raise HTTPException(status_code=404, detail="文件不存在")
# 校验3文件大小限制
max_size = 200 * 1024 * 1024 if "models" in root_dir else 10 * 1024 * 1024
if os.path.getsize(full_file_path) > max_size:
print(f"[大小超限] 文件超过{max_size//1024//1024}MBIP{request.client.host} | 路径:{full_file_path}")
raise HTTPException(status_code=413, detail=f"文件大小超过限制({max_size//1024//1024}MB)")
return full_file_path
return dependency
# ------------------------------
# 6. 核心接口
# ------------------------------
@router.get("/model/download/{resource_path:path}", summary="模型下载接口")
async def download_model(
resource_path: str,
full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_MODEL_DIR))],
request: Request
):
"""模型下载接口(仅允许 .pt 格式,强制浏览器下载)"""
try:
dir_path, file_name = os.path.split(full_file_path)
# 额外校验:仅允许 YOLO 模型格式(.pt
if not file_name.lower().endswith(".pt"):
print(f"[格式错误] 非 .pt 模型文件IP{request.client.host} | 文件名:{file_name}")
raise HTTPException(status_code=415, detail="仅支持 .pt 格式的模型文件")
print(f"[模型下载] 尝试下载IP{request.client.host} | 文件:{file_name} | 目录:{dir_path}")
# 强制下载
return FileResponse(
full_file_path,
filename=file_name,
media_type="application/octet-stream"
) )
except HTTPException:
# 强制浏览器下载(而非预览) raise
return send_from_directory(
full_dir,
file_name,
as_attachment=True,
mimetype="application/octet-stream"
)
except Exception as e: except Exception as e:
logger.error( print(f"[下载异常] IP{request.client.host} | 错误:{str(e)}")
f"[Flask 模型下载异常] IP{request.remote_addr} | 错误:{str(e)}" raise HTTPException(status_code=500, detail="服务器内部错误")
)
abort(500)
# ------------------------------
# 2. 人脸图片访问接口(/up_images/* @router.get("/up_images/{resource_path:path}", summary="人脸图片访问接口")
# ------------------------------ async def get_face_image(
@app.route('/up_images/<path:resource_path>') resource_path: str,
@safe_path_check(root_dir=BASE_IMAGE_DIR_UP_IMAGES) full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_IMAGE_DIR_UP_IMAGES))],
def get_face_image(resource_path, root_dir): request: Request
):
"""人脸图片访问接口(允许浏览器预览,仅支持常见图片格式)"""
try: try:
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep) dir_path, file_name = os.path.split(full_file_path)
dir_path, file_name = os.path.split(resource_path)
full_dir = os.path.abspath(os.path.join(root_dir, dir_path))
# 仅允许常见图片格式 # 图片格式校验
allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp') allowed_ext = (".png", ".jpg", ".jpeg", ".gif", ".bmp")
if not file_name.lower().endswith(allowed_ext): if not file_name.lower().endswith(allowed_ext):
logger.warning( print(f"[格式错误] 非图片文件IP{request.client.host} | 文件名:{file_name}")
f"[Flask 格式错误] 非图片文件IP{request.remote_addr} | 文件名:{file_name}" raise HTTPException(status_code=415, detail="仅支持常见图片格式")
)
abort(415)
logger.info( print(f"[人脸图片] 尝试访问IP{request.client.host} | 文件:{file_name} | 目录:{dir_path}")
f"[Flask 人脸图片] 成功请求IP{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}"
)
# 允许浏览器预览图片
return send_from_directory(full_dir, file_name, as_attachment=False)
return FileResponse(full_file_path)
except HTTPException:
raise
except Exception as e: except Exception as e:
logger.error( print(f"[人脸图片异常] IP{request.client.host} | 错误:{str(e)}")
f"[Flask 人脸图片异常] IP{request.remote_addr} | 错误:{str(e)}"
)
abort(500)
# ------------------------------
# 3. 检测图片访问接口(/resource/dect/* @router.get("/resource/dect/{resource_path:path}", summary="检测图片访问接口")
# ------------------------------ async def get_dect_image(
@app.route('/resource/dect/<path:resource_path>') resource_path: str,
@safe_path_check(root_dir=BASE_IMAGE_DIR_DECT) full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_IMAGE_DIR_DECT))],
def get_dect_image(resource_path, root_dir): request: Request
):
"""检测图片访问接口(允许浏览器预览,仅支持常见图片格式)"""
try: try:
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep) dir_path, file_name = os.path.split(full_file_path)
dir_path, file_name = os.path.split(resource_path)
full_dir = os.path.abspath(os.path.join(root_dir, dir_path))
# 仅允许常见图片格式 # 图片格式校验
allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp') allowed_ext = (".png", ".jpg", ".jpeg", ".gif", ".bmp")
if not file_name.lower().endswith(allowed_ext): if not file_name.lower().endswith(allowed_ext):
logger.warning( print(f"[格式错误] 非图片文件IP{request.client.host} | 文件名:{file_name}")
f"[Flask 格式错误] 非图片文件IP{request.remote_addr} | 文件名:{file_name}" raise HTTPException(status_code=415, detail="仅支持常见图片格式")
)
abort(415)
logger.info( print(f"[检测图片] 尝试访问IP{request.client.host} | 文件:{file_name} | 目录:{dir_path}")
f"[Flask 检测图片] 成功请求IP{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}"
)
return send_from_directory(full_dir, file_name, as_attachment=False)
return FileResponse(full_file_path)
except HTTPException:
raise
except Exception as e: except Exception as e:
logger.error( print(f"[检测图片异常] IP{request.client.host} | 错误:{str(e)}")
f"[Flask 检测图片异常] IP{request.remote_addr} | 错误:{str(e)}" raise HTTPException(status_code=500, detail="服务器内部错误")
)
abort(500)
# ------------------------------
# 4. 兼容旧图片接口(/images/* → 映射到 /resource/dect/* @router.get("/images/{resource_path:path}", summary="兼容旧接口")
# ------------------------------ async def get_compatible_image(
@app.route('/images/<path:resource_path>') resource_path: str,
@safe_path_check(root_dir=BASE_IMAGE_DIR_DECT) full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_IMAGE_DIR_DECT))],
def get_compatible_image(resource_path, root_dir): request: Request
):
"""兼容旧接口(/images/* → 映射到 /resource/dect/*,保留历史兼容性)"""
try: try:
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep) dir_path, file_name = os.path.split(full_file_path)
dir_path, file_name = os.path.split(resource_path)
full_dir = os.path.abspath(os.path.join(root_dir, dir_path))
allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp') # 图片格式校验
allowed_ext = (".png", ".jpg", ".jpeg", ".gif", ".bmp")
if not file_name.lower().endswith(allowed_ext): if not file_name.lower().endswith(allowed_ext):
logger.warning( print(f"[格式错误] 非图片文件IP{request.client.host} | 文件名:{file_name}")
f"[Flask 格式错误] 非图片文件IP{request.remote_addr} | 文件名:{file_name}" raise HTTPException(status_code=415, detail="仅支持常见图片格式")
) print(f"[兼容图片] 尝试访问IP{request.client.host} | 文件:{file_name} | 目录:{dir_path}")
abort(415)
logger.info(
f"[Flask 兼容图片] 成功请求IP{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}"
)
return send_from_directory(full_dir, file_name, as_attachment=False)
return FileResponse(full_file_path)
except HTTPException:
raise
except Exception as e: except Exception as e:
logger.error( print(f"[兼容图片异常] IP{request.client.host} | 错误:{str(e)}")
f"[Flask 兼容图片异常] IP{request.remote_addr} | 错误:{str(e)}" raise HTTPException(status_code=500, detail="服务器内部错误")
)
abort(500)
# ------------------------------
# 全局错误处理器(不变)
# ------------------------------
@app.errorhandler(403)
def forbidden_error(error):
return "❌ 禁止访问:路径非法(可能存在路径遍历)或无权限", 403
@app.errorhandler(404)
def not_found_error(error):
return "❌ 资源不存在请检查URL路径IP、目录、文件名是否正确", 404
@app.errorhandler(413)
def too_large_error(error):
return "❌ 文件过大图片最大10MB模型最大200MB", 413
@app.errorhandler(415)
def unsupported_type_error(error):
return "❌ 不支持的文件类型图片支持png/jpg/jpeg/gif/bmp模型仅支持pt", 415
@app.errorhandler(500)
def server_error(error):
return "❌ 服务器内部错误:请联系管理员查看后台日志", 500
# ------------------------------
# Flask 独立启动入口(供测试,实际由 main.py 子线程启动)
# ------------------------------
if __name__ == '__main__':
# 确保所有资源目录存在
required_dirs = [
(BASE_IMAGE_DIR_DECT, "检测图片目录"),
(BASE_IMAGE_DIR_UP_IMAGES, "人脸图片目录"),
(BASE_MODEL_DIR, "模型文件目录")
]
for dir_path, dir_desc in required_dirs:
if not os.path.exists(dir_path):
logger.info(f"[Flask 初始化] {dir_desc}不存在,创建:{dir_path}")
os.makedirs(dir_path, exist_ok=True)
# 启动提示
logger.info("\n[Flask 服务启动成功!] 支持的接口:")
logger.info(f"1. 模型下载 → http://服务器IP:5000/model/download/resource/models/xxx.pt")
logger.info(f"2. 人脸图片 → http://服务器IP:5000/up_images/xxx.jpg")
logger.info(f"3. 检测图片 → http://服务器IP:5000/resource/dect/xxx.jpg 或 http://服务器IP:5000/images/xxx.jpg\n")
# 启动服务(禁用 debug 和自动重载)
app.run(
host="0.0.0.0",
port=5000,
debug=False,
use_reloader=False
)

View File

@ -38,7 +38,7 @@ _yolo_model = None
_current_model_version = None # 模型版本标识 _current_model_version = None # 模型版本标识
_current_conf_threshold = 0.8 # 默认置信度初始值 _current_conf_threshold = 0.8 # 默认置信度初始值
router = APIRouter(prefix="/models", tags=["模型管理"]) router = APIRouter(prefix="/api/models", tags=["模型管理"])
# 服务重启核心工具函数(保持不变) # 服务重启核心工具函数(保持不变)

View File

@ -16,7 +16,7 @@ from schema.user_schema import UserResponse
# 创建敏感信息接口路由(前缀 /sensitives、标签用于 Swagger 分类) # 创建敏感信息接口路由(前缀 /sensitives、标签用于 Swagger 分类)
router = APIRouter( router = APIRouter(
prefix="/sensitives", prefix="/api/sensitives",
tags=["敏感信息管理"] tags=["敏感信息管理"]
) )

View File

@ -18,7 +18,7 @@ from middle.auth_middleware import (
# 创建用户接口路由(前缀 /users、标签用于 Swagger 分类) # 创建用户接口路由(前缀 /users、标签用于 Swagger 分类)
router = APIRouter( router = APIRouter(
prefix="/users", prefix="/api/users",
tags=["用户管理"] tags=["用户管理"]
) )

View File

@ -12,7 +12,8 @@ def save_face_to_up_images(
) -> Dict[str, str]: ) -> Dict[str, str]:
""" """
保存人脸图片到 `/up_images/用户IP/人脸名字/` 路径 保存人脸图片到 `/up_images/用户IP/人脸名字/` 路径
确保db_path以up_images开头且统一使用正斜杠 确保db_path以 /api/file/up_images 开头,且统一使用正斜杠
本地不创建/api/file/文件夹仅URL访问时使用该前缀路由
参数: 参数:
client_ip: 客户端IP原始格式如192.168.1.101 client_ip: 客户端IP原始格式如192.168.1.101
@ -21,10 +22,10 @@ def save_face_to_up_images(
image_format: 图片格式默认jpg image_format: 图片格式默认jpg
返回: 返回:
字典success是否成功、db_path存数据库的相对路径、local_abs_path本地绝对路径、msg提示 字典success是否成功、db_path存数据库的路径,带/api/file/前缀、local_abs_path本地绝对路径、msg提示
""" """
try: try:
# 1. 基础参数校验 # 1. 基础参数校验(不变)
if not client_ip.strip(): if not client_ip.strip():
return {"success": False, "db_path": "", "local_abs_path": "", "msg": "客户端IP不能为空"} return {"success": False, "db_path": "", "local_abs_path": "", "msg": "客户端IP不能为空"}
if not image_bytes: if not image_bytes:
@ -32,53 +33,54 @@ def save_face_to_up_images(
if image_format.lower() not in ["jpg", "jpeg", "png"]: if image_format.lower() not in ["jpg", "jpeg", "png"]:
return {"success": False, "db_path": "", "local_abs_path": "", "msg": "仅支持jpg/jpeg/png格式"} return {"success": False, "db_path": "", "local_abs_path": "", "msg": "仅支持jpg/jpeg/png格式"}
# 2. 处理特殊字符(避免路径错误) # 2. 处理特殊字符(避免路径错误)(不变)
safe_ip = client_ip.strip().replace(".", "_") # IP中的.替换为_ safe_ip = client_ip.strip().replace(".", "_") # IP中的.替换为_
safe_face_name = face_name.strip() if (face_name and face_name.strip()) else "未命名" safe_face_name = face_name.strip() if (face_name and face_name.strip()) else "未命名"
safe_face_name = "".join([c for c in safe_face_name if c not in r'\/:*?"<>|']) # 过滤非法字符 safe_face_name = "".join([c for c in safe_face_name if c not in r'\/:*?"<>|']) # 过滤非法字符
# 3. 构建根目录(强制转为绝对路径,避免相对路径混淆) # 3. 构建根目录(强制转为绝对路径,避免相对路径混淆)
root_dir = Path("up_images").resolve() # 转为绝对路径如D:/Git/bin/video/up_images root_dir = Path("up_images").resolve()
if not root_dir.exists(): if not root_dir.exists():
root_dir.mkdir(parents=True, exist_ok=True) root_dir.mkdir(parents=True, exist_ok=True)
print(f"[FileUtil] 已创建up_images根目录{root_dir}") print(f"[FileUtil] 已创建up_images根目录{root_dir}")
# 4. 构建文件层级路径确保在root_dir子目录下 # 4. 构建文件层级路径确保在root_dir子目录下(不变)
ip_dir = root_dir / safe_ip ip_dir = root_dir / safe_ip
face_name_dir = ip_dir / safe_face_name face_name_dir = ip_dir / safe_face_name
face_name_dir.mkdir(parents=True, exist_ok=True) # 自动创建目录 face_name_dir.mkdir(parents=True, exist_ok=True)
print(f"[FileUtil] 图片存储目录:{face_name_dir}") print(f"[FileUtil] 图片存储目录(本地){face_name_dir}")
# 5. 生成唯一文件名(毫秒级时间戳) # 5. 生成唯一文件名(毫秒级时间戳)(不变)
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3] timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3]
image_filename = f"face_{safe_ip}_{timestamp}.{image_format.lower()}" image_filename = f"face_{safe_ip}_{timestamp}.{image_format.lower()}"
local_abs_path = face_name_dir / image_filename
# 6. 计算路径确保所有路径都是绝对路径且在root_dir下
local_abs_path = face_name_dir / image_filename # 绝对路径
# 验证路径是否在root_dir下防止路径穿越攻击
if not local_abs_path.resolve().is_relative_to(root_dir.resolve()): if not local_abs_path.resolve().is_relative_to(root_dir.resolve()):
raise Exception(f"图片路径不在up_images根目录下安全校验失败{local_abs_path}") raise Exception(f"图片路径不在up_images根目录下安全校验失败{local_abs_path}")
# 数据库存储路径:强制包含up_images前缀统一使用正斜杠 # 数据库存储路径:核心修改——在原有relative_path前添加 /api/file/ 前缀
relative_path = local_abs_path.relative_to(root_dir.parent) # 相对于root_dir的父目录 relative_path = local_abs_path.relative_to(root_dir.parent)
db_path = str(relative_path).replace("\\", "/") # 此时会包含up_images部分
# 7. 写入图片文件 relative_path_str = str(relative_path).replace("\\", "/")
# 2. 再拼接/api/file/前缀
db_path = f"/api/file/{relative_path_str}"
# 7. 写入图片文件(不变)
with open(local_abs_path, "wb") as f: with open(local_abs_path, "wb") as f:
f.write(image_bytes) f.write(image_bytes)
print(f"[FileUtil] 图片保存成功:") print(f"[FileUtil] 图片保存成功:")
print(f" 数据库路径:{db_path}") print(f" 数据库路径(带/api/file/前缀){db_path}")
print(f" 本地绝对路径:{local_abs_path}") print(f" 本地绝对路径(无/api/file/{local_abs_path}")
return { return {
"success": True, "success": True,
"db_path": db_path, # 格式为 up_images/192_168_110_31/小龙/xxx.jpg "db_path": db_path,
"local_abs_path": str(local_abs_path), # 本地绝对路径(完整路径) "local_abs_path": str(local_abs_path),
"msg": "图片保存成功" "msg": "图片保存成功"
} }
except Exception as e: except Exception as e:
error_msg = f"图片保存失败:{str(e)}" error_msg = f"图片保存失败:{str(e)}"
print(f"[FileUtil] 错误:{error_msg}") print(f"[FileUtil] 错误:{error_msg}")
return {"success": False, "db_path": "", "local_abs_path": "", "msg": error_msg} return {"success": False, "db_path": "", "local_abs_path": "", "msg": error_msg}

129
ws/ws.py
View File

@ -17,17 +17,16 @@ from service.device_action_service import add_device_action
from schema.device_action_schema import DeviceActionCreate from schema.device_action_schema import DeviceActionCreate
from core.all import detect, load_model from core.all import detect, load_model
# -------------------------- 1. AES 加密解密工具(固定密钥-------------------------- # -------------------------- 1. AES 加密工具(仅用于服务器向客户端发送消息--------------------------
AES_SECRET_KEY = b"jr1vA6tfWMHOYi6UXw67UuO6fdak2rMa" # 约定密钥32字节 AES_SECRET_KEY = b"jr1vA6tfWMHOYi6UXw67UuO6fdak2rMa" # 约定密钥32字节
AES_BLOCK_SIZE = 16 # AES固定块大小 AES_BLOCK_SIZE = 16 # AES固定块大小
def aes_encrypt(plaintext: str) -> dict: def aes_encrypt(plaintext: str) -> dict:
"""AES-CBC加密返回{密文, IV, 算法标识}均Base64编码""" """AES-CBC加密返回{密文, IV, 算法标识}均Base64编码- 仅用于服务器发消息"""
try: try:
iv = os.urandom(AES_BLOCK_SIZE) # 随机IV16字节 iv = os.urandom(AES_BLOCK_SIZE) # 随机IV16字节
cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv) cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv)
# 明文填充+加密+Base64编码
padded_plaintext = pad(plaintext.encode("utf-8"), AES_BLOCK_SIZE) padded_plaintext = pad(plaintext.encode("utf-8"), AES_BLOCK_SIZE)
ciphertext = base64.b64encode(cipher.encrypt(padded_plaintext)).decode("utf-8") ciphertext = base64.b64encode(cipher.encrypt(padded_plaintext)).decode("utf-8")
iv_base64 = base64.b64encode(iv).decode("utf-8") iv_base64 = base64.b64encode(iv).decode("utf-8")
@ -40,20 +39,6 @@ def aes_encrypt(plaintext: str) -> dict:
raise Exception(f"AES加密失败: {str(e)}") from e raise Exception(f"AES加密失败: {str(e)}") from e
def aes_decrypt(encrypted_dict: dict) -> str:
"""AES-CBC解密输入加密字典返回原始文本"""
try:
# Base64解码密文和IV
ciphertext = base64.b64decode(encrypted_dict["ciphertext"])
iv = base64.b64decode(encrypted_dict["iv"])
# 解密+去除填充
cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv)
decrypted = unpad(cipher.decrypt(ciphertext), AES_BLOCK_SIZE).decode("utf-8")
return decrypted
except Exception as e:
raise Exception(f"AES解密失败: {str(e)}") from e
# -------------------------- 2. 配置常量(保持原有)-------------------------- # -------------------------- 2. 配置常量(保持原有)--------------------------
HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒) HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒)
HEARTBEAT_TIMEOUT = 600 # 客户端超时阈值(秒) HEARTBEAT_TIMEOUT = 600 # 客户端超时阈值(秒)
@ -72,7 +57,7 @@ def get_current_time_file_str() -> str:
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
# -------------------------- 4. 客户端连接封装(新增消息加密)-------------------------- # -------------------------- 4. 客户端连接封装(服务器发消息加密,接收消息改明文--------------------------
class ClientConnection: class ClientConnection:
def __init__(self, websocket: WebSocket, client_ip: str): def __init__(self, websocket: WebSocket, client_ip: str):
self.websocket = websocket self.websocket = websocket
@ -96,28 +81,25 @@ class ClientConnection:
return self.consumer_task return self.consumer_task
async def send_frame_permit(self): async def send_frame_permit(self):
"""发送加密的帧许可信号""" """发送加密的帧许可信号(服务器→客户端:加密)"""
try: try:
# 1. 构建原始消息
frame_permit_msg = { frame_permit_msg = {
"type": "frame", "type": "frame",
"timestamp": get_current_time_str(), "timestamp": get_current_time_str(),
"client_ip": self.client_ip "client_ip": self.client_ip
} }
# 2. AES加密消息 encrypted_msg = aes_encrypt(json.dumps(frame_permit_msg)) # 保持加密
encrypted_msg = aes_encrypt(json.dumps(frame_permit_msg))
# 3. 发送加密消息
await self.websocket.send_json(encrypted_msg) await self.websocket.send_json(encrypted_msg)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 已发送加密帧许可") print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 已发送加密帧许可")
except Exception as e: except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧许可加密/发送失败 - {str(e)}") print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧许可加密/发送失败 - {str(e)}")
async def consume_frames(self) -> None: async def consume_frames(self) -> None:
"""消费队列中的帧并处理""" """消费队列中的明文图像帧并处理"""
try: try:
while True: while True:
frame_data = await self.frame_queue.get() frame_data = await self.frame_queue.get()
await self.send_frame_permit() # 发送下一帧许可 await self.send_frame_permit() # 回复仍加密
try: try:
await self.process_frame(frame_data) await self.process_frame(frame_data)
finally: finally:
@ -128,23 +110,22 @@ class ClientConnection:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费错误 - {str(e)}") print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费错误 - {str(e)}")
async def process_frame(self, frame_data: bytes) -> None: async def process_frame(self, frame_data: bytes) -> None:
"""处理单帧图像(含加密危险通知""" """处理明文图像帧(危险通知仍加密发送"""
# 二进制转OpenCV图像 # 二进制转OpenCV图像(客户端发的是明文二进制,直接解析)
nparr = np.frombuffer(frame_data, np.uint8) nparr = np.frombuffer(frame_data, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None: if img is None:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 无法解析图像") print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 无法解析明文图像")
return return
try: try:
# 调用检测函数client_ip + img 双参数)
has_violation, data, detector_type = await asyncio.to_thread( has_violation, data, detector_type = await asyncio.to_thread(
detect, self.client_ip, img detect, self.client_ip, img
) )
print( print(
f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测结果 - 违规: {has_violation}, 类型: {detector_type}") f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测结果 - 违规: {has_violation}, 类型: {detector_type}")
# 处理违规逻辑(发送加密危险通知) # 违规通知:服务器→客户端,仍加密
if has_violation: if has_violation:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - {data}") print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - {data}")
# 违规次数+1 # 违规次数+1
@ -154,19 +135,17 @@ class ClientConnection:
except Exception as e: except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数更新失败 - {str(e)}") print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数更新失败 - {str(e)}")
# 1. 构建原始危险通知 # 构建危险通知并加密发送
danger_msg = { danger_msg = {
"type": "danger", "type": "danger",
"timestamp": get_current_time_str(), "timestamp": get_current_time_str(),
"client_ip": self.client_ip, "client_ip": self.client_ip,
"detail": data "detail": data
} }
# 2. AES加密通知 encrypted_danger_msg = aes_encrypt(json.dumps(danger_msg)) # 保持加密
encrypted_danger_msg = aes_encrypt(json.dumps(danger_msg))
# 3. 发送加密通知
await self.websocket.send_json(encrypted_danger_msg) await self.websocket.send_json(encrypted_danger_msg)
except Exception as e: except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 图像处理错误 - {str(e)}") print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 明文图像处理错误 - {str(e)}")
# -------------------------- 5. 全局状态与心跳管理(保持原有)-------------------------- # -------------------------- 5. 全局状态与心跳管理(保持原有)--------------------------
@ -178,7 +157,6 @@ async def heartbeat_checker():
"""全局心跳检查任务""" """全局心跳检查任务"""
while True: while True:
current_time = get_current_time_str() current_time = get_current_time_str()
# 筛选超时客户端
timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()] timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()]
if timeout_ips: if timeout_ips:
@ -186,11 +164,9 @@ async def heartbeat_checker():
for ip in timeout_ips: for ip in timeout_ips:
try: try:
conn = connected_clients[ip] conn = connected_clients[ip]
# 取消消费任务+关闭连接
if conn.consumer_task and not conn.consumer_task.done(): if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel() conn.consumer_task.cancel()
await conn.websocket.close(code=1008, reason="心跳超时") await conn.websocket.close(code=1008, reason="心跳超时")
# 标记离线
await asyncio.to_thread(update_online_status_by_ip, ip, 0) await asyncio.to_thread(update_online_status_by_ip, ip, 0)
action_data = DeviceActionCreate(client_ip=ip, action=0) action_data = DeviceActionCreate(client_ip=ip, action=0)
await asyncio.to_thread(add_device_action, action_data) await asyncio.to_thread(add_device_action, action_data)
@ -205,19 +181,16 @@ async def heartbeat_checker():
await asyncio.sleep(HEARTBEAT_INTERVAL) await asyncio.sleep(HEARTBEAT_INTERVAL)
# -------------------------- 6. 消息处理工具(新增消息解密-------------------------- # -------------------------- 6. 客户端明文消息处理(关键修改:删除解密逻辑--------------------------
async def send_heartbeat_ack(conn: ClientConnection): async def send_heartbeat_ack(conn: ClientConnection):
"""发送加密的心跳确认""" """发送加密的心跳确认(服务器→客户端:加密)"""
try: try:
# 1. 构建原始心跳确认
heartbeat_ack_msg = { heartbeat_ack_msg = {
"type": "heart", "type": "heart",
"timestamp": get_current_time_str(), "timestamp": get_current_time_str(),
"client_ip": conn.client_ip "client_ip": conn.client_ip
} }
# 2. AES加密 encrypted_msg = aes_encrypt(json.dumps(heartbeat_ack_msg)) # 保持加密
encrypted_msg = aes_encrypt(json.dumps(heartbeat_ack_msg))
# 3. 发送
await conn.websocket.send_json(encrypted_msg) await conn.websocket.send_json(encrypted_msg)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 已发送加密心跳确认") print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 已发送加密心跳确认")
return True return True
@ -228,44 +201,22 @@ async def send_heartbeat_ack(conn: ClientConnection):
async def handle_text_msg(conn: ClientConnection, text: str): async def handle_text_msg(conn: ClientConnection, text: str):
"""处理加密的文本消息(如心跳)""" """处理客户端明文文本消息(如心跳)- 关键修改无需解密直接解析JSON"""
try: try:
# 1. 解析加密字典 # 客户端发的是明文JSON直接解析删除原解密步骤
encrypted_dict = json.loads(text) msg = json.loads(text)
# 2. AES解密
decrypted_text = aes_decrypt(encrypted_dict)
# 3. 解析业务消息
msg = json.loads(decrypted_text)
if msg.get("type") == "heart": if msg.get("type") == "heart":
conn.update_heartbeat() conn.update_heartbeat()
await send_heartbeat_ack(conn) await send_heartbeat_ack(conn) # 服务器回复仍加密
else: else:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 未知文本类型({msg.get('type')}") print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 未知明文文本类型({msg.get('type')}")
except json.JSONDecodeError: except json.JSONDecodeError:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 无效JSON格式") print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 无效JSON格式(明文文本)")
except Exception as e: except Exception as e:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 文本消息解密失败 - {str(e)}") print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 明文文本消息处理失败 - {str(e)}")
async def handle_binary_msg(conn: ClientConnection, data: str): # -------------------------- 7. WebSocket路由与生命周期关键修改处理明文二进制图像--------------------------
"""处理加密的图像消息客户端需先转Base64+加密)"""
try:
# 1. 解密得到Base64编码的图像
encrypted_dict = json.loads(data)
decrypted_base64 = aes_decrypt(encrypted_dict)
# 2. Base64解码为二进制图像
frame_data = base64.b64decode(decrypted_base64)
# 3. 加入帧队列
conn.frame_queue.put_nowait(frame_data)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 解密后图像({len(frame_data)}字节)入队")
except asyncio.QueueFull:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 帧队列已满,丢弃数据")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 图像消息解密失败 - {str(e)}")
# -------------------------- 7. WebSocket路由与生命周期保持原有结构--------------------------
ws_router = APIRouter() ws_router = APIRouter()
@ -276,7 +227,6 @@ async def lifespan(app: FastAPI):
heartbeat_task = asyncio.create_task(heartbeat_checker()) heartbeat_task = asyncio.create_task(heartbeat_checker())
print(f"[{get_current_time_str()}] 心跳检查任务启动ID: {id(heartbeat_task)}") print(f"[{get_current_time_str()}] 心跳检查任务启动ID: {id(heartbeat_task)}")
yield yield
# 关闭时清理
if heartbeat_task and not heartbeat_task.done(): if heartbeat_task and not heartbeat_task.done():
heartbeat_task.cancel() heartbeat_task.cancel()
await heartbeat_task await heartbeat_task
@ -285,8 +235,8 @@ async def lifespan(app: FastAPI):
@ws_router.websocket(WS_ENDPOINT) @ws_router.websocket(WS_ENDPOINT)
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
"""WebSocket连接处理入口""" """WebSocket连接处理入口 - 关键修改:接收客户端明文二进制图像"""
load_model() # 加载检测模型(仅一次 load_model() # 加载检测模型(建议移到全局,避免重复加载
await websocket.accept() await websocket.accept()
client_ip = websocket.client.host if websocket.client else "unknown_ip" client_ip = websocket.client.host if websocket.client else "unknown_ip"
current_time = get_current_time_str() current_time = get_current_time_str()
@ -306,8 +256,8 @@ async def websocket_endpoint(websocket: WebSocket):
# 注册新连接 # 注册新连接
new_conn = ClientConnection(websocket, client_ip) new_conn = ClientConnection(websocket, client_ip)
connected_clients[client_ip] = new_conn connected_clients[client_ip] = new_conn
new_conn.start_consumer() # 启动帧消费 new_conn.start_consumer()
await new_conn.send_frame_permit() # 发送首次许可 await new_conn.send_frame_permit() # 首次许可仍加密
# 标记客户端上线 # 标记客户端上线
try: try:
@ -321,28 +271,33 @@ async def websocket_endpoint(websocket: WebSocket):
print(f"[{current_time}] 客户端{client_ip}: 连接注册成功,在线数: {len(connected_clients)}") print(f"[{current_time}] 客户端{client_ip}: 连接注册成功,在线数: {len(connected_clients)}")
# 消息循环接收客户端消息 # 消息循环接收客户端明文消息(关键修改
while True: while True:
data = await websocket.receive() data = await websocket.receive()
if "text" in data: if "text" in data:
# 处理加密文本消息(心跳、客户端指令 # 处理客户端明文文本(如心跳:{"type":"heart",...}
await handle_text_msg(new_conn, data["text"]) await handle_text_msg(new_conn, data["text"])
elif "bytes" in data: elif "bytes" in data:
# 兼容客户端发送二进制先转Base64再处理 # 处理客户端明文二进制图像(直接入队,无需解密)
base64_data = base64.b64encode(data["bytes"]).decode("utf-8") frame_data = data["bytes"]
await handle_binary_msg(new_conn, base64_data) try:
new_conn.frame_queue.put_nowait(frame_data)
print(f"[{get_current_time_str()}] 客户端{client_ip}: 明文图像({len(frame_data)}字节)入队")
except asyncio.QueueFull:
print(f"[{get_current_time_str()}] 客户端{client_ip}: 帧队列已满,丢弃数据")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}: 明文图像处理失败 - {str(e)}")
except WebSocketDisconnect as e: except WebSocketDisconnect as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}: 主动断开连接(代码: {e.code}") print(f"[{get_current_time_str()}] 客户端{client_ip}: 主动断开连接(代码: {e.code}")
except Exception as e: except Exception as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}: 连接异常 - {str(e)[:50]}") print(f"[{get_current_time_str()}] 客户端{client_ip}: 连接异常 - {str(e)[:50]}")
finally: finally:
# 清理资源(断开后处理) # 清理资源
if client_ip in connected_clients: if client_ip in connected_clients:
conn = connected_clients[client_ip] conn = connected_clients[client_ip]
if conn.consumer_task and not conn.consumer_task.done(): if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel() conn.consumer_task.cancel()
# 仅上线成功的客户端,才标记离线
if is_online_updated: if is_online_updated:
try: try:
await asyncio.to_thread(update_online_status_by_ip, client_ip, 0) await asyncio.to_thread(update_online_status_by_ip, client_ip, 0)
@ -352,4 +307,4 @@ async def websocket_endpoint(websocket: WebSocket):
except Exception as e: except Exception as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}: 离线更新失败 - {str(e)}") print(f"[{get_current_time_str()}] 客户端{client_ip}: 离线更新失败 - {str(e)}")
connected_clients.pop(client_ip, None) connected_clients.pop(client_ip, None)
print(f"[{get_current_time_str()}] 客户端{client_ip}: 资源已清理,在线数: {len(connected_clients)}") print(f"[{get_current_time_str()}] 客户端{client_ip}: 资源已清理,在线数: {len(connected_clients)}")