This commit is contained in:
2025-09-02 23:06:36 +08:00
parent 2a59bdcffc
commit 062ee6c70d
4 changed files with 47 additions and 22 deletions

View File

@ -7,6 +7,11 @@ def setup_logger():
配置一个全局日志记录器,支持输出到控制台和文件。 配置一个全局日志记录器,支持输出到控制台和文件。
""" """
# 创建一个日志记录器 # 创建一个日志记录器
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger = logging.getLogger("ViolationDetectorLogger") logger = logging.getLogger("ViolationDetectorLogger")
logger.setLevel(logging.DEBUG) # 设置最低级别为DEBUG logger.setLevel(logging.DEBUG) # 设置最低级别为DEBUG

View File

@ -99,7 +99,7 @@ class OCRViolationDetector:
"""初始化RapidOCR引擎""" """初始化RapidOCR引擎"""
self.logger.info("正在初始化RapidOCR引擎...") self.logger.info("正在初始化RapidOCR引擎...")
config_path = r"../ocr/config/1.yaml" config_path = r"D:\Git\bin\video\ocr\config\1.yaml"
try: try:
# 检查配置文件是否存在 # 检查配置文件是否存在
if not os.path.exists(config_path): if not os.path.exists(config_path):
@ -157,4 +157,3 @@ class OCRViolationDetector:
# 返回检测结果(是否有违禁词、所有违禁词列表、对应置信度列表) # 返回检测结果(是否有违禁词、所有违禁词列表、对应置信度列表)
return len(all_prohibited) > 0, all_prohibited, all_confidences return len(all_prohibited) > 0, all_prohibited, all_confidences

View File

@ -1,10 +1,13 @@
import asyncio import asyncio
from datetime import datetime
import aiohttp import aiohttp
import cv2 import cv2
import numpy as np import numpy as np
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration
from aiortc.mediastreams import MediaStreamTrack from aiortc.mediastreams import MediaStreamTrack
from ocr.ocr_violation_detector import OCRViolationDetector from ocr.ocr_violation_detector import OCRViolationDetector
from ws.ws import send_message_to_client
class VideoTrack(MediaStreamTrack): class VideoTrack(MediaStreamTrack):
@ -18,9 +21,10 @@ class VideoTrack(MediaStreamTrack):
return await super().recv() return await super().recv()
async def rtc_frame_receiver(url, frame_queue): async def rtc_frame_receiver(url, frame_queue, stop_event):
""" """
对每帧进行检查、只要接收到 RTC 帧且队列为空、就往队列放入cv2格式的帧数据 对每帧进行检查、只要接收到 RTC 帧且队列为空、就往队列放入cv2格式的帧数据
当stop_event被设置时停止接收
""" """
pc = RTCPeerConnection(RTCConfiguration(iceServers=[])) pc = RTCPeerConnection(RTCConfiguration(iceServers=[]))
video_track = VideoTrack() video_track = VideoTrack()
@ -34,7 +38,7 @@ async def rtc_frame_receiver(url, frame_queue):
nonlocal total_frames nonlocal total_frames
if track.kind == "video": if track.kind == "video":
print("接收到视频轨道、开始接收视频帧") print("接收到视频轨道、开始接收视频帧")
while True: while not stop_event.is_set(): # 检查是否需要停止
# 接收当前帧并累计计数 # 接收当前帧并累计计数
frame = await track.recv() frame = await track.recv()
# 转换为cv2兼容的BGR格式numpy数组 # 转换为cv2兼容的BGR格式numpy数组
@ -45,13 +49,13 @@ async def rtc_frame_receiver(url, frame_queue):
total_frames += 1 total_frames += 1
# 对每帧都检查队列状态、队列为空则放入 # 对每帧都检查队列状态、队列为空则放入
if frame_queue.empty(): if frame_queue.empty() and not stop_event.is_set(): # 确保还未收到停止信号
# 队列为空、放入当前cv2帧 # 队列为空、放入当前cv2帧
await frame_queue.put(frame_cv2) await frame_queue.put(frame_cv2)
# print(f"第{total_frames}帧队列为空、已放入新的cv2帧尺寸: {frame_cv2.shape}")
else: else:
# 队列非空、说明上一帧还未处理、跳过当前帧 # 队列非空或已收到停止信号、跳过当前帧
print(f"{total_frames}帧:队列非空、跳过该帧") if not stop_event.is_set():
print(f"{total_frames}帧:队列非空、跳过该帧")
else: else:
print("帧格式转换失败不是有效的cv2格式") print("帧格式转换失败不是有效的cv2格式")
@ -77,8 +81,8 @@ async def rtc_frame_receiver(url, frame_queue):
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer')) await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer'))
try: try:
# 保持连接 # 保持连接,直到收到停止信号
while True: while not stop_event.is_set():
await asyncio.sleep(1) await asyncio.sleep(1)
except KeyboardInterrupt: except KeyboardInterrupt:
print("用户中断") print("用户中断")
@ -88,52 +92,68 @@ async def rtc_frame_receiver(url, frame_queue):
print("已关闭 RTCPeerConnection") print("已关闭 RTCPeerConnection")
async def frame_consumer(frame_queue): async def frame_consumer(ip, frame_queue, stop_event):
""" """
从队列中读取cv2帧并处理队列空时会阻塞等待 从队列中读取cv2帧并处理队列空时会阻塞等待
检测到违规内容后设置stop_event以终止所有任务
Args: frame_queue: 帧队列 Args:
ip: IP地址
frame_queue: 帧队列
stop_event: 用于控制任务停止的事件
""" """
# 创建OCR检测器实例(请替换为实际的违禁词文件路径) # 创建OCR检测器实例
ocr_detector = OCRViolationDetector( ocr_detector = OCRViolationDetector(
forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt", # 替换为实际路径 forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt", # 替换为实际路径
ocr_confidence_threshold=0.5, ) ocr_confidence_threshold=0.5, )
while True: while not stop_event.is_set(): # 检查是否需要停止
# 从队列中获取cv2帧队列为空时会阻塞等待新帧 # 从队列中获取cv2帧队列为空时会阻塞等待新帧
current_frame = await frame_queue.get() current_frame = await frame_queue.get()
has_violation, words, confidences = ocr_detector.detect(current_frame) has_violation, words, confidences = ocr_detector.detect(current_frame)
print(has_violation)
print( words)
print( confidences)
# 输出所有检测到的违禁词 # 输出所有检测到的违禁词
if has_violation: if has_violation:
print(f"测试结果:图片中共检测到 {len(words)} 个违禁词:") print(f"测试结果:图片中共检测到 {len(words)} 个违禁词:")
response_data = {
"status": "stop",
"timestamp": datetime.now().isoformat(),
}
await send_message_to_client(ip,response_data )
for word, conf in zip(words, confidences): for word, conf in zip(words, confidences):
print(f"- {word}(置信度:{conf:.4f}") print(f"- {word}(置信度:{conf:.4f}")
# 检测到违规,设置停止事件
print("检测到违规内容准备关闭AI检测")
stop_event.set()
else: else:
print("测试结果:图片中未检测到违禁词") print("测试结果:图片中未检测到违禁词")
# 标记任务完成 # 标记任务完成
frame_queue.task_done() frame_queue.task_done()
# print("帧处理完成、队列已清空")
def process_webrtc_stream(ip, webrtc_url): def process_webrtc_stream(ip, webrtc_url):
""" """
处理WEBRTC流并持续打印OCR检测结果 处理WEBRTC流并持续打印OCR检测结果,检测到违规后关闭
Args: Args:
ip: IP地址(预留参数) ip: IP地址
webrtc_url: WEBRTC服务器地址 webrtc_url: WEBRTC服务器地址
""" """
# 创建队列 # 创建队列和停止事件
frame_queue = asyncio.Queue(maxsize=1) frame_queue = asyncio.Queue(maxsize=1)
stop_event = asyncio.Event() # 用于控制任务停止的事件
# 定义事件循环中的主任务 # 定义事件循环中的主任务
async def main_task(): async def main_task():
# 创建任务 # 创建任务
receiver_task = asyncio.create_task(rtc_frame_receiver(webrtc_url, frame_queue)) receiver_task = asyncio.create_task(rtc_frame_receiver(webrtc_url, frame_queue, stop_event))
consumer_task = asyncio.create_task(frame_consumer(frame_queue)) consumer_task = asyncio.create_task(frame_consumer(ip, frame_queue, stop_event))
# 等待任务完成 # 等待任一任务完成当stop_event被设置时两个任务都会退出
await asyncio.gather(receiver_task, consumer_task) await asyncio.gather(receiver_task, consumer_task)
try: try:
@ -144,3 +164,4 @@ def process_webrtc_stream(ip, webrtc_url):
finally: finally:
# 确保关闭所有cv2窗口 # 确保关闭所有cv2窗口
cv2.destroyAllWindows() cv2.destroyAllWindows()
print("AI检测已关闭")

View File

@ -52,7 +52,7 @@ async def create_device(request: Request, device_data: DeviceCreateRequest):
# 设备创建成功后在后台线程启动WEBRTC流处理 # 设备创建成功后在后台线程启动WEBRTC流处理
threading.Thread( threading.Thread(
target=run_webrtc_processing, target=run_webrtc_processing,
args=(device_data.ip, full_webrtc_url), args=(device_data.ip, existing_device["live_webrtc_url"]),
daemon=True # 设为守护线程,主程序退出时自动结束 daemon=True # 设为守护线程,主程序退出时自动结束
).start() ).start()
# IP已存在时返回该设备信息 # IP已存在时返回该设备信息