This commit is contained in:
2025-09-03 13:52:24 +08:00
parent 8cb8e5f935
commit eb5cf715ec
5 changed files with 273 additions and 97 deletions

View File

@ -2,13 +2,24 @@ import asyncio
import logging import logging
from aiortc import RTCPeerConnection, RTCSessionDescription from aiortc import RTCPeerConnection, RTCSessionDescription
import aiohttp import aiohttp
from ocr.ocr_violation_detector import OCRViolationDetector
import logging
# 创建检测器实例
detector = OCRViolationDetector(
forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt",
ocr_confidence_threshold=0.7,
log_level=logging.INFO,
log_file="ocr_detection.log"
)
# 配置日志 # 配置日志
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("whep_video_puller") logger = logging.getLogger("whep_video_puller")
async def whep_pull_video_stream(whep_url): async def whep_pull_video_stream(ip,whep_url):
""" """
通过WHEP从指定URL拉取视频流并在收到每一帧时打印消息 通过WHEP从指定URL拉取视频流并在收到每一帧时打印消息
@ -60,6 +71,15 @@ async def whep_pull_video_stream(whep_url):
if hasattr(frame, 'pts'): if hasattr(frame, 'pts'):
print(f" 显示时间戳: {frame.pts}") print(f" 显示时间戳: {frame.pts}")
has_violation, violations, confidences = OCRViolationDetector.detect(frame)
# 输出检测结果
if has_violation:
detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
for word, conf in zip(violations, confidences):
detector.logger.info(f"- {word} (置信度: {conf:.4f})")
else:
detector.logger.info("图片中未检测到违禁词")
except Exception as e: except Exception as e:
print(f"接收帧时出错: {e}") print(f"接收帧时出错: {e}")
# 等待一段时间后重试 # 等待一段时间后重试

View File

@ -2,6 +2,17 @@ import asyncio
import logging import logging
import cv2 import cv2
import time import time
from ocr.ocr_violation_detector import OCRViolationDetector
import logging
# 创建检测器实例
detector = OCRViolationDetector(
forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt",
ocr_confidence_threshold=0.7,
log_level=logging.INFO,
log_file="ocr_detection.log"
)
# 配置日志与WHEP代码保持一致的日志风格 # 配置日志与WHEP代码保持一致的日志风格
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -67,6 +78,15 @@ async def rtmp_pull_video_stream(rtmp_url):
print(f" 帧尺寸: {width}x{height}") print(f" 帧尺寸: {width}x{height}")
print(f" 配置帧率: {fps:.2f} FPS") print(f" 配置帧率: {fps:.2f} FPS")
has_violation, violations, confidences = OCRViolationDetector.detect(frame)
# 输出检测结果
if has_violation:
detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
for word, conf in zip(violations, confidences):
detector.logger.info(f"- {word} (置信度: {conf:.4f})")
else:
detector.logger.info("图片中未检测到违禁词")
# 7. 每100帧统计一次实际接收帧率补充性能监控与原RTMP示例逻辑一致 # 7. 每100帧统计一次实际接收帧率补充性能监控与原RTMP示例逻辑一致
if frame_count % 100 == 0: if frame_count % 100 == 0:
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time

View File

@ -157,3 +157,77 @@ class OCRViolationDetector:
# 返回检测结果(是否有违禁词、所有违禁词列表、对应置信度列表) # 返回检测结果(是否有违禁词、所有违禁词列表、对应置信度列表)
return len(all_prohibited) > 0, all_prohibited, all_confidences return len(all_prohibited) > 0, all_prohibited, all_confidences
# def test_image(self, image_path: str, show_image: bool = True) -> tuple:
# """
# 对单张图片进行OCR违禁词检测并展示结果
#
# Args:
# image_path (str): 图片文件路径
# show_image (bool): 是否显示图片默认为True
#
# Returns:
# tuple: (是否有违禁词, 违禁词列表, 对应的置信度列表)
# """
# # 检查图片文件是否存在
# if not os.path.exists(image_path):
# self.logger.error(f"图片文件不存在: {image_path}")
# return False, [], []
#
# try:
# # 读取图片
# frame = cv2.imread(image_path)
# if frame is None:
# self.logger.error(f"无法读取图片: {image_path}")
# return False, [], []
#
# self.logger.info(f"开始处理图片: {image_path}")
#
# # 调用检测方法
# has_violation, violations, confidences = self.detect(frame)
#
# # 输出检测结果
# if has_violation:
# self.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
# for word, conf in zip(violations, confidences):
# self.logger.info(f"- {word} (置信度: {conf:.4f})")
# else:
# self.logger.info("图片中未检测到违禁词")
# # 显示图片(如果需要)
# if show_image:
# # 调整图片大小以便于显示(如果太大)
# height, width = frame.shape[:2]
# max_size = 800
# if max(height, width) > max_size:
# scale = max_size / max(height, width)
# frame = cv2.resize(frame, None, fx=scale, fy=scale)
#
# cv2.imshow(f"OCR检测结果: {'发现违禁词' if has_violation else '未发现违禁词'}", frame)
# cv2.waitKey(0) # 等待用户按键
# cv2.destroyAllWindows()
#
# return has_violation, violations, confidences
#
# except Exception as e:
# self.logger.error(f"处理图片时发生错误: {str(e)}", exc_info=True)
# return False, [], []
#
#
# # 使用示例
# if __name__ == "__main__":
# # 配置参数
# forbidden_words_path = "forbidden_words.txt" # 违禁词文件路径
# test_image_path = r"D:\Git\bin\video\ocr\images\img_7.png" # 测试图片路径
# ocr_threshold = 0.6 # OCR置信度阈值
#
# # 创建检测器实例
# detector = OCRViolationDetector(
# forbidden_words_path=forbidden_words_path,
# ocr_confidence_threshold=ocr_threshold,
# log_level=logging.INFO,
# log_file="ocr_detection.log"
# )
#
# # 测试图片
# detector.test_image(test_image_path, show_image=True)

View File

@ -1,117 +1,175 @@
import queue
import asyncio import asyncio
import logging
from aiortc import RTCPeerConnection, RTCSessionDescription
import aiohttp import aiohttp
import threading
import time
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration
from aiortc.mediastreams import MediaStreamTrack
# 配置日志 from ocr.ocr_violation_detector import OCRViolationDetector
logging.basicConfig(level=logging.INFO) import logging
logger = logging.getLogger("whep_video_puller")
# 创建检测器实例
detector = OCRViolationDetector(
forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt",
ocr_confidence_threshold=0.7,
log_level=logging.INFO,
log_file="ocr_detection.log"
)
# 创建一个长度为1的队列用于生产者和消费者之间的通信
frame_queue = queue.Queue(maxsize=1)
async def whep_pull_video_stream(whep_url): class VideoTrack(MediaStreamTrack):
"""自定义视频轨道类继承自MediaStreamTrack"""
kind = "video"
def __init__(self, max_frames=100):
super().__init__()
self.frames = queue.Queue(maxsize=max_frames)
async def recv(self):
return await super().recv()
def webrtc_producer(webrtc_url):
""" """
通过WHEP从指定URL拉取视频流并在收到每一帧时打印消息 生产者方法从WEBRTC读取视频帧并放入队列
仅当队列空时才放入新帧,否则丢弃
Args:
whep_url: WHEP端点的URL
""" """
pc = RTCPeerConnection() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 添加连接状态变化监听 # 创建RTCPeerConnection对象不使用ICE服务器
@pc.on("connectionstatechange") pc = RTCPeerConnection(RTCConfiguration(iceServers=[]))
async def on_connectionstatechange(): video_track = VideoTrack()
print(f"连接状态: {pc.connectionState}") pc.addTrack(video_track)
# 添加ICE连接状态变化监听
@pc.on("iceconnectionstatechange")
async def on_iceconnectionstatechange():
print(f"ICE连接状态: {pc.iceConnectionState}")
# 添加视频接收器
pc.addTransceiver("video", direction="recvonly")
# 处理接收到的视频轨道
@pc.on("track") @pc.on("track")
def on_track(track): async def on_track(track):
print(f"接收到轨道: {track.kind}")
if track.kind == "video": if track.kind == "video":
print(f"轨道ID: {track.id}") print("接收到视频轨道,开始接收视频帧")
print(f"轨道就绪状态: {track.readyState}")
# 创建异步任务来处理视频帧
asyncio.ensure_future(handle_video_track(track))
async def handle_video_track(track):
"""处理视频轨道,接收并打印每一帧"""
frame_count = 0
print("开始处理视频轨道...")
while True: while True:
try: # 从轨道接收视频帧
# 尝试接收帧
frame = await track.recv() frame = await track.recv()
frame_count += 1 # 转换为BGR24格式的NumPy数组
print(f"收到原始帧 (第{frame_count}帧)") frame_bgr24 = frame.to_ndarray(format='bgr24')
# 打印帧的基本信息 # 检查队列是否为空,为空则加入,否则丢弃
if hasattr(frame, 'width') and hasattr(frame, 'height'): if frame_queue.empty():
print(f" 尺寸: {frame.width}x{frame.height}") try:
if hasattr(frame, 'time_base'): frame_queue.put_nowait(frame_bgr24)
print(f" 时间基准: {frame.time_base}") print("帧已放入队列")
if hasattr(frame, 'pts'): except queue.Full:
print(f" 显示时间戳: {frame.pts}") print("队列已满,丢弃帧")
else:
print("队列非空,丢弃帧")
except Exception as e: async def main():
print(f"接收帧时出错: {e}") # 创建并发送SDP Offer
# 等待一段时间后重试
await asyncio.sleep(0.1)
continue
# 创建offer
offer = await pc.createOffer() offer = await pc.createOffer()
print("已创建本地SDP Offer")
await pc.setLocalDescription(offer) await pc.setLocalDescription(offer)
print(f"本地SDP信息:\n{offer.sdp}") # 发送Offer到服务器并接收Answer
# 通过HTTP POST发送offer到WHEP端点
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
print(f"开始向服务器 {webrtc_url} 发送SDP Offer")
async with session.post( async with session.post(
whep_url, webrtc_url,
data=offer.sdp, data=offer.sdp.encode(),
headers={"Content-Type": "application/sdp"} headers={
"Content-Type": "application/sdp",
"Content-Length": str(len(offer.sdp))
},
ssl=False
) as response: ) as response:
if response.status != 201: print("已接收到服务器的响应")
print(f"WHEP服务器返回错误: {response.status}")
print(f"响应内容: {await response.text()}")
raise Exception(f"WHEP服务器返回错误: {response.status}")
# 获取answer SDP
answer_sdp = await response.text() answer_sdp = await response.text()
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer'))
# 创建RTCSessionDescription对象 # 保持连接
answer = RTCSessionDescription(sdp=answer_sdp, type="answer")
print(f"收到远程SDP:\n{answer_sdp}")
# 设置远程描述
await pc.setRemoteDescription(answer)
print("连接已建立,开始接收视频流...")
# 保持连接,直到用户中断
try: try:
while True: while True:
await asyncio.sleep(1) await asyncio.sleep(0.1)
# 检查连接状态
print(f"当前连接状态: {pc.connectionState}")
except KeyboardInterrupt: except KeyboardInterrupt:
print("用户中断,关闭连接...") pass
finally: finally:
print("关闭RTCPeerConnection")
await pc.close() await pc.close()
try:
loop.run_until_complete(main())
finally:
loop.close()
def frame_consumer(ip):
"""
消费者方法:从队列中读取帧并处理
每次处理后休眠200ms模拟延迟
"""
print("消费者启动,开始等待帧...")
try:
while True:
# 阻塞等待队列中的帧
frame = frame_queue.get()
print(f"消费帧,大小: {frame.shape}")
has_violation, violations, confidences = OCRViolationDetector.detect(frame)
# 输出检测结果
if has_violation:
detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
for word, conf in zip(violations, confidences):
detector.logger.info(f"- {word} (置信度: {conf:.4f})")
else:
detector.logger.info("图片中未检测到违禁词")
# 标记任务完成
frame_queue.task_done()
except KeyboardInterrupt:
print("消费者退出")
def start_webrtc_stream(ip, webrtc_url):
"""
启动WebRTC视频流处理的主方法
参数: webrtc_url - WebRTC服务器地址
"""
print(f"开始连接到WebRTC服务器: {webrtc_url}")
# 启动生产者线程
producer_thread = threading.Thread(
target=webrtc_producer,
args=(webrtc_url,),
daemon=True,
name="webrtc-producer"
)
# 启动消费者线程
consumer_thread = threading.Thread(
target=frame_consumer(ip),
daemon=True,
name="frame-consumer"
)
producer_thread.start()
consumer_thread.start()
print("生产者和消费者线程已启动")
try:
# 保持主线程运行
while True:
time.sleep(1)
except KeyboardInterrupt:
print("程序正在退出...")
if __name__ == "__main__": if __name__ == "__main__":
# 替换为你的WHEP端点URL # 示例用法
WHEP_URL = "http://192.168.110.25:1985/rtc/v1/whep/?app=live&stream=473b95a47e338301cbd96809ea7ac416" # 实际使用时替换为真实的WebRTC服务器地址
webrtc_server_url = "http://192.168.110.65:1985/rtc/v1/whep/?app=live&stream=677a4845aa48cb8526c811ad56fc5e60"
# 运行拉流任务 start_webrtc_stream(webrtc_server_url)
asyncio.run(whep_pull_video_stream(WHEP_URL))

View File

@ -1,5 +1,7 @@
import json import json
import threading import threading
import time
from fastapi import HTTPException, Query, APIRouter, Depends, Request from fastapi import HTTPException, Query, APIRouter, Depends, Request
from mysql.connector import Error as MySQLError from mysql.connector import Error as MySQLError
@ -17,7 +19,7 @@ from schema.response_schema import APIResponse
from schema.user_schema import UserResponse from schema.user_schema import UserResponse
# 导入之前封装的WEBRTC处理函数 # 导入之前封装的WEBRTC处理函数
from rtc.rtc import process_webrtc_stream from core.rtmp import rtmp_pull_video_stream
router = APIRouter( router = APIRouter(
prefix="/devices", prefix="/devices",
@ -29,7 +31,7 @@ router = APIRouter(
def run_webrtc_processing(ip, webrtc_url): def run_webrtc_processing(ip, webrtc_url):
try: try:
print(f"开始处理来自设备 {ip} 的WEBRTC流: {webrtc_url}") print(f"开始处理来自设备 {ip} 的WEBRTC流: {webrtc_url}")
process_webrtc_stream(ip, webrtc_url) rtmp_pull_video_stream(webrtc_url)
except Exception as e: except Exception as e:
print(f"WEBRTC处理出错: {str(e)}") print(f"WEBRTC处理出错: {str(e)}")
@ -52,7 +54,9 @@ async def create_device(request: Request, device_data: DeviceCreateRequest):
# 设备创建成功后在后台线程启动WEBRTC流处理 # 设备创建成功后在后台线程启动WEBRTC流处理
threading.Thread( threading.Thread(
target=run_webrtc_processing, target=run_webrtc_processing,
args=(device_data.ip, existing_device["live_webrtc_url"]), # args=(device_data.ip, existing_device["live_webrtc_url"]),
args=(device_data.ip, existing_device["rtmp_push_url"]),
daemon=True # 设为守护线程,主程序退出时自动结束 daemon=True # 设为守护线程,主程序退出时自动结束
).start() ).start()
# IP已存在时返回该设备信息 # IP已存在时返回该设备信息