Files
video/rtc/rtc.py
2025-09-02 21:35:24 +08:00

140 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import aiohttp
import cv2 # 导入OpenCV库
import numpy as np
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration
from aiortc.mediastreams import MediaStreamTrack
from ocr.ocr_violation_detector import OCRViolationDetector
class VideoTrack(MediaStreamTrack):
kind = "video"
def __init__(self, max_frames=1):
super().__init__()
self.frames = asyncio.Queue(maxsize=max_frames)
async def recv(self):
return await super().recv()
async def rtc_frame_receiver(url, frame_queue):
"""
对每帧进行检查、只要接收到 RTC 帧且队列为空、就往队列放入cv2格式的帧数据
"""
pc = RTCPeerConnection(RTCConfiguration(iceServers=[]))
video_track = VideoTrack()
pc.addTrack(video_track)
# 累计帧计数器
total_frames = 0
@pc.on("track")
async def on_track(track):
nonlocal total_frames
if track.kind == "video":
print("接收到视频轨道、开始接收视频帧")
while True:
# 接收当前帧并累计计数
frame = await track.recv()
# 转换为cv2兼容的BGR格式numpy数组
frame_cv2 = frame.to_ndarray(format='bgr24')
# 验证是否为cv2兼容格式
if isinstance(frame_cv2, np.ndarray) and frame_cv2.ndim == 3 and frame_cv2.shape[2] == 3:
total_frames += 1
# 对每帧都检查队列状态、队列为空则放入
if frame_queue.empty():
# 队列为空、放入当前cv2帧
await frame_queue.put(frame_cv2)
# print(f"第{total_frames}帧队列为空、已放入新的cv2帧尺寸: {frame_cv2.shape}")
else:
# 队列非空、说明上一帧还未处理、跳过当前帧
print(f"{total_frames}帧:队列非空、跳过该帧")
else:
print("帧格式转换失败不是有效的cv2格式")
# 创建并设置本地offer
offer = await pc.createOffer()
print("已创建本地 SDP Offer")
await pc.setLocalDescription(offer)
# 发送offer到服务器
async with aiohttp.ClientSession() as session:
print("开始向服务器发送 SDP Offer")
async with session.post(
url,
data=offer.sdp.encode(),
headers={
"Content-Type": "application/sdp",
"Content-Length": str(len(offer.sdp))
},
ssl=False
) as response:
print("已接收到服务器的响应、开始处理 SDP Answer")
answer_sdp = await response.text()
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer'))
try:
# 保持连接
while True:
await asyncio.sleep(1)
except KeyboardInterrupt:
print("用户中断")
finally:
print("开始关闭 RTCPeerConnection")
await pc.close()
print("已关闭 RTCPeerConnection")
async def frame_consumer(frame_queue):
"""
从队列中读取cv2帧并处理队列空时会阻塞等待
Args: frame_queue: 帧队列
"""
# 创建OCR检测器实例请替换为实际的违禁词文件路径
ocr_detector = OCRViolationDetector(
forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt", # 替换为实际路径
ocr_confidence_threshold=0.5,)
while True:
# 从队列中获取cv2帧队列为空时会阻塞等待新帧
current_frame = await frame_queue.get()
has_violation, words, confidences = ocr_detector.detect(current_frame)
# 输出所有检测到的违禁词
if has_violation:
print(f"测试结果:图片中共检测到 {len(words)} 个违禁词:")
for word, conf in zip(words, confidences):
print(f"- {word}(置信度:{conf:.4f}")
else:
print("测试结果:图片中未检测到违禁词")
# 标记任务完成
frame_queue.task_done()
# print("帧处理完成、队列已清空")
async def main():
# WebRTC服务器地址
url = "http://192.168.110.25:1985/rtc/v1/whep/?app=live&stream=677a4845aa48cb8526c811ad56fc5e60"
# 创建队列
frame_queue = asyncio.Queue(maxsize=1)
# 创建任务
receiver_task = asyncio.create_task(rtc_frame_receiver(url, frame_queue))
consumer_task = asyncio.create_task(frame_consumer(frame_queue))
# 等待任务完成
await asyncio.gather(receiver_task, consumer_task)
if __name__ == "__main__":
try:
asyncio.run(main())
finally:
# 确保关闭所有cv2窗口
cv2.destroyAllWindows()