最新可用
This commit is contained in:
45
core/all.py
Normal file
45
core/all.py
Normal file
@ -0,0 +1,45 @@
|
||||
from core.ocr import load_model as ocrLoadModel, detect as ocrDetect
|
||||
from core.face import load_model as faceLoadModel, detect as faceDetect
|
||||
from core.yolo import load_model as yoloLoadModel, detect as yoloDetect
|
||||
|
||||
# 添加一个标记变量,用于监控load_model是否已被调用
|
||||
_model_loaded = False
|
||||
|
||||
|
||||
def load_model():
|
||||
global _model_loaded
|
||||
|
||||
# 如果已经调用过,直接忽略
|
||||
if _model_loaded:
|
||||
return
|
||||
|
||||
# 首次调用时加载模型
|
||||
ocrLoadModel()
|
||||
faceLoadModel()
|
||||
yoloLoadModel()
|
||||
|
||||
# 标记为已调用
|
||||
_model_loaded = True
|
||||
|
||||
|
||||
def detect(frame):
|
||||
# 先进行YOLO检测
|
||||
yolo_flag, yolo_result = yoloDetect(frame)
|
||||
print("YOLO检测结果:", yolo_result)
|
||||
if yolo_flag:
|
||||
return (True, yolo_result, "yolo")
|
||||
|
||||
# YOLO未检测到,进行人脸检测
|
||||
face_flag, face_result = faceDetect(frame)
|
||||
print("人脸检测结果:", face_result)
|
||||
if face_flag:
|
||||
return (True, face_result, "face")
|
||||
|
||||
# 人脸未检测到,进行OCR检测
|
||||
ocr_flag, ocr_result = ocrDetect(frame)
|
||||
print("OCR检测结果:", ocr_result)
|
||||
if ocr_flag:
|
||||
return (True, ocr_result, "ocr")
|
||||
|
||||
# 所有检测都未检测到
|
||||
return (False, "未检测到任何内容", "none")
|
119
core/config/config.yaml
Normal file
119
core/config/config.yaml
Normal file
@ -0,0 +1,119 @@
|
||||
Global:
|
||||
text_score: 0.5
|
||||
|
||||
use_det: true
|
||||
use_cls: true
|
||||
use_rec: true
|
||||
|
||||
min_height: 30
|
||||
width_height_ratio: 8
|
||||
max_side_len: 2000
|
||||
min_side_len: 30
|
||||
|
||||
return_word_box: false
|
||||
return_single_char_box: false
|
||||
|
||||
font_path: null
|
||||
|
||||
EngineConfig:
|
||||
onnxruntime:
|
||||
intra_op_num_threads: -1
|
||||
inter_op_num_threads: -1
|
||||
enable_cpu_mem_arena: false
|
||||
|
||||
cpu_ep_cfg:
|
||||
arena_extend_strategy: "kSameAsRequested"
|
||||
|
||||
use_cuda: true # 改为true以启用CUDA
|
||||
cuda_ep_cfg:
|
||||
device_id: 0
|
||||
arena_extend_strategy: "kNextPowerOfTwo"
|
||||
cudnn_conv_algo_search: "EXHAUSTIVE"
|
||||
do_copy_in_default_stream: true
|
||||
|
||||
use_dml: false
|
||||
dm_ep_cfg: null
|
||||
|
||||
use_cann: false
|
||||
cann_ep_cfg:
|
||||
device_id: 0
|
||||
arena_extend_strategy: "kNextPowerOfTwo"
|
||||
npu_mem_limit: 21474836480 # 20 * 1024 * 1024 * 1024
|
||||
op_select_impl_mode: "high_performance"
|
||||
optypelist_for_implmode: "Gelu"
|
||||
enable_cann_graph: true
|
||||
|
||||
openvino:
|
||||
inference_num_threads: -1
|
||||
performance_hint: null
|
||||
performance_num_requests: -1
|
||||
enable_cpu_pinning: null
|
||||
num_streams: -1
|
||||
enable_hyper_threading: null
|
||||
scheduling_core_type: null
|
||||
|
||||
paddle:
|
||||
cpu_math_library_num_threads: -1
|
||||
use_npu: false
|
||||
npu_id: 0
|
||||
use_cuda: true # 改为true以启用CUDA
|
||||
gpu_id: 0
|
||||
gpu_mem: 500
|
||||
|
||||
torch:
|
||||
use_cuda: true # 已经是true
|
||||
gpu_id: 0
|
||||
|
||||
Det:
|
||||
engine_type: "torch"
|
||||
lang_type: "ch"
|
||||
model_type: "mobile"
|
||||
ocr_version: "PP-OCRv4"
|
||||
|
||||
task_type: "det"
|
||||
|
||||
model_path: null
|
||||
model_dir: null
|
||||
|
||||
limit_side_len: 736
|
||||
limit_type: min
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
|
||||
thresh: 0.3
|
||||
box_thresh: 0.5
|
||||
max_candidates: 1000
|
||||
unclip_ratio: 1.6
|
||||
use_dilation: true
|
||||
score_mode: fast
|
||||
|
||||
Cls:
|
||||
engine_type: "torch"
|
||||
lang_type: "ch"
|
||||
model_type: "mobile"
|
||||
ocr_version: "PP-OCRv4"
|
||||
|
||||
task_type: "cls"
|
||||
|
||||
model_path: null
|
||||
model_dir: null
|
||||
|
||||
cls_image_shape: [3, 48, 192]
|
||||
cls_batch_num: 6
|
||||
cls_thresh: 0.9
|
||||
label_list: ["0", "180"]
|
||||
|
||||
Rec:
|
||||
engine_type: "torch"
|
||||
lang_type: "ch"
|
||||
model_type: "mobile"
|
||||
ocr_version: "PP-OCRv4"
|
||||
|
||||
task_type: "rec"
|
||||
|
||||
model_path: null
|
||||
model_dir: null
|
||||
|
||||
rec_keys_path: null
|
||||
rec_img_shape: [3, 48, 320]
|
||||
rec_batch_num: 6
|
113
core/face.py
Normal file
113
core/face.py
Normal file
@ -0,0 +1,113 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image # 确保正确导入Image类
|
||||
from insightface.app import FaceAnalysis
|
||||
# 导入获取人脸信息的服务
|
||||
from service.face_service import get_all_face_name_with_eigenvalue
|
||||
|
||||
# 全局变量
|
||||
_face_app = None
|
||||
_known_faces_embeddings = {} # 存储姓名到特征值的映射
|
||||
_known_faces_names = [] # 存储所有已知姓名
|
||||
|
||||
|
||||
def load_model():
|
||||
"""加载人脸识别模型及已知人脸特征库"""
|
||||
global _face_app, _known_faces_embeddings, _known_faces_names
|
||||
|
||||
# 初始化InsightFace模型
|
||||
try:
|
||||
_face_app = FaceAnalysis(name='buffalo_l', root=os.path.expanduser('~/.insightface'))
|
||||
_face_app.prepare(ctx_id=0, det_size=(640, 640))
|
||||
except Exception as e:
|
||||
print(f"Face model load failed: {e}")
|
||||
return False
|
||||
|
||||
# 从服务获取所有人脸姓名和特征值
|
||||
try:
|
||||
face_data = get_all_face_name_with_eigenvalue()
|
||||
|
||||
# 处理获取到的人脸数据
|
||||
for person_name, eigenvalue_data in face_data.items():
|
||||
# 处理特征值数据 - 兼容数组和字符串两种格式
|
||||
if isinstance(eigenvalue_data, np.ndarray):
|
||||
# 如果已经是numpy数组,直接使用
|
||||
eigenvalue = eigenvalue_data.astype(np.float32)
|
||||
elif isinstance(eigenvalue_data, str):
|
||||
# 清理字符串:移除方括号、换行符和多余空格
|
||||
cleaned = eigenvalue_data.replace('[', '').replace(']', '').replace('\n', '').strip()
|
||||
# 按空格或逗号分割(处理可能的不同分隔符)
|
||||
values = [v for v in cleaned.split() if v]
|
||||
# 转换为数组
|
||||
eigenvalue = np.array(list(map(float, values)), dtype=np.float32)
|
||||
else:
|
||||
# 不支持的类型
|
||||
print(f"Unsupported eigenvalue type for {person_name}")
|
||||
continue
|
||||
|
||||
# 归一化处理
|
||||
norm = np.linalg.norm(eigenvalue)
|
||||
if norm != 0:
|
||||
eigenvalue = eigenvalue / norm
|
||||
|
||||
_known_faces_embeddings[person_name] = eigenvalue
|
||||
_known_faces_names.append(person_name)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading face data from service: {e}")
|
||||
|
||||
return True if _face_app else False
|
||||
|
||||
|
||||
def detect(frame, threshold=0.4):
|
||||
"""检测并识别人脸,返回结果元组(是否匹配到已知人脸, 结果字符串)"""
|
||||
global _face_app, _known_faces_embeddings, _known_faces_names
|
||||
|
||||
if not _face_app or not _known_faces_names or frame is None:
|
||||
return (False, "未初始化或无效帧")
|
||||
|
||||
try:
|
||||
faces = _face_app.get(frame)
|
||||
except Exception as e:
|
||||
print(f"Face detect error: {e}")
|
||||
return (False, f"检测错误: {str(e)}")
|
||||
|
||||
result_parts = []
|
||||
has_matched = False # 新增标记:是否有匹配到的已知人脸
|
||||
|
||||
for face in faces:
|
||||
# 特征归一化
|
||||
embedding = face.embedding.astype(np.float32)
|
||||
norm = np.linalg.norm(embedding)
|
||||
if norm == 0:
|
||||
continue
|
||||
embedding = embedding / norm
|
||||
|
||||
# 对比已知人脸
|
||||
max_sim, best_name = -1.0, "Unknown"
|
||||
for name in _known_faces_names:
|
||||
known_emb = _known_faces_embeddings[name]
|
||||
sim = np.dot(embedding, known_emb)
|
||||
if sim > max_sim:
|
||||
max_sim = sim
|
||||
best_name = name
|
||||
|
||||
# 判断匹配结果
|
||||
is_match = max_sim >= threshold
|
||||
if is_match:
|
||||
has_matched = True # 只要有一个匹配成功,就标记为True
|
||||
|
||||
bbox = face.bbox
|
||||
result_parts.append(
|
||||
f"{'匹配' if is_match else '不匹配'}: {best_name} (相似度: {max_sim:.2f}, 边界框: {bbox})"
|
||||
)
|
||||
|
||||
# 构建结果字符串
|
||||
if not result_parts:
|
||||
result_str = "未检测到人脸"
|
||||
else:
|
||||
result_str = "; ".join(result_parts)
|
||||
|
||||
# 第一个返回值改为:是否匹配到已知人脸
|
||||
return (has_matched, result_str)
|
BIN
core/models/best.pt
Normal file
BIN
core/models/best.pt
Normal file
Binary file not shown.
76
core/ocr.py
Normal file
76
core/ocr.py
Normal file
@ -0,0 +1,76 @@
|
||||
import os
|
||||
import cv2
|
||||
from rapidocr import RapidOCR
|
||||
from service.sensitive_service import get_all_sensitive_words
|
||||
|
||||
# 全局变量
|
||||
_ocr_engine = None
|
||||
_forbidden_words = set()
|
||||
_conf_threshold = 0.5
|
||||
|
||||
ocr_config_path = os.path.join(os.path.dirname(__file__), "config", "config.yaml")
|
||||
|
||||
|
||||
def load_model():
|
||||
"""加载OCR引擎及违禁词列表"""
|
||||
global _ocr_engine, _forbidden_words, _conf_threshold
|
||||
|
||||
# 加载违禁词
|
||||
try:
|
||||
_forbidden_words = get_all_sensitive_words()
|
||||
except Exception as e:
|
||||
print(f"Forbidden words load error: {e}")
|
||||
|
||||
# 初始化OCR引擎
|
||||
if not os.path.exists(ocr_config_path):
|
||||
print(f"OCR config not found: {ocr_config_path}")
|
||||
return False
|
||||
|
||||
try:
|
||||
_ocr_engine = RapidOCR(config_path=ocr_config_path)
|
||||
except Exception as e:
|
||||
print(f"OCR model load failed: {e}")
|
||||
return False
|
||||
|
||||
return True if _ocr_engine else False
|
||||
|
||||
|
||||
def detect(frame):
|
||||
"""OCR检测并筛选违禁词,返回(是否检测到违禁词, 结果字符串)"""
|
||||
if not _ocr_engine or not _forbidden_words or frame is None or frame.size == 0:
|
||||
return (False, "未初始化或无效帧")
|
||||
|
||||
try:
|
||||
ocr_res = _ocr_engine(frame)
|
||||
except Exception as e:
|
||||
print(f"OCR detect error: {e}")
|
||||
return (False, f"检测错误: {str(e)}")
|
||||
|
||||
if not ocr_res or not hasattr(ocr_res, 'txts') or not hasattr(ocr_res, 'scores'):
|
||||
return (False, "无OCR结果")
|
||||
|
||||
# 处理OCR结果
|
||||
texts = [t.strip() for t in ocr_res.txts if t and isinstance(t, str)]
|
||||
confs = [c for c in ocr_res.scores if c and isinstance(c, (int, float))]
|
||||
if len(texts) != len(confs):
|
||||
return (False, "OCR结果格式异常")
|
||||
|
||||
# 筛选违禁词
|
||||
vio_info = []
|
||||
for txt, conf in zip(texts, confs):
|
||||
if conf < _conf_threshold:
|
||||
continue
|
||||
matched = [w for w in _forbidden_words if w in txt]
|
||||
if matched:
|
||||
vio_info.append(f"文本: '{txt}' 包含违禁词: {', '.join(matched)} (置信度: {conf:.2f})")
|
||||
|
||||
# 构建结果字符串
|
||||
has_text = len(texts) > 0
|
||||
has_violation = len(vio_info) > 0
|
||||
|
||||
if not has_text:
|
||||
return (False, "未识别到文本")
|
||||
elif has_violation:
|
||||
return (True, "; ".join(vio_info))
|
||||
else:
|
||||
return (False, "未检测到违禁词")
|
137
core/rtc.py
137
core/rtc.py
@ -1,137 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from aiortc import RTCPeerConnection, RTCSessionDescription
|
||||
import aiohttp
|
||||
from ocr.ocr_violation_detector import OCRViolationDetector
|
||||
|
||||
import logging
|
||||
|
||||
# 创建检测器实例
|
||||
detector = OCRViolationDetector(
|
||||
forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt",
|
||||
ocr_confidence_threshold=0.7,
|
||||
log_level=logging.INFO,
|
||||
log_file="ocr_detection.log"
|
||||
)
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("whep_video_puller")
|
||||
|
||||
|
||||
async def whep_pull_video_stream(ip,whep_url):
|
||||
"""
|
||||
通过WHEP从指定URL拉取视频流并在收到每一帧时打印消息
|
||||
|
||||
Args:
|
||||
whep_url: WHEP端点的URL
|
||||
"""
|
||||
pc = RTCPeerConnection()
|
||||
|
||||
# 添加连接状态变化监听
|
||||
@pc.on("connectionstatechange")
|
||||
async def on_connectionstatechange():
|
||||
print(f"连接状态: {pc.connectionState}")
|
||||
|
||||
# 添加ICE连接状态变化监听
|
||||
@pc.on("iceconnectionstatechange")
|
||||
async def on_iceconnectionstatechange():
|
||||
print(f"ICE连接状态: {pc.iceConnectionState}")
|
||||
|
||||
# 添加视频接收器
|
||||
pc.addTransceiver("video", direction="recvonly")
|
||||
|
||||
# 处理接收到的视频轨道
|
||||
@pc.on("track")
|
||||
def on_track(track):
|
||||
print(f"接收到轨道: {track.kind}")
|
||||
if track.kind == "video":
|
||||
print(f"轨道ID: {track.id}")
|
||||
print(f"轨道就绪状态: {track.readyState}")
|
||||
# 创建异步任务来处理视频帧
|
||||
asyncio.ensure_future(handle_video_track(track))
|
||||
|
||||
async def handle_video_track(track):
|
||||
"""处理视频轨道,接收并打印每一帧"""
|
||||
frame_count = 0
|
||||
print("开始处理视频轨道...")
|
||||
|
||||
while True:
|
||||
try:
|
||||
# 尝试接收帧
|
||||
frame = await track.recv()
|
||||
frame_count += 1
|
||||
print(f"收到原始帧 (第{frame_count}帧)")
|
||||
|
||||
# 打印帧的基本信息
|
||||
if hasattr(frame, 'width') and hasattr(frame, 'height'):
|
||||
print(f" 尺寸: {frame.width}x{frame.height}")
|
||||
if hasattr(frame, 'time_base'):
|
||||
print(f" 时间基准: {frame.time_base}")
|
||||
if hasattr(frame, 'pts'):
|
||||
print(f" 显示时间戳: {frame.pts}")
|
||||
|
||||
has_violation, violations, confidences = OCRViolationDetector.detect(frame)
|
||||
|
||||
# 输出检测结果
|
||||
if has_violation:
|
||||
detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
|
||||
for word, conf in zip(violations, confidences):
|
||||
detector.logger.info(f"- {word} (置信度: {conf:.4f})")
|
||||
else:
|
||||
detector.logger.info("图片中未检测到违禁词")
|
||||
except Exception as e:
|
||||
print(f"接收帧时出错: {e}")
|
||||
# 等待一段时间后重试
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
# 创建offer
|
||||
offer = await pc.createOffer()
|
||||
await pc.setLocalDescription(offer)
|
||||
|
||||
print(f"本地SDP信息:\n{offer.sdp}")
|
||||
|
||||
# 通过HTTP POST发送offer到WHEP端点
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
whep_url,
|
||||
data=offer.sdp,
|
||||
headers={"Content-Type": "application/sdp"}
|
||||
) as response:
|
||||
if response.status != 201:
|
||||
print(f"WHEP服务器返回错误: {response.status}")
|
||||
print(f"响应内容: {await response.text()}")
|
||||
raise Exception(f"WHEP服务器返回错误: {response.status}")
|
||||
|
||||
# 获取answer SDP
|
||||
answer_sdp = await response.text()
|
||||
|
||||
# 创建RTCSessionDescription对象
|
||||
answer = RTCSessionDescription(sdp=answer_sdp, type="answer")
|
||||
|
||||
print(f"收到远程SDP:\n{answer_sdp}")
|
||||
|
||||
# 设置远程描述
|
||||
await pc.setRemoteDescription(answer)
|
||||
|
||||
print("连接已建立,开始接收视频流...")
|
||||
|
||||
# 保持连接,直到用户中断
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
# 检查连接状态
|
||||
print(f"当前连接状态: {pc.connectionState}")
|
||||
except KeyboardInterrupt:
|
||||
print("用户中断,关闭连接...")
|
||||
finally:
|
||||
await pc.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 替换为你的WHEP端点URL
|
||||
WHEP_URL = "http://192.168.110.25:1985/rtc/v1/whep/?app=live&stream=473b95a47e338301cbd96809ea7ac416"
|
||||
|
||||
# 运行拉流任务
|
||||
asyncio.run(whep_pull_video_stream(WHEP_URL))
|
112
core/rtmp.py
112
core/rtmp.py
@ -1,112 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import cv2
|
||||
import time
|
||||
from ocr.model_violation_detector import MultiModelViolationDetector
|
||||
|
||||
|
||||
# 配置文件相对路径(根据实际目录结构调整)
|
||||
YOLO_MODEL_PATH = "../ocr/models/best.pt" # 关键修正:从core目录向上一级找ocr文件夹
|
||||
FORBIDDEN_WORDS_PATH = "../ocr/forbidden_words.txt"
|
||||
OCR_CONFIG_PATH = "../ocr/config/1.yaml"
|
||||
KNOWN_FACES_DIR = "../ocr/known_faces"
|
||||
|
||||
# 创建检测器实例
|
||||
detector = MultiModelViolationDetector(
|
||||
forbidden_words_path=FORBIDDEN_WORDS_PATH,
|
||||
ocr_config_path=OCR_CONFIG_PATH,
|
||||
yolo_model_path=YOLO_MODEL_PATH,
|
||||
known_faces_dir=KNOWN_FACES_DIR,
|
||||
ocr_confidence_threshold=0.5
|
||||
)
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("rtmp_video_puller")
|
||||
|
||||
|
||||
async def rtmp_pull_video_stream(rtmp_url):
|
||||
"""
|
||||
通过RTMP从指定URL拉取视频流并进行违规检测
|
||||
"""
|
||||
cap = None # 初始化视频捕获对象
|
||||
try:
|
||||
# 异步打开RTMP流
|
||||
cap = await asyncio.to_thread(
|
||||
cv2.VideoCapture,
|
||||
rtmp_url,
|
||||
cv2.CAP_FFMPEG # 指定FFmpeg后端确保RTMP兼容性
|
||||
)
|
||||
|
||||
# 检查RTMP流是否成功打开
|
||||
is_opened = await asyncio.to_thread(cap.isOpened)
|
||||
if not is_opened:
|
||||
raise Exception(f"RTMP流打开失败: {rtmp_url}(请检查URL有效性和FFmpeg环境)")
|
||||
|
||||
# 获取RTMP流基础信息
|
||||
width = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_WIDTH)
|
||||
height = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_HEIGHT)
|
||||
fps = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FPS)
|
||||
|
||||
# 处理异常情况
|
||||
fps = fps if fps > 0 else 30.0
|
||||
width, height = int(width), int(height)
|
||||
|
||||
# 打印流初始化成功信息
|
||||
print(f"RTMP流状态: 已成功连接")
|
||||
print(f"流基础信息: 分辨率 {width}x{height} | 配置帧率 {fps:.2f} FPS")
|
||||
print("开始接收视频帧...(按 Ctrl+C 中断)")
|
||||
|
||||
# 初始化帧统计参数
|
||||
frame_count = 0
|
||||
start_time = time.time()
|
||||
|
||||
# 循环读取视频帧
|
||||
while True:
|
||||
ret, frame = await asyncio.to_thread(cap.read)
|
||||
|
||||
if not ret:
|
||||
print(f"RTMP流状态: 帧读取失败(可能流已中断或结束)")
|
||||
break
|
||||
|
||||
frame_count += 1
|
||||
|
||||
# 打印当前帧信息
|
||||
print(f"收到帧 (第{frame_count}帧)")
|
||||
print(f" 帧尺寸: {width}x{height}")
|
||||
print(f" 配置帧率: {fps:.2f} FPS")
|
||||
|
||||
if frame is not None:
|
||||
has_violation, violation_type, details = detector.detect_violations(frame)
|
||||
if has_violation:
|
||||
print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
|
||||
else:
|
||||
print("未检测到任何违规内容")
|
||||
else:
|
||||
print(f"无法读取测试图像")
|
||||
|
||||
# 每100帧统计一次实际接收帧率
|
||||
if frame_count % 100 == 0:
|
||||
elapsed_time = time.time() - start_time
|
||||
actual_fps = frame_count / elapsed_time
|
||||
print(f"---- 帧统计: 累计{frame_count}帧 | 实际平均帧率 {actual_fps:.2f} FPS ----")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n用户操作: 已通过 Ctrl+C 中断程序")
|
||||
except Exception as e:
|
||||
logger.error(f"RTMP流处理异常: {str(e)}", exc_info=True)
|
||||
print(f"错误信息: {str(e)}")
|
||||
finally:
|
||||
if cap is not None:
|
||||
await asyncio.to_thread(cap.release)
|
||||
print(f"\n资源释放: RTMP流已关闭")
|
||||
print(f"最终统计: 共接收 {frame_count if 'frame_count' in locals() else 0} 帧")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
RTMP_URL = "rtmp://192.168.110.25:1935/live/473b95a47e338301cbd96809ea7ac416"
|
||||
|
||||
try:
|
||||
asyncio.run(rtmp_pull_video_stream(RTMP_URL))
|
||||
except Exception as e:
|
||||
print(f"程序启动失败: {str(e)}")
|
55
core/yolo.py
Normal file
55
core/yolo.py
Normal file
@ -0,0 +1,55 @@
|
||||
import os
|
||||
|
||||
import cv2
|
||||
from ultralytics import YOLO
|
||||
|
||||
# 全局变量
|
||||
_yolo_model = None
|
||||
|
||||
|
||||
model_path = os.path.join(os.path.dirname(__file__), "models", "best.pt")
|
||||
|
||||
|
||||
def load_model():
|
||||
"""加载YOLO目标检测模型"""
|
||||
global _yolo_model
|
||||
|
||||
try:
|
||||
_yolo_model = YOLO(model_path)
|
||||
except Exception as e:
|
||||
print(f"YOLO model load failed: {e}")
|
||||
return False
|
||||
|
||||
return True if _yolo_model else False
|
||||
|
||||
|
||||
def detect(frame, conf_threshold=0.2):
|
||||
"""YOLO目标检测,返回(是否识别到, 结果字符串)"""
|
||||
global _yolo_model
|
||||
|
||||
if not _yolo_model or frame is None:
|
||||
return (False, "未初始化或无效帧")
|
||||
|
||||
try:
|
||||
results = _yolo_model(frame, conf=conf_threshold)
|
||||
# 检查是否有检测结果
|
||||
has_results = len(results[0].boxes) > 0 if results else False
|
||||
|
||||
if not has_results:
|
||||
return (False, "未检测到目标")
|
||||
|
||||
# 构建结果字符串
|
||||
result_parts = []
|
||||
for box in results[0].boxes:
|
||||
cls = int(box.cls[0])
|
||||
conf = float(box.conf[0])
|
||||
bbox = [float(x) for x in box.xyxy[0]]
|
||||
class_name = _yolo_model.names[cls] if hasattr(_yolo_model, 'names') else f"类别{cls}"
|
||||
result_parts.append(f"{class_name} (置信度: {conf:.2f}, 边界框: {bbox})")
|
||||
|
||||
result_str = "; ".join(result_parts)
|
||||
return (has_results, result_str)
|
||||
|
||||
except Exception as e:
|
||||
print(f"YOLO detect error: {e}")
|
||||
return (False, f"检测错误: {str(e)}")
|
Reference in New Issue
Block a user