Compare commits

..

19 Commits

Author SHA1 Message Date
d9192bd964 引擎计算超时 2025-09-12 20:30:54 +08:00
5ecbac0f9c 可以成功动态更换yolo模型并重启服务生效 2025-09-12 20:09:07 +08:00
206652d6bb 可以成功动态更换yolo模型并重启服务生效 2025-09-12 18:28:43 +08:00
4be7f7bf14 目前可以成功动态更换模型运行的 2025-09-12 14:05:09 +08:00
435b2a0e6c 路径写入数据库 2025-09-10 10:53:07 +08:00
ae177ca14a 从服务器读取IP并将检测数据写入数据库 2025-09-10 08:57:56 +08:00
d3c4820b73 识别结果保存到对应目录下后不显示完整路径 2025-09-09 17:09:34 +08:00
532a9e75e9 识别结果保存到对应目录下 2025-09-09 16:30:12 +08:00
0fe49bf829 paddleocr 2025-09-09 09:42:23 +08:00
2571da3c2d 去除本地存储 | 优化代码风格 2025-09-08 18:24:32 +08:00
1dd832e18d 修改WS兼容检测的Future对象 2025-09-08 18:10:49 +08:00
8ceb92c572 优化代码风格 2025-09-08 17:34:23 +08:00
9b3d20511a 最新可用 2025-09-05 17:23:50 +08:00
30bf7c9fcb 最新可用 2025-09-04 22:59:27 +08:00
ec6dbfde90 优化 2025-09-04 17:33:20 +08:00
3ed73bd9eb 1 2025-09-04 17:29:52 +08:00
08f8a0e44e 优化 2025-09-04 17:08:25 +08:00
b5d870a19c 优化 2025-09-04 12:29:27 +08:00
ea82a33a8f 平均特征值计算 2025-09-04 10:46:05 +08:00
59 changed files with 3343 additions and 2228 deletions

2
.idea/Video.iml generated
View File

@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="video" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="Python 3.10" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

2
.idea/misc.xml generated
View File

@ -3,5 +3,5 @@
<component name="Black">
<option name="sdkName" value="video" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="video" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10" project-jdk-type="Python SDK" />
</project>

283
app.py Normal file
View File

@ -0,0 +1,283 @@
from flask import Flask, send_from_directory, abort, request
import os
import logging
from functools import wraps
from pathlib import Path
# 跨域依赖必须安装pip install flask-cors
from flask_cors import CORS
# 配置日志(保持原有格式)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# 初始化 Flask 应用(供 main.py 导入)
app = Flask(__name__)
# ------------------------------
# 核心修改:与 FastAPI 对齐的跨域配置
# ------------------------------
# 1. 允许的前端域名(完全复制 FastAPI 的 ALLOWED_ORIGINS确保前后端一致
ALLOWED_ORIGINS = [
# "http://localhost:8080", # 本地前端开发地址(必改:替换为你的前端实际地址)
# "http://127.0.0.1:8080",
# "http://服务器IP:8080", # 部署后前端地址替换为你的服务器IP/域名)
# # "*" 仅开发环境临时使用,生产环境必须删除(安全风险)
"*"
]
# 2. 配置 CORS与 FastAPI 规则完全对齐)
CORS(
app,
resources={
r"/*": { # 对所有 Flask 路由生效(覆盖图片、模型下载所有接口)
"origins": ALLOWED_ORIGINS, # 允许的前端域名(与 FastAPI 一致)
"allow_credentials": True, # 允许携带 Cookie与 FastAPI 一致,需登录态必开)
"methods": ["*"], # 允许所有 HTTP 方法FastAPI 用 "*",此处同步)
"allow_headers": ["*"], # 允许所有请求头(与 FastAPI 一致)
}
},
)
# ------------------------------
# 核心路径配置(不变,确保资源目录正确)
# ------------------------------
CURRENT_FILE_PATH = Path(__file__).resolve()
PROJECT_ROOT = CURRENT_FILE_PATH.parent # 项目根目录video/
# 资源目录(图片、模型)
BASE_IMAGE_DIR_DECT = str((PROJECT_ROOT / "resource" / "dect").resolve()) # 检测图片目录
BASE_IMAGE_DIR_UP_IMAGES = str((PROJECT_ROOT / "up_images").resolve()) # 人脸图片目录
BASE_MODEL_DIR = str((PROJECT_ROOT / "resource" / "models").resolve()) # 模型文件目录
# 打印路径配置(调试用,确认目录正确)
# logger.info(f"[Flask 配置] 项目根目录:{PROJECT_ROOT}")
# logger.info(f"[Flask 配置] 模型目录:{BASE_MODEL_DIR}")
# logger.info(f"[Flask 配置] 人脸图片目录:{BASE_IMAGE_DIR_UP_IMAGES}")
# logger.info(f"[Flask 配置] 检测图片目录:{BASE_IMAGE_DIR_DECT}")
# ------------------------------
# 安全检查装饰器(不变,防路径遍历/非法文件)
# ------------------------------
def safe_path_check(root_dir: str):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
resource_path = kwargs.get('resource_path', '').strip()
# 统一路径分隔符(兼容 Windows \ 和 Linux /
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
# 拼接完整路径(防止路径遍历)
full_file_path = os.path.abspath(os.path.join(root_dir, resource_path))
logger.debug(
f"[Flask 安全检查] 请求路径:{resource_path} | 完整路径:{full_file_path} | 根目录:{root_dir}"
)
# 1. 禁止路径遍历(确保请求文件在根目录内)
if not full_file_path.startswith(root_dir):
logger.warning(
f"[Flask 安全拦截] 非法路径遍历IP{request.remote_addr} | 请求路径:{resource_path}"
)
abort(403)
# 2. 检查文件存在且为有效文件(非目录)
if not os.path.exists(full_file_path) or not os.path.isfile(full_file_path):
logger.warning(
f"[Flask 资源错误] 文件不存在/非文件IP{request.remote_addr} | 路径:{full_file_path}"
)
abort(404)
# 3. 限制文件大小模型200MB图片10MB避免超大文件攻击
max_size = 200 * 1024 * 1024 if "models" in root_dir else 10 * 1024 * 1024
if os.path.getsize(full_file_path) > max_size:
logger.warning(
f"[Flask 大小超限] 文件超过{max_size//1024//1024}MBIP{request.remote_addr} | 路径:{full_file_path}"
)
abort(413)
# 安全检查通过,传递根目录给视图函数
return func(*args, **kwargs, root_dir=root_dir)
return wrapper
return decorator
# ------------------------------
# 1. 模型下载接口(/model/download/*
# ------------------------------
@app.route('/model/download/<path:resource_path>')
@safe_path_check(root_dir=BASE_MODEL_DIR)
def download_model(resource_path, root_dir):
try:
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
dir_path, file_name = os.path.split(resource_path)
full_dir = os.path.abspath(os.path.join(root_dir, dir_path))
# 仅允许 .pt 格式YOLO 模型)
if not file_name.lower().endswith('.pt'):
logger.warning(
f"[Flask 格式错误] 非 .pt 模型文件IP{request.remote_addr} | 文件名:{file_name}"
)
abort(415)
logger.info(
f"[Flask 模型下载] 成功请求IP{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}"
)
# 强制浏览器下载(而非预览),设置二进制文件类型
return send_from_directory(
full_dir,
file_name,
as_attachment=True,
mimetype="application/octet-stream"
)
except Exception as e:
logger.error(
f"[Flask 模型下载异常] IP{request.remote_addr} | 错误:{str(e)}"
)
abort(500)
# ------------------------------
# 2. 人脸图片访问接口(/up_images/*
# ------------------------------
@app.route('/up_images/<path:resource_path>')
@safe_path_check(root_dir=BASE_IMAGE_DIR_UP_IMAGES)
def get_face_image(resource_path, root_dir):
try:
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
dir_path, file_name = os.path.split(resource_path)
full_dir = os.path.abspath(os.path.join(root_dir, dir_path))
# 仅允许常见图片格式
allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp')
if not file_name.lower().endswith(allowed_ext):
logger.warning(
f"[Flask 格式错误] 非图片文件IP{request.remote_addr} | 文件名:{file_name}"
)
abort(415)
logger.info(
f"[Flask 人脸图片] 成功请求IP{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}"
)
# 允许浏览器预览图片(而非下载)
return send_from_directory(full_dir, file_name, as_attachment=False)
except Exception as e:
logger.error(
f"[Flask 人脸图片异常] IP{request.remote_addr} | 错误:{str(e)}"
)
abort(500)
# ------------------------------
# 3. 检测图片访问接口(/resource/dect/*
# ------------------------------
@app.route('/resource/dect/<path:resource_path>')
@safe_path_check(root_dir=BASE_IMAGE_DIR_DECT)
def get_dect_image(resource_path, root_dir):
try:
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
dir_path, file_name = os.path.split(resource_path)
full_dir = os.path.abspath(os.path.join(root_dir, dir_path))
# 仅允许常见图片格式
allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp')
if not file_name.lower().endswith(allowed_ext):
logger.warning(
f"[Flask 格式错误] 非图片文件IP{request.remote_addr} | 文件名:{file_name}"
)
abort(415)
logger.info(
f"[Flask 检测图片] 成功请求IP{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}"
)
return send_from_directory(full_dir, file_name, as_attachment=False)
except Exception as e:
logger.error(
f"[Flask 检测图片异常] IP{request.remote_addr} | 错误:{str(e)}"
)
abort(500)
# ------------------------------
# 4. 兼容旧图片接口(/images/* → 映射到 /resource/dect/*
# ------------------------------
@app.route('/images/<path:resource_path>')
@safe_path_check(root_dir=BASE_IMAGE_DIR_DECT)
def get_compatible_image(resource_path, root_dir):
try:
# 逻辑与检测图片接口一致仅URL前缀不同兼容旧前端
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
dir_path, file_name = os.path.split(resource_path)
full_dir = os.path.abspath(os.path.join(root_dir, dir_path))
allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp')
if not file_name.lower().endswith(allowed_ext):
logger.warning(
f"[Flask 格式错误] 非图片文件IP{request.remote_addr} | 文件名:{file_name}"
)
abort(415)
logger.info(
f"[Flask 兼容图片] 成功请求IP{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}"
)
return send_from_directory(full_dir, file_name, as_attachment=False)
except Exception as e:
logger.error(
f"[Flask 兼容图片异常] IP{request.remote_addr} | 错误:{str(e)}"
)
abort(500)
# ------------------------------
# 全局错误处理器(友好提示,与 FastAPI 错误信息风格一致)
# ------------------------------
@app.errorhandler(403)
def forbidden_error(error):
return "❌ 禁止访问:路径非法(可能存在路径遍历)或无权限", 403
@app.errorhandler(404)
def not_found_error(error):
return "❌ 资源不存在请检查URL路径IP、目录、文件名是否正确", 404
@app.errorhandler(413)
def too_large_error(error):
return "❌ 文件过大图片最大10MB模型最大200MB", 413
@app.errorhandler(415)
def unsupported_type_error(error):
return "❌ 不支持的文件类型图片支持png/jpg/jpeg/gif/bmp模型仅支持pt", 415
@app.errorhandler(500)
def server_error(error):
return "❌ 服务器内部错误:请联系管理员查看后台日志", 500
# ------------------------------
# Flask 独立启动入口(供测试,实际由 main.py 子线程启动)
# ------------------------------
if __name__ == '__main__':
# 确保所有资源目录存在(防止初始化失败)
required_dirs = [
(BASE_IMAGE_DIR_DECT, "检测图片目录"),
(BASE_IMAGE_DIR_UP_IMAGES, "人脸图片目录"),
(BASE_MODEL_DIR, "模型文件目录")
]
for dir_path, dir_desc in required_dirs:
if not os.path.exists(dir_path):
logger.info(f"[Flask 初始化] {dir_desc}不存在,创建:{dir_path}")
os.makedirs(dir_path, exist_ok=True)
# 启动提示(含访问示例)
logger.info("\n[Flask 服务启动成功!] 支持的接口:")
logger.info(f"1. 模型下载 → http://服务器IP:5000/model/download/resource/models/xxx.pt")
logger.info(f"2. 人脸图片 → http://服务器IP:5000/up_images/xxx.jpg")
logger.info(f"3. 检测图片 → http://服务器IP:5000/resource/dect/xxx.jpg 或 http://服务器IP:5000/images/xxx.jpg\n")
# 启动服务(禁用 debug 和自动重载,避免多线程冲突)
app.run(
host="0.0.0.0", # 允许外部IP访问
port=5000, # 与 main.py 中 Flask 端口一致
debug=False,
use_reloader=False
)

View File

@ -12,8 +12,4 @@ charset = utf8mb4
[jwt]
secret_key = 6tsieyd87wefdw2wgeduwte23rfcsd
algorithm = HS256
access_token_expire_minutes = 30
[live]
rtmp_url = rtmp://192.168.110.65:1935/live/
webrtc_url = http://192.168.110.65:1985/rtc/v1/whep/?app=live&stream=
access_token_expire_minutes = 30

95
core/all.py Normal file
View File

@ -0,0 +1,95 @@
import cv2
import numpy as np
from PIL.Image import Image
from core.establish import get_image_save_path
from core.ocr import load_model as ocrLoadModel, detect as ocrDetect
from core.face import load_model as faceLoadModel, detect as faceDetect
from core.yolo import load_model as yoloLoadModel, detect as yoloDetect
# 导入保存路径函数(根据实际文件位置调整导入路径)
import numpy as np
import base64
from io import BytesIO
from PIL import Image
from ds.db import db
from mysql.connector import Error as MySQLError
# 模型加载状态标记(避免重复加载)
_model_loaded = False
def load_model():
"""加载所有检测模型(仅首次调用时执行)"""
global _model_loaded
if _model_loaded:
print("模型已加载,无需重复执行")
return
# 依次加载OCR、人脸、YOLO模型
ocrLoadModel()
faceLoadModel()
yoloLoadModel()
_model_loaded = True
print("所有检测模型加载完成")
def save_db(model_type, client_ip, result):
conn = None
cursor = None
try:
# 连接数据库
conn = db.get_connection()
# 往表插入数据
cursor = conn.cursor(dictionary=True) # 返回字典格式结果
insert_query = """
INSERT INTO device_danger (client_ip, type, result)
VALUES (%s, %s, %s)
"""
cursor.execute(insert_query, (client_ip, model_type, result))
conn.commit()
except MySQLError as e:
raise Exception(f"获取设备列表失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 修正后的 detect 函数关键部分
def detect(client_ip, frame):
# 1. YOLO检测
yolo_flag, yolo_result = yoloDetect(frame)
if yolo_flag:
# model_type 传入 "yolo"(正确)
full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip)
if full_save_path:
cv2.imwrite(full_save_path, frame)
print(f"✅ yolo违规图片已保存{display_path}") # 日志也修正
save_db(model_type="yolo", client_ip=client_ip, result=str(full_save_path))
return (True, yolo_result, "yolo")
# 2. 人脸检测
face_flag, face_result = faceDetect(frame)
if face_flag:
full_save_path, display_path = get_image_save_path(model_type="face", client_ip=client_ip) # 这里改了
if full_save_path:
cv2.imwrite(full_save_path, frame)
print(f"✅ face违规图片已保存{display_path}") # 日志也修正
save_db(model_type="face", client_ip=client_ip, result=str(full_save_path))
return (True, face_result, "face")
# 3. OCR检测
ocr_flag, ocr_result = ocrDetect(frame)
if ocr_flag:
full_save_path, display_path = get_image_save_path(model_type="ocr", client_ip=client_ip) # 这里改了
if full_save_path:
cv2.imwrite(full_save_path, frame)
print(f"✅ ocr违规图片已保存{display_path}") # 日志也修正
save_db(model_type="ocr", client_ip=client_ip, result=str(full_save_path))
return (True, ocr_result, "ocr")
# 4. 无违规内容(不保存图片)
print(f"❌ 未检测到任何违规内容,不保存图片")
return (False, "未检测到任何内容", "none")

119
core/establish.py Normal file
View File

@ -0,0 +1,119 @@
import datetime
from pathlib import Path
from service.device_service import get_unique_client_ips
def create_directory_structure():
"""创建项目所需的目录结构为所有客户端IP预创建基础目录"""
try:
# 1. 创建根目录下的resource文件夹存在则跳过不覆盖子内容
resource_dir = Path("resource")
resource_dir.mkdir(exist_ok=True)
# print(f"确保resource目录存在: {resource_dir.absolute()}")
# 2. 在resource下创建dect文件夹
dect_dir = resource_dir / "dect"
dect_dir.mkdir(exist_ok=True)
# print(f"确保dect目录存在: {dect_dir.absolute()}")
# 3. 在dect下创建三个模型文件夹
model_dirs = ["ocr", "face", "yolo"]
for model in model_dirs:
model_dir = dect_dir / model
model_dir.mkdir(exist_ok=True)
# print(f"确保{model}模型目录存在: {model_dir.absolute()}")
# 4. 调用外部方法获取所有客户端IP地址
try:
# 调用外部ip_read()方法获取所有客户端IP地址列表
all_ip_addresses = get_unique_client_ips()
# 确保返回的是列表类型
if not isinstance(all_ip_addresses, list):
all_ip_addresses = [all_ip_addresses]
# 过滤有效IP去除空字符串和空格
valid_ips = [ip.strip() for ip in all_ip_addresses if ip.strip()]
if not valid_ips:
print("警告: 未获取到有效的客户端IP地址")
return
print(f"获取到的所有客户端IP地址: {valid_ips}")
# 5. 获取当前日期(年、月)
now = datetime.datetime.now()
current_year = str(now.year)
current_month = str(now.month)
# 6. 为每个客户端IP在每个模型文件夹下创建年->月的基础目录结构
for ip in valid_ips:
# 处理IP地址中的特殊字符将.替换为_避免路径问题
safe_ip = ip.replace(".", "_")
for model in model_dirs:
# 构建路径: resource/dect/{model}/{safe_ip}/{year}/{month}
ip_dir = dect_dir / model / safe_ip
year_dir = ip_dir / current_year
month_dir = year_dir / current_month
# 递归创建目录(存在则跳过,不覆盖)
month_dir.mkdir(parents=True, exist_ok=True)
# print(f"为客户端IP {ip} 创建/确保目录存在: {month_dir.absolute()}")
except Exception as e:
print(f"处理客户端IP和日期目录时发生错误: {str(e)}")
except Exception as e:
print(f"创建基础目录结构时发生错误: {str(e)}")
def get_image_save_path(model_type: str, client_ip: str) -> tuple:
"""
获取图片保存的「完整路径」和「显示用短路径」
参数:
model_type: 模型类型,应为"ocr""face""yolo"
client_ip: 检测到违禁的客户端IP地址原始格式如192.168.1.101
返回:
元组 (完整保存路径, 显示用短路径);若出错则返回 ("", "")
"""
try:
# 1. 验证客户端IP有效性检查是否在已知IP列表中
all_ip_addresses = get_unique_client_ips()
if not isinstance(all_ip_addresses, list):
all_ip_addresses = [all_ip_addresses]
valid_ips = [ip.strip() for ip in all_ip_addresses if ip.strip()]
if client_ip.strip() not in valid_ips:
raise ValueError(f"客户端IP {client_ip} 不在已知IP列表中无法保存文件")
# 2. 处理IP地址与目录创建逻辑一致将.替换为_
safe_ip = client_ip.strip().replace(".", "_")
# 3. 获取当前日期和毫秒级时间戳(确保文件名唯一)
now = datetime.datetime.now()
current_year = str(now.year)
current_month = str(now.month)
current_day = str(now.day)
# 时间戳格式年月日时分秒毫秒如20250910143050123
timestamp = now.strftime("%Y%m%d%H%M%S%f")[:-3]
# 4. 定义基础目录(用于生成相对路径)
base_dir = Path("resource") / "dect" # 显示路径会去掉这个前缀
# 构建日级目录完整路径resource/dect/{model}/{safe_ip}/{年}/{月}/{日}
day_dir = base_dir / model_type / safe_ip / current_year / current_month / current_day
day_dir.mkdir(parents=True, exist_ok=True) # 确保日目录存在
# 5. 构建唯一文件名
image_filename = f"dect_{model_type}_{safe_ip}_{current_year}{current_month}{current_day}_{timestamp}.jpg"
# 6. 生成完整路径(用于实际保存图片)和显示路径(用于打印)
full_path = day_dir / image_filename # 完整路径resource/dect/.../xxx.jpg
display_path = full_path.relative_to(base_dir) # 短路径:{model}/.../xxx.jpg去掉resource/dect
return str(full_path), str(display_path)
except Exception as e:
print(f"获取图片保存路径时发生错误: {str(e)}")
return "", ""

330
core/face.py Normal file
View File

@ -0,0 +1,330 @@
import os
import numpy as np
import gc
import time
import threading
from insightface.app import FaceAnalysis
from service.face_service import get_all_face_name_with_eigenvalue
# GPU状态检查支持
try:
import pynvml
pynvml.nvmlInit()
_nvml_available = True
except ImportError:
print("警告: pynvml库未安装无法检测GPU状态默认尝试使用GPU")
_nvml_available = False
# 全局人脸引擎与特征库
_face_app = None
_known_faces_embeddings = {} # 姓名 -> 归一化特征值的映射
_known_faces_names = [] # 已知人脸姓名列表
# GPU使用状态标记
_using_gpu = False # 是否使用GPU
_used_gpu_id = -1 # 使用的GPU ID-1表示CPU
# 资源管理变量
_ref_count = 0 # 引擎引用计数(记录当前使用次数)
# 修复点1初始值设为当前时间避免未加载引擎时用0计算超时
_last_used_time = time.time()
_lock = threading.Lock() # 线程安全锁
_release_timeout = 8 # 闲置超时时间(秒)
_is_releasing = False # 资源释放中标记
_monitor_thread_running = False # 监控线程运行标记
# 调试计数器
_debug_counter = {
"engine_created": 0, # 引擎创建次数
"engine_released": 0, # 引擎释放次数
"detection_calls": 0 # 检测函数调用次数
}
def check_gpu_availability(gpu_id, memory_threshold=0.7):
"""检查指定GPU的内存使用率是否低于阈值判定为“可用”"""
if not _nvml_available:
return True
try:
handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
memory_usage = mem_info.used / mem_info.total
return memory_usage < memory_threshold
except Exception as e:
print(f"检查GPU {gpu_id} 状态失败: {e}")
return False
def select_best_gpu(preferred_gpus=[0, 1]):
"""按优先级选择可用GPU优先0号均不可用则返回-1CPU"""
for gpu_id in preferred_gpus:
try:
# 验证GPU是否存在
if _nvml_available:
pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
# 验证GPU内存是否充足
if check_gpu_availability(gpu_id):
print(f"GPU {gpu_id} 可用将使用该GPU")
return gpu_id
else:
if gpu_id == 0:
print("GPU 0 内存使用率过高尝试其他GPU")
except Exception as e:
print(f"GPU {gpu_id} 不可用或访问失败: {e}")
print("所有指定GPU均不可用将使用CPU计算")
return -1
def _release_engine_resources():
"""释放人脸引擎的所有资源模型、特征库、GPU缓存等"""
global _face_app, _is_releasing, _known_faces_embeddings, _known_faces_names, _last_used_time
if not _face_app or _is_releasing:
return
try:
_is_releasing = True
print("开始释放人脸引擎资源...")
# 释放InsightFace模型资源
if hasattr(_face_app, "model"):
_face_app.model = None # 显式置空模型引用
_face_app = None # 释放引擎实例
# 清空人脸特征库
_known_faces_embeddings.clear()
_known_faces_names.clear()
_debug_counter["engine_released"] += 1
print(f"人脸引擎已释放,调试统计: {_debug_counter}")
# 强制垃圾回收
gc.collect()
# 清理各深度学习框架的GPU缓存
# Torch 缓存清理
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
print("Torch GPU缓存已清理")
except ImportError:
pass
# TensorFlow 缓存清理
try:
import tensorflow as tf
tf.keras.backend.clear_session()
print("TensorFlow会话已清理")
except ImportError:
pass
# MXNet 缓存清理InsightFace底层常用MXNet
try:
import mxnet as mx
mx.nd.waitall() # 等待所有计算完成并释放资源
print("MXNet资源已等待释放")
except ImportError:
pass
except Exception as e:
print(f"释放资源过程中出错: {e}")
finally:
_is_releasing = False
# 修复点2释放完成后重置最后使用时间避免下次加载时复用旧值
_last_used_time = time.time()
def _resource_monitor_thread():
"""后台监控线程:检测引擎闲置超时,触发资源释放"""
global _ref_count, _last_used_time, _face_app, _monitor_thread_running
_monitor_thread_running = True
while _monitor_thread_running:
time.sleep(2) # 缩短检查间隔,加快闲置检测响应
with _lock:
# 当“引擎存在 + 无引用 + 未在释放中”时,检查闲置时间
if _face_app and _ref_count == 0 and not _is_releasing:
idle_time = time.time() - _last_used_time
if idle_time > _release_timeout:
print(f"引擎闲置超时({idle_time:.1f}s > {_release_timeout}s释放资源")
_release_engine_resources()
def load_model(prefer_gpu=True, preferred_gpus=[0, 1]):
"""加载人脸识别引擎及已知人脸特征库默认优先用0号GPU"""
global _face_app, _known_faces_embeddings, _known_faces_names, _using_gpu, _used_gpu_id, _last_used_time
# 启动后台监控线程(确保仅启动一次)
if not _monitor_thread_running:
threading.Thread(
target=_resource_monitor_thread,
daemon=True,
name="FaceEngineMonitor"
).start()
print("人脸引擎监控线程已启动")
# 若正在释放资源,等待释放完成
while _is_releasing:
time.sleep(0.1)
# 若引擎已初始化,直接返回
if _face_app:
return True
# 初始化InsightFace引擎
try:
print("正在初始化InsightFace人脸识别引擎...")
_face_app = FaceAnalysis(name="buffalo_l", root=os.path.expanduser("~/.insightface"))
# 选择GPU优先用0号
ctx_id = 0
if prefer_gpu:
ctx_id = select_best_gpu(preferred_gpus)
_using_gpu = ctx_id != -1
_used_gpu_id = ctx_id if _using_gpu else -1
if _using_gpu:
print(f"引擎初始化成功将使用GPU {ctx_id} 计算")
else:
print("引擎初始化成功将使用CPU计算")
# 准备模型(加载到指定设备)
_face_app.prepare(ctx_id=ctx_id, det_size=(640, 640))
print("InsightFace引擎初始化完成")
# 修复点3引擎初始化成功后立即更新“最后使用时间”核心修复
_last_used_time = time.time()
_debug_counter["engine_created"] += 1
print(f"引擎调试统计: {_debug_counter}")
except Exception as e:
print(f"引擎初始化失败: {e}")
return False
# 从服务加载已知人脸的姓名和特征值
try:
face_data = get_all_face_name_with_eigenvalue()
for person_name, eigenvalue_data in face_data.items():
# 兼容“numpy数组”和“字符串”格式的特征值
if isinstance(eigenvalue_data, np.ndarray):
eigenvalue = eigenvalue_data.astype(np.float32)
elif isinstance(eigenvalue_data, str):
# 清理字符串中的括号、换行等干扰符
cleaned = eigenvalue_data.replace("[", "").replace("]", "").replace("\n", "").strip()
# 分割并转换为浮点数数组
values = [v for v in cleaned.split() if v] # 兼容空格/逗号分隔
eigenvalue = np.array(list(map(float, values)), dtype=np.float32)
else:
print(f"不支持的特征值类型({type(eigenvalue_data)}),跳过 {person_name}")
continue
# 特征值归一化(保证后续相似度计算的一致性)
norm = np.linalg.norm(eigenvalue)
if norm != 0:
eigenvalue = eigenvalue / norm
_known_faces_embeddings[person_name] = eigenvalue
_known_faces_names.append(person_name)
print(f"成功加载 {len(_known_faces_names)} 个人脸的特征库")
except Exception as e:
print(f"加载人脸特征库失败: {e}")
return _face_app is not None
def detect(frame, similarity_threshold=0.4):
"""
检测并识别人脸
返回:(是否匹配到已知人脸, 结果描述字符串)
"""
global _face_app, _known_faces_embeddings, _known_faces_names, _ref_count, _last_used_time
# 校验输入帧有效性
if frame is None or frame.size == 0:
return (False, "无效的输入帧数据")
# 加锁并更新引用计数、最后使用时间
engine = None
with _lock:
_ref_count += 1
_last_used_time = time.time()
_debug_counter["detection_calls"] += 1
# 若引擎未初始化且未在释放中,尝试初始化
if not _face_app and not _is_releasing:
if not load_model(prefer_gpu=True):
# 初始化失败,恢复引用计数
with _lock:
_ref_count = max(0, _ref_count - 1)
return (False, "人脸引擎初始化失败")
engine = _face_app # 获取引擎引用
# 校验引擎可用性
if not engine or len(_known_faces_names) == 0:
with _lock:
_ref_count = max(0, _ref_count - 1)
return (False, "人脸引擎不可用或特征库为空")
try:
# GPU计算时确保帧数据是连续内存避免CUDA错误
if _using_gpu and engine is not None and not frame.flags.contiguous:
frame = np.ascontiguousarray(frame)
# 执行人脸检测与特征提取
faces = engine.get(frame)
except Exception as e:
print(f"人脸检测过程出错: {e}")
# 出错时尝试重新初始化引擎可能是GPU状态变化导致
print("尝试重新初始化人脸引擎...")
with _lock:
_ref_count = max(0, _ref_count - 1)
load_model(prefer_gpu=True)
return (False, f"检测错误: {str(e)}")
result_parts = []
has_matched_known_face = False # 是否有任意人脸匹配到已知库
for face in faces:
# 归一化当前检测到的人脸特征
face_embedding = face.embedding.astype(np.float32)
norm = np.linalg.norm(face_embedding)
if norm == 0:
continue
face_embedding = face_embedding / norm
# 与已知人脸特征逐一比对
max_similarity, best_match_name = -1.0, "Unknown"
for name in _known_faces_names:
known_emb = _known_faces_embeddings[name]
similarity = np.dot(face_embedding, known_emb) # 余弦相似度
if similarity > max_similarity:
max_similarity = similarity
best_match_name = name
# 判断是否匹配成功
is_matched = max_similarity >= similarity_threshold
if is_matched:
has_matched_known_face = True
# 记录该人脸的检测结果
bbox = face.bbox # 人脸边界框
result_parts.append(
f"{'匹配' if is_matched else '未匹配'}: {best_match_name} "
f"(相似度: {max_similarity:.2f}, 边界框: {bbox.astype(int).tolist()})"
)
# 构建最终结果字符串
result_str = "未检测到人脸" if not result_parts else "; ".join(result_parts)
# 释放引用计数(线程安全)
with _lock:
_ref_count = max(0, _ref_count - 1)
# 若仍有引用更新最后使用时间若引用为0也立即标记加快闲置检测
_last_used_time = time.time()
return (has_matched_known_face, result_str)

BIN
core/models/best.pt Normal file

Binary file not shown.

253
core/ocr.py Normal file
View File

@ -0,0 +1,253 @@
import os
import cv2
import gc
import time
import threading
import numpy as np
from paddleocr import PaddleOCR
from service.sensitive_service import get_all_sensitive_words
# 解决NumPy 1.20+版本中np.int已移除的兼容性问题
try:
if not hasattr(np, 'int'):
np.int = int
except Exception as e:
print(f"处理NumPy兼容性时出错: {e}")
# 全局变量
_ocr_engine = None
_forbidden_words = set()
_conf_threshold = 0.5
# 资源管理变量
_ref_count = 0
_last_used_time = 0
_lock = threading.Lock()
_release_timeout = 5 # 30秒无使用则释放
_is_releasing = False # 标记是否正在释放
# 并行处理配置
_max_workers = 4 # 并行处理的线程数
# 调试用计数器
_debug_counter = {
"created": 0,
"released": 0,
"detected": 0
}
def _release_engine():
"""释放OCR引擎资源"""
global _ocr_engine, _is_releasing
if not _ocr_engine or _is_releasing:
return
try:
_is_releasing = True
_ocr_engine = None
_debug_counter["released"] += 1
print(f"OCR engine released. Stats: {_debug_counter}")
# 清理GPU缓存
gc.collect()
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
except ImportError:
pass
try:
import paddle
if paddle.is_compiled_with_cuda():
paddle.device.cuda.empty_cache()
except ImportError:
pass
finally:
_is_releasing = False
def _monitor_thread():
"""监控线程,优化检查逻辑"""
global _ref_count, _last_used_time, _ocr_engine
while True:
time.sleep(5) # 每5秒检查一次
with _lock:
if _ocr_engine and _ref_count == 0 and not _is_releasing:
elapsed = time.time() - _last_used_time
if elapsed > _release_timeout:
print(f"Idle timeout ({elapsed:.1f}s > {_release_timeout}s), releasing engine")
_release_engine()
def load_model():
"""加载违禁词列表和初始化监控线程"""
global _forbidden_words
# 确保监控线程只启动一次
if not any(t.name == "OCRMonitor" for t in threading.enumerate()):
threading.Thread(target=_monitor_thread, daemon=True, name="OCRMonitor").start()
print("OCR monitor thread started")
# 加载违禁词
try:
_forbidden_words = get_all_sensitive_words()
print(f"Loaded {len(_forbidden_words)} forbidden words")
except Exception as e:
print(f"Forbidden words load error: {e}")
return False
return True
def detect(frame):
"""OCR检测支持并行处理"""
global _ocr_engine, _forbidden_words, _conf_threshold, _ref_count, _last_used_time, _max_workers
# 验证前置条件
if not _forbidden_words:
return (False, "违禁词未初始化")
if frame is None or frame.size == 0:
return (False, "无效帧数据")
# 增加引用计数并获取引擎实例
engine = None
with _lock:
_ref_count += 1
_last_used_time = time.time()
_debug_counter["detected"] += 1
# 初始化引擎(如果未初始化且不在释放中)
if not _ocr_engine and not _is_releasing:
try:
# 初始化PaddleOCR设置并行处理参数
_ocr_engine = PaddleOCR(
use_angle_cls=True,
lang="ch",
show_log=False,
use_gpu=True,
max_text_length=1024,
threads=_max_workers
)
_debug_counter["created"] += 1
print(f"PaddleOCR engine initialized with {_max_workers} workers. Stats: {_debug_counter}")
except Exception as e:
print(f"OCR model load failed: {e}")
_ref_count -= 1
return (False, f"引擎初始化失败: {str(e)}")
engine = _ocr_engine
# 检查引擎是否可用
if not engine:
with _lock:
_ref_count -= 1
return (False, "OCR引擎不可用")
try:
# 执行OCR检测
ocr_res = engine.ocr(frame, cls=True)
# 验证OCR结果格式
if not ocr_res or not isinstance(ocr_res, list):
return (False, "无OCR结果")
# 处理OCR结果 - 兼容多种格式
texts = []
confs = []
for line in ocr_res:
if line is None:
continue
# 处理line可能是列表或直接是文本信息的情况
if isinstance(line, list):
items_to_process = line
else:
items_to_process = [line]
for item in items_to_process:
# 精确识别并忽略图片坐标位置信息 [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
if isinstance(item, list) and len(item) == 4: # 四边形有4个顶点
is_coordinate = True
for point in item:
# 每个顶点应该是包含2个数字的列表
if not (isinstance(point, list) and len(point) == 2 and
all(isinstance(coord, (int, float)) for coord in point)):
is_coordinate = False
break
if is_coordinate:
continue # 是坐标信息,直接忽略
# 跳过纯数字列表(其他可能的坐标形式)
if isinstance(item, list) and all(isinstance(x, (int, float)) for x in item):
continue
# 处理元组形式的文本和置信度 (text, confidence)
if isinstance(item, tuple) and len(item) == 2:
text, conf = item
if isinstance(text, str) and isinstance(conf, (int, float)):
texts.append(text.strip())
confs.append(float(conf))
continue
# 处理列表形式的[坐标信息, (text, confidence)]
if isinstance(item, list) and len(item) >= 2:
# 尝试从列表中提取文本和置信度
text_data = item[1]
if isinstance(text_data, tuple) and len(text_data) == 2:
text, conf = text_data
if isinstance(text, str) and isinstance(conf, (int, float)):
texts.append(text.strip())
confs.append(float(conf))
continue
elif isinstance(text_data, str):
# 只有文本没有置信度的情况
texts.append(text_data.strip())
confs.append(1.0) # 默认最高置信度
continue
# 无法识别的格式,记录日志
print(f"无法解析的OCR结果格式: {item}")
if len(texts) != len(confs):
return (False, "OCR结果格式异常")
# 筛选违禁词
vio_info = []
for txt, conf in zip(texts, confs):
if conf < _conf_threshold:
continue
matched = [w for w in _forbidden_words if w in txt]
if matched:
vio_info.append(f"文本: '{txt}' 包含违禁词: {', '.join(matched)} (置信度: {conf:.2f})")
# 构建结果
has_text = len(texts) > 0
has_violation = len(vio_info) > 0
if not has_text:
return (False, "未识别到文本")
elif has_violation:
return (True, "; ".join(vio_info))
else:
return (False, "未检测到违禁词")
except Exception as e:
print(f"OCR detect error: {e}")
return (False, f"检测错误: {str(e)}")
finally:
# 减少引用计数,确保线程安全
with _lock:
_ref_count = max(0, _ref_count - 1)
if _ref_count > 0:
_last_used_time = time.time()
def batch_detect(frames):
"""批量检测接口,充分利用并行能力"""
results = []
for frame in frames:
results.append(detect(frame))
return results

View File

@ -1,137 +0,0 @@
import asyncio
import logging
from aiortc import RTCPeerConnection, RTCSessionDescription
import aiohttp
from ocr.ocr_violation_detector import OCRViolationDetector
import logging
# 创建检测器实例
detector = OCRViolationDetector(
forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt",
ocr_confidence_threshold=0.7,
log_level=logging.INFO,
log_file="ocr_detection.log"
)
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("whep_video_puller")
async def whep_pull_video_stream(ip,whep_url):
"""
通过WHEP从指定URL拉取视频流并在收到每一帧时打印消息
Args:
whep_url: WHEP端点的URL
"""
pc = RTCPeerConnection()
# 添加连接状态变化监听
@pc.on("connectionstatechange")
async def on_connectionstatechange():
print(f"连接状态: {pc.connectionState}")
# 添加ICE连接状态变化监听
@pc.on("iceconnectionstatechange")
async def on_iceconnectionstatechange():
print(f"ICE连接状态: {pc.iceConnectionState}")
# 添加视频接收器
pc.addTransceiver("video", direction="recvonly")
# 处理接收到的视频轨道
@pc.on("track")
def on_track(track):
print(f"接收到轨道: {track.kind}")
if track.kind == "video":
print(f"轨道ID: {track.id}")
print(f"轨道就绪状态: {track.readyState}")
# 创建异步任务来处理视频帧
asyncio.ensure_future(handle_video_track(track))
async def handle_video_track(track):
"""处理视频轨道,接收并打印每一帧"""
frame_count = 0
print("开始处理视频轨道...")
while True:
try:
# 尝试接收帧
frame = await track.recv()
frame_count += 1
print(f"收到原始帧 (第{frame_count}帧)")
# 打印帧的基本信息
if hasattr(frame, 'width') and hasattr(frame, 'height'):
print(f" 尺寸: {frame.width}x{frame.height}")
if hasattr(frame, 'time_base'):
print(f" 时间基准: {frame.time_base}")
if hasattr(frame, 'pts'):
print(f" 显示时间戳: {frame.pts}")
has_violation, violations, confidences = OCRViolationDetector.detect(frame)
# 输出检测结果
if has_violation:
detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
for word, conf in zip(violations, confidences):
detector.logger.info(f"- {word} (置信度: {conf:.4f})")
else:
detector.logger.info("图片中未检测到违禁词")
except Exception as e:
print(f"接收帧时出错: {e}")
# 等待一段时间后重试
await asyncio.sleep(0.1)
continue
# 创建offer
offer = await pc.createOffer()
await pc.setLocalDescription(offer)
print(f"本地SDP信息:\n{offer.sdp}")
# 通过HTTP POST发送offer到WHEP端点
async with aiohttp.ClientSession() as session:
async with session.post(
whep_url,
data=offer.sdp,
headers={"Content-Type": "application/sdp"}
) as response:
if response.status != 201:
print(f"WHEP服务器返回错误: {response.status}")
print(f"响应内容: {await response.text()}")
raise Exception(f"WHEP服务器返回错误: {response.status}")
# 获取answer SDP
answer_sdp = await response.text()
# 创建RTCSessionDescription对象
answer = RTCSessionDescription(sdp=answer_sdp, type="answer")
print(f"收到远程SDP:\n{answer_sdp}")
# 设置远程描述
await pc.setRemoteDescription(answer)
print("连接已建立,开始接收视频流...")
# 保持连接,直到用户中断
try:
while True:
await asyncio.sleep(1)
# 检查连接状态
print(f"当前连接状态: {pc.connectionState}")
except KeyboardInterrupt:
print("用户中断,关闭连接...")
finally:
await pc.close()
if __name__ == "__main__":
# 替换为你的WHEP端点URL
WHEP_URL = "http://192.168.110.25:1985/rtc/v1/whep/?app=live&stream=473b95a47e338301cbd96809ea7ac416"
# 运行拉流任务
asyncio.run(whep_pull_video_stream(WHEP_URL))

View File

@ -1,112 +0,0 @@
import asyncio
import logging
import cv2
import time
from ocr.model_violation_detector import MultiModelViolationDetector
# 配置文件相对路径(根据实际目录结构调整)
YOLO_MODEL_PATH = "../ocr/models/best.pt" # 关键修正从core目录向上一级找ocr文件夹
FORBIDDEN_WORDS_PATH = "../ocr/forbidden_words.txt"
OCR_CONFIG_PATH = "../ocr/config/1.yaml"
KNOWN_FACES_DIR = "../ocr/known_faces"
# 创建检测器实例
detector = MultiModelViolationDetector(
forbidden_words_path=FORBIDDEN_WORDS_PATH,
ocr_config_path=OCR_CONFIG_PATH,
yolo_model_path=YOLO_MODEL_PATH,
known_faces_dir=KNOWN_FACES_DIR,
ocr_confidence_threshold=0.5
)
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("rtmp_video_puller")
async def rtmp_pull_video_stream(rtmp_url):
"""
通过RTMP从指定URL拉取视频流并进行违规检测
"""
cap = None # 初始化视频捕获对象
try:
# 异步打开RTMP流
cap = await asyncio.to_thread(
cv2.VideoCapture,
rtmp_url,
cv2.CAP_FFMPEG # 指定FFmpeg后端确保RTMP兼容性
)
# 检查RTMP流是否成功打开
is_opened = await asyncio.to_thread(cap.isOpened)
if not is_opened:
raise Exception(f"RTMP流打开失败: {rtmp_url}请检查URL有效性和FFmpeg环境")
# 获取RTMP流基础信息
width = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_WIDTH)
height = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_HEIGHT)
fps = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FPS)
# 处理异常情况
fps = fps if fps > 0 else 30.0
width, height = int(width), int(height)
# 打印流初始化成功信息
print(f"RTMP流状态: 已成功连接")
print(f"流基础信息: 分辨率 {width}x{height} | 配置帧率 {fps:.2f} FPS")
print("开始接收视频帧...(按 Ctrl+C 中断)")
# 初始化帧统计参数
frame_count = 0
start_time = time.time()
# 循环读取视频帧
while True:
ret, frame = await asyncio.to_thread(cap.read)
if not ret:
print(f"RTMP流状态: 帧读取失败(可能流已中断或结束)")
break
frame_count += 1
# 打印当前帧信息
print(f"收到帧 (第{frame_count}帧)")
print(f" 帧尺寸: {width}x{height}")
print(f" 配置帧率: {fps:.2f} FPS")
if frame is not None:
has_violation, violation_type, details = detector.detect_violations(frame)
if has_violation:
print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
else:
print("未检测到任何违规内容")
else:
print(f"无法读取测试图像")
# 每100帧统计一次实际接收帧率
if frame_count % 100 == 0:
elapsed_time = time.time() - start_time
actual_fps = frame_count / elapsed_time
print(f"---- 帧统计: 累计{frame_count}帧 | 实际平均帧率 {actual_fps:.2f} FPS ----")
except KeyboardInterrupt:
print(f"\n用户操作: 已通过 Ctrl+C 中断程序")
except Exception as e:
logger.error(f"RTMP流处理异常: {str(e)}", exc_info=True)
print(f"错误信息: {str(e)}")
finally:
if cap is not None:
await asyncio.to_thread(cap.release)
print(f"\n资源释放: RTMP流已关闭")
print(f"最终统计: 共接收 {frame_count if 'frame_count' in locals() else 0}")
if __name__ == "__main__":
RTMP_URL = "rtmp://192.168.110.25:1935/live/473b95a47e338301cbd96809ea7ac416"
try:
asyncio.run(rtmp_pull_video_stream(RTMP_URL))
except Exception as e:
print(f"程序启动失败: {str(e)}")

55
core/yolo.py Normal file
View File

@ -0,0 +1,55 @@
from ultralytics import YOLO
from service.model_service import get_current_yolo_model # 带版本校验的模型获取
def load_model(model_path=None):
"""加载YOLO模型优先使用带版本校验的默认模型"""
if model_path is None:
# 调用带版本校验的模型获取函数(自动判断是否需要重新加载)
return get_current_yolo_model()
try:
# 加载指定路径模型(用于特殊场景)
return YOLO(model_path)
except Exception as e:
print(f"YOLO模型加载失败指定路径{str(e)}")
return None
def detect(frame, conf_threshold=0.7):
"""执行目标检测(仅模型版本变化时重新加载,平时复用缓存)"""
# 获取模型(内部已做版本校验,未变化则直接返回缓存)
current_model = load_model()
if not current_model:
return (False, "未加载到最新YOLO模型")
if frame is None:
return (False, "无效输入帧")
try:
# 用当前模型执行检测(复用缓存,无额外加载耗时)
results = current_model(frame, conf=conf_threshold, verbose=False)
has_results = len(results[0].boxes) > 0 if results else False
if not has_results:
return (False, "未检测到目标")
# 构建结果字符串
result_parts = []
for box in results[0].boxes:
cls = int(box.cls[0])
conf = float(box.conf[0])
bbox = [round(x, 2) for x in box.xyxy[0].tolist()] # 保留两位小数
# 从当前模型中获取类别名(确保与模型匹配)
class_name = current_model.names[cls] if hasattr(current_model, 'names') else f"类别{cls}"
result_parts.append(f"{class_name}(置信度:{conf:.2f},位置:{bbox}")
# 打印当前使用的模型路径和版本(用于验证)
# model_path = getattr(current_model, "model_path", "未知路径")
# from service.model_service import _current_model_version
# print(f"[YOLO检测] 使用模型:{model_path}(版本:{_current_model_version[:10]}...")
return (True, "; ".join(result_parts))
except Exception as e:
print(f"YOLO检测过程出错{str(e)}")
return (False, f"检测错误:{str(e)}")

View File

@ -14,4 +14,3 @@ config.read(config_path, encoding="utf-8")
SERVER_CONFIG = config["server"]
MYSQL_CONFIG = config["mysql"]
JWT_CONFIG = config["jwt"]
LIVE_CONFIG = config["live"]

View File

@ -3,6 +3,8 @@ from mysql.connector import Error
from .config import MYSQL_CONFIG
# 关键:声明类级别的连接池实例(必须有这一行!)
_connection_pool = None # 确保这一行存在,且拼写正确
class Database:
"""MySQL 连接池管理类"""
@ -41,6 +43,18 @@ class Database:
except Error as e:
raise Exception(f"MySQL 连接关闭失败: {str(e)}") from e
@classmethod
def close_all_connections(cls):
"""清理连接池(服务重启前调用)"""
try:
# 先检查属性是否存在,再判断是否有值
if hasattr(cls, "_connection_pool") and cls._connection_pool:
cls._connection_pool = None # 重置连接池
print("[Database] 连接池已重置,旧连接将被自动清理")
else:
print("[Database] 连接池未初始化或已重置,无需操作")
except Exception as e:
print(f"[Database] 重置连接池失败: {str(e)}")
# 暴露数据库操作工具
db = Database()

127
main.py
View File

@ -1,43 +1,142 @@
from PIL import Image
import numpy as np
import uvicorn
import threading
import time
import os
from fastapi import FastAPI
# 新增:导入 CORS 相关依赖
from fastapi.middleware.cors import CORSMiddleware
# 导入 Flask 服务实例
from app import app as flask_app
# 原有业务导入
from core.all import load_model, detect
from ds.config import SERVER_CONFIG
from middle.error_handler import global_exception_handler
from service.user_service import router as user_router
from service.sensitive_service import router as sensitive_router
from service.face_service import router as face_router
from service.device_service import router as device_router
from service.model_service import router as model_router # 模型管理路由
from ws.ws import ws_router, lifespan
from core.establish import create_directory_structure
# ------------------------------
# 初始化 FastAPI 应用、指定生命周期管理
# ------------------------------
# Flask 服务启动函数(不变)
def start_flask_service():
try:
print(f"\n[Flask 服务] 准备启动端口5000")
print(f"[Flask 服务] 访问示例http://服务器IP:5000/resource/dect/ocr/xxx.jpg\n")
BASE_IMAGE_DIR = os.path.abspath(os.path.join("resource", "dect"))
if not os.path.exists(BASE_IMAGE_DIR):
print(f"[Flask 服务] 图片根目录不存在,创建:{BASE_IMAGE_DIR}")
os.makedirs(BASE_IMAGE_DIR, exist_ok=True)
flask_app.run(
host="0.0.0.0",
port=5000,
debug=False,
use_reloader=False
)
except Exception as e:
print(f"[Flask 服务] 启动失败:{str(e)}")
# 初始化 FastAPI 应用(新增 CORS 配置)
app = FastAPI(
title="内容安全审核后台",
description="内容安全审核后台",
description="含图片访问服务和动态模型管理",
version="1.0.0",
lifespan=lifespan
)
# ------------------------------
# 注册路由
# 新增:完整 CORS 配置(解决跨域问题)
# ------------------------------
# 1. 允许的前端域名(根据实际情况修改!本地开发通常是 http://localhost:8080 等)
ALLOWED_ORIGINS = [
# "http://localhost:8080", # 前端本地开发地址(必改,填实际前端地址)
# "http://127.0.0.1:8080",
# "http://服务器IP:8080", # 部署后前端地址(如适用)
"*" #表示允许所有域名(开发环境可用,生产环境不推荐)
]
# 2. 配置 CORS 中间件
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS, # 允许的前端域名
allow_credentials=True, # 允许携带 Cookie如需登录态则必开
allow_methods=["*"], # 允许所有 HTTP 方法(包括 PUT/DELETE
allow_headers=["*"], # 允许所有请求头(包括 Content-Type
)
# 注册路由(不变)
app.include_router(user_router)
app.include_router(device_router)
app.include_router(face_router)
app.include_router(sensitive_router)
app.include_router(model_router) # 模型管理路由
app.include_router(ws_router)
# ------------------------------
# 注册全局异常处理器
# ------------------------------
# 注册全局异常处理器(不变)
app.add_exception_handler(Exception, global_exception_handler)
# ------------------------------
# 启动服务
# ------------------------------
# 主服务启动入口(不变)
if __name__ == "__main__":
# 1. 初始化资源
create_directory_structure()
print(f"[初始化] 目录结构创建完成")
# 创建模型保存目录
MODEL_SAVE_DIR = os.path.join("core", "models")
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
print(f"[初始化] 模型保存目录:{MODEL_SAVE_DIR}")
# # 模型路径配置
# YOLO_MODEL_PATH = os.path.join("core", "models", "best.pt")
# OCR_CONFIG_PATH = os.path.join("core", "config", "config.yaml")
# print(f"[初始化] 默认YOLO模型路径{YOLO_MODEL_PATH}")
# print(f"[初始化] OCR 配置路径:{OCR_CONFIG_PATH}")
# 加载检测模型
try:
load_success = load_model()
if load_success:
print(f"[初始化] 检测模型加载完成")
else:
print(f"[初始化] 未找到默认模型可通过API上传并设置默认模型")
except Exception as e:
print(f"[初始化] 模型加载警告:{str(e)}(服务仍启动)")
# 2. 启动 Flask 服务(子线程)
flask_thread = threading.Thread(
target=start_flask_service,
daemon=True
)
flask_thread.start()
# 等待 Flask 初始化
time.sleep(1)
if flask_thread.is_alive():
print(f"[Flask 服务] 启动成功(运行中)")
else:
print(f"[Flask 服务] 启动失败!图片访问不可用")
# 3. 启动 FastAPI 主服务
port = int(SERVER_CONFIG.get("port", 8000))
print(f"\n[FastAPI 服务] 准备启动,端口:{port}")
print(f"[FastAPI 服务] 接口文档http://服务器IP:{port}/docs\n")
uvicorn.run(
app="main:app",
host="0.0.0.0",
port=port,
reload=True,
ws="websockets"
)
workers=1,
ws="websockets",
reload=False
)

View File

@ -8,7 +8,6 @@ from passlib.context import CryptContext
from ds.config import JWT_CONFIG
from ds.db import db
from service.user_service import UserResponse
# ------------------------------
# 密码加密配置
@ -22,9 +21,10 @@ SECRET_KEY = JWT_CONFIG["secret_key"]
ALGORITHM = JWT_CONFIG["algorithm"]
ACCESS_TOKEN_EXPIRE_MINUTES = int(JWT_CONFIG["access_token_expire_minutes"])
# OAuth2 依赖(从请求头获取 Token、格式Bearer <token>
# OAuth2 依赖(从请求头获取 Token、格式: Bearer <token>
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login")
# ------------------------------
# 密码工具函数
# ------------------------------
@ -32,10 +32,12 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
"""验证明文密码与加密密码是否匹配"""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""对明文密码进行 bcrypt 加密"""
return pwd_context.hash(password)
# ------------------------------
# JWT 工具函数
# ------------------------------
@ -53,11 +55,15 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
# ------------------------------
# 认证依赖(获取当前登录用户)
# ------------------------------
def get_current_user(token: str = Depends(oauth2_scheme)) -> UserResponse:
def get_current_user(token: str = Depends(oauth2_scheme)): # 移除返回类型注解
"""从 Token 中解析用户信息、验证通过后返回当前用户"""
# 延迟导入、打破循环依赖
from schema.user_schema import UserResponse # 在这里导入
# 认证失败异常
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -89,8 +95,8 @@ def get_current_user(token: str = Depends(oauth2_scheme)) -> UserResponse:
raise credentials_exception # 用户不存在
# 转换为 UserResponse 模型(自动校验字段)
return UserResponse(** user)
return UserResponse(**user)
except Exception as e:
raise credentials_exception from e
finally:
db.close_connection(conn, cursor)
db.close_connection(conn, cursor)

View File

@ -8,7 +8,7 @@ from schema.response_schema import APIResponse
async def global_exception_handler(request: Request, exc: Exception):
"""全局异常处理器所有未捕获的异常都会在这里统一处理"""
"""全局异常处理器: 所有未捕获的异常都会在这里统一处理"""
# 1. 请求参数验证错误Pydantic 校验失败)
if isinstance(exc, RequestValidationError):
error_details = []
@ -18,7 +18,7 @@ async def global_exception_handler(request: Request, exc: Exception):
status_code=status.HTTP_400_BAD_REQUEST,
content=APIResponse(
code=400,
message=f"请求参数错误{'; '.join(error_details)}",
message=f"请求参数错误: {'; '.join(error_details)}",
data=None
).model_dump()
)
@ -52,7 +52,7 @@ async def global_exception_handler(request: Request, exc: Exception):
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=APIResponse(
code=500,
message=f"数据库错误{str(exc)}",
message=f"数据库错误: {str(exc)}",
data=None
).model_dump()
)
@ -62,7 +62,7 @@ async def global_exception_handler(request: Request, exc: Exception):
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=APIResponse(
code=500,
message=f"服务器内部错误{str(exc)}",
message=f"服务器内部错误: {str(exc)}",
data=None
).model_dump()
)

View File

@ -1,139 +0,0 @@
import os
import cv2
import numpy as np
import insightface
from insightface.app import FaceAnalysis
class FaceRecognizer:
"""
封装InsightFace人脸识别功能支持从文件夹加载已知人脸。
"""
def __init__(self, known_faces_dir: str):
self.known_faces_dir = known_faces_dir
self.app = self._initialize_insightface()
self.known_faces_embeddings = {}
self.known_faces_names = []
self._load_known_faces()
def _initialize_insightface(self):
"""初始化InsightFace FaceAnalysis应用"""
print("初始化InsightFace引擎...")
try:
app = FaceAnalysis(name='buffalo_l', root='~/.insightface')
app.prepare(ctx_id=0, det_size=(640, 640))
print("InsightFace引擎初始化完成")
return app
except Exception as e:
print(f"InsightFace初始化失败: {e}")
print("请检查依赖是否安装及模型是否可访问")
return None
def _load_known_faces(self):
"""加载已知人脸特征"""
if not os.path.exists(self.known_faces_dir):
print(f"已知人脸目录不存在,已创建: {self.known_faces_dir}")
os.makedirs(self.known_faces_dir, exist_ok=True)
return
print(f"从目录加载人脸特征: {self.known_faces_dir}")
for person_name in os.listdir(self.known_faces_dir):
person_dir = os.path.join(self.known_faces_dir, person_name)
if os.path.isdir(person_dir):
print(f"处理人物: {person_name}")
embeddings = []
for filename in os.listdir(person_dir):
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
image_path = os.path.join(person_dir, filename)
try:
img = cv2.imread(image_path)
if img is None:
print(f"无法读取图片: {image_path},已跳过")
continue
faces = self.app.get(img)
if faces:
embeddings.append(faces[0].embedding)
print(f"提取特征成功: {filename}")
else:
print(f"未检测到人脸: {filename},已跳过")
except Exception as e:
print(f"处理图片出错 {image_path}: {e}")
if embeddings:
self.known_faces_embeddings[person_name] = np.array(embeddings).mean(axis=0)
self.known_faces_names.append(person_name)
print(f"人物 {person_name} 加载完成,共 {len(embeddings)} 张照片")
else:
print(f"人物 {person_name} 无有效特征,已跳过")
print(f"人脸加载完成,共 {len(self.known_faces_names)}")
def recognize(self, frame, threshold=0.4):
"""识别人脸并返回结果"""
if not self.app or not self.known_faces_names:
return False, None, None
faces = self.app.get(frame)
if not faces:
return False, None, None
for face in faces:
for known_name in self.known_faces_names:
known_embedding = self.known_faces_embeddings[known_name]
embedding1 = face.embedding.astype(np.float32)
embedding2 = known_embedding.astype(np.float32)
dot_product = np.dot(embedding1, embedding2)
norm_embedding1 = np.linalg.norm(embedding1)
norm_embedding2 = np.linalg.norm(embedding2)
similarity = 0.0 if (norm_embedding1 == 0 or norm_embedding2 == 0) else (
dot_product / (norm_embedding1 * norm_embedding2)
)
if similarity >= threshold:
print(f"检测到已知人物: {known_name} (相似度: {similarity:.4f})")
return True, known_name, similarity
return False, None, None
def test_single_image(self, image_path: str, threshold=0.4):
"""测试单张图片识别"""
if not os.path.exists(image_path):
print(f"图片不存在: {image_path}")
return False, None, None
frame = cv2.imread(image_path)
if frame is None:
print(f"无法读取图片: {image_path}")
return False, None, None
result, name, similarity = self.recognize(frame, threshold)
if result:
print(f"识别结果: {name} (相似度: {similarity:.4f})")
faces = self.app.get(frame)
for face in faces:
bbox = face.bbox.astype(int)
cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2)
text = f"{name}: {similarity:.2f}"
cv2.putText(frame, text, (bbox[0], bbox[1] - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
cv2.imshow('识别结果', frame)
print("按任意键关闭窗口...")
cv2.waitKey(0)
cv2.destroyAllWindows()
else:
print("未识别到已知人脸")
return result, name, similarity
#
# if __name__ == "__main__":
# recognizer = FaceRecognizer(known_faces_dir="known_faces")
# test_image_path = r"F:\OCR\RapidOCR-main\known_faces\B\14sino-qiu02-master1050.jpg"
# recognizer.test_single_image(test_image_path, threshold=0.4)

View File

@ -1,156 +0,0 @@
import cv2
import numpy as np
import insightface
from insightface.app import FaceAnalysis
from io import BytesIO
from PIL import Image
class BinaryFaceFeatureHandler:
"""
专门处理图片二进制数据的特征提取器,支持分批次接收二进制数据并累积计算平均特征
"""
def __init__(self):
self.app = self._init_insightface()
self.feature_list = [] # 存储所有图片二进制数据提取的特征
def _init_insightface(self):
"""初始化InsightFace引擎"""
try:
print("正在初始化InsightFace引擎...")
app = FaceAnalysis(name='buffalo_l', root='~/.insightface')
app.prepare(ctx_id=0, det_size=(640, 640))
print("InsightFace引擎初始化完成")
return app
except Exception as e:
print(f"InsightFace初始化失败: {e}")
return None
def add_binary_data(self, binary_data):
"""
接收单张图片的二进制数据,提取特征并保存
参数:
binary_data: 图片的二进制数据bytes类型
返回:
成功提取特征时返回 (True, 特征值numpy数组)
失败时返回 (False, None)
"""
if not self.app:
print("引擎未初始化,无法处理")
return False, None
try:
# 直接处理二进制数据:转换为图像格式
img = Image.open(BytesIO(binary_data))
frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
# 提取特征
faces = self.app.get(frame)
if faces:
# 获取当前提取的特征值
current_feature = faces[0].embedding
# 添加到特征列表
self.feature_list.append(current_feature)
print(f"已累计 {len(self.feature_list)} 个特征")
# 返回成功标志和当前特征值
return True,current_feature
else:
print("二进制数据中未检测到人脸")
return False, None
except Exception as e:
print(f"处理二进制数据出错: {e}")
return False, None
def get_average_feature(self, features):
"""
计算多个特征向量的平均值
参数:
features: 特征值列表每个元素可以是字符串格式或numpy数组
例如: [feature1, feature2, ...]
返回:
单一平均特征向量的numpy数组若无可计算数据则返回None
"""
try:
# 验证输入是否为列表且不为空
if not isinstance(features, list) or len(features) == 0:
print("输入必须是包含至少一个特征值的列表")
return None
# 处理每个特征值
processed_features = []
for i, embedding in enumerate(features):
try:
if isinstance(embedding, str):
# 处理包含括号和逗号的字符串格式
embedding_str = embedding.replace('[', '').replace(']', '').replace(',', ' ').strip()
embedding_list = [float(num) for num in embedding_str.split() if num.strip()]
embedding_np = np.array(embedding_list, dtype=np.float32)
else:
embedding_np = np.array(embedding, dtype=np.float32)
# 验证特征值格式
if len(embedding_np.shape) == 1:
processed_features.append(embedding_np)
print(f"已添加第 {i + 1} 个特征值用于计算平均值")
else:
print(f"跳过第 {i + 1} 个特征值,不是一维数组")
except Exception as e:
print(f"处理第 {i + 1} 个特征值时出错: {e}")
# 确保有有效的特征值
if not processed_features:
print("没有有效的特征值用于计算平均值")
return None
# 检查所有特征向量维度是否相同
dims = {feat.shape[0] for feat in processed_features}
if len(dims) > 1:
print(f"特征值维度不一致,无法计算平均值。检测到的维度: {dims}")
return None
# 计算平均值
avg_feature = np.mean(processed_features, axis=0)
print(f"成功计算 {len(processed_features)} 个特征值的平均特征向量,维度: {avg_feature.shape[0]}")
return avg_feature
except Exception as e:
print(f"计算平均特征值时出错: {e}")
return None
# def clear(self):
# """清空已存储的特征数据"""
# self.feature_list = []
# print("已清空所有特征数据")
# # 使用示例
# if __name__ == "__main__":
# handler = BinaryFaceFeatureHandler()
#
# # 模拟接收图片二进制数据
# try:
# # 第一次接收
# with open(r"D:\Git\bin\video\ocr\known_faces\B\102-f.jpg_1140x855.jpg", "rb") as f:
# bin_data1 = f.read()
# success, feature1 = handler.add_binary_data(bin_data1)
# if success:
# print(f"第一次提取的特征值前5个: {feature1[:5]}")
#
# # 第二次接收
# with open(r"D:\Git\bin\video\ocr\known_faces\B\104-1.jpg", "rb") as f:
# bin_data2 = f.read()
# success, feature2 = handler.add_binary_data(bin_data2)
# if success:
# print(f"第二次提取的特征值前5个: {feature2[:5]}")
#
# # 计算平均值
# avg_feature = handler.get_average_feature()
#
# except Exception as e:
# print(f"处理过程出错: {e}")

View File

@ -1,18 +0,0 @@
打倒习近平
打到习近平
打倒毛泽东
打到毛泽东
打到主席
打倒主席
打到共产主义
打倒共产主义
打到共产党
打倒共产党
胖猫
法轮功
法轮大法
法轮大法好
法轮功大法好
法轮
李洪志
习近平

Binary file not shown.

Before

Width:  |  Height:  |  Size: 195 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 208 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 657 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 53 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 8.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 58 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 34 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 155 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 386 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.4 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 62 KiB

View File

@ -1,49 +0,0 @@
#日志文件
import logging
import sys
def setup_logger():
"""
配置一个全局日志记录器,支持输出到控制台和文件。
"""
# 创建一个日志记录器
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger = logging.getLogger("ViolationDetectorLogger")
logger.setLevel(logging.DEBUG) # 设置最低级别为DEBUG
# 如果已经有处理器了,就不要重复添加,防止日志重复打印
if logger.hasHandlers():
return logger
# --- 控制台处理器 ---
console_handler = logging.StreamHandler(sys.stdout)
# 对于控制台我们只显示INFO及以上级别的信息
console_handler.setLevel(logging.INFO)
console_formatter = logging.Formatter(
'%(asctime)s - %(levelname)s - [%(module)s:%(lineno)d] - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
console_handler.setFormatter(console_formatter)
# --- 文件处理器 ---
file_handler = logging.FileHandler("violation_detector.log", mode='a', encoding='utf-8')
# 对于文件我们记录所有DEBUG及以上级别的信息
file_handler.setLevel(logging.DEBUG)
file_formatter = logging.Formatter(
'%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
file_handler.setFormatter(file_formatter)
# 将处理器添加到日志记录器
logger.addHandler(console_handler)
logger.addHandler(file_handler)
return logger
# 创建并导出logger实例
logger = setup_logger()

View File

@ -1,136 +0,0 @@
import os
import cv2
import yaml
from pathlib import Path
from .ocr_violation_detector import OCRViolationDetector
from .yolo_violation_detector import ViolationDetector as YoloViolationDetector
from .face_recognizer import FaceRecognizer
class MultiModelViolationDetector:
"""
多模型违规检测封装类串行调用OCR、人脸识别和YOLO模型任一模型检测到违规即返回结果
"""
def __init__(self,
forbidden_words_path: str,
ocr_config_path: str,
yolo_model_path: str,
known_faces_dir: str,
ocr_confidence_threshold: float = 0.5):
"""
初始化所有检测模型
"""
# 初始化OCR检测器
self.ocr_detector = OCRViolationDetector(
forbidden_words_path=forbidden_words_path,
ocr_config_path=ocr_config_path,
ocr_confidence_threshold=ocr_confidence_threshold
)
# 初始化人脸识别器
self.face_recognizer = FaceRecognizer(
known_faces_dir=known_faces_dir
)
# 初始化YOLO检测器
self.yolo_detector = YoloViolationDetector(
model_path=yolo_model_path
)
print("多模型违规检测器初始化完成")
def detect_violations(self, frame):
"""
串行调用三个检测模型OCR → 人脸识别 → YOLO任一检测到违规即返回结果
"""
# 1. 首先进行OCR违禁词检测
try:
ocr_has_violation, ocr_words, ocr_confs = self.ocr_detector.detect(frame)
if ocr_has_violation:
details = {
"words": ocr_words,
"confidences": ocr_confs
}
print(f"警告: OCR检测到违禁内容: {details}")
return (True, "ocr", details)
except Exception as e:
print(f"错误: OCR检测出错: {str(e)}")
# 2. 接着进行人脸识别检测
try:
face_has_violation, face_name, face_similarity = self.face_recognizer.recognize(frame)
if face_has_violation:
details = {
"name": face_name,
"similarity": face_similarity
}
print(f"警告: 人脸识别到违规人员: {details}")
return (True, "face", details)
except Exception as e:
print(f"错误: 人脸识别出错: {str(e)}")
# 3. 最后进行YOLO目标检测
try:
yolo_results = self.yolo_detector.detect(frame)
if len(yolo_results.boxes) > 0:
details = {
"classes": yolo_results.names,
"boxes": yolo_results.boxes.xyxy.tolist(),
"confidences": yolo_results.boxes.conf.tolist(),
"class_ids": yolo_results.boxes.cls.tolist()
}
print(f"警告: YOLO检测到违规目标: {details}")
return (True, "yolo", details)
except Exception as e:
print(f"错误: YOLO检测出错: {str(e)}")
# 所有检测均未发现违规
return (False, None, None)
def load_config(config_path: str) -> dict:
"""加载YAML配置文件"""
try:
with open(config_path, 'r', encoding='utf-8') as f:
return yaml.safe_load(f)
except FileNotFoundError:
print(f"错误: 配置文件未找到: {config_path}")
raise
except yaml.YAMLError as e:
print(f"错误: 配置文件格式错误: {config_path}, 错误: {str(e)}")
raise
except Exception as e:
print(f"错误: 加载配置文件出错: {str(e)}")
raise
# 使用示例
# if __name__ == "__main__":
# # 加载配置文件
# config = load_config("config.yaml") # 配置文件路径,可根据实际情况修改
#
# # 初始化多模型检测器
# detector = MultiModelViolationDetector(
# forbidden_words_path=config["forbidden_words_path"],
# ocr_config_path=config["ocr_config_path"],
# yolo_model_path=config["yolo_model_path"],
# known_faces_dir=config["known_faces_dir"],
# ocr_confidence_threshold=config.get("ocr_confidence_threshold", 0.5)
# )
#
# # 读取测试图像(可替换为视频帧读取逻辑)
# test_image_path = config.get("test_image_path") # 从配置文件获取测试图片路径
# if test_image_path:
# frame = cv2.imread(test_image_path)
#
# if frame is not None:
# has_violation, violation_type, details = detector.detect_violations(frame)
# if has_violation:
# print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
# else:
# print("未检测到任何违规内容")
# else:
# print(f"无法读取测试图像: {test_image_path}")
# else:
# print("配置文件中未指定测试图像路径")

Binary file not shown.

View File

@ -1,178 +0,0 @@
import os
import cv2
from rapidocr import RapidOCR
class OCRViolationDetector:
"""
封装RapidOCR引擎用于检测图像帧中的违禁词。
核心功能加载违禁词、初始化OCR引擎、单帧图像违禁词检测
"""
def __init__(self,
forbidden_words_path: str,
ocr_config_path: str,
ocr_confidence_threshold: float = 0.5):
"""
初始化OCR引擎和违禁词列表。
Args:
forbidden_words_path (str): 违禁词列表 .txt 文件的路径。
ocr_config_path (str): OCR配置文件如1.yaml的路径。
ocr_confidence_threshold (float): OCR识别结果的置信度阈值0~1
"""
# 加载违禁词
self.forbidden_words = self._load_forbidden_words(forbidden_words_path)
# 初始化RapidOCR引擎
self.ocr_engine = self._initialize_ocr(ocr_config_path)
# 校验核心依赖是否就绪
self._check_dependencies()
# 设置置信度阈值限制在0~1范围
self.OCR_CONFIDENCE_THRESHOLD = max(0.0, min(ocr_confidence_threshold, 1.0))
print(f"OCR置信度阈值已设置范围0~1: {self.OCR_CONFIDENCE_THRESHOLD:.4f}")
def _load_forbidden_words(self, path: str) -> set:
"""
从TXT文件加载违禁词去重、过滤空行支持UTF-8编码
"""
forbidden_words = set()
# 检查文件是否存在
if not os.path.exists(path):
print(f"错误:违禁词文件不存在: {path}")
return forbidden_words
# 读取文件并处理内容
try:
with open(path, 'r', encoding='utf-8') as f:
forbidden_words = {
line.strip() for line in f
if line.strip() # 跳过空行或纯空格行
}
print(f"成功加载违禁词: {len(forbidden_words)} 个(已去重)")
except UnicodeDecodeError:
print(f"错误违禁词文件编码错误需UTF-8: {path}")
except PermissionError:
print(f"错误:无权限读取违禁词文件: {path}")
except Exception as e:
print(f"错误:加载违禁词失败: {str(e)}")
return forbidden_words
def _initialize_ocr(self, config_path: str) -> RapidOCR | None:
"""
初始化RapidOCR引擎校验配置文件、捕获初始化异常
"""
print("开始初始化RapidOCR引擎...")
# 检查配置文件是否存在
if not os.path.exists(config_path):
print(f"错误OCR配置文件不存在: {config_path}")
return None
# 初始化OCR引擎
try:
ocr_engine = RapidOCR(config_path=config_path)
print("RapidOCR引擎初始化成功")
return ocr_engine
except ImportError:
print("错误RapidOCR依赖未安装需执行pip install rapidocr-onnxruntime")
except Exception as e:
print(f"错误RapidOCR初始化失败: {str(e)}")
return None
def _check_dependencies(self) -> None:
"""校验OCR引擎和违禁词列表是否就绪"""
if not self.ocr_engine:
print("警告:⚠️ OCR引擎未就绪违禁词检测功能将禁用")
if not self.forbidden_words:
print("警告:⚠️ 违禁词列表为空,违禁词检测功能将禁用")
def detect(self, frame) -> tuple[bool, list, list]:
"""
对单帧图像进行OCR违禁词检测核心方法
Args:
frame: 输入图像帧NumPy数组BGR格式cv2读取的图像
Returns:
tuple[bool, list, list]:
- 第一个元素是否检测到违禁词True/False
- 第二个元素:检测到的违禁词列表(空列表表示无违禁词);
- 第三个元素:对应违禁词的置信度列表(与违禁词列表一一对应)。
"""
# 初始化返回结果
has_violation = False
violation_words = []
violation_confs = []
# 前置校验
if frame is None or frame.size == 0:
print("警告输入图像帧为空或无效跳过OCR检测")
return has_violation, violation_words, violation_confs
if not self.ocr_engine or not self.forbidden_words:
print("OCR引擎未就绪或违禁词为空跳过OCR检测")
return has_violation, violation_words, violation_confs
try:
# 执行OCR识别
print("开始执行OCR识别...")
ocr_result = self.ocr_engine(frame)
print(f"RapidOCR原始结果: {ocr_result}")
# 校验OCR结果是否有效
if ocr_result is None:
print("OCR识别未返回任何结果图像无文本或识别失败")
return has_violation, violation_words, violation_confs
# 检查txts和scores是否存在且不为None
if not hasattr(ocr_result, 'txts') or ocr_result.txts is None:
print("警告OCR结果中txts为None或不存在")
return has_violation, violation_words, violation_confs
if not hasattr(ocr_result, 'scores') or ocr_result.scores is None:
print("警告OCR结果中scores为None或不存在")
return has_violation, violation_words, violation_confs
# 转为列表并去None
if not isinstance(ocr_result.txts, (list, tuple)):
print(f"警告OCR txts不是可迭代类型实际类型: {type(ocr_result.txts)}")
texts = []
else:
texts = [txt.strip() for txt in ocr_result.txts if txt and isinstance(txt, str)]
if not isinstance(ocr_result.scores, (list, tuple)):
print(f"警告OCR scores不是可迭代类型实际类型: {type(ocr_result.scores)}")
confidences = []
else:
confidences = [conf for conf in ocr_result.scores if conf and isinstance(conf, (int, float))]
# 校验文本和置信度列表长度是否一致
if len(texts) != len(confidences):
print(f"警告OCR文本与置信度数量不匹配文本{len(texts)}个,置信度{len(confidences)}个),跳过检测")
return has_violation, violation_words, violation_confs
if len(texts) == 0:
print("OCR未识别到任何有效文本")
return has_violation, violation_words, violation_confs
# 遍历识别结果,筛选违禁词
print(f"开始筛选违禁词(阈值{self.OCR_CONFIDENCE_THRESHOLD:.4f}")
for text, conf in zip(texts, confidences):
if conf < self.OCR_CONFIDENCE_THRESHOLD:
print(f"文本 '{text}' 置信度{conf:.4f} < 阈值,跳过")
continue
matched_words = [word for word in self.forbidden_words if word in text]
if matched_words:
has_violation = True
violation_words.extend(matched_words)
violation_confs.extend([conf] * len(matched_words))
print(f"警告:检测到违禁词: {matched_words}(来源文本: '{text}',置信度: {conf:.4f}")
except Exception as e:
print(f"错误OCR检测过程异常: {str(e)}")
return has_violation, violation_words, violation_confs

View File

@ -1,47 +0,0 @@
from ultralytics import YOLO
import cv2
class ViolationDetector:
"""
用于加载YOLOv8 .pt模型并进行违规内容检测的类。
"""
def __init__(self, model_path):
"""
初始化检测器。
Args:
model_path (str): YOLO .pt模型的路径。
"""
print(f"正在从 '{model_path}' 加载YOLO模型...")
self.model = YOLO(model_path)
print("YOLO模型加载成功。")
def detect(self, frame):
"""
对单帧图像进行目标检测。
Args:
frame: 输入的图像帧 (NumPy数组, BGR格式)。
Returns:
ultralytics.engine.results.Results: YOLO的检测结果对象。
"""
# conf可以根据您的模型效果进行调整
# --- 为了测试,我们暂时将置信度调低,例如 0.2 ---
results = self.model(frame, conf=0.2)
return results[0]
def draw_boxes(self, frame, result):
"""
在图像帧上绘制检测框。
Args:
frame: 原始图像帧。
result: YOLO的检测结果对象。
Returns:
numpy.ndarray: 绘制了检测框的图像帧。
"""
# 使用YOLO自带的plot功能方便快捷
annotated_frame = result.plot()
return annotated_frame

View File

@ -1,164 +0,0 @@
import queue
import asyncio
import aiohttp
import threading
import time
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration
from aiortc.mediastreams import MediaStreamTrack
# 创建一个长度为1的队列用于生产者和消费者之间的通信
frame_queue = queue.Queue(maxsize=1)
class VideoTrack(MediaStreamTrack):
"""自定义视频轨道类继承自MediaStreamTrack"""
kind = "video"
def __init__(self, max_frames=100):
super().__init__()
self.frames = queue.Queue(maxsize=max_frames)
async def recv(self):
return await super().recv()
def webrtc_producer(webrtc_url):
"""
生产者方法从WEBRTC读取视频帧并放入队列
仅当队列空时才放入新帧,否则丢弃
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 创建RTCPeerConnection对象不使用ICE服务器
pc = RTCPeerConnection(RTCConfiguration(iceServers=[]))
video_track = VideoTrack()
pc.addTrack(video_track)
@pc.on("track")
async def on_track(track):
if track.kind == "video":
print("接收到视频轨道,开始接收视频帧")
while True:
# 从轨道接收视频帧
frame = await track.recv()
# 转换为BGR24格式的NumPy数组
frame_bgr24 = frame.to_ndarray(format='bgr24')
# 检查队列是否为空,为空则加入,否则丢弃
if frame_queue.empty():
try:
frame_queue.put_nowait(frame_bgr24)
print("帧已放入队列")
except queue.Full:
print("队列已满,丢弃帧")
else:
print("队列非空,丢弃帧")
async def main():
# 创建并发送SDP Offer
offer = await pc.createOffer()
print("已创建本地SDP Offer")
await pc.setLocalDescription(offer)
# 发送Offer到服务器并接收Answer
async with aiohttp.ClientSession() as session:
print(f"开始向服务器 {webrtc_url} 发送SDP Offer")
async with session.post(
webrtc_url,
data=offer.sdp.encode(),
headers={
"Content-Type": "application/sdp",
"Content-Length": str(len(offer.sdp))
},
ssl=False
) as response:
print("已接收到服务器的响应")
answer_sdp = await response.text()
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer'))
# 保持连接
try:
while True:
await asyncio.sleep(0.1)
except KeyboardInterrupt:
pass
finally:
print("关闭RTCPeerConnection")
await pc.close()
try:
loop.run_until_complete(main())
finally:
loop.close()
def frame_consumer(ip):
"""
消费者方法:从队列中读取帧并处理
每次处理后休眠200ms模拟延迟
"""
print("消费者启动,开始等待帧...")
try:
while True:
# 阻塞等待队列中的帧
frame = frame_queue.get()
print(f"消费帧,大小: {frame.shape}")
has_violation, violations, confidences = OCRViolationDetector.detect(frame)
# 输出检测结果
if has_violation:
detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
for word, conf in zip(violations, confidences):
detector.logger.info(f"- {word} (置信度: {conf:.4f})")
else:
detector.logger.info("图片中未检测到违禁词")
# 标记任务完成
frame_queue.task_done()
except KeyboardInterrupt:
print("消费者退出")
def start_webrtc_stream(ip, webrtc_url):
"""
启动WebRTC视频流处理的主方法
参数: webrtc_url - WebRTC服务器地址
"""
print(f"开始连接到WebRTC服务器: {webrtc_url}")
# 启动生产者线程
producer_thread = threading.Thread(
target=webrtc_producer,
args=(webrtc_url,),
daemon=True,
name="webrtc-producer"
)
# 启动消费者线程
consumer_thread = threading.Thread(
target=frame_consumer(ip),
daemon=True,
name="frame-consumer"
)
producer_thread.start()
consumer_thread.start()
print("生产者和消费者线程已启动")
try:
# 保持主线程运行
while True:
time.sleep(1)
except KeyboardInterrupt:
print("程序正在退出...")
if __name__ == "__main__":
# 示例用法
# 实际使用时替换为真实的WebRTC服务器地址
webrtc_server_url = "http://192.168.110.65:1985/rtc/v1/whep/?app=live&stream=677a4845aa48cb8526c811ad56fc5e60"
start_webrtc_stream(webrtc_server_url)

View File

@ -1,101 +0,0 @@
import asyncio
import logging
import cv2
import time
# 配置日志与WHEP代码保持一致的日志风格
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("rtmp_video_puller")
async def rtmp_pull_video_stream(rtmp_url):
"""
通过RTMP从指定URL拉取视频流并在收到每一帧时打印消息
功能与WHEP拉流函数对齐流状态反馈、帧信息打印、帧率统计、异常处理
Args:
rtmp_url: RTMP流的URL地址如 rtmp://xxx/live/stream_key
"""
cap = None # 初始化视频捕获对象
try:
# 1. 异步打开RTMP流指定FFmpeg后端确保RTMP兼容性同步操作通过to_thread避免阻塞事件循环
cap = await asyncio.to_thread(
cv2.VideoCapture,
rtmp_url,
cv2.CAP_FFMPEG # 必须指定FFmpeg后端RTMP协议依赖该后端解析
)
# 2. 检查RTMP流是否成功打开
is_opened = await asyncio.to_thread(cap.isOpened)
if not is_opened:
raise Exception(f"RTMP流打开失败: {rtmp_url}请检查URL有效性和FFmpeg环境")
# 3. 异步获取RTMP流基础信息分辨率、帧率
width = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_WIDTH)
height = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_HEIGHT)
fps = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FPS)
# 处理异常情况部分RTMP流未返回帧率时默认30FPS
fps = fps if fps > 0 else 30.0
# 分辨率转为整数(视频尺寸必然是整数)
width, height = int(width), int(height)
# 打印流初始化成功信息与WHEP连接成功信息风格一致
print(f"RTMP流状态: 已成功连接")
print(f"流基础信息: 分辨率 {width}x{height} | 配置帧率 {fps:.2f} FPS")
print("开始接收视频帧...(按 Ctrl+C 中断)")
# 4. 初始化帧统计参数
frame_count = 0 # 总接收帧数
start_time = time.time() # 统计起始时间
# 5. 循环异步读取视频帧(核心逻辑)
while True:
# 异步读取一帧cv2.read是同步操作用to_thread适配异步环境
ret, frame = await asyncio.to_thread(cap.read)
# 检查帧是否读取成功(流中断/结束时ret为False
if not ret:
print(f"RTMP流状态: 帧读取失败(可能流已中断或结束)")
break
# 帧计数累加
frame_count += 1
# 6. 打印当前帧基础信息与WHEP帧信息打印风格对齐
print(f"收到帧 (第{frame_count}帧)")
print(f" 帧尺寸: {width}x{height}")
print(f" 配置帧率: {fps:.2f} FPS")
# 7. 每100帧统计一次实际接收帧率补充性能监控与原RTMP示例逻辑一致
if frame_count % 100 == 0:
elapsed_time = time.time() - start_time
actual_fps = frame_count / elapsed_time # 实际接收帧率(可能低于配置帧率)
print(f"---- 帧统计: 累计{frame_count}帧 | 实际平均帧率 {actual_fps:.2f} FPS ----")
# (可选)帧数据处理入口:如需处理帧(如推流、分析),可在此处添加逻辑
# 示例yield frame (若需生成器模式,可调整函数为异步生成器)
# 8. 异常处理(覆盖用户中断、通用错误)
except KeyboardInterrupt:
print(f"\n用户操作: 已通过 Ctrl+C 中断程序")
except Exception as e:
# 日志记录详细错误(便于问题排查),同时打印用户可见信息
logger.error(f"RTMP流处理异常: {str(e)}", exc_info=True)
print(f"错误信息: {str(e)}")
finally:
# 9. 资源释放(无论成功/失败都确保释放,避免内存泄漏)
if cap is not None:
await asyncio.to_thread(cap.release)
print(f"\n资源释放: RTMP流已关闭")
print(f"最终统计: 共接收 {frame_count if 'frame_count' in locals() else 0}")
if __name__ == "__main__":
RTMP_URL = "rtmp://192.168.110.25:1935/live/473b95a47e338301cbd96809ea7ac416"
# 运行RTMP拉流任务与WHEP一致的异步执行方式
try:
asyncio.run(rtmp_pull_video_stream(RTMP_URL))
except Exception as e:
print(f"程序启动失败: {str(e)}")

View File

@ -0,0 +1,36 @@
from datetime import datetime
from typing import Optional, List
from pydantic import BaseModel, Field
# ------------------------------
# 请求模型
# ------------------------------
class DeviceActionCreate(BaseModel):
"""设备操作记录创建模型0=离线、1=上线)"""
client_ip: str = Field(..., description="客户端IP")
action: int = Field(..., ge=0, le=1, description="操作状态0=离线、1=上线)")
# ------------------------------
# 响应模型(单条记录)
# ------------------------------
class DeviceActionResponse(BaseModel):
"""设备操作记录响应模型(与自增表对齐)"""
id: int = Field(..., description="自增主键ID")
client_ip: Optional[str] = Field(None, description="客户端IP")
action: Optional[int] = Field(None, description="操作状态0=离线、1=上线)")
created_at: datetime = Field(..., description="记录创建时间")
updated_at: datetime = Field(..., description="记录更新时间")
# 支持从数据库结果直接转换
model_config = {"from_attributes": True}
# ------------------------------
# 列表响应模型(仅含 total + device_actions
# ------------------------------
class DeviceActionListResponse(BaseModel):
"""设备操作记录列表(仅核心返回字段)"""
total: int = Field(..., description="总记录数")
device_actions: List[DeviceActionResponse] = Field(..., description="操作记录列表")

View File

@ -1,4 +1,3 @@
import hashlib
from datetime import datetime
from typing import Optional, List, Dict
@ -6,46 +5,51 @@ from pydantic import BaseModel, Field
# ------------------------------
# 请求模型(前端传参校验)
# 请求模型
# ------------------------------
class DeviceCreateRequest(BaseModel):
"""设备流信息创建请求模型"""
"""设备创建请求模型"""
ip: Optional[str] = Field(..., max_length=100, description="设备IP地址")
hostname: Optional[str] = Field(None, max_length=100, description="设备别名")
params: Optional[Dict] = Field(None, description="设备详细信息")
def md5_encrypt(text: str) -> str:
"""对字符串进行MD5加密"""
if not text:
return ""
md5_hash = hashlib.md5()
md5_hash.update(text.encode('utf-8'))
return md5_hash.hexdigest()
params: Optional[Dict] = Field(None, description="设备扩展参数JSON格式")
# ------------------------------
# 响应模型(后端返回设备数据)
# 响应模型
# ------------------------------
class DeviceResponse(BaseModel):
"""设备信息响应模型(字段与表结构完全对齐)"""
id: int = Field(..., description="设备ID")
"""设备信息响应模型(与数据库表字段对齐)"""
id: int = Field(..., description="设备主键ID")
client_ip: Optional[str] = Field(None, max_length=100, description="设备IP地址")
hostname: Optional[str] = Field(None, max_length=100, description="设备别名")
rtmp_push_url: Optional[str] = Field(None, description="需要推送的RTMP地址")
live_webrtc_url: Optional[str] = Field(None, description="直播的Webrtc地址")
detection_webrtc_url: Optional[str] = Field(None, description="检测的Webrtc地址")
device_online_status: int = Field(..., description="设备在线状态1-在线、0-离线)")
device_online_status: int = Field(..., description="在线状态1-在线、0-离线)")
device_type: Optional[str] = Field(None, description="设备类型")
alarm_count: int = Field(..., description="报警次数")
params: Optional[str] = Field(None, description="设备详细信息")
params: Optional[str] = Field(None, description="扩展参数JSON字符串")
created_at: datetime = Field(..., description="记录创建时间")
updated_at: datetime = Field(..., description="记录更新时间")
# 支持从数据库查询结果转换
model_config = {"from_attributes": True}
model_config = {"from_attributes": True} # 支持从数据库结果直接转换
class DeviceListResponse(BaseModel):
"""设备流信息列表响应模型"""
"""设备列表响应模型"""
total: int = Field(..., description="设备总数")
devices: List[DeviceResponse] = Field(..., description="设备列表")
class DeviceStatusHistoryResponse(BaseModel):
"""设备上下线记录响应模型"""
id: int = Field(..., description="记录ID")
device_id: int = Field(..., description="关联设备ID")
client_ip: Optional[str] = Field(None, description="设备IP地址")
status: int = Field(..., description="状态1-在线、0-离线)")
status_time: datetime = Field(..., description="状态变更时间")
model_config = {"from_attributes": True}
class DeviceStatusHistoryListResponse(BaseModel):
"""设备上下线记录列表响应模型"""
total: int = Field(..., description="记录总数")
history: List[DeviceStatusHistoryResponse] = Field(..., description="上下线记录列表")

View File

@ -1,30 +1,41 @@
from datetime import datetime
from pydantic import BaseModel, Field
from typing import List, Optional
# ------------------------------
# 请求模型(前端传参校验)
# 请求模型(前端传参校验)- 保留update的eigenvalue如需更新特征值
# ------------------------------
class FaceCreateRequest(BaseModel):
"""创建人脸记录请求模型无需ID由数据库自增)"""
name: str = Field(None, max_length=255, description="名称(可选最长255字符")
"""创建人脸记录请求模型无需ID由数据库自增)"""
name: Optional[str] = Field(None, max_length=255, description="名称(可选最长255字符")
class FaceUpdateRequest(BaseModel):
"""更新人脸记录请求模型(不变"""
name: str = Field(None, max_length=255, description="名称")
eigenvalue: str = Field(None, max_length=255, description="特征文件处理后可更新)")
"""更新人脸记录请求模型 - 保留eigenvalue如需更新特征值不影响返回"""
name: Optional[str] = Field(None, max_length=255, description="名称(可选)")
eigenvalue: Optional[str] = Field(None, description="特征值(可选,文件处理后可更新)") # 保留更新能力
address: Optional[str] = Field(None, description="图片完整路径(可选,更新图片时使用)")
# ------------------------------
# 响应模型(后端返回数据)
# 响应模型(后端返回数据)- 核心修改删除eigenvalue字段
# ------------------------------
class FaceResponse(BaseModel):
"""人脸记录响应模型(仍包含ID由数据库生成后返回"""
"""人脸记录响应模型(仅返回需要的字段移除eigenvalue"""
id: int = Field(..., description="主键ID数据库自增")
name: str = Field(None, description="名称")
eigenvalue: str = Field(None, description="特征暂为None")
created_at: datetime = Field(..., description="记录创建时间")
updated_at: datetime = Field(..., description="记录更新时间")
name: Optional[str] = Field(None, description="名称")
address: Optional[str] = Field(None, description="人脸图片完整保存路径(数据库新增字段)") # 仅保留address
created_at: datetime = Field(..., description="记录创建时间(数据库自动生成)")
updated_at: datetime = Field(..., description="记录更新时间(数据库自动生成)")
# 关键配置:支持从数据库查询结果(字典)直接转换
model_config = {"from_attributes": True}
class FaceListResponse(BaseModel):
"""人脸列表分页响应模型结构不变内部FaceResponse已移除eigenvalue"""
total: int = Field(..., description="筛选后的总记录数")
faces: List[FaceResponse] = Field(..., description="当前页的人脸记录列表")
model_config = {"from_attributes": True}

37
schema/model_schema.py Normal file
View File

@ -0,0 +1,37 @@
from datetime import datetime
from pydantic import BaseModel, Field
from typing import List, Optional
# 请求模型
class ModelCreateRequest(BaseModel):
name: str = Field(..., max_length=255, description="模型名称必填yolo-v8s-car")
description: Optional[str] = Field(None, description="模型描述(可选)")
is_default: Optional[bool] = Field(False, description="是否设为默认模型")
class ModelUpdateRequest(BaseModel):
name: Optional[str] = Field(None, max_length=255, description="模型名称(可选修改)")
description: Optional[str] = Field(None, description="模型描述(可选修改)")
is_default: Optional[bool] = Field(None, description="是否设为默认模型(可选切换)")
# 响应模型
class ModelResponse(BaseModel):
id: int = Field(..., description="模型ID")
name: str = Field(..., description="模型名称")
path: str = Field(..., description="模型文件相对路径")
is_default: bool = Field(..., description="是否默认模型")
description: Optional[str] = Field(None, description="模型描述")
file_size: Optional[int] = Field(None, description="文件大小(字节)")
created_at: datetime = Field(..., description="创建时间")
updated_at: datetime = Field(..., description="更新时间")
model_config = {"from_attributes": True}
class ModelListResponse(BaseModel):
total: int = Field(..., description="总记录数")
models: List[ModelResponse] = Field(..., description="当前页模型列表")
model_config = {"from_attributes": True}

View File

@ -5,9 +5,9 @@ from pydantic import BaseModel, Field
class APIResponse(BaseModel):
"""统一 API 响应模型(所有接口必返此格式)"""
code: int = Field(..., description="状态码200=成功、4xx=客户端错误、5xx=服务端错误")
message: str = Field(..., description="响应信息成功/错误描述")
data: Optional[Any] = Field(None, description="响应数据成功时返回、错误时为 None")
code: int = Field(..., description="状态码: 200=成功、4xx=客户端错误、5xx=服务端错误")
message: str = Field(..., description="响应信息: 成功/错误描述")
data: Optional[Any] = Field(None, description="响应数据: 成功时返回、错误时为 None")
# Pydantic V2 配置(支持从 ORM 对象转换)
model_config = {"from_attributes": True}

View File

@ -7,7 +7,7 @@ from pydantic import BaseModel, Field
# ------------------------------
class SensitiveCreateRequest(BaseModel):
"""创建敏感信息记录请求模型"""
# 移除了id字段由数据库自动生成
# 移除了id字段由数据库自动生成
name: str = Field(None, max_length=255, description="名称")

View File

@ -1,6 +1,6 @@
from datetime import datetime
from pydantic import BaseModel, Field
from typing import List, Optional
# ------------------------------
@ -30,3 +30,11 @@ class UserResponse(BaseModel):
# Pydantic V2 配置(支持从数据库查询结果转换)
model_config = {"from_attributes": True}
class UserListResponse(BaseModel):
"""用户列表分页响应模型(与设备/人脸列表结构对齐)"""
total: int = Field(..., description="用户总数")
users: List[UserResponse] = Field(..., description="当前页用户列表")
model_config = {"from_attributes": True}

View File

@ -0,0 +1,158 @@
from fastapi import APIRouter, Query, Path
from mysql.connector import Error as MySQLError
from ds.db import db
from schema.device_action_schema import (
DeviceActionCreate,
DeviceActionResponse,
DeviceActionListResponse
)
from schema.response_schema import APIResponse
# 路由配置
router = APIRouter(
prefix="/device/actions",
tags=["设备操作记录"]
)
# ------------------------------
# 内部方法: 新增设备操作记录适配id自增
# ------------------------------
def add_device_action(action_data: DeviceActionCreate) -> DeviceActionResponse:
"""
新增设备操作记录(内部方法、非接口)
:param action_data: 含client_ip和action0/1
:return: 新增的完整记录
"""
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 插入SQLid自增、依赖数据库自动生成
insert_query = """
INSERT INTO device_action
(client_ip, action, created_at, updated_at)
VALUES (%s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
"""
cursor.execute(insert_query, (
action_data.client_ip,
action_data.action
))
conn.commit()
# 获取新增记录通过自增ID查询
new_id = cursor.lastrowid
cursor.execute("SELECT * FROM device_action WHERE id = %s", (new_id,))
new_action = cursor.fetchone()
return DeviceActionResponse(**new_action)
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"新增记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 接口: 分页查询操作记录列表(仅返回 total + device_actions
# ------------------------------
@router.get("/list", response_model=APIResponse, summary="分页查询设备操作记录")
async def get_device_action_list(
page: int = Query(1, ge=1, description="页码、默认1"),
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100"),
client_ip: str = Query(None, description="按客户端IP筛选"),
action: int = Query(None, ge=0, le=1, description="按状态筛选0=离线、1=上线)")
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 1. 构建筛选条件(参数化查询、避免注入)
where_clause = []
params = []
if client_ip:
where_clause.append("client_ip = %s")
params.append(client_ip)
if action is not None:
where_clause.append("action = %s")
params.append(action)
# 2. 查询总记录数(用于返回 total
count_sql = "SELECT COUNT(*) AS total FROM device_action"
if where_clause:
count_sql += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_sql, params)
total = cursor.fetchone()["total"]
# 3. 分页查询记录(按创建时间倒序、确保最新记录在前)
offset = (page - 1) * page_size
list_sql = "SELECT * FROM device_action"
if where_clause:
list_sql += " WHERE " + " AND ".join(where_clause)
list_sql += " ORDER BY created_at DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset]) # 追加分页参数page/page_size仅用于查询、不返回
cursor.execute(list_sql, params)
action_list = cursor.fetchall()
# 4. 仅返回 total + device_actions
return APIResponse(
code=200,
message="查询成功",
data=DeviceActionListResponse(
total=total,
device_actions=[DeviceActionResponse(**item) for item in action_list]
)
)
except MySQLError as e:
raise Exception(f"查询记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
@router.get("/{client_ip}", response_model=APIResponse, summary="根据IP查询设备操作记录")
async def get_device_actions_by_ip(
client_ip: str = Path(..., description="客户端IP地址")
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 1. 查询总记录数
count_sql = "SELECT COUNT(*) AS total FROM device_action WHERE client_ip = %s"
cursor.execute(count_sql, (client_ip,))
total = cursor.fetchone()["total"]
# 2. 查询该IP的所有记录按创建时间倒序
list_sql = """
SELECT * FROM device_action
WHERE client_ip = %s
ORDER BY created_at DESC
"""
cursor.execute(list_sql, (client_ip,))
action_list = cursor.fetchall()
# 3. 返回结果
return APIResponse(
code=200,
message="查询成功",
data=DeviceActionListResponse(
total=total,
device_actions=[DeviceActionResponse(**item) for item in action_list]
)
)
except MySQLError as e:
raise Exception(f"查询记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)

View File

@ -1,25 +1,15 @@
import json
import threading
import time
from datetime import date
from fastapi import HTTPException, Query, APIRouter, Depends, Request
from fastapi import APIRouter, Query, HTTPException, Request, Path
from mysql.connector import Error as MySQLError
from ds.config import LIVE_CONFIG
from ds.db import db
from middle.auth_middleware import get_current_user
# 注意导入的Schema已更新字段
from schema.device_schema import (
DeviceCreateRequest,
DeviceResponse,
DeviceListResponse,
md5_encrypt
DeviceCreateRequest, DeviceResponse, DeviceListResponse,
DeviceStatusHistoryResponse, DeviceStatusHistoryListResponse
)
from schema.response_schema import APIResponse
from schema.user_schema import UserResponse
# 导入之前封装的WEBRTC处理函数
from core.rtmp import rtmp_pull_video_stream
router = APIRouter(
prefix="/devices",
@ -27,90 +17,192 @@ router = APIRouter(
)
# 在后台线程中运行WEBRTC处理
def run_webrtc_processing(ip, webrtc_url):
try:
print(f"开始处理来自设备 {ip} 的WEBRTC流: {webrtc_url}")
rtmp_pull_video_stream(webrtc_url)
except Exception as e:
print(f"WEBRTC处理出错: {str(e)}")
# ------------------------------
# 1. 创建设备信息
# 内部工具方法 - 记录设备状态变更历史
# ------------------------------
@router.post("/add", response_model=APIResponse, summary="创建设备信息")
async def create_device(request: Request, device_data: DeviceCreateRequest):
def record_status_change(client_ip: str, status: int) -> bool:
"""
记录设备状态变更历史(写入 device_action 表)
:param client_ip: 设备IP
:param status: 状态1-在线、0-离线)
:return: 操作是否成功
"""
if not client_ip:
raise ValueError("客户端IP不能为空")
if status not in (0, 1):
raise ValueError("状态必须是0离线或1在线")
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 检查client_ip是否已存在
# 插入状态变更记录到 device_action
insert_query = """
INSERT INTO device_action
(client_ip, action, created_at, updated_at)
VALUES (%s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
"""
cursor.execute(insert_query, (client_ip, status))
conn.commit()
return True
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"记录设备状态变更失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 内部工具方法 - 通过客户端IP增加设备报警次数
# ------------------------------
def increment_alarm_count_by_ip(client_ip: str) -> bool:
"""
通过客户端IP增加设备的报警次数
:param client_ip: 客户端IP地址
:return: 操作是否成功
"""
if not client_ip:
raise ValueError("客户端IP不能为空")
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 检查设备是否存在
cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,))
device = cursor.fetchone()
if not device:
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在")
# 报警次数加1、并更新时间戳
update_query = """
UPDATE devices
SET alarm_count = alarm_count + 1,
updated_at = CURRENT_TIMESTAMP
WHERE client_ip = %s
"""
cursor.execute(update_query, (client_ip,))
conn.commit()
return True
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"更新报警次数失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 内部工具方法 - 通过客户端IP更新设备在线状态
# ------------------------------
def update_online_status_by_ip(client_ip: str, online_status: int) -> bool:
"""
通过客户端IP更新设备的在线状态
:param client_ip: 客户端IP地址
:param online_status: 在线状态1-在线、0-离线)
:return: 操作是否成功
"""
if not client_ip:
raise ValueError("客户端IP不能为空")
if online_status not in (0, 1):
raise ValueError("在线状态必须是0离线或1在线")
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 检查设备是否存在并获取设备ID
cursor.execute("SELECT id, device_online_status FROM devices WHERE client_ip = %s", (client_ip,))
device = cursor.fetchone()
if not device:
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在")
# 状态无变化则不操作
if device['device_online_status'] == online_status:
return True
# 更新在线状态和时间戳
update_query = """
UPDATE devices
SET device_online_status = %s,
updated_at = CURRENT_TIMESTAMP
WHERE client_ip = %s
"""
cursor.execute(update_query, (online_status, client_ip))
# 记录状态变更历史
record_status_change(client_ip, online_status)
conn.commit()
return True
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"更新设备在线状态失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 创建设备信息接口
# ------------------------------
@router.post("/add", response_model=APIResponse, summary="创建设备信息")
async def create_device(device_data: DeviceCreateRequest, request: Request):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 检查设备是否已存在
cursor.execute("SELECT * FROM devices WHERE client_ip = %s", (device_data.ip,))
existing_device = cursor.fetchone()
if existing_device:
# 设备创建成功后在后台线程启动WEBRTC流处理
threading.Thread(
target=run_webrtc_processing,
# args=(device_data.ip, existing_device["live_webrtc_url"]),
args=(device_data.ip, existing_device["rtmp_push_url"]),
daemon=True # 设为守护线程,主程序退出时自动结束
).start()
# IP已存在时返回该设备信息
# 更新设备为在线状态
update_online_status_by_ip(client_ip=device_data.ip, online_status=1)
return APIResponse(
code=200,
message=f"客户端IP {device_data.ip} 已存在",
message=f"设备IP {device_data.ip} 已存在、返回已有设备信息",
data=DeviceResponse(**existing_device)
)
# 获取RTMP URL和WEBRTC URL配置
rtmp_url = str(LIVE_CONFIG.get("rtmp_url", ""))
webrtc_url = str(LIVE_CONFIG.get("webrtc_url", ""))
# 将设备详细信息params转换为JSON字符串
device_params_json = json.dumps(device_data.params) if device_data.params else None
# 对JSON字符串进行MD5加密
device_md5 = md5_encrypt(device_params_json) if device_params_json else ""
# 解析User-Agent获取设备类型
# 通过 User-Agent 判断设备类型
user_agent = request.headers.get("User-Agent", "").lower()
# 优先处理User-Agent为default的情况
device_type = "unknown"
if user_agent == "default":
# 检查params中是否存在os键
if device_data.params and isinstance(device_data.params, dict) and "os" in device_data.params:
device_type = device_data.params["os"]
else:
device_type = "unknown"
device_type = device_data.params.get("os") if (device_data.params and isinstance(device_data.params, dict)) else "unknown"
elif "windows" in user_agent:
device_type = "windows"
elif "android" in user_agent:
device_type = "android"
elif "linux" in user_agent:
device_type = "linux"
else:
device_type = "unknown"
# 构建完整的WEBRTC URL
full_webrtc_url = webrtc_url + device_md5
device_params_json = json.dumps(device_data.params) if device_data.params else None
# SQL插入语句
# 插入新设备
insert_query = """
INSERT INTO devices
(client_ip, hostname, rtmp_push_url, live_webrtc_url, detection_webrtc_url,
device_online_status, device_type, alarm_count, params)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
(client_ip, hostname, device_online_status, device_type, alarm_count, params)
VALUES (%s, %s, %s, %s, %s, %s)
"""
cursor.execute(insert_query, (
device_data.ip,
device_data.hostname,
rtmp_url + device_md5,
full_webrtc_url, # 存储完整的WEBRTC URL
"",
1,
device_type,
0,
@ -118,28 +210,26 @@ async def create_device(request: Request, device_data: DeviceCreateRequest):
))
conn.commit()
# 获取刚创建的设备信息
# 获取新设备并返回
device_id = cursor.lastrowid
cursor.execute("SELECT * FROM devices WHERE id = %s", (device_id,))
device = cursor.fetchone()
new_device = cursor.fetchone()
# 记录上线历史
record_status_change(device_data.ip, 1)
# 设备创建成功后在后台线程启动WEBRTC流处理
threading.Thread(
target=run_webrtc_processing,
args=(device_data.ip, full_webrtc_url),
daemon=True # 设为守护线程,主程序退出时自动结束
).start()
return APIResponse(
code=200,
message="设备创建成功已开始处理WEBRTC流",
data=DeviceResponse(**device)
message="设备创建成功",
data=DeviceResponse(**new_device)
)
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"创建设备失败{str(e)}") from e
raise Exception(f"创建设备失败: {str(e)}") from e
except json.JSONDecodeError as e:
raise Exception(f"设备信息JSON序列化失败{str(e)}") from e
raise Exception(f"设备参数JSON序列化失败: {str(e)}") from e
except Exception as e:
if conn:
conn.rollback()
@ -149,14 +239,14 @@ async def create_device(request: Request, device_data: DeviceCreateRequest):
# ------------------------------
# 2. 获取设备列表
# 获取设备列表接口
# ------------------------------
@router.get("/", response_model=APIResponse, summary="获取设备列表")
@router.get("/", response_model=APIResponse, summary="获取设备列表(支持筛选分页)")
async def get_device_list(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(10, ge=1, le=100, description="每页条数"),
device_type: str = Query(None, description="设备类型筛选"),
online_status: int = Query(None, ge=0, le=1, description="在线状态筛选1-在线、0-离线)")
page: int = Query(1, ge=1, description="页码默认第1页"),
page_size: int = Query(10, ge=1, le=100, description="每页条数1-100之间"),
device_type: str = Query(None, description="设备类型筛选"),
online_status: int = Query(None, ge=0, le=1, description="在线状态筛选")
):
conn = None
cursor = None
@ -164,58 +254,59 @@ async def get_device_list(
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 构建查询条件
where_clause = []
params = []
if device_type:
where_clause.append("device_type = %s")
params.append(device_type)
if online_status is not None:
where_clause.append("device_online_status = %s")
params.append(online_status)
# 总条数查询
count_query = "SELECT COUNT(*) as total FROM devices"
# 统计总数
count_query = "SELECT COUNT(*) AS total FROM devices"
if where_clause:
count_query += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_query, params)
total = cursor.fetchone()["total"]
# 分页查询SELECT * 会自动匹配表字段、响应模型已对齐)
# 分页查询列表
offset = (page - 1) * page_size
query = "SELECT * FROM devices"
list_query = "SELECT * FROM devices"
if where_clause:
query += " WHERE " + " AND ".join(where_clause)
query += " ORDER BY id DESC LIMIT %s OFFSET %s"
list_query += " WHERE " + " AND ".join(where_clause)
list_query += " ORDER BY id DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset])
cursor.execute(query, params)
devices = cursor.fetchall()
# 响应模型已更新为params字段、直接转换即可
device_list = [DeviceResponse(**device) for device in devices]
cursor.execute(list_query, params)
device_list = cursor.fetchall()
return APIResponse(
code=200,
message="获取设备列表成功",
data=DeviceListResponse(total=total, devices=device_list)
data=DeviceListResponse(
total=total,
devices=[DeviceResponse(**device) for device in device_list]
)
)
except MySQLError as e:
raise Exception(f"获取设备列表失败{str(e)}") from e
raise Exception(f"获取设备列表失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 3. 获取单个设备详情
# 获取设备上下线记录接口
# ------------------------------
@router.get("/{device_id}", response_model=APIResponse, summary="获取设备详情")
async def get_device_detail(
device_id: int,
current_user: UserResponse = Depends(get_current_user)
@router.get("/{device_id}/status-history", response_model=APIResponse, summary="获取设备上下线记录")
async def get_device_status_history(
device_id: int = Path(..., description="设备ID"),
page: int = Query(1, ge=1, description="页码默认第1页"),
page_size: int = Query(10, ge=1, le=100, description="每页条数1-100之间"),
start_date: date = Query(None, description="开始日期格式YYYY-MM-DD"),
end_date: date = Query(None, description="结束日期格式YYYY-MM-DD")
):
conn = None
cursor = None
@ -223,36 +314,75 @@ async def get_device_detail(
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 查询设备信息SELECT * 匹配表字段)
query = "SELECT * FROM devices WHERE id = %s"
cursor.execute(query, (device_id,))
# 检查设备是否存在并获取 client_ip
cursor.execute("SELECT id, client_ip FROM devices WHERE id = %s", (device_id,))
device = cursor.fetchone()
if not device:
raise HTTPException(
status_code=404,
detail=f"设备ID为 {device_id} 的设备不存在"
)
raise HTTPException(status_code=404, detail=f"设备ID为 {device_id} 的设备不存在")
client_ip = device['client_ip']
where_clause = ["client_ip = %s"]
params = [client_ip]
# 日期筛选
if start_date:
where_clause.append("DATE(created_at) >= %s")
params.append(start_date.strftime("%Y-%m-%d"))
if end_date:
where_clause.append("DATE(created_at) <= %s")
params.append(end_date.strftime("%Y-%m-%d"))
# 统计记录总数
count_query = "SELECT COUNT(*) AS total FROM device_action WHERE " + " AND ".join(where_clause)
cursor.execute(count_query, params)
total = cursor.fetchone()["total"]
# 分页查询记录
offset = (page - 1) * page_size
list_query = f"""
SELECT * FROM device_action
WHERE {' AND '.join(where_clause)}
ORDER BY created_at DESC
LIMIT %s OFFSET %s
"""
params.extend([page_size, offset])
cursor.execute(list_query, params)
history_list = cursor.fetchall()
# 格式化为响应模型结构
formatted_history = []
for item in history_list:
formatted_item = {
"id": item["id"],
"device_id": device_id,
"client_ip": item["client_ip"],
"status": item["action"],
"status_time": item["created_at"]
}
formatted_history.append(formatted_item)
# 响应模型已更新为params字段
return APIResponse(
code=200,
message="获取设备详情成功",
data=DeviceResponse(**device)
message="获取设备上下线记录成功",
data=DeviceStatusHistoryListResponse(
total=total,
history=[DeviceStatusHistoryResponse(**item) for item in formatted_history]
)
)
except MySQLError as e:
raise Exception(f"获取设备详情失败:{str(e)}") from e
raise Exception(f"获取设备上下线记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 4. 删除设备信息
# 手动更新设备在线状态接口
# ------------------------------
@router.delete("/{device_id}", response_model=APIResponse, summary="删除设备信息")
async def delete_device(
device_id: int,
current_user: UserResponse = Depends(get_current_user)
@router.put("/{device_id}/status", response_model=APIResponse, summary="更新设备在线状态")
async def update_device_status(
device_id: int = Path(..., description="设备ID"),
status: int = Query(..., ge=0, le=1, description="在线状态1-在线、0-离线)")
):
conn = None
cursor = None
@ -260,27 +390,51 @@ async def delete_device(
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 检查设备是否存在
cursor.execute("SELECT id FROM devices WHERE id = %s", (device_id,))
if not cursor.fetchone():
raise HTTPException(
status_code=404,
detail=f"设备ID为 {device_id} 的设备不存在"
# 获取设备 client_ip
cursor.execute("SELECT id, client_ip FROM devices WHERE id = %s", (device_id,))
device = cursor.fetchone()
if not device:
raise HTTPException(status_code=404, detail=f"设备ID为 {device_id} 的设备不存在")
# 更新状态
success = update_online_status_by_ip(device['client_ip'], status)
if success:
status_text = "在线" if status == 1 else "离线"
return APIResponse(
code=200,
message=f"设备已更新为{status_text}状态",
data={"device_id": device_id, "status": status, "status_text": status_text}
)
# 执行删除
delete_query = "DELETE FROM devices WHERE id = %s"
cursor.execute(delete_query, (device_id,))
conn.commit()
return APIResponse(
code=200,
message=f"设备ID为 {device_id} 的设备已成功删除",
code=500,
message="更新设备状态失败",
data=None
)
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"删除设备失败:{str(e)}") from e
raise Exception(f"更新设备状态失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 获取所有去重的客户端IP列表
# ------------------------------
def get_unique_client_ips() -> list[str]:
"""获取所有去重的客户端IP列表"""
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
query = "SELECT DISTINCT client_ip FROM devices WHERE client_ip IS NOT NULL"
cursor.execute(query)
results = cursor.fetchall()
return [item['client_ip'] for item in results]
except MySQLError as e:
raise Exception(f"获取客户端IP列表失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)

View File

@ -1,118 +1,140 @@
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query, Request
from fastapi.responses import FileResponse
from mysql.connector import Error as MySQLError
import os
from pathlib import Path
from ds.db import db
from schema.face_schema import FaceCreateRequest, FaceUpdateRequest, FaceResponse
from schema.response_schema import APIResponse
from middle.auth_middleware import get_current_user
from schema.user_schema import UserResponse
router = APIRouter(
prefix="/faces",
tags=["人脸管理"]
from schema.face_schema import (
FaceCreateRequest,
FaceUpdateRequest,
FaceResponse,
FaceListResponse
)
from schema.response_schema import APIResponse
from util.face_util import add_binary_data, get_average_feature
from util.file_util import save_face_to_up_images
router = APIRouter(prefix="/faces", tags=["人脸管理"])
# ------------------------------
# 1. 创建人脸记录(核心修正ID 数据库自增,前端无需传
# 1. 创建人脸记录(使用修复后的路径
# ------------------------------
@router.post("", response_model=APIResponse, summary="创建人脸记录(传名称+文件ID自增")
@router.post("", response_model=APIResponse, summary="创建人脸记录")
async def create_face(
# 前端仅需传name可选Form格式、file必传文件
request: Request,
name: str = Form(None, max_length=255, description="名称(可选)"),
file: UploadFile = File(..., description="人脸文件(必传,暂不处理内容")
file: UploadFile = File(..., description="人脸文件(必传)")
):
"""
创建人脸记录:
- 需登录认证
- 前端传参multipart/form-data 表单name 可选file 必传)
- ID 由数据库自动生成,无需前端传入
- 暂不处理文件内容eigenvalue 设为 None
"""
conn = None
cursor = None
try:
# 1. 用模型校验 name仅校验长度无需ID
face_create = FaceCreateRequest(name=name)
client_ip = request.client.host if request.client else ""
if not client_ip:
raise HTTPException(status_code=400, detail="无法获取客户端IP")
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 把文件转为二进制数组
# 读取图片并保存(使用修复后的路径逻辑)
file_content = await file.read()
file_ext = file.filename.split(".")[-1].lower() if "." in file.filename else "jpg"
save_result = save_face_to_up_images(
client_ip=client_ip,
face_name=name,
image_bytes=file_content,
image_format=file_ext
)
if not save_result["success"]:
raise HTTPException(status_code=500, detail=f"图片保存失败:{save_result['msg']}")
db_image_path = save_result["db_path"] # 从修复后的方法获取路径
# 调用人脸识别得到特征
# 提取人脸特征
detect_success, detect_result = add_binary_data(file_content)
if not detect_success:
raise HTTPException(status_code=400, detail=f"人脸检测失败:{detect_result}")
eigenvalue = detect_result
# 2. 插入数据库:无需传 ID自增只传 name 和 eigenvalueNone
# 插入数据库
insert_query = """
INSERT INTO face (name, eigenvalue)
VALUES (%s, %s)
INSERT INTO face (name, eigenvalue, address)
VALUES (%s, %s, %s)
"""
cursor.execute(insert_query, (face_create.name, None))
cursor.execute(insert_query, (face_create.name, str(eigenvalue), db_image_path))
conn.commit()
# 3. 获取数据库自动生成的 ID关键用 LAST_INSERT_ID() 查刚插入的记录
select_new_query = "SELECT * FROM face WHERE id = LAST_INSERT_ID()"
cursor.execute(select_new_query)
# 查询新记录
cursor.execute("""
SELECT id, name, address, created_at, updated_at
FROM face
WHERE id = LAST_INSERT_ID()
""")
created_face = cursor.fetchone()
if not created_face:
raise HTTPException(status_code=500, detail="创建成功但无法获取记录")
return APIResponse(
code=201,
message=f"人脸记录创建成功ID{created_face['id']},文件名:{file.filename}",
message=f"人脸记录创建成功ID: {created_face['id']}",
data=FaceResponse(**created_face)
)
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"创建人脸记录失败:{str(e)}") from e
raise HTTPException(status_code=500, detail=f"创建失败: {str(e)}") from e
except Exception as e:
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
finally:
await file.close() # 关闭文件流
await file.close()
db.close_connection(conn, cursor)
# 其他接口(获取单条/列表、更新、删除、获取图片)与之前一致,无需修改
# ------------------------------
# 2. 获取单个人脸记录不变用自增ID查询
# 2. 获取单个人脸记录
# ------------------------------
@router.get("/{face_id}", response_model=APIResponse, summary="获取单个人脸记录")
async def get_face(
face_id: int, # 这里的 ID 是数据库自增的,前端从创建响应中获取
current_user: UserResponse = Depends(get_current_user)
):
async def get_face(face_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
query = "SELECT * FROM face WHERE id = %s"
query = """
SELECT id, name, address, created_at, updated_at
FROM face
WHERE id = %s
"""
cursor.execute(query, (face_id,))
face = cursor.fetchone()
if not face:
raise HTTPException(
status_code=404,
detail=f"ID为 {face_id} 的人脸记录不存在"
)
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
return APIResponse(
code=200,
message="人脸记录查询成功",
message="查询成功",
data=FaceResponse(**face)
)
except MySQLError as e:
raise Exception(f"查询人脸记录失败:{str(e)}") from e
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 后续 3.获取所有、4.更新、5.删除 接口(不变,仅用数据库自增的 ID 操作,无需修改)
# ------------------------------
# 3. 获取所有人脸记录(不变)
# 3. 获取人脸列表
# ------------------------------
@router.get("", response_model=APIResponse, summary="获取所有人脸记录")
async def get_all_faces(
current_user: UserResponse = Depends(get_current_user)
@router.get("", response_model=APIResponse, summary="获取人脸列表(分页+筛选)")
async def get_face_list(
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=100),
name: str = Query(None),
has_eigenvalue: bool = Query(None)
):
conn = None
cursor = None
@ -120,47 +142,66 @@ async def get_all_faces(
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
query = "SELECT * FROM face ORDER BY id" # 按自增ID排序
cursor.execute(query)
faces = cursor.fetchall()
where_clause = []
params = []
if name:
where_clause.append("name LIKE %s")
params.append(f"%{name}%")
if has_eigenvalue is not None:
where_clause.append("eigenvalue IS NOT NULL" if has_eigenvalue else "eigenvalue IS NULL")
# 总记录数
count_query = "SELECT COUNT(*) AS total FROM face"
if where_clause:
count_query += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_query, params)
total = cursor.fetchone()["total"]
# 列表数据
offset = (page - 1) * page_size
list_query = """
SELECT id, name, address, created_at, updated_at
FROM face
"""
if where_clause:
list_query += " WHERE " + " AND ".join(where_clause)
list_query += " ORDER BY id DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset])
cursor.execute(list_query, params)
face_list = cursor.fetchall()
return APIResponse(
code=200,
message="所有人脸记录查询成功",
data=[FaceResponse(**face) for face in faces]
message=f"获取成功(共{total}条)",
data=FaceListResponse(
total=total,
faces=[FaceResponse(**face) for face in face_list]
)
)
except MySQLError as e:
raise Exception(f"查询所有人脸记录失败:{str(e)}") from e
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 4. 更新人脸记录不变用自增ID更新
# 4. 更新人脸记录
# ------------------------------
@router.put("/{face_id}", response_model=APIResponse, summary="更新人脸记录")
async def update_face(
face_id: int,
face_update: FaceUpdateRequest,
current_user: UserResponse = Depends(get_current_user)
):
async def update_face(face_id: int, face_update: FaceUpdateRequest):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 检查记录是否存在
check_query = "SELECT id FROM face WHERE id = %s"
cursor.execute(check_query, (face_id,))
existing_face = cursor.fetchone()
if not existing_face:
raise HTTPException(
status_code=404,
detail=f"ID为 {face_id} 的人脸记录不存在"
)
cursor.execute("SELECT id, address FROM face WHERE id = %s", (face_id,))
exist_face = cursor.fetchone()
if not exist_face:
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
old_db_path = exist_face["address"]
# 构建更新语句
update_fields = []
params = []
if face_update.name is not None:
@ -169,6 +210,18 @@ async def update_face(
if face_update.eigenvalue is not None:
update_fields.append("eigenvalue = %s")
params.append(face_update.eigenvalue)
if face_update.address is not None:
# 删除旧图片(相对路径转绝对路径)
if old_db_path:
old_abs_path = Path(old_db_path).resolve()
if old_abs_path.exists():
try:
old_abs_path.unlink() # 使用Path方法删除更安全
print(f"[FaceRouter] 已删除旧图片:{old_abs_path}")
except Exception as e:
print(f"[FaceRouter] 删除旧图片失败:{str(e)}")
update_fields.append("address = %s")
params.append(face_update.address)
if not update_fields:
raise HTTPException(status_code=400, detail="至少需提供一个更新字段")
@ -178,92 +231,143 @@ async def update_face(
cursor.execute(update_query, params)
conn.commit()
# 查询更新后记录
select_query = "SELECT * FROM face WHERE id = %s"
cursor.execute(select_query, (face_id,))
cursor.execute("""
SELECT id, name, address, created_at, updated_at
FROM face
WHERE id = %s
""", (face_id,))
updated_face = cursor.fetchone()
return APIResponse(
code=200,
message="人脸记录更新成功",
message="更新成功",
data=FaceResponse(**updated_face)
)
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"更新人脸记录失败:{str(e)}") from e
raise HTTPException(status_code=500, detail=f"更新失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 5. 删除人脸记录不变用自增ID删除
# 5. 删除人脸记录
# ------------------------------
@router.delete("/{face_id}", response_model=APIResponse, summary="删除人脸记录")
async def delete_face(
face_id: int,
current_user: UserResponse = Depends(get_current_user)
):
async def delete_face(face_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
check_query = "SELECT id FROM face WHERE id = %s"
cursor.execute(check_query, (face_id,))
existing_face = cursor.fetchone()
if not existing_face:
raise HTTPException(
status_code=404,
detail=f"ID为 {face_id} 的人脸记录不存在"
)
cursor.execute("SELECT id, address FROM face WHERE id = %s", (face_id,))
exist_face = cursor.fetchone()
if not exist_face:
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
old_db_path = exist_face["address"]
delete_query = "DELETE FROM face WHERE id = %s"
cursor.execute(delete_query, (face_id,))
cursor.execute("DELETE FROM face WHERE id = %s", (face_id,))
conn.commit()
# 删除图片
extra_msg = ""
if old_db_path:
old_abs_path = Path(old_db_path).resolve()
if old_abs_path.exists():
try:
old_abs_path.unlink()
print(f"[FaceRouter] 已删除图片:{old_abs_path}")
extra_msg = "(已同步删除图片)"
except Exception as e:
print(f"[FaceRouter] 删除图片失败:{str(e)}")
extra_msg = "(图片删除失败)"
else:
extra_msg = "(图片不存在)"
else:
extra_msg = "(无关联图片)"
return APIResponse(
code=200,
message=f"ID为 {face_id}人脸记录删除成功",
message=f"ID为 {face_id} 的记录删除成功 {extra_msg}",
data=None
)
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"删除人脸记录失败:{str(e)}") from e
raise HTTPException(status_code=500, detail=f"删除失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
def get_all_face_name_with_eigenvalue() -> dict:
"""
获取所有人脸的名称及其对应的特征值,组成字典返回
key: 人脸名称name
value: 人脸特征值eigenvalue
过滤掉name为None的记录避免字典key为None的情况
"""
# ------------------------------
# 6. 获取人脸图片
# ------------------------------
@router.get("/{face_id}/image", summary="获取人脸图片")
async def get_face_image(face_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
query = "SELECT address, name FROM face WHERE id = %s"
cursor.execute(query, (face_id,))
face = cursor.fetchone()
if not face:
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
db_path = face["address"]
abs_path = Path(db_path).resolve() # 转为绝对路径
if not db_path or not abs_path.exists():
raise HTTPException(status_code=404, detail=f"图片不存在(路径:{db_path}")
return FileResponse(
path=abs_path,
filename=f"face_{face_id}_{face['name'] or '未命名'}.{db_path.split('.')[-1]}",
media_type=f"image/{db_path.split('.')[-1]}"
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"获取图片失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 内部工具方法
# ------------------------------
def get_all_face_name_with_eigenvalue() -> dict:
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 只查询需要的字段,提高效率
query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL"
cursor.execute(query)
faces = cursor.fetchall()
# 构建name到eigenvalue的映射字典
face_dict = {
face["name"]: face["eigenvalue"]
for face in faces
}
name_to_eigenvalues = {}
for face in faces:
name = face["name"]
eigenvalue = face["eigenvalue"]
if name in name_to_eigenvalues:
name_to_eigenvalues[name].append(eigenvalue)
else:
name_to_eigenvalues[name] = [eigenvalue]
face_dict = {}
for name, eigenvalues in name_to_eigenvalues.items():
if len(eigenvalues) > 1:
face_dict[name] = get_average_feature(eigenvalues)
else:
face_dict[name] = eigenvalues[0]
return face_dict
except MySQLError as e:
raise Exception(f"获取人脸名称与特征失败{str(e)}") from e
raise Exception(f"获取人脸特征失败: {str(e)}") from e
finally:
# 确保资源释放
db.close_connection(conn, cursor)

668
service/model_service.py Normal file
View File

@ -0,0 +1,668 @@
import subprocess
import os
import sys
import shutil
import threading
from pathlib import Path
from datetime import datetime
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
from fastapi.responses import FileResponse
from mysql.connector import Error as MySQLError
# 复用项目依赖
from ds.db import db
from schema.model_schema import (
ModelCreateRequest,
ModelUpdateRequest,
ModelResponse,
ModelListResponse
)
from schema.response_schema import APIResponse
from util.model_util import load_yolo_model # 模型加载工具
# 路径配置
CURRENT_FILE_PATH = Path(__file__).resolve()
PROJECT_ROOT = CURRENT_FILE_PATH.parent.parent
MODEL_SAVE_ROOT = PROJECT_ROOT / "resource" / "models"
MODEL_SAVE_ROOT.mkdir(exist_ok=True, parents=True)
DB_PATH_PREFIX_TO_REMOVE = str(PROJECT_ROOT) + os.sep
# 模型限制
ALLOWED_MODEL_EXT = {"pt"}
MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB
# 全局模型变量(带版本标识)
global _yolo_model, _current_model_version
_yolo_model = None
_current_model_version = None # 模型版本标识(用于检测模型是否变化)
router = APIRouter(prefix="/models", tags=["模型管理"])
# 服务重启核心工具函数
def restart_service():
"""重启当前FastAPI服务进程"""
print("\n[服务重启] 检测到默认模型更换,开始清理资源并重启...")
try:
# 关闭所有WebSocket连接
try:
from ws import connected_clients
if connected_clients:
print(f"[服务重启] 关闭{len(connected_clients)}个WebSocket旧连接")
for ip, conn in list(connected_clients.items()):
try:
if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel()
conn.websocket.close(code=1001, reason="模型更新,服务重启")
connected_clients.pop(ip)
except Exception as e:
print(f"[服务重启] 关闭客户端{ip}连接失败:{str(e)}")
except ImportError:
print("[服务重启] 未找到WebSocket连接管理模块跳过连接关闭")
# 关闭数据库连接
if hasattr(db, "close_all_connections"):
db.close_all_connections()
else:
print("[警告] db模块未实现close_all_connections可能存在连接泄漏")
# 启动新进程
python_exec = sys.executable
current_argv = sys.argv
print(f"[服务重启] 启动新进程:{python_exec} {' '.join(current_argv)}")
subprocess.Popen(
[python_exec] + current_argv,
close_fds=True,
start_new_session=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
# 退出当前进程
print("[服务重启] 新进程已启动,当前进程退出")
sys.exit(0)
except Exception as e:
print(f"[服务重启] 重启失败:{str(e)}")
raise HTTPException(status_code=500, detail=f"模型更换成功,但服务重启失败:{str(e)}") from e
# 模型路径验证工具函数
def get_valid_model_abs_path(relative_path: str) -> str:
try:
relative_path = relative_path.replace("/", os.sep)
model_abs_path = PROJECT_ROOT / relative_path
model_abs_path = model_abs_path.resolve()
model_abs_path_str = str(model_abs_path)
if not model_abs_path_str.startswith(str(MODEL_SAVE_ROOT)):
raise HTTPException(
status_code=400,
detail=f"模型路径非法!允许目录:{str(MODEL_SAVE_ROOT)},当前路径:{model_abs_path_str}"
)
if not model_abs_path.exists():
raise HTTPException(
status_code=404,
detail=f"模型文件不存在!路径:{model_abs_path_str}"
)
if not model_abs_path.is_file():
raise HTTPException(
status_code=400,
detail=f"路径不是文件!路径:{model_abs_path_str}"
)
file_size = model_abs_path.stat().st_size
if file_size > MAX_MODEL_SIZE:
raise HTTPException(
status_code=400,
detail=f"模型文件过大({file_size // 1024 // 1024}MB超过限制{MAX_MODEL_SIZE // 1024 // 1024}MB"
)
file_ext = model_abs_path.suffix.lower()
if file_ext not in [f".{ext}" for ext in ALLOWED_MODEL_EXT]:
raise HTTPException(
status_code=400,
detail=f"模型格式非法(仅支持{ALLOWED_MODEL_EXT})!当前格式:{file_ext}"
)
print(f"[模型路径校验] 成功!路径:{model_abs_path_str},大小:{file_size // 1024}KB")
return model_abs_path_str
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"路径处理失败:{str(e)}"
) from e
# 对外提供当前模型(带版本校验)
def get_current_yolo_model():
"""供检测模块获取当前最新默认模型(仅版本变化时重新加载)"""
global _yolo_model, _current_model_version
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT path FROM model WHERE is_default = 1")
default_model = cursor.fetchone()
if not default_model:
print("[get_current_yolo_model] 暂无默认模型")
return None
# 1. 计算当前默认模型的唯一版本标识
# (路径哈希 + 文件修改时间戳,确保模型变化时版本变化)
valid_abs_path = get_valid_model_abs_path(default_model["path"])
model_stat = os.stat(valid_abs_path)
model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
# 2. 版本未变化则复用已有模型(核心优化点)
if _yolo_model and _current_model_version == model_version:
# print(f"[get_current_yolo_model] 模型版本未变,复用缓存(版本:{_current_model_version[:10]}...")
return _yolo_model
# 3. 版本变化时重新加载模型
_yolo_model = load_yolo_model(valid_abs_path)
if _yolo_model:
setattr(_yolo_model, "model_path", valid_abs_path)
_current_model_version = model_version # 更新版本标识
print(f"[get_current_yolo_model] 模型版本更新,重新加载(新版本:{model_version[:10]}...")
else:
print(f"[get_current_yolo_model] 加载最新默认模型失败:{valid_abs_path}")
return _yolo_model
except Exception as e:
print(f"[get_current_yolo_model] 加载失败:{str(e)}")
return None
finally:
db.close_connection(conn, cursor)
# 1. 上传模型
@router.post("", response_model=APIResponse, summary="上传YOLO模型.pt格式")
async def upload_model(
name: str = Form(..., description="模型名称"),
description: str = Form(None, description="模型描述"),
is_default: bool = Form(False, description="是否设为默认模型"),
file: UploadFile = File(..., description=f"YOLO模型文件.pt最大{MAX_MODEL_SIZE // 1024 // 1024}MB")
):
conn = None
cursor = None
saved_file_path = None
try:
# 校验文件
file_ext = file.filename.split(".")[-1].lower() if "." in file.filename else ""
if file_ext not in ALLOWED_MODEL_EXT:
raise HTTPException(
status_code=400,
detail=f"仅支持{ALLOWED_MODEL_EXT}格式,当前:{file_ext}"
)
if file.size > MAX_MODEL_SIZE:
raise HTTPException(
status_code=400,
detail=f"文件过大!最大{MAX_MODEL_SIZE // 1024 // 1024}MB当前{file.size // 1024 // 1024}MB"
)
# 保存文件
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
safe_filename = f"model_{timestamp}_{file.filename.replace(' ', '_')}"
saved_file_path = MODEL_SAVE_ROOT / safe_filename
with open(saved_file_path, "wb") as f:
shutil.copyfileobj(file.file, f)
saved_file_path.chmod(0o644) # 设置权限
# 数据库路径处理
db_relative_path = str(saved_file_path).replace(DB_PATH_PREFIX_TO_REMOVE, "").replace(os.sep, "/")
# 数据库操作
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
if is_default:
cursor.execute("UPDATE model SET is_default = 0")
insert_sql = """
INSERT INTO model (name, path, is_default, description, file_size)
VALUES (%s, %s, %s, %s, %s)
"""
cursor.execute(insert_sql, (name, db_relative_path, 1 if is_default else 0, description, file.size))
conn.commit()
cursor.execute("SELECT * FROM model WHERE id = LAST_INSERT_ID()")
new_model = cursor.fetchone()
if not new_model:
raise HTTPException(status_code=500, detail="上传成功但无法获取记录")
# 加载默认模型并更新版本
global _yolo_model, _current_model_version
if is_default:
valid_abs_path = get_valid_model_abs_path(db_relative_path)
_yolo_model = load_yolo_model(valid_abs_path)
if _yolo_model:
setattr(_yolo_model, "model_path", valid_abs_path)
model_stat = os.stat(valid_abs_path)
_current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
else:
raise HTTPException(
status_code=500,
detail=f"上传成功,但加载默认模型失败(路径:{valid_abs_path}"
)
return APIResponse(
code=201,
message=f"模型上传成功ID{new_model['id']}",
data=ModelResponse(**new_model)
)
except MySQLError as e:
if conn:
conn.rollback()
if saved_file_path and saved_file_path.exists():
saved_file_path.unlink()
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
except Exception as e:
if saved_file_path and saved_file_path.exists():
saved_file_path.unlink()
raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}") from e
finally:
await file.close()
db.close_connection(conn, cursor)
# 2. 获取模型列表
@router.get("", response_model=APIResponse, summary="获取模型列表(分页)")
async def get_model_list(
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=100),
name: str = Query(None),
is_default: bool = Query(None)
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
where_clause = []
params = []
if name:
where_clause.append("name LIKE %s")
params.append(f"%{name}%")
if is_default is not None:
where_clause.append("is_default = %s")
params.append(1 if is_default else 0)
# 总记录数
count_sql = "SELECT COUNT(*) AS total FROM model"
if where_clause:
count_sql += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_sql, params)
total = cursor.fetchone()["total"]
# 分页数据
offset = (page - 1) * page_size
list_sql = "SELECT * FROM model"
if where_clause:
list_sql += " WHERE " + " AND ".join(where_clause)
list_sql += " ORDER BY updated_at DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset])
cursor.execute(list_sql, params)
model_list = cursor.fetchall()
return APIResponse(
code=200,
message=f"获取成功!共{total}条记录",
data=ModelListResponse(
total=total,
models=[ModelResponse(**model) for model in model_list]
)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 3. 获取默认模型
@router.get("/default", response_model=APIResponse, summary="获取当前默认模型")
async def get_default_model():
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT * FROM model WHERE is_default = 1")
default_model = cursor.fetchone()
if not default_model:
raise HTTPException(status_code=404, detail="暂无默认模型")
valid_abs_path = get_valid_model_abs_path(default_model["path"])
global _yolo_model, _current_model_version
if not _yolo_model:
_yolo_model = load_yolo_model(valid_abs_path)
if _yolo_model:
setattr(_yolo_model, "model_path", valid_abs_path)
model_stat = os.stat(valid_abs_path)
_current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
else:
raise HTTPException(
status_code=500,
detail=f"默认模型存在,但加载失败(路径:{valid_abs_path}"
)
return APIResponse(
code=200,
message="默认模型查询成功",
data=ModelResponse(**default_model)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 4. 获取单个模型详情
@router.get("/{model_id}", response_model=APIResponse, summary="获取单个模型详情")
async def get_model(model_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
model = cursor.fetchone()
if not model:
raise HTTPException(status_code=404, detail=f"模型不存在ID{model_id}")
try:
model_abs_path = get_valid_model_abs_path(model["path"])
except HTTPException as e:
return APIResponse(
code=200,
message=f"查询成功,但路径异常:{e.detail}",
data=ModelResponse(**model)
)
return APIResponse(
code=200,
message="查询成功",
data=ModelResponse(**model)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 5. 更新模型信息
@router.put("/{model_id}", response_model=APIResponse, summary="更新模型信息")
async def update_model(model_id: int, model_update: ModelUpdateRequest):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
exist_model = cursor.fetchone()
if not exist_model:
raise HTTPException(status_code=404, detail=f"模型不存在ID{model_id}")
update_fields = []
params = []
if model_update.name is not None:
update_fields.append("name = %s")
params.append(model_update.name)
if model_update.description is not None:
update_fields.append("description = %s")
params.append(model_update.description)
need_load_default = False
if model_update.is_default is not None:
if model_update.is_default:
cursor.execute("UPDATE model SET is_default = 0")
update_fields.append("is_default = 1")
need_load_default = True
else:
cursor.execute("SELECT COUNT(*) AS cnt FROM model WHERE is_default = 1")
default_count = cursor.fetchone()["cnt"]
if default_count == 1 and exist_model["is_default"]:
raise HTTPException(
status_code=400,
detail="当前是唯一默认模型,不可取消!"
)
update_fields.append("is_default = 0")
if not update_fields:
raise HTTPException(status_code=400, detail="至少需提供一个更新字段")
params.append(model_id)
update_sql = f"""
UPDATE model
SET {', '.join(update_fields)}, updated_at = CURRENT_TIMESTAMP
WHERE id = %s
"""
cursor.execute(update_sql, params)
conn.commit()
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
updated_model = cursor.fetchone()
# 更新模型后重置版本标识
global _yolo_model, _current_model_version
if need_load_default:
valid_abs_path = get_valid_model_abs_path(updated_model["path"])
_yolo_model = load_yolo_model(valid_abs_path)
if _yolo_model:
setattr(_yolo_model, "model_path", valid_abs_path)
model_stat = os.stat(valid_abs_path)
_current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
else:
raise HTTPException(
status_code=500,
detail=f"更新成功,但加载新默认模型失败(路径:{valid_abs_path}"
)
return APIResponse(
code=200,
message="模型更新成功",
data=ModelResponse(**updated_model)
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 5.1 更换默认模型(自动重启服务)
@router.put("/{model_id}/set-default", response_model=APIResponse, summary="更换默认模型(自动重启服务)")
async def set_default_model(model_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
conn.autocommit = False # 开启事务
# 1. 校验目标模型是否存在
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
target_model = cursor.fetchone()
if not target_model:
raise HTTPException(status_code=404, detail=f"目标模型不存在ID{model_id}")
# 2. 检查是否已为默认模型
if target_model["is_default"]:
return APIResponse(
code=200,
message=f"模型ID{model_id} 已是默认模型,无需更换和重启",
data=ModelResponse(**target_model)
)
# 3. 校验目标模型文件合法性
try:
valid_abs_path = get_valid_model_abs_path(target_model["path"])
except HTTPException as e:
raise HTTPException(
status_code=400,
detail=f"目标模型文件非法,无法设为默认:{e.detail}"
) from e
# 4. 数据库事务:更新默认模型状态
try:
cursor.execute("UPDATE model SET is_default = 0, updated_at = CURRENT_TIMESTAMP")
cursor.execute(
"UPDATE model SET is_default = 1, updated_at = CURRENT_TIMESTAMP WHERE id = %s",
(model_id,)
)
conn.commit()
except MySQLError as e:
conn.rollback()
raise HTTPException(
status_code=500,
detail=f"更新默认模型状态失败(已回滚):{str(e)}"
) from e
# 5. 验证新模型可加载性
test_model = load_yolo_model(valid_abs_path)
if not test_model:
conn.rollback()
raise HTTPException(
status_code=500,
detail=f"新默认模型加载失败,已回滚状态(路径:{valid_abs_path}"
)
# 6. 重新查询更新后的模型信息
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
updated_model = cursor.fetchone()
# 7. 重置版本标识(关键:确保下次检测加载新模型)
global _current_model_version
_current_model_version = None
print(f"[更换默认模型] 已重置模型版本标识,下次检测将加载新模型")
# 8. 延迟重启服务
print(f"[更换默认模型] 成功将在1秒后重启服务以应用新模型ID{model_id}")
threading.Timer(
interval=1.0,
function=restart_service
).start()
# 9. 返回成功响应
return APIResponse(
code=200,
message=f"已成功更换默认模型ID{model_id}服务将在1秒后自动重启以应用新模型",
data=ModelResponse(**updated_model)
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
if conn:
conn.autocommit = True
db.close_connection(conn, cursor)
# 6. 删除模型
@router.delete("/{model_id}", response_model=APIResponse, summary="删除模型")
async def delete_model(model_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
exist_model = cursor.fetchone()
if not exist_model:
raise HTTPException(status_code=404, detail=f"模型不存在ID{model_id}")
if exist_model["is_default"]:
raise HTTPException(status_code=400, detail="默认模型不可删除!")
try:
model_abs_path_str = get_valid_model_abs_path(exist_model["path"])
model_abs_path = Path(model_abs_path_str)
except HTTPException as e:
cursor.execute("DELETE FROM model WHERE id = %s", (model_id,))
conn.commit()
return APIResponse(
code=200,
message=f"记录删除成功,文件异常:{e.detail}",
data=None
)
cursor.execute("DELETE FROM model WHERE id = %s", (model_id,))
conn.commit()
extra_msg = ""
try:
model_abs_path.unlink()
extra_msg = f"(已删除文件)"
except Exception as e:
extra_msg = f"(文件删除失败:{str(e)}"
# 如果删除的是当前加载的模型,重置缓存
global _yolo_model, _current_model_version
if _yolo_model and str(getattr(_yolo_model, "model_path", "")) == model_abs_path_str:
_yolo_model = None
_current_model_version = None
print(f"[模型删除] 已清空全局模型缓存(路径:{model_abs_path_str}")
return APIResponse(
code=200,
message=f"模型删除成功ID{model_id} {extra_msg}",
data=None
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 7. 下载模型文件
@router.get("/{model_id}/download", summary="下载模型文件")
async def download_model(model_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
model = cursor.fetchone()
if not model:
raise HTTPException(status_code=404, detail=f"模型不存在ID{model_id}")
valid_abs_path = get_valid_model_abs_path(model["path"])
model_abs_path = Path(valid_abs_path)
return FileResponse(
path=model_abs_path,
filename=f"model_{model_id}_{model['name']}.pt",
media_type="application/octet-stream"
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)

View File

@ -21,7 +21,7 @@ router = APIRouter(
async def create_sensitive(
sensitive: SensitiveCreateRequest): # 添加了登录认证依赖
"""
创建敏感信息记录
创建敏感信息记录:
- 需登录认证
- 插入新的敏感信息记录到数据库ID由数据库自动生成
- 返回创建成功信息
@ -32,7 +32,7 @@ async def create_sensitive(
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 插入新敏感信息记录到数据库不包含ID由数据库自动生成)
# 插入新敏感信息记录到数据库不包含ID由数据库自动生成)
insert_query = """
INSERT INTO sensitives (name)
VALUES (%s)
@ -56,7 +56,7 @@ async def create_sensitive(
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"创建敏感信息记录失败{str(e)}") from e
raise Exception(f"创建敏感信息记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
@ -71,7 +71,7 @@ async def get_sensitive(
current_user: UserResponse = Depends(get_current_user) # 需登录认证
):
"""
获取单个敏感信息记录
获取单个敏感信息记录:
- 需登录认证
- 根据ID查询敏感信息记录
- 返回查询到的敏感信息
@ -98,7 +98,7 @@ async def get_sensitive(
data=SensitiveResponse(**sensitive)
)
except MySQLError as e:
raise Exception(f"查询敏感信息记录失败{str(e)}") from e
raise Exception(f"查询敏感信息记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
@ -109,7 +109,7 @@ async def get_sensitive(
@router.get("", response_model=APIResponse, summary="获取所有敏感信息记录")
async def get_all_sensitives():
"""
获取所有敏感信息记录
获取所有敏感信息记录:
- 需登录认证
- 查询所有敏感信息记录(不需要分页)
- 返回所有敏感信息列表
@ -130,7 +130,7 @@ async def get_all_sensitives():
data=[SensitiveResponse(**sensitive) for sensitive in sensitives]
)
except MySQLError as e:
raise Exception(f"查询所有敏感信息记录失败{str(e)}") from e
raise Exception(f"查询所有敏感信息记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
@ -145,7 +145,7 @@ async def update_sensitive(
current_user: UserResponse = Depends(get_current_user) # 需登录认证
):
"""
更新敏感信息记录
更新敏感信息记录:
- 需登录认证
- 根据ID更新敏感信息记录
- 返回更新后的敏感信息
@ -203,7 +203,7 @@ async def update_sensitive(
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"更新敏感信息记录失败{str(e)}") from e
raise Exception(f"更新敏感信息记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
@ -217,7 +217,7 @@ async def delete_sensitive(
current_user: UserResponse = Depends(get_current_user) # 需登录认证
):
"""
删除敏感信息记录
删除敏感信息记录:
- 需登录认证
- 根据ID删除敏感信息记录
- 返回删除成功信息
@ -251,14 +251,14 @@ async def delete_sensitive(
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"删除敏感信息记录失败{str(e)}") from e
raise Exception(f"删除敏感信息记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
def get_all_sensitive_words() -> list[str]:
"""
获取所有敏感词返回字符串数组
获取所有敏感词返回字符串数组
返回:
list[str]: 包含所有敏感词的数组
@ -273,7 +273,7 @@ def get_all_sensitive_words() -> list[str]:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 执行查询只获取敏感词字段
# 执行查询只获取敏感词字段
query = "SELECT name FROM sensitives ORDER BY id"
cursor.execute(query)
sensitive_records = cursor.fetchall()
@ -283,7 +283,7 @@ def get_all_sensitive_words() -> list[str]:
except MySQLError as e:
# 数据库错误处理
raise MySQLError(f"查询敏感词失败{str(e)}") from e
raise MySQLError(f"查询敏感词失败: {str(e)}") from e
finally:
# 确保资源正确释放
db.close_connection(conn, cursor)

View File

@ -1,6 +1,7 @@
from datetime import timedelta
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Query
from mysql.connector import Error as MySQLError
from ds.db import db
@ -11,7 +12,7 @@ from middle.auth_middleware import (
verify_password,
create_access_token,
ACCESS_TOKEN_EXPIRE_MINUTES,
get_current_user
get_current_user # 仅保留登录用户校验移除is_admin导入
)
# 创建用户接口路由(前缀 /users、标签用于 Swagger 分类)
@ -27,7 +28,7 @@ router = APIRouter(
@router.post("/register", response_model=APIResponse, summary="用户注册")
async def user_register(request: UserRegisterRequest):
"""
用户注册
用户注册:
- 校验用户名是否已存在
- 加密密码后插入数据库
- 返回注册成功信息
@ -67,7 +68,7 @@ async def user_register(request: UserRegisterRequest):
)
except MySQLError as e:
conn.rollback() # 数据库错误时回滚事务
raise Exception(f"注册失败{str(e)}") from e
raise Exception(f"注册失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
@ -78,7 +79,7 @@ async def user_register(request: UserRegisterRequest):
@router.post("/login", response_model=APIResponse, summary="用户登录(获取 Token")
async def user_login(request: UserLoginRequest):
"""
用户登录
用户登录:
- 校验用户名是否存在
- 校验密码是否正确
- 生成 JWT Token 并返回
@ -89,7 +90,7 @@ async def user_login(request: UserLoginRequest):
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 修复SQL查询添加 created_at 和 updated_at 字段
# 修复: SQL查询添加 created_at 和 updated_at 字段
query = """
SELECT id, username, password, created_at, updated_at
FROM users
@ -129,7 +130,7 @@ async def user_login(request: UserLoginRequest):
}
)
except MySQLError as e:
raise Exception(f"登录失败{str(e)}") from e
raise Exception(f"登录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
@ -142,8 +143,8 @@ async def get_current_user_info(
current_user: UserResponse = Depends(get_current_user) # 依赖认证中间件
):
"""
获取当前登录用户信息
- 需在请求头携带 Token格式Bearer <token>
获取当前登录用户信息:
- 需在请求头携带 Token格式: Bearer <token>
- 认证通过后返回用户信息
"""
return APIResponse(
@ -152,3 +153,98 @@ async def get_current_user_info(
data=current_user
)
# ------------------------------
# 4. 获取用户列表(仅需登录权限)
# ------------------------------
@router.get("/list", response_model=APIResponse, summary="获取用户列表")
async def get_user_list(
page: int = Query(1, ge=1, description="页码从1开始"),
page_size: int = Query(10, ge=1, le=100, description="每页条数1-100之间"),
username: Optional[str] = Query(None, description="用户名模糊搜索"),
current_user: UserResponse = Depends(get_current_user) # 仅需登录即可访问(移除管理员校验)
):
"""
获取用户列表:
- 需登录权限(请求头携带 Token: Bearer <token>
- 支持分页查询page=页码page_size=每页条数)
- 支持用户名模糊搜索(如输入"test"可匹配"test123""admin_test"等)
- 仅返回用户ID、用户名、创建时间、更新时间不包含密码等敏感信息
"""
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 计算分页偏移量page从1开始偏移量=(页码-1*每页条数)
offset = (page - 1) * page_size
# 基础查询(仅查非敏感字段)
base_query = """
SELECT id, username, created_at, updated_at
FROM users
"""
# 总条数查询(用于分页计算)
count_query = "SELECT COUNT(*) as total FROM users"
# 条件拼接(支持用户名模糊搜索)
conditions = []
params = []
if username:
conditions.append("username LIKE %s")
params.append(f"%{username}%") # 模糊匹配:%表示任意字符
# 构建最终查询语句
if conditions:
where_clause = " WHERE " + " AND ".join(conditions)
final_query = f"{base_query}{where_clause} LIMIT %s OFFSET %s"
final_count_query = f"{count_query}{where_clause}"
params.extend([page_size, offset]) # 追加分页参数
else:
final_query = f"{base_query} LIMIT %s OFFSET %s"
final_count_query = count_query
params = [page_size, offset]
# 1. 查询用户列表数据
cursor.execute(final_query, params)
users = cursor.fetchall()
# 2. 查询总条数(用于计算总页数)
count_params = [f"%{username}%"] if username else []
cursor.execute(final_count_query, count_params)
total = cursor.fetchone()["total"]
# 3. 转换为UserResponse模型确保字段匹配
user_list = [
UserResponse(
id=user["id"],
username=user["username"],
created_at=user["created_at"],
updated_at=user["updated_at"]
)
for user in users
]
# 4. 计算总页数向上取整如11条数据每页10条=2页
total_pages = (total + page_size - 1) // page_size
# 返回结果(包含列表和分页信息)
return APIResponse(
code=200,
message="获取用户列表成功",
data={
"users": user_list,
"pagination": {
"page": page, # 当前页码
"page_size": page_size, # 每页条数
"total": total, # 总数据量
"total_pages": total_pages # 总页数
}
}
)
except MySQLError as e:
raise Exception(f"获取用户列表失败: {str(e)}") from e
finally:
# 无论成功失败,都关闭数据库连接
db.close_connection(conn, cursor)

156
util/face_util.py Normal file
View File

@ -0,0 +1,156 @@
import cv2
import numpy as np
import insightface
from insightface.app import FaceAnalysis
from io import BytesIO
from PIL import Image
import logging
# 配置日志(便于排查)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - [FaceUtil] - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# 全局变量存储InsightFace引擎和特征列表
_insightface_app = None
_feature_list = []
def init_insightface():
"""初始化InsightFace引擎确保成功后再使用"""
global _insightface_app
try:
if _insightface_app is not None:
logger.info("InsightFace引擎已初始化无需重复执行")
return _insightface_app
logger.info("正在初始化InsightFace引擎模型buffalo_l...")
# 手动指定模型下载路径(避免权限问题,可选)
app = FaceAnalysis(
name='buffalo_l',
root='~/.insightface', # 模型默认下载路径
providers=['CPUExecutionProvider'] # 强制用CPU若有GPU可加'CUDAExecutionProvider'
)
app.prepare(ctx_id=0, det_size=(640, 640)) # det_size越大小人脸检测越准
logger.info("InsightFace引擎初始化完成")
_insightface_app = app
return app
except Exception as e:
logger.error(f"InsightFace初始化失败{str(e)}", exc_info=True) # 打印详细堆栈
_insightface_app = None
return None
def add_binary_data(binary_data):
"""
接收单张图片的二进制数据、提取特征并保存
返回:(True, 特征值numpy数组) 或 (False, 错误信息字符串)
"""
global _insightface_app, _feature_list
# 1. 先检查引擎是否初始化成功
if not _insightface_app:
init_result = init_insightface() # 尝试重新初始化
if not init_result:
error_msg = "InsightFace引擎未初始化无法检测人脸"
logger.error(error_msg)
return False, error_msg
try:
# 2. 验证二进制数据有效性
if len(binary_data) < 1024: # 过滤过小的无效图片小于1KB
error_msg = f"图片过小({len(binary_data)}字节),可能不是有效图片"
logger.warning(error_msg)
return False, error_msg
# 3. 二进制数据转CV2格式关键步骤避免通道错误
try:
img = Image.open(BytesIO(binary_data)).convert("RGB") # 强制转RGB
frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) # InsightFace需要BGR格式
except Exception as e:
error_msg = f"图片格式转换失败:{str(e)}"
logger.error(error_msg, exc_info=True)
return False, error_msg
# 4. 检查图片尺寸(避免极端尺寸导致检测失败)
height, width = frame.shape[:2]
if height < 64 or width < 64: # 人脸检测最小建议尺寸
error_msg = f"图片尺寸过小({width}x{height}需至少64x64像素"
logger.warning(error_msg)
return False, error_msg
# 5. 调用InsightFace检测人脸
logger.info(f"开始检测人脸(图片尺寸:{width}x{height}格式BGR")
faces = _insightface_app.get(frame)
if not faces:
error_msg = "未检测到人脸(请确保图片包含清晰正面人脸,无遮挡、不模糊)"
logger.warning(error_msg)
return False, error_msg
# 6. 提取特征并保存
current_feature = faces[0].embedding
_feature_list.append(current_feature)
logger.info(f"人脸检测成功,提取特征值(维度:{current_feature.shape[0]}),累计特征数:{len(_feature_list)}")
return True, current_feature
except Exception as e:
error_msg = f"处理图片时发生异常:{str(e)}"
logger.error(error_msg, exc_info=True)
return False, error_msg
# 以下函数保持不变get_average_feature/clear_features/get_feature_list
def get_average_feature(features=None):
global _feature_list
try:
if features is None:
features = _feature_list
if not isinstance(features, list) or len(features) == 0:
logger.warning("输入必须是包含至少一个特征值的列表")
return None
processed_features = []
for i, embedding in enumerate(features):
try:
if isinstance(embedding, str):
embedding_str = embedding.replace('[', '').replace(']', '').replace(',', ' ').strip()
embedding_list = [float(num) for num in embedding_str.split() if num.strip()]
embedding_np = np.array(embedding_list, dtype=np.float32)
else:
embedding_np = np.array(embedding, dtype=np.float32)
if len(embedding_np.shape) == 1:
processed_features.append(embedding_np)
logger.info(f"已添加第 {i + 1} 个特征值用于计算平均值")
else:
logger.warning(f"跳过第 {i + 1} 个特征值:不是一维数组")
except Exception as e:
logger.error(f"处理第 {i + 1} 个特征值时出错:{str(e)}")
if not processed_features:
logger.warning("没有有效的特征值用于计算平均值")
return None
dims = {feat.shape[0] for feat in processed_features}
if len(dims) > 1:
logger.error(f"特征值维度不一致:{dims},无法计算平均值")
return None
avg_feature = np.mean(processed_features, axis=0)
logger.info(f"计算成功:{len(processed_features)} 个特征值的平均向量(维度:{avg_feature.shape[0]}")
return avg_feature
except Exception as e:
logger.error(f"计算平均特征值出错:{str(e)}", exc_info=True)
return None
def clear_features():
global _feature_list
_feature_list = []
logger.info("已清空所有特征数据")
def get_feature_list():
global _feature_list
logger.info(f"当前特征列表长度:{len(_feature_list)}")
return _feature_list.copy()

84
util/file_util.py Normal file
View File

@ -0,0 +1,84 @@
import os
import datetime
from pathlib import Path
from typing import Dict
def save_face_to_up_images(
client_ip: str,
face_name: str,
image_bytes: bytes,
image_format: str = "jpg"
) -> Dict[str, str]:
"""
保存人脸图片到 `/up_images/用户IP/人脸名字/` 路径
确保db_path以up_images开头且统一使用正斜杠
参数:
client_ip: 客户端IP原始格式如192.168.1.101
face_name: 人脸名称(用户输入,可为空)
image_bytes: 人脸图片二进制数据
image_format: 图片格式默认jpg
返回:
字典success是否成功、db_path存数据库的相对路径、local_abs_path本地绝对路径、msg提示
"""
try:
# 1. 基础参数校验
if not client_ip.strip():
return {"success": False, "db_path": "", "local_abs_path": "", "msg": "客户端IP不能为空"}
if not image_bytes:
return {"success": False, "db_path": "", "local_abs_path": "", "msg": "图片二进制数据为空"}
if image_format.lower() not in ["jpg", "jpeg", "png"]:
return {"success": False, "db_path": "", "local_abs_path": "", "msg": "仅支持jpg/jpeg/png格式"}
# 2. 处理特殊字符(避免路径错误)
safe_ip = client_ip.strip().replace(".", "_") # IP中的.替换为_
safe_face_name = face_name.strip() if (face_name and face_name.strip()) else "未命名"
safe_face_name = "".join([c for c in safe_face_name if c not in r'\/:*?"<>|']) # 过滤非法字符
# 3. 构建根目录(强制转为绝对路径,避免相对路径混淆)
root_dir = Path("up_images").resolve() # 转为绝对路径如D:/Git/bin/video/up_images
if not root_dir.exists():
root_dir.mkdir(parents=True, exist_ok=True)
print(f"[FileUtil] 已创建up_images根目录{root_dir}")
# 4. 构建文件层级路径确保在root_dir子目录下
ip_dir = root_dir / safe_ip
face_name_dir = ip_dir / safe_face_name
face_name_dir.mkdir(parents=True, exist_ok=True) # 自动创建目录
print(f"[FileUtil] 图片存储目录:{face_name_dir}")
# 5. 生成唯一文件名(毫秒级时间戳)
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3]
image_filename = f"face_{safe_ip}_{timestamp}.{image_format.lower()}"
# 6. 计算路径确保所有路径都是绝对路径且在root_dir下
local_abs_path = face_name_dir / image_filename # 绝对路径
# 验证路径是否在root_dir下防止路径穿越攻击
if not local_abs_path.resolve().is_relative_to(root_dir.resolve()):
raise Exception(f"图片路径不在up_images根目录下安全校验失败{local_abs_path}")
# 数据库存储路径强制包含up_images前缀统一使用正斜杠
relative_path = local_abs_path.relative_to(root_dir.parent) # 相对于root_dir的父目录
db_path = str(relative_path).replace("\\", "/") # 此时会包含up_images部分
# 7. 写入图片文件
with open(local_abs_path, "wb") as f:
f.write(image_bytes)
print(f"[FileUtil] 图片保存成功:")
print(f" 数据库路径:{db_path}")
print(f" 本地绝对路径:{local_abs_path}")
return {
"success": True,
"db_path": db_path, # 格式为 up_images/192_168_110_31/小龙/xxx.jpg
"local_abs_path": str(local_abs_path), # 本地绝对路径(完整路径)
"msg": "图片保存成功"
}
except Exception as e:
error_msg = f"图片保存失败:{str(e)}"
print(f"[FileUtil] 错误:{error_msg}")
return {"success": False, "db_path": "", "local_abs_path": "", "msg": error_msg}

61
util/model_util.py Normal file
View File

@ -0,0 +1,61 @@
import os
import numpy as np
import traceback
from ultralytics import YOLO
from typing import Optional
def load_yolo_model(model_path: str) -> Optional[YOLO]:
"""
加载YOLO模型支持v5/v8并校验模型有效性
:param model_path: 模型文件的绝对路径
:return: 加载成功返回YOLO模型实例失败返回None
"""
try:
# 加载前的基础信息检查
print(f"\n[模型工具] 开始加载模型:{model_path}")
print(f"[模型工具] 文件是否存在:{os.path.exists(model_path)}")
if os.path.exists(model_path):
print(f"[模型工具] 文件大小:{os.path.getsize(model_path) / 1024 / 1024:.2f} MB")
# 强制重新加载模型,避免缓存问题
model = YOLO(model_path)
# 兼容性校验使用numpy空数组测试模型
dummy_image = np.zeros((640, 640, 3), dtype=np.uint8)
try:
# 优先使用新版本参数
model.predict(
source=dummy_image,
imgsz=640,
conf=0.25,
verbose=False,
stream=False
)
except Exception as pred_e:
print(f"[模型工具] 预测校验兼容处理:{str(pred_e)}")
# 兼容旧版本YOLO参数
model.predict(
img=dummy_image,
imgsz=640,
conf=0.25,
verbose=False
)
# 验证模型基本属性
if not hasattr(model, 'names'):
print("[模型工具] 警告:模型缺少类别名称属性")
else:
print(f"[模型工具] 模型包含类别:{list(model.names.values())[:5]}...") # 显示前5个类别
print(f"[模型工具] 模型加载成功!")
return model
except Exception as e:
# 详细错误信息输出
print(f"\n[模型工具] 加载模型失败!路径:{model_path}")
print(f"[模型工具] 异常类型:{type(e).__name__}")
print(f"[模型工具] 异常详情:{str(e)}")
print(f"[模型工具] 堆栈跟踪:\n{traceback.format_exc()}")
return None

482
ws.html
View File

@ -1,482 +0,0 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>WebSocket 测试工具</title>
<style>
* {
box-sizing: border-box;
margin: 0;
padding: 0;
font-family: 'Arial', 'Microsoft YaHei', sans-serif;
}
body {
max-width: 1200px;
margin: 20px auto;
padding: 0 20px;
background-color: #f5f7fa;
}
.container {
background: white;
border-radius: 8px;
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
padding: 25px;
margin-bottom: 20px;
}
h1 {
color: #2c3e50;
margin-bottom: 20px;
font-size: 24px;
border-bottom: 2px solid #3498db;
padding-bottom: 10px;
}
.status-bar {
display: flex;
align-items: center;
gap: 15px;
margin-bottom: 20px;
padding: 12px 15px;
background-color: #f8f9fa;
border-radius: 6px;
}
.status-label {
font-weight: bold;
color: #495057;
}
.status-value {
padding: 4px 10px;
border-radius: 4px;
font-weight: bold;
}
.status-connected {
background-color: #d4edda;
color: #155724;
}
.status-disconnected {
background-color: #f8d7da;
color: #721c24;
}
.status-connecting {
background-color: #fff3cd;
color: #856404;
}
.btn {
padding: 8px 16px;
border: none;
border-radius: 4px;
cursor: pointer;
font-size: 14px;
font-weight: 500;
transition: background-color 0.2s;
}
.btn-primary {
background-color: #3498db;
color: white;
}
.btn-primary:hover {
background-color: #2980b9;
}
.btn-danger {
background-color: #e74c3c;
color: white;
}
.btn-danger:hover {
background-color: #c0392b;
}
.btn-success {
background-color: #2ecc71;
color: white;
}
.btn-success:hover {
background-color: #27ae60;
}
.control-group {
display: flex;
gap: 15px;
margin-bottom: 20px;
align-items: center;
}
.input-group {
display: flex;
gap: 10px;
align-items: center;
}
.input-group label {
color: #495057;
font-weight: 500;
}
.input-group input, .input-group select {
padding: 8px 12px;
border: 1px solid #ced4da;
border-radius: 4px;
font-size: 14px;
}
.message-area {
margin-top: 20px;
}
.message-input {
width: 100%;
height: 100px;
padding: 12px;
border: 1px solid #ced4da;
border-radius: 6px;
resize: none;
font-size: 14px;
margin-bottom: 10px;
}
.log-area {
width: 100%;
height: 300px;
padding: 15px;
border: 1px solid #ced4da;
border-radius: 6px;
background-color: #f8f9fa;
overflow-y: auto;
font-size: 14px;
line-height: 1.6;
}
.log-item {
margin-bottom: 8px;
padding-bottom: 8px;
border-bottom: 1px dashed #e9ecef;
}
.log-time {
color: #6c757d;
font-size: 12px;
margin-right: 10px;
}
.log-send {
color: #2980b9;
}
.log-receive {
color: #27ae60;
}
.log-status {
color: #856404;
}
.log-error {
color: #e74c3c;
}
</style>
</head>
<body>
<div class="container">
<h1>WebSocket 测试工具</h1>
<!-- 连接状态区 -->
<div class="status-bar">
<div class="status-label">连接状态:</div>
<div id="connectionStatus" class="status-value status-disconnected">未连接</div>
<div class="status-label">服务地址:</div>
<div id="wsUrl" class="status-value">ws://192.168.110.25:8000/ws</div>
<div class="status-label">连接时间:</div>
<div id="connectTime" class="status-value">-</div>
</div>
<!-- 控制按钮区 -->
<div class="control-group">
<button id="connectBtn" class="btn btn-primary">建立连接</button>
<button id="disconnectBtn" class="btn btn-danger" disabled>断开连接</button>
<!-- 心跳控制 -->
<div class="input-group">
<label>自动心跳:</label>
<select id="autoHeartbeat">
<option value="on">开启</option>
<option value="off">关闭</option>
</select>
<label>间隔(秒)</label>
<input type="number" id="heartbeatInterval" value="30" min="10" max="120" style="width: 80px;">
<button id="sendHeartbeatBtn" class="btn btn-success">手动发送心跳</button>
</div>
</div>
<!-- 自定义消息发送区 -->
<div class="message-area">
<h3>发送自定义消息</h3>
<textarea id="messageInput" class="message-input"
placeholder='示例:{"type":"test","content":"Hello WebSocket"}'>{"type":"test","content":"Hello WebSocket"}</textarea>
<button id="sendMessageBtn" class="btn btn-primary" disabled>发送消息</button>
</div>
<!-- 日志显示区 -->
<div class="message-area">
<h3>消息日志</h3>
<div id="logContainer" class="log-area">
<div class="log-item"><span class="log-time">[加载完成]</span> 请点击「建立连接」开始测试</div>
</div>
<button id="clearLogBtn" class="btn btn-primary" style="margin-top: 10px;">清空日志</button>
</div>
</div>
<script>
// 全局变量
let ws = null;
let heartbeatTimer = null;
const wsUrl = "ws://192.168.110.25:8000/ws";
// DOM 元素
const connectionStatus = document.getElementById('connectionStatus');
const connectTime = document.getElementById('connectTime');
const connectBtn = document.getElementById('connectBtn');
const disconnectBtn = document.getElementById('disconnectBtn');
const sendMessageBtn = document.getElementById('sendMessageBtn');
const sendHeartbeatBtn = document.getElementById('sendHeartbeatBtn');
const autoHeartbeat = document.getElementById('autoHeartbeat');
const heartbeatInterval = document.getElementById('heartbeatInterval');
const messageInput = document.getElementById('messageInput');
const logContainer = document.getElementById('logContainer');
const clearLogBtn = document.getElementById('clearLogBtn');
// 工具函数:添加日志
function addLog(content, type = 'status') {
const now = new Date().toLocaleString('zh-CN', {
year: 'numeric', month: '2-digit', day: '2-digit',
hour: '2-digit', minute: '2-digit', second: '2-digit'
});
const logItem = document.createElement('div');
logItem.className = 'log-item';
let logClass = '';
switch (type) {
case 'send':
logClass = 'log-send';
break;
case 'receive':
logClass = 'log-receive';
break;
case 'error':
logClass = 'log-error';
break;
default:
logClass = 'log-status';
}
logItem.innerHTML = `<span class="log-time">[${now}]</span> <span class="${logClass}">${content}</span>`;
logContainer.appendChild(logItem);
// 滚动到最新日志
logContainer.scrollTop = logContainer.scrollHeight;
}
// 工具函数格式化JSON便于日志显示
function formatJson(jsonStr) {
try {
const obj = JSON.parse(jsonStr);
return JSON.stringify(obj, null, 2);
} catch (e) {
return jsonStr; // 非JSON格式直接返回
}
}
// 建立WebSocket连接
function connectWebSocket() {
if (ws) {
addLog('已存在连接,无需重复建立', 'error');
return;
}
try {
ws = new WebSocket(wsUrl);
// 连接成功
ws.onopen = function () {
connectionStatus.className = 'status-value status-connected';
connectionStatus.textContent = '已连接';
const now = new Date().toLocaleString('zh-CN');
connectTime.textContent = now;
addLog(`连接成功!服务地址:${wsUrl}`, 'status');
// 更新按钮状态
connectBtn.disabled = true;
disconnectBtn.disabled = false;
sendMessageBtn.disabled = false;
// 开启自动心跳(默认开启)
if (autoHeartbeat.value === 'on') {
startAutoHeartbeat();
}
};
// 接收消息
ws.onmessage = function (event) {
const message = event.data;
addLog(`收到消息:\n${formatJson(message)}`, 'receive');
};
// 连接关闭
ws.onclose = function (event) {
connectionStatus.className = 'status-value status-disconnected';
connectionStatus.textContent = '已断开';
addLog(`连接断开!代码:${event.code},原因:${event.reason || '未知'}`, 'status');
// 清除自动心跳
stopAutoHeartbeat();
// 更新按钮状态
connectBtn.disabled = false;
disconnectBtn.disabled = true;
sendMessageBtn.disabled = true;
// 重置WebSocket对象
ws = null;
};
// 连接错误
ws.onerror = function (error) {
addLog(`连接错误:${error.message || '未知错误'}`, 'error');
};
} catch (e) {
addLog(`建立连接失败:${e.message}`, 'error');
ws = null;
}
}
// 断开WebSocket连接
function disconnectWebSocket() {
if (!ws) {
addLog('当前无连接,无需断开', 'error');
return;
}
ws.close(1000, '手动断开连接');
}
// 发送心跳消息(符合约定格式:{"timestamp":xxxxx, "type":"heartbeat"}
function sendHeartbeat() {
if (!ws || ws.readyState !== WebSocket.OPEN) {
addLog('发送心跳失败:当前无有效连接', 'error');
return;
}
const heartbeatMsg = {
timestamp: Date.now(), // 当前毫秒时间戳
type: "heartbeat"
};
const msgStr = JSON.stringify(heartbeatMsg);
ws.send(msgStr);
addLog(`发送心跳:\n${formatJson(msgStr)}`, 'send');
}
// 开启自动心跳
function startAutoHeartbeat() {
// 先停止已有定时器
stopAutoHeartbeat();
const interval = parseInt(heartbeatInterval.value) * 1000;
if (isNaN(interval) || interval < 10000) {
addLog('自动心跳间隔无效已重置为30秒', 'error');
heartbeatInterval.value = 30;
return startAutoHeartbeat();
}
addLog(`开启自动心跳,间隔:${heartbeatInterval.value}`, 'status');
heartbeatTimer = setInterval(sendHeartbeat, interval);
}
// 停止自动心跳
function stopAutoHeartbeat() {
if (heartbeatTimer) {
clearInterval(heartbeatTimer);
heartbeatTimer = null;
addLog('已停止自动心跳', 'status');
}
}
// 发送自定义消息
function sendCustomMessage() {
if (!ws || ws.readyState !== WebSocket.OPEN) {
addLog('发送消息失败:当前无有效连接', 'error');
return;
}
const msgStr = messageInput.value.trim();
if (!msgStr) {
addLog('发送消息失败:消息内容不能为空', 'error');
return;
}
try {
// 验证JSON格式可选仅提示不强制
JSON.parse(msgStr);
ws.send(msgStr);
addLog(`发送自定义消息:\n${formatJson(msgStr)}`, 'send');
} catch (e) {
addLog(`JSON格式错误${e.message},仍尝试发送原始内容`, 'error');
ws.send(msgStr);
addLog(`发送自定义消息非JSON\n${msgStr}`, 'send');
}
}
// 绑定按钮事件
connectBtn.addEventListener('click', connectWebSocket);
disconnectBtn.addEventListener('click', disconnectWebSocket);
sendMessageBtn.addEventListener('click', sendCustomMessage);
sendHeartbeatBtn.addEventListener('click', sendHeartbeat);
clearLogBtn.addEventListener('click', () => {
logContainer.innerHTML = '<div class="log-item"><span class="log-time">[日志已清空]</span> 请继续操作...</div>';
});
// 自动心跳开关变更事件
autoHeartbeat.addEventListener('change', function () {
if (ws && ws.readyState === WebSocket.OPEN) {
if (this.value === 'on') {
startAutoHeartbeat();
} else {
stopAutoHeartbeat();
}
} else {
addLog('需先建立有效连接才能控制自动心跳', 'error');
// 重置选择
this.value = 'off';
}
});
// 心跳间隔变更事件(实时生效)
heartbeatInterval.addEventListener('change', function () {
if (autoHeartbeat.value === 'on' && ws && ws.readyState === WebSocket.OPEN) {
startAutoHeartbeat();
}
});
// 快捷键支持Ctrl+Enter发送消息
messageInput.addEventListener('keydown', function (e) {
if (e.ctrlKey && e.key === 'Enter') {
sendCustomMessage();
e.preventDefault();
}
});
</script>
</body>
</html>

338
ws/ws.py
View File

@ -3,288 +3,296 @@ import datetime
import json
import os
from contextlib import asynccontextmanager
from typing import Dict, Optional, AsyncGenerator
from typing import Dict, Optional
from service.device_service import update_online_status_by_ip, increment_alarm_count_by_ip
from service.device_action_service import add_device_action
from schema.device_action_schema import DeviceActionCreate
from core.all import detect, load_model
import cv2
import numpy as np
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI
from ocr.model_violation_detector import MultiModelViolationDetector
# 配置文件相对路径(根据实际目录结构调整)
YOLO_MODEL_PATH = r"D:\Git\bin\video\ocr\models\best.pt"
FORBIDDEN_WORDS_PATH = r"D:\Git\bin\video\ocr\forbidden_words.txt"
OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml"
KNOWN_FACES_DIR = r"D:\Git\bin\video\ocr\known_faces"
# 创建检测器实例
detector = MultiModelViolationDetector(
forbidden_words_path=FORBIDDEN_WORDS_PATH,
ocr_config_path=OCR_CONFIG_PATH,
yolo_model_path=YOLO_MODEL_PATH,
known_faces_dir=KNOWN_FACES_DIR,
ocr_confidence_threshold=0.5
)
# -------------------------- 配置常量 --------------------------
# 配置常量
HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒)
HEARTBEAT_TIMEOUT = 600 # 客户端超时阈值(秒)
WS_ENDPOINT = "/ws" # WebSocket端点路径
FRAME_QUEUE_SIZE = 1 # 帧队列大小限制保持1确保单帧处理
# -------------------------- 核心数据结构与全局变量 --------------------------
ws_router = APIRouter()
FRAME_QUEUE_SIZE = 1 # 帧队列大小限制
# 客户端连接封装(包含帧队列)
# 工具函数: 获取格式化时间字符串
def get_current_time_str() -> str:
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def get_current_time_file_str() -> str:
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
# 客户端连接封装
class ClientConnection:
def __init__(self, websocket: WebSocket, client_ip: str):
self.websocket = websocket
self.client_ip = client_ip
self.client_ip = client_ip # 已初始化客户端IP用于传递给detect
self.last_heartbeat = datetime.datetime.now()
self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE) # 帧队列长度为1
self.consumer_task: Optional[asyncio.Task] = None # 消费者任务
self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE)
self.consumer_task: Optional[asyncio.Task] = None
# 更新心跳时间
def update_heartbeat(self):
"""更新心跳时间"""
self.last_heartbeat = datetime.datetime.now()
# 检查是否存活超时返回False
def is_alive(self) -> bool:
timeout = (datetime.datetime.now() - self.last_heartbeat).total_seconds()
return timeout < HEARTBEAT_TIMEOUT
"""判断客户端是否存活"""
timeout_seconds = (datetime.datetime.now() - self.last_heartbeat).total_seconds()
return timeout_seconds < HEARTBEAT_TIMEOUT
# 启动帧消费任务
def start_consumer(self):
"""启动帧消费任务"""
self.consumer_task = asyncio.create_task(self.consume_frames())
return self.consumer_task
# ---------- 新增:发送“允许发送二进制帧”的信号给客户端 ----------
async def send_allow_send_frame(self):
"""向客户端发送JSON信号通知其可发送下一帧二进制数据"""
async def send_frame_permit(self):
"""发送帧发送许可信号"""
try:
allow_msg = {
"type": "allow_send_frame", # 信号类型,与客户端约定
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"status": "ready", # 表示服务器已准备好接收下一帧
"client_ip": self.client_ip # 可选:便于客户端确认自身身份
frame_permit_msg = {
"type": "frame",
"timestamp": get_current_time_str(),
"client_ip": self.client_ip
}
await self.websocket.send_json(allow_msg)
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}已发送「允许发送帧」信号")
await self.websocket.send_json(frame_permit_msg)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 已发送帧发送许可信号")
except Exception as e:
# 发送失败大概率是客户端已断开,不影响主流程,仅日志记录
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}:发送「允许发送帧」信号失败 - {str(e)}")
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧许可信号发送失败 - {str(e)}")
# 帧消费协程
async def consume_frames(self) -> None:
"""队列中获取帧并进行处理,处理完后通知客户端可发送下一帧"""
"""消费队列中的帧并处理"""
try:
while True:
# 从队列获取帧数据(队列空时会阻塞,等待客户端发送)
# 取出帧并立即发送下一帧许可
frame_data = await self.frame_queue.get()
await self.send_frame_permit()
try:
# 处理帧数据
await self.process_frame(frame_data)
finally:
# 标记任务完成(队列计数-1此时队列回到空状态
self.frame_queue.task_done()
# ---------- 修改:处理完当前帧后,立即通知客户端可发送下一帧 ----------
await self.send_allow_send_frame()
except asyncio.CancelledError:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}帧消费任务已取消")
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费任务已取消")
except Exception as e:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}:帧处理错误 - {str(e)}")
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费逻辑错误 - {str(e)}")
async def process_frame(self, frame_data: bytes) -> None:
"""处理单帧图像数据(原有逻辑不变"""
# 二进制数据转换为NumPy数组uint8类型
"""处理单帧图像数据(核心修改detect函数传入 client_ip + img 双参数"""
# 二进制转OpenCV图像
nparr = np.frombuffer(frame_data, np.uint8)
# 解码为图像返回与cv2.imread相同的格式BGR通道的ndarray
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
# 确保images文件夹存在
if not os.path.exists('images'):
os.makedirs('images')
# 生成唯一的文件名包含时间戳和客户端IP避免文件名冲突
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"images/{self.client_ip.replace('.', '_')}_{timestamp}.jpg"
if img is None:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 无法解析图像数据")
return
try:
# 保存图像到本地
cv2.imwrite(filename, img)
print(f"[{datetime.datetime.now():%H:%M:%S}] 图像已保存至:{filename}")
# -------------------------- 核心修改按要求传入参数1.client_ip 2.img --------------------------
# detect函数参数顺序第一个为client_ip第二个为图像数据img
# 保持返回值解包(是否违规, 结果数据, 检测器类型)不变
has_violation, data, detector_type = await asyncio.to_thread(
detect, # 调用检测函数
self.client_ip, # 第一个参数客户端IP新增按需求顺序
img # 第二个参数:图像数据(原参数,调整顺序)
)
# -------------------------------------------------------------------------------------
# 进行检测
if img is not None:
has_violation, violation_type, details = detector.detect_violations(img)
if has_violation:
print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
# 发送检测结果回客户端(原有逻辑不变)
await self.websocket.send_json({
"type": "detection_result",
"has_violation": has_violation,
"violation_type": violation_type,
"details": details,
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
})
else:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}:未检测到任何违规内容")
# 打印检测结果包含客户端IP与传入参数对应
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测结果 - "
f"违规: {has_violation}, 类型: {detector_type}, 数据: {data}")
# 处理违规逻辑逻辑不变基于detect返回结果执行
if has_violation:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - "
f"类型: {detector_type}, 详情: {data}")
# 违规次数+1
try:
await asyncio.to_thread(increment_alarm_count_by_ip, self.client_ip)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数已+1")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数更新失败 - {str(e)}")
# 发送危险通知
danger_msg = {
"type": "danger",
"timestamp": get_current_time_str(),
"client_ip": self.client_ip,
"detail": data
}
await self.websocket.send_json(danger_msg)
else:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}:无法解析图像数据")
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 未检测到违规")
except Exception as e:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}图像处理错误 - {str(e)}")
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 图像处理错误 - {str(e)}")
# 全局连接管理IP -> 连接实例)
# 全局状态管理
connected_clients: Dict[str, ClientConnection] = {}
# 心跳任务(全局引用,用于关闭时清理)
heartbeat_task: Optional[asyncio.Task] = None
# -------------------------- 心跳检查逻辑(原有逻辑不变) --------------------------
# 心跳检查任务
async def heartbeat_checker():
while True:
now = datetime.datetime.now()
# 1. 筛选超时客户端(避免遍历中修改字典)
current_time = get_current_time_str()
timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()]
# 2. 处理超时连接(关闭+移除)
if timeout_ips:
print(f"[{now:%H:%M:%S}] 心跳检查{len(timeout_ips)}个客户端超时({timeout_ips}")
print(f"[{current_time}] 心跳检查: {len(timeout_ips)}个客户端超时(IP: {timeout_ips}")
for ip in timeout_ips:
try:
# 取消消费者任务
if connected_clients[ip].consumer_task and not connected_clients[ip].consumer_task.done():
connected_clients[ip].consumer_task.cancel()
await connected_clients[ip].websocket.close(code=1008, reason="心跳超时")
conn = connected_clients[ip]
if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel()
await conn.websocket.close(code=1008, reason="心跳超时")
# 标记离线
try:
await asyncio.to_thread(update_online_status_by_ip, ip, 0)
action_data = DeviceActionCreate(client_ip=ip, action=0)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{current_time}] 客户端{ip}: 已标记为离线并记录操作")
except Exception as e:
print(f"[{current_time}] 客户端{ip}: 离线状态更新失败 - {str(e)}")
finally:
connected_clients.pop(ip, None)
else:
print(f"[{now:%H:%M:%S}] 心跳检查{len(connected_clients)}个客户端在线,无超时")
print(f"[{current_time}] 心跳检查: {len(connected_clients)}个客户端在线")
# 3. 等待下一轮检查
await asyncio.sleep(HEARTBEAT_INTERVAL)
# -------------------------- 应用生命周期(原有逻辑不变) --------------------------
# 应用生命周期管理
@asynccontextmanager
async def lifespan(app: FastAPI):
global heartbeat_task
# 启动心跳任务
heartbeat_task = asyncio.create_task(heartbeat_checker())
print(f"[{datetime.datetime.now():%H:%M:%S}] 心跳任务启动ID{id(heartbeat_task)}")
print(f"[{get_current_time_str()}] 全局心跳检查任务启动任务ID: {id(heartbeat_task)}")
yield
# 关闭时取消心跳任务
if heartbeat_task and not heartbeat_task.done():
heartbeat_task.cancel()
try:
await heartbeat_task
print(f"[{datetime.datetime.now():%H:%M:%S}] 心跳任务已取消")
print(f"[{get_current_time_str()}] 全局心跳检查任务已取消")
except asyncio.CancelledError:
pass
# -------------------------- 消息处理(文本/心跳逻辑不变,二进制逻辑保留) --------------------------
async def send_heartbeat_ack(client_ip: str):
"""回复心跳确认(原有逻辑不变)"""
if client_ip not in connected_clients:
return False
# 消息处理工具函数
async def send_heartbeat_ack(conn: ClientConnection):
try:
ack = {
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"type": "heartbeat"
heartbeat_ack_msg = {
"type": "heart",
"timestamp": get_current_time_str(),
"client_ip": conn.client_ip
}
await connected_clients[client_ip].websocket.send_json(ack)
await conn.websocket.send_json(heartbeat_ack_msg)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 已发送心跳确认")
return True
except Exception:
connected_clients.pop(client_ip, None)
except Exception as e:
connected_clients.pop(conn.client_ip, None)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 心跳确认发送失败 - {str(e)}")
return False
async def handle_text_msg(client_ip: str, text: str, conn: ClientConnection):
"""处理文本消息(核心:心跳+JSON解析原有逻辑不变"""
async def handle_text_msg(conn: ClientConnection, text: str):
try:
msg = json.loads(text)
# 仅处理心跳类型消息
if msg.get("type") == "heartbeat":
if msg.get("type") == "heart":
conn.update_heartbeat()
await send_heartbeat_ack(client_ip)
await send_heartbeat_ack(conn)
else:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:收到文本消息{msg}")
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 未知文本消息类型({msg.get('type')}")
except json.JSONDecodeError:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}无效JSON消息")
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 无效JSON文本消息")
async def handle_binary_msg(client_ip: str, data: bytes):
"""处理二进制消息(原有逻辑不变,因客户端仅在收到允许信号后发送,队列不会满)"""
if client_ip not in connected_clients:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:连接不存在,丢弃{len(data)}字节数据")
return
conn = connected_clients[client_ip]
# 检查队列是否已满(理论上不会触发,因客户端按信号发送)
if conn.frame_queue.full():
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:队列已满,丢弃{len(data)}字节数据")
return
# 队列未满,添加帧到队列
async def handle_binary_msg(conn: ClientConnection, data: bytes):
try:
conn.frame_queue.put_nowait(data)
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:已接收{len(data)}字节二进制数据,加入队列")
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 图像数据({len(data)}字节)已加入队列")
except asyncio.QueueFull:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:队列突然满了,丢弃{len(data)}字节数据")
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 帧队列已满、丢弃当前图像数据")
# WebSocket路由配置
ws_router = APIRouter()
# -------------------------- WebSocket核心端点修改连接初始化逻辑 --------------------------
@ws_router.websocket(WS_ENDPOINT)
async def websocket_endpoint(websocket: WebSocket):
# 接受连接 + 获取客户端IP
load_model() # 加载检测模型(仅在连接建立时加载一次,避免重复加载)
await websocket.accept()
client_ip = websocket.client.host if websocket.client else "unknown"
now = datetime.datetime.now()
print(f"[{now:%H:%M:%S}] 客户端{client_ip}:连接成功")
client_ip = websocket.client.host if websocket.client else "unknown_ip"
current_time = get_current_time_str()
print(f"[{current_time}] 客户端{client_ip}: WebSocket连接已建立")
is_online_updated = False
consumer_task = None
try:
# 处理重复连接(关闭旧连接)
# 处理重复连接(同一IP断开旧连接)
if client_ip in connected_clients:
# 取消旧连接的消费者任务
if connected_clients[client_ip].consumer_task and not connected_clients[client_ip].consumer_task.done():
connected_clients[client_ip].consumer_task.cancel()
await connected_clients[client_ip].websocket.close(code=1008, reason="同一IP新连接")
old_conn = connected_clients[client_ip]
if old_conn.consumer_task and not old_conn.consumer_task.done():
old_conn.consumer_task.cancel()
await old_conn.websocket.close(code=1008, reason="同一IP新连接建立")
connected_clients.pop(client_ip)
print(f"[{now:%H:%M:%S}] 客户端{client_ip}关闭旧连接")
print(f"[{current_time}] 客户端{client_ip}: 已关闭旧连接")
# 注册新连接
# 注册新连接绑定client_ip和WebSocket
new_conn = ClientConnection(websocket, client_ip)
connected_clients[client_ip] = new_conn
new_conn.start_consumer() # 启动帧消费任务
await new_conn.send_frame_permit() # 发送首次帧许可
# 启动帧消费任务
consumer_task = new_conn.start_consumer()
# ---------- 修改:客户端刚连接时,队列空,立即发送「允许发送帧」信号 ----------
await new_conn.send_allow_send_frame()
print(f"[{now:%H:%M:%S}] 客户端{client_ip}:注册成功,已启动帧消费任务,当前在线{len(connected_clients)}")
# 标记客户端上线
try:
await asyncio.to_thread(update_online_status_by_ip, client_ip, 1)
action_data = DeviceActionCreate(client_ip=client_ip, action=1)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{current_time}] 客户端{client_ip}: 已标记为在线并记录操作")
is_online_updated = True
except Exception as e:
print(f"[{current_time}] 客户端{client_ip}: 上线状态更新失败 - {str(e)}")
# 循环接收消息(原有逻辑不变)
print(f"[{current_time}] 客户端{client_ip}: 新连接注册成功、在线数: {len(connected_clients)}")
# 消息循环(持续接收客户端消息)
while True:
data = await websocket.receive()
if "text" in data:
await handle_text_msg(client_ip, data["text"], new_conn)
await handle_text_msg(new_conn, data["text"])
elif "bytes" in data:
await handle_binary_msg(client_ip, data["bytes"])
await handle_binary_msg(new_conn, data["bytes"])
# 异常处理(断开/错误)
except WebSocketDisconnect as e:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}主动断开(代码{e.code}")
print(f"[{get_current_time_str()}] 客户端{client_ip}: 主动断开连接(代码: {e.code}")
except Exception as e:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}连接异常{str(e)[:50]}")
print(f"[{get_current_time_str()}] 客户端{client_ip}: 连接异常 - {str(e)[:50]}")
finally:
# 清理连接和任务
# 清理资源(断开后标记离线+删除连接)
if client_ip in connected_clients:
# 取消消费者任务
if connected_clients[client_ip].consumer_task and not connected_clients[client_ip].consumer_task.done():
connected_clients[client_ip].consumer_task.cancel()
conn = connected_clients[client_ip]
if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel()
# 仅当上线状态更新成功时,才执行离线更新
if is_online_updated:
try:
await asyncio.to_thread(update_online_status_by_ip, client_ip, 0)
action_data = DeviceActionCreate(client_ip=client_ip, action=0)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{get_current_time_str()}] 客户端{client_ip}: 断开后已标记为离线")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}: 断开后离线更新失败 - {str(e)}")
connected_clients.pop(client_ip, None)
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:连接已清理,当前在线{len(connected_clients)}")
print(f"[{get_current_time_str()}] 客户端{client_ip}: 资源已清理、在线数: {len(connected_clients)}")