Compare commits

..

17 Commits

Author SHA1 Message Date
435b2a0e6c 路径写入数据库 2025-09-10 10:53:07 +08:00
ae177ca14a 从服务器读取IP并将检测数据写入数据库 2025-09-10 08:57:56 +08:00
d3c4820b73 识别结果保存到对应目录下后不显示完整路径 2025-09-09 17:09:34 +08:00
532a9e75e9 识别结果保存到对应目录下 2025-09-09 16:30:12 +08:00
0fe49bf829 paddleocr 2025-09-09 09:42:23 +08:00
2571da3c2d 去除本地存储 | 优化代码风格 2025-09-08 18:24:32 +08:00
1dd832e18d 修改WS兼容检测的Future对象 2025-09-08 18:10:49 +08:00
8ceb92c572 优化代码风格 2025-09-08 17:34:23 +08:00
9b3d20511a 最新可用 2025-09-05 17:23:50 +08:00
30bf7c9fcb 最新可用 2025-09-04 22:59:27 +08:00
ec6dbfde90 优化 2025-09-04 17:33:20 +08:00
3ed73bd9eb 1 2025-09-04 17:29:52 +08:00
08f8a0e44e 优化 2025-09-04 17:08:25 +08:00
b5d870a19c 优化 2025-09-04 12:29:27 +08:00
ea82a33a8f 平均特征值计算 2025-09-04 10:46:05 +08:00
bae7785a97 Merge remote-tracking branch 'origin/master' 2025-09-04 10:40:06 +08:00
49d2c71fdd 1 2025-09-04 10:39:41 +08:00
51 changed files with 1707 additions and 2006 deletions

2
.idea/Video.iml generated
View File

@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4"> <module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager"> <component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" /> <content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="video" jdkType="Python SDK" /> <orderEntry type="jdk" jdkName="Python 3.10" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" /> <orderEntry type="sourceFolder" forTests="false" />
</component> </component>
</module> </module>

2
.idea/misc.xml generated
View File

@ -3,5 +3,5 @@
<component name="Black"> <component name="Black">
<option name="sdkName" value="video" /> <option name="sdkName" value="video" />
</component> </component>
<component name="ProjectRootManager" version="2" project-jdk-name="video" project-jdk-type="Python SDK" /> <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10" project-jdk-type="Python SDK" />
</project> </project>

View File

@ -12,8 +12,4 @@ charset = utf8mb4
[jwt] [jwt]
secret_key = 6tsieyd87wefdw2wgeduwte23rfcsd secret_key = 6tsieyd87wefdw2wgeduwte23rfcsd
algorithm = HS256 algorithm = HS256
access_token_expire_minutes = 30 access_token_expire_minutes = 30
[live]
rtmp_url = rtmp://192.168.110.65:1935/live/
webrtc_url = http://192.168.110.65:1985/rtc/v1/whep/?app=live&stream=

106
core/all.py Normal file
View File

@ -0,0 +1,106 @@
import cv2
import numpy as np
from PIL.Image import Image
from core.establish import get_image_save_path
from core.ocr import load_model as ocrLoadModel, detect as ocrDetect
from core.face import load_model as faceLoadModel, detect as faceDetect
from core.yolo import load_model as yoloLoadModel, detect as yoloDetect
# 导入保存路径函数(根据实际文件位置调整导入路径)
import numpy as np
import base64
from io import BytesIO
from PIL import Image
from ds.db import db
from mysql.connector import Error as MySQLError
# 模型加载状态标记(避免重复加载)
_model_loaded = False
def load_model():
"""加载所有检测模型(仅首次调用时执行)"""
global _model_loaded
if _model_loaded:
print("模型已加载,无需重复执行")
return
# 依次加载OCR、人脸、YOLO模型
ocrLoadModel()
faceLoadModel()
yoloLoadModel()
_model_loaded = True
print("所有检测模型加载完成")
def save_db(model_type, client_ip, result):
conn = None
cursor = None
try:
# 连接数据库
conn = db.get_connection()
# 往表插入数据
cursor = conn.cursor(dictionary=True) # 返回字典格式结果
insert_query = """
INSERT INTO device_danger (client_ip, type, result)
VALUES (%s, %s, %s)
"""
cursor.execute(insert_query, (client_ip, model_type, result))
conn.commit()
except MySQLError as e:
raise Exception(f"获取设备列表失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
def detect(client_ip, frame):
"""
执行模型检测,检测到违规时按指定格式保存图片
参数:
frame: 待检测的图像帧OpenCV格式numpy.ndarray类型
返回:
(检测结果布尔值, 检测详情, 检测模型类型)
"""
# 1. YOLO检测优先级1
yolo_flag, yolo_result = yoloDetect(frame)
print(f"YOLO检测结果{yolo_result}")
if yolo_flag:
full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip)
if full_save_path: # 只判断完整路径是否有效(用于保存)
cv2.imwrite(full_save_path, frame)
# 打印时使用「显示用短路径」,符合需求格式
print(f"✅ YOLO违规图片已保存{display_path}")
save_db(model_type="yolo", client_ip=client_ip, result=str(full_save_path))
return (True, yolo_result, "yolo")
#
# # 2. 人脸检测优先级2
face_flag, face_result = faceDetect(frame)
print(f"人脸检测结果:{face_result}")
if face_flag:
full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip)
if full_save_path: # 只判断完整路径是否有效(用于保存)
cv2.imwrite(full_save_path, frame)
# 打印时使用「显示用短路径」,符合需求格式
print(f"✅ face违规图片已保存{display_path}")
save_db(model_type="face", client_ip=client_ip, result=str(full_save_path))
return (True, face_result, "face")
# 3. OCR检测优先级3
ocr_flag, ocr_result = ocrDetect(frame)
print(f"OCR检测结果{ocr_result}")
if ocr_flag:
# 解构元组,保存用完整路径,打印用短路径
full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip)
if full_save_path: # 只判断完整路径是否有效(用于保存)
cv2.imwrite(full_save_path, frame)
# 打印时使用「显示用短路径」,符合需求格式
print(f"✅ ocr违规图片已保存{display_path}")
save_db(model_type="ocr", client_ip=client_ip, result=str(full_save_path))
return (True, ocr_result, "ocr")
# 4. 无违规内容(不保存图片)
print(f"❌ 未检测到任何违规内容,不保存图片")
return (False, "未检测到任何内容", "none")

120
core/establish.py Normal file
View File

@ -0,0 +1,120 @@
import os
import datetime
from pathlib import Path
from service.device_service import get_unique_client_ips
def create_directory_structure():
"""创建项目所需的目录结构为所有客户端IP预创建基础目录"""
try:
# 1. 创建根目录下的resource文件夹存在则跳过不覆盖子内容
resource_dir = Path("resource")
resource_dir.mkdir(exist_ok=True)
print(f"确保resource目录存在: {resource_dir.absolute()}")
# 2. 在resource下创建dect文件夹
dect_dir = resource_dir / "dect"
dect_dir.mkdir(exist_ok=True)
print(f"确保dect目录存在: {dect_dir.absolute()}")
# 3. 在dect下创建三个模型文件夹
model_dirs = ["ocr", "face", "yolo"]
for model in model_dirs:
model_dir = dect_dir / model
model_dir.mkdir(exist_ok=True)
print(f"确保{model}模型目录存在: {model_dir.absolute()}")
# 4. 调用外部方法获取所有客户端IP地址
try:
# 调用外部ip_read()方法获取所有客户端IP地址列表
all_ip_addresses = get_unique_client_ips()
# 确保返回的是列表类型
if not isinstance(all_ip_addresses, list):
all_ip_addresses = [all_ip_addresses]
# 过滤有效IP去除空字符串和空格
valid_ips = [ip.strip() for ip in all_ip_addresses if ip.strip()]
if not valid_ips:
print("警告: 未获取到有效的客户端IP地址")
return
print(f"获取到的所有客户端IP地址: {valid_ips}")
# 5. 获取当前日期(年、月)
now = datetime.datetime.now()
current_year = str(now.year)
current_month = str(now.month)
# 6. 为每个客户端IP在每个模型文件夹下创建年->月的基础目录结构
for ip in valid_ips:
# 处理IP地址中的特殊字符将.替换为_避免路径问题
safe_ip = ip.replace(".", "_")
for model in model_dirs:
# 构建路径: resource/dect/{model}/{safe_ip}/{year}/{month}
ip_dir = dect_dir / model / safe_ip
year_dir = ip_dir / current_year
month_dir = year_dir / current_month
# 递归创建目录(存在则跳过,不覆盖)
month_dir.mkdir(parents=True, exist_ok=True)
print(f"为客户端IP {ip} 创建/确保目录存在: {month_dir.absolute()}")
except Exception as e:
print(f"处理客户端IP和日期目录时发生错误: {str(e)}")
except Exception as e:
print(f"创建基础目录结构时发生错误: {str(e)}")
def get_image_save_path(model_type: str, client_ip: str) -> tuple:
"""
获取图片保存的「完整路径」和「显示用短路径」
参数:
model_type: 模型类型,应为"ocr""face""yolo"
client_ip: 检测到违禁的客户端IP地址原始格式如192.168.1.101
返回:
元组 (完整保存路径, 显示用短路径);若出错则返回 ("", "")
"""
try:
# 1. 验证客户端IP有效性检查是否在已知IP列表中
all_ip_addresses = get_unique_client_ips()
if not isinstance(all_ip_addresses, list):
all_ip_addresses = [all_ip_addresses]
valid_ips = [ip.strip() for ip in all_ip_addresses if ip.strip()]
if client_ip.strip() not in valid_ips:
raise ValueError(f"客户端IP {client_ip} 不在已知IP列表中无法保存文件")
# 2. 处理IP地址与目录创建逻辑一致将.替换为_
safe_ip = client_ip.strip().replace(".", "_")
# 3. 获取当前日期和毫秒级时间戳(确保文件名唯一)
now = datetime.datetime.now()
current_year = str(now.year)
current_month = str(now.month)
current_day = str(now.day)
# 时间戳格式年月日时分秒毫秒如20250910143050123
timestamp = now.strftime("%Y%m%d%H%M%S%f")[:-3]
# 4. 定义基础目录(用于生成相对路径)
base_dir = Path("resource") / "dect" # 显示路径会去掉这个前缀
# 构建日级目录完整路径resource/dect/{model}/{safe_ip}/{年}/{月}/{日}
day_dir = base_dir / model_type / safe_ip / current_year / current_month / current_day
day_dir.mkdir(parents=True, exist_ok=True) # 确保日目录存在
# 5. 构建唯一文件名
image_filename = f"dect_{model_type}_{safe_ip}_{current_year}{current_month}{current_day}_{timestamp}.jpg"
# 6. 生成完整路径(用于实际保存图片)和显示路径(用于打印)
full_path = day_dir / image_filename # 完整路径resource/dect/.../xxx.jpg
display_path = full_path.relative_to(base_dir) # 短路径:{model}/.../xxx.jpg去掉resource/dect
return str(full_path), str(display_path)
except Exception as e:
print(f"获取图片保存路径时发生错误: {str(e)}")
return "", ""

326
core/face.py Normal file
View File

@ -0,0 +1,326 @@
import os
import numpy as np
import cv2
import gc
import time
import threading
from PIL import Image
from insightface.app import FaceAnalysis
# 假设service.face_service中get_all_face_name_with_eigenvalue可获取人脸数据
from service.face_service import get_all_face_name_with_eigenvalue
# GPU状态检查支持
try:
import pynvml
pynvml.nvmlInit()
_nvml_available = True
except ImportError:
print("警告: pynvml库未安装无法检测GPU状态默认尝试使用GPU")
_nvml_available = False
# 全局人脸引擎与特征库
_face_app = None
_known_faces_embeddings = {} # 姓名 -> 归一化特征值的映射
_known_faces_names = [] # 已知人脸姓名列表
# GPU使用状态标记
_using_gpu = False # 是否使用GPU
_used_gpu_id = -1 # 使用的GPU ID-1表示CPU
# 资源管理变量
_ref_count = 0 # 引擎引用计数(记录当前使用次数)
_last_used_time = 0 # 最后一次使用引擎的时间
_lock = threading.Lock() # 线程安全锁
_release_timeout = 8 # 闲置超时时间(秒)
_is_releasing = False # 资源释放中标记
_monitor_thread_running = False # 监控线程运行标记
# 调试计数器
_debug_counter = {
"engine_created": 0, # 引擎创建次数
"engine_released": 0, # 引擎释放次数
"detection_calls": 0 # 检测函数调用次数
}
def check_gpu_availability(gpu_id, memory_threshold=0.7):
"""检查指定GPU的内存使用率是否低于阈值判定为“可用”"""
if not _nvml_available:
return True
try:
handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
memory_usage = mem_info.used / mem_info.total
return memory_usage < memory_threshold
except Exception as e:
print(f"检查GPU {gpu_id} 状态失败: {e}")
return False
def select_best_gpu(preferred_gpus=[0, 1]):
"""按优先级选择可用GPU优先0号均不可用则返回-1CPU"""
for gpu_id in preferred_gpus:
try:
# 验证GPU是否存在
if _nvml_available:
pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
# 验证GPU内存是否充足
if check_gpu_availability(gpu_id):
print(f"GPU {gpu_id} 可用将使用该GPU")
return gpu_id
else:
if gpu_id == 0:
print("GPU 0 内存使用率过高尝试其他GPU")
except Exception as e:
print(f"GPU {gpu_id} 不可用或访问失败: {e}")
print("所有指定GPU均不可用将使用CPU计算")
return -1
def _release_engine_resources():
"""释放人脸引擎的所有资源模型、特征库、GPU缓存等"""
global _face_app, _is_releasing, _known_faces_embeddings, _known_faces_names
if not _face_app or _is_releasing:
return
try:
_is_releasing = True
print("开始释放人脸引擎资源...")
# 释放InsightFace模型资源
if hasattr(_face_app, "model"):
_face_app.model = None # 显式置空模型引用
_face_app = None # 释放引擎实例
# 清空人脸特征库
_known_faces_embeddings.clear()
_known_faces_names.clear()
_debug_counter["engine_released"] += 1
print(f"人脸引擎已释放,调试统计: {_debug_counter}")
# 强制垃圾回收
gc.collect()
# 清理各深度学习框架的GPU缓存
# Torch 缓存清理
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
print("Torch GPU缓存已清理")
except ImportError:
pass
# TensorFlow 缓存清理
try:
import tensorflow as tf
tf.keras.backend.clear_session()
print("TensorFlow会话已清理")
except ImportError:
pass
# MXNet 缓存清理InsightFace底层常用MXNet
try:
import mxnet as mx
mx.nd.waitall() # 等待所有计算完成并释放资源
print("MXNet资源已等待释放")
except ImportError:
pass
except Exception as e:
print(f"释放资源过程中出错: {e}")
finally:
_is_releasing = False
def _resource_monitor_thread():
"""后台监控线程:检测引擎闲置超时,触发资源释放"""
global _ref_count, _last_used_time, _face_app, _monitor_thread_running
_monitor_thread_running = True
while _monitor_thread_running:
time.sleep(2) # 缩短检查间隔,加快闲置检测响应
with _lock:
# 当“引擎存在 + 无引用 + 未在释放中”时,检查闲置时间
if _face_app and _ref_count == 0 and not _is_releasing:
idle_time = time.time() - _last_used_time
if idle_time > _release_timeout:
print(f"引擎闲置超时({idle_time:.1f}s > {_release_timeout}s释放资源")
_release_engine_resources()
def load_model(prefer_gpu=True, preferred_gpus=[0, 1]):
"""加载人脸识别引擎及已知人脸特征库默认优先用0号GPU"""
global _face_app, _known_faces_embeddings, _known_faces_names, _using_gpu, _used_gpu_id
# 启动后台监控线程(确保仅启动一次)
if not _monitor_thread_running:
threading.Thread(
target=_resource_monitor_thread,
daemon=True,
name="FaceEngineMonitor"
).start()
print("人脸引擎监控线程已启动")
# 若正在释放资源,等待释放完成
while _is_releasing:
time.sleep(0.1)
# 若引擎已初始化,直接返回
if _face_app:
return True
# 初始化InsightFace引擎
try:
print("正在初始化InsightFace人脸识别引擎...")
_face_app = FaceAnalysis(name="buffalo_l", root=os.path.expanduser("~/.insightface"))
# 选择GPU优先用0号
ctx_id = 0
if prefer_gpu:
ctx_id = select_best_gpu(preferred_gpus)
_using_gpu = ctx_id != -1
_used_gpu_id = ctx_id if _using_gpu else -1
if _using_gpu:
print(f"引擎初始化成功将使用GPU {ctx_id} 计算")
else:
print("引擎初始化成功将使用CPU计算")
# 准备模型(加载到指定设备)
_face_app.prepare(ctx_id=ctx_id, det_size=(640, 640))
print("InsightFace引擎初始化完成")
_debug_counter["engine_created"] += 1
print(f"引擎调试统计: {_debug_counter}")
except Exception as e:
print(f"引擎初始化失败: {e}")
return False
# 从服务加载已知人脸的姓名和特征值
try:
face_data = get_all_face_name_with_eigenvalue()
for person_name, eigenvalue_data in face_data.items():
# 兼容“numpy数组”和“字符串”格式的特征值
if isinstance(eigenvalue_data, np.ndarray):
eigenvalue = eigenvalue_data.astype(np.float32)
elif isinstance(eigenvalue_data, str):
# 清理字符串中的括号、换行等干扰符
cleaned = eigenvalue_data.replace("[", "").replace("]", "").replace("\n", "").strip()
# 分割并转换为浮点数数组
values = [v for v in cleaned.split() if v] # 兼容空格/逗号分隔
eigenvalue = np.array(list(map(float, values)), dtype=np.float32)
else:
print(f"不支持的特征值类型({type(eigenvalue_data)}),跳过 {person_name}")
continue
# 特征值归一化(保证后续相似度计算的一致性)
norm = np.linalg.norm(eigenvalue)
if norm != 0:
eigenvalue = eigenvalue / norm
_known_faces_embeddings[person_name] = eigenvalue
_known_faces_names.append(person_name)
print(f"成功加载 {len(_known_faces_names)} 个人脸的特征库")
except Exception as e:
print(f"加载人脸特征库失败: {e}")
return _face_app is not None
def detect(frame, similarity_threshold=0.4):
"""
检测并识别人脸
返回:(是否匹配到已知人脸, 结果描述字符串)
"""
global _face_app, _known_faces_embeddings, _known_faces_names, _ref_count, _last_used_time
# 校验输入帧有效性
if frame is None or frame.size == 0:
return (False, "无效的输入帧数据")
# 加锁并更新引用计数、最后使用时间
engine = None
with _lock:
_ref_count += 1
_last_used_time = time.time()
_debug_counter["detection_calls"] += 1
# 若引擎未初始化且未在释放中,尝试初始化
if not _face_app and not _is_releasing:
if not load_model(prefer_gpu=True):
# 初始化失败,恢复引用计数
with _lock:
_ref_count = max(0, _ref_count - 1)
return (False, "人脸引擎初始化失败")
engine = _face_app # 获取引擎引用
# 校验引擎可用性
if not engine or len(_known_faces_names) == 0:
with _lock:
_ref_count = max(0, _ref_count - 1)
return (False, "人脸引擎不可用或特征库为空")
try:
# GPU计算时确保帧数据是连续内存避免CUDA错误
if _using_gpu and engine is not None and not frame.flags.contiguous:
frame = np.ascontiguousarray(frame)
# 执行人脸检测与特征提取
faces = engine.get(frame)
except Exception as e:
print(f"人脸检测过程出错: {e}")
# 出错时尝试重新初始化引擎可能是GPU状态变化导致
print("尝试重新初始化人脸引擎...")
with _lock:
_ref_count = max(0, _ref_count - 1)
load_model(prefer_gpu=True)
return (False, f"检测错误: {str(e)}")
result_parts = []
has_matched_known_face = False # 是否有任意人脸匹配到已知库
for face in faces:
# 归一化当前检测到的人脸特征
face_embedding = face.embedding.astype(np.float32)
norm = np.linalg.norm(face_embedding)
if norm == 0:
continue
face_embedding = face_embedding / norm
# 与已知人脸特征逐一比对
max_similarity, best_match_name = -1.0, "Unknown"
for name in _known_faces_names:
known_emb = _known_faces_embeddings[name]
similarity = np.dot(face_embedding, known_emb) # 余弦相似度
if similarity > max_similarity:
max_similarity = similarity
best_match_name = name
# 判断是否匹配成功
is_matched = max_similarity >= similarity_threshold
if is_matched:
has_matched_known_face = True
# 记录该人脸的检测结果
bbox = face.bbox # 人脸边界框
result_parts.append(
f"{'匹配' if is_matched else '未匹配'}: {best_match_name} "
f"(相似度: {max_similarity:.2f}, 边界框: {bbox.astype(int).tolist()})"
)
# 构建最终结果字符串
result_str = "未检测到人脸" if not result_parts else "; ".join(result_parts)
# 释放引用计数(线程安全)
with _lock:
_ref_count = max(0, _ref_count - 1)
# 若仍有引用更新最后使用时间若引用为0也立即标记加快闲置检测
_last_used_time = time.time()
return (has_matched_known_face, result_str)

BIN
core/models/best.pt Normal file

Binary file not shown.

253
core/ocr.py Normal file
View File

@ -0,0 +1,253 @@
import os
import cv2
import gc
import time
import threading
import numpy as np
from paddleocr import PaddleOCR
from service.sensitive_service import get_all_sensitive_words
# 解决NumPy 1.20+版本中np.int已移除的兼容性问题
try:
if not hasattr(np, 'int'):
np.int = int
except Exception as e:
print(f"处理NumPy兼容性时出错: {e}")
# 全局变量
_ocr_engine = None
_forbidden_words = set()
_conf_threshold = 0.5
# 资源管理变量
_ref_count = 0
_last_used_time = 0
_lock = threading.Lock()
_release_timeout = 5 # 30秒无使用则释放
_is_releasing = False # 标记是否正在释放
# 并行处理配置
_max_workers = 4 # 并行处理的线程数
# 调试用计数器
_debug_counter = {
"created": 0,
"released": 0,
"detected": 0
}
def _release_engine():
"""释放OCR引擎资源"""
global _ocr_engine, _is_releasing
if not _ocr_engine or _is_releasing:
return
try:
_is_releasing = True
_ocr_engine = None
_debug_counter["released"] += 1
print(f"OCR engine released. Stats: {_debug_counter}")
# 清理GPU缓存
gc.collect()
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
except ImportError:
pass
try:
import paddle
if paddle.is_compiled_with_cuda():
paddle.device.cuda.empty_cache()
except ImportError:
pass
finally:
_is_releasing = False
def _monitor_thread():
"""监控线程,优化检查逻辑"""
global _ref_count, _last_used_time, _ocr_engine
while True:
time.sleep(5) # 每5秒检查一次
with _lock:
if _ocr_engine and _ref_count == 0 and not _is_releasing:
elapsed = time.time() - _last_used_time
if elapsed > _release_timeout:
print(f"Idle timeout ({elapsed:.1f}s > {_release_timeout}s), releasing engine")
_release_engine()
def load_model():
"""加载违禁词列表和初始化监控线程"""
global _forbidden_words
# 确保监控线程只启动一次
if not any(t.name == "OCRMonitor" for t in threading.enumerate()):
threading.Thread(target=_monitor_thread, daemon=True, name="OCRMonitor").start()
print("OCR monitor thread started")
# 加载违禁词
try:
_forbidden_words = get_all_sensitive_words()
print(f"Loaded {len(_forbidden_words)} forbidden words")
except Exception as e:
print(f"Forbidden words load error: {e}")
return False
return True
def detect(frame):
"""OCR检测支持并行处理"""
global _ocr_engine, _forbidden_words, _conf_threshold, _ref_count, _last_used_time, _max_workers
# 验证前置条件
if not _forbidden_words:
return (False, "违禁词未初始化")
if frame is None or frame.size == 0:
return (False, "无效帧数据")
# 增加引用计数并获取引擎实例
engine = None
with _lock:
_ref_count += 1
_last_used_time = time.time()
_debug_counter["detected"] += 1
# 初始化引擎(如果未初始化且不在释放中)
if not _ocr_engine and not _is_releasing:
try:
# 初始化PaddleOCR设置并行处理参数
_ocr_engine = PaddleOCR(
use_angle_cls=True,
lang="ch",
show_log=False,
use_gpu=True,
max_text_length=1024,
threads=_max_workers
)
_debug_counter["created"] += 1
print(f"PaddleOCR engine initialized with {_max_workers} workers. Stats: {_debug_counter}")
except Exception as e:
print(f"OCR model load failed: {e}")
_ref_count -= 1
return (False, f"引擎初始化失败: {str(e)}")
engine = _ocr_engine
# 检查引擎是否可用
if not engine:
with _lock:
_ref_count -= 1
return (False, "OCR引擎不可用")
try:
# 执行OCR检测
ocr_res = engine.ocr(frame, cls=True)
# 验证OCR结果格式
if not ocr_res or not isinstance(ocr_res, list):
return (False, "无OCR结果")
# 处理OCR结果 - 兼容多种格式
texts = []
confs = []
for line in ocr_res:
if line is None:
continue
# 处理line可能是列表或直接是文本信息的情况
if isinstance(line, list):
items_to_process = line
else:
items_to_process = [line]
for item in items_to_process:
# 精确识别并忽略图片坐标位置信息 [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
if isinstance(item, list) and len(item) == 4: # 四边形有4个顶点
is_coordinate = True
for point in item:
# 每个顶点应该是包含2个数字的列表
if not (isinstance(point, list) and len(point) == 2 and
all(isinstance(coord, (int, float)) for coord in point)):
is_coordinate = False
break
if is_coordinate:
continue # 是坐标信息,直接忽略
# 跳过纯数字列表(其他可能的坐标形式)
if isinstance(item, list) and all(isinstance(x, (int, float)) for x in item):
continue
# 处理元组形式的文本和置信度 (text, confidence)
if isinstance(item, tuple) and len(item) == 2:
text, conf = item
if isinstance(text, str) and isinstance(conf, (int, float)):
texts.append(text.strip())
confs.append(float(conf))
continue
# 处理列表形式的[坐标信息, (text, confidence)]
if isinstance(item, list) and len(item) >= 2:
# 尝试从列表中提取文本和置信度
text_data = item[1]
if isinstance(text_data, tuple) and len(text_data) == 2:
text, conf = text_data
if isinstance(text, str) and isinstance(conf, (int, float)):
texts.append(text.strip())
confs.append(float(conf))
continue
elif isinstance(text_data, str):
# 只有文本没有置信度的情况
texts.append(text_data.strip())
confs.append(1.0) # 默认最高置信度
continue
# 无法识别的格式,记录日志
print(f"无法解析的OCR结果格式: {item}")
if len(texts) != len(confs):
return (False, "OCR结果格式异常")
# 筛选违禁词
vio_info = []
for txt, conf in zip(texts, confs):
if conf < _conf_threshold:
continue
matched = [w for w in _forbidden_words if w in txt]
if matched:
vio_info.append(f"文本: '{txt}' 包含违禁词: {', '.join(matched)} (置信度: {conf:.2f})")
# 构建结果
has_text = len(texts) > 0
has_violation = len(vio_info) > 0
if not has_text:
return (False, "未识别到文本")
elif has_violation:
return (True, "; ".join(vio_info))
else:
return (False, "未检测到违禁词")
except Exception as e:
print(f"OCR detect error: {e}")
return (False, f"检测错误: {str(e)}")
finally:
# 减少引用计数,确保线程安全
with _lock:
_ref_count = max(0, _ref_count - 1)
if _ref_count > 0:
_last_used_time = time.time()
def batch_detect(frames):
"""批量检测接口,充分利用并行能力"""
results = []
for frame in frames:
results.append(detect(frame))
return results

View File

@ -1,137 +0,0 @@
import asyncio
import logging
from aiortc import RTCPeerConnection, RTCSessionDescription
import aiohttp
from ocr.ocr_violation_detector import OCRViolationDetector
import logging
# 创建检测器实例
detector = OCRViolationDetector(
forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt",
ocr_confidence_threshold=0.7,
log_level=logging.INFO,
log_file="ocr_detection.log"
)
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("whep_video_puller")
async def whep_pull_video_stream(ip,whep_url):
"""
通过WHEP从指定URL拉取视频流并在收到每一帧时打印消息
Args:
whep_url: WHEP端点的URL
"""
pc = RTCPeerConnection()
# 添加连接状态变化监听
@pc.on("connectionstatechange")
async def on_connectionstatechange():
print(f"连接状态: {pc.connectionState}")
# 添加ICE连接状态变化监听
@pc.on("iceconnectionstatechange")
async def on_iceconnectionstatechange():
print(f"ICE连接状态: {pc.iceConnectionState}")
# 添加视频接收器
pc.addTransceiver("video", direction="recvonly")
# 处理接收到的视频轨道
@pc.on("track")
def on_track(track):
print(f"接收到轨道: {track.kind}")
if track.kind == "video":
print(f"轨道ID: {track.id}")
print(f"轨道就绪状态: {track.readyState}")
# 创建异步任务来处理视频帧
asyncio.ensure_future(handle_video_track(track))
async def handle_video_track(track):
"""处理视频轨道,接收并打印每一帧"""
frame_count = 0
print("开始处理视频轨道...")
while True:
try:
# 尝试接收帧
frame = await track.recv()
frame_count += 1
print(f"收到原始帧 (第{frame_count}帧)")
# 打印帧的基本信息
if hasattr(frame, 'width') and hasattr(frame, 'height'):
print(f" 尺寸: {frame.width}x{frame.height}")
if hasattr(frame, 'time_base'):
print(f" 时间基准: {frame.time_base}")
if hasattr(frame, 'pts'):
print(f" 显示时间戳: {frame.pts}")
has_violation, violations, confidences = OCRViolationDetector.detect(frame)
# 输出检测结果
if has_violation:
detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
for word, conf in zip(violations, confidences):
detector.logger.info(f"- {word} (置信度: {conf:.4f})")
else:
detector.logger.info("图片中未检测到违禁词")
except Exception as e:
print(f"接收帧时出错: {e}")
# 等待一段时间后重试
await asyncio.sleep(0.1)
continue
# 创建offer
offer = await pc.createOffer()
await pc.setLocalDescription(offer)
print(f"本地SDP信息:\n{offer.sdp}")
# 通过HTTP POST发送offer到WHEP端点
async with aiohttp.ClientSession() as session:
async with session.post(
whep_url,
data=offer.sdp,
headers={"Content-Type": "application/sdp"}
) as response:
if response.status != 201:
print(f"WHEP服务器返回错误: {response.status}")
print(f"响应内容: {await response.text()}")
raise Exception(f"WHEP服务器返回错误: {response.status}")
# 获取answer SDP
answer_sdp = await response.text()
# 创建RTCSessionDescription对象
answer = RTCSessionDescription(sdp=answer_sdp, type="answer")
print(f"收到远程SDP:\n{answer_sdp}")
# 设置远程描述
await pc.setRemoteDescription(answer)
print("连接已建立,开始接收视频流...")
# 保持连接,直到用户中断
try:
while True:
await asyncio.sleep(1)
# 检查连接状态
print(f"当前连接状态: {pc.connectionState}")
except KeyboardInterrupt:
print("用户中断,关闭连接...")
finally:
await pc.close()
if __name__ == "__main__":
# 替换为你的WHEP端点URL
WHEP_URL = "http://192.168.110.25:1985/rtc/v1/whep/?app=live&stream=473b95a47e338301cbd96809ea7ac416"
# 运行拉流任务
asyncio.run(whep_pull_video_stream(WHEP_URL))

View File

@ -1,112 +0,0 @@
import asyncio
import logging
import cv2
import time
from ocr.model_violation_detector import MultiModelViolationDetector
# 配置文件相对路径(根据实际目录结构调整)
YOLO_MODEL_PATH = "../ocr/models/best.pt" # 关键修正从core目录向上一级找ocr文件夹
FORBIDDEN_WORDS_PATH = "../ocr/forbidden_words.txt"
OCR_CONFIG_PATH = "../ocr/config/1.yaml"
KNOWN_FACES_DIR = "../ocr/known_faces"
# 创建检测器实例
detector = MultiModelViolationDetector(
forbidden_words_path=FORBIDDEN_WORDS_PATH,
ocr_config_path=OCR_CONFIG_PATH,
yolo_model_path=YOLO_MODEL_PATH,
known_faces_dir=KNOWN_FACES_DIR,
ocr_confidence_threshold=0.5
)
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("rtmp_video_puller")
async def rtmp_pull_video_stream(rtmp_url):
"""
通过RTMP从指定URL拉取视频流并进行违规检测
"""
cap = None # 初始化视频捕获对象
try:
# 异步打开RTMP流
cap = await asyncio.to_thread(
cv2.VideoCapture,
rtmp_url,
cv2.CAP_FFMPEG # 指定FFmpeg后端确保RTMP兼容性
)
# 检查RTMP流是否成功打开
is_opened = await asyncio.to_thread(cap.isOpened)
if not is_opened:
raise Exception(f"RTMP流打开失败: {rtmp_url}请检查URL有效性和FFmpeg环境")
# 获取RTMP流基础信息
width = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_WIDTH)
height = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_HEIGHT)
fps = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FPS)
# 处理异常情况
fps = fps if fps > 0 else 30.0
width, height = int(width), int(height)
# 打印流初始化成功信息
print(f"RTMP流状态: 已成功连接")
print(f"流基础信息: 分辨率 {width}x{height} | 配置帧率 {fps:.2f} FPS")
print("开始接收视频帧...(按 Ctrl+C 中断)")
# 初始化帧统计参数
frame_count = 0
start_time = time.time()
# 循环读取视频帧
while True:
ret, frame = await asyncio.to_thread(cap.read)
if not ret:
print(f"RTMP流状态: 帧读取失败(可能流已中断或结束)")
break
frame_count += 1
# 打印当前帧信息
print(f"收到帧 (第{frame_count}帧)")
print(f" 帧尺寸: {width}x{height}")
print(f" 配置帧率: {fps:.2f} FPS")
if frame is not None:
has_violation, violation_type, details = detector.detect_violations(frame)
if has_violation:
print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
else:
print("未检测到任何违规内容")
else:
print(f"无法读取测试图像")
# 每100帧统计一次实际接收帧率
if frame_count % 100 == 0:
elapsed_time = time.time() - start_time
actual_fps = frame_count / elapsed_time
print(f"---- 帧统计: 累计{frame_count}帧 | 实际平均帧率 {actual_fps:.2f} FPS ----")
except KeyboardInterrupt:
print(f"\n用户操作: 已通过 Ctrl+C 中断程序")
except Exception as e:
logger.error(f"RTMP流处理异常: {str(e)}", exc_info=True)
print(f"错误信息: {str(e)}")
finally:
if cap is not None:
await asyncio.to_thread(cap.release)
print(f"\n资源释放: RTMP流已关闭")
print(f"最终统计: 共接收 {frame_count if 'frame_count' in locals() else 0}")
if __name__ == "__main__":
RTMP_URL = "rtmp://192.168.110.25:1935/live/473b95a47e338301cbd96809ea7ac416"
try:
asyncio.run(rtmp_pull_video_stream(RTMP_URL))
except Exception as e:
print(f"程序启动失败: {str(e)}")

54
core/yolo.py Normal file
View File

@ -0,0 +1,54 @@
import os
from ultralytics import YOLO
# 全局变量
_yolo_model = None
model_path = os.path.join(os.path.dirname(__file__), "models", "best.pt")
def load_model():
"""加载YOLO目标检测模型"""
global _yolo_model
try:
_yolo_model = YOLO(model_path)
except Exception as e:
print(f"YOLO model load failed: {e}")
return False
return True if _yolo_model else False
def detect(frame, conf_threshold=0.2):
"""YOLO目标检测、返回(是否识别到, 结果字符串)"""
global _yolo_model
if not _yolo_model or frame is None:
return (False, "未初始化或无效帧")
try:
results = _yolo_model(frame, conf=conf_threshold)
# 检查是否有检测结果
has_results = len(results[0].boxes) > 0 if results else False
if not has_results:
return (False, "未检测到目标")
# 构建结果字符串
result_parts = []
for box in results[0].boxes:
cls = int(box.cls[0])
conf = float(box.conf[0])
bbox = [float(x) for x in box.xyxy[0]]
class_name = _yolo_model.names[cls] if hasattr(_yolo_model, 'names') else f"类别{cls}"
result_parts.append(f"{class_name} (置信度: {conf:.2f}, 边界框: {bbox})")
result_str = "; ".join(result_parts)
return (has_results, result_str)
except Exception as e:
print(f"YOLO detect error: {e}")
return (False, f"检测错误: {str(e)}")

View File

@ -14,4 +14,3 @@ config.read(config_path, encoding="utf-8")
SERVER_CONFIG = config["server"] SERVER_CONFIG = config["server"]
MYSQL_CONFIG = config["mysql"] MYSQL_CONFIG = config["mysql"]
JWT_CONFIG = config["jwt"] JWT_CONFIG = config["jwt"]
LIVE_CONFIG = config["live"]

25
main.py
View File

@ -1,11 +1,18 @@
import uvicorn from PIL import Image # 正确导入
from fastapi import FastAPI import numpy as np
import uvicorn
from PIL import Image
from fastapi import FastAPI
from core.all import load_model,detect
from ds.config import SERVER_CONFIG from ds.config import SERVER_CONFIG
from middle.error_handler import global_exception_handler from middle.error_handler import global_exception_handler
from service.user_service import router as user_router from service.user_service import router as user_router
from service.sensitive_service import router as sensitive_router
from service.face_service import router as face_router
from service.device_service import router as device_router from service.device_service import router as device_router
from ws.ws import ws_router, lifespan from ws.ws import ws_router, lifespan
from core.establish import create_directory_structure
# ------------------------------ # ------------------------------
# 初始化 FastAPI 应用、指定生命周期管理 # 初始化 FastAPI 应用、指定生命周期管理
@ -22,6 +29,8 @@ app = FastAPI(
# ------------------------------ # ------------------------------
app.include_router(user_router) app.include_router(user_router)
app.include_router(device_router) app.include_router(device_router)
app.include_router(face_router)
app.include_router(sensitive_router)
app.include_router(ws_router) app.include_router(ws_router)
# ------------------------------ # ------------------------------
@ -33,11 +42,21 @@ app.add_exception_handler(Exception, global_exception_handler)
# 启动服务 # 启动服务
# ------------------------------ # ------------------------------
if __name__ == "__main__": if __name__ == "__main__":
# -------------------------- 配置调整 --------------------------
# 模型配置路径(建议改为环境变量)
YOLO_MODEL_PATH = r"/core/models\best.pt"
OCR_CONFIG_PATH = r"/core/config\config.yaml"
create_directory_structure()
# 初始化项目默认端口设为8000、避免初始化失败时port未定义
port = int(SERVER_CONFIG.get("port", 8000)) port = int(SERVER_CONFIG.get("port", 8000))
# 启动 UVicorn 服务
uvicorn.run( uvicorn.run(
app="main:app", app="main:app",
host="0.0.0.0", host="0.0.0.0",
port=port, port=port,
reload=True, workers=8,
ws="websockets" ws="websockets"
) )

View File

@ -8,7 +8,6 @@ from passlib.context import CryptContext
from ds.config import JWT_CONFIG from ds.config import JWT_CONFIG
from ds.db import db from ds.db import db
from service.user_service import UserResponse
# ------------------------------ # ------------------------------
# 密码加密配置 # 密码加密配置
@ -22,9 +21,10 @@ SECRET_KEY = JWT_CONFIG["secret_key"]
ALGORITHM = JWT_CONFIG["algorithm"] ALGORITHM = JWT_CONFIG["algorithm"]
ACCESS_TOKEN_EXPIRE_MINUTES = int(JWT_CONFIG["access_token_expire_minutes"]) ACCESS_TOKEN_EXPIRE_MINUTES = int(JWT_CONFIG["access_token_expire_minutes"])
# OAuth2 依赖(从请求头获取 Token、格式Bearer <token> # OAuth2 依赖(从请求头获取 Token、格式: Bearer <token>
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login")
# ------------------------------ # ------------------------------
# 密码工具函数 # 密码工具函数
# ------------------------------ # ------------------------------
@ -32,10 +32,12 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
"""验证明文密码与加密密码是否匹配""" """验证明文密码与加密密码是否匹配"""
return pwd_context.verify(plain_password, hashed_password) return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str: def get_password_hash(password: str) -> str:
"""对明文密码进行 bcrypt 加密""" """对明文密码进行 bcrypt 加密"""
return pwd_context.hash(password) return pwd_context.hash(password)
# ------------------------------ # ------------------------------
# JWT 工具函数 # JWT 工具函数
# ------------------------------ # ------------------------------
@ -53,11 +55,15 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt return encoded_jwt
# ------------------------------ # ------------------------------
# 认证依赖(获取当前登录用户) # 认证依赖(获取当前登录用户)
# ------------------------------ # ------------------------------
def get_current_user(token: str = Depends(oauth2_scheme)) -> UserResponse: def get_current_user(token: str = Depends(oauth2_scheme)): # 移除返回类型注解
"""从 Token 中解析用户信息、验证通过后返回当前用户""" """从 Token 中解析用户信息、验证通过后返回当前用户"""
# 延迟导入、打破循环依赖
from schema.user_schema import UserResponse # 在这里导入
# 认证失败异常 # 认证失败异常
credentials_exception = HTTPException( credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -89,8 +95,8 @@ def get_current_user(token: str = Depends(oauth2_scheme)) -> UserResponse:
raise credentials_exception # 用户不存在 raise credentials_exception # 用户不存在
# 转换为 UserResponse 模型(自动校验字段) # 转换为 UserResponse 模型(自动校验字段)
return UserResponse(** user) return UserResponse(**user)
except Exception as e: except Exception as e:
raise credentials_exception from e raise credentials_exception from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)

View File

@ -8,7 +8,7 @@ from schema.response_schema import APIResponse
async def global_exception_handler(request: Request, exc: Exception): async def global_exception_handler(request: Request, exc: Exception):
"""全局异常处理器所有未捕获的异常都会在这里统一处理""" """全局异常处理器: 所有未捕获的异常都会在这里统一处理"""
# 1. 请求参数验证错误Pydantic 校验失败) # 1. 请求参数验证错误Pydantic 校验失败)
if isinstance(exc, RequestValidationError): if isinstance(exc, RequestValidationError):
error_details = [] error_details = []
@ -18,7 +18,7 @@ async def global_exception_handler(request: Request, exc: Exception):
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
content=APIResponse( content=APIResponse(
code=400, code=400,
message=f"请求参数错误{'; '.join(error_details)}", message=f"请求参数错误: {'; '.join(error_details)}",
data=None data=None
).model_dump() ).model_dump()
) )
@ -52,7 +52,7 @@ async def global_exception_handler(request: Request, exc: Exception):
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=APIResponse( content=APIResponse(
code=500, code=500,
message=f"数据库错误{str(exc)}", message=f"数据库错误: {str(exc)}",
data=None data=None
).model_dump() ).model_dump()
) )
@ -62,7 +62,7 @@ async def global_exception_handler(request: Request, exc: Exception):
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=APIResponse( content=APIResponse(
code=500, code=500,
message=f"服务器内部错误{str(exc)}", message=f"服务器内部错误: {str(exc)}",
data=None data=None
).model_dump() ).model_dump()
) )

View File

@ -1,139 +0,0 @@
import os
import cv2
import numpy as np
import insightface
from insightface.app import FaceAnalysis
class FaceRecognizer:
"""
封装InsightFace人脸识别功能支持从文件夹加载已知人脸。
"""
def __init__(self, known_faces_dir: str):
self.known_faces_dir = known_faces_dir
self.app = self._initialize_insightface()
self.known_faces_embeddings = {}
self.known_faces_names = []
self._load_known_faces()
def _initialize_insightface(self):
"""初始化InsightFace FaceAnalysis应用"""
print("初始化InsightFace引擎...")
try:
app = FaceAnalysis(name='buffalo_l', root='~/.insightface')
app.prepare(ctx_id=0, det_size=(640, 640))
print("InsightFace引擎初始化完成")
return app
except Exception as e:
print(f"InsightFace初始化失败: {e}")
print("请检查依赖是否安装及模型是否可访问")
return None
def _load_known_faces(self):
"""加载已知人脸特征"""
if not os.path.exists(self.known_faces_dir):
print(f"已知人脸目录不存在,已创建: {self.known_faces_dir}")
os.makedirs(self.known_faces_dir, exist_ok=True)
return
print(f"从目录加载人脸特征: {self.known_faces_dir}")
for person_name in os.listdir(self.known_faces_dir):
person_dir = os.path.join(self.known_faces_dir, person_name)
if os.path.isdir(person_dir):
print(f"处理人物: {person_name}")
embeddings = []
for filename in os.listdir(person_dir):
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
image_path = os.path.join(person_dir, filename)
try:
img = cv2.imread(image_path)
if img is None:
print(f"无法读取图片: {image_path},已跳过")
continue
faces = self.app.get(img)
if faces:
embeddings.append(faces[0].embedding)
print(f"提取特征成功: {filename}")
else:
print(f"未检测到人脸: {filename},已跳过")
except Exception as e:
print(f"处理图片出错 {image_path}: {e}")
if embeddings:
self.known_faces_embeddings[person_name] = np.array(embeddings).mean(axis=0)
self.known_faces_names.append(person_name)
print(f"人物 {person_name} 加载完成,共 {len(embeddings)} 张照片")
else:
print(f"人物 {person_name} 无有效特征,已跳过")
print(f"人脸加载完成,共 {len(self.known_faces_names)}")
def recognize(self, frame, threshold=0.4):
"""识别人脸并返回结果"""
if not self.app or not self.known_faces_names:
return False, None, None
faces = self.app.get(frame)
if not faces:
return False, None, None
for face in faces:
for known_name in self.known_faces_names:
known_embedding = self.known_faces_embeddings[known_name]
embedding1 = face.embedding.astype(np.float32)
embedding2 = known_embedding.astype(np.float32)
dot_product = np.dot(embedding1, embedding2)
norm_embedding1 = np.linalg.norm(embedding1)
norm_embedding2 = np.linalg.norm(embedding2)
similarity = 0.0 if (norm_embedding1 == 0 or norm_embedding2 == 0) else (
dot_product / (norm_embedding1 * norm_embedding2)
)
if similarity >= threshold:
print(f"检测到已知人物: {known_name} (相似度: {similarity:.4f})")
return True, known_name, similarity
return False, None, None
def test_single_image(self, image_path: str, threshold=0.4):
"""测试单张图片识别"""
if not os.path.exists(image_path):
print(f"图片不存在: {image_path}")
return False, None, None
frame = cv2.imread(image_path)
if frame is None:
print(f"无法读取图片: {image_path}")
return False, None, None
result, name, similarity = self.recognize(frame, threshold)
if result:
print(f"识别结果: {name} (相似度: {similarity:.4f})")
faces = self.app.get(frame)
for face in faces:
bbox = face.bbox.astype(int)
cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2)
text = f"{name}: {similarity:.2f}"
cv2.putText(frame, text, (bbox[0], bbox[1] - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
cv2.imshow('识别结果', frame)
print("按任意键关闭窗口...")
cv2.waitKey(0)
cv2.destroyAllWindows()
else:
print("未识别到已知人脸")
return result, name, similarity
#
# if __name__ == "__main__":
# recognizer = FaceRecognizer(known_faces_dir="known_faces")
# test_image_path = r"F:\OCR\RapidOCR-main\known_faces\B\14sino-qiu02-master1050.jpg"
# recognizer.test_single_image(test_image_path, threshold=0.4)

View File

@ -1,18 +0,0 @@
打倒习近平
打到习近平
打倒毛泽东
打到毛泽东
打到主席
打倒主席
打到共产主义
打倒共产主义
打到共产党
打倒共产党
胖猫
法轮功
法轮大法
法轮大法好
法轮功大法好
法轮
李洪志
习近平

Binary file not shown.

Before

Width:  |  Height:  |  Size: 195 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 208 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 657 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 53 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 8.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 58 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 34 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 155 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 386 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.4 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 62 KiB

View File

@ -1,49 +0,0 @@
#日志文件
import logging
import sys
def setup_logger():
"""
配置一个全局日志记录器,支持输出到控制台和文件。
"""
# 创建一个日志记录器
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger = logging.getLogger("ViolationDetectorLogger")
logger.setLevel(logging.DEBUG) # 设置最低级别为DEBUG
# 如果已经有处理器了,就不要重复添加,防止日志重复打印
if logger.hasHandlers():
return logger
# --- 控制台处理器 ---
console_handler = logging.StreamHandler(sys.stdout)
# 对于控制台我们只显示INFO及以上级别的信息
console_handler.setLevel(logging.INFO)
console_formatter = logging.Formatter(
'%(asctime)s - %(levelname)s - [%(module)s:%(lineno)d] - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
console_handler.setFormatter(console_formatter)
# --- 文件处理器 ---
file_handler = logging.FileHandler("violation_detector.log", mode='a', encoding='utf-8')
# 对于文件我们记录所有DEBUG及以上级别的信息
file_handler.setLevel(logging.DEBUG)
file_formatter = logging.Formatter(
'%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
file_handler.setFormatter(file_formatter)
# 将处理器添加到日志记录器
logger.addHandler(console_handler)
logger.addHandler(file_handler)
return logger
# 创建并导出logger实例
logger = setup_logger()

View File

@ -1,136 +0,0 @@
import os
import cv2
import yaml
from pathlib import Path
from .ocr_violation_detector import OCRViolationDetector
from .yolo_violation_detector import ViolationDetector as YoloViolationDetector
from .face_recognizer import FaceRecognizer
class MultiModelViolationDetector:
"""
多模型违规检测封装类串行调用OCR、人脸识别和YOLO模型任一模型检测到违规即返回结果
"""
def __init__(self,
forbidden_words_path: str,
ocr_config_path: str,
yolo_model_path: str,
known_faces_dir: str,
ocr_confidence_threshold: float = 0.5):
"""
初始化所有检测模型
"""
# 初始化OCR检测器
self.ocr_detector = OCRViolationDetector(
forbidden_words_path=forbidden_words_path,
ocr_config_path=ocr_config_path,
ocr_confidence_threshold=ocr_confidence_threshold
)
# 初始化人脸识别器
self.face_recognizer = FaceRecognizer(
known_faces_dir=known_faces_dir
)
# 初始化YOLO检测器
self.yolo_detector = YoloViolationDetector(
model_path=yolo_model_path
)
print("多模型违规检测器初始化完成")
def detect_violations(self, frame):
"""
串行调用三个检测模型OCR → 人脸识别 → YOLO任一检测到违规即返回结果
"""
# 1. 首先进行OCR违禁词检测
try:
ocr_has_violation, ocr_words, ocr_confs = self.ocr_detector.detect(frame)
if ocr_has_violation:
details = {
"words": ocr_words,
"confidences": ocr_confs
}
print(f"警告: OCR检测到违禁内容: {details}")
return (True, "ocr", details)
except Exception as e:
print(f"错误: OCR检测出错: {str(e)}")
# 2. 接着进行人脸识别检测
try:
face_has_violation, face_name, face_similarity = self.face_recognizer.recognize(frame)
if face_has_violation:
details = {
"name": face_name,
"similarity": face_similarity
}
print(f"警告: 人脸识别到违规人员: {details}")
return (True, "face", details)
except Exception as e:
print(f"错误: 人脸识别出错: {str(e)}")
# 3. 最后进行YOLO目标检测
try:
yolo_results = self.yolo_detector.detect(frame)
if len(yolo_results.boxes) > 0:
details = {
"classes": yolo_results.names,
"boxes": yolo_results.boxes.xyxy.tolist(),
"confidences": yolo_results.boxes.conf.tolist(),
"class_ids": yolo_results.boxes.cls.tolist()
}
print(f"警告: YOLO检测到违规目标: {details}")
return (True, "yolo", details)
except Exception as e:
print(f"错误: YOLO检测出错: {str(e)}")
# 所有检测均未发现违规
return (False, None, None)
def load_config(config_path: str) -> dict:
"""加载YAML配置文件"""
try:
with open(config_path, 'r', encoding='utf-8') as f:
return yaml.safe_load(f)
except FileNotFoundError:
print(f"错误: 配置文件未找到: {config_path}")
raise
except yaml.YAMLError as e:
print(f"错误: 配置文件格式错误: {config_path}, 错误: {str(e)}")
raise
except Exception as e:
print(f"错误: 加载配置文件出错: {str(e)}")
raise
# 使用示例
# if __name__ == "__main__":
# # 加载配置文件
# config = load_config("config.yaml") # 配置文件路径,可根据实际情况修改
#
# # 初始化多模型检测器
# detector = MultiModelViolationDetector(
# forbidden_words_path=config["forbidden_words_path"],
# ocr_config_path=config["ocr_config_path"],
# yolo_model_path=config["yolo_model_path"],
# known_faces_dir=config["known_faces_dir"],
# ocr_confidence_threshold=config.get("ocr_confidence_threshold", 0.5)
# )
#
# # 读取测试图像(可替换为视频帧读取逻辑)
# test_image_path = config.get("test_image_path") # 从配置文件获取测试图片路径
# if test_image_path:
# frame = cv2.imread(test_image_path)
#
# if frame is not None:
# has_violation, violation_type, details = detector.detect_violations(frame)
# if has_violation:
# print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
# else:
# print("未检测到任何违规内容")
# else:
# print(f"无法读取测试图像: {test_image_path}")
# else:
# print("配置文件中未指定测试图像路径")

Binary file not shown.

View File

@ -1,178 +0,0 @@
import os
import cv2
from rapidocr import RapidOCR
class OCRViolationDetector:
"""
封装RapidOCR引擎用于检测图像帧中的违禁词。
核心功能加载违禁词、初始化OCR引擎、单帧图像违禁词检测
"""
def __init__(self,
forbidden_words_path: str,
ocr_config_path: str,
ocr_confidence_threshold: float = 0.5):
"""
初始化OCR引擎和违禁词列表。
Args:
forbidden_words_path (str): 违禁词列表 .txt 文件的路径。
ocr_config_path (str): OCR配置文件如1.yaml的路径。
ocr_confidence_threshold (float): OCR识别结果的置信度阈值0~1
"""
# 加载违禁词
self.forbidden_words = self._load_forbidden_words(forbidden_words_path)
# 初始化RapidOCR引擎
self.ocr_engine = self._initialize_ocr(ocr_config_path)
# 校验核心依赖是否就绪
self._check_dependencies()
# 设置置信度阈值限制在0~1范围
self.OCR_CONFIDENCE_THRESHOLD = max(0.0, min(ocr_confidence_threshold, 1.0))
print(f"OCR置信度阈值已设置范围0~1: {self.OCR_CONFIDENCE_THRESHOLD:.4f}")
def _load_forbidden_words(self, path: str) -> set:
"""
从TXT文件加载违禁词去重、过滤空行支持UTF-8编码
"""
forbidden_words = set()
# 检查文件是否存在
if not os.path.exists(path):
print(f"错误:违禁词文件不存在: {path}")
return forbidden_words
# 读取文件并处理内容
try:
with open(path, 'r', encoding='utf-8') as f:
forbidden_words = {
line.strip() for line in f
if line.strip() # 跳过空行或纯空格行
}
print(f"成功加载违禁词: {len(forbidden_words)} 个(已去重)")
except UnicodeDecodeError:
print(f"错误违禁词文件编码错误需UTF-8: {path}")
except PermissionError:
print(f"错误:无权限读取违禁词文件: {path}")
except Exception as e:
print(f"错误:加载违禁词失败: {str(e)}")
return forbidden_words
def _initialize_ocr(self, config_path: str) -> RapidOCR | None:
"""
初始化RapidOCR引擎校验配置文件、捕获初始化异常
"""
print("开始初始化RapidOCR引擎...")
# 检查配置文件是否存在
if not os.path.exists(config_path):
print(f"错误OCR配置文件不存在: {config_path}")
return None
# 初始化OCR引擎
try:
ocr_engine = RapidOCR(config_path=config_path)
print("RapidOCR引擎初始化成功")
return ocr_engine
except ImportError:
print("错误RapidOCR依赖未安装需执行pip install rapidocr-onnxruntime")
except Exception as e:
print(f"错误RapidOCR初始化失败: {str(e)}")
return None
def _check_dependencies(self) -> None:
"""校验OCR引擎和违禁词列表是否就绪"""
if not self.ocr_engine:
print("警告:⚠️ OCR引擎未就绪违禁词检测功能将禁用")
if not self.forbidden_words:
print("警告:⚠️ 违禁词列表为空,违禁词检测功能将禁用")
def detect(self, frame) -> tuple[bool, list, list]:
"""
对单帧图像进行OCR违禁词检测核心方法
Args:
frame: 输入图像帧NumPy数组BGR格式cv2读取的图像
Returns:
tuple[bool, list, list]:
- 第一个元素是否检测到违禁词True/False
- 第二个元素:检测到的违禁词列表(空列表表示无违禁词);
- 第三个元素:对应违禁词的置信度列表(与违禁词列表一一对应)。
"""
# 初始化返回结果
has_violation = False
violation_words = []
violation_confs = []
# 前置校验
if frame is None or frame.size == 0:
print("警告输入图像帧为空或无效跳过OCR检测")
return has_violation, violation_words, violation_confs
if not self.ocr_engine or not self.forbidden_words:
print("OCR引擎未就绪或违禁词为空跳过OCR检测")
return has_violation, violation_words, violation_confs
try:
# 执行OCR识别
print("开始执行OCR识别...")
ocr_result = self.ocr_engine(frame)
print(f"RapidOCR原始结果: {ocr_result}")
# 校验OCR结果是否有效
if ocr_result is None:
print("OCR识别未返回任何结果图像无文本或识别失败")
return has_violation, violation_words, violation_confs
# 检查txts和scores是否存在且不为None
if not hasattr(ocr_result, 'txts') or ocr_result.txts is None:
print("警告OCR结果中txts为None或不存在")
return has_violation, violation_words, violation_confs
if not hasattr(ocr_result, 'scores') or ocr_result.scores is None:
print("警告OCR结果中scores为None或不存在")
return has_violation, violation_words, violation_confs
# 转为列表并去None
if not isinstance(ocr_result.txts, (list, tuple)):
print(f"警告OCR txts不是可迭代类型实际类型: {type(ocr_result.txts)}")
texts = []
else:
texts = [txt.strip() for txt in ocr_result.txts if txt and isinstance(txt, str)]
if not isinstance(ocr_result.scores, (list, tuple)):
print(f"警告OCR scores不是可迭代类型实际类型: {type(ocr_result.scores)}")
confidences = []
else:
confidences = [conf for conf in ocr_result.scores if conf and isinstance(conf, (int, float))]
# 校验文本和置信度列表长度是否一致
if len(texts) != len(confidences):
print(f"警告OCR文本与置信度数量不匹配文本{len(texts)}个,置信度{len(confidences)}个),跳过检测")
return has_violation, violation_words, violation_confs
if len(texts) == 0:
print("OCR未识别到任何有效文本")
return has_violation, violation_words, violation_confs
# 遍历识别结果,筛选违禁词
print(f"开始筛选违禁词(阈值{self.OCR_CONFIDENCE_THRESHOLD:.4f}")
for text, conf in zip(texts, confidences):
if conf < self.OCR_CONFIDENCE_THRESHOLD:
print(f"文本 '{text}' 置信度{conf:.4f} < 阈值,跳过")
continue
matched_words = [word for word in self.forbidden_words if word in text]
if matched_words:
has_violation = True
violation_words.extend(matched_words)
violation_confs.extend([conf] * len(matched_words))
print(f"警告:检测到违禁词: {matched_words}(来源文本: '{text}',置信度: {conf:.4f}")
except Exception as e:
print(f"错误OCR检测过程异常: {str(e)}")
return has_violation, violation_words, violation_confs

View File

@ -1,47 +0,0 @@
from ultralytics import YOLO
import cv2
class ViolationDetector:
"""
用于加载YOLOv8 .pt模型并进行违规内容检测的类。
"""
def __init__(self, model_path):
"""
初始化检测器。
Args:
model_path (str): YOLO .pt模型的路径。
"""
print(f"正在从 '{model_path}' 加载YOLO模型...")
self.model = YOLO(model_path)
print("YOLO模型加载成功。")
def detect(self, frame):
"""
对单帧图像进行目标检测。
Args:
frame: 输入的图像帧 (NumPy数组, BGR格式)。
Returns:
ultralytics.engine.results.Results: YOLO的检测结果对象。
"""
# conf可以根据您的模型效果进行调整
# --- 为了测试,我们暂时将置信度调低,例如 0.2 ---
results = self.model(frame, conf=0.2)
return results[0]
def draw_boxes(self, frame, result):
"""
在图像帧上绘制检测框。
Args:
frame: 原始图像帧。
result: YOLO的检测结果对象。
Returns:
numpy.ndarray: 绘制了检测框的图像帧。
"""
# 使用YOLO自带的plot功能方便快捷
annotated_frame = result.plot()
return annotated_frame

View File

@ -1,164 +0,0 @@
import queue
import asyncio
import aiohttp
import threading
import time
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration
from aiortc.mediastreams import MediaStreamTrack
# 创建一个长度为1的队列用于生产者和消费者之间的通信
frame_queue = queue.Queue(maxsize=1)
class VideoTrack(MediaStreamTrack):
"""自定义视频轨道类继承自MediaStreamTrack"""
kind = "video"
def __init__(self, max_frames=100):
super().__init__()
self.frames = queue.Queue(maxsize=max_frames)
async def recv(self):
return await super().recv()
def webrtc_producer(webrtc_url):
"""
生产者方法从WEBRTC读取视频帧并放入队列
仅当队列空时才放入新帧,否则丢弃
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 创建RTCPeerConnection对象不使用ICE服务器
pc = RTCPeerConnection(RTCConfiguration(iceServers=[]))
video_track = VideoTrack()
pc.addTrack(video_track)
@pc.on("track")
async def on_track(track):
if track.kind == "video":
print("接收到视频轨道,开始接收视频帧")
while True:
# 从轨道接收视频帧
frame = await track.recv()
# 转换为BGR24格式的NumPy数组
frame_bgr24 = frame.to_ndarray(format='bgr24')
# 检查队列是否为空,为空则加入,否则丢弃
if frame_queue.empty():
try:
frame_queue.put_nowait(frame_bgr24)
print("帧已放入队列")
except queue.Full:
print("队列已满,丢弃帧")
else:
print("队列非空,丢弃帧")
async def main():
# 创建并发送SDP Offer
offer = await pc.createOffer()
print("已创建本地SDP Offer")
await pc.setLocalDescription(offer)
# 发送Offer到服务器并接收Answer
async with aiohttp.ClientSession() as session:
print(f"开始向服务器 {webrtc_url} 发送SDP Offer")
async with session.post(
webrtc_url,
data=offer.sdp.encode(),
headers={
"Content-Type": "application/sdp",
"Content-Length": str(len(offer.sdp))
},
ssl=False
) as response:
print("已接收到服务器的响应")
answer_sdp = await response.text()
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer'))
# 保持连接
try:
while True:
await asyncio.sleep(0.1)
except KeyboardInterrupt:
pass
finally:
print("关闭RTCPeerConnection")
await pc.close()
try:
loop.run_until_complete(main())
finally:
loop.close()
def frame_consumer(ip):
"""
消费者方法:从队列中读取帧并处理
每次处理后休眠200ms模拟延迟
"""
print("消费者启动,开始等待帧...")
try:
while True:
# 阻塞等待队列中的帧
frame = frame_queue.get()
print(f"消费帧,大小: {frame.shape}")
has_violation, violations, confidences = OCRViolationDetector.detect(frame)
# 输出检测结果
if has_violation:
detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
for word, conf in zip(violations, confidences):
detector.logger.info(f"- {word} (置信度: {conf:.4f})")
else:
detector.logger.info("图片中未检测到违禁词")
# 标记任务完成
frame_queue.task_done()
except KeyboardInterrupt:
print("消费者退出")
def start_webrtc_stream(ip, webrtc_url):
"""
启动WebRTC视频流处理的主方法
参数: webrtc_url - WebRTC服务器地址
"""
print(f"开始连接到WebRTC服务器: {webrtc_url}")
# 启动生产者线程
producer_thread = threading.Thread(
target=webrtc_producer,
args=(webrtc_url,),
daemon=True,
name="webrtc-producer"
)
# 启动消费者线程
consumer_thread = threading.Thread(
target=frame_consumer(ip),
daemon=True,
name="frame-consumer"
)
producer_thread.start()
consumer_thread.start()
print("生产者和消费者线程已启动")
try:
# 保持主线程运行
while True:
time.sleep(1)
except KeyboardInterrupt:
print("程序正在退出...")
if __name__ == "__main__":
# 示例用法
# 实际使用时替换为真实的WebRTC服务器地址
webrtc_server_url = "http://192.168.110.65:1985/rtc/v1/whep/?app=live&stream=677a4845aa48cb8526c811ad56fc5e60"
start_webrtc_stream(webrtc_server_url)

View File

@ -1,101 +0,0 @@
import asyncio
import logging
import cv2
import time
# 配置日志与WHEP代码保持一致的日志风格
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("rtmp_video_puller")
async def rtmp_pull_video_stream(rtmp_url):
"""
通过RTMP从指定URL拉取视频流并在收到每一帧时打印消息
功能与WHEP拉流函数对齐流状态反馈、帧信息打印、帧率统计、异常处理
Args:
rtmp_url: RTMP流的URL地址如 rtmp://xxx/live/stream_key
"""
cap = None # 初始化视频捕获对象
try:
# 1. 异步打开RTMP流指定FFmpeg后端确保RTMP兼容性同步操作通过to_thread避免阻塞事件循环
cap = await asyncio.to_thread(
cv2.VideoCapture,
rtmp_url,
cv2.CAP_FFMPEG # 必须指定FFmpeg后端RTMP协议依赖该后端解析
)
# 2. 检查RTMP流是否成功打开
is_opened = await asyncio.to_thread(cap.isOpened)
if not is_opened:
raise Exception(f"RTMP流打开失败: {rtmp_url}请检查URL有效性和FFmpeg环境")
# 3. 异步获取RTMP流基础信息分辨率、帧率
width = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_WIDTH)
height = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_HEIGHT)
fps = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FPS)
# 处理异常情况部分RTMP流未返回帧率时默认30FPS
fps = fps if fps > 0 else 30.0
# 分辨率转为整数(视频尺寸必然是整数)
width, height = int(width), int(height)
# 打印流初始化成功信息与WHEP连接成功信息风格一致
print(f"RTMP流状态: 已成功连接")
print(f"流基础信息: 分辨率 {width}x{height} | 配置帧率 {fps:.2f} FPS")
print("开始接收视频帧...(按 Ctrl+C 中断)")
# 4. 初始化帧统计参数
frame_count = 0 # 总接收帧数
start_time = time.time() # 统计起始时间
# 5. 循环异步读取视频帧(核心逻辑)
while True:
# 异步读取一帧cv2.read是同步操作用to_thread适配异步环境
ret, frame = await asyncio.to_thread(cap.read)
# 检查帧是否读取成功(流中断/结束时ret为False
if not ret:
print(f"RTMP流状态: 帧读取失败(可能流已中断或结束)")
break
# 帧计数累加
frame_count += 1
# 6. 打印当前帧基础信息与WHEP帧信息打印风格对齐
print(f"收到帧 (第{frame_count}帧)")
print(f" 帧尺寸: {width}x{height}")
print(f" 配置帧率: {fps:.2f} FPS")
# 7. 每100帧统计一次实际接收帧率补充性能监控与原RTMP示例逻辑一致
if frame_count % 100 == 0:
elapsed_time = time.time() - start_time
actual_fps = frame_count / elapsed_time # 实际接收帧率(可能低于配置帧率)
print(f"---- 帧统计: 累计{frame_count}帧 | 实际平均帧率 {actual_fps:.2f} FPS ----")
# (可选)帧数据处理入口:如需处理帧(如推流、分析),可在此处添加逻辑
# 示例yield frame (若需生成器模式,可调整函数为异步生成器)
# 8. 异常处理(覆盖用户中断、通用错误)
except KeyboardInterrupt:
print(f"\n用户操作: 已通过 Ctrl+C 中断程序")
except Exception as e:
# 日志记录详细错误(便于问题排查),同时打印用户可见信息
logger.error(f"RTMP流处理异常: {str(e)}", exc_info=True)
print(f"错误信息: {str(e)}")
finally:
# 9. 资源释放(无论成功/失败都确保释放,避免内存泄漏)
if cap is not None:
await asyncio.to_thread(cap.release)
print(f"\n资源释放: RTMP流已关闭")
print(f"最终统计: 共接收 {frame_count if 'frame_count' in locals() else 0}")
if __name__ == "__main__":
RTMP_URL = "rtmp://192.168.110.25:1935/live/473b95a47e338301cbd96809ea7ac416"
# 运行RTMP拉流任务与WHEP一致的异步执行方式
try:
asyncio.run(rtmp_pull_video_stream(RTMP_URL))
except Exception as e:
print(f"程序启动失败: {str(e)}")

View File

@ -0,0 +1,36 @@
from datetime import datetime
from typing import Optional, List
from pydantic import BaseModel, Field
# ------------------------------
# 请求模型
# ------------------------------
class DeviceActionCreate(BaseModel):
"""设备操作记录创建模型0=离线、1=上线)"""
client_ip: str = Field(..., description="客户端IP")
action: int = Field(..., ge=0, le=1, description="操作状态0=离线、1=上线)")
# ------------------------------
# 响应模型(单条记录)
# ------------------------------
class DeviceActionResponse(BaseModel):
"""设备操作记录响应模型(与自增表对齐)"""
id: int = Field(..., description="自增主键ID")
client_ip: Optional[str] = Field(None, description="客户端IP")
action: Optional[int] = Field(None, description="操作状态0=离线、1=上线)")
created_at: datetime = Field(..., description="记录创建时间")
updated_at: datetime = Field(..., description="记录更新时间")
# 支持从数据库结果直接转换
model_config = {"from_attributes": True}
# ------------------------------
# 列表响应模型(仅含 total + device_actions
# ------------------------------
class DeviceActionListResponse(BaseModel):
"""设备操作记录列表(仅核心返回字段)"""
total: int = Field(..., description="总记录数")
device_actions: List[DeviceActionResponse] = Field(..., description="操作记录列表")

View File

@ -1,4 +1,3 @@
import hashlib
from datetime import datetime from datetime import datetime
from typing import Optional, List, Dict from typing import Optional, List, Dict
@ -6,42 +5,31 @@ from pydantic import BaseModel, Field
# ------------------------------ # ------------------------------
# 请求模型(前端传参校验) # 请求模型
# ------------------------------ # ------------------------------
class DeviceCreateRequest(BaseModel): class DeviceCreateRequest(BaseModel):
"""设备流信息创建请求模型""" """设备流信息创建请求模型(与数据库表字段对齐)"""
ip: Optional[str] = Field(..., max_length=100, description="设备IP地址") ip: Optional[str] = Field(..., max_length=100, description="设备IP地址")
hostname: Optional[str] = Field(None, max_length=100, description="设备别名") hostname: Optional[str] = Field(None, max_length=100, description="设备别名")
params: Optional[Dict] = Field(None, description="设备详细信息") params: Optional[Dict] = Field(None, description="设备详细信息JSON格式")
def md5_encrypt(text: str) -> str:
"""对字符串进行MD5加密"""
if not text:
return ""
md5_hash = hashlib.md5()
md5_hash.update(text.encode('utf-8'))
return md5_hash.hexdigest()
# ------------------------------ # ------------------------------
# 响应模型(后端返回设备数据) # 响应模型(后端返回数据)- 严格对齐数据库表字段
# ------------------------------ # ------------------------------
class DeviceResponse(BaseModel): class DeviceResponse(BaseModel):
"""设备流信息响应模型(字段与表结构完全对齐""" """设备流信息响应模型(与数据库表字段完全一致"""
id: int = Field(..., description="设备ID") id: int = Field(..., description="设备主键ID")
client_ip: Optional[str] = Field(None, max_length=100, description="设备IP地址")
hostname: Optional[str] = Field(None, max_length=100, description="设备别名") hostname: Optional[str] = Field(None, max_length=100, description="设备别名")
rtmp_push_url: Optional[str] = Field(None, description="需要推送的RTMP地址")
live_webrtc_url: Optional[str] = Field(None, description="直播的Webrtc地址")
detection_webrtc_url: Optional[str] = Field(None, description="检测的Webrtc地址")
device_online_status: int = Field(..., description="设备在线状态1-在线、0-离线)") device_online_status: int = Field(..., description="设备在线状态1-在线、0-离线)")
device_type: Optional[str] = Field(None, description="设备类型") device_type: Optional[str] = Field(None, description="设备类型")
alarm_count: int = Field(..., description="报警次数") alarm_count: int = Field(..., description="报警次数")
params: Optional[str] = Field(None, description="设备详细信息") params: Optional[str] = Field(None, description="设备详细信息JSON字符串")
created_at: datetime = Field(..., description="记录创建时间") created_at: datetime = Field(..., description="记录创建时间")
updated_at: datetime = Field(..., description="记录更新时间") updated_at: datetime = Field(..., description="记录更新时间")
# 支持从数据库查询结果转换 # 支持从数据库查询结果直接转换
model_config = {"from_attributes": True} model_config = {"from_attributes": True}

View File

@ -6,8 +6,8 @@ from pydantic import BaseModel, Field
# 请求模型(前端传参校验) # 请求模型(前端传参校验)
# ------------------------------ # ------------------------------
class FaceCreateRequest(BaseModel): class FaceCreateRequest(BaseModel):
"""创建人脸记录请求模型无需ID由数据库自增)""" """创建人脸记录请求模型无需ID由数据库自增)"""
name: str = Field(None, max_length=255, description="名称(可选最长255字符") name: str = Field(None, max_length=255, description="名称(可选最长255字符")
class FaceUpdateRequest(BaseModel): class FaceUpdateRequest(BaseModel):
@ -20,11 +20,11 @@ class FaceUpdateRequest(BaseModel):
# 响应模型(后端返回数据) # 响应模型(后端返回数据)
# ------------------------------ # ------------------------------
class FaceResponse(BaseModel): class FaceResponse(BaseModel):
"""人脸记录响应模型仍包含ID由数据库生成后返回)""" """人脸记录响应模型仍包含ID由数据库生成后返回)"""
id: int = Field(..., description="主键ID数据库自增") id: int = Field(..., description="主键ID数据库自增")
name: str = Field(None, description="名称") name: str = Field(None, description="名称")
eigenvalue: str = Field(None, description="特征(暂为None") eigenvalue: str | None = Field(None, description="特征(可为空")
created_at: datetime = Field(..., description="记录创建时间") created_at: datetime = Field(..., description="记录创建时间")
updated_at: datetime = Field(..., description="记录更新时间") updated_at: datetime = Field(..., description="记录更新时间")
model_config = {"from_attributes": True} model_config = {"from_attributes": True}

View File

@ -5,9 +5,9 @@ from pydantic import BaseModel, Field
class APIResponse(BaseModel): class APIResponse(BaseModel):
"""统一 API 响应模型(所有接口必返此格式)""" """统一 API 响应模型(所有接口必返此格式)"""
code: int = Field(..., description="状态码200=成功、4xx=客户端错误、5xx=服务端错误") code: int = Field(..., description="状态码: 200=成功、4xx=客户端错误、5xx=服务端错误")
message: str = Field(..., description="响应信息成功/错误描述") message: str = Field(..., description="响应信息: 成功/错误描述")
data: Optional[Any] = Field(None, description="响应数据成功时返回、错误时为 None") data: Optional[Any] = Field(None, description="响应数据: 成功时返回、错误时为 None")
# Pydantic V2 配置(支持从 ORM 对象转换) # Pydantic V2 配置(支持从 ORM 对象转换)
model_config = {"from_attributes": True} model_config = {"from_attributes": True}

View File

@ -7,7 +7,7 @@ from pydantic import BaseModel, Field
# ------------------------------ # ------------------------------
class SensitiveCreateRequest(BaseModel): class SensitiveCreateRequest(BaseModel):
"""创建敏感信息记录请求模型""" """创建敏感信息记录请求模型"""
# 移除了id字段由数据库自动生成 # 移除了id字段由数据库自动生成
name: str = Field(None, max_length=255, description="名称") name: str = Field(None, max_length=255, description="名称")

View File

@ -0,0 +1,158 @@
from fastapi import APIRouter, Query, Path
from mysql.connector import Error as MySQLError
from ds.db import db
from schema.device_action_schema import (
DeviceActionCreate,
DeviceActionResponse,
DeviceActionListResponse
)
from schema.response_schema import APIResponse
# 路由配置
router = APIRouter(
prefix="/device/actions",
tags=["设备操作记录"]
)
# ------------------------------
# 内部方法: 新增设备操作记录适配id自增
# ------------------------------
def add_device_action(action_data: DeviceActionCreate) -> DeviceActionResponse:
"""
新增设备操作记录(内部方法、非接口)
:param action_data: 含client_ip和action0/1
:return: 新增的完整记录
"""
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 插入SQLid自增、依赖数据库自动生成
insert_query = """
INSERT INTO device_action
(client_ip, action, created_at, updated_at)
VALUES (%s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
"""
cursor.execute(insert_query, (
action_data.client_ip,
action_data.action
))
conn.commit()
# 获取新增记录通过自增ID查询
new_id = cursor.lastrowid
cursor.execute("SELECT * FROM device_action WHERE id = %s", (new_id,))
new_action = cursor.fetchone()
return DeviceActionResponse(**new_action)
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"新增记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 接口: 分页查询操作记录列表(仅返回 total + device_actions
# ------------------------------
@router.get("/list", response_model=APIResponse, summary="分页查询设备操作记录")
async def get_device_action_list(
page: int = Query(1, ge=1, description="页码、默认1"),
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100"),
client_ip: str = Query(None, description="按客户端IP筛选"),
action: int = Query(None, ge=0, le=1, description="按状态筛选0=离线、1=上线)")
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 1. 构建筛选条件(参数化查询、避免注入)
where_clause = []
params = []
if client_ip:
where_clause.append("client_ip = %s")
params.append(client_ip)
if action is not None:
where_clause.append("action = %s")
params.append(action)
# 2. 查询总记录数(用于返回 total
count_sql = "SELECT COUNT(*) AS total FROM device_action"
if where_clause:
count_sql += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_sql, params)
total = cursor.fetchone()["total"]
# 3. 分页查询记录(按创建时间倒序、确保最新记录在前)
offset = (page - 1) * page_size
list_sql = "SELECT * FROM device_action"
if where_clause:
list_sql += " WHERE " + " AND ".join(where_clause)
list_sql += " ORDER BY created_at DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset]) # 追加分页参数page/page_size仅用于查询、不返回
cursor.execute(list_sql, params)
action_list = cursor.fetchall()
# 4. 仅返回 total + device_actions
return APIResponse(
code=200,
message="查询成功",
data=DeviceActionListResponse(
total=total,
device_actions=[DeviceActionResponse(**item) for item in action_list]
)
)
except MySQLError as e:
raise Exception(f"查询记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
@router.get("/{client_ip}", response_model=APIResponse, summary="根据IP查询设备操作记录")
async def get_device_actions_by_ip(
client_ip: str = Path(..., description="客户端IP地址")
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 1. 查询总记录数
count_sql = "SELECT COUNT(*) AS total FROM device_action WHERE client_ip = %s"
cursor.execute(count_sql, (client_ip,))
total = cursor.fetchone()["total"]
# 2. 查询该IP的所有记录按创建时间倒序
list_sql = """
SELECT * FROM device_action
WHERE client_ip = %s
ORDER BY created_at DESC
"""
cursor.execute(list_sql, (client_ip,))
action_list = cursor.fetchall()
# 3. 返回结果
return APIResponse(
code=200,
message="查询成功",
data=DeviceActionListResponse(
total=total,
device_actions=[DeviceActionResponse(**item) for item in action_list]
)
)
except MySQLError as e:
raise Exception(f"查询记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)

View File

@ -1,25 +1,11 @@
import json import json
import threading
import time
from fastapi import HTTPException, Query, APIRouter, Depends, Request from fastapi import APIRouter, Query, HTTPException,Request
from mysql.connector import Error as MySQLError from mysql.connector import Error as MySQLError
from ds.config import LIVE_CONFIG
from ds.db import db from ds.db import db
from middle.auth_middleware import get_current_user from schema.device_schema import DeviceCreateRequest, DeviceResponse, DeviceListResponse
# 注意导入的Schema已更新字段
from schema.device_schema import (
DeviceCreateRequest,
DeviceResponse,
DeviceListResponse,
md5_encrypt
)
from schema.response_schema import APIResponse from schema.response_schema import APIResponse
from schema.user_schema import UserResponse
# 导入之前封装的WEBRTC处理函数
from core.rtmp import rtmp_pull_video_stream
router = APIRouter( router = APIRouter(
prefix="/devices", prefix="/devices",
@ -27,65 +13,127 @@ router = APIRouter(
) )
# 在后台线程中运行WEBRTC处理
def run_webrtc_processing(ip, webrtc_url):
try:
print(f"开始处理来自设备 {ip} 的WEBRTC流: {webrtc_url}")
rtmp_pull_video_stream(webrtc_url)
except Exception as e:
print(f"WEBRTC处理出错: {str(e)}")
# ------------------------------ # ------------------------------
# 1. 创建设备信息 # 内部工具方法 - 通过客户端IP增加设备报警次数
# ------------------------------ # ------------------------------
@router.post("/add", response_model=APIResponse, summary="创建设备信息") def increment_alarm_count_by_ip(client_ip: str) -> bool:
async def create_device(request: Request, device_data: DeviceCreateRequest): """
通过客户端IP增加设备的报警次数内部服务方法
:param client_ip: 客户端IP地址
:return: 操作是否成功
"""
if not client_ip:
raise ValueError("客户端IP不能为空")
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 检查设备是否存在
cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,))
if not cursor.fetchone():
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在")
# 报警次数加1、并更新时间戳
update_query = """
UPDATE devices
SET alarm_count = alarm_count + 1,
updated_at = CURRENT_TIMESTAMP
WHERE client_ip = %s
"""
cursor.execute(update_query, (client_ip,))
conn.commit()
return True
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"更新报警次数失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 内部工具方法 - 通过客户端IP更新设备在线状态
# ------------------------------
def update_online_status_by_ip(client_ip: str, online_status: int) -> bool:
"""
通过客户端IP更新设备的在线状态内部服务方法
:param client_ip: 客户端IP地址
:param online_status: 在线状态1-在线、0-离线)
:return: 操作是否成功
"""
if not client_ip:
raise ValueError("客户端IP不能为空")
# 验证状态值有效性
if online_status not in (0, 1):
raise ValueError("在线状态必须是0离线或1在线")
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 检查设备是否存在
cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,))
if not cursor.fetchone():
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在")
# 更新在线状态和时间戳
update_query = """
UPDATE devices
SET device_online_status = %s,
updated_at = CURRENT_TIMESTAMP
WHERE client_ip = %s
"""
cursor.execute(update_query, (online_status, client_ip))
conn.commit()
return True
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"更新设备在线状态失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 原有接口保持不变
# ------------------------------
@router.post("/add", response_model=APIResponse, summary="创建设备信息")
async def create_device(device_data: DeviceCreateRequest, request: Request): # 注入Request对象
# 原有代码保持不变
conn = None conn = None
cursor = None cursor = None
try: try:
conn = db.get_connection() conn = db.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
# 检查client_ip是否已存在
cursor.execute("SELECT * FROM devices WHERE client_ip = %s", (device_data.ip,)) cursor.execute("SELECT * FROM devices WHERE client_ip = %s", (device_data.ip,))
existing_device = cursor.fetchone() existing_device = cursor.fetchone()
if existing_device: if existing_device:
# 设备创建成功后在后台线程启动WEBRTC流处理 # 更新设备状态为在线
threading.Thread( update_online_status_by_ip(client_ip=device_data.ip, online_status=1)
target=run_webrtc_processing, # 返回信息
# args=(device_data.ip, existing_device["live_webrtc_url"]),
args=(device_data.ip, existing_device["rtmp_push_url"]),
daemon=True # 设为守护线程,主程序退出时自动结束
).start()
# IP已存在时返回该设备信息
return APIResponse( return APIResponse(
code=200, code=200,
message=f"客户端IP {device_data.ip} 已存在", message=f"设备IP {device_data.ip} 已存在、返回已有设备信息",
data=DeviceResponse(**existing_device) data=DeviceResponse(** existing_device)
) )
# 获取RTMP URL和WEBRTC URL配置 # 直接使用注入的request对象获取用户代理
rtmp_url = str(LIVE_CONFIG.get("rtmp_url", ""))
webrtc_url = str(LIVE_CONFIG.get("webrtc_url", ""))
# 将设备详细信息params转换为JSON字符串
device_params_json = json.dumps(device_data.params) if device_data.params else None
# 对JSON字符串进行MD5加密
device_md5 = md5_encrypt(device_params_json) if device_params_json else ""
# 解析User-Agent获取设备类型
user_agent = request.headers.get("User-Agent", "").lower() user_agent = request.headers.get("User-Agent", "").lower()
# 优先处理User-Agent为default的情况
if user_agent == "default": if user_agent == "default":
# 检查params中是否存在os键 device_type = device_data.params.get("os") if (
if device_data.params and isinstance(device_data.params, dict) and "os" in device_data.params: device_data.params and isinstance(device_data.params, dict)) else "unknown"
device_type = device_data.params["os"]
else:
device_type = "unknown"
elif "windows" in user_agent: elif "windows" in user_agent:
device_type = "windows" device_type = "windows"
elif "android" in user_agent: elif "android" in user_agent:
@ -95,22 +143,16 @@ async def create_device(request: Request, device_data: DeviceCreateRequest):
else: else:
device_type = "unknown" device_type = "unknown"
# 构建完整的WEBRTC URL device_params_json = json.dumps(device_data.params) if device_data.params else None
full_webrtc_url = webrtc_url + device_md5
# SQL插入语句
insert_query = """ insert_query = """
INSERT INTO devices INSERT INTO devices
(client_ip, hostname, rtmp_push_url, live_webrtc_url, detection_webrtc_url, (client_ip, hostname, device_online_status, device_type, alarm_count, params)
device_online_status, device_type, alarm_count, params) VALUES (%s, %s, %s, %s, %s, %s)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
""" """
cursor.execute(insert_query, ( cursor.execute(insert_query, (
device_data.ip, device_data.ip,
device_data.hostname, device_data.hostname,
rtmp_url + device_md5,
full_webrtc_url, # 存储完整的WEBRTC URL
"",
1, 1,
device_type, device_type,
0, 0,
@ -118,28 +160,22 @@ async def create_device(request: Request, device_data: DeviceCreateRequest):
)) ))
conn.commit() conn.commit()
# 获取刚创建的设备信息
device_id = cursor.lastrowid device_id = cursor.lastrowid
cursor.execute("SELECT * FROM devices WHERE id = %s", (device_id,)) cursor.execute("SELECT * FROM devices WHERE id = %s", (device_id,))
device = cursor.fetchone() new_device = cursor.fetchone()
# 设备创建成功后在后台线程启动WEBRTC流处理
threading.Thread(
target=run_webrtc_processing,
args=(device_data.ip, full_webrtc_url),
daemon=True # 设为守护线程,主程序退出时自动结束
).start()
return APIResponse( return APIResponse(
code=200, code=200,
message="设备创建成功已开始处理WEBRTC流", message="设备创建成功",
data=DeviceResponse(**device) data=DeviceResponse(**new_device)
) )
except MySQLError as e: except MySQLError as e:
if conn: if conn:
conn.rollback() conn.rollback()
raise Exception(f"创建设备失败{str(e)}") from e raise Exception(f"创建设备失败: {str(e)}") from e
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise Exception(f"设备信息JSON序列化失败{str(e)}") from e raise Exception(f"设备详细信息JSON序列化失败: {str(e)}") from e
except Exception as e: except Exception as e:
if conn: if conn:
conn.rollback() conn.rollback()
@ -147,140 +183,82 @@ async def create_device(request: Request, device_data: DeviceCreateRequest):
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
@router.get("/", response_model=APIResponse, summary="获取设备列表(支持筛选分页)")
# ------------------------------
# 2. 获取设备列表
# ------------------------------
@router.get("/", response_model=APIResponse, summary="获取设备列表")
async def get_device_list( async def get_device_list(
page: int = Query(1, ge=1, description="页码"), page: int = Query(1, ge=1, description="页码、默认第1页"),
page_size: int = Query(10, ge=1, le=100, description="每页条数"), page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间"),
device_type: str = Query(None, description="设备类型筛选"), device_type: str = Query(None, description="设备类型筛选"),
online_status: int = Query(None, ge=0, le=1, description="在线状态筛选1-在线、0-离线)") online_status: int = Query(None, ge=0, le=1, description="在线状态筛选")
): ):
# 原有代码保持不变
conn = None conn = None
cursor = None cursor = None
try: try:
conn = db.get_connection() conn = db.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
# 构建查询条件
where_clause = [] where_clause = []
params = [] params = []
if device_type: if device_type:
where_clause.append("device_type = %s") where_clause.append("device_type = %s")
params.append(device_type) params.append(device_type)
if online_status is not None: if online_status is not None:
where_clause.append("device_online_status = %s") where_clause.append("device_online_status = %s")
params.append(online_status) params.append(online_status)
# 总条数查询 count_query = "SELECT COUNT(*) AS total FROM devices"
count_query = "SELECT COUNT(*) as total FROM devices"
if where_clause: if where_clause:
count_query += " WHERE " + " AND ".join(where_clause) count_query += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_query, params) cursor.execute(count_query, params)
total = cursor.fetchone()["total"] total = cursor.fetchone()["total"]
# 分页查询SELECT * 会自动匹配表字段、响应模型已对齐)
offset = (page - 1) * page_size offset = (page - 1) * page_size
query = "SELECT * FROM devices" list_query = "SELECT * FROM devices"
if where_clause: if where_clause:
query += " WHERE " + " AND ".join(where_clause) list_query += " WHERE " + " AND ".join(where_clause)
query += " ORDER BY id DESC LIMIT %s OFFSET %s" list_query += " ORDER BY id DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset]) params.extend([page_size, offset])
cursor.execute(query, params) cursor.execute(list_query, params)
devices = cursor.fetchall() device_list = cursor.fetchall()
# 响应模型已更新为params字段、直接转换即可
device_list = [DeviceResponse(**device) for device in devices]
return APIResponse( return APIResponse(
code=200, code=200,
message="获取设备列表成功", message="获取设备列表成功",
data=DeviceListResponse(total=total, devices=device_list) data=DeviceListResponse(
total=total,
devices=[DeviceResponse(**device) for device in device_list]
)
) )
except MySQLError as e: except MySQLError as e:
raise Exception(f"获取设备列表失败{str(e)}") from e raise Exception(f"获取设备列表失败: {str(e)}") from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# ------------------------------ def get_unique_client_ips() -> list[str]:
# 3. 获取单个设备详情 """
# ------------------------------ 获取所有去重的客户端IP列表
@router.get("/{device_id}", response_model=APIResponse, summary="获取设备详情")
async def get_device_detail( :return: 去重后的客户端IP字符串列表如果没有数据则返回空列表
device_id: int, """
current_user: UserResponse = Depends(get_current_user)
):
conn = None conn = None
cursor = None cursor = None
try: try:
conn = db.get_connection() conn = db.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
# 查询设备信息SELECT * 匹配表字段) # 查询去重的客户端IP
query = "SELECT * FROM devices WHERE id = %s" query = "SELECT DISTINCT client_ip FROM devices WHERE client_ip IS NOT NULL"
cursor.execute(query, (device_id,)) cursor.execute(query)
device = cursor.fetchone()
if not device: # 提取结果并转换为字符串列表
raise HTTPException( results = cursor.fetchall()
status_code=404, return [item['client_ip'] for item in results]
detail=f"设备ID为 {device_id} 的设备不存在"
)
# 响应模型已更新为params字段
return APIResponse(
code=200,
message="获取设备详情成功",
data=DeviceResponse(**device)
)
except MySQLError as e: except MySQLError as e:
raise Exception(f"获取设备详情失败:{str(e)}") from e raise Exception(f"获取客户端IP列表失败: {str(e)}") from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# ------------------------------
# 4. 删除设备信息
# ------------------------------
@router.delete("/{device_id}", response_model=APIResponse, summary="删除设备信息")
async def delete_device(
device_id: int,
current_user: UserResponse = Depends(get_current_user)
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 检查设备是否存在
cursor.execute("SELECT id FROM devices WHERE id = %s", (device_id,))
if not cursor.fetchone():
raise HTTPException(
status_code=404,
detail=f"设备ID为 {device_id} 的设备不存在"
)
# 执行删除
delete_query = "DELETE FROM devices WHERE id = %s"
cursor.execute(delete_query, (device_id,))
conn.commit()
return APIResponse(
code=200,
message=f"设备ID为 {device_id} 的设备已成功删除",
data=None
)
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"删除设备失败:{str(e)}") from e
finally:
db.close_connection(conn, cursor)

View File

@ -7,32 +7,38 @@ from schema.response_schema import APIResponse
from middle.auth_middleware import get_current_user from middle.auth_middleware import get_current_user
from schema.user_schema import UserResponse from schema.user_schema import UserResponse
from util.face_util import add_binary_data,get_average_feature
#初始化实例
router = APIRouter( router = APIRouter(
prefix="/faces", prefix="/faces",
tags=["人脸管理"] tags=["人脸管理"]
) )
# ------------------------------ # ------------------------------
# 1. 创建人脸记录(核心修正ID 数据库自增前端无需传) # 1. 创建人脸记录(核心修正: ID 数据库自增前端无需传)
# ------------------------------ # ------------------------------
@router.post("", response_model=APIResponse, summary="创建人脸记录(传名称+文件ID自增") @router.post("", response_model=APIResponse, summary="创建人脸记录(传名称+文件ID自增")
async def create_face( async def create_face(
# 前端仅需传name可选Form格式、file必传文件) # 前端仅需传: name可选Form格式、file必传文件)
name: str = Form(None, max_length=255, description="名称(可选)"), name: str = Form(None, max_length=255, description="名称(可选)"),
file: UploadFile = File(..., description="人脸文件(必传暂不处理内容)") file: UploadFile = File(..., description="人脸文件(必传暂不处理内容)")
): ):
""" """
创建人脸记录 创建人脸记录:
- 需登录认证 - 需登录认证
- 前端传参multipart/form-data 表单name 可选file 必传) - 前端传参: multipart/form-data 表单name 可选file 必传)
- ID 由数据库自动生成无需前端传入 - ID 由数据库自动生成无需前端传入
- 暂不处理文件内容eigenvalue 设为 None - 暂不处理文件内容eigenvalue 设为 None
""" """
# 调用你的方法
conn = None conn = None
cursor = None cursor = None
try: try:
# 1. 用模型校验 name仅校验长度无需ID # 1. 用模型校验 name仅校验长度无需ID
face_create = FaceCreateRequest(name=name) face_create = FaceCreateRequest(name=name)
conn = db.get_connection() conn = db.get_connection()
@ -41,42 +47,77 @@ async def create_face(
# 把文件转为二进制数组 # 把文件转为二进制数组
file_content = await file.read() file_content = await file.read()
# 调用人脸识别得到特征值 # 计算特征值
flag, eigenvalue = add_binary_data(file_content)
if flag == False:
raise HTTPException(
status_code=500,
detail="未检测到人脸"
)
# 2. 插入数据库:无需传 ID自增只传 name 和 eigenvalueNone # 打印数组长度
print(f"文件大小: {len(file_content)} 字节")
# 2. 插入数据库: 无需传 ID自增、只传 name 和 eigenvalueNone
insert_query = """ insert_query = """
INSERT INTO face (name, eigenvalue) INSERT INTO face (name, eigenvalue)
VALUES (%s, %s) VALUES (%s, %s)
""" """
cursor.execute(insert_query, (face_create.name, None)) cursor.execute(insert_query, (face_create.name, str(eigenvalue)))
conn.commit() conn.commit()
# 3. 获取数据库自动生成的 ID关键用 LAST_INSERT_ID() 查刚插入的记录) # 3. 获取数据库自动生成的 ID关键: 用 LAST_INSERT_ID() 查刚插入的记录)
select_new_query = "SELECT * FROM face WHERE id = LAST_INSERT_ID()" select_new_query = "SELECT * FROM face WHERE id = LAST_INSERT_ID()"
cursor.execute(select_new_query) cursor.execute(select_new_query)
created_face = cursor.fetchone() created_face = cursor.fetchone()
if not created_face:
raise HTTPException(
status_code=500,
detail="创建人脸记录成功、但无法获取新创建的记录"
)
return APIResponse( return APIResponse(
code=201, code=201,
message=f"人脸记录创建成功ID{created_face['id']}文件名{file.filename}", message=f"人脸记录创建成功ID: {created_face['id']}文件名: {file.filename}",
data=FaceResponse(**created_face) data=FaceResponse(** created_face)
) )
except MySQLError as e: except MySQLError as e:
if conn: if conn:
conn.rollback() conn.rollback()
raise Exception(f"创建人脸记录失败:{str(e)}") from e # 改为使用HTTPException
raise HTTPException(
status_code=500,
detail=f"创建人脸记录失败: {str(e)}"
) from e
except Exception as e:
# 捕获其他可能的异常
raise HTTPException(
status_code=500,
detail=f"服务器错误: {str(e)}"
) from e
finally: finally:
await file.close() # 关闭文件流 await file.close() # 关闭文件流
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# 调用人脸识别得到特征值(这里可以添加你的人脸识别逻辑)
flag, eigenvalue = add_binary_data(file_content)
if flag == False:
raise HTTPException(
status_code=500,
detail="未检测到人脸"
)
# 将 eigenvalue 转为 str
eigenvalue = str(eigenvalue)
# ------------------------------ # ------------------------------
# 2. 获取单个人脸记录(不变用自增ID查询 # 2. 获取单个人脸记录(不变用自增ID查询
# ------------------------------ # ------------------------------
@router.get("/{face_id}", response_model=APIResponse, summary="获取单个人脸记录") @router.get("/{face_id}", response_model=APIResponse, summary="获取单个人脸记录")
async def get_face( async def get_face(
face_id: int, # 这里的 ID 是数据库自增的前端从创建响应中获取 face_id: int, # 这里的 ID 是数据库自增的前端从创建响应中获取
current_user: UserResponse = Depends(get_current_user) current_user: UserResponse = Depends(get_current_user)
): ):
conn = None conn = None
@ -101,18 +142,21 @@ async def get_face(
data=FaceResponse(**face) data=FaceResponse(**face)
) )
except MySQLError as e: except MySQLError as e:
raise Exception(f"查询人脸记录失败:{str(e)}") from e # 改为使用HTTPException
raise HTTPException(
status_code=500,
detail=f"查询人脸记录失败: {str(e)}"
) from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# 后续 3.获取所有、4.更新、5.删除 接口(不变,仅用数据库自增的 ID 操作,无需修改 # 后续 3.获取所有、4.更新、5.删除 接口(修复异常处理
# ------------------------------ # ------------------------------
# 3. 获取所有人脸记录(不变) # 3. 获取所有人脸记录(不变)
# ------------------------------ # ------------------------------
@router.get("", response_model=APIResponse, summary="获取所有人脸记录") @router.get("", response_model=APIResponse, summary="获取所有人脸记录")
async def get_all_faces( async def get_all_faces(
current_user: UserResponse = Depends(get_current_user)
): ):
conn = None conn = None
cursor = None cursor = None
@ -127,16 +171,19 @@ async def get_all_faces(
return APIResponse( return APIResponse(
code=200, code=200,
message="所有人脸记录查询成功", message="所有人脸记录查询成功",
data=[FaceResponse(**face) for face in faces] data=[FaceResponse(** face) for face in faces]
) )
except MySQLError as e: except MySQLError as e:
raise Exception(f"查询所有人脸记录失败:{str(e)}") from e raise HTTPException(
status_code=500,
detail=f"查询所有人脸记录失败: {str(e)}"
) from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# ------------------------------ # ------------------------------
# 4. 更新人脸记录(不变用自增ID更新 # 4. 更新人脸记录(不变用自增ID更新
# ------------------------------ # ------------------------------
@router.put("/{face_id}", response_model=APIResponse, summary="更新人脸记录") @router.put("/{face_id}", response_model=APIResponse, summary="更新人脸记录")
async def update_face( async def update_face(
@ -191,13 +238,16 @@ async def update_face(
except MySQLError as e: except MySQLError as e:
if conn: if conn:
conn.rollback() conn.rollback()
raise Exception(f"更新人脸记录失败:{str(e)}") from e raise HTTPException(
status_code=500,
detail=f"更新人脸记录失败: {str(e)}"
) from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# ------------------------------ # ------------------------------
# 5. 删除人脸记录(不变用自增ID删除 # 5. 删除人脸记录(不变用自增ID删除
# ------------------------------ # ------------------------------
@router.delete("/{face_id}", response_model=APIResponse, summary="删除人脸记录") @router.delete("/{face_id}", response_model=APIResponse, summary="删除人脸记录")
async def delete_face( async def delete_face(
@ -231,39 +281,61 @@ async def delete_face(
except MySQLError as e: except MySQLError as e:
if conn: if conn:
conn.rollback() conn.rollback()
raise Exception(f"删除人脸记录失败:{str(e)}") from e raise HTTPException(
status_code=500,
detail=f"删除人脸记录失败: {str(e)}"
) from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
def get_all_face_name_with_eigenvalue() -> dict: def get_all_face_name_with_eigenvalue() -> dict:
""" """
获取所有人脸的名称及其对应的特征值组成字典返回 获取所有人脸的名称及其对应的特征值组成字典返回
key: 人脸名称name key: 人脸名称name
value: 人脸特征值eigenvalue value: 人脸特征值eigenvalue、若名称重复则返回平均特征值
过滤掉name为None的记录避免字典key为None的情况 : 过滤掉name为None的记录避免字典key为None的情况
""" """
conn = None conn = None
cursor = None cursor = None
try: try:
# 1. 建立数据库连接并获取游标dictionary=True使结果以字典形式返回
conn = db.get_connection() conn = db.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
# 只查询需要的字段,提高效率 # 2. 执行SQL查询: 只获取name非空的记录、减少数据传输
query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL" query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL"
cursor.execute(query) cursor.execute(query)
faces = cursor.fetchall() faces = cursor.fetchall() # 返回结果: 列表套字典、如 [{"name":"张三","eigenvalue":...}, ...]
# 构建name到eigenvalue的映射字典 # 3. 收集同一名称对应的所有特征值(处理名称重复场景)
face_dict = { name_to_eigenvalues = {}
face["name"]: face["eigenvalue"] for face in faces:
for face in faces name = face["name"]
} eigenvalue = face["eigenvalue"]
# 若名称已存在、追加特征值;否则新建列表存储
if name in name_to_eigenvalues:
name_to_eigenvalues[name].append(eigenvalue)
else:
name_to_eigenvalues[name] = [eigenvalue]
# 4. 构建最终字典: 重复名称取平均、唯一名称直接取特征值
face_dict = {}
for name, eigenvalues in name_to_eigenvalues.items():
# 处理特征值: 多个则求平均、单个则直接使用
if len(eigenvalues) > 1:
# 调用外部方法计算平均特征值需确保binary_face_feature_handler已正确导入
face_dict[name] = get_average_feature(eigenvalues)
else:
# 取列表中唯一的特征值避免value为列表类型
face_dict[name] = eigenvalues[0]
return face_dict return face_dict
except MySQLError as e: except MySQLError as e:
raise Exception(f"获取人脸名称与特征值失败:{str(e)}") from e # 捕获数据库异常、添加上下文信息后重新抛出(便于定位问题)
raise Exception(f"获取人脸名称与特征值失败: {str(e)}") from e
finally: finally:
# 确保资源释放 # 5. 无论是否异常、均释放数据库连接和游标(避免资源泄漏)
db.close_connection(conn, cursor) db.close_connection(conn, cursor)

View File

@ -21,7 +21,7 @@ router = APIRouter(
async def create_sensitive( async def create_sensitive(
sensitive: SensitiveCreateRequest): # 添加了登录认证依赖 sensitive: SensitiveCreateRequest): # 添加了登录认证依赖
""" """
创建敏感信息记录 创建敏感信息记录:
- 需登录认证 - 需登录认证
- 插入新的敏感信息记录到数据库ID由数据库自动生成 - 插入新的敏感信息记录到数据库ID由数据库自动生成
- 返回创建成功信息 - 返回创建成功信息
@ -32,7 +32,7 @@ async def create_sensitive(
conn = db.get_connection() conn = db.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
# 插入新敏感信息记录到数据库不包含ID由数据库自动生成) # 插入新敏感信息记录到数据库不包含ID由数据库自动生成)
insert_query = """ insert_query = """
INSERT INTO sensitives (name) INSERT INTO sensitives (name)
VALUES (%s) VALUES (%s)
@ -56,7 +56,7 @@ async def create_sensitive(
except MySQLError as e: except MySQLError as e:
if conn: if conn:
conn.rollback() conn.rollback()
raise Exception(f"创建敏感信息记录失败{str(e)}") from e raise Exception(f"创建敏感信息记录失败: {str(e)}") from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
@ -71,7 +71,7 @@ async def get_sensitive(
current_user: UserResponse = Depends(get_current_user) # 需登录认证 current_user: UserResponse = Depends(get_current_user) # 需登录认证
): ):
""" """
获取单个敏感信息记录 获取单个敏感信息记录:
- 需登录认证 - 需登录认证
- 根据ID查询敏感信息记录 - 根据ID查询敏感信息记录
- 返回查询到的敏感信息 - 返回查询到的敏感信息
@ -98,7 +98,7 @@ async def get_sensitive(
data=SensitiveResponse(**sensitive) data=SensitiveResponse(**sensitive)
) )
except MySQLError as e: except MySQLError as e:
raise Exception(f"查询敏感信息记录失败{str(e)}") from e raise Exception(f"查询敏感信息记录失败: {str(e)}") from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
@ -109,7 +109,7 @@ async def get_sensitive(
@router.get("", response_model=APIResponse, summary="获取所有敏感信息记录") @router.get("", response_model=APIResponse, summary="获取所有敏感信息记录")
async def get_all_sensitives(): async def get_all_sensitives():
""" """
获取所有敏感信息记录 获取所有敏感信息记录:
- 需登录认证 - 需登录认证
- 查询所有敏感信息记录(不需要分页) - 查询所有敏感信息记录(不需要分页)
- 返回所有敏感信息列表 - 返回所有敏感信息列表
@ -130,7 +130,7 @@ async def get_all_sensitives():
data=[SensitiveResponse(**sensitive) for sensitive in sensitives] data=[SensitiveResponse(**sensitive) for sensitive in sensitives]
) )
except MySQLError as e: except MySQLError as e:
raise Exception(f"查询所有敏感信息记录失败{str(e)}") from e raise Exception(f"查询所有敏感信息记录失败: {str(e)}") from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
@ -145,7 +145,7 @@ async def update_sensitive(
current_user: UserResponse = Depends(get_current_user) # 需登录认证 current_user: UserResponse = Depends(get_current_user) # 需登录认证
): ):
""" """
更新敏感信息记录 更新敏感信息记录:
- 需登录认证 - 需登录认证
- 根据ID更新敏感信息记录 - 根据ID更新敏感信息记录
- 返回更新后的敏感信息 - 返回更新后的敏感信息
@ -203,7 +203,7 @@ async def update_sensitive(
except MySQLError as e: except MySQLError as e:
if conn: if conn:
conn.rollback() conn.rollback()
raise Exception(f"更新敏感信息记录失败{str(e)}") from e raise Exception(f"更新敏感信息记录失败: {str(e)}") from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
@ -217,7 +217,7 @@ async def delete_sensitive(
current_user: UserResponse = Depends(get_current_user) # 需登录认证 current_user: UserResponse = Depends(get_current_user) # 需登录认证
): ):
""" """
删除敏感信息记录 删除敏感信息记录:
- 需登录认证 - 需登录认证
- 根据ID删除敏感信息记录 - 根据ID删除敏感信息记录
- 返回删除成功信息 - 返回删除成功信息
@ -251,14 +251,14 @@ async def delete_sensitive(
except MySQLError as e: except MySQLError as e:
if conn: if conn:
conn.rollback() conn.rollback()
raise Exception(f"删除敏感信息记录失败{str(e)}") from e raise Exception(f"删除敏感信息记录失败: {str(e)}") from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
def get_all_sensitive_words() -> list[str]: def get_all_sensitive_words() -> list[str]:
""" """
获取所有敏感词返回字符串数组 获取所有敏感词返回字符串数组
返回: 返回:
list[str]: 包含所有敏感词的数组 list[str]: 包含所有敏感词的数组
@ -273,7 +273,7 @@ def get_all_sensitive_words() -> list[str]:
conn = db.get_connection() conn = db.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
# 执行查询只获取敏感词字段 # 执行查询只获取敏感词字段
query = "SELECT name FROM sensitives ORDER BY id" query = "SELECT name FROM sensitives ORDER BY id"
cursor.execute(query) cursor.execute(query)
sensitive_records = cursor.fetchall() sensitive_records = cursor.fetchall()
@ -283,7 +283,7 @@ def get_all_sensitive_words() -> list[str]:
except MySQLError as e: except MySQLError as e:
# 数据库错误处理 # 数据库错误处理
raise MySQLError(f"查询敏感词失败{str(e)}") from e raise MySQLError(f"查询敏感词失败: {str(e)}") from e
finally: finally:
# 确保资源正确释放 # 确保资源正确释放
db.close_connection(conn, cursor) db.close_connection(conn, cursor)

View File

@ -27,7 +27,7 @@ router = APIRouter(
@router.post("/register", response_model=APIResponse, summary="用户注册") @router.post("/register", response_model=APIResponse, summary="用户注册")
async def user_register(request: UserRegisterRequest): async def user_register(request: UserRegisterRequest):
""" """
用户注册 用户注册:
- 校验用户名是否已存在 - 校验用户名是否已存在
- 加密密码后插入数据库 - 加密密码后插入数据库
- 返回注册成功信息 - 返回注册成功信息
@ -67,7 +67,7 @@ async def user_register(request: UserRegisterRequest):
) )
except MySQLError as e: except MySQLError as e:
conn.rollback() # 数据库错误时回滚事务 conn.rollback() # 数据库错误时回滚事务
raise Exception(f"注册失败{str(e)}") from e raise Exception(f"注册失败: {str(e)}") from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
@ -78,7 +78,7 @@ async def user_register(request: UserRegisterRequest):
@router.post("/login", response_model=APIResponse, summary="用户登录(获取 Token") @router.post("/login", response_model=APIResponse, summary="用户登录(获取 Token")
async def user_login(request: UserLoginRequest): async def user_login(request: UserLoginRequest):
""" """
用户登录 用户登录:
- 校验用户名是否存在 - 校验用户名是否存在
- 校验密码是否正确 - 校验密码是否正确
- 生成 JWT Token 并返回 - 生成 JWT Token 并返回
@ -89,7 +89,7 @@ async def user_login(request: UserLoginRequest):
conn = db.get_connection() conn = db.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
# 修复SQL查询添加 created_at 和 updated_at 字段 # 修复: SQL查询添加 created_at 和 updated_at 字段
query = """ query = """
SELECT id, username, password, created_at, updated_at SELECT id, username, password, created_at, updated_at
FROM users FROM users
@ -129,7 +129,7 @@ async def user_login(request: UserLoginRequest):
} }
) )
except MySQLError as e: except MySQLError as e:
raise Exception(f"登录失败{str(e)}") from e raise Exception(f"登录失败: {str(e)}") from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
@ -142,8 +142,8 @@ async def get_current_user_info(
current_user: UserResponse = Depends(get_current_user) # 依赖认证中间件 current_user: UserResponse = Depends(get_current_user) # 依赖认证中间件
): ):
""" """
获取当前登录用户信息 获取当前登录用户信息:
- 需在请求头携带 Token格式Bearer <token> - 需在请求头携带 Token格式: Bearer <token>
- 认证通过后返回用户信息 - 认证通过后返回用户信息
""" """
return APIResponse( return APIResponse(

145
util/face_util.py Normal file
View File

@ -0,0 +1,145 @@
import cv2
import numpy as np
import insightface
from insightface.app import FaceAnalysis
from io import BytesIO
from PIL import Image
# 全局变量存储InsightFace引擎和特征列表
_insightface_app = None
_feature_list = []
def init_insightface():
"""初始化InsightFace引擎"""
global _insightface_app
try:
print("正在初始化InsightFace引擎...")
app = FaceAnalysis(name='buffalo_l', root='~/.insightface')
app.prepare(ctx_id=0, det_size=(640, 640))
print("InsightFace引擎初始化完成")
_insightface_app = app
return app
except Exception as e:
print(f"InsightFace初始化失败: {e}")
return None
def add_binary_data(binary_data):
"""
接收单张图片的二进制数据、提取特征并保存
参数:
binary_data: 图片的二进制数据bytes类型
返回:
成功提取特征时返回 (True, 特征值numpy数组)
失败时返回 (False, None)
"""
global _insightface_app, _feature_list
if not _insightface_app:
print("引擎未初始化、无法处理")
return False, None
try:
# 直接处理二进制数据: 转换为图像格式
img = Image.open(BytesIO(binary_data))
frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
# 提取特征
faces = _insightface_app.get(frame)
if faces:
# 获取当前提取的特征值
current_feature = faces[0].embedding
# 添加到特征列表
_feature_list.append(current_feature)
print(f"已累计 {len(_feature_list)} 个特征")
# 返回成功标志和当前特征值
return True, current_feature
else:
print("二进制数据中未检测到人脸")
return False, None
except Exception as e:
print(f"处理二进制数据出错: {e}")
return False, None
def get_average_feature(features=None):
"""
计算多个特征向量的平均值
参数:
features: 可选、特征值列表。如果未提供、则使用全局存储的_feature_list
每个元素可以是字符串格式或numpy数组
返回:
单一平均特征向量的numpy数组、若无可计算数据则返回None
"""
global _feature_list
# 如果未提供features参数、则使用全局特征列表
if features is None:
features = _feature_list
try:
# 验证输入是否为列表且不为空
if not isinstance(features, list) or len(features) == 0:
print("输入必须是包含至少一个特征值的列表")
return None
# 处理每个特征值
processed_features = []
for i, embedding in enumerate(features):
try:
if isinstance(embedding, str):
# 处理包含括号和逗号的字符串格式
embedding_str = embedding.replace('[', '').replace(']', '').replace(',', ' ').strip()
embedding_list = [float(num) for num in embedding_str.split() if num.strip()]
embedding_np = np.array(embedding_list, dtype=np.float32)
else:
embedding_np = np.array(embedding, dtype=np.float32)
# 验证特征值格式
if len(embedding_np.shape) == 1:
processed_features.append(embedding_np)
print(f"已添加第 {i + 1} 个特征值用于计算平均值")
else:
print(f"跳过第 {i + 1} 个特征值、不是一维数组")
except Exception as e:
print(f"处理第 {i + 1} 个特征值时出错: {e}")
# 确保有有效的特征值
if not processed_features:
print("没有有效的特征值用于计算平均值")
return None
# 检查所有特征向量维度是否相同
dims = {feat.shape[0] for feat in processed_features}
if len(dims) > 1:
print(f"特征值维度不一致、无法计算平均值。检测到的维度: {dims}")
return None
# 计算平均值
avg_feature = np.mean(processed_features, axis=0)
print(f"成功计算 {len(processed_features)} 个特征值的平均特征向量、维度: {avg_feature.shape[0]}")
return avg_feature
except Exception as e:
print(f"计算平均特征值时出错: {e}")
return None
def clear_features():
"""清空已存储的特征数据"""
global _feature_list
_feature_list = []
print("已清空所有特征数据")
def get_feature_list():
"""获取当前存储的特征列表"""
global _feature_list
return _feature_list.copy() # 返回副本防止外部直接修改

482
ws.html
View File

@ -1,482 +0,0 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>WebSocket 测试工具</title>
<style>
* {
box-sizing: border-box;
margin: 0;
padding: 0;
font-family: 'Arial', 'Microsoft YaHei', sans-serif;
}
body {
max-width: 1200px;
margin: 20px auto;
padding: 0 20px;
background-color: #f5f7fa;
}
.container {
background: white;
border-radius: 8px;
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
padding: 25px;
margin-bottom: 20px;
}
h1 {
color: #2c3e50;
margin-bottom: 20px;
font-size: 24px;
border-bottom: 2px solid #3498db;
padding-bottom: 10px;
}
.status-bar {
display: flex;
align-items: center;
gap: 15px;
margin-bottom: 20px;
padding: 12px 15px;
background-color: #f8f9fa;
border-radius: 6px;
}
.status-label {
font-weight: bold;
color: #495057;
}
.status-value {
padding: 4px 10px;
border-radius: 4px;
font-weight: bold;
}
.status-connected {
background-color: #d4edda;
color: #155724;
}
.status-disconnected {
background-color: #f8d7da;
color: #721c24;
}
.status-connecting {
background-color: #fff3cd;
color: #856404;
}
.btn {
padding: 8px 16px;
border: none;
border-radius: 4px;
cursor: pointer;
font-size: 14px;
font-weight: 500;
transition: background-color 0.2s;
}
.btn-primary {
background-color: #3498db;
color: white;
}
.btn-primary:hover {
background-color: #2980b9;
}
.btn-danger {
background-color: #e74c3c;
color: white;
}
.btn-danger:hover {
background-color: #c0392b;
}
.btn-success {
background-color: #2ecc71;
color: white;
}
.btn-success:hover {
background-color: #27ae60;
}
.control-group {
display: flex;
gap: 15px;
margin-bottom: 20px;
align-items: center;
}
.input-group {
display: flex;
gap: 10px;
align-items: center;
}
.input-group label {
color: #495057;
font-weight: 500;
}
.input-group input, .input-group select {
padding: 8px 12px;
border: 1px solid #ced4da;
border-radius: 4px;
font-size: 14px;
}
.message-area {
margin-top: 20px;
}
.message-input {
width: 100%;
height: 100px;
padding: 12px;
border: 1px solid #ced4da;
border-radius: 6px;
resize: none;
font-size: 14px;
margin-bottom: 10px;
}
.log-area {
width: 100%;
height: 300px;
padding: 15px;
border: 1px solid #ced4da;
border-radius: 6px;
background-color: #f8f9fa;
overflow-y: auto;
font-size: 14px;
line-height: 1.6;
}
.log-item {
margin-bottom: 8px;
padding-bottom: 8px;
border-bottom: 1px dashed #e9ecef;
}
.log-time {
color: #6c757d;
font-size: 12px;
margin-right: 10px;
}
.log-send {
color: #2980b9;
}
.log-receive {
color: #27ae60;
}
.log-status {
color: #856404;
}
.log-error {
color: #e74c3c;
}
</style>
</head>
<body>
<div class="container">
<h1>WebSocket 测试工具</h1>
<!-- 连接状态区 -->
<div class="status-bar">
<div class="status-label">连接状态:</div>
<div id="connectionStatus" class="status-value status-disconnected">未连接</div>
<div class="status-label">服务地址:</div>
<div id="wsUrl" class="status-value">ws://192.168.110.25:8000/ws</div>
<div class="status-label">连接时间:</div>
<div id="connectTime" class="status-value">-</div>
</div>
<!-- 控制按钮区 -->
<div class="control-group">
<button id="connectBtn" class="btn btn-primary">建立连接</button>
<button id="disconnectBtn" class="btn btn-danger" disabled>断开连接</button>
<!-- 心跳控制 -->
<div class="input-group">
<label>自动心跳:</label>
<select id="autoHeartbeat">
<option value="on">开启</option>
<option value="off">关闭</option>
</select>
<label>间隔(秒)</label>
<input type="number" id="heartbeatInterval" value="30" min="10" max="120" style="width: 80px;">
<button id="sendHeartbeatBtn" class="btn btn-success">手动发送心跳</button>
</div>
</div>
<!-- 自定义消息发送区 -->
<div class="message-area">
<h3>发送自定义消息</h3>
<textarea id="messageInput" class="message-input"
placeholder='示例:{"type":"test","content":"Hello WebSocket"}'>{"type":"test","content":"Hello WebSocket"}</textarea>
<button id="sendMessageBtn" class="btn btn-primary" disabled>发送消息</button>
</div>
<!-- 日志显示区 -->
<div class="message-area">
<h3>消息日志</h3>
<div id="logContainer" class="log-area">
<div class="log-item"><span class="log-time">[加载完成]</span> 请点击「建立连接」开始测试</div>
</div>
<button id="clearLogBtn" class="btn btn-primary" style="margin-top: 10px;">清空日志</button>
</div>
</div>
<script>
// 全局变量
let ws = null;
let heartbeatTimer = null;
const wsUrl = "ws://192.168.110.25:8000/ws";
// DOM 元素
const connectionStatus = document.getElementById('connectionStatus');
const connectTime = document.getElementById('connectTime');
const connectBtn = document.getElementById('connectBtn');
const disconnectBtn = document.getElementById('disconnectBtn');
const sendMessageBtn = document.getElementById('sendMessageBtn');
const sendHeartbeatBtn = document.getElementById('sendHeartbeatBtn');
const autoHeartbeat = document.getElementById('autoHeartbeat');
const heartbeatInterval = document.getElementById('heartbeatInterval');
const messageInput = document.getElementById('messageInput');
const logContainer = document.getElementById('logContainer');
const clearLogBtn = document.getElementById('clearLogBtn');
// 工具函数:添加日志
function addLog(content, type = 'status') {
const now = new Date().toLocaleString('zh-CN', {
year: 'numeric', month: '2-digit', day: '2-digit',
hour: '2-digit', minute: '2-digit', second: '2-digit'
});
const logItem = document.createElement('div');
logItem.className = 'log-item';
let logClass = '';
switch (type) {
case 'send':
logClass = 'log-send';
break;
case 'receive':
logClass = 'log-receive';
break;
case 'error':
logClass = 'log-error';
break;
default:
logClass = 'log-status';
}
logItem.innerHTML = `<span class="log-time">[${now}]</span> <span class="${logClass}">${content}</span>`;
logContainer.appendChild(logItem);
// 滚动到最新日志
logContainer.scrollTop = logContainer.scrollHeight;
}
// 工具函数格式化JSON便于日志显示
function formatJson(jsonStr) {
try {
const obj = JSON.parse(jsonStr);
return JSON.stringify(obj, null, 2);
} catch (e) {
return jsonStr; // 非JSON格式直接返回
}
}
// 建立WebSocket连接
function connectWebSocket() {
if (ws) {
addLog('已存在连接,无需重复建立', 'error');
return;
}
try {
ws = new WebSocket(wsUrl);
// 连接成功
ws.onopen = function () {
connectionStatus.className = 'status-value status-connected';
connectionStatus.textContent = '已连接';
const now = new Date().toLocaleString('zh-CN');
connectTime.textContent = now;
addLog(`连接成功!服务地址:${wsUrl}`, 'status');
// 更新按钮状态
connectBtn.disabled = true;
disconnectBtn.disabled = false;
sendMessageBtn.disabled = false;
// 开启自动心跳(默认开启)
if (autoHeartbeat.value === 'on') {
startAutoHeartbeat();
}
};
// 接收消息
ws.onmessage = function (event) {
const message = event.data;
addLog(`收到消息:\n${formatJson(message)}`, 'receive');
};
// 连接关闭
ws.onclose = function (event) {
connectionStatus.className = 'status-value status-disconnected';
connectionStatus.textContent = '已断开';
addLog(`连接断开!代码:${event.code},原因:${event.reason || '未知'}`, 'status');
// 清除自动心跳
stopAutoHeartbeat();
// 更新按钮状态
connectBtn.disabled = false;
disconnectBtn.disabled = true;
sendMessageBtn.disabled = true;
// 重置WebSocket对象
ws = null;
};
// 连接错误
ws.onerror = function (error) {
addLog(`连接错误:${error.message || '未知错误'}`, 'error');
};
} catch (e) {
addLog(`建立连接失败:${e.message}`, 'error');
ws = null;
}
}
// 断开WebSocket连接
function disconnectWebSocket() {
if (!ws) {
addLog('当前无连接,无需断开', 'error');
return;
}
ws.close(1000, '手动断开连接');
}
// 发送心跳消息(符合约定格式:{"timestamp":xxxxx, "type":"heartbeat"}
function sendHeartbeat() {
if (!ws || ws.readyState !== WebSocket.OPEN) {
addLog('发送心跳失败:当前无有效连接', 'error');
return;
}
const heartbeatMsg = {
timestamp: Date.now(), // 当前毫秒时间戳
type: "heartbeat"
};
const msgStr = JSON.stringify(heartbeatMsg);
ws.send(msgStr);
addLog(`发送心跳:\n${formatJson(msgStr)}`, 'send');
}
// 开启自动心跳
function startAutoHeartbeat() {
// 先停止已有定时器
stopAutoHeartbeat();
const interval = parseInt(heartbeatInterval.value) * 1000;
if (isNaN(interval) || interval < 10000) {
addLog('自动心跳间隔无效已重置为30秒', 'error');
heartbeatInterval.value = 30;
return startAutoHeartbeat();
}
addLog(`开启自动心跳,间隔:${heartbeatInterval.value}`, 'status');
heartbeatTimer = setInterval(sendHeartbeat, interval);
}
// 停止自动心跳
function stopAutoHeartbeat() {
if (heartbeatTimer) {
clearInterval(heartbeatTimer);
heartbeatTimer = null;
addLog('已停止自动心跳', 'status');
}
}
// 发送自定义消息
function sendCustomMessage() {
if (!ws || ws.readyState !== WebSocket.OPEN) {
addLog('发送消息失败:当前无有效连接', 'error');
return;
}
const msgStr = messageInput.value.trim();
if (!msgStr) {
addLog('发送消息失败:消息内容不能为空', 'error');
return;
}
try {
// 验证JSON格式可选仅提示不强制
JSON.parse(msgStr);
ws.send(msgStr);
addLog(`发送自定义消息:\n${formatJson(msgStr)}`, 'send');
} catch (e) {
addLog(`JSON格式错误${e.message},仍尝试发送原始内容`, 'error');
ws.send(msgStr);
addLog(`发送自定义消息非JSON\n${msgStr}`, 'send');
}
}
// 绑定按钮事件
connectBtn.addEventListener('click', connectWebSocket);
disconnectBtn.addEventListener('click', disconnectWebSocket);
sendMessageBtn.addEventListener('click', sendCustomMessage);
sendHeartbeatBtn.addEventListener('click', sendHeartbeat);
clearLogBtn.addEventListener('click', () => {
logContainer.innerHTML = '<div class="log-item"><span class="log-time">[日志已清空]</span> 请继续操作...</div>';
});
// 自动心跳开关变更事件
autoHeartbeat.addEventListener('change', function () {
if (ws && ws.readyState === WebSocket.OPEN) {
if (this.value === 'on') {
startAutoHeartbeat();
} else {
stopAutoHeartbeat();
}
} else {
addLog('需先建立有效连接才能控制自动心跳', 'error');
// 重置选择
this.value = 'off';
}
});
// 心跳间隔变更事件(实时生效)
heartbeatInterval.addEventListener('change', function () {
if (autoHeartbeat.value === 'on' && ws && ws.readyState === WebSocket.OPEN) {
startAutoHeartbeat();
}
});
// 快捷键支持Ctrl+Enter发送消息
messageInput.addEventListener('keydown', function (e) {
if (e.ctrlKey && e.key === 'Enter') {
sendCustomMessage();
e.preventDefault();
}
});
</script>
</body>
</html>

338
ws/ws.py
View File

@ -3,288 +3,296 @@ import datetime
import json import json
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Dict, Optional, AsyncGenerator from typing import Dict, Optional
from service.device_service import update_online_status_by_ip, increment_alarm_count_by_ip
from service.device_action_service import add_device_action
from schema.device_action_schema import DeviceActionCreate
from core.all import detect, load_model
import cv2 import cv2
import numpy as np import numpy as np
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI
from ocr.model_violation_detector import MultiModelViolationDetector # 配置常量
# 配置文件相对路径(根据实际目录结构调整)
YOLO_MODEL_PATH = r"D:\Git\bin\video\ocr\models\best.pt"
FORBIDDEN_WORDS_PATH = r"D:\Git\bin\video\ocr\forbidden_words.txt"
OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml"
KNOWN_FACES_DIR = r"D:\Git\bin\video\ocr\known_faces"
# 创建检测器实例
detector = MultiModelViolationDetector(
forbidden_words_path=FORBIDDEN_WORDS_PATH,
ocr_config_path=OCR_CONFIG_PATH,
yolo_model_path=YOLO_MODEL_PATH,
known_faces_dir=KNOWN_FACES_DIR,
ocr_confidence_threshold=0.5
)
# -------------------------- 配置常量 --------------------------
HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒) HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒)
HEARTBEAT_TIMEOUT = 600 # 客户端超时阈值(秒) HEARTBEAT_TIMEOUT = 600 # 客户端超时阈值(秒)
WS_ENDPOINT = "/ws" # WebSocket端点路径 WS_ENDPOINT = "/ws" # WebSocket端点路径
FRAME_QUEUE_SIZE = 1 # 帧队列大小限制保持1确保单帧处理 FRAME_QUEUE_SIZE = 1 # 帧队列大小限制
# -------------------------- 核心数据结构与全局变量 --------------------------
ws_router = APIRouter()
# 客户端连接封装(包含帧队列) # 工具函数: 获取格式化时间字符串
def get_current_time_str() -> str:
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def get_current_time_file_str() -> str:
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
# 客户端连接封装
class ClientConnection: class ClientConnection:
def __init__(self, websocket: WebSocket, client_ip: str): def __init__(self, websocket: WebSocket, client_ip: str):
self.websocket = websocket self.websocket = websocket
self.client_ip = client_ip self.client_ip = client_ip # 已初始化客户端IP用于传递给detect
self.last_heartbeat = datetime.datetime.now() self.last_heartbeat = datetime.datetime.now()
self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE) # 帧队列长度为1 self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE)
self.consumer_task: Optional[asyncio.Task] = None # 消费者任务 self.consumer_task: Optional[asyncio.Task] = None
# 更新心跳时间
def update_heartbeat(self): def update_heartbeat(self):
"""更新心跳时间"""
self.last_heartbeat = datetime.datetime.now() self.last_heartbeat = datetime.datetime.now()
# 检查是否存活超时返回False
def is_alive(self) -> bool: def is_alive(self) -> bool:
timeout = (datetime.datetime.now() - self.last_heartbeat).total_seconds() """判断客户端是否存活"""
return timeout < HEARTBEAT_TIMEOUT timeout_seconds = (datetime.datetime.now() - self.last_heartbeat).total_seconds()
return timeout_seconds < HEARTBEAT_TIMEOUT
# 启动帧消费任务
def start_consumer(self): def start_consumer(self):
"""启动帧消费任务"""
self.consumer_task = asyncio.create_task(self.consume_frames()) self.consumer_task = asyncio.create_task(self.consume_frames())
return self.consumer_task return self.consumer_task
# ---------- 新增:发送“允许发送二进制帧”的信号给客户端 ---------- async def send_frame_permit(self):
async def send_allow_send_frame(self): """发送帧发送许可信号"""
"""向客户端发送JSON信号通知其可发送下一帧二进制数据"""
try: try:
allow_msg = { frame_permit_msg = {
"type": "allow_send_frame", # 信号类型,与客户端约定 "type": "frame",
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "timestamp": get_current_time_str(),
"status": "ready", # 表示服务器已准备好接收下一帧 "client_ip": self.client_ip
"client_ip": self.client_ip # 可选:便于客户端确认自身身份
} }
await self.websocket.send_json(allow_msg) await self.websocket.send_json(frame_permit_msg)
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}已发送「允许发送帧」信号") print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 已发送帧发送许可信号")
except Exception as e: except Exception as e:
# 发送失败大概率是客户端已断开,不影响主流程,仅日志记录 print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧许可信号发送失败 - {str(e)}")
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}:发送「允许发送帧」信号失败 - {str(e)}")
# 帧消费协程
async def consume_frames(self) -> None: async def consume_frames(self) -> None:
"""队列中获取帧并进行处理,处理完后通知客户端可发送下一帧""" """消费队列中的帧并处理"""
try: try:
while True: while True:
# 从队列获取帧数据(队列空时会阻塞,等待客户端发送) # 取出帧并立即发送下一帧许可
frame_data = await self.frame_queue.get() frame_data = await self.frame_queue.get()
await self.send_frame_permit()
try: try:
# 处理帧数据
await self.process_frame(frame_data) await self.process_frame(frame_data)
finally: finally:
# 标记任务完成(队列计数-1此时队列回到空状态
self.frame_queue.task_done() self.frame_queue.task_done()
# ---------- 修改:处理完当前帧后,立即通知客户端可发送下一帧 ----------
await self.send_allow_send_frame()
except asyncio.CancelledError: except asyncio.CancelledError:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}帧消费任务已取消") print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费任务已取消")
except Exception as e: except Exception as e:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}:帧处理错误 - {str(e)}") print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费逻辑错误 - {str(e)}")
async def process_frame(self, frame_data: bytes) -> None: async def process_frame(self, frame_data: bytes) -> None:
"""处理单帧图像数据(原有逻辑不变""" """处理单帧图像数据(核心修改detect函数传入 client_ip + img 双参数"""
# 二进制数据转换为NumPy数组uint8类型 # 二进制转OpenCV图像
nparr = np.frombuffer(frame_data, np.uint8) nparr = np.frombuffer(frame_data, np.uint8)
# 解码为图像返回与cv2.imread相同的格式BGR通道的ndarray
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
# 确保images文件夹存在 print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 无法解析图像数据")
if not os.path.exists('images'): return
os.makedirs('images')
# 生成唯一的文件名包含时间戳和客户端IP避免文件名冲突
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"images/{self.client_ip.replace('.', '_')}_{timestamp}.jpg"
try: try:
# 保存图像到本地 # -------------------------- 核心修改按要求传入参数1.client_ip 2.img --------------------------
cv2.imwrite(filename, img) # detect函数参数顺序第一个为client_ip第二个为图像数据img
print(f"[{datetime.datetime.now():%H:%M:%S}] 图像已保存至:{filename}") # 保持返回值解包(是否违规, 结果数据, 检测器类型)不变
has_violation, data, detector_type = await asyncio.to_thread(
detect, # 调用检测函数
self.client_ip, # 第一个参数客户端IP新增按需求顺序
img # 第二个参数:图像数据(原参数,调整顺序)
)
# -------------------------------------------------------------------------------------
# 进行检测 # 打印检测结果包含客户端IP与传入参数对应
if img is not None: print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测结果 - "
has_violation, violation_type, details = detector.detect_violations(img) f"违规: {has_violation}, 类型: {detector_type}, 数据: {data}")
if has_violation:
print(f"检测到违规 - 类型: {violation_type}, 详情: {details}") # 处理违规逻辑逻辑不变基于detect返回结果执行
# 发送检测结果回客户端(原有逻辑不变) if has_violation:
await self.websocket.send_json({ print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - "
"type": "detection_result", f"类型: {detector_type}, 详情: {data}")
"has_violation": has_violation,
"violation_type": violation_type, # 违规次数+1
"details": details, try:
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") await asyncio.to_thread(increment_alarm_count_by_ip, self.client_ip)
}) print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数已+1")
else: except Exception as e:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}:未检测到任何违规内容") print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数更新失败 - {str(e)}")
# 发送危险通知
danger_msg = {
"type": "danger",
"timestamp": get_current_time_str(),
"client_ip": self.client_ip,
"detail": data
}
await self.websocket.send_json(danger_msg)
else: else:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}:无法解析图像数据") print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 未检测到违规")
except Exception as e: except Exception as e:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}图像处理错误 - {str(e)}") print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 图像处理错误 - {str(e)}")
# 全局连接管理IP -> 连接实例) # 全局状态管理
connected_clients: Dict[str, ClientConnection] = {} connected_clients: Dict[str, ClientConnection] = {}
# 心跳任务(全局引用,用于关闭时清理)
heartbeat_task: Optional[asyncio.Task] = None heartbeat_task: Optional[asyncio.Task] = None
# -------------------------- 心跳检查逻辑(原有逻辑不变) -------------------------- # 心跳检查任务
async def heartbeat_checker(): async def heartbeat_checker():
while True: while True:
now = datetime.datetime.now() current_time = get_current_time_str()
# 1. 筛选超时客户端(避免遍历中修改字典)
timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()] timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()]
# 2. 处理超时连接(关闭+移除)
if timeout_ips: if timeout_ips:
print(f"[{now:%H:%M:%S}] 心跳检查{len(timeout_ips)}个客户端超时({timeout_ips}") print(f"[{current_time}] 心跳检查: {len(timeout_ips)}个客户端超时(IP: {timeout_ips}")
for ip in timeout_ips: for ip in timeout_ips:
try: try:
# 取消消费者任务 conn = connected_clients[ip]
if connected_clients[ip].consumer_task and not connected_clients[ip].consumer_task.done(): if conn.consumer_task and not conn.consumer_task.done():
connected_clients[ip].consumer_task.cancel() conn.consumer_task.cancel()
await connected_clients[ip].websocket.close(code=1008, reason="心跳超时") await conn.websocket.close(code=1008, reason="心跳超时")
# 标记离线
try:
await asyncio.to_thread(update_online_status_by_ip, ip, 0)
action_data = DeviceActionCreate(client_ip=ip, action=0)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{current_time}] 客户端{ip}: 已标记为离线并记录操作")
except Exception as e:
print(f"[{current_time}] 客户端{ip}: 离线状态更新失败 - {str(e)}")
finally: finally:
connected_clients.pop(ip, None) connected_clients.pop(ip, None)
else: else:
print(f"[{now:%H:%M:%S}] 心跳检查{len(connected_clients)}个客户端在线,无超时") print(f"[{current_time}] 心跳检查: {len(connected_clients)}个客户端在线")
# 3. 等待下一轮检查
await asyncio.sleep(HEARTBEAT_INTERVAL) await asyncio.sleep(HEARTBEAT_INTERVAL)
# -------------------------- 应用生命周期(原有逻辑不变) -------------------------- # 应用生命周期管理
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
global heartbeat_task global heartbeat_task
# 启动心跳任务
heartbeat_task = asyncio.create_task(heartbeat_checker()) heartbeat_task = asyncio.create_task(heartbeat_checker())
print(f"[{datetime.datetime.now():%H:%M:%S}] 心跳任务启动ID{id(heartbeat_task)}") print(f"[{get_current_time_str()}] 全局心跳检查任务启动任务ID: {id(heartbeat_task)}")
yield yield
# 关闭时取消心跳任务
if heartbeat_task and not heartbeat_task.done(): if heartbeat_task and not heartbeat_task.done():
heartbeat_task.cancel() heartbeat_task.cancel()
try: try:
await heartbeat_task await heartbeat_task
print(f"[{datetime.datetime.now():%H:%M:%S}] 心跳任务已取消") print(f"[{get_current_time_str()}] 全局心跳检查任务已取消")
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
# -------------------------- 消息处理(文本/心跳逻辑不变,二进制逻辑保留) -------------------------- # 消息处理工具函数
async def send_heartbeat_ack(client_ip: str): async def send_heartbeat_ack(conn: ClientConnection):
"""回复心跳确认(原有逻辑不变)"""
if client_ip not in connected_clients:
return False
try: try:
ack = { heartbeat_ack_msg = {
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "type": "heart",
"type": "heartbeat" "timestamp": get_current_time_str(),
"client_ip": conn.client_ip
} }
await connected_clients[client_ip].websocket.send_json(ack) await conn.websocket.send_json(heartbeat_ack_msg)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 已发送心跳确认")
return True return True
except Exception: except Exception as e:
connected_clients.pop(client_ip, None) connected_clients.pop(conn.client_ip, None)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 心跳确认发送失败 - {str(e)}")
return False return False
async def handle_text_msg(client_ip: str, text: str, conn: ClientConnection): async def handle_text_msg(conn: ClientConnection, text: str):
"""处理文本消息(核心:心跳+JSON解析原有逻辑不变"""
try: try:
msg = json.loads(text) msg = json.loads(text)
# 仅处理心跳类型消息 if msg.get("type") == "heart":
if msg.get("type") == "heartbeat":
conn.update_heartbeat() conn.update_heartbeat()
await send_heartbeat_ack(client_ip) await send_heartbeat_ack(conn)
else: else:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:收到文本消息{msg}") print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 未知文本消息类型({msg.get('type')}")
except json.JSONDecodeError: except json.JSONDecodeError:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}无效JSON消息") print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 无效JSON文本消息")
async def handle_binary_msg(client_ip: str, data: bytes): async def handle_binary_msg(conn: ClientConnection, data: bytes):
"""处理二进制消息(原有逻辑不变,因客户端仅在收到允许信号后发送,队列不会满)"""
if client_ip not in connected_clients:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:连接不存在,丢弃{len(data)}字节数据")
return
conn = connected_clients[client_ip]
# 检查队列是否已满(理论上不会触发,因客户端按信号发送)
if conn.frame_queue.full():
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:队列已满,丢弃{len(data)}字节数据")
return
# 队列未满,添加帧到队列
try: try:
conn.frame_queue.put_nowait(data) conn.frame_queue.put_nowait(data)
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:已接收{len(data)}字节二进制数据,加入队列") print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 图像数据({len(data)}字节)已加入队列")
except asyncio.QueueFull: except asyncio.QueueFull:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:队列突然满了,丢弃{len(data)}字节数据") print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 帧队列已满、丢弃当前图像数据")
# WebSocket路由配置
ws_router = APIRouter()
# -------------------------- WebSocket核心端点修改连接初始化逻辑 --------------------------
@ws_router.websocket(WS_ENDPOINT) @ws_router.websocket(WS_ENDPOINT)
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
# 接受连接 + 获取客户端IP load_model() # 加载检测模型(仅在连接建立时加载一次,避免重复加载)
await websocket.accept() await websocket.accept()
client_ip = websocket.client.host if websocket.client else "unknown" client_ip = websocket.client.host if websocket.client else "unknown_ip"
now = datetime.datetime.now() current_time = get_current_time_str()
print(f"[{now:%H:%M:%S}] 客户端{client_ip}:连接成功") print(f"[{current_time}] 客户端{client_ip}: WebSocket连接已建立")
is_online_updated = False
consumer_task = None
try: try:
# 处理重复连接(关闭旧连接) # 处理重复连接(同一IP断开旧连接)
if client_ip in connected_clients: if client_ip in connected_clients:
# 取消旧连接的消费者任务 old_conn = connected_clients[client_ip]
if connected_clients[client_ip].consumer_task and not connected_clients[client_ip].consumer_task.done(): if old_conn.consumer_task and not old_conn.consumer_task.done():
connected_clients[client_ip].consumer_task.cancel() old_conn.consumer_task.cancel()
await connected_clients[client_ip].websocket.close(code=1008, reason="同一IP新连接") await old_conn.websocket.close(code=1008, reason="同一IP新连接建立")
connected_clients.pop(client_ip) connected_clients.pop(client_ip)
print(f"[{now:%H:%M:%S}] 客户端{client_ip}关闭旧连接") print(f"[{current_time}] 客户端{client_ip}: 已关闭旧连接")
# 注册新连接 # 注册新连接绑定client_ip和WebSocket
new_conn = ClientConnection(websocket, client_ip) new_conn = ClientConnection(websocket, client_ip)
connected_clients[client_ip] = new_conn connected_clients[client_ip] = new_conn
new_conn.start_consumer() # 启动帧消费任务
await new_conn.send_frame_permit() # 发送首次帧许可
# 启动帧消费任务 # 标记客户端上线
consumer_task = new_conn.start_consumer() try:
# ---------- 修改:客户端刚连接时,队列空,立即发送「允许发送帧」信号 ---------- await asyncio.to_thread(update_online_status_by_ip, client_ip, 1)
await new_conn.send_allow_send_frame() action_data = DeviceActionCreate(client_ip=client_ip, action=1)
print(f"[{now:%H:%M:%S}] 客户端{client_ip}:注册成功,已启动帧消费任务,当前在线{len(connected_clients)}") await asyncio.to_thread(add_device_action, action_data)
print(f"[{current_time}] 客户端{client_ip}: 已标记为在线并记录操作")
is_online_updated = True
except Exception as e:
print(f"[{current_time}] 客户端{client_ip}: 上线状态更新失败 - {str(e)}")
# 循环接收消息(原有逻辑不变) print(f"[{current_time}] 客户端{client_ip}: 新连接注册成功、在线数: {len(connected_clients)}")
# 消息循环(持续接收客户端消息)
while True: while True:
data = await websocket.receive() data = await websocket.receive()
if "text" in data: if "text" in data:
await handle_text_msg(client_ip, data["text"], new_conn) await handle_text_msg(new_conn, data["text"])
elif "bytes" in data: elif "bytes" in data:
await handle_binary_msg(client_ip, data["bytes"]) await handle_binary_msg(new_conn, data["bytes"])
# 异常处理(断开/错误)
except WebSocketDisconnect as e: except WebSocketDisconnect as e:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}主动断开(代码{e.code}") print(f"[{get_current_time_str()}] 客户端{client_ip}: 主动断开连接(代码: {e.code}")
except Exception as e: except Exception as e:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}连接异常{str(e)[:50]}") print(f"[{get_current_time_str()}] 客户端{client_ip}: 连接异常 - {str(e)[:50]}")
finally: finally:
# 清理连接和任务 # 清理资源(断开后标记离线+删除连接)
if client_ip in connected_clients: if client_ip in connected_clients:
# 取消消费者任务 conn = connected_clients[client_ip]
if connected_clients[client_ip].consumer_task and not connected_clients[client_ip].consumer_task.done(): if conn.consumer_task and not conn.consumer_task.done():
connected_clients[client_ip].consumer_task.cancel() conn.consumer_task.cancel()
# 仅当上线状态更新成功时,才执行离线更新
if is_online_updated:
try:
await asyncio.to_thread(update_online_status_by_ip, client_ip, 0)
action_data = DeviceActionCreate(client_ip=client_ip, action=0)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{get_current_time_str()}] 客户端{client_ip}: 断开后已标记为离线")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}: 断开后离线更新失败 - {str(e)}")
connected_clients.pop(client_ip, None) connected_clients.pop(client_ip, None)
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:连接已清理,当前在线{len(connected_clients)}") print(f"[{get_current_time_str()}] 客户端{client_ip}: 资源已清理、在线数: {len(connected_clients)}")