ocr1.0
This commit is contained in:
		| @ -7,6 +7,11 @@ def setup_logger(): | |||||||
|     配置一个全局日志记录器,支持输出到控制台和文件。 |     配置一个全局日志记录器,支持输出到控制台和文件。 | ||||||
|     """ |     """ | ||||||
|     # 创建一个日志记录器 |     # 创建一个日志记录器 | ||||||
|  |  | ||||||
|  |     # 配置日志 | ||||||
|  |     logging.basicConfig(level=logging.INFO) | ||||||
|  |     logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
|     logger = logging.getLogger("ViolationDetectorLogger") |     logger = logging.getLogger("ViolationDetectorLogger") | ||||||
|     logger.setLevel(logging.DEBUG)  # 设置最低级别为DEBUG |     logger.setLevel(logging.DEBUG)  # 设置最低级别为DEBUG | ||||||
|  |  | ||||||
|  | |||||||
| @ -99,7 +99,7 @@ class OCRViolationDetector: | |||||||
|         """初始化RapidOCR引擎""" |         """初始化RapidOCR引擎""" | ||||||
|         self.logger.info("正在初始化RapidOCR引擎...") |         self.logger.info("正在初始化RapidOCR引擎...") | ||||||
|  |  | ||||||
|         config_path = r"../ocr/config/1.yaml" |         config_path = r"D:\Git\bin\video\ocr\config\1.yaml" | ||||||
|         try: |         try: | ||||||
|             # 检查配置文件是否存在 |             # 检查配置文件是否存在 | ||||||
|             if not os.path.exists(config_path): |             if not os.path.exists(config_path): | ||||||
| @ -157,4 +157,3 @@ class OCRViolationDetector: | |||||||
|  |  | ||||||
|         # 返回检测结果(是否有违禁词、所有违禁词列表、对应置信度列表) |         # 返回检测结果(是否有违禁词、所有违禁词列表、对应置信度列表) | ||||||
|         return len(all_prohibited) > 0, all_prohibited, all_confidences |         return len(all_prohibited) > 0, all_prohibited, all_confidences | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										57
									
								
								rtc/rtc.py
									
									
									
									
									
								
							
							
						
						
									
										57
									
								
								rtc/rtc.py
									
									
									
									
									
								
							| @ -1,10 +1,13 @@ | |||||||
| import asyncio | import asyncio | ||||||
|  | from datetime import datetime | ||||||
|  |  | ||||||
| import aiohttp | import aiohttp | ||||||
| import cv2 | import cv2 | ||||||
| import numpy as np | import numpy as np | ||||||
| from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration | from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration | ||||||
| from aiortc.mediastreams import MediaStreamTrack | from aiortc.mediastreams import MediaStreamTrack | ||||||
| from ocr.ocr_violation_detector import OCRViolationDetector | from ocr.ocr_violation_detector import OCRViolationDetector | ||||||
|  | from ws.ws import send_message_to_client | ||||||
|  |  | ||||||
|  |  | ||||||
| class VideoTrack(MediaStreamTrack): | class VideoTrack(MediaStreamTrack): | ||||||
| @ -18,9 +21,10 @@ class VideoTrack(MediaStreamTrack): | |||||||
|         return await super().recv() |         return await super().recv() | ||||||
|  |  | ||||||
|  |  | ||||||
| async def rtc_frame_receiver(url, frame_queue): | async def rtc_frame_receiver(url, frame_queue, stop_event): | ||||||
|     """ |     """ | ||||||
|     对每帧进行检查、只要接收到 RTC 帧且队列为空、就往队列放入cv2格式的帧数据 |     对每帧进行检查、只要接收到 RTC 帧且队列为空、就往队列放入cv2格式的帧数据 | ||||||
|  |     当stop_event被设置时停止接收 | ||||||
|     """ |     """ | ||||||
|     pc = RTCPeerConnection(RTCConfiguration(iceServers=[])) |     pc = RTCPeerConnection(RTCConfiguration(iceServers=[])) | ||||||
|     video_track = VideoTrack() |     video_track = VideoTrack() | ||||||
| @ -34,7 +38,7 @@ async def rtc_frame_receiver(url, frame_queue): | |||||||
|         nonlocal total_frames |         nonlocal total_frames | ||||||
|         if track.kind == "video": |         if track.kind == "video": | ||||||
|             print("接收到视频轨道、开始接收视频帧") |             print("接收到视频轨道、开始接收视频帧") | ||||||
|             while True: |             while not stop_event.is_set():  # 检查是否需要停止 | ||||||
|                 # 接收当前帧并累计计数 |                 # 接收当前帧并累计计数 | ||||||
|                 frame = await track.recv() |                 frame = await track.recv() | ||||||
|                 # 转换为cv2兼容的BGR格式numpy数组 |                 # 转换为cv2兼容的BGR格式numpy数组 | ||||||
| @ -45,12 +49,12 @@ async def rtc_frame_receiver(url, frame_queue): | |||||||
|                     total_frames += 1 |                     total_frames += 1 | ||||||
|  |  | ||||||
|                     # 对每帧都检查队列状态、队列为空则放入 |                     # 对每帧都检查队列状态、队列为空则放入 | ||||||
|                     if frame_queue.empty(): |                     if frame_queue.empty() and not stop_event.is_set():  # 确保还未收到停止信号 | ||||||
|                         # 队列为空、放入当前cv2帧 |                         # 队列为空、放入当前cv2帧 | ||||||
|                         await frame_queue.put(frame_cv2) |                         await frame_queue.put(frame_cv2) | ||||||
|                         # print(f"第{total_frames}帧:队列为空、已放入新的cv2帧,尺寸: {frame_cv2.shape}") |  | ||||||
|                     else: |                     else: | ||||||
|                         # 队列非空、说明上一帧还未处理、跳过当前帧 |                         # 队列非空或已收到停止信号、跳过当前帧 | ||||||
|  |                         if not stop_event.is_set(): | ||||||
|                             print(f"第{total_frames}帧:队列非空、跳过该帧") |                             print(f"第{total_frames}帧:队列非空、跳过该帧") | ||||||
|                 else: |                 else: | ||||||
|                     print("帧格式转换失败,不是有效的cv2格式") |                     print("帧格式转换失败,不是有效的cv2格式") | ||||||
| @ -77,8 +81,8 @@ async def rtc_frame_receiver(url, frame_queue): | |||||||
|             await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer')) |             await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer')) | ||||||
|  |  | ||||||
|     try: |     try: | ||||||
|         # 保持连接 |         # 保持连接,直到收到停止信号 | ||||||
|         while True: |         while not stop_event.is_set(): | ||||||
|             await asyncio.sleep(1) |             await asyncio.sleep(1) | ||||||
|     except KeyboardInterrupt: |     except KeyboardInterrupt: | ||||||
|         print("用户中断") |         print("用户中断") | ||||||
| @ -88,52 +92,68 @@ async def rtc_frame_receiver(url, frame_queue): | |||||||
|         print("已关闭 RTCPeerConnection") |         print("已关闭 RTCPeerConnection") | ||||||
|  |  | ||||||
|  |  | ||||||
| async def frame_consumer(frame_queue): | async def frame_consumer(ip, frame_queue, stop_event): | ||||||
|     """ |     """ | ||||||
|     从队列中读取cv2帧并处理(队列空时会阻塞等待) |     从队列中读取cv2帧并处理(队列空时会阻塞等待) | ||||||
|  |     检测到违规内容后设置stop_event以终止所有任务 | ||||||
|  |  | ||||||
|     Args: frame_queue: 帧队列 |     Args: | ||||||
|  |         ip: IP地址 | ||||||
|  |         frame_queue: 帧队列 | ||||||
|  |         stop_event: 用于控制任务停止的事件 | ||||||
|     """ |     """ | ||||||
|     # 创建OCR检测器实例(请替换为实际的违禁词文件路径) |     # 创建OCR检测器实例 | ||||||
|     ocr_detector = OCRViolationDetector( |     ocr_detector = OCRViolationDetector( | ||||||
|         forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt",  # 替换为实际路径 |         forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt",  # 替换为实际路径 | ||||||
|         ocr_confidence_threshold=0.5, ) |         ocr_confidence_threshold=0.5, ) | ||||||
|  |  | ||||||
|     while True: |     while not stop_event.is_set():  # 检查是否需要停止 | ||||||
|         # 从队列中获取cv2帧(队列为空时会阻塞等待新帧) |         # 从队列中获取cv2帧(队列为空时会阻塞等待新帧) | ||||||
|         current_frame = await frame_queue.get() |         current_frame = await frame_queue.get() | ||||||
|         has_violation, words, confidences = ocr_detector.detect(current_frame) |         has_violation, words, confidences = ocr_detector.detect(current_frame) | ||||||
|  |         print(has_violation) | ||||||
|  |         print( words) | ||||||
|  |         print( confidences) | ||||||
|         # 输出所有检测到的违禁词 |         # 输出所有检测到的违禁词 | ||||||
|         if has_violation: |         if has_violation: | ||||||
|             print(f"测试结果:图片中共检测到 {len(words)} 个违禁词:") |             print(f"测试结果:图片中共检测到 {len(words)} 个违禁词:") | ||||||
|  |             response_data = { | ||||||
|  |                 "status": "stop", | ||||||
|  |                 "timestamp": datetime.now().isoformat(), | ||||||
|  |             } | ||||||
|  |             await send_message_to_client(ip,response_data ) | ||||||
|             for word, conf in zip(words, confidences): |             for word, conf in zip(words, confidences): | ||||||
|                 print(f"- {word}(置信度:{conf:.4f})") |                 print(f"- {word}(置信度:{conf:.4f})") | ||||||
|  |  | ||||||
|  |             # 检测到违规,设置停止事件 | ||||||
|  |             print("检测到违规内容,准备关闭AI检测") | ||||||
|  |             stop_event.set() | ||||||
|         else: |         else: | ||||||
|             print("测试结果:图片中未检测到违禁词") |             print("测试结果:图片中未检测到违禁词") | ||||||
|  |  | ||||||
|         # 标记任务完成 |         # 标记任务完成 | ||||||
|         frame_queue.task_done() |         frame_queue.task_done() | ||||||
|         # print("帧处理完成、队列已清空") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def process_webrtc_stream(ip, webrtc_url): | def process_webrtc_stream(ip, webrtc_url): | ||||||
|     """ |     """ | ||||||
|     处理WEBRTC流并持续打印OCR检测结果 |     处理WEBRTC流并持续打印OCR检测结果,检测到违规后关闭 | ||||||
|  |  | ||||||
|     Args: |     Args: | ||||||
|         ip: IP地址(预留参数) |         ip: IP地址 | ||||||
|         webrtc_url: WEBRTC服务器地址 |         webrtc_url: WEBRTC服务器地址 | ||||||
|     """ |     """ | ||||||
|     # 创建队列 |     # 创建队列和停止事件 | ||||||
|     frame_queue = asyncio.Queue(maxsize=1) |     frame_queue = asyncio.Queue(maxsize=1) | ||||||
|  |     stop_event = asyncio.Event()  # 用于控制任务停止的事件 | ||||||
|  |  | ||||||
|     # 定义事件循环中的主任务 |     # 定义事件循环中的主任务 | ||||||
|     async def main_task(): |     async def main_task(): | ||||||
|         # 创建任务 |         # 创建任务 | ||||||
|         receiver_task = asyncio.create_task(rtc_frame_receiver(webrtc_url, frame_queue)) |         receiver_task = asyncio.create_task(rtc_frame_receiver(webrtc_url, frame_queue, stop_event)) | ||||||
|         consumer_task = asyncio.create_task(frame_consumer(frame_queue)) |         consumer_task = asyncio.create_task(frame_consumer(ip, frame_queue, stop_event)) | ||||||
|  |  | ||||||
|         # 等待任务完成 |         # 等待任一任务完成(当stop_event被设置时,两个任务都会退出) | ||||||
|         await asyncio.gather(receiver_task, consumer_task) |         await asyncio.gather(receiver_task, consumer_task) | ||||||
|  |  | ||||||
|     try: |     try: | ||||||
| @ -144,3 +164,4 @@ def process_webrtc_stream(ip, webrtc_url): | |||||||
|     finally: |     finally: | ||||||
|         # 确保关闭所有cv2窗口 |         # 确保关闭所有cv2窗口 | ||||||
|         cv2.destroyAllWindows() |         cv2.destroyAllWindows() | ||||||
|  |         print("AI检测已关闭") | ||||||
|  | |||||||
| @ -52,7 +52,7 @@ async def create_device(request: Request, device_data: DeviceCreateRequest): | |||||||
|             # 设备创建成功后,在后台线程启动WEBRTC流处理 |             # 设备创建成功后,在后台线程启动WEBRTC流处理 | ||||||
|             threading.Thread( |             threading.Thread( | ||||||
|                 target=run_webrtc_processing, |                 target=run_webrtc_processing, | ||||||
|                 args=(device_data.ip, full_webrtc_url), |                 args=(device_data.ip, existing_device["live_webrtc_url"]), | ||||||
|                 daemon=True  # 设为守护线程,主程序退出时自动结束 |                 daemon=True  # 设为守护线程,主程序退出时自动结束 | ||||||
|             ).start() |             ).start() | ||||||
|             # IP已存在时返回该设备信息 |             # IP已存在时返回该设备信息 | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user