diff --git a/ocr/logger_config.py b/ocr/logger_config.py index 038d657..09052a5 100644 --- a/ocr/logger_config.py +++ b/ocr/logger_config.py @@ -7,6 +7,11 @@ def setup_logger(): 配置一个全局日志记录器,支持输出到控制台和文件。 """ # 创建一个日志记录器 + + # 配置日志 + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + logger = logging.getLogger("ViolationDetectorLogger") logger.setLevel(logging.DEBUG) # 设置最低级别为DEBUG diff --git a/ocr/ocr_violation_detector.py b/ocr/ocr_violation_detector.py index 8bf5d7e..460f5ec 100644 --- a/ocr/ocr_violation_detector.py +++ b/ocr/ocr_violation_detector.py @@ -99,7 +99,7 @@ class OCRViolationDetector: """初始化RapidOCR引擎""" self.logger.info("正在初始化RapidOCR引擎...") - config_path = r"../ocr/config/1.yaml" + config_path = r"D:\Git\bin\video\ocr\config\1.yaml" try: # 检查配置文件是否存在 if not os.path.exists(config_path): @@ -157,4 +157,3 @@ class OCRViolationDetector: # 返回检测结果(是否有违禁词、所有违禁词列表、对应置信度列表) return len(all_prohibited) > 0, all_prohibited, all_confidences - diff --git a/rtc/rtc.py b/rtc/rtc.py index 938d912..08412cd 100644 --- a/rtc/rtc.py +++ b/rtc/rtc.py @@ -1,10 +1,13 @@ import asyncio +from datetime import datetime + import aiohttp import cv2 import numpy as np from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration from aiortc.mediastreams import MediaStreamTrack from ocr.ocr_violation_detector import OCRViolationDetector +from ws.ws import send_message_to_client class VideoTrack(MediaStreamTrack): @@ -18,9 +21,10 @@ class VideoTrack(MediaStreamTrack): return await super().recv() -async def rtc_frame_receiver(url, frame_queue): +async def rtc_frame_receiver(url, frame_queue, stop_event): """ 对每帧进行检查、只要接收到 RTC 帧且队列为空、就往队列放入cv2格式的帧数据 + 当stop_event被设置时停止接收 """ pc = RTCPeerConnection(RTCConfiguration(iceServers=[])) video_track = VideoTrack() @@ -34,7 +38,7 @@ async def rtc_frame_receiver(url, frame_queue): nonlocal total_frames if track.kind == "video": print("接收到视频轨道、开始接收视频帧") - while True: + while not stop_event.is_set(): # 检查是否需要停止 # 接收当前帧并累计计数 frame = await track.recv() # 转换为cv2兼容的BGR格式numpy数组 @@ -45,13 +49,13 @@ async def rtc_frame_receiver(url, frame_queue): total_frames += 1 # 对每帧都检查队列状态、队列为空则放入 - if frame_queue.empty(): + if frame_queue.empty() and not stop_event.is_set(): # 确保还未收到停止信号 # 队列为空、放入当前cv2帧 await frame_queue.put(frame_cv2) - # print(f"第{total_frames}帧:队列为空、已放入新的cv2帧,尺寸: {frame_cv2.shape}") else: - # 队列非空、说明上一帧还未处理、跳过当前帧 - print(f"第{total_frames}帧:队列非空、跳过该帧") + # 队列非空或已收到停止信号、跳过当前帧 + if not stop_event.is_set(): + print(f"第{total_frames}帧:队列非空、跳过该帧") else: print("帧格式转换失败,不是有效的cv2格式") @@ -77,8 +81,8 @@ async def rtc_frame_receiver(url, frame_queue): await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer')) try: - # 保持连接 - while True: + # 保持连接,直到收到停止信号 + while not stop_event.is_set(): await asyncio.sleep(1) except KeyboardInterrupt: print("用户中断") @@ -88,52 +92,68 @@ async def rtc_frame_receiver(url, frame_queue): print("已关闭 RTCPeerConnection") -async def frame_consumer(frame_queue): +async def frame_consumer(ip, frame_queue, stop_event): """ 从队列中读取cv2帧并处理(队列空时会阻塞等待) + 检测到违规内容后设置stop_event以终止所有任务 - Args: frame_queue: 帧队列 + Args: + ip: IP地址 + frame_queue: 帧队列 + stop_event: 用于控制任务停止的事件 """ - # 创建OCR检测器实例(请替换为实际的违禁词文件路径) + # 创建OCR检测器实例 ocr_detector = OCRViolationDetector( forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt", # 替换为实际路径 ocr_confidence_threshold=0.5, ) - while True: + while not stop_event.is_set(): # 检查是否需要停止 # 从队列中获取cv2帧(队列为空时会阻塞等待新帧) current_frame = await frame_queue.get() has_violation, words, confidences = ocr_detector.detect(current_frame) + print(has_violation) + print( words) + print( confidences) # 输出所有检测到的违禁词 if has_violation: print(f"测试结果:图片中共检测到 {len(words)} 个违禁词:") + response_data = { + "status": "stop", + "timestamp": datetime.now().isoformat(), + } + await send_message_to_client(ip,response_data ) for word, conf in zip(words, confidences): print(f"- {word}(置信度:{conf:.4f})") + + # 检测到违规,设置停止事件 + print("检测到违规内容,准备关闭AI检测") + stop_event.set() else: print("测试结果:图片中未检测到违禁词") # 标记任务完成 frame_queue.task_done() - # print("帧处理完成、队列已清空") def process_webrtc_stream(ip, webrtc_url): """ - 处理WEBRTC流并持续打印OCR检测结果 + 处理WEBRTC流并持续打印OCR检测结果,检测到违规后关闭 Args: - ip: IP地址(预留参数) + ip: IP地址 webrtc_url: WEBRTC服务器地址 """ - # 创建队列 + # 创建队列和停止事件 frame_queue = asyncio.Queue(maxsize=1) + stop_event = asyncio.Event() # 用于控制任务停止的事件 # 定义事件循环中的主任务 async def main_task(): # 创建任务 - receiver_task = asyncio.create_task(rtc_frame_receiver(webrtc_url, frame_queue)) - consumer_task = asyncio.create_task(frame_consumer(frame_queue)) + receiver_task = asyncio.create_task(rtc_frame_receiver(webrtc_url, frame_queue, stop_event)) + consumer_task = asyncio.create_task(frame_consumer(ip, frame_queue, stop_event)) - # 等待任务完成 + # 等待任一任务完成(当stop_event被设置时,两个任务都会退出) await asyncio.gather(receiver_task, consumer_task) try: @@ -144,3 +164,4 @@ def process_webrtc_stream(ip, webrtc_url): finally: # 确保关闭所有cv2窗口 cv2.destroyAllWindows() + print("AI检测已关闭") diff --git a/service/device_service.py b/service/device_service.py index 115aa50..65910db 100644 --- a/service/device_service.py +++ b/service/device_service.py @@ -52,7 +52,7 @@ async def create_device(request: Request, device_data: DeviceCreateRequest): # 设备创建成功后,在后台线程启动WEBRTC流处理 threading.Thread( target=run_webrtc_processing, - args=(device_data.ip, full_webrtc_url), + args=(device_data.ip, existing_device["live_webrtc_url"]), daemon=True # 设为守护线程,主程序退出时自动结束 ).start() # IP已存在时返回该设备信息