最新可用

This commit is contained in:
2025-09-04 22:59:27 +08:00
parent ec6dbfde90
commit 30bf7c9fcb
42 changed files with 746 additions and 1967 deletions

45
core/all.py Normal file
View File

@ -0,0 +1,45 @@
from core.ocr import load_model as ocrLoadModel, detect as ocrDetect
from core.face import load_model as faceLoadModel, detect as faceDetect
from core.yolo import load_model as yoloLoadModel, detect as yoloDetect
# 添加一个标记变量用于监控load_model是否已被调用
_model_loaded = False
def load_model():
global _model_loaded
# 如果已经调用过,直接忽略
if _model_loaded:
return
# 首次调用时加载模型
ocrLoadModel()
faceLoadModel()
yoloLoadModel()
# 标记为已调用
_model_loaded = True
def detect(frame):
# 先进行YOLO检测
yolo_flag, yolo_result = yoloDetect(frame)
print("YOLO检测结果", yolo_result)
if yolo_flag:
return (True, yolo_result, "yolo")
# YOLO未检测到进行人脸检测
face_flag, face_result = faceDetect(frame)
print("人脸检测结果:", face_result)
if face_flag:
return (True, face_result, "face")
# 人脸未检测到进行OCR检测
ocr_flag, ocr_result = ocrDetect(frame)
print("OCR检测结果", ocr_result)
if ocr_flag:
return (True, ocr_result, "ocr")
# 所有检测都未检测到
return (False, "未检测到任何内容", "none")

119
core/config/config.yaml Normal file
View File

@ -0,0 +1,119 @@
Global:
text_score: 0.5
use_det: true
use_cls: true
use_rec: true
min_height: 30
width_height_ratio: 8
max_side_len: 2000
min_side_len: 30
return_word_box: false
return_single_char_box: false
font_path: null
EngineConfig:
onnxruntime:
intra_op_num_threads: -1
inter_op_num_threads: -1
enable_cpu_mem_arena: false
cpu_ep_cfg:
arena_extend_strategy: "kSameAsRequested"
use_cuda: true # 改为true以启用CUDA
cuda_ep_cfg:
device_id: 0
arena_extend_strategy: "kNextPowerOfTwo"
cudnn_conv_algo_search: "EXHAUSTIVE"
do_copy_in_default_stream: true
use_dml: false
dm_ep_cfg: null
use_cann: false
cann_ep_cfg:
device_id: 0
arena_extend_strategy: "kNextPowerOfTwo"
npu_mem_limit: 21474836480 # 20 * 1024 * 1024 * 1024
op_select_impl_mode: "high_performance"
optypelist_for_implmode: "Gelu"
enable_cann_graph: true
openvino:
inference_num_threads: -1
performance_hint: null
performance_num_requests: -1
enable_cpu_pinning: null
num_streams: -1
enable_hyper_threading: null
scheduling_core_type: null
paddle:
cpu_math_library_num_threads: -1
use_npu: false
npu_id: 0
use_cuda: true # 改为true以启用CUDA
gpu_id: 0
gpu_mem: 500
torch:
use_cuda: true # 已经是true
gpu_id: 0
Det:
engine_type: "torch"
lang_type: "ch"
model_type: "mobile"
ocr_version: "PP-OCRv4"
task_type: "det"
model_path: null
model_dir: null
limit_side_len: 736
limit_type: min
std: [ 0.5, 0.5, 0.5 ]
mean: [ 0.5, 0.5, 0.5 ]
thresh: 0.3
box_thresh: 0.5
max_candidates: 1000
unclip_ratio: 1.6
use_dilation: true
score_mode: fast
Cls:
engine_type: "torch"
lang_type: "ch"
model_type: "mobile"
ocr_version: "PP-OCRv4"
task_type: "cls"
model_path: null
model_dir: null
cls_image_shape: [3, 48, 192]
cls_batch_num: 6
cls_thresh: 0.9
label_list: ["0", "180"]
Rec:
engine_type: "torch"
lang_type: "ch"
model_type: "mobile"
ocr_version: "PP-OCRv4"
task_type: "rec"
model_path: null
model_dir: null
rec_keys_path: null
rec_img_shape: [3, 48, 320]
rec_batch_num: 6

113
core/face.py Normal file
View File

@ -0,0 +1,113 @@
import os
import numpy as np
import cv2
from PIL import Image # 确保正确导入Image类
from insightface.app import FaceAnalysis
# 导入获取人脸信息的服务
from service.face_service import get_all_face_name_with_eigenvalue
# 全局变量
_face_app = None
_known_faces_embeddings = {} # 存储姓名到特征值的映射
_known_faces_names = [] # 存储所有已知姓名
def load_model():
"""加载人脸识别模型及已知人脸特征库"""
global _face_app, _known_faces_embeddings, _known_faces_names
# 初始化InsightFace模型
try:
_face_app = FaceAnalysis(name='buffalo_l', root=os.path.expanduser('~/.insightface'))
_face_app.prepare(ctx_id=0, det_size=(640, 640))
except Exception as e:
print(f"Face model load failed: {e}")
return False
# 从服务获取所有人脸姓名和特征值
try:
face_data = get_all_face_name_with_eigenvalue()
# 处理获取到的人脸数据
for person_name, eigenvalue_data in face_data.items():
# 处理特征值数据 - 兼容数组和字符串两种格式
if isinstance(eigenvalue_data, np.ndarray):
# 如果已经是numpy数组直接使用
eigenvalue = eigenvalue_data.astype(np.float32)
elif isinstance(eigenvalue_data, str):
# 清理字符串:移除方括号、换行符和多余空格
cleaned = eigenvalue_data.replace('[', '').replace(']', '').replace('\n', '').strip()
# 按空格或逗号分割(处理可能的不同分隔符)
values = [v for v in cleaned.split() if v]
# 转换为数组
eigenvalue = np.array(list(map(float, values)), dtype=np.float32)
else:
# 不支持的类型
print(f"Unsupported eigenvalue type for {person_name}")
continue
# 归一化处理
norm = np.linalg.norm(eigenvalue)
if norm != 0:
eigenvalue = eigenvalue / norm
_known_faces_embeddings[person_name] = eigenvalue
_known_faces_names.append(person_name)
except Exception as e:
print(f"Error loading face data from service: {e}")
return True if _face_app else False
def detect(frame, threshold=0.4):
"""检测并识别人脸,返回结果元组(是否匹配到已知人脸, 结果字符串)"""
global _face_app, _known_faces_embeddings, _known_faces_names
if not _face_app or not _known_faces_names or frame is None:
return (False, "未初始化或无效帧")
try:
faces = _face_app.get(frame)
except Exception as e:
print(f"Face detect error: {e}")
return (False, f"检测错误: {str(e)}")
result_parts = []
has_matched = False # 新增标记:是否有匹配到的已知人脸
for face in faces:
# 特征归一化
embedding = face.embedding.astype(np.float32)
norm = np.linalg.norm(embedding)
if norm == 0:
continue
embedding = embedding / norm
# 对比已知人脸
max_sim, best_name = -1.0, "Unknown"
for name in _known_faces_names:
known_emb = _known_faces_embeddings[name]
sim = np.dot(embedding, known_emb)
if sim > max_sim:
max_sim = sim
best_name = name
# 判断匹配结果
is_match = max_sim >= threshold
if is_match:
has_matched = True # 只要有一个匹配成功就标记为True
bbox = face.bbox
result_parts.append(
f"{'匹配' if is_match else '不匹配'}: {best_name} (相似度: {max_sim:.2f}, 边界框: {bbox})"
)
# 构建结果字符串
if not result_parts:
result_str = "未检测到人脸"
else:
result_str = "; ".join(result_parts)
# 第一个返回值改为:是否匹配到已知人脸
return (has_matched, result_str)

BIN
core/models/best.pt Normal file

Binary file not shown.

76
core/ocr.py Normal file
View File

@ -0,0 +1,76 @@
import os
import cv2
from rapidocr import RapidOCR
from service.sensitive_service import get_all_sensitive_words
# 全局变量
_ocr_engine = None
_forbidden_words = set()
_conf_threshold = 0.5
ocr_config_path = os.path.join(os.path.dirname(__file__), "config", "config.yaml")
def load_model():
"""加载OCR引擎及违禁词列表"""
global _ocr_engine, _forbidden_words, _conf_threshold
# 加载违禁词
try:
_forbidden_words = get_all_sensitive_words()
except Exception as e:
print(f"Forbidden words load error: {e}")
# 初始化OCR引擎
if not os.path.exists(ocr_config_path):
print(f"OCR config not found: {ocr_config_path}")
return False
try:
_ocr_engine = RapidOCR(config_path=ocr_config_path)
except Exception as e:
print(f"OCR model load failed: {e}")
return False
return True if _ocr_engine else False
def detect(frame):
"""OCR检测并筛选违禁词返回(是否检测到违禁词, 结果字符串)"""
if not _ocr_engine or not _forbidden_words or frame is None or frame.size == 0:
return (False, "未初始化或无效帧")
try:
ocr_res = _ocr_engine(frame)
except Exception as e:
print(f"OCR detect error: {e}")
return (False, f"检测错误: {str(e)}")
if not ocr_res or not hasattr(ocr_res, 'txts') or not hasattr(ocr_res, 'scores'):
return (False, "无OCR结果")
# 处理OCR结果
texts = [t.strip() for t in ocr_res.txts if t and isinstance(t, str)]
confs = [c for c in ocr_res.scores if c and isinstance(c, (int, float))]
if len(texts) != len(confs):
return (False, "OCR结果格式异常")
# 筛选违禁词
vio_info = []
for txt, conf in zip(texts, confs):
if conf < _conf_threshold:
continue
matched = [w for w in _forbidden_words if w in txt]
if matched:
vio_info.append(f"文本: '{txt}' 包含违禁词: {', '.join(matched)} (置信度: {conf:.2f})")
# 构建结果字符串
has_text = len(texts) > 0
has_violation = len(vio_info) > 0
if not has_text:
return (False, "未识别到文本")
elif has_violation:
return (True, "; ".join(vio_info))
else:
return (False, "未检测到违禁词")

View File

@ -1,137 +0,0 @@
import asyncio
import logging
from aiortc import RTCPeerConnection, RTCSessionDescription
import aiohttp
from ocr.ocr_violation_detector import OCRViolationDetector
import logging
# 创建检测器实例
detector = OCRViolationDetector(
forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt",
ocr_confidence_threshold=0.7,
log_level=logging.INFO,
log_file="ocr_detection.log"
)
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("whep_video_puller")
async def whep_pull_video_stream(ip,whep_url):
"""
通过WHEP从指定URL拉取视频流并在收到每一帧时打印消息
Args:
whep_url: WHEP端点的URL
"""
pc = RTCPeerConnection()
# 添加连接状态变化监听
@pc.on("connectionstatechange")
async def on_connectionstatechange():
print(f"连接状态: {pc.connectionState}")
# 添加ICE连接状态变化监听
@pc.on("iceconnectionstatechange")
async def on_iceconnectionstatechange():
print(f"ICE连接状态: {pc.iceConnectionState}")
# 添加视频接收器
pc.addTransceiver("video", direction="recvonly")
# 处理接收到的视频轨道
@pc.on("track")
def on_track(track):
print(f"接收到轨道: {track.kind}")
if track.kind == "video":
print(f"轨道ID: {track.id}")
print(f"轨道就绪状态: {track.readyState}")
# 创建异步任务来处理视频帧
asyncio.ensure_future(handle_video_track(track))
async def handle_video_track(track):
"""处理视频轨道,接收并打印每一帧"""
frame_count = 0
print("开始处理视频轨道...")
while True:
try:
# 尝试接收帧
frame = await track.recv()
frame_count += 1
print(f"收到原始帧 (第{frame_count}帧)")
# 打印帧的基本信息
if hasattr(frame, 'width') and hasattr(frame, 'height'):
print(f" 尺寸: {frame.width}x{frame.height}")
if hasattr(frame, 'time_base'):
print(f" 时间基准: {frame.time_base}")
if hasattr(frame, 'pts'):
print(f" 显示时间戳: {frame.pts}")
has_violation, violations, confidences = OCRViolationDetector.detect(frame)
# 输出检测结果
if has_violation:
detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
for word, conf in zip(violations, confidences):
detector.logger.info(f"- {word} (置信度: {conf:.4f})")
else:
detector.logger.info("图片中未检测到违禁词")
except Exception as e:
print(f"接收帧时出错: {e}")
# 等待一段时间后重试
await asyncio.sleep(0.1)
continue
# 创建offer
offer = await pc.createOffer()
await pc.setLocalDescription(offer)
print(f"本地SDP信息:\n{offer.sdp}")
# 通过HTTP POST发送offer到WHEP端点
async with aiohttp.ClientSession() as session:
async with session.post(
whep_url,
data=offer.sdp,
headers={"Content-Type": "application/sdp"}
) as response:
if response.status != 201:
print(f"WHEP服务器返回错误: {response.status}")
print(f"响应内容: {await response.text()}")
raise Exception(f"WHEP服务器返回错误: {response.status}")
# 获取answer SDP
answer_sdp = await response.text()
# 创建RTCSessionDescription对象
answer = RTCSessionDescription(sdp=answer_sdp, type="answer")
print(f"收到远程SDP:\n{answer_sdp}")
# 设置远程描述
await pc.setRemoteDescription(answer)
print("连接已建立,开始接收视频流...")
# 保持连接,直到用户中断
try:
while True:
await asyncio.sleep(1)
# 检查连接状态
print(f"当前连接状态: {pc.connectionState}")
except KeyboardInterrupt:
print("用户中断,关闭连接...")
finally:
await pc.close()
if __name__ == "__main__":
# 替换为你的WHEP端点URL
WHEP_URL = "http://192.168.110.25:1985/rtc/v1/whep/?app=live&stream=473b95a47e338301cbd96809ea7ac416"
# 运行拉流任务
asyncio.run(whep_pull_video_stream(WHEP_URL))

View File

@ -1,112 +0,0 @@
import asyncio
import logging
import cv2
import time
from ocr.model_violation_detector import MultiModelViolationDetector
# 配置文件相对路径(根据实际目录结构调整)
YOLO_MODEL_PATH = "../ocr/models/best.pt" # 关键修正从core目录向上一级找ocr文件夹
FORBIDDEN_WORDS_PATH = "../ocr/forbidden_words.txt"
OCR_CONFIG_PATH = "../ocr/config/1.yaml"
KNOWN_FACES_DIR = "../ocr/known_faces"
# 创建检测器实例
detector = MultiModelViolationDetector(
forbidden_words_path=FORBIDDEN_WORDS_PATH,
ocr_config_path=OCR_CONFIG_PATH,
yolo_model_path=YOLO_MODEL_PATH,
known_faces_dir=KNOWN_FACES_DIR,
ocr_confidence_threshold=0.5
)
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("rtmp_video_puller")
async def rtmp_pull_video_stream(rtmp_url):
"""
通过RTMP从指定URL拉取视频流并进行违规检测
"""
cap = None # 初始化视频捕获对象
try:
# 异步打开RTMP流
cap = await asyncio.to_thread(
cv2.VideoCapture,
rtmp_url,
cv2.CAP_FFMPEG # 指定FFmpeg后端确保RTMP兼容性
)
# 检查RTMP流是否成功打开
is_opened = await asyncio.to_thread(cap.isOpened)
if not is_opened:
raise Exception(f"RTMP流打开失败: {rtmp_url}请检查URL有效性和FFmpeg环境")
# 获取RTMP流基础信息
width = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_WIDTH)
height = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_HEIGHT)
fps = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FPS)
# 处理异常情况
fps = fps if fps > 0 else 30.0
width, height = int(width), int(height)
# 打印流初始化成功信息
print(f"RTMP流状态: 已成功连接")
print(f"流基础信息: 分辨率 {width}x{height} | 配置帧率 {fps:.2f} FPS")
print("开始接收视频帧...(按 Ctrl+C 中断)")
# 初始化帧统计参数
frame_count = 0
start_time = time.time()
# 循环读取视频帧
while True:
ret, frame = await asyncio.to_thread(cap.read)
if not ret:
print(f"RTMP流状态: 帧读取失败(可能流已中断或结束)")
break
frame_count += 1
# 打印当前帧信息
print(f"收到帧 (第{frame_count}帧)")
print(f" 帧尺寸: {width}x{height}")
print(f" 配置帧率: {fps:.2f} FPS")
if frame is not None:
has_violation, violation_type, details = detector.detect_violations(frame)
if has_violation:
print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
else:
print("未检测到任何违规内容")
else:
print(f"无法读取测试图像")
# 每100帧统计一次实际接收帧率
if frame_count % 100 == 0:
elapsed_time = time.time() - start_time
actual_fps = frame_count / elapsed_time
print(f"---- 帧统计: 累计{frame_count}帧 | 实际平均帧率 {actual_fps:.2f} FPS ----")
except KeyboardInterrupt:
print(f"\n用户操作: 已通过 Ctrl+C 中断程序")
except Exception as e:
logger.error(f"RTMP流处理异常: {str(e)}", exc_info=True)
print(f"错误信息: {str(e)}")
finally:
if cap is not None:
await asyncio.to_thread(cap.release)
print(f"\n资源释放: RTMP流已关闭")
print(f"最终统计: 共接收 {frame_count if 'frame_count' in locals() else 0}")
if __name__ == "__main__":
RTMP_URL = "rtmp://192.168.110.25:1935/live/473b95a47e338301cbd96809ea7ac416"
try:
asyncio.run(rtmp_pull_video_stream(RTMP_URL))
except Exception as e:
print(f"程序启动失败: {str(e)}")

55
core/yolo.py Normal file
View File

@ -0,0 +1,55 @@
import os
import cv2
from ultralytics import YOLO
# 全局变量
_yolo_model = None
model_path = os.path.join(os.path.dirname(__file__), "models", "best.pt")
def load_model():
"""加载YOLO目标检测模型"""
global _yolo_model
try:
_yolo_model = YOLO(model_path)
except Exception as e:
print(f"YOLO model load failed: {e}")
return False
return True if _yolo_model else False
def detect(frame, conf_threshold=0.2):
"""YOLO目标检测返回(是否识别到, 结果字符串)"""
global _yolo_model
if not _yolo_model or frame is None:
return (False, "未初始化或无效帧")
try:
results = _yolo_model(frame, conf=conf_threshold)
# 检查是否有检测结果
has_results = len(results[0].boxes) > 0 if results else False
if not has_results:
return (False, "未检测到目标")
# 构建结果字符串
result_parts = []
for box in results[0].boxes:
cls = int(box.cls[0])
conf = float(box.conf[0])
bbox = [float(x) for x in box.xyxy[0]]
class_name = _yolo_model.names[cls] if hasattr(_yolo_model, 'names') else f"类别{cls}"
result_parts.append(f"{class_name} (置信度: {conf:.2f}, 边界框: {bbox})")
result_str = "; ".join(result_parts)
return (has_results, result_str)
except Exception as e:
print(f"YOLO detect error: {e}")
return (False, f"检测错误: {str(e)}")