diff --git a/core/rtc.py b/core/rtc.py index 1f5fc6f..823cd27 100644 --- a/core/rtc.py +++ b/core/rtc.py @@ -2,13 +2,24 @@ import asyncio import logging from aiortc import RTCPeerConnection, RTCSessionDescription import aiohttp +from ocr.ocr_violation_detector import OCRViolationDetector + +import logging + +# 创建检测器实例 +detector = OCRViolationDetector( + forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt", + ocr_confidence_threshold=0.7, + log_level=logging.INFO, + log_file="ocr_detection.log" +) # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger("whep_video_puller") -async def whep_pull_video_stream(whep_url): +async def whep_pull_video_stream(ip,whep_url): """ 通过WHEP从指定URL拉取视频流并在收到每一帧时打印消息 @@ -60,6 +71,15 @@ async def whep_pull_video_stream(whep_url): if hasattr(frame, 'pts'): print(f" 显示时间戳: {frame.pts}") + has_violation, violations, confidences = OCRViolationDetector.detect(frame) + + # 输出检测结果 + if has_violation: + detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:") + for word, conf in zip(violations, confidences): + detector.logger.info(f"- {word} (置信度: {conf:.4f})") + else: + detector.logger.info("图片中未检测到违禁词") except Exception as e: print(f"接收帧时出错: {e}") # 等待一段时间后重试 diff --git a/core/rtmp.py b/core/rtmp.py index c200c04..02da436 100644 --- a/core/rtmp.py +++ b/core/rtmp.py @@ -2,6 +2,17 @@ import asyncio import logging import cv2 import time +from ocr.ocr_violation_detector import OCRViolationDetector + +import logging + +# 创建检测器实例 +detector = OCRViolationDetector( + forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt", + ocr_confidence_threshold=0.7, + log_level=logging.INFO, + log_file="ocr_detection.log" +) # 配置日志(与WHEP代码保持一致的日志风格) logging.basicConfig(level=logging.INFO) @@ -67,6 +78,15 @@ async def rtmp_pull_video_stream(rtmp_url): print(f" 帧尺寸: {width}x{height}") print(f" 配置帧率: {fps:.2f} FPS") + has_violation, violations, confidences = OCRViolationDetector.detect(frame) + + # 输出检测结果 + if has_violation: + detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:") + for word, conf in zip(violations, confidences): + detector.logger.info(f"- {word} (置信度: {conf:.4f})") + else: + detector.logger.info("图片中未检测到违禁词") # 7. 每100帧统计一次实际接收帧率(补充性能监控,与原RTMP示例逻辑一致) if frame_count % 100 == 0: elapsed_time = time.time() - start_time diff --git a/ocr/ocr_violation_detector.py b/ocr/ocr_violation_detector.py index 460f5ec..dbd4fee 100644 --- a/ocr/ocr_violation_detector.py +++ b/ocr/ocr_violation_detector.py @@ -157,3 +157,77 @@ class OCRViolationDetector: # 返回检测结果(是否有违禁词、所有违禁词列表、对应置信度列表) return len(all_prohibited) > 0, all_prohibited, all_confidences + +# def test_image(self, image_path: str, show_image: bool = True) -> tuple: +# """ +# 对单张图片进行OCR违禁词检测并展示结果 +# +# Args: +# image_path (str): 图片文件路径 +# show_image (bool): 是否显示图片,默认为True +# +# Returns: +# tuple: (是否有违禁词, 违禁词列表, 对应的置信度列表) +# """ +# # 检查图片文件是否存在 +# if not os.path.exists(image_path): +# self.logger.error(f"图片文件不存在: {image_path}") +# return False, [], [] +# +# try: +# # 读取图片 +# frame = cv2.imread(image_path) +# if frame is None: +# self.logger.error(f"无法读取图片: {image_path}") +# return False, [], [] +# +# self.logger.info(f"开始处理图片: {image_path}") +# +# # 调用检测方法 +# has_violation, violations, confidences = self.detect(frame) +# +# # 输出检测结果 +# if has_violation: +# self.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:") +# for word, conf in zip(violations, confidences): +# self.logger.info(f"- {word} (置信度: {conf:.4f})") +# else: +# self.logger.info("图片中未检测到违禁词") + +# # 显示图片(如果需要) +# if show_image: +# # 调整图片大小以便于显示(如果太大) +# height, width = frame.shape[:2] +# max_size = 800 +# if max(height, width) > max_size: +# scale = max_size / max(height, width) +# frame = cv2.resize(frame, None, fx=scale, fy=scale) +# +# cv2.imshow(f"OCR检测结果: {'发现违禁词' if has_violation else '未发现违禁词'}", frame) +# cv2.waitKey(0) # 等待用户按键 +# cv2.destroyAllWindows() +# +# return has_violation, violations, confidences +# +# except Exception as e: +# self.logger.error(f"处理图片时发生错误: {str(e)}", exc_info=True) +# return False, [], [] +# +# +# # 使用示例 +# if __name__ == "__main__": +# # 配置参数 +# forbidden_words_path = "forbidden_words.txt" # 违禁词文件路径 +# test_image_path = r"D:\Git\bin\video\ocr\images\img_7.png" # 测试图片路径 +# ocr_threshold = 0.6 # OCR置信度阈值 +# +# # 创建检测器实例 +# detector = OCRViolationDetector( +# forbidden_words_path=forbidden_words_path, +# ocr_confidence_threshold=ocr_threshold, +# log_level=logging.INFO, +# log_file="ocr_detection.log" +# ) +# +# # 测试图片 +# detector.test_image(test_image_path, show_image=True) \ No newline at end of file diff --git a/rtc/rtc.py b/rtc/rtc.py index 1f5fc6f..c47e5e3 100644 --- a/rtc/rtc.py +++ b/rtc/rtc.py @@ -1,117 +1,175 @@ +import queue import asyncio -import logging -from aiortc import RTCPeerConnection, RTCSessionDescription import aiohttp +import threading +import time +from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration +from aiortc.mediastreams import MediaStreamTrack -# 配置日志 -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger("whep_video_puller") +from ocr.ocr_violation_detector import OCRViolationDetector +import logging + +# 创建检测器实例 +detector = OCRViolationDetector( + forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt", + ocr_confidence_threshold=0.7, + log_level=logging.INFO, + log_file="ocr_detection.log" +) + +# 创建一个长度为1的队列,用于生产者和消费者之间的通信 +frame_queue = queue.Queue(maxsize=1) -async def whep_pull_video_stream(whep_url): +class VideoTrack(MediaStreamTrack): + """自定义视频轨道类,继承自MediaStreamTrack""" + kind = "video" + + def __init__(self, max_frames=100): + super().__init__() + self.frames = queue.Queue(maxsize=max_frames) + + async def recv(self): + return await super().recv() + + +def webrtc_producer(webrtc_url): """ - 通过WHEP从指定URL拉取视频流并在收到每一帧时打印消息 - - Args: - whep_url: WHEP端点的URL + 生产者方法:从WEBRTC读取视频帧并放入队列 + 仅当队列空时才放入新帧,否则丢弃 """ - pc = RTCPeerConnection() + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) - # 添加连接状态变化监听 - @pc.on("connectionstatechange") - async def on_connectionstatechange(): - print(f"连接状态: {pc.connectionState}") + # 创建RTCPeerConnection对象,不使用ICE服务器 + pc = RTCPeerConnection(RTCConfiguration(iceServers=[])) + video_track = VideoTrack() + pc.addTrack(video_track) - # 添加ICE连接状态变化监听 - @pc.on("iceconnectionstatechange") - async def on_iceconnectionstatechange(): - print(f"ICE连接状态: {pc.iceConnectionState}") - - # 添加视频接收器 - pc.addTransceiver("video", direction="recvonly") - - # 处理接收到的视频轨道 @pc.on("track") - def on_track(track): - print(f"接收到轨道: {track.kind}") + async def on_track(track): if track.kind == "video": - print(f"轨道ID: {track.id}") - print(f"轨道就绪状态: {track.readyState}") - # 创建异步任务来处理视频帧 - asyncio.ensure_future(handle_video_track(track)) - - async def handle_video_track(track): - """处理视频轨道,接收并打印每一帧""" - frame_count = 0 - print("开始处理视频轨道...") - - while True: - try: - # 尝试接收帧 + print("接收到视频轨道,开始接收视频帧") + while True: + # 从轨道接收视频帧 frame = await track.recv() - frame_count += 1 - print(f"收到原始帧 (第{frame_count}帧)") + # 转换为BGR24格式的NumPy数组 + frame_bgr24 = frame.to_ndarray(format='bgr24') - # 打印帧的基本信息 - if hasattr(frame, 'width') and hasattr(frame, 'height'): - print(f" 尺寸: {frame.width}x{frame.height}") - if hasattr(frame, 'time_base'): - print(f" 时间基准: {frame.time_base}") - if hasattr(frame, 'pts'): - print(f" 显示时间戳: {frame.pts}") + # 检查队列是否为空,为空则加入,否则丢弃 + if frame_queue.empty(): + try: + frame_queue.put_nowait(frame_bgr24) + print("帧已放入队列") + except queue.Full: + print("队列已满,丢弃帧") + else: + print("队列非空,丢弃帧") - except Exception as e: - print(f"接收帧时出错: {e}") - # 等待一段时间后重试 + async def main(): + # 创建并发送SDP Offer + offer = await pc.createOffer() + print("已创建本地SDP Offer") + await pc.setLocalDescription(offer) + + # 发送Offer到服务器并接收Answer + async with aiohttp.ClientSession() as session: + print(f"开始向服务器 {webrtc_url} 发送SDP Offer") + async with session.post( + webrtc_url, + data=offer.sdp.encode(), + headers={ + "Content-Type": "application/sdp", + "Content-Length": str(len(offer.sdp)) + }, + ssl=False + ) as response: + print("已接收到服务器的响应") + answer_sdp = await response.text() + await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer')) + + # 保持连接 + try: + while True: await asyncio.sleep(0.1) - continue + except KeyboardInterrupt: + pass + finally: + print("关闭RTCPeerConnection") + await pc.close() - # 创建offer - offer = await pc.createOffer() - await pc.setLocalDescription(offer) + try: + loop.run_until_complete(main()) + finally: + loop.close() - print(f"本地SDP信息:\n{offer.sdp}") - # 通过HTTP POST发送offer到WHEP端点 - async with aiohttp.ClientSession() as session: - async with session.post( - whep_url, - data=offer.sdp, - headers={"Content-Type": "application/sdp"} - ) as response: - if response.status != 201: - print(f"WHEP服务器返回错误: {response.status}") - print(f"响应内容: {await response.text()}") - raise Exception(f"WHEP服务器返回错误: {response.status}") - - # 获取answer SDP - answer_sdp = await response.text() - - # 创建RTCSessionDescription对象 - answer = RTCSessionDescription(sdp=answer_sdp, type="answer") - - print(f"收到远程SDP:\n{answer_sdp}") - - # 设置远程描述 - await pc.setRemoteDescription(answer) - - print("连接已建立,开始接收视频流...") - - # 保持连接,直到用户中断 +def frame_consumer(ip): + """ + 消费者方法:从队列中读取帧并处理 + 每次处理后休眠200ms模拟延迟 + """ + print("消费者启动,开始等待帧...") try: while True: - await asyncio.sleep(1) - # 检查连接状态 - print(f"当前连接状态: {pc.connectionState}") + # 阻塞等待队列中的帧 + frame = frame_queue.get() + print(f"消费帧,大小: {frame.shape}") + + has_violation, violations, confidences = OCRViolationDetector.detect(frame) + + + # 输出检测结果 + if has_violation: + detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:") + for word, conf in zip(violations, confidences): + detector.logger.info(f"- {word} (置信度: {conf:.4f})") + else: + detector.logger.info("图片中未检测到违禁词") + + + # 标记任务完成 + frame_queue.task_done() except KeyboardInterrupt: - print("用户中断,关闭连接...") - finally: - await pc.close() + print("消费者退出") + + +def start_webrtc_stream(ip, webrtc_url): + """ + 启动WebRTC视频流处理的主方法 + 参数: webrtc_url - WebRTC服务器地址 + """ + print(f"开始连接到WebRTC服务器: {webrtc_url}") + + # 启动生产者线程 + producer_thread = threading.Thread( + target=webrtc_producer, + args=(webrtc_url,), + daemon=True, + name="webrtc-producer" + ) + + # 启动消费者线程 + consumer_thread = threading.Thread( + target=frame_consumer(ip), + daemon=True, + name="frame-consumer" + ) + + producer_thread.start() + consumer_thread.start() + print("生产者和消费者线程已启动") + + try: + # 保持主线程运行 + while True: + time.sleep(1) + except KeyboardInterrupt: + print("程序正在退出...") if __name__ == "__main__": - # 替换为你的WHEP端点URL - WHEP_URL = "http://192.168.110.25:1985/rtc/v1/whep/?app=live&stream=473b95a47e338301cbd96809ea7ac416" - - # 运行拉流任务 - asyncio.run(whep_pull_video_stream(WHEP_URL)) + # 示例用法 + # 实际使用时替换为真实的WebRTC服务器地址 + webrtc_server_url = "http://192.168.110.65:1985/rtc/v1/whep/?app=live&stream=677a4845aa48cb8526c811ad56fc5e60" + start_webrtc_stream(webrtc_server_url) diff --git a/service/device_service.py b/service/device_service.py index 65910db..798f4b7 100644 --- a/service/device_service.py +++ b/service/device_service.py @@ -1,5 +1,7 @@ import json import threading +import time + from fastapi import HTTPException, Query, APIRouter, Depends, Request from mysql.connector import Error as MySQLError @@ -17,7 +19,7 @@ from schema.response_schema import APIResponse from schema.user_schema import UserResponse # 导入之前封装的WEBRTC处理函数 -from rtc.rtc import process_webrtc_stream +from core.rtmp import rtmp_pull_video_stream router = APIRouter( prefix="/devices", @@ -29,7 +31,7 @@ router = APIRouter( def run_webrtc_processing(ip, webrtc_url): try: print(f"开始处理来自设备 {ip} 的WEBRTC流: {webrtc_url}") - process_webrtc_stream(ip, webrtc_url) + rtmp_pull_video_stream(webrtc_url) except Exception as e: print(f"WEBRTC处理出错: {str(e)}") @@ -52,7 +54,9 @@ async def create_device(request: Request, device_data: DeviceCreateRequest): # 设备创建成功后,在后台线程启动WEBRTC流处理 threading.Thread( target=run_webrtc_processing, - args=(device_data.ip, existing_device["live_webrtc_url"]), + # args=(device_data.ip, existing_device["live_webrtc_url"]), + args=(device_data.ip, existing_device["rtmp_push_url"]), + daemon=True # 设为守护线程,主程序退出时自动结束 ).start() # IP已存在时返回该设备信息