ocr1.0
This commit is contained in:
79
core/rtmp.py
79
core/rtmp.py
@ -2,109 +2,101 @@ import asyncio
|
||||
import logging
|
||||
import cv2
|
||||
import time
|
||||
from ocr.ocr_violation_detector import OCRViolationDetector
|
||||
from ocr.model_violation_detector import MultiModelViolationDetector
|
||||
|
||||
import logging
|
||||
|
||||
# 配置文件相对路径(根据实际目录结构调整)
|
||||
YOLO_MODEL_PATH = "../ocr/models/best.pt" # 关键修正:从core目录向上一级找ocr文件夹
|
||||
FORBIDDEN_WORDS_PATH = "../ocr/forbidden_words.txt"
|
||||
OCR_CONFIG_PATH = "../ocr/config/1.yaml"
|
||||
KNOWN_FACES_DIR = "../ocr/known_faces"
|
||||
|
||||
# 创建检测器实例
|
||||
detector = OCRViolationDetector(
|
||||
forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt",
|
||||
ocr_confidence_threshold=0.7,
|
||||
log_level=logging.INFO,
|
||||
log_file="ocr_detection.log"
|
||||
detector = MultiModelViolationDetector(
|
||||
forbidden_words_path=FORBIDDEN_WORDS_PATH,
|
||||
ocr_config_path=OCR_CONFIG_PATH,
|
||||
yolo_model_path=YOLO_MODEL_PATH,
|
||||
known_faces_dir=KNOWN_FACES_DIR,
|
||||
ocr_confidence_threshold=0.5
|
||||
)
|
||||
|
||||
# 配置日志(与WHEP代码保持一致的日志风格)
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("rtmp_video_puller")
|
||||
|
||||
|
||||
async def rtmp_pull_video_stream(rtmp_url):
|
||||
"""
|
||||
通过RTMP从指定URL拉取视频流并在收到每一帧时打印消息
|
||||
功能与WHEP拉流函数对齐:流状态反馈、帧信息打印、帧率统计、异常处理
|
||||
|
||||
Args:
|
||||
rtmp_url: RTMP流的URL地址(如 rtmp://xxx/live/stream_key)
|
||||
通过RTMP从指定URL拉取视频流并进行违规检测
|
||||
"""
|
||||
cap = None # 初始化视频捕获对象
|
||||
try:
|
||||
# 1. 异步打开RTMP流(指定FFmpeg后端确保RTMP兼容性,同步操作通过to_thread避免阻塞事件循环)
|
||||
# 异步打开RTMP流
|
||||
cap = await asyncio.to_thread(
|
||||
cv2.VideoCapture,
|
||||
rtmp_url,
|
||||
cv2.CAP_FFMPEG # 必须指定FFmpeg后端,RTMP协议依赖该后端解析
|
||||
cv2.CAP_FFMPEG # 指定FFmpeg后端确保RTMP兼容性
|
||||
)
|
||||
|
||||
# 2. 检查RTMP流是否成功打开
|
||||
# 检查RTMP流是否成功打开
|
||||
is_opened = await asyncio.to_thread(cap.isOpened)
|
||||
if not is_opened:
|
||||
raise Exception(f"RTMP流打开失败: {rtmp_url}(请检查URL有效性和FFmpeg环境)")
|
||||
|
||||
# 3. 异步获取RTMP流基础信息(分辨率、帧率)
|
||||
# 获取RTMP流基础信息
|
||||
width = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_WIDTH)
|
||||
height = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FRAME_HEIGHT)
|
||||
fps = await asyncio.to_thread(cap.get, cv2.CAP_PROP_FPS)
|
||||
|
||||
# 处理异常情况:部分RTMP流未返回帧率时默认30FPS
|
||||
# 处理异常情况
|
||||
fps = fps if fps > 0 else 30.0
|
||||
# 分辨率转为整数(视频尺寸必然是整数)
|
||||
width, height = int(width), int(height)
|
||||
|
||||
# 打印流初始化成功信息(与WHEP连接成功信息风格一致)
|
||||
# 打印流初始化成功信息
|
||||
print(f"RTMP流状态: 已成功连接")
|
||||
print(f"流基础信息: 分辨率 {width}x{height} | 配置帧率 {fps:.2f} FPS")
|
||||
print("开始接收视频帧...(按 Ctrl+C 中断)")
|
||||
|
||||
# 4. 初始化帧统计参数
|
||||
frame_count = 0 # 总接收帧数
|
||||
start_time = time.time() # 统计起始时间
|
||||
# 初始化帧统计参数
|
||||
frame_count = 0
|
||||
start_time = time.time()
|
||||
|
||||
# 5. 循环异步读取视频帧(核心逻辑)
|
||||
# 循环读取视频帧
|
||||
while True:
|
||||
# 异步读取一帧(cv2.read是同步操作,用to_thread适配异步环境)
|
||||
ret, frame = await asyncio.to_thread(cap.read)
|
||||
|
||||
# 检查帧是否读取成功(流中断/结束时ret为False)
|
||||
if not ret:
|
||||
print(f"RTMP流状态: 帧读取失败(可能流已中断或结束)")
|
||||
break
|
||||
|
||||
# 帧计数累加
|
||||
frame_count += 1
|
||||
|
||||
# 6. 打印当前帧基础信息(与WHEP帧信息打印风格对齐)
|
||||
# 打印当前帧信息
|
||||
print(f"收到帧 (第{frame_count}帧)")
|
||||
print(f" 帧尺寸: {width}x{height}")
|
||||
print(f" 配置帧率: {fps:.2f} FPS")
|
||||
|
||||
has_violation, violations, confidences = OCRViolationDetector.detect(frame)
|
||||
|
||||
# 输出检测结果
|
||||
if frame is not None:
|
||||
has_violation, violation_type, details = detector.detect_violations(frame)
|
||||
if has_violation:
|
||||
detector.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
|
||||
for word, conf in zip(violations, confidences):
|
||||
detector.logger.info(f"- {word} (置信度: {conf:.4f})")
|
||||
print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
|
||||
else:
|
||||
detector.logger.info("图片中未检测到违禁词")
|
||||
# 7. 每100帧统计一次实际接收帧率(补充性能监控,与原RTMP示例逻辑一致)
|
||||
print("未检测到任何违规内容")
|
||||
else:
|
||||
print(f"无法读取测试图像")
|
||||
|
||||
# 每100帧统计一次实际接收帧率
|
||||
if frame_count % 100 == 0:
|
||||
elapsed_time = time.time() - start_time
|
||||
actual_fps = frame_count / elapsed_time # 实际接收帧率(可能低于配置帧率)
|
||||
actual_fps = frame_count / elapsed_time
|
||||
print(f"---- 帧统计: 累计{frame_count}帧 | 实际平均帧率 {actual_fps:.2f} FPS ----")
|
||||
|
||||
# (可选)帧数据处理入口:如需处理帧(如推流、分析),可在此处添加逻辑
|
||||
# 示例:yield frame (若需生成器模式,可调整函数为异步生成器)
|
||||
|
||||
# 8. 异常处理(覆盖用户中断、通用错误)
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n用户操作: 已通过 Ctrl+C 中断程序")
|
||||
except Exception as e:
|
||||
# 日志记录详细错误(便于问题排查),同时打印用户可见信息
|
||||
logger.error(f"RTMP流处理异常: {str(e)}", exc_info=True)
|
||||
print(f"错误信息: {str(e)}")
|
||||
finally:
|
||||
# 9. 资源释放(无论成功/失败都确保释放,避免内存泄漏)
|
||||
if cap is not None:
|
||||
await asyncio.to_thread(cap.release)
|
||||
print(f"\n资源释放: RTMP流已关闭")
|
||||
@ -114,7 +106,6 @@ async def rtmp_pull_video_stream(rtmp_url):
|
||||
if __name__ == "__main__":
|
||||
RTMP_URL = "rtmp://192.168.110.25:1935/live/473b95a47e338301cbd96809ea7ac416"
|
||||
|
||||
# 运行RTMP拉流任务(与WHEP一致的异步执行方式)
|
||||
try:
|
||||
asyncio.run(rtmp_pull_video_stream(RTMP_URL))
|
||||
except Exception as e:
|
||||
|
@ -4,10 +4,12 @@ import numpy as np
|
||||
import insightface
|
||||
from insightface.app import FaceAnalysis
|
||||
|
||||
|
||||
class FaceRecognizer:
|
||||
"""
|
||||
封装InsightFace人脸识别功能,支持从文件夹加载已知人脸。
|
||||
"""
|
||||
|
||||
def __init__(self, known_faces_dir: str):
|
||||
self.known_faces_dir = known_faces_dir
|
||||
self.app = self._initialize_insightface()
|
||||
@ -16,40 +18,30 @@ class FaceRecognizer:
|
||||
self._load_known_faces()
|
||||
|
||||
def _initialize_insightface(self):
|
||||
"""
|
||||
初始化InsightFace FaceAnalysis应用。
|
||||
默认使用CPU,如果检测到CUDA,会自动使用GPU。
|
||||
"""
|
||||
print("正在初始化InsightFace人脸识别引擎...")
|
||||
"""初始化InsightFace FaceAnalysis应用"""
|
||||
print("初始化InsightFace引擎...")
|
||||
try:
|
||||
# 默认模型是 'buffalo_l',包含检测、对齐、识别功能
|
||||
# 如果需要更小的模型,可以尝试 'buffalo_s' 或 'buffalo_m'
|
||||
# ctx_id=0 表示使用GPU,ctx_id=-1 表示使用CPU
|
||||
# InsightFace会自动检测CUDA并选择GPU,所以通常不需要手动设置ctx_id
|
||||
app = FaceAnalysis(name='buffalo_l', root='~/.insightface') # 模型下载到用户目录
|
||||
app.prepare(ctx_id=0, det_size=(640, 640)) # det_size影响检测性能和精度
|
||||
print("InsightFace人脸识别引擎初始化成功。")
|
||||
app = FaceAnalysis(name='buffalo_l', root='~/.insightface')
|
||||
app.prepare(ctx_id=0, det_size=(640, 640))
|
||||
print("InsightFace引擎初始化完成")
|
||||
return app
|
||||
except Exception as e:
|
||||
print(f"InsightFace人脸识别引擎初始化失败: {e}")
|
||||
print("请确保已安装insightface和onnxruntime,并且模型文件已下载或可访问。")
|
||||
print(f"InsightFace初始化失败: {e}")
|
||||
print("请检查依赖是否安装及模型是否可访问")
|
||||
return None
|
||||
|
||||
def _load_known_faces(self):
|
||||
"""
|
||||
扫描已知人脸目录,加载每个人的照片并计算人脸特征。
|
||||
"""
|
||||
"""加载已知人脸特征"""
|
||||
if not os.path.exists(self.known_faces_dir):
|
||||
print(f"警告: 已知人脸目录 '{self.known_faces_dir}' 不存在。请创建并放入照片。")
|
||||
print(f"已知人脸目录不存在,已创建: {self.known_faces_dir}")
|
||||
os.makedirs(self.known_faces_dir, exist_ok=True)
|
||||
return
|
||||
|
||||
print(f"正在加载已知人脸特征从: '{self.known_faces_dir}'...")
|
||||
print(f"从目录加载人脸特征: {self.known_faces_dir}")
|
||||
for person_name in os.listdir(self.known_faces_dir):
|
||||
|
||||
person_dir = os.path.join(self.known_faces_dir, person_name)
|
||||
if os.path.isdir(person_dir):
|
||||
print(f" 加载人物: {person_name}")
|
||||
print(f"处理人物: {person_name}")
|
||||
embeddings = []
|
||||
for filename in os.listdir(person_dir):
|
||||
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
|
||||
@ -57,131 +49,91 @@ class FaceRecognizer:
|
||||
try:
|
||||
img = cv2.imread(image_path)
|
||||
if img is None:
|
||||
print(f" 警告: 无法读取图片 '{image_path}',已跳过。")
|
||||
print(f"无法读取图片: {image_path},已跳过")
|
||||
continue
|
||||
|
||||
# 查找人脸并提取特征
|
||||
faces = self.app.get(img)
|
||||
if faces:
|
||||
# 通常一张照片只有一个人脸,取第一个
|
||||
embeddings.append(faces[0].embedding)
|
||||
print(f" 成功提取 '{filename}' 的人脸特征。")
|
||||
print(f"提取特征成功: {filename}")
|
||||
else:
|
||||
print(f" 警告: 在图片 '{filename}' 中未检测到人脸,已跳过。")
|
||||
print(f"未检测到人脸: {filename},已跳过")
|
||||
except Exception as e:
|
||||
print(f" 处理图片 '{image_path}' 时发生错误: {e}")
|
||||
print(f"处理图片出错 {image_path}: {e}")
|
||||
|
||||
if embeddings:
|
||||
# 将多张照片的特征取平均,作为该人物的最终特征
|
||||
self.known_faces_embeddings[person_name] = np.array(embeddings).mean(axis=0)
|
||||
self.known_faces_names.append(person_name)
|
||||
print(f" 人物 '{person_name}' 加载完成,共 {len(embeddings)} 张照片。")
|
||||
print(f"人物 {person_name} 加载完成,共 {len(embeddings)} 张照片")
|
||||
else:
|
||||
print(f" 警告: 人物 '{person_name}' 没有有效的人脸特征,已跳过。")
|
||||
print(f"已知人脸加载完成。共 {len(self.known_faces_names)} 个人物。")
|
||||
print(f"人物 {person_name} 无有效特征,已跳过")
|
||||
print(f"人脸加载完成,共 {len(self.known_faces_names)} 人")
|
||||
|
||||
def recognize(self, frame, threshold=0.4):
|
||||
"""
|
||||
在视频帧中识别人脸。
|
||||
|
||||
Args:
|
||||
frame: 输入的图像帧 (NumPy数组, BGR格式)。
|
||||
threshold (float): 识别相似度阈值。0.0到1.0,越高越严格。
|
||||
|
||||
Returns:
|
||||
tuple[bool, str|None, float|None]: 是否检测到已知人脸,检测到的人名,以及相似度。
|
||||
"""
|
||||
"""识别人脸并返回结果"""
|
||||
if not self.app or not self.known_faces_names:
|
||||
return False, None, None
|
||||
|
||||
faces = self.app.get(frame) # 在帧中检测并提取所有人的脸
|
||||
faces = self.app.get(frame)
|
||||
if not faces:
|
||||
return False, None, None
|
||||
|
||||
for face in faces:
|
||||
# 遍历已知人脸,进行比对
|
||||
for known_name in self.known_faces_names:
|
||||
known_embedding = self.known_faces_embeddings[known_name]
|
||||
|
||||
# --- 关键修改:手动计算余弦相似度 ---
|
||||
# 确保embedding是float32类型,避免潜在的类型不匹配问题
|
||||
embedding1 = face.embedding.astype(np.float32)
|
||||
embedding2 = known_embedding.astype(np.float32)
|
||||
|
||||
# 计算点积
|
||||
dot_product = np.dot(embedding1, embedding2)
|
||||
# 计算L2范数(向量长度)
|
||||
norm_embedding1 = np.linalg.norm(embedding1)
|
||||
norm_embedding2 = np.linalg.norm(embedding2)
|
||||
|
||||
# 避免除以零
|
||||
if norm_embedding1 == 0 or norm_embedding2 == 0:
|
||||
similarity = 0.0
|
||||
else:
|
||||
similarity = dot_product / (norm_embedding1 * norm_embedding2)
|
||||
# -------------------------------------
|
||||
similarity = 0.0 if (norm_embedding1 == 0 or norm_embedding2 == 0) else (
|
||||
dot_product / (norm_embedding1 * norm_embedding2)
|
||||
)
|
||||
|
||||
if similarity >= threshold:
|
||||
print(f"!!! 人脸识别检测到已知人物: '{known_name}' (相似度: {similarity:.4f}) !!!")
|
||||
return True, known_name, similarity # 只要检测到一个就立即返回
|
||||
print(f"检测到已知人物: {known_name} (相似度: {similarity:.4f})")
|
||||
return True, known_name, similarity
|
||||
|
||||
return False, None, None # 没有检测到已知人脸
|
||||
return False, None, None
|
||||
|
||||
def test_single_image(self, image_path: str, threshold=0.4):
|
||||
"""测试单张图片识别"""
|
||||
if not os.path.exists(image_path):
|
||||
print(f"图片不存在: {image_path}")
|
||||
return False, None, None
|
||||
|
||||
# def test_single_image(self, image_path: str, threshold=0.4):
|
||||
# """
|
||||
# 测试单张图片的人脸识别效果
|
||||
#
|
||||
# Args:
|
||||
# image_path: 图片路径
|
||||
# threshold: 识别阈值
|
||||
#
|
||||
# Returns:
|
||||
# tuple[bool, str|None, float|None]: 是否检测到已知人脸,检测到的人名,以及相似度
|
||||
# """
|
||||
# if not os.path.exists(image_path):
|
||||
# print(f"错误: 图片 '{image_path}' 不存在")
|
||||
# return False, None, None
|
||||
#
|
||||
# # 读取图片
|
||||
# frame = cv2.imread(image_path)
|
||||
# if frame is None:
|
||||
# print(f"错误: 无法读取图片 '{image_path}'")
|
||||
# return False, None, None
|
||||
#
|
||||
# # 调用识别方法
|
||||
# result, name, similarity = self.recognize(frame, threshold)
|
||||
#
|
||||
# # 显示结果
|
||||
# if result:
|
||||
# print(f"测试结果: 在图片中识别到 {name},相似度: {similarity:.4f}")
|
||||
#
|
||||
# # 绘制识别结果并显示图片
|
||||
# faces = self.app.get(frame)
|
||||
# for face in faces:
|
||||
# bbox = face.bbox.astype(int)
|
||||
# # 绘制 bounding box
|
||||
# cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2)
|
||||
# # 绘制姓名和相似度
|
||||
# text = f"{name}: {similarity:.2f}"
|
||||
# cv2.putText(frame, text, (bbox[0], bbox[1] - 10),
|
||||
# cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
|
||||
#
|
||||
# # 显示图片
|
||||
# cv2.imshow('Recognition Result', frame)
|
||||
# print("按任意键关闭图片窗口...")
|
||||
# cv2.waitKey(0)
|
||||
# cv2.destroyAllWindows()
|
||||
# else:
|
||||
# print("测试结果: 未在图片中识别到已知人脸")
|
||||
#
|
||||
# return result, name, similarity
|
||||
frame = cv2.imread(image_path)
|
||||
if frame is None:
|
||||
print(f"无法读取图片: {image_path}")
|
||||
return False, None, None
|
||||
|
||||
result, name, similarity = self.recognize(frame, threshold)
|
||||
|
||||
if result:
|
||||
print(f"识别结果: {name} (相似度: {similarity:.4f})")
|
||||
|
||||
faces = self.app.get(frame)
|
||||
for face in faces:
|
||||
bbox = face.bbox.astype(int)
|
||||
cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2)
|
||||
text = f"{name}: {similarity:.2f}"
|
||||
cv2.putText(frame, text, (bbox[0], bbox[1] - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
|
||||
|
||||
cv2.imshow('识别结果', frame)
|
||||
print("按任意键关闭窗口...")
|
||||
cv2.waitKey(0)
|
||||
cv2.destroyAllWindows()
|
||||
else:
|
||||
print("未识别到已知人脸")
|
||||
|
||||
return result, name, similarity
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# # 初始化人脸识别器,指定已知人脸目录
|
||||
# recognizer = FaceRecognizer(known_faces_dir="known_faces")
|
||||
#
|
||||
# # 测试单张图片
|
||||
# test_image_path = r"F:\OCR\RapidOCR-main\known_faces\B\14sino-qiu02-master1050.jpg" # 替换为你的测试图片路径
|
||||
# if __name__ == "__main__":
|
||||
# recognizer = FaceRecognizer(known_faces_dir="known_faces")
|
||||
# test_image_path = r"F:\OCR\RapidOCR-main\known_faces\B\14sino-qiu02-master1050.jpg"
|
||||
# recognizer.test_single_image(test_image_path, threshold=0.4)
|
@ -1,35 +1,30 @@
|
||||
import cv2
|
||||
from logger_config import logger
|
||||
from ocr_violation_detector import OCRViolationDetector
|
||||
from yolo_violation_detector import ViolationDetector as YoloViolationDetector
|
||||
from face_recognizer import FaceRecognizer
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from .ocr_violation_detector import OCRViolationDetector
|
||||
from .yolo_violation_detector import ViolationDetector as YoloViolationDetector
|
||||
from .face_recognizer import FaceRecognizer
|
||||
|
||||
class MultiModelViolationDetector:
|
||||
"""
|
||||
多模型违规检测封装类,串行调用OCR、人脸识别和YOLO模型(调整为YOLO最后检测),任一模型检测到违规即返回结果
|
||||
多模型违规检测封装类,串行调用OCR、人脸识别和YOLO模型,任一模型检测到违规即返回结果
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
forbidden_words_path: str,
|
||||
ocr_config_path: str, # 新增OCR配置文件路径参数
|
||||
ocr_config_path: str,
|
||||
yolo_model_path: str,
|
||||
known_faces_dir: str,
|
||||
ocr_confidence_threshold: float = 0.5):
|
||||
"""
|
||||
初始化所有检测模型
|
||||
|
||||
Args:
|
||||
forbidden_words_path: 违禁词文件路径
|
||||
ocr_config_path: OCR配置文件(1.yaml)路径
|
||||
yolo_model_path: YOLO模型文件路径
|
||||
known_faces_dir: 已知人脸目录路径
|
||||
ocr_confidence_threshold: OCR置信度阈值
|
||||
"""
|
||||
# 初始化OCR检测器(传入配置文件路径)
|
||||
# 初始化OCR检测器
|
||||
self.ocr_detector = OCRViolationDetector(
|
||||
forbidden_words_path=forbidden_words_path,
|
||||
ocr_config_path=ocr_config_path, # 传递配置文件路径
|
||||
ocr_config_path=ocr_config_path,
|
||||
ocr_confidence_threshold=ocr_confidence_threshold
|
||||
)
|
||||
|
||||
@ -38,22 +33,16 @@ class MultiModelViolationDetector:
|
||||
known_faces_dir=known_faces_dir
|
||||
)
|
||||
|
||||
# 初始化YOLO检测器(调整为最后初始化)
|
||||
# 初始化YOLO检测器
|
||||
self.yolo_detector = YoloViolationDetector(
|
||||
model_path=yolo_model_path
|
||||
)
|
||||
|
||||
logger.info("多模型违规检测器初始化完成")
|
||||
print("多模型违规检测器初始化完成")
|
||||
|
||||
def detect_violations(self, frame):
|
||||
"""
|
||||
串行调用三个检测模型(OCR → 人脸识别 → YOLO),任一检测到违规即返回结果
|
||||
Args:
|
||||
frame: 输入视频帧 (NumPy数组, BGR格式)
|
||||
Returns:
|
||||
tuple: (是否有违规, 违规类型, 违规详情)
|
||||
违规类型: 'ocr' | 'yolo' | 'face' | None
|
||||
违规详情: 对应模型的检测结果
|
||||
"""
|
||||
# 1. 首先进行OCR违禁词检测
|
||||
try:
|
||||
@ -63,10 +52,10 @@ class MultiModelViolationDetector:
|
||||
"words": ocr_words,
|
||||
"confidences": ocr_confs
|
||||
}
|
||||
logger.warning(f"OCR检测到违禁内容: {details}")
|
||||
print(f"警告: OCR检测到违禁内容: {details}")
|
||||
return (True, "ocr", details)
|
||||
except Exception as e:
|
||||
logger.error(f"OCR检测出错: {str(e)}", exc_info=True)
|
||||
print(f"错误: OCR检测出错: {str(e)}")
|
||||
|
||||
# 2. 接着进行人脸识别检测
|
||||
try:
|
||||
@ -76,51 +65,63 @@ class MultiModelViolationDetector:
|
||||
"name": face_name,
|
||||
"similarity": face_similarity
|
||||
}
|
||||
logger.warning(f"人脸识别到违规人员: {details}")
|
||||
print(f"警告: 人脸识别到违规人员: {details}")
|
||||
return (True, "face", details)
|
||||
except Exception as e:
|
||||
logger.error(f"人脸识别出错: {str(e)}", exc_info=True)
|
||||
print(f"错误: 人脸识别出错: {str(e)}")
|
||||
|
||||
# 3. 最后进行YOLO目标检测(调整为最后检测)
|
||||
# 3. 最后进行YOLO目标检测
|
||||
try:
|
||||
yolo_results = self.yolo_detector.detect(frame)
|
||||
# 检查是否有检测结果(根据实际业务定义何为违规目标)
|
||||
if len(yolo_results.boxes) > 0:
|
||||
# 提取检测到的目标信息
|
||||
details = {
|
||||
"classes": yolo_results.names,
|
||||
"boxes": yolo_results.boxes.xyxy.tolist(), # 边界框坐标
|
||||
"confidences": yolo_results.boxes.conf.tolist(), # 置信度
|
||||
"class_ids": yolo_results.boxes.cls.tolist() # 类别ID
|
||||
"boxes": yolo_results.boxes.xyxy.tolist(),
|
||||
"confidences": yolo_results.boxes.conf.tolist(),
|
||||
"class_ids": yolo_results.boxes.cls.tolist()
|
||||
}
|
||||
logger.warning(f"YOLO检测到违规目标: {details}")
|
||||
print(f"警告: YOLO检测到违规目标: {details}")
|
||||
return (True, "yolo", details)
|
||||
except Exception as e:
|
||||
logger.error(f"YOLO检测出错: {str(e)}", exc_info=True)
|
||||
print(f"错误: YOLO检测出错: {str(e)}")
|
||||
|
||||
# 所有检测均未发现违规
|
||||
return (False, None, None)
|
||||
|
||||
|
||||
# # 使用示例
|
||||
def load_config(config_path: str) -> dict:
|
||||
"""加载YAML配置文件"""
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
return yaml.safe_load(f)
|
||||
except FileNotFoundError:
|
||||
print(f"错误: 配置文件未找到: {config_path}")
|
||||
raise
|
||||
except yaml.YAMLError as e:
|
||||
print(f"错误: 配置文件格式错误: {config_path}, 错误: {str(e)}")
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"错误: 加载配置文件出错: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
# 使用示例
|
||||
# if __name__ == "__main__":
|
||||
# # 配置文件路径(根据实际情况修改)
|
||||
# FORBIDDEN_WORDS_PATH = r"D:\Git\bin\video\ocr\forbidden_words.txt"
|
||||
# OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml" # 新增OCR配置文件路径
|
||||
# YOLO_MODEL_PATH = r"D:\Git\bin\video\ocr\models\best.pt"
|
||||
# KNOWN_FACES_DIR = r"D:\Git\bin\video\ocr\known_faces"
|
||||
# # 加载配置文件
|
||||
# config = load_config("config.yaml") # 配置文件路径,可根据实际情况修改
|
||||
#
|
||||
# # 初始化多模型检测器
|
||||
# detector = MultiModelViolationDetector(
|
||||
# forbidden_words_path=FORBIDDEN_WORDS_PATH,
|
||||
# ocr_config_path=OCR_CONFIG_PATH, # 传入OCR配置文件路径
|
||||
# yolo_model_path=YOLO_MODEL_PATH,
|
||||
# known_faces_dir=KNOWN_FACES_DIR,
|
||||
# ocr_confidence_threshold=0.5
|
||||
# forbidden_words_path=config["forbidden_words_path"],
|
||||
# ocr_config_path=config["ocr_config_path"],
|
||||
# yolo_model_path=config["yolo_model_path"],
|
||||
# known_faces_dir=config["known_faces_dir"],
|
||||
# ocr_confidence_threshold=config.get("ocr_confidence_threshold", 0.5)
|
||||
# )
|
||||
#
|
||||
# # 读取测试图像(可替换为视频帧读取逻辑)
|
||||
# test_image_path = r"D:\Git\bin\video\ocr\images\img.png"
|
||||
# test_image_path = config.get("test_image_path") # 从配置文件获取测试图片路径
|
||||
# if test_image_path:
|
||||
# frame = cv2.imread(test_image_path)
|
||||
#
|
||||
# if frame is not None:
|
||||
@ -131,3 +132,5 @@ class MultiModelViolationDetector:
|
||||
# print("未检测到任何违规内容")
|
||||
# else:
|
||||
# print(f"无法读取测试图像: {test_image_path}")
|
||||
# else:
|
||||
# print("配置文件中未指定测试图像路径")
|
@ -1,6 +1,5 @@
|
||||
import os
|
||||
import cv2
|
||||
import logging
|
||||
from rapidocr import RapidOCR
|
||||
|
||||
|
||||
@ -13,153 +12,85 @@ class OCRViolationDetector:
|
||||
def __init__(self,
|
||||
forbidden_words_path: str,
|
||||
ocr_config_path: str,
|
||||
ocr_confidence_threshold: float = 0.5,
|
||||
log_level: int = logging.INFO,
|
||||
log_file: str = None):
|
||||
ocr_confidence_threshold: float = 0.5):
|
||||
"""
|
||||
初始化OCR引擎、违禁词列表和日志配置。
|
||||
初始化OCR引擎和违禁词列表。
|
||||
|
||||
Args:
|
||||
forbidden_words_path (str): 违禁词列表 .txt 文件的路径。
|
||||
ocr_config_path (str): OCR配置文件(如1.yaml)的路径。
|
||||
ocr_confidence_threshold (float): OCR识别结果的置信度阈值(0~1)。
|
||||
log_level (int): 日志级别,默认为logging.INFO。
|
||||
log_file (str, optional): 日志文件路径,如不提供则只输出到控制台。
|
||||
"""
|
||||
# 初始化日志(确保先初始化日志,后续操作可正常打日志)
|
||||
self.logger = self._setup_logger(log_level, log_file)
|
||||
|
||||
# 加载违禁词(优先级:先加载配置,再初始化引擎)
|
||||
# 加载违禁词
|
||||
self.forbidden_words = self._load_forbidden_words(forbidden_words_path)
|
||||
|
||||
# 初始化RapidOCR引擎(传入配置文件路径)
|
||||
# 初始化RapidOCR引擎
|
||||
self.ocr_engine = self._initialize_ocr(ocr_config_path)
|
||||
|
||||
# 校验核心依赖是否就绪
|
||||
self._check_dependencies()
|
||||
|
||||
# 设置置信度阈值(限制在0~1范围,避免非法值)
|
||||
# 设置置信度阈值(限制在0~1范围)
|
||||
self.OCR_CONFIDENCE_THRESHOLD = max(0.0, min(ocr_confidence_threshold, 1.0))
|
||||
self.logger.info(f"OCR置信度阈值已设置(范围0~1): {self.OCR_CONFIDENCE_THRESHOLD:.4f}")
|
||||
|
||||
def _setup_logger(self, log_level: int, log_file: str = None) -> logging.Logger:
|
||||
"""
|
||||
配置日志系统(避免重复添加处理器,支持控制台+文件双输出)
|
||||
|
||||
Args:
|
||||
log_level: 日志级别(如logging.DEBUG、logging.INFO)。
|
||||
log_file: 日志文件路径,为None时仅输出到控制台。
|
||||
|
||||
Returns:
|
||||
logging.Logger: 配置好的日志实例。
|
||||
"""
|
||||
logger = logging.getLogger('OCRViolationDetector')
|
||||
logger.setLevel(log_level)
|
||||
|
||||
# 避免重复添加处理器(防止日志重复输出)
|
||||
if logger.handlers:
|
||||
return logger
|
||||
|
||||
# 定义日志格式(包含时间、模块名、级别、内容)
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
|
||||
# 1. 添加控制台处理器
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(formatter)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# 2. 若指定日志文件,添加文件处理器(自动创建目录)
|
||||
if log_file:
|
||||
try:
|
||||
log_dir = os.path.dirname(log_file)
|
||||
# 若日志目录不存在,自动创建
|
||||
if log_dir and not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
self.logger.debug(f"自动创建日志目录: {log_dir}")
|
||||
|
||||
file_handler = logging.FileHandler(log_file, encoding='utf-8')
|
||||
file_handler.setFormatter(formatter)
|
||||
logger.addHandler(file_handler)
|
||||
logger.info(f"日志文件已配置: {log_file}")
|
||||
except Exception as e:
|
||||
logger.warning(f"创建日志文件失败(仅控制台输出): {str(e)}")
|
||||
|
||||
return logger
|
||||
print(f"OCR置信度阈值已设置(范围0~1): {self.OCR_CONFIDENCE_THRESHOLD:.4f}")
|
||||
|
||||
def _load_forbidden_words(self, path: str) -> set:
|
||||
"""
|
||||
从TXT文件加载违禁词(去重、过滤空行,支持UTF-8编码)
|
||||
|
||||
Args:
|
||||
path (str): 违禁词TXT文件路径。
|
||||
|
||||
Returns:
|
||||
set: 去重后的违禁词集合(空集合表示加载失败)。
|
||||
"""
|
||||
forbidden_words = set()
|
||||
|
||||
# 第一步:检查文件是否存在
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(path):
|
||||
self.logger.error(f"违禁词文件不存在: {path}")
|
||||
print(f"错误:违禁词文件不存在: {path}")
|
||||
return forbidden_words
|
||||
|
||||
# 第二步:读取文件并处理内容
|
||||
# 读取文件并处理内容
|
||||
try:
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
# 过滤空行、去除首尾空格、去重
|
||||
forbidden_words = {
|
||||
line.strip() for line in f
|
||||
if line.strip() # 跳过空行或纯空格行
|
||||
}
|
||||
self.logger.info(f"成功加载违禁词: {len(forbidden_words)} 个(已去重)")
|
||||
self.logger.debug(f"违禁词列表: {forbidden_words}")
|
||||
print(f"成功加载违禁词: {len(forbidden_words)} 个(已去重)")
|
||||
except UnicodeDecodeError:
|
||||
self.logger.error(f"违禁词文件编码错误(需UTF-8): {path}")
|
||||
print(f"错误:违禁词文件编码错误(需UTF-8): {path}")
|
||||
except PermissionError:
|
||||
self.logger.error(f"无权限读取违禁词文件: {path}")
|
||||
print(f"错误:无权限读取违禁词文件: {path}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"加载违禁词失败: {str(e)}", exc_info=True)
|
||||
print(f"错误:加载违禁词失败: {str(e)}")
|
||||
|
||||
return forbidden_words
|
||||
|
||||
def _initialize_ocr(self, config_path: str) -> RapidOCR | None:
|
||||
"""
|
||||
初始化RapidOCR引擎(校验配置文件、捕获初始化异常)
|
||||
|
||||
Args:
|
||||
config_path (str): RapidOCR配置文件(如1.yaml)路径。
|
||||
|
||||
Returns:
|
||||
RapidOCR | None: OCR引擎实例(None表示初始化失败)。
|
||||
"""
|
||||
self.logger.info("开始初始化RapidOCR引擎...")
|
||||
print("开始初始化RapidOCR引擎...")
|
||||
|
||||
# 第一步:检查配置文件是否存在
|
||||
# 检查配置文件是否存在
|
||||
if not os.path.exists(config_path):
|
||||
self.logger.error(f"OCR配置文件不存在: {config_path}")
|
||||
print(f"错误:OCR配置文件不存在: {config_path}")
|
||||
return None
|
||||
|
||||
# 第二步:初始化OCR引擎(捕获RapidOCR相关异常)
|
||||
# 初始化OCR引擎
|
||||
try:
|
||||
ocr_engine = RapidOCR(config_path=config_path)
|
||||
self.logger.info("RapidOCR引擎初始化成功")
|
||||
print("RapidOCR引擎初始化成功")
|
||||
return ocr_engine
|
||||
except ImportError:
|
||||
self.logger.error("RapidOCR依赖未安装(需执行:pip install rapidocr-onnxruntime)")
|
||||
print("错误:RapidOCR依赖未安装(需执行:pip install rapidocr-onnxruntime)")
|
||||
except Exception as e:
|
||||
self.logger.error(f"RapidOCR初始化失败: {str(e)}", exc_info=True)
|
||||
print(f"错误:RapidOCR初始化失败: {str(e)}")
|
||||
|
||||
return None
|
||||
|
||||
def _check_dependencies(self) -> None:
|
||||
"""校验OCR引擎和违禁词列表是否就绪(输出警告日志)"""
|
||||
"""校验OCR引擎和违禁词列表是否就绪"""
|
||||
if not self.ocr_engine:
|
||||
self.logger.warning("⚠️ OCR引擎未就绪,违禁词检测功能将禁用")
|
||||
print("警告:⚠️ OCR引擎未就绪,违禁词检测功能将禁用")
|
||||
if not self.forbidden_words:
|
||||
self.logger.warning("⚠️ 违禁词列表为空,违禁词检测功能将禁用")
|
||||
print("警告:⚠️ 违禁词列表为空,违禁词检测功能将禁用")
|
||||
|
||||
def detect(self, frame) -> tuple[bool, list, list]:
|
||||
"""
|
||||
@ -179,76 +110,69 @@ class OCRViolationDetector:
|
||||
violation_words = []
|
||||
violation_confs = []
|
||||
|
||||
# 前置校验:1. 图像帧是否有效 2. OCR引擎是否就绪 3. 违禁词是否存在
|
||||
# 前置校验
|
||||
if frame is None or frame.size == 0:
|
||||
self.logger.warning("输入图像帧为空或无效,跳过OCR检测")
|
||||
print("警告:输入图像帧为空或无效,跳过OCR检测")
|
||||
return has_violation, violation_words, violation_confs
|
||||
if not self.ocr_engine or not self.forbidden_words:
|
||||
self.logger.debug("OCR引擎未就绪或违禁词为空,跳过OCR检测")
|
||||
print("OCR引擎未就绪或违禁词为空,跳过OCR检测")
|
||||
return has_violation, violation_words, violation_confs
|
||||
|
||||
try:
|
||||
# 1. 执行OCR识别(获取RapidOCR原始结果)
|
||||
self.logger.debug("开始执行OCR识别...")
|
||||
# 执行OCR识别
|
||||
print("开始执行OCR识别...")
|
||||
ocr_result = self.ocr_engine(frame)
|
||||
self.logger.debug(f"RapidOCR原始结果: {ocr_result}")
|
||||
print(f"RapidOCR原始结果: {ocr_result}")
|
||||
|
||||
# 2. 校验OCR结果是否有效(避免None或格式异常)
|
||||
# 校验OCR结果是否有效
|
||||
if ocr_result is None:
|
||||
self.logger.debug("OCR识别未返回任何结果(图像无文本或识别失败)")
|
||||
print("OCR识别未返回任何结果(图像无文本或识别失败)")
|
||||
return has_violation, violation_words, violation_confs
|
||||
|
||||
# 3. 检查txts和scores是否存在且不为None
|
||||
# 检查txts和scores是否存在且不为None
|
||||
if not hasattr(ocr_result, 'txts') or ocr_result.txts is None:
|
||||
self.logger.warning("OCR结果中txts为None或不存在")
|
||||
print("警告:OCR结果中txts为None或不存在")
|
||||
return has_violation, violation_words, violation_confs
|
||||
|
||||
if not hasattr(ocr_result, 'scores') or ocr_result.scores is None:
|
||||
self.logger.warning("OCR结果中scores为None或不存在")
|
||||
print("警告:OCR结果中scores为None或不存在")
|
||||
return has_violation, violation_words, violation_confs
|
||||
|
||||
# 4. 转为列表并去None(防止单个元素为None)
|
||||
# 确保txts是可迭代的,如果不是则转为空列表
|
||||
# 转为列表并去None
|
||||
if not isinstance(ocr_result.txts, (list, tuple)):
|
||||
self.logger.warning(f"OCR txts不是可迭代类型,实际类型: {type(ocr_result.txts)}")
|
||||
print(f"警告:OCR txts不是可迭代类型,实际类型: {type(ocr_result.txts)}")
|
||||
texts = []
|
||||
else:
|
||||
texts = [txt.strip() for txt in ocr_result.txts if txt and isinstance(txt, str)]
|
||||
|
||||
# 确保scores是可迭代的,如果不是则转为空列表
|
||||
if not isinstance(ocr_result.scores, (list, tuple)):
|
||||
self.logger.warning(f"OCR scores不是可迭代类型,实际类型: {type(ocr_result.scores)}")
|
||||
print(f"警告:OCR scores不是可迭代类型,实际类型: {type(ocr_result.scores)}")
|
||||
confidences = []
|
||||
else:
|
||||
confidences = [conf for conf in ocr_result.scores if conf and isinstance(conf, (int, float))]
|
||||
|
||||
# 5. 校验文本和置信度列表长度是否一致(避免zip迭代错误)
|
||||
# 校验文本和置信度列表长度是否一致
|
||||
if len(texts) != len(confidences):
|
||||
self.logger.warning(
|
||||
f"OCR文本与置信度数量不匹配(文本{len(texts)}个,置信度{len(confidences)}个),跳过检测")
|
||||
print(f"警告:OCR文本与置信度数量不匹配(文本{len(texts)}个,置信度{len(confidences)}个),跳过检测")
|
||||
return has_violation, violation_words, violation_confs
|
||||
if len(texts) == 0:
|
||||
self.logger.debug("OCR未识别到任何有效文本")
|
||||
print("OCR未识别到任何有效文本")
|
||||
return has_violation, violation_words, violation_confs
|
||||
|
||||
# 6. 遍历识别结果,筛选违禁词(按置信度阈值过滤)
|
||||
self.logger.debug(f"开始筛选违禁词(阈值{self.OCR_CONFIDENCE_THRESHOLD:.4f})")
|
||||
# 遍历识别结果,筛选违禁词
|
||||
print(f"开始筛选违禁词(阈值{self.OCR_CONFIDENCE_THRESHOLD:.4f})")
|
||||
for text, conf in zip(texts, confidences):
|
||||
# 过滤低置信度结果
|
||||
if conf < self.OCR_CONFIDENCE_THRESHOLD:
|
||||
self.logger.debug(f"文本 '{text}' 置信度{conf:.4f} < 阈值,跳过")
|
||||
print(f"文本 '{text}' 置信度{conf:.4f} < 阈值,跳过")
|
||||
continue
|
||||
# 检查当前文本是否包含违禁词(支持一个文本含多个违禁词)
|
||||
matched_words = [word for word in self.forbidden_words if word in text]
|
||||
if matched_words:
|
||||
has_violation = True
|
||||
# 记录所有匹配的违禁词和对应置信度
|
||||
violation_words.extend(matched_words)
|
||||
violation_confs.extend([conf] * len(matched_words)) # 一个文本对应多个违禁词时,置信度复用
|
||||
self.logger.warning(f"检测到违禁词: {matched_words}(来源文本: '{text}',置信度: {conf:.4f})")
|
||||
violation_confs.extend([conf] * len(matched_words))
|
||||
print(f"警告:检测到违禁词: {matched_words}(来源文本: '{text}',置信度: {conf:.4f})")
|
||||
|
||||
except Exception as e:
|
||||
# 捕获所有异常,确保不中断上层调用
|
||||
self.logger.error(f"OCR检测过程异常: {str(e)}", exc_info=True)
|
||||
print(f"错误:OCR检测过程异常: {str(e)}")
|
||||
|
||||
return has_violation, violation_words, violation_confs
|
@ -1,6 +1,5 @@
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
from logger_config import logger
|
||||
|
||||
class ViolationDetector:
|
||||
"""
|
||||
@ -13,9 +12,9 @@ class ViolationDetector:
|
||||
Args:
|
||||
model_path (str): YOLO .pt模型的路径。
|
||||
"""
|
||||
logger.info(f"正在从 '{model_path}' 加载YOLO模型...")
|
||||
print(f"正在从 '{model_path}' 加载YOLO模型...")
|
||||
self.model = YOLO(model_path)
|
||||
logger.info("YOLO模型加载成功。")
|
||||
print("YOLO模型加载成功。")
|
||||
|
||||
def detect(self, frame):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user