This commit is contained in:
2025-09-02 23:06:36 +08:00
parent 2a59bdcffc
commit 062ee6c70d
4 changed files with 47 additions and 22 deletions

View File

@ -1,10 +1,13 @@
import asyncio
from datetime import datetime
import aiohttp
import cv2
import numpy as np
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration
from aiortc.mediastreams import MediaStreamTrack
from ocr.ocr_violation_detector import OCRViolationDetector
from ws.ws import send_message_to_client
class VideoTrack(MediaStreamTrack):
@ -18,9 +21,10 @@ class VideoTrack(MediaStreamTrack):
return await super().recv()
async def rtc_frame_receiver(url, frame_queue):
async def rtc_frame_receiver(url, frame_queue, stop_event):
"""
对每帧进行检查、只要接收到 RTC 帧且队列为空、就往队列放入cv2格式的帧数据
当stop_event被设置时停止接收
"""
pc = RTCPeerConnection(RTCConfiguration(iceServers=[]))
video_track = VideoTrack()
@ -34,7 +38,7 @@ async def rtc_frame_receiver(url, frame_queue):
nonlocal total_frames
if track.kind == "video":
print("接收到视频轨道、开始接收视频帧")
while True:
while not stop_event.is_set(): # 检查是否需要停止
# 接收当前帧并累计计数
frame = await track.recv()
# 转换为cv2兼容的BGR格式numpy数组
@ -45,13 +49,13 @@ async def rtc_frame_receiver(url, frame_queue):
total_frames += 1
# 对每帧都检查队列状态、队列为空则放入
if frame_queue.empty():
if frame_queue.empty() and not stop_event.is_set(): # 确保还未收到停止信号
# 队列为空、放入当前cv2帧
await frame_queue.put(frame_cv2)
# print(f"第{total_frames}帧队列为空、已放入新的cv2帧尺寸: {frame_cv2.shape}")
else:
# 队列非空、说明上一帧还未处理、跳过当前帧
print(f"{total_frames}帧:队列非空、跳过该帧")
# 队列非空或已收到停止信号、跳过当前帧
if not stop_event.is_set():
print(f"{total_frames}帧:队列非空、跳过该帧")
else:
print("帧格式转换失败不是有效的cv2格式")
@ -77,8 +81,8 @@ async def rtc_frame_receiver(url, frame_queue):
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer'))
try:
# 保持连接
while True:
# 保持连接,直到收到停止信号
while not stop_event.is_set():
await asyncio.sleep(1)
except KeyboardInterrupt:
print("用户中断")
@ -88,52 +92,68 @@ async def rtc_frame_receiver(url, frame_queue):
print("已关闭 RTCPeerConnection")
async def frame_consumer(frame_queue):
async def frame_consumer(ip, frame_queue, stop_event):
"""
从队列中读取cv2帧并处理队列空时会阻塞等待
检测到违规内容后设置stop_event以终止所有任务
Args: frame_queue: 帧队列
Args:
ip: IP地址
frame_queue: 帧队列
stop_event: 用于控制任务停止的事件
"""
# 创建OCR检测器实例(请替换为实际的违禁词文件路径)
# 创建OCR检测器实例
ocr_detector = OCRViolationDetector(
forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt", # 替换为实际路径
ocr_confidence_threshold=0.5, )
while True:
while not stop_event.is_set(): # 检查是否需要停止
# 从队列中获取cv2帧队列为空时会阻塞等待新帧
current_frame = await frame_queue.get()
has_violation, words, confidences = ocr_detector.detect(current_frame)
print(has_violation)
print( words)
print( confidences)
# 输出所有检测到的违禁词
if has_violation:
print(f"测试结果:图片中共检测到 {len(words)} 个违禁词:")
response_data = {
"status": "stop",
"timestamp": datetime.now().isoformat(),
}
await send_message_to_client(ip,response_data )
for word, conf in zip(words, confidences):
print(f"- {word}(置信度:{conf:.4f}")
# 检测到违规,设置停止事件
print("检测到违规内容准备关闭AI检测")
stop_event.set()
else:
print("测试结果:图片中未检测到违禁词")
# 标记任务完成
frame_queue.task_done()
# print("帧处理完成、队列已清空")
def process_webrtc_stream(ip, webrtc_url):
"""
处理WEBRTC流并持续打印OCR检测结果
处理WEBRTC流并持续打印OCR检测结果,检测到违规后关闭
Args:
ip: IP地址(预留参数)
ip: IP地址
webrtc_url: WEBRTC服务器地址
"""
# 创建队列
# 创建队列和停止事件
frame_queue = asyncio.Queue(maxsize=1)
stop_event = asyncio.Event() # 用于控制任务停止的事件
# 定义事件循环中的主任务
async def main_task():
# 创建任务
receiver_task = asyncio.create_task(rtc_frame_receiver(webrtc_url, frame_queue))
consumer_task = asyncio.create_task(frame_consumer(frame_queue))
receiver_task = asyncio.create_task(rtc_frame_receiver(webrtc_url, frame_queue, stop_event))
consumer_task = asyncio.create_task(frame_consumer(ip, frame_queue, stop_event))
# 等待任务完成
# 等待任一任务完成当stop_event被设置时两个任务都会退出
await asyncio.gather(receiver_task, consumer_task)
try:
@ -144,3 +164,4 @@ def process_webrtc_stream(ip, webrtc_url):
finally:
# 确保关闭所有cv2窗口
cv2.destroyAllWindows()
print("AI检测已关闭")