RTC提交
This commit is contained in:
		
							
								
								
									
										95
									
								
								rtc/rtc.py
									
									
									
									
									
								
							
							
						
						
									
										95
									
								
								rtc/rtc.py
									
									
									
									
									
								
							| @ -5,57 +5,46 @@ import aiohttp | ||||
| import cv2 | ||||
| import numpy as np | ||||
| from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration | ||||
| from aiortc.mediastreams import MediaStreamTrack | ||||
| from ocr.ocr_violation_detector import OCRViolationDetector | ||||
| from ws.ws import send_message_to_client | ||||
|  | ||||
|  | ||||
| class VideoTrack(MediaStreamTrack): | ||||
|     kind = "video" | ||||
|  | ||||
|     def __init__(self, max_frames=1): | ||||
|         super().__init__() | ||||
|         self.frames = asyncio.Queue(maxsize=max_frames) | ||||
|  | ||||
|     async def recv(self): | ||||
|         return await super().recv() | ||||
|  | ||||
|  | ||||
| async def rtc_frame_receiver(url, frame_queue, stop_event): | ||||
|     """ | ||||
|     对每帧进行检查、只要接收到 RTC 帧且队列为空、就往队列放入cv2格式的帧数据 | ||||
|     接收RTC帧并往队列放入cv2格式的帧数据 | ||||
|     当队列已满时直接丢弃新帧,不阻塞等待 | ||||
|     当stop_event被设置时停止接收 | ||||
|     """ | ||||
|     pc = RTCPeerConnection(RTCConfiguration(iceServers=[])) | ||||
|     video_track = VideoTrack() | ||||
|     pc.addTrack(video_track) | ||||
|  | ||||
|     # 累计帧计数器 | ||||
|     # 累计帧计数器和丢弃帧计数器 | ||||
|     total_frames = 0 | ||||
|     dropped_frames = 0 | ||||
|  | ||||
|     @pc.on("track") | ||||
|     async def on_track(track): | ||||
|         nonlocal total_frames | ||||
|         nonlocal total_frames, dropped_frames | ||||
|         if track.kind == "video": | ||||
|             print("接收到视频轨道、开始接收视频帧") | ||||
|             while not stop_event.is_set():  # 检查是否需要停止 | ||||
|                 # 接收当前帧并累计计数 | ||||
|                 frame = await track.recv() | ||||
|                 total_frames += 1 | ||||
|  | ||||
|                 # 转换为cv2兼容的BGR格式numpy数组 | ||||
|                 frame_cv2 = frame.to_ndarray(format='bgr24') | ||||
|  | ||||
|                 # 验证是否为cv2兼容格式 | ||||
|                 if isinstance(frame_cv2, np.ndarray) and frame_cv2.ndim == 3 and frame_cv2.shape[2] == 3: | ||||
|                     total_frames += 1 | ||||
|  | ||||
|                     # 对每帧都检查队列状态、队列为空则放入 | ||||
|                     if frame_queue.empty() and not stop_event.is_set():  # 确保还未收到停止信号 | ||||
|                         # 队列为空、放入当前cv2帧 | ||||
|                         await frame_queue.put(frame_cv2) | ||||
|                     # 检查队列是否已满 | ||||
|                     if frame_queue.full(): | ||||
|                         # 队列已满,丢弃当前帧 | ||||
|                         dropped_frames += 1 | ||||
|                         print(f"第{total_frames}帧:队列已满,丢弃该帧(累计丢弃: {dropped_frames})") | ||||
|                     else: | ||||
|                         # 队列非空或已收到停止信号、跳过当前帧 | ||||
|                         if not stop_event.is_set(): | ||||
|                             print(f"第{total_frames}帧:队列非空、跳过该帧") | ||||
|                         # 队列未满,放入当前帧 | ||||
|                         await frame_queue.put(frame_cv2) | ||||
|                         print(f"第{total_frames}帧:已放入队列") | ||||
|                 else: | ||||
|                     print("帧格式转换失败,不是有效的cv2格式") | ||||
|  | ||||
| @ -67,6 +56,7 @@ async def rtc_frame_receiver(url, frame_queue, stop_event): | ||||
|     # 发送offer到服务器 | ||||
|     async with aiohttp.ClientSession() as session: | ||||
|         print("开始向服务器发送 SDP Offer") | ||||
|         try: | ||||
|             async with session.post( | ||||
|                     url, | ||||
|                     data=offer.sdp.encode(), | ||||
| @ -76,9 +66,16 @@ async def rtc_frame_receiver(url, frame_queue, stop_event): | ||||
|                     }, | ||||
|                     ssl=False | ||||
|             ) as response: | ||||
|                 if response.status == 200: | ||||
|                     print("已接收到服务器的响应、开始处理 SDP Answer") | ||||
|                     answer_sdp = await response.text() | ||||
|                     await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer')) | ||||
|                 else: | ||||
|                     print(f"服务器响应错误: {response.status}") | ||||
|                     stop_event.set() | ||||
|         except Exception as e: | ||||
|             print(f"发送SDP Offer失败: {str(e)}") | ||||
|             stop_event.set() | ||||
|  | ||||
|     try: | ||||
|         # 保持连接,直到收到停止信号 | ||||
| @ -87,64 +84,68 @@ async def rtc_frame_receiver(url, frame_queue, stop_event): | ||||
|     except KeyboardInterrupt: | ||||
|         print("用户中断") | ||||
|     finally: | ||||
|         print("开始关闭 RTCPeerConnection") | ||||
|         print(f"开始关闭 RTCPeerConnection,共接收{total_frames}帧,丢弃{dropped_frames}帧") | ||||
|         await pc.close() | ||||
|         print("已关闭 RTCPeerConnection") | ||||
|  | ||||
|  | ||||
| async def frame_consumer(ip, frame_queue, stop_event): | ||||
|     """ | ||||
|     从队列中读取cv2帧并处理(队列空时会阻塞等待) | ||||
|     从队列中阻塞读取cv2帧并处理(队列为空时阻塞等待) | ||||
|     检测到违规内容后设置stop_event以终止所有任务 | ||||
|  | ||||
|     Args: | ||||
|         ip: IP地址 | ||||
|         frame_queue: 帧队列 | ||||
|         stop_event: 用于控制任务停止的事件 | ||||
|     """ | ||||
|     # 创建OCR检测器实例 | ||||
|     ocr_detector = OCRViolationDetector( | ||||
|         forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt",  # 替换为实际路径 | ||||
|         forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt", | ||||
|         ocr_confidence_threshold=0.5, ) | ||||
|  | ||||
|     while not stop_event.is_set():  # 检查是否需要停止 | ||||
|         # 从队列中获取cv2帧(队列为空时会阻塞等待新帧) | ||||
|         try: | ||||
|             # 阻塞等待队列中的帧 | ||||
|             current_frame = await frame_queue.get() | ||||
|  | ||||
|             # 进行OCR检测 | ||||
|             has_violation, words, confidences = ocr_detector.detect(current_frame) | ||||
|         print(has_violation) | ||||
|         print( words) | ||||
|         print( confidences) | ||||
|             print(f"检测结果: {'有违规内容' if has_violation else '无违规内容'}") | ||||
|             print(f"检测到的词: {words}") | ||||
|             print(f"置信度: {confidences}") | ||||
|  | ||||
|             # 输出所有检测到的违禁词 | ||||
|             if has_violation: | ||||
|                 print(f"测试结果:图片中共检测到 {len(words)} 个违禁词:") | ||||
|                 response_data = { | ||||
|                     "status": "stop", | ||||
|                     "timestamp": datetime.now().isoformat(), | ||||
|                     "violations": [{"word": w, "confidence": c} for w, c in zip(words, confidences)] | ||||
|                 } | ||||
|             await send_message_to_client(ip,response_data ) | ||||
|                 await send_message_to_client(ip, response_data) | ||||
|  | ||||
|                 for word, conf in zip(words, confidences): | ||||
|                     print(f"- {word}(置信度:{conf:.4f})") | ||||
|  | ||||
|                 # 检测到违规,设置停止事件 | ||||
|                 print("检测到违规内容,准备关闭AI检测") | ||||
|                 stop_event.set() | ||||
|         else: | ||||
|             print("测试结果:图片中未检测到违禁词") | ||||
|  | ||||
|         # 标记任务完成 | ||||
|             # 标记任务完成,允许生产者放入新的帧 | ||||
|             frame_queue.task_done() | ||||
|  | ||||
|         except Exception as e: | ||||
|             print(f"处理帧时发生错误: {str(e)}") | ||||
|             frame_queue.task_done() | ||||
|  | ||||
|  | ||||
| def process_webrtc_stream(ip, webrtc_url): | ||||
|     """ | ||||
|     处理WEBRTC流并持续打印OCR检测结果,检测到违规后关闭 | ||||
|     队列大小为1,满时直接丢弃新帧 | ||||
|  | ||||
|     Args: | ||||
|         ip: IP地址 | ||||
|         webrtc_url: WEBRTC服务器地址 | ||||
|     """ | ||||
|     # 创建队列和停止事件 | ||||
|     frame_queue = asyncio.Queue(maxsize=1) | ||||
|     # 创建队列(大小为1)和停止事件 | ||||
|     frame_queue = asyncio.Queue(maxsize=1)  # 只存储一帧 | ||||
|     stop_event = asyncio.Event()  # 用于控制任务停止的事件 | ||||
|  | ||||
|     # 定义事件循环中的主任务 | ||||
| @ -153,14 +154,18 @@ def process_webrtc_stream(ip, webrtc_url): | ||||
|         receiver_task = asyncio.create_task(rtc_frame_receiver(webrtc_url, frame_queue, stop_event)) | ||||
|         consumer_task = asyncio.create_task(frame_consumer(ip, frame_queue, stop_event)) | ||||
|  | ||||
|         # 等待任一任务完成(当stop_event被设置时,两个任务都会退出) | ||||
|         # 等待任务完成 | ||||
|         await asyncio.gather(receiver_task, consumer_task) | ||||
|  | ||||
|         # 确保队列处理完毕 | ||||
|         await frame_queue.join() | ||||
|  | ||||
|     try: | ||||
|         # 运行事件循环 | ||||
|         asyncio.run(main_task()) | ||||
|     except KeyboardInterrupt: | ||||
|         print("用户中断处理流程") | ||||
|         stop_event.set() | ||||
|     finally: | ||||
|         # 确保关闭所有cv2窗口 | ||||
|         cv2.destroyAllWindows() | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	 ZZX9599
					ZZX9599