RTC提交

This commit is contained in:
ZZX9599
2025-09-02 23:15:07 +08:00
parent 062ee6c70d
commit be5383d752

View File

@ -5,57 +5,46 @@ import aiohttp
import cv2 import cv2
import numpy as np import numpy as np
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration
from aiortc.mediastreams import MediaStreamTrack
from ocr.ocr_violation_detector import OCRViolationDetector from ocr.ocr_violation_detector import OCRViolationDetector
from ws.ws import send_message_to_client from ws.ws import send_message_to_client
class VideoTrack(MediaStreamTrack):
kind = "video"
def __init__(self, max_frames=1):
super().__init__()
self.frames = asyncio.Queue(maxsize=max_frames)
async def recv(self):
return await super().recv()
async def rtc_frame_receiver(url, frame_queue, stop_event): async def rtc_frame_receiver(url, frame_queue, stop_event):
""" """
对每帧进行检查、只要接收到 RTC 帧且队列为空、就往队列放入cv2格式的帧数据 接收RTC帧并往队列放入cv2格式的帧数据
当队列已满时直接丢弃新帧,不阻塞等待
当stop_event被设置时停止接收 当stop_event被设置时停止接收
""" """
pc = RTCPeerConnection(RTCConfiguration(iceServers=[])) pc = RTCPeerConnection(RTCConfiguration(iceServers=[]))
video_track = VideoTrack()
pc.addTrack(video_track)
# 累计帧计数器 # 累计帧计数器和丢弃帧计数器
total_frames = 0 total_frames = 0
dropped_frames = 0
@pc.on("track") @pc.on("track")
async def on_track(track): async def on_track(track):
nonlocal total_frames nonlocal total_frames, dropped_frames
if track.kind == "video": if track.kind == "video":
print("接收到视频轨道、开始接收视频帧") print("接收到视频轨道、开始接收视频帧")
while not stop_event.is_set(): # 检查是否需要停止 while not stop_event.is_set(): # 检查是否需要停止
# 接收当前帧并累计计数 # 接收当前帧并累计计数
frame = await track.recv() frame = await track.recv()
total_frames += 1
# 转换为cv2兼容的BGR格式numpy数组 # 转换为cv2兼容的BGR格式numpy数组
frame_cv2 = frame.to_ndarray(format='bgr24') frame_cv2 = frame.to_ndarray(format='bgr24')
# 验证是否为cv2兼容格式 # 验证是否为cv2兼容格式
if isinstance(frame_cv2, np.ndarray) and frame_cv2.ndim == 3 and frame_cv2.shape[2] == 3: if isinstance(frame_cv2, np.ndarray) and frame_cv2.ndim == 3 and frame_cv2.shape[2] == 3:
total_frames += 1 # 检查队列是否已满
if frame_queue.full():
# 对每帧都检查队列状态、队列为空则放入 # 队列已满,丢弃当前帧
if frame_queue.empty() and not stop_event.is_set(): # 确保还未收到停止信号 dropped_frames += 1
# 队列为空、放入当前cv2帧 print(f"{total_frames}帧:队列已满,丢弃该帧(累计丢弃: {dropped_frames}")
await frame_queue.put(frame_cv2)
else: else:
# 队列非空或已收到停止信号、跳过当前帧 # 队列未满,放入当前帧
if not stop_event.is_set(): await frame_queue.put(frame_cv2)
print(f"{total_frames}帧:队列非空、跳过该帧") print(f"{total_frames}帧:已放入队列")
else: else:
print("帧格式转换失败不是有效的cv2格式") print("帧格式转换失败不是有效的cv2格式")
@ -67,18 +56,26 @@ async def rtc_frame_receiver(url, frame_queue, stop_event):
# 发送offer到服务器 # 发送offer到服务器
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
print("开始向服务器发送 SDP Offer") print("开始向服务器发送 SDP Offer")
async with session.post( try:
url, async with session.post(
data=offer.sdp.encode(), url,
headers={ data=offer.sdp.encode(),
"Content-Type": "application/sdp", headers={
"Content-Length": str(len(offer.sdp)) "Content-Type": "application/sdp",
}, "Content-Length": str(len(offer.sdp))
ssl=False },
) as response: ssl=False
print("已接收到服务器的响应、开始处理 SDP Answer") ) as response:
answer_sdp = await response.text() if response.status == 200:
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer')) print("已接收到服务器的响应、开始处理 SDP Answer")
answer_sdp = await response.text()
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer'))
else:
print(f"服务器响应错误: {response.status}")
stop_event.set()
except Exception as e:
print(f"发送SDP Offer失败: {str(e)}")
stop_event.set()
try: try:
# 保持连接,直到收到停止信号 # 保持连接,直到收到停止信号
@ -87,64 +84,68 @@ async def rtc_frame_receiver(url, frame_queue, stop_event):
except KeyboardInterrupt: except KeyboardInterrupt:
print("用户中断") print("用户中断")
finally: finally:
print("开始关闭 RTCPeerConnection") print(f"开始关闭 RTCPeerConnection,共接收{total_frames}帧,丢弃{dropped_frames}")
await pc.close() await pc.close()
print("已关闭 RTCPeerConnection") print("已关闭 RTCPeerConnection")
async def frame_consumer(ip, frame_queue, stop_event): async def frame_consumer(ip, frame_queue, stop_event):
""" """
从队列中读取cv2帧并处理队列空时阻塞等待) 从队列中阻塞读取cv2帧并处理队列空时阻塞等待)
检测到违规内容后设置stop_event以终止所有任务 检测到违规内容后设置stop_event以终止所有任务
Args:
ip: IP地址
frame_queue: 帧队列
stop_event: 用于控制任务停止的事件
""" """
# 创建OCR检测器实例 # 创建OCR检测器实例
ocr_detector = OCRViolationDetector( ocr_detector = OCRViolationDetector(
forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt", # 替换为实际路径 forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt",
ocr_confidence_threshold=0.5, ) ocr_confidence_threshold=0.5, )
while not stop_event.is_set(): # 检查是否需要停止 while not stop_event.is_set(): # 检查是否需要停止
# 从队列中获取cv2帧队列为空时会阻塞等待新帧 try:
current_frame = await frame_queue.get() # 阻塞等待队列中的帧
has_violation, words, confidences = ocr_detector.detect(current_frame) current_frame = await frame_queue.get()
print(has_violation)
print( words)
print( confidences)
# 输出所有检测到的违禁词
if has_violation:
print(f"测试结果:图片中共检测到 {len(words)} 个违禁词:")
response_data = {
"status": "stop",
"timestamp": datetime.now().isoformat(),
}
await send_message_to_client(ip,response_data )
for word, conf in zip(words, confidences):
print(f"- {word}(置信度:{conf:.4f}")
# 检测到违规,设置停止事件 # 进行OCR检测
print("检测到违规内容准备关闭AI检测") has_violation, words, confidences = ocr_detector.detect(current_frame)
stop_event.set() print(f"检测结果: {'有违规内容' if has_violation else '无违规内容'}")
else: print(f"检测到的词: {words}")
print("测试结果:图片中未检测到违禁词") print(f"置信度: {confidences}")
# 标记任务完成 # 输出所有检测到的违禁词
frame_queue.task_done() if has_violation:
print(f"测试结果:图片中共检测到 {len(words)} 个违禁词:")
response_data = {
"status": "stop",
"timestamp": datetime.now().isoformat(),
"violations": [{"word": w, "confidence": c} for w, c in zip(words, confidences)]
}
await send_message_to_client(ip, response_data)
for word, conf in zip(words, confidences):
print(f"- {word}(置信度:{conf:.4f}")
# 检测到违规,设置停止事件
print("检测到违规内容准备关闭AI检测")
stop_event.set()
# 标记任务完成,允许生产者放入新的帧
frame_queue.task_done()
except Exception as e:
print(f"处理帧时发生错误: {str(e)}")
frame_queue.task_done()
def process_webrtc_stream(ip, webrtc_url): def process_webrtc_stream(ip, webrtc_url):
""" """
处理WEBRTC流并持续打印OCR检测结果检测到违规后关闭 处理WEBRTC流并持续打印OCR检测结果检测到违规后关闭
队列大小为1满时直接丢弃新帧
Args: Args:
ip: IP地址 ip: IP地址
webrtc_url: WEBRTC服务器地址 webrtc_url: WEBRTC服务器地址
""" """
# 创建队列和停止事件 # 创建队列大小为1和停止事件
frame_queue = asyncio.Queue(maxsize=1) frame_queue = asyncio.Queue(maxsize=1) # 只存储一帧
stop_event = asyncio.Event() # 用于控制任务停止的事件 stop_event = asyncio.Event() # 用于控制任务停止的事件
# 定义事件循环中的主任务 # 定义事件循环中的主任务
@ -153,14 +154,18 @@ def process_webrtc_stream(ip, webrtc_url):
receiver_task = asyncio.create_task(rtc_frame_receiver(webrtc_url, frame_queue, stop_event)) receiver_task = asyncio.create_task(rtc_frame_receiver(webrtc_url, frame_queue, stop_event))
consumer_task = asyncio.create_task(frame_consumer(ip, frame_queue, stop_event)) consumer_task = asyncio.create_task(frame_consumer(ip, frame_queue, stop_event))
# 等待任一任务完成当stop_event被设置时两个任务都会退出 # 等待任务完成
await asyncio.gather(receiver_task, consumer_task) await asyncio.gather(receiver_task, consumer_task)
# 确保队列处理完毕
await frame_queue.join()
try: try:
# 运行事件循环 # 运行事件循环
asyncio.run(main_task()) asyncio.run(main_task())
except KeyboardInterrupt: except KeyboardInterrupt:
print("用户中断处理流程") print("用户中断处理流程")
stop_event.set()
finally: finally:
# 确保关闭所有cv2窗口 # 确保关闭所有cv2窗口
cv2.destroyAllWindows() cv2.destroyAllWindows()