Files
video/rtc/rtc.py
2025-09-03 13:52:24 +08:00

176 lines
5.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import queue
import asyncio
import aiohttp
import threading
import time
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration
from aiortc.mediastreams import MediaStreamTrack
from ocr.ocr_violation_detector import OCRViolationDetector
import logging
# 创建检测器实例
detector = OCRViolationDetector(
forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt",
ocr_confidence_threshold=0.7,
log_level=logging.INFO,
log_file="ocr_detection.log"
)
# 创建一个长度为1的队列用于生产者和消费者之间的通信
frame_queue = queue.Queue(maxsize=1)
class VideoTrack(MediaStreamTrack):
"""自定义视频轨道类继承自MediaStreamTrack"""
kind = "video"
def __init__(self, max_frames=100):
super().__init__()
self.frames = queue.Queue(maxsize=max_frames)
async def recv(self):
return await super().recv()
def webrtc_producer(webrtc_url):
"""
生产者方法从WEBRTC读取视频帧并放入队列
仅当队列空时才放入新帧,否则丢弃
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 创建RTCPeerConnection对象不使用ICE服务器
pc = RTCPeerConnection(RTCConfiguration(iceServers=[]))
video_track = VideoTrack()
pc.addTrack(video_track)
@pc.on("track")
async def on_track(track):
if track.kind == "video":
print("接收到视频轨道,开始接收视频帧")
while True:
# 从轨道接收视频帧
frame = await track.recv()
# 转换为BGR24格式的NumPy数组
frame_bgr24 = frame.to_ndarray(format='bgr24')
# 检查队列是否为空,为空则加入,否则丢弃
if frame_queue.empty():
try:
frame_queue.put_nowait(frame_bgr24)
print("帧已放入队列")
except queue.Full:
print("队列已满,丢弃帧")
else:
print("队列非空,丢弃帧")
async def main():
# 创建并发送SDP Offer
offer = await pc.createOffer()
print("已创建本地SDP Offer")
await pc.setLocalDescription(offer)
# 发送Offer到服务器并接收Answer
async with aiohttp.ClientSession() as session:
print(f"开始向服务器 {webrtc_url} 发送SDP Offer")
async with session.post(
webrtc_url,
data=offer.sdp.encode(),
headers={
"Content-Type": "application/sdp",
"Content-Length": str(len(offer.sdp))
},
ssl=False
) as response:
print("已接收到服务器的响应")
answer_sdp = await response.text()
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer'))
# 保持连接
try:
while True:
await asyncio.sleep(0.1)
except KeyboardInterrupt:
pass
finally:
print("关闭RTCPeerConnection")
await pc.close()
try:
loop.run_until_complete(main())
finally:
loop.close()
def frame_consumer(ip):
"""
消费者方法:从队列中读取帧并处理
每次处理后休眠200ms模拟延迟
"""
print("消费者启动,开始等待帧...")
try:
while True:
# 阻塞等待队列中的帧
frame = frame_queue.get()
print(f"消费帧,大小: {frame.shape}")
has_violation, violations, confidences = OCRViolationDetector.detect(frame)
# 输出检测结果
if has_violation:
detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
for word, conf in zip(violations, confidences):
detector.logger.info(f"- {word} (置信度: {conf:.4f})")
else:
detector.logger.info("图片中未检测到违禁词")
# 标记任务完成
frame_queue.task_done()
except KeyboardInterrupt:
print("消费者退出")
def start_webrtc_stream(ip, webrtc_url):
"""
启动WebRTC视频流处理的主方法
参数: webrtc_url - WebRTC服务器地址
"""
print(f"开始连接到WebRTC服务器: {webrtc_url}")
# 启动生产者线程
producer_thread = threading.Thread(
target=webrtc_producer,
args=(webrtc_url,),
daemon=True,
name="webrtc-producer"
)
# 启动消费者线程
consumer_thread = threading.Thread(
target=frame_consumer(ip),
daemon=True,
name="frame-consumer"
)
producer_thread.start()
consumer_thread.start()
print("生产者和消费者线程已启动")
try:
# 保持主线程运行
while True:
time.sleep(1)
except KeyboardInterrupt:
print("程序正在退出...")
if __name__ == "__main__":
# 示例用法
# 实际使用时替换为真实的WebRTC服务器地址
webrtc_server_url = "http://192.168.110.65:1985/rtc/v1/whep/?app=live&stream=677a4845aa48cb8526c811ad56fc5e60"
start_webrtc_stream(webrtc_server_url)