import asyncio import aiohttp import cv2 import numpy as np from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration from aiortc.mediastreams import MediaStreamTrack from ocr.ocr_violation_detector import OCRViolationDetector class VideoTrack(MediaStreamTrack): kind = "video" def __init__(self, max_frames=1): super().__init__() self.frames = asyncio.Queue(maxsize=max_frames) async def recv(self): return await super().recv() async def rtc_frame_receiver(url, frame_queue): """ 对每帧进行检查、只要接收到 RTC 帧且队列为空、就往队列放入cv2格式的帧数据 """ pc = RTCPeerConnection(RTCConfiguration(iceServers=[])) video_track = VideoTrack() pc.addTrack(video_track) # 累计帧计数器 total_frames = 0 @pc.on("track") async def on_track(track): nonlocal total_frames if track.kind == "video": print("接收到视频轨道、开始接收视频帧") while True: # 接收当前帧并累计计数 frame = await track.recv() # 转换为cv2兼容的BGR格式numpy数组 frame_cv2 = frame.to_ndarray(format='bgr24') # 验证是否为cv2兼容格式 if isinstance(frame_cv2, np.ndarray) and frame_cv2.ndim == 3 and frame_cv2.shape[2] == 3: total_frames += 1 # 对每帧都检查队列状态、队列为空则放入 if frame_queue.empty(): # 队列为空、放入当前cv2帧 await frame_queue.put(frame_cv2) # print(f"第{total_frames}帧:队列为空、已放入新的cv2帧,尺寸: {frame_cv2.shape}") else: # 队列非空、说明上一帧还未处理、跳过当前帧 print(f"第{total_frames}帧:队列非空、跳过该帧") else: print("帧格式转换失败,不是有效的cv2格式") # 创建并设置本地offer offer = await pc.createOffer() print("已创建本地 SDP Offer") await pc.setLocalDescription(offer) # 发送offer到服务器 async with aiohttp.ClientSession() as session: print("开始向服务器发送 SDP Offer") async with session.post( url, data=offer.sdp.encode(), headers={ "Content-Type": "application/sdp", "Content-Length": str(len(offer.sdp)) }, ssl=False ) as response: print("已接收到服务器的响应、开始处理 SDP Answer") answer_sdp = await response.text() await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer')) try: # 保持连接 while True: await asyncio.sleep(1) except KeyboardInterrupt: print("用户中断") finally: print("开始关闭 RTCPeerConnection") await pc.close() print("已关闭 RTCPeerConnection") async def frame_consumer(frame_queue): """ 从队列中读取cv2帧并处理(队列空时会阻塞等待) Args: frame_queue: 帧队列 """ # 创建OCR检测器实例(请替换为实际的违禁词文件路径) ocr_detector = OCRViolationDetector( forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt", # 替换为实际路径 ocr_confidence_threshold=0.5, ) while True: # 从队列中获取cv2帧(队列为空时会阻塞等待新帧) current_frame = await frame_queue.get() has_violation, words, confidences = ocr_detector.detect(current_frame) # 输出所有检测到的违禁词 if has_violation: print(f"测试结果:图片中共检测到 {len(words)} 个违禁词:") for word, conf in zip(words, confidences): print(f"- {word}(置信度:{conf:.4f})") else: print("测试结果:图片中未检测到违禁词") # 标记任务完成 frame_queue.task_done() # print("帧处理完成、队列已清空") def process_webrtc_stream(ip, webrtc_url): """ 处理WEBRTC流并持续打印OCR检测结果 Args: ip: IP地址(预留参数) webrtc_url: WEBRTC服务器地址 """ # 创建队列 frame_queue = asyncio.Queue(maxsize=1) # 定义事件循环中的主任务 async def main_task(): # 创建任务 receiver_task = asyncio.create_task(rtc_frame_receiver(webrtc_url, frame_queue)) consumer_task = asyncio.create_task(frame_consumer(frame_queue)) # 等待任务完成 await asyncio.gather(receiver_task, consumer_task) try: # 运行事件循环 asyncio.run(main_task()) except KeyboardInterrupt: print("用户中断处理流程") finally: # 确保关闭所有cv2窗口 cv2.destroyAllWindows()