ocr1.0
This commit is contained in:
@ -7,6 +7,11 @@ def setup_logger():
|
|||||||
配置一个全局日志记录器,支持输出到控制台和文件。
|
配置一个全局日志记录器,支持输出到控制台和文件。
|
||||||
"""
|
"""
|
||||||
# 创建一个日志记录器
|
# 创建一个日志记录器
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
logger = logging.getLogger("ViolationDetectorLogger")
|
logger = logging.getLogger("ViolationDetectorLogger")
|
||||||
logger.setLevel(logging.DEBUG) # 设置最低级别为DEBUG
|
logger.setLevel(logging.DEBUG) # 设置最低级别为DEBUG
|
||||||
|
|
||||||
|
@ -99,7 +99,7 @@ class OCRViolationDetector:
|
|||||||
"""初始化RapidOCR引擎"""
|
"""初始化RapidOCR引擎"""
|
||||||
self.logger.info("正在初始化RapidOCR引擎...")
|
self.logger.info("正在初始化RapidOCR引擎...")
|
||||||
|
|
||||||
config_path = r"../ocr/config/1.yaml"
|
config_path = r"D:\Git\bin\video\ocr\config\1.yaml"
|
||||||
try:
|
try:
|
||||||
# 检查配置文件是否存在
|
# 检查配置文件是否存在
|
||||||
if not os.path.exists(config_path):
|
if not os.path.exists(config_path):
|
||||||
@ -157,4 +157,3 @@ class OCRViolationDetector:
|
|||||||
|
|
||||||
# 返回检测结果(是否有违禁词、所有违禁词列表、对应置信度列表)
|
# 返回检测结果(是否有违禁词、所有违禁词列表、对应置信度列表)
|
||||||
return len(all_prohibited) > 0, all_prohibited, all_confidences
|
return len(all_prohibited) > 0, all_prohibited, all_confidences
|
||||||
|
|
||||||
|
59
rtc/rtc.py
59
rtc/rtc.py
@ -1,10 +1,13 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration
|
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration
|
||||||
from aiortc.mediastreams import MediaStreamTrack
|
from aiortc.mediastreams import MediaStreamTrack
|
||||||
from ocr.ocr_violation_detector import OCRViolationDetector
|
from ocr.ocr_violation_detector import OCRViolationDetector
|
||||||
|
from ws.ws import send_message_to_client
|
||||||
|
|
||||||
|
|
||||||
class VideoTrack(MediaStreamTrack):
|
class VideoTrack(MediaStreamTrack):
|
||||||
@ -18,9 +21,10 @@ class VideoTrack(MediaStreamTrack):
|
|||||||
return await super().recv()
|
return await super().recv()
|
||||||
|
|
||||||
|
|
||||||
async def rtc_frame_receiver(url, frame_queue):
|
async def rtc_frame_receiver(url, frame_queue, stop_event):
|
||||||
"""
|
"""
|
||||||
对每帧进行检查、只要接收到 RTC 帧且队列为空、就往队列放入cv2格式的帧数据
|
对每帧进行检查、只要接收到 RTC 帧且队列为空、就往队列放入cv2格式的帧数据
|
||||||
|
当stop_event被设置时停止接收
|
||||||
"""
|
"""
|
||||||
pc = RTCPeerConnection(RTCConfiguration(iceServers=[]))
|
pc = RTCPeerConnection(RTCConfiguration(iceServers=[]))
|
||||||
video_track = VideoTrack()
|
video_track = VideoTrack()
|
||||||
@ -34,7 +38,7 @@ async def rtc_frame_receiver(url, frame_queue):
|
|||||||
nonlocal total_frames
|
nonlocal total_frames
|
||||||
if track.kind == "video":
|
if track.kind == "video":
|
||||||
print("接收到视频轨道、开始接收视频帧")
|
print("接收到视频轨道、开始接收视频帧")
|
||||||
while True:
|
while not stop_event.is_set(): # 检查是否需要停止
|
||||||
# 接收当前帧并累计计数
|
# 接收当前帧并累计计数
|
||||||
frame = await track.recv()
|
frame = await track.recv()
|
||||||
# 转换为cv2兼容的BGR格式numpy数组
|
# 转换为cv2兼容的BGR格式numpy数组
|
||||||
@ -45,13 +49,13 @@ async def rtc_frame_receiver(url, frame_queue):
|
|||||||
total_frames += 1
|
total_frames += 1
|
||||||
|
|
||||||
# 对每帧都检查队列状态、队列为空则放入
|
# 对每帧都检查队列状态、队列为空则放入
|
||||||
if frame_queue.empty():
|
if frame_queue.empty() and not stop_event.is_set(): # 确保还未收到停止信号
|
||||||
# 队列为空、放入当前cv2帧
|
# 队列为空、放入当前cv2帧
|
||||||
await frame_queue.put(frame_cv2)
|
await frame_queue.put(frame_cv2)
|
||||||
# print(f"第{total_frames}帧:队列为空、已放入新的cv2帧,尺寸: {frame_cv2.shape}")
|
|
||||||
else:
|
else:
|
||||||
# 队列非空、说明上一帧还未处理、跳过当前帧
|
# 队列非空或已收到停止信号、跳过当前帧
|
||||||
print(f"第{total_frames}帧:队列非空、跳过该帧")
|
if not stop_event.is_set():
|
||||||
|
print(f"第{total_frames}帧:队列非空、跳过该帧")
|
||||||
else:
|
else:
|
||||||
print("帧格式转换失败,不是有效的cv2格式")
|
print("帧格式转换失败,不是有效的cv2格式")
|
||||||
|
|
||||||
@ -77,8 +81,8 @@ async def rtc_frame_receiver(url, frame_queue):
|
|||||||
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer'))
|
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type='answer'))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 保持连接
|
# 保持连接,直到收到停止信号
|
||||||
while True:
|
while not stop_event.is_set():
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("用户中断")
|
print("用户中断")
|
||||||
@ -88,52 +92,68 @@ async def rtc_frame_receiver(url, frame_queue):
|
|||||||
print("已关闭 RTCPeerConnection")
|
print("已关闭 RTCPeerConnection")
|
||||||
|
|
||||||
|
|
||||||
async def frame_consumer(frame_queue):
|
async def frame_consumer(ip, frame_queue, stop_event):
|
||||||
"""
|
"""
|
||||||
从队列中读取cv2帧并处理(队列空时会阻塞等待)
|
从队列中读取cv2帧并处理(队列空时会阻塞等待)
|
||||||
|
检测到违规内容后设置stop_event以终止所有任务
|
||||||
|
|
||||||
Args: frame_queue: 帧队列
|
Args:
|
||||||
|
ip: IP地址
|
||||||
|
frame_queue: 帧队列
|
||||||
|
stop_event: 用于控制任务停止的事件
|
||||||
"""
|
"""
|
||||||
# 创建OCR检测器实例(请替换为实际的违禁词文件路径)
|
# 创建OCR检测器实例
|
||||||
ocr_detector = OCRViolationDetector(
|
ocr_detector = OCRViolationDetector(
|
||||||
forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt", # 替换为实际路径
|
forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt", # 替换为实际路径
|
||||||
ocr_confidence_threshold=0.5, )
|
ocr_confidence_threshold=0.5, )
|
||||||
|
|
||||||
while True:
|
while not stop_event.is_set(): # 检查是否需要停止
|
||||||
# 从队列中获取cv2帧(队列为空时会阻塞等待新帧)
|
# 从队列中获取cv2帧(队列为空时会阻塞等待新帧)
|
||||||
current_frame = await frame_queue.get()
|
current_frame = await frame_queue.get()
|
||||||
has_violation, words, confidences = ocr_detector.detect(current_frame)
|
has_violation, words, confidences = ocr_detector.detect(current_frame)
|
||||||
|
print(has_violation)
|
||||||
|
print( words)
|
||||||
|
print( confidences)
|
||||||
# 输出所有检测到的违禁词
|
# 输出所有检测到的违禁词
|
||||||
if has_violation:
|
if has_violation:
|
||||||
print(f"测试结果:图片中共检测到 {len(words)} 个违禁词:")
|
print(f"测试结果:图片中共检测到 {len(words)} 个违禁词:")
|
||||||
|
response_data = {
|
||||||
|
"status": "stop",
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
}
|
||||||
|
await send_message_to_client(ip,response_data )
|
||||||
for word, conf in zip(words, confidences):
|
for word, conf in zip(words, confidences):
|
||||||
print(f"- {word}(置信度:{conf:.4f})")
|
print(f"- {word}(置信度:{conf:.4f})")
|
||||||
|
|
||||||
|
# 检测到违规,设置停止事件
|
||||||
|
print("检测到违规内容,准备关闭AI检测")
|
||||||
|
stop_event.set()
|
||||||
else:
|
else:
|
||||||
print("测试结果:图片中未检测到违禁词")
|
print("测试结果:图片中未检测到违禁词")
|
||||||
|
|
||||||
# 标记任务完成
|
# 标记任务完成
|
||||||
frame_queue.task_done()
|
frame_queue.task_done()
|
||||||
# print("帧处理完成、队列已清空")
|
|
||||||
|
|
||||||
|
|
||||||
def process_webrtc_stream(ip, webrtc_url):
|
def process_webrtc_stream(ip, webrtc_url):
|
||||||
"""
|
"""
|
||||||
处理WEBRTC流并持续打印OCR检测结果
|
处理WEBRTC流并持续打印OCR检测结果,检测到违规后关闭
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ip: IP地址(预留参数)
|
ip: IP地址
|
||||||
webrtc_url: WEBRTC服务器地址
|
webrtc_url: WEBRTC服务器地址
|
||||||
"""
|
"""
|
||||||
# 创建队列
|
# 创建队列和停止事件
|
||||||
frame_queue = asyncio.Queue(maxsize=1)
|
frame_queue = asyncio.Queue(maxsize=1)
|
||||||
|
stop_event = asyncio.Event() # 用于控制任务停止的事件
|
||||||
|
|
||||||
# 定义事件循环中的主任务
|
# 定义事件循环中的主任务
|
||||||
async def main_task():
|
async def main_task():
|
||||||
# 创建任务
|
# 创建任务
|
||||||
receiver_task = asyncio.create_task(rtc_frame_receiver(webrtc_url, frame_queue))
|
receiver_task = asyncio.create_task(rtc_frame_receiver(webrtc_url, frame_queue, stop_event))
|
||||||
consumer_task = asyncio.create_task(frame_consumer(frame_queue))
|
consumer_task = asyncio.create_task(frame_consumer(ip, frame_queue, stop_event))
|
||||||
|
|
||||||
# 等待任务完成
|
# 等待任一任务完成(当stop_event被设置时,两个任务都会退出)
|
||||||
await asyncio.gather(receiver_task, consumer_task)
|
await asyncio.gather(receiver_task, consumer_task)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -144,3 +164,4 @@ def process_webrtc_stream(ip, webrtc_url):
|
|||||||
finally:
|
finally:
|
||||||
# 确保关闭所有cv2窗口
|
# 确保关闭所有cv2窗口
|
||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
|
print("AI检测已关闭")
|
||||||
|
@ -52,7 +52,7 @@ async def create_device(request: Request, device_data: DeviceCreateRequest):
|
|||||||
# 设备创建成功后,在后台线程启动WEBRTC流处理
|
# 设备创建成功后,在后台线程启动WEBRTC流处理
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
target=run_webrtc_processing,
|
target=run_webrtc_processing,
|
||||||
args=(device_data.ip, full_webrtc_url),
|
args=(device_data.ip, existing_device["live_webrtc_url"]),
|
||||||
daemon=True # 设为守护线程,主程序退出时自动结束
|
daemon=True # 设为守护线程,主程序退出时自动结束
|
||||||
).start()
|
).start()
|
||||||
# IP已存在时返回该设备信息
|
# IP已存在时返回该设备信息
|
||||||
|
Reference in New Issue
Block a user