Files
video/rtc/rtc.py

173 lines
6.4 KiB
Python
Raw Normal View History

2025-09-02 19:46:34 +08:00
import asyncio
2025-09-02 23:06:36 +08:00
from datetime import datetime
2025-09-02 19:46:34 +08:00
import aiohttp
2025-09-02 21:42:09 +08:00
import cv2
2025-09-02 20:14:40 +08:00
import numpy as np
2025-09-02 19:46:34 +08:00
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration
2025-09-02 21:30:28 +08:00
from ocr.ocr_violation_detector import OCRViolationDetector
2025-09-02 23:06:36 +08:00
from ws.ws import send_message_to_client
2025-09-02 19:46:34 +08:00
2025-09-02 23:06:36 +08:00
async def rtc_frame_receiver(url, frame_queue, stop_event):
2025-09-02 19:46:34 +08:00
"""
2025-09-02 23:15:07 +08:00
接收RTC帧并往队列放入cv2格式的帧数据
当队列已满时直接丢弃新帧不阻塞等待
2025-09-02 23:06:36 +08:00
当stop_event被设置时停止接收
2025-09-02 19:46:34 +08:00
"""
pc = RTCPeerConnection(RTCConfiguration(iceServers=[]))
2025-09-02 23:15:07 +08:00
# 累计帧计数器和丢弃帧计数器
2025-09-02 19:46:34 +08:00
total_frames = 0
2025-09-02 23:15:07 +08:00
dropped_frames = 0
2025-09-02 19:46:34 +08:00
@pc.on("track")
async def on_track(track):
2025-09-02 23:15:07 +08:00
nonlocal total_frames, dropped_frames
2025-09-02 19:46:34 +08:00
if track.kind == "video":
print("接收到视频轨道、开始接收视频帧")
2025-09-02 23:06:36 +08:00
while not stop_event.is_set(): # 检查是否需要停止
2025-09-02 19:46:34 +08:00
# 接收当前帧并累计计数
frame = await track.recv()
2025-09-02 23:15:07 +08:00
total_frames += 1
2025-09-02 20:14:40 +08:00
# 转换为cv2兼容的BGR格式numpy数组
frame_cv2 = frame.to_ndarray(format='bgr24')
# 验证是否为cv2兼容格式
if isinstance(frame_cv2, np.ndarray) and frame_cv2.ndim == 3 and frame_cv2.shape[2] == 3:
2025-09-02 23:15:07 +08:00
# 检查队列是否已满
if frame_queue.full():
# 队列已满,丢弃当前帧
dropped_frames += 1
print(f"{total_frames}帧:队列已满,丢弃该帧(累计丢弃: {dropped_frames}")
2025-09-02 20:14:40 +08:00
else:
2025-09-02 23:15:07 +08:00
# 队列未满,放入当前帧
await frame_queue.put(frame_cv2)
print(f"{total_frames}帧:已放入队列")
2025-09-02 19:46:34 +08:00
else:
2025-09-02 20:14:40 +08:00
print("帧格式转换失败不是有效的cv2格式")
2025-09-02 19:46:34 +08:00
# 创建并设置本地offer
offer = await pc.createOffer()
print("已创建本地 SDP Offer")
await pc.setLocalDescription(offer)
# 发送offer到服务器
async with aiohttp.ClientSession() as session:
print("开始向服务器发送 SDP Offer")
2025-09-02 23:15:07 +08:00
try:
async with session.post(
url,
data=offer.sdp.encode(),
headers={
"Content-Type": "application/sdp",
"Content-Length": str(len(offer.sdp))
},
ssl=False
) as response:
if response.status == 200:
print("已接收到服务器的响应、开始处理 SDP Answer")
answer_sdp = await response.text()
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer'))
else:
print(f"服务器响应错误: {response.status}")
stop_event.set()
except Exception as e:
print(f"发送SDP Offer失败: {str(e)}")
stop_event.set()
2025-09-02 19:46:34 +08:00
try:
2025-09-02 23:06:36 +08:00
# 保持连接,直到收到停止信号
while not stop_event.is_set():
2025-09-02 19:46:34 +08:00
await asyncio.sleep(1)
except KeyboardInterrupt:
print("用户中断")
finally:
2025-09-02 23:15:07 +08:00
print(f"开始关闭 RTCPeerConnection共接收{total_frames}帧,丢弃{dropped_frames}")
2025-09-02 19:46:34 +08:00
await pc.close()
print("已关闭 RTCPeerConnection")
2025-09-02 23:06:36 +08:00
async def frame_consumer(ip, frame_queue, stop_event):
2025-09-02 19:46:34 +08:00
"""
2025-09-02 23:15:07 +08:00
从队列中阻塞读取cv2帧并处理队列为空时阻塞等待
2025-09-02 23:06:36 +08:00
检测到违规内容后设置stop_event以终止所有任务
2025-09-02 19:46:34 +08:00
"""
2025-09-02 23:06:36 +08:00
# 创建OCR检测器实例
2025-09-02 21:30:28 +08:00
ocr_detector = OCRViolationDetector(
2025-09-02 23:15:07 +08:00
forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt",
2025-09-02 21:42:09 +08:00
ocr_confidence_threshold=0.5, )
2025-09-02 21:30:28 +08:00
2025-09-02 23:06:36 +08:00
while not stop_event.is_set(): # 检查是否需要停止
2025-09-02 23:15:07 +08:00
try:
# 阻塞等待队列中的帧
current_frame = await frame_queue.get()
# 进行OCR检测
has_violation, words, confidences = ocr_detector.detect(current_frame)
print(f"检测结果: {'有违规内容' if has_violation else '无违规内容'}")
print(f"检测到的词: {words}")
print(f"置信度: {confidences}")
2025-09-02 19:46:34 +08:00
2025-09-02 23:15:07 +08:00
# 输出所有检测到的违禁词
if has_violation:
print(f"测试结果:图片中共检测到 {len(words)} 个违禁词:")
response_data = {
"status": "stop",
"timestamp": datetime.now().isoformat(),
"violations": [{"word": w, "confidence": c} for w, c in zip(words, confidences)]
}
await send_message_to_client(ip, response_data)
for word, conf in zip(words, confidences):
print(f"- {word}(置信度:{conf:.4f}")
# 检测到违规,设置停止事件
print("检测到违规内容准备关闭AI检测")
stop_event.set()
# 标记任务完成,允许生产者放入新的帧
frame_queue.task_done()
except Exception as e:
print(f"处理帧时发生错误: {str(e)}")
frame_queue.task_done()
2025-09-02 19:46:34 +08:00
2025-09-02 21:42:09 +08:00
def process_webrtc_stream(ip, webrtc_url):
"""
2025-09-02 23:06:36 +08:00
处理WEBRTC流并持续打印OCR检测结果检测到违规后关闭
2025-09-02 23:15:07 +08:00
队列大小为1满时直接丢弃新帧
2025-09-02 19:46:34 +08:00
2025-09-02 21:42:09 +08:00
Args:
2025-09-02 23:06:36 +08:00
ip: IP地址
2025-09-02 21:42:09 +08:00
webrtc_url: WEBRTC服务器地址
"""
2025-09-02 23:15:07 +08:00
# 创建队列大小为1和停止事件
frame_queue = asyncio.Queue(maxsize=1) # 只存储一帧
2025-09-02 23:06:36 +08:00
stop_event = asyncio.Event() # 用于控制任务停止的事件
2025-09-02 19:46:34 +08:00
2025-09-02 21:42:09 +08:00
# 定义事件循环中的主任务
async def main_task():
# 创建任务
2025-09-02 23:06:36 +08:00
receiver_task = asyncio.create_task(rtc_frame_receiver(webrtc_url, frame_queue, stop_event))
consumer_task = asyncio.create_task(frame_consumer(ip, frame_queue, stop_event))
2025-09-02 19:46:34 +08:00
2025-09-02 23:15:07 +08:00
# 等待任务完成
2025-09-02 21:42:09 +08:00
await asyncio.gather(receiver_task, consumer_task)
2025-09-02 19:46:34 +08:00
2025-09-02 23:15:07 +08:00
# 确保队列处理完毕
await frame_queue.join()
2025-09-02 20:14:40 +08:00
try:
2025-09-02 21:42:09 +08:00
# 运行事件循环
asyncio.run(main_task())
except KeyboardInterrupt:
print("用户中断处理流程")
2025-09-02 23:15:07 +08:00
stop_event.set()
2025-09-02 20:14:40 +08:00
finally:
# 确保关闭所有cv2窗口
cv2.destroyAllWindows()
2025-09-02 23:06:36 +08:00
print("AI检测已关闭")