diff --git a/rtc/rtc.py b/rtc/rtc.py index 08412cd..6cf40e3 100644 --- a/rtc/rtc.py +++ b/rtc/rtc.py @@ -5,57 +5,46 @@ import aiohttp import cv2 import numpy as np from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration -from aiortc.mediastreams import MediaStreamTrack from ocr.ocr_violation_detector import OCRViolationDetector from ws.ws import send_message_to_client -class VideoTrack(MediaStreamTrack): - kind = "video" - - def __init__(self, max_frames=1): - super().__init__() - self.frames = asyncio.Queue(maxsize=max_frames) - - async def recv(self): - return await super().recv() - - async def rtc_frame_receiver(url, frame_queue, stop_event): """ - 对每帧进行检查、只要接收到 RTC 帧且队列为空、就往队列放入cv2格式的帧数据 + 接收RTC帧并往队列放入cv2格式的帧数据 + 当队列已满时直接丢弃新帧,不阻塞等待 当stop_event被设置时停止接收 """ pc = RTCPeerConnection(RTCConfiguration(iceServers=[])) - video_track = VideoTrack() - pc.addTrack(video_track) - # 累计帧计数器 + # 累计帧计数器和丢弃帧计数器 total_frames = 0 + dropped_frames = 0 @pc.on("track") async def on_track(track): - nonlocal total_frames + nonlocal total_frames, dropped_frames if track.kind == "video": print("接收到视频轨道、开始接收视频帧") while not stop_event.is_set(): # 检查是否需要停止 # 接收当前帧并累计计数 frame = await track.recv() + total_frames += 1 + # 转换为cv2兼容的BGR格式numpy数组 frame_cv2 = frame.to_ndarray(format='bgr24') # 验证是否为cv2兼容格式 if isinstance(frame_cv2, np.ndarray) and frame_cv2.ndim == 3 and frame_cv2.shape[2] == 3: - total_frames += 1 - - # 对每帧都检查队列状态、队列为空则放入 - if frame_queue.empty() and not stop_event.is_set(): # 确保还未收到停止信号 - # 队列为空、放入当前cv2帧 - await frame_queue.put(frame_cv2) + # 检查队列是否已满 + if frame_queue.full(): + # 队列已满,丢弃当前帧 + dropped_frames += 1 + print(f"第{total_frames}帧:队列已满,丢弃该帧(累计丢弃: {dropped_frames})") else: - # 队列非空或已收到停止信号、跳过当前帧 - if not stop_event.is_set(): - print(f"第{total_frames}帧:队列非空、跳过该帧") + # 队列未满,放入当前帧 + await frame_queue.put(frame_cv2) + print(f"第{total_frames}帧:已放入队列") else: print("帧格式转换失败,不是有效的cv2格式") @@ -67,18 +56,26 @@ async def rtc_frame_receiver(url, frame_queue, stop_event): # 发送offer到服务器 async with aiohttp.ClientSession() as session: print("开始向服务器发送 SDP Offer") - async with session.post( - url, - data=offer.sdp.encode(), - headers={ - "Content-Type": "application/sdp", - "Content-Length": str(len(offer.sdp)) - }, - ssl=False - ) as response: - print("已接收到服务器的响应、开始处理 SDP Answer") - answer_sdp = await response.text() - await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer')) + try: + async with session.post( + url, + data=offer.sdp.encode(), + headers={ + "Content-Type": "application/sdp", + "Content-Length": str(len(offer.sdp)) + }, + ssl=False + ) as response: + if response.status == 200: + print("已接收到服务器的响应、开始处理 SDP Answer") + answer_sdp = await response.text() + await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer')) + else: + print(f"服务器响应错误: {response.status}") + stop_event.set() + except Exception as e: + print(f"发送SDP Offer失败: {str(e)}") + stop_event.set() try: # 保持连接,直到收到停止信号 @@ -87,64 +84,68 @@ async def rtc_frame_receiver(url, frame_queue, stop_event): except KeyboardInterrupt: print("用户中断") finally: - print("开始关闭 RTCPeerConnection") + print(f"开始关闭 RTCPeerConnection,共接收{total_frames}帧,丢弃{dropped_frames}帧") await pc.close() print("已关闭 RTCPeerConnection") async def frame_consumer(ip, frame_queue, stop_event): """ - 从队列中读取cv2帧并处理(队列空时会阻塞等待) + 从队列中阻塞读取cv2帧并处理(队列为空时阻塞等待) 检测到违规内容后设置stop_event以终止所有任务 - - Args: - ip: IP地址 - frame_queue: 帧队列 - stop_event: 用于控制任务停止的事件 """ # 创建OCR检测器实例 ocr_detector = OCRViolationDetector( - forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt", # 替换为实际路径 + forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt", ocr_confidence_threshold=0.5, ) while not stop_event.is_set(): # 检查是否需要停止 - # 从队列中获取cv2帧(队列为空时会阻塞等待新帧) - current_frame = await frame_queue.get() - has_violation, words, confidences = ocr_detector.detect(current_frame) - print(has_violation) - print( words) - print( confidences) - # 输出所有检测到的违禁词 - if has_violation: - print(f"测试结果:图片中共检测到 {len(words)} 个违禁词:") - response_data = { - "status": "stop", - "timestamp": datetime.now().isoformat(), - } - await send_message_to_client(ip,response_data ) - for word, conf in zip(words, confidences): - print(f"- {word}(置信度:{conf:.4f})") + try: + # 阻塞等待队列中的帧 + current_frame = await frame_queue.get() - # 检测到违规,设置停止事件 - print("检测到违规内容,准备关闭AI检测") - stop_event.set() - else: - print("测试结果:图片中未检测到违禁词") + # 进行OCR检测 + has_violation, words, confidences = ocr_detector.detect(current_frame) + print(f"检测结果: {'有违规内容' if has_violation else '无违规内容'}") + print(f"检测到的词: {words}") + print(f"置信度: {confidences}") - # 标记任务完成 - frame_queue.task_done() + # 输出所有检测到的违禁词 + if has_violation: + print(f"测试结果:图片中共检测到 {len(words)} 个违禁词:") + response_data = { + "status": "stop", + "timestamp": datetime.now().isoformat(), + "violations": [{"word": w, "confidence": c} for w, c in zip(words, confidences)] + } + await send_message_to_client(ip, response_data) + + for word, conf in zip(words, confidences): + print(f"- {word}(置信度:{conf:.4f})") + + # 检测到违规,设置停止事件 + print("检测到违规内容,准备关闭AI检测") + stop_event.set() + + # 标记任务完成,允许生产者放入新的帧 + frame_queue.task_done() + + except Exception as e: + print(f"处理帧时发生错误: {str(e)}") + frame_queue.task_done() def process_webrtc_stream(ip, webrtc_url): """ 处理WEBRTC流并持续打印OCR检测结果,检测到违规后关闭 + 队列大小为1,满时直接丢弃新帧 Args: ip: IP地址 webrtc_url: WEBRTC服务器地址 """ - # 创建队列和停止事件 - frame_queue = asyncio.Queue(maxsize=1) + # 创建队列(大小为1)和停止事件 + frame_queue = asyncio.Queue(maxsize=1) # 只存储一帧 stop_event = asyncio.Event() # 用于控制任务停止的事件 # 定义事件循环中的主任务 @@ -153,14 +154,18 @@ def process_webrtc_stream(ip, webrtc_url): receiver_task = asyncio.create_task(rtc_frame_receiver(webrtc_url, frame_queue, stop_event)) consumer_task = asyncio.create_task(frame_consumer(ip, frame_queue, stop_event)) - # 等待任一任务完成(当stop_event被设置时,两个任务都会退出) + # 等待任务完成 await asyncio.gather(receiver_task, consumer_task) + # 确保队列处理完毕 + await frame_queue.join() + try: # 运行事件循环 asyncio.run(main_task()) except KeyboardInterrupt: print("用户中断处理流程") + stop_event.set() finally: # 确保关闭所有cv2窗口 cv2.destroyAllWindows()