读取帧优化

This commit is contained in:
ZZX9599
2025-09-03 10:35:14 +08:00
parent be5383d752
commit 1816b5c5dd

View File

@ -1,64 +1,71 @@
import queue
import asyncio import asyncio
from datetime import datetime
import aiohttp import aiohttp
import cv2 import threading
import numpy as np import time
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration
from ocr.ocr_violation_detector import OCRViolationDetector from aiortc.mediastreams import MediaStreamTrack
from ws.ws import send_message_to_client
# 创建一个长度为1的队列用于生产者和消费者之间的通信
frame_queue = queue.Queue(maxsize=1)
async def rtc_frame_receiver(url, frame_queue, stop_event): class VideoTrack(MediaStreamTrack):
"""自定义视频轨道类继承自MediaStreamTrack"""
kind = "video"
def __init__(self, max_frames=100):
super().__init__()
self.frames = queue.Queue(maxsize=max_frames)
async def recv(self):
return await super().recv()
def webrtc_producer(webrtc_url):
""" """
接收RTC帧并往队列放入cv2格式的帧数据 生产者方法从WEBRTC读取视频帧并放入队列
当队列已满时直接丢弃新帧,不阻塞等待 当队列空时才放入新帧,否则丢弃
当stop_event被设置时停止接收
""" """
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 创建RTCPeerConnection对象不使用ICE服务器
pc = RTCPeerConnection(RTCConfiguration(iceServers=[])) pc = RTCPeerConnection(RTCConfiguration(iceServers=[]))
video_track = VideoTrack()
# 累计帧计数器和丢弃帧计数器 pc.addTrack(video_track)
total_frames = 0
dropped_frames = 0
@pc.on("track") @pc.on("track")
async def on_track(track): async def on_track(track):
nonlocal total_frames, dropped_frames
if track.kind == "video": if track.kind == "video":
print("接收到视频轨道开始接收视频帧") print("接收到视频轨道开始接收视频帧")
while not stop_event.is_set(): # 检查是否需要停止 while True:
# 接收当前帧并累计计数 # 从轨道接收视频帧
frame = await track.recv() frame = await track.recv()
total_frames += 1 # 转换为BGR24格式的NumPy数组
frame_bgr24 = frame.to_ndarray(format='bgr24')
# 转换为cv2兼容的BGR格式numpy数组 # 检查队列是否为空,为空则加入,否则丢弃
frame_cv2 = frame.to_ndarray(format='bgr24') if frame_queue.empty():
try:
# 验证是否为cv2兼容格式 frame_queue.put_nowait(frame_bgr24)
if isinstance(frame_cv2, np.ndarray) and frame_cv2.ndim == 3 and frame_cv2.shape[2] == 3: print("帧已放入队列")
# 检查队列是否已满 except queue.Full:
if frame_queue.full(): print("队列已满,丢弃帧")
# 队列已满,丢弃当前帧
dropped_frames += 1
print(f"{total_frames}帧:队列已满,丢弃该帧(累计丢弃: {dropped_frames}")
else: else:
# 队列未满,放入当前帧 print("队列非空,丢弃帧")
await frame_queue.put(frame_cv2)
print(f"{total_frames}帧:已放入队列")
else:
print("帧格式转换失败不是有效的cv2格式")
# 创建并设置本地offer async def main():
# 创建并发送SDP Offer
offer = await pc.createOffer() offer = await pc.createOffer()
print("已创建本地 SDP Offer") print("已创建本地SDP Offer")
await pc.setLocalDescription(offer) await pc.setLocalDescription(offer)
# 发送offer到服务器 # 发送Offer到服务器并接收Answer
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
print("开始向服务器发送 SDP Offer") print(f"开始向服务器 {webrtc_url} 发送SDP Offer")
try:
async with session.post( async with session.post(
url, webrtc_url,
data=offer.sdp.encode(), data=offer.sdp.encode(),
headers={ headers={
"Content-Type": "application/sdp", "Content-Type": "application/sdp",
@ -66,107 +73,83 @@ async def rtc_frame_receiver(url, frame_queue, stop_event):
}, },
ssl=False ssl=False
) as response: ) as response:
if response.status == 200: print("已接收到服务器的响应")
print("已接收到服务器的响应、开始处理 SDP Answer")
answer_sdp = await response.text() answer_sdp = await response.text()
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer')) await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer'))
else:
print(f"服务器响应错误: {response.status}")
stop_event.set()
except Exception as e:
print(f"发送SDP Offer失败: {str(e)}")
stop_event.set()
# 保持连接
try: try:
# 保持连接,直到收到停止信号 while True:
while not stop_event.is_set(): await asyncio.sleep(0.1)
await asyncio.sleep(1)
except KeyboardInterrupt: except KeyboardInterrupt:
print("用户中断") pass
finally: finally:
print(f"开始关闭 RTCPeerConnection,共接收{total_frames}帧,丢弃{dropped_frames}") print("关闭RTCPeerConnection")
await pc.close() await pc.close()
print("已关闭 RTCPeerConnection")
async def frame_consumer(ip, frame_queue, stop_event):
"""
从队列中阻塞读取cv2帧并处理队列为空时阻塞等待
检测到违规内容后设置stop_event以终止所有任务
"""
# 创建OCR检测器实例
ocr_detector = OCRViolationDetector(
forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt",
ocr_confidence_threshold=0.5, )
while not stop_event.is_set(): # 检查是否需要停止
try:
# 阻塞等待队列中的帧
current_frame = await frame_queue.get()
# 进行OCR检测
has_violation, words, confidences = ocr_detector.detect(current_frame)
print(f"检测结果: {'有违规内容' if has_violation else '无违规内容'}")
print(f"检测到的词: {words}")
print(f"置信度: {confidences}")
# 输出所有检测到的违禁词
if has_violation:
print(f"测试结果:图片中共检测到 {len(words)} 个违禁词:")
response_data = {
"status": "stop",
"timestamp": datetime.now().isoformat(),
"violations": [{"word": w, "confidence": c} for w, c in zip(words, confidences)]
}
await send_message_to_client(ip, response_data)
for word, conf in zip(words, confidences):
print(f"- {word}(置信度:{conf:.4f}")
# 检测到违规,设置停止事件
print("检测到违规内容准备关闭AI检测")
stop_event.set()
# 标记任务完成,允许生产者放入新的帧
frame_queue.task_done()
except Exception as e:
print(f"处理帧时发生错误: {str(e)}")
frame_queue.task_done()
def process_webrtc_stream(ip, webrtc_url):
"""
处理WEBRTC流并持续打印OCR检测结果检测到违规后关闭
队列大小为1满时直接丢弃新帧
Args:
ip: IP地址
webrtc_url: WEBRTC服务器地址
"""
# 创建队列大小为1和停止事件
frame_queue = asyncio.Queue(maxsize=1) # 只存储一帧
stop_event = asyncio.Event() # 用于控制任务停止的事件
# 定义事件循环中的主任务
async def main_task():
# 创建任务
receiver_task = asyncio.create_task(rtc_frame_receiver(webrtc_url, frame_queue, stop_event))
consumer_task = asyncio.create_task(frame_consumer(ip, frame_queue, stop_event))
# 等待任务完成
await asyncio.gather(receiver_task, consumer_task)
# 确保队列处理完毕
await frame_queue.join()
try: try:
# 运行事件循环 loop.run_until_complete(main())
asyncio.run(main_task())
except KeyboardInterrupt:
print("用户中断处理流程")
stop_event.set()
finally: finally:
# 确保关闭所有cv2窗口 loop.close()
cv2.destroyAllWindows()
print("AI检测已关闭")
def frame_consumer():
"""
消费者方法:从队列中读取帧并处理
每次处理后休眠200ms模拟延迟
"""
print("消费者启动,开始等待帧...")
try:
while True:
# 阻塞等待队列中的帧
frame = frame_queue.get()
print(f"消费帧,大小: {frame.shape}")
# 模拟处理延迟
time.sleep(0.2) # 200ms
# 标记任务完成
frame_queue.task_done()
except KeyboardInterrupt:
print("消费者退出")
def start_webrtc_stream(webrtc_url):
"""
启动WebRTC视频流处理的主方法
参数: webrtc_url - WebRTC服务器地址
"""
print(f"开始连接到WebRTC服务器: {webrtc_url}")
# 启动生产者线程
producer_thread = threading.Thread(
target=webrtc_producer,
args=(webrtc_url,),
daemon=True,
name="webrtc-producer"
)
# 启动消费者线程
consumer_thread = threading.Thread(
target=frame_consumer,
daemon=True,
name="frame-consumer"
)
producer_thread.start()
consumer_thread.start()
print("生产者和消费者线程已启动")
try:
# 保持主线程运行
while True:
time.sleep(1)
except KeyboardInterrupt:
print("程序正在退出...")
if __name__ == "__main__":
# 示例用法
# 实际使用时替换为真实的WebRTC服务器地址
webrtc_server_url = "http://192.168.110.65:1985/rtc/v1/whep/?app=live&stream=677a4845aa48cb8526c811ad56fc5e60"
start_webrtc_stream(webrtc_server_url)