ocr1.0
This commit is contained in:
22
core/rtc.py
22
core/rtc.py
@ -2,13 +2,24 @@ import asyncio
|
||||
import logging
|
||||
from aiortc import RTCPeerConnection, RTCSessionDescription
|
||||
import aiohttp
|
||||
from ocr.ocr_violation_detector import OCRViolationDetector
|
||||
|
||||
import logging
|
||||
|
||||
# 创建检测器实例
|
||||
detector = OCRViolationDetector(
|
||||
forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt",
|
||||
ocr_confidence_threshold=0.7,
|
||||
log_level=logging.INFO,
|
||||
log_file="ocr_detection.log"
|
||||
)
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("whep_video_puller")
|
||||
|
||||
|
||||
async def whep_pull_video_stream(whep_url):
|
||||
async def whep_pull_video_stream(ip,whep_url):
|
||||
"""
|
||||
通过WHEP从指定URL拉取视频流并在收到每一帧时打印消息
|
||||
|
||||
@ -60,6 +71,15 @@ async def whep_pull_video_stream(whep_url):
|
||||
if hasattr(frame, 'pts'):
|
||||
print(f" 显示时间戳: {frame.pts}")
|
||||
|
||||
has_violation, violations, confidences = OCRViolationDetector.detect(frame)
|
||||
|
||||
# 输出检测结果
|
||||
if has_violation:
|
||||
detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
|
||||
for word, conf in zip(violations, confidences):
|
||||
detector.logger.info(f"- {word} (置信度: {conf:.4f})")
|
||||
else:
|
||||
detector.logger.info("图片中未检测到违禁词")
|
||||
except Exception as e:
|
||||
print(f"接收帧时出错: {e}")
|
||||
# 等待一段时间后重试
|
||||
|
20
core/rtmp.py
20
core/rtmp.py
@ -2,6 +2,17 @@ import asyncio
|
||||
import logging
|
||||
import cv2
|
||||
import time
|
||||
from ocr.ocr_violation_detector import OCRViolationDetector
|
||||
|
||||
import logging
|
||||
|
||||
# 创建检测器实例
|
||||
detector = OCRViolationDetector(
|
||||
forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt",
|
||||
ocr_confidence_threshold=0.7,
|
||||
log_level=logging.INFO,
|
||||
log_file="ocr_detection.log"
|
||||
)
|
||||
|
||||
# 配置日志(与WHEP代码保持一致的日志风格)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@ -67,6 +78,15 @@ async def rtmp_pull_video_stream(rtmp_url):
|
||||
print(f" 帧尺寸: {width}x{height}")
|
||||
print(f" 配置帧率: {fps:.2f} FPS")
|
||||
|
||||
has_violation, violations, confidences = OCRViolationDetector.detect(frame)
|
||||
|
||||
# 输出检测结果
|
||||
if has_violation:
|
||||
detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
|
||||
for word, conf in zip(violations, confidences):
|
||||
detector.logger.info(f"- {word} (置信度: {conf:.4f})")
|
||||
else:
|
||||
detector.logger.info("图片中未检测到违禁词")
|
||||
# 7. 每100帧统计一次实际接收帧率(补充性能监控,与原RTMP示例逻辑一致)
|
||||
if frame_count % 100 == 0:
|
||||
elapsed_time = time.time() - start_time
|
||||
|
@ -157,3 +157,77 @@ class OCRViolationDetector:
|
||||
|
||||
# 返回检测结果(是否有违禁词、所有违禁词列表、对应置信度列表)
|
||||
return len(all_prohibited) > 0, all_prohibited, all_confidences
|
||||
|
||||
# def test_image(self, image_path: str, show_image: bool = True) -> tuple:
|
||||
# """
|
||||
# 对单张图片进行OCR违禁词检测并展示结果
|
||||
#
|
||||
# Args:
|
||||
# image_path (str): 图片文件路径
|
||||
# show_image (bool): 是否显示图片,默认为True
|
||||
#
|
||||
# Returns:
|
||||
# tuple: (是否有违禁词, 违禁词列表, 对应的置信度列表)
|
||||
# """
|
||||
# # 检查图片文件是否存在
|
||||
# if not os.path.exists(image_path):
|
||||
# self.logger.error(f"图片文件不存在: {image_path}")
|
||||
# return False, [], []
|
||||
#
|
||||
# try:
|
||||
# # 读取图片
|
||||
# frame = cv2.imread(image_path)
|
||||
# if frame is None:
|
||||
# self.logger.error(f"无法读取图片: {image_path}")
|
||||
# return False, [], []
|
||||
#
|
||||
# self.logger.info(f"开始处理图片: {image_path}")
|
||||
#
|
||||
# # 调用检测方法
|
||||
# has_violation, violations, confidences = self.detect(frame)
|
||||
#
|
||||
# # 输出检测结果
|
||||
# if has_violation:
|
||||
# self.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
|
||||
# for word, conf in zip(violations, confidences):
|
||||
# self.logger.info(f"- {word} (置信度: {conf:.4f})")
|
||||
# else:
|
||||
# self.logger.info("图片中未检测到违禁词")
|
||||
|
||||
# # 显示图片(如果需要)
|
||||
# if show_image:
|
||||
# # 调整图片大小以便于显示(如果太大)
|
||||
# height, width = frame.shape[:2]
|
||||
# max_size = 800
|
||||
# if max(height, width) > max_size:
|
||||
# scale = max_size / max(height, width)
|
||||
# frame = cv2.resize(frame, None, fx=scale, fy=scale)
|
||||
#
|
||||
# cv2.imshow(f"OCR检测结果: {'发现违禁词' if has_violation else '未发现违禁词'}", frame)
|
||||
# cv2.waitKey(0) # 等待用户按键
|
||||
# cv2.destroyAllWindows()
|
||||
#
|
||||
# return has_violation, violations, confidences
|
||||
#
|
||||
# except Exception as e:
|
||||
# self.logger.error(f"处理图片时发生错误: {str(e)}", exc_info=True)
|
||||
# return False, [], []
|
||||
#
|
||||
#
|
||||
# # 使用示例
|
||||
# if __name__ == "__main__":
|
||||
# # 配置参数
|
||||
# forbidden_words_path = "forbidden_words.txt" # 违禁词文件路径
|
||||
# test_image_path = r"D:\Git\bin\video\ocr\images\img_7.png" # 测试图片路径
|
||||
# ocr_threshold = 0.6 # OCR置信度阈值
|
||||
#
|
||||
# # 创建检测器实例
|
||||
# detector = OCRViolationDetector(
|
||||
# forbidden_words_path=forbidden_words_path,
|
||||
# ocr_confidence_threshold=ocr_threshold,
|
||||
# log_level=logging.INFO,
|
||||
# log_file="ocr_detection.log"
|
||||
# )
|
||||
#
|
||||
# # 测试图片
|
||||
# detector.test_image(test_image_path, show_image=True)
|
244
rtc/rtc.py
244
rtc/rtc.py
@ -1,117 +1,175 @@
|
||||
import queue
|
||||
import asyncio
|
||||
import logging
|
||||
from aiortc import RTCPeerConnection, RTCSessionDescription
|
||||
import aiohttp
|
||||
import threading
|
||||
import time
|
||||
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration
|
||||
from aiortc.mediastreams import MediaStreamTrack
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("whep_video_puller")
|
||||
from ocr.ocr_violation_detector import OCRViolationDetector
|
||||
import logging
|
||||
|
||||
# 创建检测器实例
|
||||
detector = OCRViolationDetector(
|
||||
forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt",
|
||||
ocr_confidence_threshold=0.7,
|
||||
log_level=logging.INFO,
|
||||
log_file="ocr_detection.log"
|
||||
)
|
||||
|
||||
# 创建一个长度为1的队列,用于生产者和消费者之间的通信
|
||||
frame_queue = queue.Queue(maxsize=1)
|
||||
|
||||
|
||||
async def whep_pull_video_stream(whep_url):
|
||||
class VideoTrack(MediaStreamTrack):
|
||||
"""自定义视频轨道类,继承自MediaStreamTrack"""
|
||||
kind = "video"
|
||||
|
||||
def __init__(self, max_frames=100):
|
||||
super().__init__()
|
||||
self.frames = queue.Queue(maxsize=max_frames)
|
||||
|
||||
async def recv(self):
|
||||
return await super().recv()
|
||||
|
||||
|
||||
def webrtc_producer(webrtc_url):
|
||||
"""
|
||||
通过WHEP从指定URL拉取视频流并在收到每一帧时打印消息
|
||||
|
||||
Args:
|
||||
whep_url: WHEP端点的URL
|
||||
生产者方法:从WEBRTC读取视频帧并放入队列
|
||||
仅当队列空时才放入新帧,否则丢弃
|
||||
"""
|
||||
pc = RTCPeerConnection()
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# 添加连接状态变化监听
|
||||
@pc.on("connectionstatechange")
|
||||
async def on_connectionstatechange():
|
||||
print(f"连接状态: {pc.connectionState}")
|
||||
# 创建RTCPeerConnection对象,不使用ICE服务器
|
||||
pc = RTCPeerConnection(RTCConfiguration(iceServers=[]))
|
||||
video_track = VideoTrack()
|
||||
pc.addTrack(video_track)
|
||||
|
||||
# 添加ICE连接状态变化监听
|
||||
@pc.on("iceconnectionstatechange")
|
||||
async def on_iceconnectionstatechange():
|
||||
print(f"ICE连接状态: {pc.iceConnectionState}")
|
||||
|
||||
# 添加视频接收器
|
||||
pc.addTransceiver("video", direction="recvonly")
|
||||
|
||||
# 处理接收到的视频轨道
|
||||
@pc.on("track")
|
||||
def on_track(track):
|
||||
print(f"接收到轨道: {track.kind}")
|
||||
async def on_track(track):
|
||||
if track.kind == "video":
|
||||
print(f"轨道ID: {track.id}")
|
||||
print(f"轨道就绪状态: {track.readyState}")
|
||||
# 创建异步任务来处理视频帧
|
||||
asyncio.ensure_future(handle_video_track(track))
|
||||
|
||||
async def handle_video_track(track):
|
||||
"""处理视频轨道,接收并打印每一帧"""
|
||||
frame_count = 0
|
||||
print("开始处理视频轨道...")
|
||||
|
||||
while True:
|
||||
try:
|
||||
# 尝试接收帧
|
||||
print("接收到视频轨道,开始接收视频帧")
|
||||
while True:
|
||||
# 从轨道接收视频帧
|
||||
frame = await track.recv()
|
||||
frame_count += 1
|
||||
print(f"收到原始帧 (第{frame_count}帧)")
|
||||
# 转换为BGR24格式的NumPy数组
|
||||
frame_bgr24 = frame.to_ndarray(format='bgr24')
|
||||
|
||||
# 打印帧的基本信息
|
||||
if hasattr(frame, 'width') and hasattr(frame, 'height'):
|
||||
print(f" 尺寸: {frame.width}x{frame.height}")
|
||||
if hasattr(frame, 'time_base'):
|
||||
print(f" 时间基准: {frame.time_base}")
|
||||
if hasattr(frame, 'pts'):
|
||||
print(f" 显示时间戳: {frame.pts}")
|
||||
# 检查队列是否为空,为空则加入,否则丢弃
|
||||
if frame_queue.empty():
|
||||
try:
|
||||
frame_queue.put_nowait(frame_bgr24)
|
||||
print("帧已放入队列")
|
||||
except queue.Full:
|
||||
print("队列已满,丢弃帧")
|
||||
else:
|
||||
print("队列非空,丢弃帧")
|
||||
|
||||
except Exception as e:
|
||||
print(f"接收帧时出错: {e}")
|
||||
# 等待一段时间后重试
|
||||
async def main():
|
||||
# 创建并发送SDP Offer
|
||||
offer = await pc.createOffer()
|
||||
print("已创建本地SDP Offer")
|
||||
await pc.setLocalDescription(offer)
|
||||
|
||||
# 发送Offer到服务器并接收Answer
|
||||
async with aiohttp.ClientSession() as session:
|
||||
print(f"开始向服务器 {webrtc_url} 发送SDP Offer")
|
||||
async with session.post(
|
||||
webrtc_url,
|
||||
data=offer.sdp.encode(),
|
||||
headers={
|
||||
"Content-Type": "application/sdp",
|
||||
"Content-Length": str(len(offer.sdp))
|
||||
},
|
||||
ssl=False
|
||||
) as response:
|
||||
print("已接收到服务器的响应")
|
||||
answer_sdp = await response.text()
|
||||
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer'))
|
||||
|
||||
# 保持连接
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
print("关闭RTCPeerConnection")
|
||||
await pc.close()
|
||||
|
||||
# 创建offer
|
||||
offer = await pc.createOffer()
|
||||
await pc.setLocalDescription(offer)
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
print(f"本地SDP信息:\n{offer.sdp}")
|
||||
|
||||
# 通过HTTP POST发送offer到WHEP端点
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
whep_url,
|
||||
data=offer.sdp,
|
||||
headers={"Content-Type": "application/sdp"}
|
||||
) as response:
|
||||
if response.status != 201:
|
||||
print(f"WHEP服务器返回错误: {response.status}")
|
||||
print(f"响应内容: {await response.text()}")
|
||||
raise Exception(f"WHEP服务器返回错误: {response.status}")
|
||||
|
||||
# 获取answer SDP
|
||||
answer_sdp = await response.text()
|
||||
|
||||
# 创建RTCSessionDescription对象
|
||||
answer = RTCSessionDescription(sdp=answer_sdp, type="answer")
|
||||
|
||||
print(f"收到远程SDP:\n{answer_sdp}")
|
||||
|
||||
# 设置远程描述
|
||||
await pc.setRemoteDescription(answer)
|
||||
|
||||
print("连接已建立,开始接收视频流...")
|
||||
|
||||
# 保持连接,直到用户中断
|
||||
def frame_consumer(ip):
|
||||
"""
|
||||
消费者方法:从队列中读取帧并处理
|
||||
每次处理后休眠200ms模拟延迟
|
||||
"""
|
||||
print("消费者启动,开始等待帧...")
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
# 检查连接状态
|
||||
print(f"当前连接状态: {pc.connectionState}")
|
||||
# 阻塞等待队列中的帧
|
||||
frame = frame_queue.get()
|
||||
print(f"消费帧,大小: {frame.shape}")
|
||||
|
||||
has_violation, violations, confidences = OCRViolationDetector.detect(frame)
|
||||
|
||||
|
||||
# 输出检测结果
|
||||
if has_violation:
|
||||
detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
|
||||
for word, conf in zip(violations, confidences):
|
||||
detector.logger.info(f"- {word} (置信度: {conf:.4f})")
|
||||
else:
|
||||
detector.logger.info("图片中未检测到违禁词")
|
||||
|
||||
|
||||
# 标记任务完成
|
||||
frame_queue.task_done()
|
||||
except KeyboardInterrupt:
|
||||
print("用户中断,关闭连接...")
|
||||
finally:
|
||||
await pc.close()
|
||||
print("消费者退出")
|
||||
|
||||
|
||||
def start_webrtc_stream(ip, webrtc_url):
|
||||
"""
|
||||
启动WebRTC视频流处理的主方法
|
||||
参数: webrtc_url - WebRTC服务器地址
|
||||
"""
|
||||
print(f"开始连接到WebRTC服务器: {webrtc_url}")
|
||||
|
||||
# 启动生产者线程
|
||||
producer_thread = threading.Thread(
|
||||
target=webrtc_producer,
|
||||
args=(webrtc_url,),
|
||||
daemon=True,
|
||||
name="webrtc-producer"
|
||||
)
|
||||
|
||||
# 启动消费者线程
|
||||
consumer_thread = threading.Thread(
|
||||
target=frame_consumer(ip),
|
||||
daemon=True,
|
||||
name="frame-consumer"
|
||||
)
|
||||
|
||||
producer_thread.start()
|
||||
consumer_thread.start()
|
||||
print("生产者和消费者线程已启动")
|
||||
|
||||
try:
|
||||
# 保持主线程运行
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
print("程序正在退出...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 替换为你的WHEP端点URL
|
||||
WHEP_URL = "http://192.168.110.25:1985/rtc/v1/whep/?app=live&stream=473b95a47e338301cbd96809ea7ac416"
|
||||
|
||||
# 运行拉流任务
|
||||
asyncio.run(whep_pull_video_stream(WHEP_URL))
|
||||
# 示例用法
|
||||
# 实际使用时替换为真实的WebRTC服务器地址
|
||||
webrtc_server_url = "http://192.168.110.65:1985/rtc/v1/whep/?app=live&stream=677a4845aa48cb8526c811ad56fc5e60"
|
||||
start_webrtc_stream(webrtc_server_url)
|
||||
|
@ -1,5 +1,7 @@
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
|
||||
from fastapi import HTTPException, Query, APIRouter, Depends, Request
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
@ -17,7 +19,7 @@ from schema.response_schema import APIResponse
|
||||
from schema.user_schema import UserResponse
|
||||
|
||||
# 导入之前封装的WEBRTC处理函数
|
||||
from rtc.rtc import process_webrtc_stream
|
||||
from core.rtmp import rtmp_pull_video_stream
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/devices",
|
||||
@ -29,7 +31,7 @@ router = APIRouter(
|
||||
def run_webrtc_processing(ip, webrtc_url):
|
||||
try:
|
||||
print(f"开始处理来自设备 {ip} 的WEBRTC流: {webrtc_url}")
|
||||
process_webrtc_stream(ip, webrtc_url)
|
||||
rtmp_pull_video_stream(webrtc_url)
|
||||
except Exception as e:
|
||||
print(f"WEBRTC处理出错: {str(e)}")
|
||||
|
||||
@ -52,7 +54,9 @@ async def create_device(request: Request, device_data: DeviceCreateRequest):
|
||||
# 设备创建成功后,在后台线程启动WEBRTC流处理
|
||||
threading.Thread(
|
||||
target=run_webrtc_processing,
|
||||
args=(device_data.ip, existing_device["live_webrtc_url"]),
|
||||
# args=(device_data.ip, existing_device["live_webrtc_url"]),
|
||||
args=(device_data.ip, existing_device["rtmp_push_url"]),
|
||||
|
||||
daemon=True # 设为守护线程,主程序退出时自动结束
|
||||
).start()
|
||||
# IP已存在时返回该设备信息
|
||||
|
Reference in New Issue
Block a user