yolo模型识别不到

This commit is contained in:
2025-09-16 20:17:48 +08:00
parent 396505d8c2
commit de6d1b957a
15 changed files with 568 additions and 441 deletions

View File

@ -67,7 +67,7 @@ def detect(client_ip, frame):
if full_save_path:
cv2.imwrite(full_save_path, frame)
print(f"✅ yolo违规图片已保存{display_path}") # 日志也修正
save_db(model_type="yolo", client_ip=client_ip, result=str(full_save_path))
save_db(model_type="yolo", client_ip=client_ip, result=str(display_path))
return (True, yolo_result, "yolo")
# 2. 人脸检测
@ -77,17 +77,19 @@ def detect(client_ip, frame):
if full_save_path:
cv2.imwrite(full_save_path, frame)
print(f"✅ face违规图片已保存{display_path}") # 日志也修正
save_db(model_type="face", client_ip=client_ip, result=str(full_save_path))
save_db(model_type="face", client_ip=client_ip, result=str(display_path))
return (True, face_result, "face")
# 3. OCR检测
ocr_flag, ocr_result = ocrDetect(frame)
if ocr_flag:
full_save_path, display_path = get_image_save_path(model_type="ocr", client_ip=client_ip) # 这里改了
full_save_path, display_path = get_image_save_path(model_type="ocr", client_ip=client_ip)
print(f"✅ ocr违规图片已保存{display_path}")
# 这里改了
if full_save_path:
cv2.imwrite(full_save_path, frame)
print(f"✅ ocr违规图片已保存{display_path}") # 日志也修正
save_db(model_type="ocr", client_ip=client_ip, result=str(full_save_path))
save_db(model_type="ocr", client_ip=client_ip, result=str(display_path))
return (True, ocr_result, "ocr")
# 4. 无违规内容(不保存图片)

View File

@ -1,30 +1,29 @@
import datetime
from pathlib import Path
from typing import List, Tuple
from service.device_service import get_unique_client_ips
def create_directory_structure():
"""创建项目所需的目录结构为所有客户端IP预创建基础目录"""
try:
# 1. 创建根目录下的resource文件夹存在则跳过不覆盖子内容
resource_dir = Path("resource")
resource_dir.mkdir(exist_ok=True)
# print(f"确保resource目录存在: {resource_dir.absolute()}")
# 2. 在resource下创建dect文件夹
dect_dir = resource_dir / "dect"
dect_dir.mkdir(exist_ok=True)
# print(f"确保dect目录存在: {dect_dir.absolute()}")
# 3. 在dect下创建三个模型文件夹
model_dirs = ["ocr", "face", "yolo"]
for model in model_dirs:
model_dir = dect_dir / model
model_dir.mkdir(exist_ok=True)
# print(f"确保{model}模型目录存在: {model_dir.absolute()}")
# 4. 调用外部方法获取所有客户端IP地址
try:
# 调用外部ip_read()方法获取所有客户端IP地址列表
all_ip_addresses = get_unique_client_ips()
# 确保返回的是列表类型
@ -58,7 +57,6 @@ def create_directory_structure():
# 递归创建目录(存在则跳过,不覆盖)
month_dir.mkdir(parents=True, exist_ok=True)
# print(f"为客户端IP {ip} 创建/确保目录存在: {month_dir.absolute()}")
except Exception as e:
print(f"处理客户端IP和日期目录时发生错误: {str(e)}")
@ -67,52 +65,68 @@ def create_directory_structure():
print(f"创建基础目录结构时发生错误: {str(e)}")
def get_image_save_path(model_type: str, client_ip: str) -> tuple:
def get_image_save_path(model_type: str, client_ip: str) -> Tuple[str, str]:
"""
获取图片保存的「完整路径」和「显示用短路径」
获取图片保存的「本地完整路径」和「带路由前缀的显示路径」
参数:
model_type: 模型类型,应为"ocr""face""yolo"
client_ip: 检测到违禁的客户端IP地址原始格式如192.168.1.101
返回:
元组 (完整保存路径, 显示用短路径);若出错则返回 ("", "")
元组 (本地完整保存路径, 带/api/file/前缀的显示路径);若出错则返回 ("", "")
"""
try:
# 验证模型类型有效性
valid_models = ["ocr", "face", "yolo"]
if model_type not in valid_models:
raise ValueError(f"无效的模型类型: {model_type},必须是{valid_models}之一")
# 1. 验证客户端IP有效性检查是否在已知IP列表中
all_ip_addresses = get_unique_client_ips()
if not isinstance(all_ip_addresses, list):
all_ip_addresses = [all_ip_addresses]
valid_ips = [ip.strip() for ip in all_ip_addresses if ip.strip()]
if client_ip.strip() not in valid_ips:
client_ip_stripped = client_ip.strip()
if client_ip_stripped not in valid_ips:
raise ValueError(f"客户端IP {client_ip} 不在已知IP列表中无法保存文件")
# 2. 处理IP地址与目录创建逻辑一致,将.替换为_
safe_ip = client_ip.strip().replace(".", "_")
# 2. 处理IP地址将.替换为_,避免路径问题
safe_ip = client_ip_stripped.replace(".", "_")
# 3. 获取当前日期和毫秒级时间戳(确保文件名唯一)
now = datetime.datetime.now()
current_year = str(now.year)
current_month = str(now.month)
current_day = str(now.day)
# 时间戳格式年月日时分秒毫秒如20250910143050123
timestamp = now.strftime("%Y%m%d%H%M%S%f")[:-3]
current_month = str(now.month).zfill(2) # 确保月份为两位数
current_day = str(now.day).zfill(2) # 确保日期为两位数
timestamp = now.strftime("%Y%m%d%H%M%S%f")[:-3] # 取毫秒级时间戳
# 4. 定义基础目录(用于生成相对路径)
base_dir = Path("resource") / "dect" # 显示路径会去掉这个前缀
base_dir = Path("resource") / "dect"
# 构建日级目录完整路径resource/dect/{model}/{safe_ip}/{年}/{月}/{日}
day_dir = base_dir / model_type / safe_ip / current_year / current_month / current_day
day_dir.mkdir(parents=True, exist_ok=True) # 确保目录存在
day_dir.mkdir(parents=True, exist_ok=True) # 确保目录存在
# 5. 构建唯一文件名
image_filename = f"dect_{model_type}_{safe_ip}_{current_year}{current_month}{current_day}_{timestamp}.jpg"
# 6. 生成完整路径(用于实际保存图片)和显示路径(用于打印
full_path = day_dir / image_filename # 完整路径resource/dect/.../xxx.jpg
display_path = full_path.relative_to(base_dir) # 短路径:{model}/.../xxx.jpg去掉resource/dect
# 6. 生成「本地完整路径」(使用系统路径,但在字符串表示时统一为正斜杠
local_full_path = day_dir / image_filename
# 转换为字符串并统一使用正斜杠
local_full_path_str = str(local_full_path).replace("\\", "/")
return str(full_path), str(display_path)
# 7. 生成带路由前缀的显示路径(核心修改部分)
# 获取项目根目录base_dir是resource/dect向上两级即为项目根目录
project_root = base_dir.parents[1]
# 计算相对于项目根目录的路径包含resource/dect层级
relative_path = local_full_path.relative_to(project_root)
# 转换为字符串并统一使用正斜杠
relative_path_str = str(relative_path).replace("\\", "/")
# 拼接路由前缀
routed_display_path = f"/api/file/{relative_path_str}"
return local_full_path_str, routed_display_path
except Exception as e:
print(f"获取图片保存路径时发生错误: {str(e)}")

85
main.py
View File

@ -1,10 +1,8 @@
import uvicorn
import threading
import time
import os
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from service.file_service import app as flask_app
# 原有业务导入
from core.all import load_model
@ -15,32 +13,12 @@ from service.sensitive_service import router as sensitive_router
from service.face_service import router as face_router
from service.device_service import router as device_router
from service.model_service import router as model_router
from service.file_service import router as file_router
from service.device_danger_service import router as device_danger_router
from ws.ws import ws_router, lifespan
from core.establish import create_directory_structure
# Flask 服务启动函数(不变)
def start_flask_service():
try:
print(f"\n[Flask 服务] 准备启动端口5000")
print(f"[Flask 服务] 访问示例http://服务器IP:5000/resource/dect/ocr/xxx.jpg\n")
BASE_IMAGE_DIR = os.path.abspath(os.path.join("resource", "dect"))
if not os.path.exists(BASE_IMAGE_DIR):
print(f"[Flask 服务] 图片根目录不存在,创建:{BASE_IMAGE_DIR}")
os.makedirs(BASE_IMAGE_DIR, exist_ok=True)
flask_app.run(
host="0.0.0.0",
port=5000,
debug=False,
use_reloader=False
)
except Exception as e:
print(f"[Flask 服务] 启动失败:{str(e)}")
# 初始化 FastAPI 应用(新增 CORS 配置)
# 初始化 FastAPI 应用
app = FastAPI(
title="内容安全审核后台",
description="含图片访问服务和动态模型管理",
@ -48,38 +26,33 @@ app = FastAPI(
lifespan=lifespan
)
# ------------------------------
# 新增:完整 CORS 配置(解决跨域问题)
# ------------------------------
# 1. 允许的前端域名(根据实际情况修改!本地开发通常是 http://localhost:8080 等)
ALLOWED_ORIGINS = [
# "http://localhost:8080", # 前端本地开发地址(必改,填实际前端地址)
# "http://127.0.0.1:8080",
# "http://服务器IP:8080", # 部署后前端地址(如适用)
"*" #表示允许所有域名(开发环境可用,生产环境不推荐)
"*"
]
# 2. 配置 CORS 中间件
# 配置 CORS 中间件
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS, # 允许的前端域名
allow_credentials=True, # 允许携带 Cookie(如需登录态则必开)
allow_methods=["*"], # 允许所有 HTTP 方法(包括 PUT/DELETE
allow_headers=["*"], # 允许所有请求头(包括 Content-Type
allow_credentials=True, # 允许携带 Cookie
allow_methods=["*"], # 允许所有 HTTP 方法
allow_headers=["*"], # 允许所有请求头
)
# 注册路由(不变)
# 注册路由
app.include_router(user_router)
app.include_router(device_router)
app.include_router(face_router)
app.include_router(sensitive_router)
app.include_router(model_router) # 模型管理路由
app.include_router(model_router)
app.include_router(file_router)
app.include_router(device_danger_router)
app.include_router(ws_router)
# 注册全局异常处理器(不变)
# 注册全局异常处理器
app.add_exception_handler(Exception, global_exception_handler)
# 主服务启动入口(不变)
# 主服务启动入口
if __name__ == "__main__":
create_directory_structure()
print(f"[初始化] 目录结构创建完成")
@ -89,11 +62,11 @@ if __name__ == "__main__":
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
print(f"[初始化] 模型保存目录:{MODEL_SAVE_DIR}")
# # 模型路径配置
# YOLO_MODEL_PATH = os.path.join("core", "models", "best.pt")
# OCR_CONFIG_PATH = os.path.join("core", "config", "config.yaml")
# print(f"[初始化] 默认YOLO模型路径{YOLO_MODEL_PATH}")
# print(f"[初始化] OCR 配置路径:{OCR_CONFIG_PATH}")
# 确保图片目录存在原Flask服务负责的目录
BASE_IMAGE_DIR = os.path.abspath(os.path.join("resource", "dect"))
if not os.path.exists(BASE_IMAGE_DIR):
print(f"[初始化] 图片根目录不存在,创建:{BASE_IMAGE_DIR}")
os.makedirs(BASE_IMAGE_DIR, exist_ok=True)
# 加载检测模型
try:
@ -105,23 +78,7 @@ if __name__ == "__main__":
except Exception as e:
print(f"[初始化] 模型加载警告:{str(e)}(服务仍启动)")
# 2. 启动 Flask 服务(子线程)
flask_thread = threading.Thread(
target=start_flask_service,
daemon=True
)
flask_thread.start()
# 等待 Flask 初始化
time.sleep(1)
if flask_thread.is_alive():
print(f"[Flask 服务] 启动成功(运行中)")
else:
print(f"[Flask 服务] 启动失败!图片访问不可用")
# 3. 启动 FastAPI 主服务
# 启动 FastAPI 主服务仅使用8000端口
port = int(SERVER_CONFIG.get("port", 8000))
print(f"\n[FastAPI 服务] 准备启动,端口:{port}")
print(f"[FastAPI 服务] 接口文档http://服务器IP:{port}/docs\n")
@ -133,4 +90,4 @@ if __name__ == "__main__":
workers=1,
ws="websockets",
reload=False
)
)

View File

@ -0,0 +1,33 @@
from datetime import datetime
from typing import Optional, List
from pydantic import BaseModel, Field
# ------------------------------
# 请求模型
# ------------------------------
class DeviceDangerCreateRequest(BaseModel):
"""设备危险记录创建请求模型"""
client_ip: str = Field(..., max_length=100, description="设备IP地址必须与devices表中IP对应")
type: str = Field(..., max_length=50, description="危险类型(如:病毒检测、端口异常、权限泄露等)")
result: str = Field(..., description="危险检测结果/处理结果检测到木马病毒已隔离端口22异常开放已关闭")
# ------------------------------
# 响应模型
# ------------------------------
class DeviceDangerResponse(BaseModel):
"""单条设备危险记录响应模型与device_danger表字段对齐updated_at允许为null"""
id: int = Field(..., description="危险记录主键ID")
client_ip: str = Field(..., max_length=100, description="设备IP地址")
type: str = Field(..., max_length=50, description="危险类型")
result: str = Field(..., description="危险检测结果/处理结果")
created_at: datetime = Field(..., description="记录创建时间(危险发生/检测时间)")
updated_at: Optional[datetime] = Field(None, description="记录更新时间数据库中该字段当前为null")
model_config = {"from_attributes": True}
class DeviceDangerListResponse(BaseModel):
"""设备危险记录列表响应模型(带分页)"""
total: int = Field(..., description="危险记录总数")
dangers: List[DeviceDangerResponse] = Field(..., description="设备危险记录列表")

View File

@ -28,7 +28,6 @@ class DeviceResponse(BaseModel):
params: Optional[str] = Field(None, description="扩展参数JSON字符串")
created_at: datetime = Field(..., description="记录创建时间")
updated_at: datetime = Field(..., description="记录更新时间")
model_config = {"from_attributes": True} # 支持从数据库结果直接转换

View File

@ -12,7 +12,7 @@ from schema.response_schema import APIResponse
# 路由配置
router = APIRouter(
prefix="/device/actions",
prefix="/api/device/actions",
tags=["设备操作记录"]
)

View File

@ -0,0 +1,267 @@
import json
from datetime import date
from fastapi import APIRouter, Query, HTTPException, Path
from mysql.connector import Error as MySQLError
from ds.db import db
from encryption.encrypt_decorator import encrypt_response
from schema.device_danger_schema import (
DeviceDangerCreateRequest, DeviceDangerResponse, DeviceDangerListResponse
)
from schema.response_schema import APIResponse
# 路由初始化(前缀与设备管理相关,标签区分功能)
router = APIRouter(
prefix="/api/devices/dangers",
tags=["设备管理-危险记录"]
)
# ------------------------------
# 内部工具方法 - 检查设备是否存在(复用设备表逻辑)
# ------------------------------
def check_device_exist(client_ip: str) -> bool:
"""
检查指定IP的设备是否在devices表中存在
:param client_ip: 设备IP地址
:return: 存在返回True不存在返回False
"""
if not client_ip:
raise ValueError("设备IP不能为空")
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,))
return cursor.fetchone() is not None
except MySQLError as e:
raise Exception(f"检查设备存在性失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 内部工具方法 - 创建设备危险记录(核心插入逻辑)
# ------------------------------
def create_danger_record(danger_data: DeviceDangerCreateRequest) -> DeviceDangerResponse:
"""
内部工具方法向device_danger表插入新的危险记录
:param danger_data: 危险记录创建请求数据
:return: 创建成功的危险记录模型对象
"""
# 先检查设备是否存在
if not check_device_exist(danger_data.client_ip):
raise ValueError(f"IP为 {danger_data.client_ip} 的设备不存在,无法创建危险记录")
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 插入危险记录id自增时间自动填充
insert_query = """
INSERT INTO device_danger
(client_ip, type, result, created_at, updated_at)
VALUES (%s, %s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
"""
cursor.execute(insert_query, (
danger_data.client_ip,
danger_data.type,
danger_data.result
))
conn.commit()
# 获取刚创建的记录用自增ID查询
danger_id = cursor.lastrowid
cursor.execute("SELECT * FROM device_danger WHERE id = %s", (danger_id,))
new_danger = cursor.fetchone()
return DeviceDangerResponse(**new_danger)
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"插入危险记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 接口1创建设备危险记录
# ------------------------------
@router.post("/add", response_model=APIResponse, summary="创建设备危险记录")
@encrypt_response()
async def add_device_danger(danger_data: DeviceDangerCreateRequest):
try:
# 调用内部方法创建记录
new_danger = create_danger_record(danger_data)
return APIResponse(
code=200,
message=f"设备[{danger_data.client_ip}]危险记录创建成功",
data=new_danger
)
except ValueError as e:
# 设备不存在等业务异常
raise HTTPException(status_code=400, detail=str(e)) from e
except Exception as e:
# 数据库异常等系统错误
raise HTTPException(status_code=500, detail=str(e)) from e
# ------------------------------
# 接口2获取危险记录列表支持多条件筛选+分页)
# ------------------------------
@router.get("/", response_model=APIResponse, summary="获取设备危险记录列表(多条件筛选)")
@encrypt_response()
async def get_danger_list(
page: int = Query(1, ge=1, description="页码默认第1页"),
page_size: int = Query(10, ge=1, le=100, description="每页条数1-100之间"),
client_ip: str = Query(None, max_length=100, description="按设备IP筛选"),
danger_type: str = Query(None, max_length=50, alias="type", description="按危险类型筛选"),
start_date: date = Query(None, description="按创建时间筛选开始日期格式YYYY-MM-DD"),
end_date: date = Query(None, description="按创建时间筛选结束日期格式YYYY-MM-DD")
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 构建筛选条件
where_clause = []
params = []
if client_ip:
where_clause.append("client_ip = %s")
params.append(client_ip)
if danger_type:
where_clause.append("type = %s")
params.append(danger_type)
if start_date:
where_clause.append("DATE(created_at) >= %s")
params.append(start_date.strftime("%Y-%m-%d"))
if end_date:
where_clause.append("DATE(created_at) <= %s")
params.append(end_date.strftime("%Y-%m-%d"))
# 1. 统计符合条件的总记录数
count_query = "SELECT COUNT(*) AS total FROM device_danger"
if where_clause:
count_query += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_query, params)
total = cursor.fetchone()["total"]
# 2. 分页查询记录(按创建时间倒序,最新的在前)
offset = (page - 1) * page_size
list_query = "SELECT * FROM device_danger"
if where_clause:
list_query += " WHERE " + " AND ".join(where_clause)
list_query += " ORDER BY created_at DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset]) # 追加分页参数
cursor.execute(list_query, params)
danger_list = cursor.fetchall()
# 转换为响应模型
return APIResponse(
code=200,
message="获取危险记录列表成功",
data=DeviceDangerListResponse(
total=total,
dangers=[DeviceDangerResponse(**item) for item in danger_list]
)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"查询危险记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 接口3获取单个设备的所有危险记录
# ------------------------------
@router.get("/device/{client_ip}", response_model=APIResponse, summary="获取单个设备的所有危险记录")
# @encrypt_response()
async def get_device_dangers(
client_ip: str = Path(..., max_length=100, description="设备IP地址"),
page: int = Query(1, ge=1, description="页码默认第1页"),
page_size: int = Query(10, ge=1, le=100, description="每页条数1-100之间")
):
# 先检查设备是否存在
if not check_device_exist(client_ip):
raise HTTPException(status_code=404, detail=f"IP为 {client_ip} 的设备不存在")
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 1. 统计该设备的危险记录总数
count_query = "SELECT COUNT(*) AS total FROM device_danger WHERE client_ip = %s"
cursor.execute(count_query, (client_ip,))
total = cursor.fetchone()["total"]
# 2. 分页查询该设备的危险记录
offset = (page - 1) * page_size
list_query = """
SELECT * FROM device_danger
WHERE client_ip = %s
ORDER BY created_at DESC
LIMIT %s OFFSET %s
"""
cursor.execute(list_query, (client_ip, page_size, offset))
danger_list = cursor.fetchall()
return APIResponse(
code=200,
message=f"获取设备[{client_ip}]危险记录成功(共{total}条)",
data=DeviceDangerListResponse(
total=total,
dangers=[DeviceDangerResponse(**item) for item in danger_list]
)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"查询设备[{client_ip}]危险记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 接口4根据ID获取单个危险记录详情
# ------------------------------
@router.get("/{danger_id}", response_model=APIResponse, summary="根据ID获取单个危险记录详情")
@encrypt_response()
async def get_danger_detail(
danger_id: int = Path(..., ge=1, description="危险记录ID")
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 查询单个危险记录
query = "SELECT * FROM device_danger WHERE id = %s"
cursor.execute(query, (danger_id,))
danger = cursor.fetchone()
if not danger:
raise HTTPException(status_code=404, detail=f"ID为 {danger_id} 的危险记录不存在")
return APIResponse(
code=200,
message="获取危险记录详情成功",
data=DeviceDangerResponse(**danger)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"查询危险记录详情失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)

View File

@ -13,7 +13,7 @@ from schema.device_schema import (
from schema.response_schema import APIResponse
router = APIRouter(
prefix="/devices",
prefix="/api/devices",
tags=["设备管理"]
)

View File

@ -17,7 +17,7 @@ from schema.response_schema import APIResponse
from util.face_util import add_binary_data, get_average_feature
from util.file_util import save_face_to_up_images
router = APIRouter(prefix="/faces", tags=["人脸管理"])
router = APIRouter(prefix="/api/faces", tags=["人脸管理"])
# ------------------------------

View File

@ -1,276 +1,174 @@
from flask import Flask, send_from_directory, abort, request
from fastapi import FastAPI, HTTPException, Request, Depends, APIRouter
from fastapi.responses import FileResponse
import os
import logging
from functools import wraps
from pathlib import Path
from flask_cors import CORS
from fastapi.middleware.cors import CORSMiddleware
from typing import Annotated
# 配置日志(保持原有格式)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# 初始化 Flask 应用(供 main.py 导入)
app = Flask(__name__)
# ------------------------------
# 核心修改:与 FastAPI 对齐的跨域配置
# ------------------------------
# 1. 允许的前端域名(根据实际环境修改,生产环境删除 "*"
ALLOWED_ORIGINS = [
# "http://localhost:8080", # 本地前端开发地址
# "http://127.0.0.1:8080",
# "http://服务器IP:8080", # 部署后前端地址
"*"
]
# 2. 配置 CORS与 FastAPI 规则完全对齐)
CORS(
app,
resources={
r"/*": {
"origins": ALLOWED_ORIGINS,
"allow_credentials": True,
"methods": ["*"],
"allow_headers": ["*"],
}
},
router = APIRouter(
prefix="/api/file",
tags=["文件管理"]
)
# ------------------------------
# 核心路径配置(关键修改:修正 PROJECT_ROOT 计算)
# 原问题file_service.py 在 service 文件夹内,需向上跳一级到项目根目录
# 4. 路径配置
# ------------------------------
CURRENT_FILE_PATH = Path(__file__).resolve() # 当前文件路径service/file_service.py
PROJECT_ROOT = CURRENT_FILE_PATH.parent.parent # 项目根目录service 文件夹的父目录)
# 资源目录(现在正确指向项目根目录下的文件夹)
BASE_IMAGE_DIR_DECT = str((PROJECT_ROOT / "resource" / "dect").resolve()) # 根目录/resource/dect
BASE_IMAGE_DIR_UP_IMAGES = str((PROJECT_ROOT / "up_images").resolve()) # 根目录/up_images
BASE_MODEL_DIR = str((PROJECT_ROOT / "resource" / "models").resolve()) # 根目录/resource/models
CURRENT_FILE_PATH = Path(__file__).resolve()
PROJECT_ROOT = CURRENT_FILE_PATH.parent.parent # 项目根目录
# 资源目录定义
BASE_IMAGE_DIR_DECT = str((PROJECT_ROOT / "resource" / "dect").resolve()) # 检测图片目录
BASE_IMAGE_DIR_UP_IMAGES = str((PROJECT_ROOT / "up_images").resolve()) # 人脸图片目录
BASE_MODEL_DIR = str((PROJECT_ROOT / "resource" / "models").resolve()) # 模型文件目录
# 确保资源目录存在
for dir_path in [BASE_IMAGE_DIR_DECT, BASE_IMAGE_DIR_UP_IMAGES, BASE_MODEL_DIR]:
if not os.path.exists(dir_path):
os.makedirs(dir_path, exist_ok=True)
print(f"[创建目录] {dir_path}")
# ------------------------------
# 安全检查装饰器(不变
# 5. 安全依赖项替代Flask装饰器
# ------------------------------
def safe_path_check(root_dir: str):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
resource_path = kwargs.get('resource_path', '').strip()
# 统一路径分隔符(兼容 Windows \ 和 Linux /
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
# 拼接完整路径(防止路径遍历)
full_file_path = os.path.abspath(os.path.join(root_dir, resource_path))
logger.debug(
f"[Flask 安全检查] 请求路径:{resource_path} | 完整路径:{full_file_path} | 根目录:{root_dir}"
)
# 1. 禁止路径遍历(确保请求文件在根目录内)
if not full_file_path.startswith(root_dir):
logger.warning(
f"[Flask 安全拦截] 非法路径遍历IP{request.remote_addr} | 请求路径:{resource_path}"
)
abort(403)
# 2. 检查文件存在且为有效文件(非目录)
if not os.path.exists(full_file_path) or not os.path.isfile(full_file_path):
logger.warning(
f"[Flask 资源错误] 文件不存在/非文件IP{request.remote_addr} | 路径:{full_file_path}"
)
abort(404)
# 3. 限制文件大小模型200MB图片10MB
max_size = 200 * 1024 * 1024 if "models" in root_dir else 10 * 1024 * 1024
if os.path.getsize(full_file_path) > max_size:
logger.warning(
f"[Flask 大小超限] 文件超过{max_size//1024//1024}MBIP{request.remote_addr} | 路径:{full_file_path}"
)
abort(413)
# 安全检查通过,传递根目录给视图函数
return func(*args, **kwargs, root_dir=root_dir)
return wrapper
return decorator
# ------------------------------
# 1. 模型下载接口(/model/download/*
# ------------------------------
@app.route('/model/download/<path:resource_path>')
@safe_path_check(root_dir=BASE_MODEL_DIR)
def download_model(resource_path, root_dir):
try:
"""
安全路径校验依赖项:
1. 禁止路径遍历(确保请求文件在根目录内)
2. 校验文件存在且为有效文件(非目录)
3. 限制文件大小模型200MB图片10MB
"""
async def dependency(request: Request, resource_path: str):
# 统一路径分隔符
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
dir_path, file_name = os.path.split(resource_path)
full_dir = os.path.abspath(os.path.join(root_dir, dir_path))
# 拼接完整路径
full_file_path = os.path.abspath(os.path.join(root_dir, resource_path))
# 仅允许 .pt 格式YOLO 模型)
if not file_name.lower().endswith('.pt'):
logger.warning(
f"[Flask 格式错误] 非 .pt 模型文件IP{request.remote_addr} | 文件名:{file_name}"
)
abort(415)
# 校验1禁止路径遍历
if not full_file_path.startswith(root_dir):
print(f"[安全检查] 禁止路径遍历IP{request.client.host} | 请求路径:{resource_path}")
raise HTTPException(status_code=403, detail="非法路径访问")
logger.info(
f"[Flask 模型下载] 成功请求IP{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}"
# 校验2文件存在且为有效文件
if not os.path.exists(full_file_path) or not os.path.isfile(full_file_path):
print(f"[资源错误] 文件不存在/非文件IP{request.client.host} | 路径:{full_file_path}")
raise HTTPException(status_code=404, detail="文件不存在")
# 校验3文件大小限制
max_size = 200 * 1024 * 1024 if "models" in root_dir else 10 * 1024 * 1024
if os.path.getsize(full_file_path) > max_size:
print(f"[大小超限] 文件超过{max_size//1024//1024}MBIP{request.client.host} | 路径:{full_file_path}")
raise HTTPException(status_code=413, detail=f"文件大小超过限制({max_size//1024//1024}MB)")
return full_file_path
return dependency
# ------------------------------
# 6. 核心接口
# ------------------------------
@router.get("/model/download/{resource_path:path}", summary="模型下载接口")
async def download_model(
resource_path: str,
full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_MODEL_DIR))],
request: Request
):
"""模型下载接口(仅允许 .pt 格式,强制浏览器下载)"""
try:
dir_path, file_name = os.path.split(full_file_path)
# 额外校验:仅允许 YOLO 模型格式(.pt
if not file_name.lower().endswith(".pt"):
print(f"[格式错误] 非 .pt 模型文件IP{request.client.host} | 文件名:{file_name}")
raise HTTPException(status_code=415, detail="仅支持 .pt 格式的模型文件")
print(f"[模型下载] 尝试下载IP{request.client.host} | 文件:{file_name} | 目录:{dir_path}")
# 强制下载
return FileResponse(
full_file_path,
filename=file_name,
media_type="application/octet-stream"
)
# 强制浏览器下载(而非预览)
return send_from_directory(
full_dir,
file_name,
as_attachment=True,
mimetype="application/octet-stream"
)
except HTTPException:
raise
except Exception as e:
logger.error(
f"[Flask 模型下载异常] IP{request.remote_addr} | 错误:{str(e)}"
)
abort(500)
print(f"[下载异常] IP{request.client.host} | 错误:{str(e)}")
raise HTTPException(status_code=500, detail="服务器内部错误")
# ------------------------------
# 2. 人脸图片访问接口(/up_images/*
# ------------------------------
@app.route('/up_images/<path:resource_path>')
@safe_path_check(root_dir=BASE_IMAGE_DIR_UP_IMAGES)
def get_face_image(resource_path, root_dir):
@router.get("/up_images/{resource_path:path}", summary="人脸图片访问接口")
async def get_face_image(
resource_path: str,
full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_IMAGE_DIR_UP_IMAGES))],
request: Request
):
"""人脸图片访问接口(允许浏览器预览,仅支持常见图片格式)"""
try:
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
dir_path, file_name = os.path.split(resource_path)
full_dir = os.path.abspath(os.path.join(root_dir, dir_path))
dir_path, file_name = os.path.split(full_file_path)
# 仅允许常见图片格式
allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp')
# 图片格式校验
allowed_ext = (".png", ".jpg", ".jpeg", ".gif", ".bmp")
if not file_name.lower().endswith(allowed_ext):
logger.warning(
f"[Flask 格式错误] 非图片文件IP{request.remote_addr} | 文件名:{file_name}"
)
abort(415)
print(f"[格式错误] 非图片文件IP{request.client.host} | 文件名:{file_name}")
raise HTTPException(status_code=415, detail="仅支持常见图片格式")
logger.info(
f"[Flask 人脸图片] 成功请求IP{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}"
)
# 允许浏览器预览图片
return send_from_directory(full_dir, file_name, as_attachment=False)
print(f"[人脸图片] 尝试访问IP{request.client.host} | 文件:{file_name} | 目录:{dir_path}")
return FileResponse(full_file_path)
except HTTPException:
raise
except Exception as e:
logger.error(
f"[Flask 人脸图片异常] IP{request.remote_addr} | 错误:{str(e)}"
)
abort(500)
print(f"[人脸图片异常] IP{request.client.host} | 错误:{str(e)}")
# ------------------------------
# 3. 检测图片访问接口(/resource/dect/*
# ------------------------------
@app.route('/resource/dect/<path:resource_path>')
@safe_path_check(root_dir=BASE_IMAGE_DIR_DECT)
def get_dect_image(resource_path, root_dir):
@router.get("/resource/dect/{resource_path:path}", summary="检测图片访问接口")
async def get_dect_image(
resource_path: str,
full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_IMAGE_DIR_DECT))],
request: Request
):
"""检测图片访问接口(允许浏览器预览,仅支持常见图片格式)"""
try:
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
dir_path, file_name = os.path.split(resource_path)
full_dir = os.path.abspath(os.path.join(root_dir, dir_path))
dir_path, file_name = os.path.split(full_file_path)
# 仅允许常见图片格式
allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp')
# 图片格式校验
allowed_ext = (".png", ".jpg", ".jpeg", ".gif", ".bmp")
if not file_name.lower().endswith(allowed_ext):
logger.warning(
f"[Flask 格式错误] 非图片文件IP{request.remote_addr} | 文件名:{file_name}"
)
abort(415)
print(f"[格式错误] 非图片文件IP{request.client.host} | 文件名:{file_name}")
raise HTTPException(status_code=415, detail="仅支持常见图片格式")
logger.info(
f"[Flask 检测图片] 成功请求IP{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}"
)
return send_from_directory(full_dir, file_name, as_attachment=False)
print(f"[检测图片] 尝试访问IP{request.client.host} | 文件:{file_name} | 目录:{dir_path}")
return FileResponse(full_file_path)
except HTTPException:
raise
except Exception as e:
logger.error(
f"[Flask 检测图片异常] IP{request.remote_addr} | 错误:{str(e)}"
)
abort(500)
print(f"[检测图片异常] IP{request.client.host} | 错误:{str(e)}")
raise HTTPException(status_code=500, detail="服务器内部错误")
# ------------------------------
# 4. 兼容旧图片接口(/images/* → 映射到 /resource/dect/*
# ------------------------------
@app.route('/images/<path:resource_path>')
@safe_path_check(root_dir=BASE_IMAGE_DIR_DECT)
def get_compatible_image(resource_path, root_dir):
@router.get("/images/{resource_path:path}", summary="兼容旧接口")
async def get_compatible_image(
resource_path: str,
full_file_path: Annotated[str, Depends(safe_path_check(root_dir=BASE_IMAGE_DIR_DECT))],
request: Request
):
"""兼容旧接口(/images/* → 映射到 /resource/dect/*,保留历史兼容性)"""
try:
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
dir_path, file_name = os.path.split(resource_path)
full_dir = os.path.abspath(os.path.join(root_dir, dir_path))
dir_path, file_name = os.path.split(full_file_path)
allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp')
# 图片格式校验
allowed_ext = (".png", ".jpg", ".jpeg", ".gif", ".bmp")
if not file_name.lower().endswith(allowed_ext):
logger.warning(
f"[Flask 格式错误] 非图片文件IP{request.remote_addr} | 文件名:{file_name}"
)
abort(415)
logger.info(
f"[Flask 兼容图片] 成功请求IP{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}"
)
return send_from_directory(full_dir, file_name, as_attachment=False)
print(f"[格式错误] 非图片文件IP{request.client.host} | 文件名:{file_name}")
raise HTTPException(status_code=415, detail="仅支持常见图片格式")
print(f"[兼容图片] 尝试访问IP{request.client.host} | 文件:{file_name} | 目录:{dir_path}")
return FileResponse(full_file_path)
except HTTPException:
raise
except Exception as e:
logger.error(
f"[Flask 兼容图片异常] IP{request.remote_addr} | 错误:{str(e)}"
)
abort(500)
# ------------------------------
# 全局错误处理器(不变)
# ------------------------------
@app.errorhandler(403)
def forbidden_error(error):
return "❌ 禁止访问:路径非法(可能存在路径遍历)或无权限", 403
@app.errorhandler(404)
def not_found_error(error):
return "❌ 资源不存在请检查URL路径IP、目录、文件名是否正确", 404
@app.errorhandler(413)
def too_large_error(error):
return "❌ 文件过大图片最大10MB模型最大200MB", 413
@app.errorhandler(415)
def unsupported_type_error(error):
return "❌ 不支持的文件类型图片支持png/jpg/jpeg/gif/bmp模型仅支持pt", 415
@app.errorhandler(500)
def server_error(error):
return "❌ 服务器内部错误:请联系管理员查看后台日志", 500
# ------------------------------
# Flask 独立启动入口(供测试,实际由 main.py 子线程启动)
# ------------------------------
if __name__ == '__main__':
# 确保所有资源目录存在
required_dirs = [
(BASE_IMAGE_DIR_DECT, "检测图片目录"),
(BASE_IMAGE_DIR_UP_IMAGES, "人脸图片目录"),
(BASE_MODEL_DIR, "模型文件目录")
]
for dir_path, dir_desc in required_dirs:
if not os.path.exists(dir_path):
logger.info(f"[Flask 初始化] {dir_desc}不存在,创建:{dir_path}")
os.makedirs(dir_path, exist_ok=True)
# 启动提示
logger.info("\n[Flask 服务启动成功!] 支持的接口:")
logger.info(f"1. 模型下载 → http://服务器IP:5000/model/download/resource/models/xxx.pt")
logger.info(f"2. 人脸图片 → http://服务器IP:5000/up_images/xxx.jpg")
logger.info(f"3. 检测图片 → http://服务器IP:5000/resource/dect/xxx.jpg 或 http://服务器IP:5000/images/xxx.jpg\n")
# 启动服务(禁用 debug 和自动重载)
app.run(
host="0.0.0.0",
port=5000,
debug=False,
use_reloader=False
)
print(f"[兼容图片异常] IP{request.client.host} | 错误:{str(e)}")
raise HTTPException(status_code=500, detail="服务器内部错误")

View File

@ -38,7 +38,7 @@ _yolo_model = None
_current_model_version = None # 模型版本标识
_current_conf_threshold = 0.8 # 默认置信度初始值
router = APIRouter(prefix="/models", tags=["模型管理"])
router = APIRouter(prefix="/api/models", tags=["模型管理"])
# 服务重启核心工具函数(保持不变)

View File

@ -16,7 +16,7 @@ from schema.user_schema import UserResponse
# 创建敏感信息接口路由(前缀 /sensitives、标签用于 Swagger 分类)
router = APIRouter(
prefix="/sensitives",
prefix="/api/sensitives",
tags=["敏感信息管理"]
)

View File

@ -18,7 +18,7 @@ from middle.auth_middleware import (
# 创建用户接口路由(前缀 /users、标签用于 Swagger 分类)
router = APIRouter(
prefix="/users",
prefix="/api/users",
tags=["用户管理"]
)

View File

@ -12,7 +12,8 @@ def save_face_to_up_images(
) -> Dict[str, str]:
"""
保存人脸图片到 `/up_images/用户IP/人脸名字/` 路径
确保db_path以up_images开头且统一使用正斜杠
确保db_path以 /api/file/up_images 开头,且统一使用正斜杠
本地不创建/api/file/文件夹仅URL访问时使用该前缀路由
参数:
client_ip: 客户端IP原始格式如192.168.1.101
@ -21,10 +22,10 @@ def save_face_to_up_images(
image_format: 图片格式默认jpg
返回:
字典success是否成功、db_path存数据库的相对路径、local_abs_path本地绝对路径、msg提示
字典success是否成功、db_path存数据库的路径,带/api/file/前缀、local_abs_path本地绝对路径、msg提示
"""
try:
# 1. 基础参数校验
# 1. 基础参数校验(不变)
if not client_ip.strip():
return {"success": False, "db_path": "", "local_abs_path": "", "msg": "客户端IP不能为空"}
if not image_bytes:
@ -32,53 +33,54 @@ def save_face_to_up_images(
if image_format.lower() not in ["jpg", "jpeg", "png"]:
return {"success": False, "db_path": "", "local_abs_path": "", "msg": "仅支持jpg/jpeg/png格式"}
# 2. 处理特殊字符(避免路径错误)
# 2. 处理特殊字符(避免路径错误)(不变)
safe_ip = client_ip.strip().replace(".", "_") # IP中的.替换为_
safe_face_name = face_name.strip() if (face_name and face_name.strip()) else "未命名"
safe_face_name = "".join([c for c in safe_face_name if c not in r'\/:*?"<>|']) # 过滤非法字符
# 3. 构建根目录(强制转为绝对路径,避免相对路径混淆)
root_dir = Path("up_images").resolve() # 转为绝对路径如D:/Git/bin/video/up_images
root_dir = Path("up_images").resolve()
if not root_dir.exists():
root_dir.mkdir(parents=True, exist_ok=True)
print(f"[FileUtil] 已创建up_images根目录{root_dir}")
# 4. 构建文件层级路径确保在root_dir子目录下
# 4. 构建文件层级路径确保在root_dir子目录下(不变)
ip_dir = root_dir / safe_ip
face_name_dir = ip_dir / safe_face_name
face_name_dir.mkdir(parents=True, exist_ok=True) # 自动创建目录
print(f"[FileUtil] 图片存储目录:{face_name_dir}")
face_name_dir.mkdir(parents=True, exist_ok=True)
print(f"[FileUtil] 图片存储目录(本地){face_name_dir}")
# 5. 生成唯一文件名(毫秒级时间戳)
# 5. 生成唯一文件名(毫秒级时间戳)(不变)
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3]
image_filename = f"face_{safe_ip}_{timestamp}.{image_format.lower()}"
local_abs_path = face_name_dir / image_filename
# 6. 计算路径确保所有路径都是绝对路径且在root_dir下
local_abs_path = face_name_dir / image_filename # 绝对路径
# 验证路径是否在root_dir下防止路径穿越攻击
if not local_abs_path.resolve().is_relative_to(root_dir.resolve()):
raise Exception(f"图片路径不在up_images根目录下安全校验失败{local_abs_path}")
# 数据库存储路径:强制包含up_images前缀统一使用正斜杠
relative_path = local_abs_path.relative_to(root_dir.parent) # 相对于root_dir的父目录
db_path = str(relative_path).replace("\\", "/") # 此时会包含up_images部分
# 数据库存储路径:核心修改——在原有relative_path前添加 /api/file/ 前缀
relative_path = local_abs_path.relative_to(root_dir.parent)
# 7. 写入图片文件
relative_path_str = str(relative_path).replace("\\", "/")
# 2. 再拼接/api/file/前缀
db_path = f"/api/file/{relative_path_str}"
# 7. 写入图片文件(不变)
with open(local_abs_path, "wb") as f:
f.write(image_bytes)
print(f"[FileUtil] 图片保存成功:")
print(f" 数据库路径:{db_path}")
print(f" 本地绝对路径:{local_abs_path}")
print(f" 数据库路径(带/api/file/前缀){db_path}")
print(f" 本地绝对路径(无/api/file/{local_abs_path}")
return {
"success": True,
"db_path": db_path, # 格式为 up_images/192_168_110_31/小龙/xxx.jpg
"local_abs_path": str(local_abs_path), # 本地绝对路径(完整路径)
"db_path": db_path,
"local_abs_path": str(local_abs_path),
"msg": "图片保存成功"
}
except Exception as e:
error_msg = f"图片保存失败:{str(e)}"
print(f"[FileUtil] 错误:{error_msg}")
return {"success": False, "db_path": "", "local_abs_path": "", "msg": error_msg}
return {"success": False, "db_path": "", "local_abs_path": "", "msg": error_msg}

129
ws/ws.py
View File

@ -17,17 +17,16 @@ from service.device_action_service import add_device_action
from schema.device_action_schema import DeviceActionCreate
from core.all import detect, load_model
# -------------------------- 1. AES 加密解密工具(固定密钥--------------------------
# -------------------------- 1. AES 加密工具(仅用于服务器向客户端发送消息--------------------------
AES_SECRET_KEY = b"jr1vA6tfWMHOYi6UXw67UuO6fdak2rMa" # 约定密钥32字节
AES_BLOCK_SIZE = 16 # AES固定块大小
def aes_encrypt(plaintext: str) -> dict:
"""AES-CBC加密返回{密文, IV, 算法标识}均Base64编码"""
"""AES-CBC加密返回{密文, IV, 算法标识}均Base64编码- 仅用于服务器发消息"""
try:
iv = os.urandom(AES_BLOCK_SIZE) # 随机IV16字节
cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv)
# 明文填充+加密+Base64编码
padded_plaintext = pad(plaintext.encode("utf-8"), AES_BLOCK_SIZE)
ciphertext = base64.b64encode(cipher.encrypt(padded_plaintext)).decode("utf-8")
iv_base64 = base64.b64encode(iv).decode("utf-8")
@ -40,20 +39,6 @@ def aes_encrypt(plaintext: str) -> dict:
raise Exception(f"AES加密失败: {str(e)}") from e
def aes_decrypt(encrypted_dict: dict) -> str:
"""AES-CBC解密输入加密字典返回原始文本"""
try:
# Base64解码密文和IV
ciphertext = base64.b64decode(encrypted_dict["ciphertext"])
iv = base64.b64decode(encrypted_dict["iv"])
# 解密+去除填充
cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv)
decrypted = unpad(cipher.decrypt(ciphertext), AES_BLOCK_SIZE).decode("utf-8")
return decrypted
except Exception as e:
raise Exception(f"AES解密失败: {str(e)}") from e
# -------------------------- 2. 配置常量(保持原有)--------------------------
HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒)
HEARTBEAT_TIMEOUT = 600 # 客户端超时阈值(秒)
@ -72,7 +57,7 @@ def get_current_time_file_str() -> str:
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
# -------------------------- 4. 客户端连接封装(新增消息加密)--------------------------
# -------------------------- 4. 客户端连接封装(服务器发消息加密,接收消息改明文--------------------------
class ClientConnection:
def __init__(self, websocket: WebSocket, client_ip: str):
self.websocket = websocket
@ -96,28 +81,25 @@ class ClientConnection:
return self.consumer_task
async def send_frame_permit(self):
"""发送加密的帧许可信号"""
"""发送加密的帧许可信号(服务器→客户端:加密)"""
try:
# 1. 构建原始消息
frame_permit_msg = {
"type": "frame",
"timestamp": get_current_time_str(),
"client_ip": self.client_ip
}
# 2. AES加密消息
encrypted_msg = aes_encrypt(json.dumps(frame_permit_msg))
# 3. 发送加密消息
encrypted_msg = aes_encrypt(json.dumps(frame_permit_msg)) # 保持加密
await self.websocket.send_json(encrypted_msg)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 已发送加密帧许可")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧许可加密/发送失败 - {str(e)}")
async def consume_frames(self) -> None:
"""消费队列中的帧并处理"""
"""消费队列中的明文图像帧并处理"""
try:
while True:
frame_data = await self.frame_queue.get()
await self.send_frame_permit() # 发送下一帧许可
await self.send_frame_permit() # 回复仍加密
try:
await self.process_frame(frame_data)
finally:
@ -128,23 +110,22 @@ class ClientConnection:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费错误 - {str(e)}")
async def process_frame(self, frame_data: bytes) -> None:
"""处理单帧图像(含加密危险通知"""
# 二进制转OpenCV图像
"""处理明文图像帧(危险通知仍加密发送"""
# 二进制转OpenCV图像(客户端发的是明文二进制,直接解析)
nparr = np.frombuffer(frame_data, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 无法解析图像")
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 无法解析明文图像")
return
try:
# 调用检测函数client_ip + img 双参数)
has_violation, data, detector_type = await asyncio.to_thread(
detect, self.client_ip, img
)
print(
f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测结果 - 违规: {has_violation}, 类型: {detector_type}")
# 处理违规逻辑(发送加密危险通知)
# 违规通知:服务器→客户端,仍加密
if has_violation:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - {data}")
# 违规次数+1
@ -154,19 +135,17 @@ class ClientConnection:
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数更新失败 - {str(e)}")
# 1. 构建原始危险通知
# 构建危险通知并加密发送
danger_msg = {
"type": "danger",
"timestamp": get_current_time_str(),
"client_ip": self.client_ip,
"detail": data
}
# 2. AES加密通知
encrypted_danger_msg = aes_encrypt(json.dumps(danger_msg))
# 3. 发送加密通知
encrypted_danger_msg = aes_encrypt(json.dumps(danger_msg)) # 保持加密
await self.websocket.send_json(encrypted_danger_msg)
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 图像处理错误 - {str(e)}")
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 明文图像处理错误 - {str(e)}")
# -------------------------- 5. 全局状态与心跳管理(保持原有)--------------------------
@ -178,7 +157,6 @@ async def heartbeat_checker():
"""全局心跳检查任务"""
while True:
current_time = get_current_time_str()
# 筛选超时客户端
timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()]
if timeout_ips:
@ -186,11 +164,9 @@ async def heartbeat_checker():
for ip in timeout_ips:
try:
conn = connected_clients[ip]
# 取消消费任务+关闭连接
if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel()
await conn.websocket.close(code=1008, reason="心跳超时")
# 标记离线
await asyncio.to_thread(update_online_status_by_ip, ip, 0)
action_data = DeviceActionCreate(client_ip=ip, action=0)
await asyncio.to_thread(add_device_action, action_data)
@ -205,19 +181,16 @@ async def heartbeat_checker():
await asyncio.sleep(HEARTBEAT_INTERVAL)
# -------------------------- 6. 消息处理工具(新增消息解密--------------------------
# -------------------------- 6. 客户端明文消息处理(关键修改:删除解密逻辑--------------------------
async def send_heartbeat_ack(conn: ClientConnection):
"""发送加密的心跳确认"""
"""发送加密的心跳确认(服务器→客户端:加密)"""
try:
# 1. 构建原始心跳确认
heartbeat_ack_msg = {
"type": "heart",
"timestamp": get_current_time_str(),
"client_ip": conn.client_ip
}
# 2. AES加密
encrypted_msg = aes_encrypt(json.dumps(heartbeat_ack_msg))
# 3. 发送
encrypted_msg = aes_encrypt(json.dumps(heartbeat_ack_msg)) # 保持加密
await conn.websocket.send_json(encrypted_msg)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 已发送加密心跳确认")
return True
@ -228,44 +201,22 @@ async def send_heartbeat_ack(conn: ClientConnection):
async def handle_text_msg(conn: ClientConnection, text: str):
"""处理加密的文本消息(如心跳)"""
"""处理客户端明文文本消息(如心跳)- 关键修改无需解密直接解析JSON"""
try:
# 1. 解析加密字典
encrypted_dict = json.loads(text)
# 2. AES解密
decrypted_text = aes_decrypt(encrypted_dict)
# 3. 解析业务消息
msg = json.loads(decrypted_text)
# 客户端发的是明文JSON直接解析删除原解密步骤
msg = json.loads(text)
if msg.get("type") == "heart":
conn.update_heartbeat()
await send_heartbeat_ack(conn)
await send_heartbeat_ack(conn) # 服务器回复仍加密
else:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 未知文本类型({msg.get('type')}")
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 未知明文文本类型({msg.get('type')}")
except json.JSONDecodeError:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 无效JSON格式")
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 无效JSON格式(明文文本)")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 文本消息解密失败 - {str(e)}")
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 明文文本消息处理失败 - {str(e)}")
async def handle_binary_msg(conn: ClientConnection, data: str):
"""处理加密的图像消息客户端需先转Base64+加密)"""
try:
# 1. 解密得到Base64编码的图像
encrypted_dict = json.loads(data)
decrypted_base64 = aes_decrypt(encrypted_dict)
# 2. Base64解码为二进制图像
frame_data = base64.b64decode(decrypted_base64)
# 3. 加入帧队列
conn.frame_queue.put_nowait(frame_data)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 解密后图像({len(frame_data)}字节)入队")
except asyncio.QueueFull:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 帧队列已满,丢弃数据")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 图像消息解密失败 - {str(e)}")
# -------------------------- 7. WebSocket路由与生命周期保持原有结构--------------------------
# -------------------------- 7. WebSocket路由与生命周期关键修改处理明文二进制图像--------------------------
ws_router = APIRouter()
@ -276,7 +227,6 @@ async def lifespan(app: FastAPI):
heartbeat_task = asyncio.create_task(heartbeat_checker())
print(f"[{get_current_time_str()}] 心跳检查任务启动ID: {id(heartbeat_task)}")
yield
# 关闭时清理
if heartbeat_task and not heartbeat_task.done():
heartbeat_task.cancel()
await heartbeat_task
@ -285,8 +235,8 @@ async def lifespan(app: FastAPI):
@ws_router.websocket(WS_ENDPOINT)
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket连接处理入口"""
load_model() # 加载检测模型(仅一次
"""WebSocket连接处理入口 - 关键修改:接收客户端明文二进制图像"""
load_model() # 加载检测模型(建议移到全局,避免重复加载
await websocket.accept()
client_ip = websocket.client.host if websocket.client else "unknown_ip"
current_time = get_current_time_str()
@ -306,8 +256,8 @@ async def websocket_endpoint(websocket: WebSocket):
# 注册新连接
new_conn = ClientConnection(websocket, client_ip)
connected_clients[client_ip] = new_conn
new_conn.start_consumer() # 启动帧消费
await new_conn.send_frame_permit() # 发送首次许可
new_conn.start_consumer()
await new_conn.send_frame_permit() # 首次许可仍加密
# 标记客户端上线
try:
@ -321,28 +271,33 @@ async def websocket_endpoint(websocket: WebSocket):
print(f"[{current_time}] 客户端{client_ip}: 连接注册成功,在线数: {len(connected_clients)}")
# 消息循环接收客户端消息
# 消息循环接收客户端明文消息(关键修改
while True:
data = await websocket.receive()
if "text" in data:
# 处理加密文本消息(心跳、客户端指令
# 处理客户端明文文本(如心跳:{"type":"heart",...}
await handle_text_msg(new_conn, data["text"])
elif "bytes" in data:
# 兼容客户端发送二进制先转Base64再处理
base64_data = base64.b64encode(data["bytes"]).decode("utf-8")
await handle_binary_msg(new_conn, base64_data)
# 处理客户端明文二进制图像(直接入队,无需解密)
frame_data = data["bytes"]
try:
new_conn.frame_queue.put_nowait(frame_data)
print(f"[{get_current_time_str()}] 客户端{client_ip}: 明文图像({len(frame_data)}字节)入队")
except asyncio.QueueFull:
print(f"[{get_current_time_str()}] 客户端{client_ip}: 帧队列已满,丢弃数据")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}: 明文图像处理失败 - {str(e)}")
except WebSocketDisconnect as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}: 主动断开连接(代码: {e.code}")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}: 连接异常 - {str(e)[:50]}")
finally:
# 清理资源(断开后处理)
# 清理资源
if client_ip in connected_clients:
conn = connected_clients[client_ip]
if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel()
# 仅上线成功的客户端,才标记离线
if is_online_updated:
try:
await asyncio.to_thread(update_online_status_by_ip, client_ip, 0)
@ -352,4 +307,4 @@ async def websocket_endpoint(websocket: WebSocket):
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}: 离线更新失败 - {str(e)}")
connected_clients.pop(client_ip, None)
print(f"[{get_current_time_str()}] 客户端{client_ip}: 资源已清理,在线数: {len(connected_clients)}")
print(f"[{get_current_time_str()}] 客户端{client_ip}: 资源已清理,在线数: {len(connected_clients)}")