From 1816b5c5dda70b132636b8c829e14d3fe85b041d Mon Sep 17 00:00:00 2001 From: ZZX9599 <536509593@qq.com> Date: Wed, 3 Sep 2025 10:35:14 +0800 Subject: [PATCH] =?UTF-8?q?=E8=AF=BB=E5=8F=96=E5=B8=A7=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rtc/rtc.py | 267 +++++++++++++++++++++++++---------------------------- 1 file changed, 125 insertions(+), 142 deletions(-) diff --git a/rtc/rtc.py b/rtc/rtc.py index 6cf40e3..2e7563b 100644 --- a/rtc/rtc.py +++ b/rtc/rtc.py @@ -1,64 +1,71 @@ +import queue import asyncio -from datetime import datetime - import aiohttp -import cv2 -import numpy as np +import threading +import time from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration -from ocr.ocr_violation_detector import OCRViolationDetector -from ws.ws import send_message_to_client +from aiortc.mediastreams import MediaStreamTrack + +# 创建一个长度为1的队列,用于生产者和消费者之间的通信 +frame_queue = queue.Queue(maxsize=1) -async def rtc_frame_receiver(url, frame_queue, stop_event): +class VideoTrack(MediaStreamTrack): + """自定义视频轨道类,继承自MediaStreamTrack""" + kind = "video" + + def __init__(self, max_frames=100): + super().__init__() + self.frames = queue.Queue(maxsize=max_frames) + + async def recv(self): + return await super().recv() + + +def webrtc_producer(webrtc_url): """ - 接收RTC帧并往队列放入cv2格式的帧数据 - 当队列已满时直接丢弃新帧,不阻塞等待 - 当stop_event被设置时停止接收 + 生产者方法:从WEBRTC读取视频帧并放入队列 + 仅当队列空时才放入新帧,否则丢弃 """ + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # 创建RTCPeerConnection对象,不使用ICE服务器 pc = RTCPeerConnection(RTCConfiguration(iceServers=[])) - - # 累计帧计数器和丢弃帧计数器 - total_frames = 0 - dropped_frames = 0 + video_track = VideoTrack() + pc.addTrack(video_track) @pc.on("track") async def on_track(track): - nonlocal total_frames, dropped_frames if track.kind == "video": - print("接收到视频轨道、开始接收视频帧") - while not stop_event.is_set(): # 检查是否需要停止 - # 接收当前帧并累计计数 + print("接收到视频轨道,开始接收视频帧") + while True: + # 从轨道接收视频帧 frame = await track.recv() - total_frames += 1 + # 转换为BGR24格式的NumPy数组 + frame_bgr24 = frame.to_ndarray(format='bgr24') - # 转换为cv2兼容的BGR格式numpy数组 - frame_cv2 = frame.to_ndarray(format='bgr24') - - # 验证是否为cv2兼容格式 - if isinstance(frame_cv2, np.ndarray) and frame_cv2.ndim == 3 and frame_cv2.shape[2] == 3: - # 检查队列是否已满 - if frame_queue.full(): - # 队列已满,丢弃当前帧 - dropped_frames += 1 - print(f"第{total_frames}帧:队列已满,丢弃该帧(累计丢弃: {dropped_frames})") - else: - # 队列未满,放入当前帧 - await frame_queue.put(frame_cv2) - print(f"第{total_frames}帧:已放入队列") + # 检查队列是否为空,为空则加入,否则丢弃 + if frame_queue.empty(): + try: + frame_queue.put_nowait(frame_bgr24) + print("帧已放入队列") + except queue.Full: + print("队列已满,丢弃帧") else: - print("帧格式转换失败,不是有效的cv2格式") + print("队列非空,丢弃帧") - # 创建并设置本地offer - offer = await pc.createOffer() - print("已创建本地 SDP Offer") - await pc.setLocalDescription(offer) + async def main(): + # 创建并发送SDP Offer + offer = await pc.createOffer() + print("已创建本地SDP Offer") + await pc.setLocalDescription(offer) - # 发送offer到服务器 - async with aiohttp.ClientSession() as session: - print("开始向服务器发送 SDP Offer") - try: + # 发送Offer到服务器并接收Answer + async with aiohttp.ClientSession() as session: + print(f"开始向服务器 {webrtc_url} 发送SDP Offer") async with session.post( - url, + webrtc_url, data=offer.sdp.encode(), headers={ "Content-Type": "application/sdp", @@ -66,107 +73,83 @@ async def rtc_frame_receiver(url, frame_queue, stop_event): }, ssl=False ) as response: - if response.status == 200: - print("已接收到服务器的响应、开始处理 SDP Answer") - answer_sdp = await response.text() - await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer')) - else: - print(f"服务器响应错误: {response.status}") - stop_event.set() - except Exception as e: - print(f"发送SDP Offer失败: {str(e)}") - stop_event.set() + print("已接收到服务器的响应") + answer_sdp = await response.text() + await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer')) - try: - # 保持连接,直到收到停止信号 - while not stop_event.is_set(): - await asyncio.sleep(1) - except KeyboardInterrupt: - print("用户中断") - finally: - print(f"开始关闭 RTCPeerConnection,共接收{total_frames}帧,丢弃{dropped_frames}帧") - await pc.close() - print("已关闭 RTCPeerConnection") - - -async def frame_consumer(ip, frame_queue, stop_event): - """ - 从队列中阻塞读取cv2帧并处理(队列为空时阻塞等待) - 检测到违规内容后设置stop_event以终止所有任务 - """ - # 创建OCR检测器实例 - ocr_detector = OCRViolationDetector( - forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt", - ocr_confidence_threshold=0.5, ) - - while not stop_event.is_set(): # 检查是否需要停止 + # 保持连接 try: - # 阻塞等待队列中的帧 - current_frame = await frame_queue.get() - - # 进行OCR检测 - has_violation, words, confidences = ocr_detector.detect(current_frame) - print(f"检测结果: {'有违规内容' if has_violation else '无违规内容'}") - print(f"检测到的词: {words}") - print(f"置信度: {confidences}") - - # 输出所有检测到的违禁词 - if has_violation: - print(f"测试结果:图片中共检测到 {len(words)} 个违禁词:") - response_data = { - "status": "stop", - "timestamp": datetime.now().isoformat(), - "violations": [{"word": w, "confidence": c} for w, c in zip(words, confidences)] - } - await send_message_to_client(ip, response_data) - - for word, conf in zip(words, confidences): - print(f"- {word}(置信度:{conf:.4f})") - - # 检测到违规,设置停止事件 - print("检测到违规内容,准备关闭AI检测") - stop_event.set() - - # 标记任务完成,允许生产者放入新的帧 - frame_queue.task_done() - - except Exception as e: - print(f"处理帧时发生错误: {str(e)}") - frame_queue.task_done() - - -def process_webrtc_stream(ip, webrtc_url): - """ - 处理WEBRTC流并持续打印OCR检测结果,检测到违规后关闭 - 队列大小为1,满时直接丢弃新帧 - - Args: - ip: IP地址 - webrtc_url: WEBRTC服务器地址 - """ - # 创建队列(大小为1)和停止事件 - frame_queue = asyncio.Queue(maxsize=1) # 只存储一帧 - stop_event = asyncio.Event() # 用于控制任务停止的事件 - - # 定义事件循环中的主任务 - async def main_task(): - # 创建任务 - receiver_task = asyncio.create_task(rtc_frame_receiver(webrtc_url, frame_queue, stop_event)) - consumer_task = asyncio.create_task(frame_consumer(ip, frame_queue, stop_event)) - - # 等待任务完成 - await asyncio.gather(receiver_task, consumer_task) - - # 确保队列处理完毕 - await frame_queue.join() + while True: + await asyncio.sleep(0.1) + except KeyboardInterrupt: + pass + finally: + print("关闭RTCPeerConnection") + await pc.close() try: - # 运行事件循环 - asyncio.run(main_task()) - except KeyboardInterrupt: - print("用户中断处理流程") - stop_event.set() + loop.run_until_complete(main()) finally: - # 确保关闭所有cv2窗口 - cv2.destroyAllWindows() - print("AI检测已关闭") + loop.close() + + +def frame_consumer(): + """ + 消费者方法:从队列中读取帧并处理 + 每次处理后休眠200ms模拟延迟 + """ + print("消费者启动,开始等待帧...") + try: + while True: + # 阻塞等待队列中的帧 + frame = frame_queue.get() + print(f"消费帧,大小: {frame.shape}") + + # 模拟处理延迟 + time.sleep(0.2) # 200ms + + # 标记任务完成 + frame_queue.task_done() + except KeyboardInterrupt: + print("消费者退出") + + +def start_webrtc_stream(webrtc_url): + """ + 启动WebRTC视频流处理的主方法 + 参数: webrtc_url - WebRTC服务器地址 + """ + print(f"开始连接到WebRTC服务器: {webrtc_url}") + + # 启动生产者线程 + producer_thread = threading.Thread( + target=webrtc_producer, + args=(webrtc_url,), + daemon=True, + name="webrtc-producer" + ) + + # 启动消费者线程 + consumer_thread = threading.Thread( + target=frame_consumer, + daemon=True, + name="frame-consumer" + ) + + producer_thread.start() + consumer_thread.start() + print("生产者和消费者线程已启动") + + try: + # 保持主线程运行 + while True: + time.sleep(1) + except KeyboardInterrupt: + print("程序正在退出...") + + +if __name__ == "__main__": + # 示例用法 + # 实际使用时替换为真实的WebRTC服务器地址 + webrtc_server_url = "http://192.168.110.65:1985/rtc/v1/whep/?app=live&stream=677a4845aa48cb8526c811ad56fc5e60" + start_webrtc_stream(webrtc_server_url)