读取帧优化
This commit is contained in:
		
							
								
								
									
										267
									
								
								rtc/rtc.py
									
									
									
									
									
								
							
							
						
						
									
										267
									
								
								rtc/rtc.py
									
									
									
									
									
								
							| @ -1,64 +1,71 @@ | |||||||
|  | import queue | ||||||
| import asyncio | import asyncio | ||||||
| from datetime import datetime |  | ||||||
|  |  | ||||||
| import aiohttp | import aiohttp | ||||||
| import cv2 | import threading | ||||||
| import numpy as np | import time | ||||||
| from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration | from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration | ||||||
| from ocr.ocr_violation_detector import OCRViolationDetector | from aiortc.mediastreams import MediaStreamTrack | ||||||
| from ws.ws import send_message_to_client |  | ||||||
|  | # 创建一个长度为1的队列,用于生产者和消费者之间的通信 | ||||||
|  | frame_queue = queue.Queue(maxsize=1) | ||||||
|  |  | ||||||
|  |  | ||||||
| async def rtc_frame_receiver(url, frame_queue, stop_event): | class VideoTrack(MediaStreamTrack): | ||||||
|  |     """自定义视频轨道类,继承自MediaStreamTrack""" | ||||||
|  |     kind = "video" | ||||||
|  |  | ||||||
|  |     def __init__(self, max_frames=100): | ||||||
|  |         super().__init__() | ||||||
|  |         self.frames = queue.Queue(maxsize=max_frames) | ||||||
|  |  | ||||||
|  |     async def recv(self): | ||||||
|  |         return await super().recv() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def webrtc_producer(webrtc_url): | ||||||
|     """ |     """ | ||||||
|     接收RTC帧并往队列放入cv2格式的帧数据 |     生产者方法:从WEBRTC读取视频帧并放入队列 | ||||||
|     当队列已满时直接丢弃新帧,不阻塞等待 |     仅当队列空时才放入新帧,否则丢弃 | ||||||
|     当stop_event被设置时停止接收 |  | ||||||
|     """ |     """ | ||||||
|  |     loop = asyncio.new_event_loop() | ||||||
|  |     asyncio.set_event_loop(loop) | ||||||
|  |  | ||||||
|  |     # 创建RTCPeerConnection对象,不使用ICE服务器 | ||||||
|     pc = RTCPeerConnection(RTCConfiguration(iceServers=[])) |     pc = RTCPeerConnection(RTCConfiguration(iceServers=[])) | ||||||
|  |     video_track = VideoTrack() | ||||||
|     # 累计帧计数器和丢弃帧计数器 |     pc.addTrack(video_track) | ||||||
|     total_frames = 0 |  | ||||||
|     dropped_frames = 0 |  | ||||||
|  |  | ||||||
|     @pc.on("track") |     @pc.on("track") | ||||||
|     async def on_track(track): |     async def on_track(track): | ||||||
|         nonlocal total_frames, dropped_frames |  | ||||||
|         if track.kind == "video": |         if track.kind == "video": | ||||||
|             print("接收到视频轨道、开始接收视频帧") |             print("接收到视频轨道,开始接收视频帧") | ||||||
|             while not stop_event.is_set():  # 检查是否需要停止 |             while True: | ||||||
|                 # 接收当前帧并累计计数 |                 # 从轨道接收视频帧 | ||||||
|                 frame = await track.recv() |                 frame = await track.recv() | ||||||
|                 total_frames += 1 |                 # 转换为BGR24格式的NumPy数组 | ||||||
|  |                 frame_bgr24 = frame.to_ndarray(format='bgr24') | ||||||
|  |  | ||||||
|                 # 转换为cv2兼容的BGR格式numpy数组 |                 # 检查队列是否为空,为空则加入,否则丢弃 | ||||||
|                 frame_cv2 = frame.to_ndarray(format='bgr24') |                 if frame_queue.empty(): | ||||||
|  |                     try: | ||||||
|                 # 验证是否为cv2兼容格式 |                         frame_queue.put_nowait(frame_bgr24) | ||||||
|                 if isinstance(frame_cv2, np.ndarray) and frame_cv2.ndim == 3 and frame_cv2.shape[2] == 3: |                         print("帧已放入队列") | ||||||
|                     # 检查队列是否已满 |                     except queue.Full: | ||||||
|                     if frame_queue.full(): |                         print("队列已满,丢弃帧") | ||||||
|                         # 队列已满,丢弃当前帧 |  | ||||||
|                         dropped_frames += 1 |  | ||||||
|                         print(f"第{total_frames}帧:队列已满,丢弃该帧(累计丢弃: {dropped_frames})") |  | ||||||
|                     else: |  | ||||||
|                         # 队列未满,放入当前帧 |  | ||||||
|                         await frame_queue.put(frame_cv2) |  | ||||||
|                         print(f"第{total_frames}帧:已放入队列") |  | ||||||
|                 else: |                 else: | ||||||
|                     print("帧格式转换失败,不是有效的cv2格式") |                     print("队列非空,丢弃帧") | ||||||
|  |  | ||||||
|     # 创建并设置本地offer |     async def main(): | ||||||
|     offer = await pc.createOffer() |         # 创建并发送SDP Offer | ||||||
|     print("已创建本地 SDP Offer") |         offer = await pc.createOffer() | ||||||
|     await pc.setLocalDescription(offer) |         print("已创建本地SDP Offer") | ||||||
|  |         await pc.setLocalDescription(offer) | ||||||
|  |  | ||||||
|     # 发送offer到服务器 |         # 发送Offer到服务器并接收Answer | ||||||
|     async with aiohttp.ClientSession() as session: |         async with aiohttp.ClientSession() as session: | ||||||
|         print("开始向服务器发送 SDP Offer") |             print(f"开始向服务器 {webrtc_url} 发送SDP Offer") | ||||||
|         try: |  | ||||||
|             async with session.post( |             async with session.post( | ||||||
|                     url, |                     webrtc_url, | ||||||
|                     data=offer.sdp.encode(), |                     data=offer.sdp.encode(), | ||||||
|                     headers={ |                     headers={ | ||||||
|                         "Content-Type": "application/sdp", |                         "Content-Type": "application/sdp", | ||||||
| @ -66,107 +73,83 @@ async def rtc_frame_receiver(url, frame_queue, stop_event): | |||||||
|                     }, |                     }, | ||||||
|                     ssl=False |                     ssl=False | ||||||
|             ) as response: |             ) as response: | ||||||
|                 if response.status == 200: |                 print("已接收到服务器的响应") | ||||||
|                     print("已接收到服务器的响应、开始处理 SDP Answer") |                 answer_sdp = await response.text() | ||||||
|                     answer_sdp = await response.text() |                 await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer')) | ||||||
|                     await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer')) |  | ||||||
|                 else: |  | ||||||
|                     print(f"服务器响应错误: {response.status}") |  | ||||||
|                     stop_event.set() |  | ||||||
|         except Exception as e: |  | ||||||
|             print(f"发送SDP Offer失败: {str(e)}") |  | ||||||
|             stop_event.set() |  | ||||||
|  |  | ||||||
|     try: |         # 保持连接 | ||||||
|         # 保持连接,直到收到停止信号 |  | ||||||
|         while not stop_event.is_set(): |  | ||||||
|             await asyncio.sleep(1) |  | ||||||
|     except KeyboardInterrupt: |  | ||||||
|         print("用户中断") |  | ||||||
|     finally: |  | ||||||
|         print(f"开始关闭 RTCPeerConnection,共接收{total_frames}帧,丢弃{dropped_frames}帧") |  | ||||||
|         await pc.close() |  | ||||||
|         print("已关闭 RTCPeerConnection") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| async def frame_consumer(ip, frame_queue, stop_event): |  | ||||||
|     """ |  | ||||||
|     从队列中阻塞读取cv2帧并处理(队列为空时阻塞等待) |  | ||||||
|     检测到违规内容后设置stop_event以终止所有任务 |  | ||||||
|     """ |  | ||||||
|     # 创建OCR检测器实例 |  | ||||||
|     ocr_detector = OCRViolationDetector( |  | ||||||
|         forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt", |  | ||||||
|         ocr_confidence_threshold=0.5, ) |  | ||||||
|  |  | ||||||
|     while not stop_event.is_set():  # 检查是否需要停止 |  | ||||||
|         try: |         try: | ||||||
|             # 阻塞等待队列中的帧 |             while True: | ||||||
|             current_frame = await frame_queue.get() |                 await asyncio.sleep(0.1) | ||||||
|  |         except KeyboardInterrupt: | ||||||
|             # 进行OCR检测 |             pass | ||||||
|             has_violation, words, confidences = ocr_detector.detect(current_frame) |         finally: | ||||||
|             print(f"检测结果: {'有违规内容' if has_violation else '无违规内容'}") |             print("关闭RTCPeerConnection") | ||||||
|             print(f"检测到的词: {words}") |             await pc.close() | ||||||
|             print(f"置信度: {confidences}") |  | ||||||
|  |  | ||||||
|             # 输出所有检测到的违禁词 |  | ||||||
|             if has_violation: |  | ||||||
|                 print(f"测试结果:图片中共检测到 {len(words)} 个违禁词:") |  | ||||||
|                 response_data = { |  | ||||||
|                     "status": "stop", |  | ||||||
|                     "timestamp": datetime.now().isoformat(), |  | ||||||
|                     "violations": [{"word": w, "confidence": c} for w, c in zip(words, confidences)] |  | ||||||
|                 } |  | ||||||
|                 await send_message_to_client(ip, response_data) |  | ||||||
|  |  | ||||||
|                 for word, conf in zip(words, confidences): |  | ||||||
|                     print(f"- {word}(置信度:{conf:.4f})") |  | ||||||
|  |  | ||||||
|                 # 检测到违规,设置停止事件 |  | ||||||
|                 print("检测到违规内容,准备关闭AI检测") |  | ||||||
|                 stop_event.set() |  | ||||||
|  |  | ||||||
|             # 标记任务完成,允许生产者放入新的帧 |  | ||||||
|             frame_queue.task_done() |  | ||||||
|  |  | ||||||
|         except Exception as e: |  | ||||||
|             print(f"处理帧时发生错误: {str(e)}") |  | ||||||
|             frame_queue.task_done() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def process_webrtc_stream(ip, webrtc_url): |  | ||||||
|     """ |  | ||||||
|     处理WEBRTC流并持续打印OCR检测结果,检测到违规后关闭 |  | ||||||
|     队列大小为1,满时直接丢弃新帧 |  | ||||||
|  |  | ||||||
|     Args: |  | ||||||
|         ip: IP地址 |  | ||||||
|         webrtc_url: WEBRTC服务器地址 |  | ||||||
|     """ |  | ||||||
|     # 创建队列(大小为1)和停止事件 |  | ||||||
|     frame_queue = asyncio.Queue(maxsize=1)  # 只存储一帧 |  | ||||||
|     stop_event = asyncio.Event()  # 用于控制任务停止的事件 |  | ||||||
|  |  | ||||||
|     # 定义事件循环中的主任务 |  | ||||||
|     async def main_task(): |  | ||||||
|         # 创建任务 |  | ||||||
|         receiver_task = asyncio.create_task(rtc_frame_receiver(webrtc_url, frame_queue, stop_event)) |  | ||||||
|         consumer_task = asyncio.create_task(frame_consumer(ip, frame_queue, stop_event)) |  | ||||||
|  |  | ||||||
|         # 等待任务完成 |  | ||||||
|         await asyncio.gather(receiver_task, consumer_task) |  | ||||||
|  |  | ||||||
|         # 确保队列处理完毕 |  | ||||||
|         await frame_queue.join() |  | ||||||
|  |  | ||||||
|     try: |     try: | ||||||
|         # 运行事件循环 |         loop.run_until_complete(main()) | ||||||
|         asyncio.run(main_task()) |  | ||||||
|     except KeyboardInterrupt: |  | ||||||
|         print("用户中断处理流程") |  | ||||||
|         stop_event.set() |  | ||||||
|     finally: |     finally: | ||||||
|         # 确保关闭所有cv2窗口 |         loop.close() | ||||||
|         cv2.destroyAllWindows() |  | ||||||
|         print("AI检测已关闭") |  | ||||||
|  | def frame_consumer(): | ||||||
|  |     """ | ||||||
|  |     消费者方法:从队列中读取帧并处理 | ||||||
|  |     每次处理后休眠200ms模拟延迟 | ||||||
|  |     """ | ||||||
|  |     print("消费者启动,开始等待帧...") | ||||||
|  |     try: | ||||||
|  |         while True: | ||||||
|  |             # 阻塞等待队列中的帧 | ||||||
|  |             frame = frame_queue.get() | ||||||
|  |             print(f"消费帧,大小: {frame.shape}") | ||||||
|  |  | ||||||
|  |             # 模拟处理延迟 | ||||||
|  |             time.sleep(0.2)  # 200ms | ||||||
|  |  | ||||||
|  |             # 标记任务完成 | ||||||
|  |             frame_queue.task_done() | ||||||
|  |     except KeyboardInterrupt: | ||||||
|  |         print("消费者退出") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def start_webrtc_stream(webrtc_url): | ||||||
|  |     """ | ||||||
|  |     启动WebRTC视频流处理的主方法 | ||||||
|  |     参数: webrtc_url - WebRTC服务器地址 | ||||||
|  |     """ | ||||||
|  |     print(f"开始连接到WebRTC服务器: {webrtc_url}") | ||||||
|  |  | ||||||
|  |     # 启动生产者线程 | ||||||
|  |     producer_thread = threading.Thread( | ||||||
|  |         target=webrtc_producer, | ||||||
|  |         args=(webrtc_url,), | ||||||
|  |         daemon=True, | ||||||
|  |         name="webrtc-producer" | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     # 启动消费者线程 | ||||||
|  |     consumer_thread = threading.Thread( | ||||||
|  |         target=frame_consumer, | ||||||
|  |         daemon=True, | ||||||
|  |         name="frame-consumer" | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     producer_thread.start() | ||||||
|  |     consumer_thread.start() | ||||||
|  |     print("生产者和消费者线程已启动") | ||||||
|  |  | ||||||
|  |     try: | ||||||
|  |         # 保持主线程运行 | ||||||
|  |         while True: | ||||||
|  |             time.sleep(1) | ||||||
|  |     except KeyboardInterrupt: | ||||||
|  |         print("程序正在退出...") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     # 示例用法 | ||||||
|  |     # 实际使用时替换为真实的WebRTC服务器地址 | ||||||
|  |     webrtc_server_url = "http://192.168.110.65:1985/rtc/v1/whep/?app=live&stream=677a4845aa48cb8526c811ad56fc5e60" | ||||||
|  |     start_webrtc_stream(webrtc_server_url) | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	 ZZX9599
					ZZX9599