读取帧优化
This commit is contained in:
267
rtc/rtc.py
267
rtc/rtc.py
@ -1,64 +1,71 @@
|
||||
import queue
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
import aiohttp
|
||||
import cv2
|
||||
import numpy as np
|
||||
import threading
|
||||
import time
|
||||
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration
|
||||
from ocr.ocr_violation_detector import OCRViolationDetector
|
||||
from ws.ws import send_message_to_client
|
||||
from aiortc.mediastreams import MediaStreamTrack
|
||||
|
||||
# 创建一个长度为1的队列,用于生产者和消费者之间的通信
|
||||
frame_queue = queue.Queue(maxsize=1)
|
||||
|
||||
|
||||
async def rtc_frame_receiver(url, frame_queue, stop_event):
|
||||
class VideoTrack(MediaStreamTrack):
|
||||
"""自定义视频轨道类,继承自MediaStreamTrack"""
|
||||
kind = "video"
|
||||
|
||||
def __init__(self, max_frames=100):
|
||||
super().__init__()
|
||||
self.frames = queue.Queue(maxsize=max_frames)
|
||||
|
||||
async def recv(self):
|
||||
return await super().recv()
|
||||
|
||||
|
||||
def webrtc_producer(webrtc_url):
|
||||
"""
|
||||
接收RTC帧并往队列放入cv2格式的帧数据
|
||||
当队列已满时直接丢弃新帧,不阻塞等待
|
||||
当stop_event被设置时停止接收
|
||||
生产者方法:从WEBRTC读取视频帧并放入队列
|
||||
仅当队列空时才放入新帧,否则丢弃
|
||||
"""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# 创建RTCPeerConnection对象,不使用ICE服务器
|
||||
pc = RTCPeerConnection(RTCConfiguration(iceServers=[]))
|
||||
|
||||
# 累计帧计数器和丢弃帧计数器
|
||||
total_frames = 0
|
||||
dropped_frames = 0
|
||||
video_track = VideoTrack()
|
||||
pc.addTrack(video_track)
|
||||
|
||||
@pc.on("track")
|
||||
async def on_track(track):
|
||||
nonlocal total_frames, dropped_frames
|
||||
if track.kind == "video":
|
||||
print("接收到视频轨道、开始接收视频帧")
|
||||
while not stop_event.is_set(): # 检查是否需要停止
|
||||
# 接收当前帧并累计计数
|
||||
print("接收到视频轨道,开始接收视频帧")
|
||||
while True:
|
||||
# 从轨道接收视频帧
|
||||
frame = await track.recv()
|
||||
total_frames += 1
|
||||
# 转换为BGR24格式的NumPy数组
|
||||
frame_bgr24 = frame.to_ndarray(format='bgr24')
|
||||
|
||||
# 转换为cv2兼容的BGR格式numpy数组
|
||||
frame_cv2 = frame.to_ndarray(format='bgr24')
|
||||
|
||||
# 验证是否为cv2兼容格式
|
||||
if isinstance(frame_cv2, np.ndarray) and frame_cv2.ndim == 3 and frame_cv2.shape[2] == 3:
|
||||
# 检查队列是否已满
|
||||
if frame_queue.full():
|
||||
# 队列已满,丢弃当前帧
|
||||
dropped_frames += 1
|
||||
print(f"第{total_frames}帧:队列已满,丢弃该帧(累计丢弃: {dropped_frames})")
|
||||
else:
|
||||
# 队列未满,放入当前帧
|
||||
await frame_queue.put(frame_cv2)
|
||||
print(f"第{total_frames}帧:已放入队列")
|
||||
# 检查队列是否为空,为空则加入,否则丢弃
|
||||
if frame_queue.empty():
|
||||
try:
|
||||
frame_queue.put_nowait(frame_bgr24)
|
||||
print("帧已放入队列")
|
||||
except queue.Full:
|
||||
print("队列已满,丢弃帧")
|
||||
else:
|
||||
print("帧格式转换失败,不是有效的cv2格式")
|
||||
print("队列非空,丢弃帧")
|
||||
|
||||
# 创建并设置本地offer
|
||||
offer = await pc.createOffer()
|
||||
print("已创建本地 SDP Offer")
|
||||
await pc.setLocalDescription(offer)
|
||||
async def main():
|
||||
# 创建并发送SDP Offer
|
||||
offer = await pc.createOffer()
|
||||
print("已创建本地SDP Offer")
|
||||
await pc.setLocalDescription(offer)
|
||||
|
||||
# 发送offer到服务器
|
||||
async with aiohttp.ClientSession() as session:
|
||||
print("开始向服务器发送 SDP Offer")
|
||||
try:
|
||||
# 发送Offer到服务器并接收Answer
|
||||
async with aiohttp.ClientSession() as session:
|
||||
print(f"开始向服务器 {webrtc_url} 发送SDP Offer")
|
||||
async with session.post(
|
||||
url,
|
||||
webrtc_url,
|
||||
data=offer.sdp.encode(),
|
||||
headers={
|
||||
"Content-Type": "application/sdp",
|
||||
@ -66,107 +73,83 @@ async def rtc_frame_receiver(url, frame_queue, stop_event):
|
||||
},
|
||||
ssl=False
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
print("已接收到服务器的响应、开始处理 SDP Answer")
|
||||
answer_sdp = await response.text()
|
||||
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer'))
|
||||
else:
|
||||
print(f"服务器响应错误: {response.status}")
|
||||
stop_event.set()
|
||||
except Exception as e:
|
||||
print(f"发送SDP Offer失败: {str(e)}")
|
||||
stop_event.set()
|
||||
print("已接收到服务器的响应")
|
||||
answer_sdp = await response.text()
|
||||
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer'))
|
||||
|
||||
try:
|
||||
# 保持连接,直到收到停止信号
|
||||
while not stop_event.is_set():
|
||||
await asyncio.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
print("用户中断")
|
||||
finally:
|
||||
print(f"开始关闭 RTCPeerConnection,共接收{total_frames}帧,丢弃{dropped_frames}帧")
|
||||
await pc.close()
|
||||
print("已关闭 RTCPeerConnection")
|
||||
|
||||
|
||||
async def frame_consumer(ip, frame_queue, stop_event):
|
||||
"""
|
||||
从队列中阻塞读取cv2帧并处理(队列为空时阻塞等待)
|
||||
检测到违规内容后设置stop_event以终止所有任务
|
||||
"""
|
||||
# 创建OCR检测器实例
|
||||
ocr_detector = OCRViolationDetector(
|
||||
forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt",
|
||||
ocr_confidence_threshold=0.5, )
|
||||
|
||||
while not stop_event.is_set(): # 检查是否需要停止
|
||||
# 保持连接
|
||||
try:
|
||||
# 阻塞等待队列中的帧
|
||||
current_frame = await frame_queue.get()
|
||||
|
||||
# 进行OCR检测
|
||||
has_violation, words, confidences = ocr_detector.detect(current_frame)
|
||||
print(f"检测结果: {'有违规内容' if has_violation else '无违规内容'}")
|
||||
print(f"检测到的词: {words}")
|
||||
print(f"置信度: {confidences}")
|
||||
|
||||
# 输出所有检测到的违禁词
|
||||
if has_violation:
|
||||
print(f"测试结果:图片中共检测到 {len(words)} 个违禁词:")
|
||||
response_data = {
|
||||
"status": "stop",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"violations": [{"word": w, "confidence": c} for w, c in zip(words, confidences)]
|
||||
}
|
||||
await send_message_to_client(ip, response_data)
|
||||
|
||||
for word, conf in zip(words, confidences):
|
||||
print(f"- {word}(置信度:{conf:.4f})")
|
||||
|
||||
# 检测到违规,设置停止事件
|
||||
print("检测到违规内容,准备关闭AI检测")
|
||||
stop_event.set()
|
||||
|
||||
# 标记任务完成,允许生产者放入新的帧
|
||||
frame_queue.task_done()
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理帧时发生错误: {str(e)}")
|
||||
frame_queue.task_done()
|
||||
|
||||
|
||||
def process_webrtc_stream(ip, webrtc_url):
|
||||
"""
|
||||
处理WEBRTC流并持续打印OCR检测结果,检测到违规后关闭
|
||||
队列大小为1,满时直接丢弃新帧
|
||||
|
||||
Args:
|
||||
ip: IP地址
|
||||
webrtc_url: WEBRTC服务器地址
|
||||
"""
|
||||
# 创建队列(大小为1)和停止事件
|
||||
frame_queue = asyncio.Queue(maxsize=1) # 只存储一帧
|
||||
stop_event = asyncio.Event() # 用于控制任务停止的事件
|
||||
|
||||
# 定义事件循环中的主任务
|
||||
async def main_task():
|
||||
# 创建任务
|
||||
receiver_task = asyncio.create_task(rtc_frame_receiver(webrtc_url, frame_queue, stop_event))
|
||||
consumer_task = asyncio.create_task(frame_consumer(ip, frame_queue, stop_event))
|
||||
|
||||
# 等待任务完成
|
||||
await asyncio.gather(receiver_task, consumer_task)
|
||||
|
||||
# 确保队列处理完毕
|
||||
await frame_queue.join()
|
||||
while True:
|
||||
await asyncio.sleep(0.1)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
print("关闭RTCPeerConnection")
|
||||
await pc.close()
|
||||
|
||||
try:
|
||||
# 运行事件循环
|
||||
asyncio.run(main_task())
|
||||
except KeyboardInterrupt:
|
||||
print("用户中断处理流程")
|
||||
stop_event.set()
|
||||
loop.run_until_complete(main())
|
||||
finally:
|
||||
# 确保关闭所有cv2窗口
|
||||
cv2.destroyAllWindows()
|
||||
print("AI检测已关闭")
|
||||
loop.close()
|
||||
|
||||
|
||||
def frame_consumer():
|
||||
"""
|
||||
消费者方法:从队列中读取帧并处理
|
||||
每次处理后休眠200ms模拟延迟
|
||||
"""
|
||||
print("消费者启动,开始等待帧...")
|
||||
try:
|
||||
while True:
|
||||
# 阻塞等待队列中的帧
|
||||
frame = frame_queue.get()
|
||||
print(f"消费帧,大小: {frame.shape}")
|
||||
|
||||
# 模拟处理延迟
|
||||
time.sleep(0.2) # 200ms
|
||||
|
||||
# 标记任务完成
|
||||
frame_queue.task_done()
|
||||
except KeyboardInterrupt:
|
||||
print("消费者退出")
|
||||
|
||||
|
||||
def start_webrtc_stream(webrtc_url):
|
||||
"""
|
||||
启动WebRTC视频流处理的主方法
|
||||
参数: webrtc_url - WebRTC服务器地址
|
||||
"""
|
||||
print(f"开始连接到WebRTC服务器: {webrtc_url}")
|
||||
|
||||
# 启动生产者线程
|
||||
producer_thread = threading.Thread(
|
||||
target=webrtc_producer,
|
||||
args=(webrtc_url,),
|
||||
daemon=True,
|
||||
name="webrtc-producer"
|
||||
)
|
||||
|
||||
# 启动消费者线程
|
||||
consumer_thread = threading.Thread(
|
||||
target=frame_consumer,
|
||||
daemon=True,
|
||||
name="frame-consumer"
|
||||
)
|
||||
|
||||
producer_thread.start()
|
||||
consumer_thread.start()
|
||||
print("生产者和消费者线程已启动")
|
||||
|
||||
try:
|
||||
# 保持主线程运行
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
print("程序正在退出...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 示例用法
|
||||
# 实际使用时替换为真实的WebRTC服务器地址
|
||||
webrtc_server_url = "http://192.168.110.65:1985/rtc/v1/whep/?app=live&stream=677a4845aa48cb8526c811ad56fc5e60"
|
||||
start_webrtc_stream(webrtc_server_url)
|
||||
|
Reference in New Issue
Block a user