数据预处理
This commit is contained in:
9
README.md
Normal file
9
README.md
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
### 数据集预处理说明
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- 读取一段原始视频、切片为帧
|
||||||
|
|
||||||
|
- 然后从原始数据集贴图到原始视频帧、模拟显示识别的复杂场景
|
||||||
|
- 然后生成最新的数据文件和标注文件
|
||||||
|
- 即可使用最新生成的文件进行训练
|
||||||
272
数据增强底图版.py
Normal file
272
数据增强底图版.py
Normal file
@ -0,0 +1,272 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import albumentations as A
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# --- 1. 用户配置(重点修改!!!)---
|
||||||
|
# 请根据你的实际路径修改,三个核心目录需区分清楚:
|
||||||
|
# 1. 特征素材来源目录:存放有「待粘贴目标(如no_helmet)」的图片和标签(用于提取可粘贴的目标)
|
||||||
|
SOURCE_FEATURE_IMAGE_DIR = r"E:\NSFW-Detection-YOLO\data\images\val\images" # 有目标的原图
|
||||||
|
SOURCE_FEATURE_LABEL_DIR = r"E:\NSFW-Detection-YOLO\data\images\val\labels" # 对应原图的标签
|
||||||
|
# 2. 独立底图目录:存放你要粘贴目标的「空白/背景底图」(底图无需标签)
|
||||||
|
BASE_IMAGE_DIR = r"D:\DataPreHandler\images\valid" # 你的底图文件夹
|
||||||
|
# 3. 输出目录:保存最终增强后的图片和标签
|
||||||
|
OUTPUT_IMAGE_DIR = r"D:\DataPreHandler\data\valid\images"
|
||||||
|
OUTPUT_LABEL_DIR = r"D:\DataPreHandler\data\valid\labels"
|
||||||
|
|
||||||
|
# 数据增强参数
|
||||||
|
AUGMENTATION_FACTOR = 1 # 每张底图生成的增强图数量(如40张)
|
||||||
|
|
||||||
|
# --- Copy-Paste 核心配置 ---
|
||||||
|
SMALL_OBJECT_CLASSES_TO_PASTE = [0,1,2,3,4,5,6] # 要粘贴的目标类别ID(如no_helmet是2)
|
||||||
|
PASTE_COUNT_RANGE = (5, 10) # 每张增强图上粘贴的目标数量(随机5-10个)
|
||||||
|
|
||||||
|
# --- 2. 常规增强流水线(修复Albumentations参数)---
|
||||||
|
transform_geometric = A.Compose([
|
||||||
|
A.HorizontalFlip(p=0.5),
|
||||||
|
# 修改1:A.Affine参数:rotate_limit→rotate,cval→pad_val,新增border_mode
|
||||||
|
A.Affine(scale=(0.8, 1.2), shear=(-10, 10), translate_percent=0.1,
|
||||||
|
rotate=30, border_mode=cv2.BORDER_CONSTANT, pad_val=0, p=0.8),
|
||||||
|
A.Perspective(scale=(0.02, 0.05), p=0.4),
|
||||||
|
], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels'], min_visibility=0.25))
|
||||||
|
|
||||||
|
transform_quality = A.Compose([
|
||||||
|
A.RandomBrightnessContrast(brightness_limit=0.25, contrast_limit=0.25, p=0.8),
|
||||||
|
A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.7),
|
||||||
|
# 修改2:A.GaussNoise参数:var_limit→std_limit(方差转标准差,数值取平方根近似)
|
||||||
|
A.OneOf([A.GaussNoise(std_limit=(3.0, 8.0), p=1.0), A.ISONoise(p=1.0)], p=0.6),
|
||||||
|
A.OneOf([A.Blur(blur_limit=(3, 7), p=1.0), A.MotionBlur(blur_limit=(3, 7), p=1.0)], p=0.5),
|
||||||
|
# 修改3:A.ImageCompression参数:quality_lower/upper→quality_range(合并为元组)
|
||||||
|
A.ImageCompression(quality_range=(70, 95), p=0.3),
|
||||||
|
], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels'], min_visibility=0.25))
|
||||||
|
|
||||||
|
transform_mixed = A.Compose([
|
||||||
|
A.HorizontalFlip(p=0.5),
|
||||||
|
# 修改4:A.Rotate参数:value→pad_val
|
||||||
|
A.Rotate(limit=15, p=0.5, border_mode=cv2.BORDER_CONSTANT, pad_val=0),
|
||||||
|
A.RandomBrightnessContrast(p=0.6),
|
||||||
|
A.GaussNoise(std_limit=(2.0, 6.0), p=0.4), # 同步修改GaussNoise参数
|
||||||
|
A.Blur(blur_limit=3, p=0.3),
|
||||||
|
], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels'], min_visibility=0.25))
|
||||||
|
|
||||||
|
base_transforms = [transform_geometric, transform_quality, transform_mixed] # 随机选择增强策略
|
||||||
|
|
||||||
|
|
||||||
|
# --- 3. 核心工具函数 ---
|
||||||
|
def harvest_objects_for_pasting(feature_image_dir, feature_label_dir, target_classes):
|
||||||
|
"""
|
||||||
|
从「特征素材来源目录」提取目标,创建可粘贴的素材库
|
||||||
|
:param feature_image_dir: 有目标的图片目录(如含no_helmet的原图)
|
||||||
|
:param feature_label_dir: 对应图片的标签目录
|
||||||
|
:param target_classes: 要提取的目标类别(如[2])
|
||||||
|
:return: 素材库 {类别ID: [目标图像1, 目标图像2, ...]}
|
||||||
|
"""
|
||||||
|
print(f"正在从 {feature_image_dir} 提取目标类别 {target_classes}...")
|
||||||
|
asset_library = {cls_id: [] for cls_id in target_classes}
|
||||||
|
|
||||||
|
# 只读取特征素材目录中的图片文件
|
||||||
|
feature_image_files = [f for f in os.listdir(feature_image_dir) if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
|
||||||
|
if not feature_image_files:
|
||||||
|
raise FileNotFoundError(f"特征素材目录 {feature_image_dir} 中未找到图片!")
|
||||||
|
|
||||||
|
for filename in tqdm(feature_image_files, desc="提取目标素材"):
|
||||||
|
label_file = os.path.splitext(filename)[0] + ".txt"
|
||||||
|
label_path = os.path.join(feature_label_dir, label_file)
|
||||||
|
if not os.path.exists(label_path):
|
||||||
|
continue # 跳过无标签的图片
|
||||||
|
|
||||||
|
# 读取图片并获取尺寸
|
||||||
|
img = cv2.imread(os.path.join(feature_image_dir, filename))
|
||||||
|
if img is None:
|
||||||
|
tqdm.write(f"警告:无法读取图片 {filename},已跳过")
|
||||||
|
continue
|
||||||
|
img_h, img_w, _ = img.shape
|
||||||
|
|
||||||
|
# 解析标签,裁剪目标
|
||||||
|
with open(label_path, 'r') as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
parts = line.split()
|
||||||
|
# 修改5:处理标签类别ID为浮点数的情况(如6.0→6):先转float再转int
|
||||||
|
cls_id = int(float(parts[0]))
|
||||||
|
if cls_id not in target_classes:
|
||||||
|
continue # 只保留目标类别
|
||||||
|
|
||||||
|
# YOLO归一化坐标转像素坐标(x1,y1:左上角;x2,y2:右下角)
|
||||||
|
x_center, y_center, box_w, box_h = [float(p) for p in parts[1:]]
|
||||||
|
x1 = int((x_center - box_w / 2) * img_w)
|
||||||
|
y1 = int((y_center - box_h / 2) * img_h)
|
||||||
|
x2 = int((x_center + box_w / 2) * img_w)
|
||||||
|
y2 = int((y_center + box_h / 2) * img_h)
|
||||||
|
|
||||||
|
# 确保坐标在图片范围内,避免裁剪出错
|
||||||
|
x1, y1 = max(0, x1), max(0, y1)
|
||||||
|
x2, y2 = min(img_w, x2), min(img_h, y2)
|
||||||
|
|
||||||
|
# 裁剪目标并加入素材库(排除空图像)
|
||||||
|
if x1 < x2 and y1 < y2:
|
||||||
|
cropped_obj = img[y1:y2, x1:x2]
|
||||||
|
if cropped_obj.size > 0:
|
||||||
|
asset_library[cls_id].append(cropped_obj)
|
||||||
|
|
||||||
|
# 检查素材库是否为空
|
||||||
|
total_assets = sum(len(v) for v in asset_library.values())
|
||||||
|
if total_assets == 0:
|
||||||
|
raise ValueError(f"未从特征素材目录提取到任何目标!请检查类别ID {target_classes} 是否正确")
|
||||||
|
|
||||||
|
print(f"素材库创建完成!共提取 {total_assets} 个目标(类别:{target_classes})")
|
||||||
|
return asset_library
|
||||||
|
|
||||||
|
|
||||||
|
def paste_objects_to_base(base_image, asset_library):
|
||||||
|
"""
|
||||||
|
将素材库中的目标粘贴到单张底图上
|
||||||
|
:param base_image: 输入的底图(cv2读取的BGR图像)
|
||||||
|
:param asset_library: 目标素材库
|
||||||
|
:return: 粘贴后的图像、对应的YOLO格式标签(bboxes + labels)
|
||||||
|
"""
|
||||||
|
base_h, base_w, _ = base_image.shape
|
||||||
|
pasted_bboxes = [] # 存储粘贴目标的YOLO bbox
|
||||||
|
pasted_labels = [] # 存储粘贴目标的类别ID
|
||||||
|
|
||||||
|
# 随机确定本次要粘贴的目标数量
|
||||||
|
num_to_paste = random.randint(*PASTE_COUNT_RANGE)
|
||||||
|
|
||||||
|
for _ in range(num_to_paste):
|
||||||
|
# 选择要粘贴的目标类别(只从有素材的类别中选)
|
||||||
|
valid_classes = [cls for cls, assets in asset_library.items() if len(assets) > 0]
|
||||||
|
if not valid_classes:
|
||||||
|
break # 极端情况:素材库临时为空(几乎不会发生)
|
||||||
|
|
||||||
|
# 随机选择一个目标类别和该类别下的一个素材
|
||||||
|
target_cls = random.choice(valid_classes)
|
||||||
|
target_obj = random.choice(asset_library[target_cls])
|
||||||
|
obj_h, obj_w, _ = target_obj.shape
|
||||||
|
|
||||||
|
# 跳过比底图大的目标(避免粘贴后超出边界)
|
||||||
|
if obj_h >= base_h or obj_w >= base_w:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 随机选择粘贴位置(左上角坐标,确保目标完全在底图内)
|
||||||
|
paste_x1 = random.randint(0, base_w - obj_w)
|
||||||
|
paste_y1 = random.randint(0, base_h - obj_h)
|
||||||
|
paste_x2 = paste_x1 + obj_w
|
||||||
|
paste_y2 = paste_y1 + obj_h
|
||||||
|
|
||||||
|
# 直接用Numpy切片粘贴目标(覆盖底图对应区域)
|
||||||
|
base_image[paste_y1:paste_y2, paste_x1:paste_x2] = target_obj
|
||||||
|
|
||||||
|
# 计算粘贴目标的YOLO归一化坐标(x_center, y_center, w, h)
|
||||||
|
yolo_x_center = (paste_x1 + obj_w / 2) / base_w
|
||||||
|
yolo_y_center = (paste_y1 + obj_h / 2) / base_h
|
||||||
|
yolo_w = obj_w / base_w
|
||||||
|
yolo_h = obj_h / base_h
|
||||||
|
|
||||||
|
# 将标签加入列表
|
||||||
|
pasted_bboxes.append([yolo_x_center, yolo_y_center, yolo_w, yolo_h])
|
||||||
|
pasted_labels.append(target_cls)
|
||||||
|
|
||||||
|
return base_image, pasted_bboxes, pasted_labels
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 1. 初始化:创建输出目录
|
||||||
|
os.makedirs(OUTPUT_IMAGE_DIR, exist_ok=True)
|
||||||
|
os.makedirs(OUTPUT_LABEL_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
# 2. 第一步:创建目标素材库(从特征素材目录提取可粘贴的目标)
|
||||||
|
try:
|
||||||
|
asset_library = harvest_objects_for_pasting(
|
||||||
|
feature_image_dir=SOURCE_FEATURE_IMAGE_DIR,
|
||||||
|
feature_label_dir=SOURCE_FEATURE_LABEL_DIR,
|
||||||
|
target_classes=SMALL_OBJECT_CLASSES_TO_PASTE
|
||||||
|
)
|
||||||
|
except (FileNotFoundError, ValueError) as e:
|
||||||
|
print(f"错误:{e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 3. 第二步:获取所有底图(只读取图片文件)
|
||||||
|
base_image_files = [f for f in os.listdir(BASE_IMAGE_DIR) if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
|
||||||
|
if not base_image_files:
|
||||||
|
print(f"错误:底图目录 {BASE_IMAGE_DIR} 中未找到任何图片!")
|
||||||
|
return
|
||||||
|
print(f"\n找到 {len(base_image_files)} 张底图,开始生成增强数据(每张底图生成 {AUGMENTATION_FACTOR} 张)")
|
||||||
|
|
||||||
|
# 4. 主循环:遍历每张底图,生成增强数据
|
||||||
|
for base_filename in tqdm(base_image_files, desc="处理底图"):
|
||||||
|
base_name, base_ext = os.path.splitext(base_filename)
|
||||||
|
base_image_path = os.path.join(BASE_IMAGE_DIR, base_filename)
|
||||||
|
|
||||||
|
# 读取底图(若读取失败则跳过)
|
||||||
|
base_image = cv2.imread(base_image_path)
|
||||||
|
if base_image is None:
|
||||||
|
tqdm.write(f"\n警告:无法读取底图 {base_filename},已跳过")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 为当前底图生成 AUGMENTATION_FACTOR 张增强图
|
||||||
|
for aug_idx in range(AUGMENTATION_FACTOR):
|
||||||
|
# 步骤1:复制底图(避免修改原始底图),并粘贴目标
|
||||||
|
base_image_copy = base_image.copy()
|
||||||
|
pasted_image, pasted_bboxes, pasted_labels = paste_objects_to_base(
|
||||||
|
base_image=base_image_copy,
|
||||||
|
asset_library=asset_library
|
||||||
|
)
|
||||||
|
|
||||||
|
# 步骤2:对粘贴后的图像应用常规增强(Albumentations需要RGB格式)
|
||||||
|
pasted_image_rgb = cv2.cvtColor(pasted_image, cv2.COLOR_BGR2RGB)
|
||||||
|
chosen_transform = random.choice(base_transforms) # 随机选择增强策略
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 应用增强(同时处理bbox和label)
|
||||||
|
augmented_result = chosen_transform(
|
||||||
|
image=pasted_image_rgb,
|
||||||
|
bboxes=pasted_bboxes,
|
||||||
|
class_labels=pasted_labels
|
||||||
|
)
|
||||||
|
final_image_rgb = augmented_result['image']
|
||||||
|
final_bboxes = augmented_result['bboxes']
|
||||||
|
final_labels = augmented_result['class_labels']
|
||||||
|
except Exception as e:
|
||||||
|
tqdm.write(f"\n警告:底图 {base_filename} 增强失败(序号 {aug_idx}):{str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 步骤3:保存增强后的图片和标签
|
||||||
|
# 图片命名格式:底图名_aug_序号.jpg(统一转为jpg格式,避免格式混乱)
|
||||||
|
output_img_name = f"{base_name}_aug_{aug_idx}.jpg"
|
||||||
|
output_img_path = os.path.join(OUTPUT_IMAGE_DIR, output_img_name)
|
||||||
|
# RGB转BGR(cv2保存需要BGR格式)
|
||||||
|
cv2.imwrite(output_img_path, cv2.cvtColor(final_image_rgb, cv2.COLOR_RGB2BGR))
|
||||||
|
|
||||||
|
# 标签命名格式:与图片同名.txt(YOLO格式)
|
||||||
|
output_label_name = f"{base_name}_aug_{aug_idx}.txt"
|
||||||
|
output_label_path = os.path.join(OUTPUT_LABEL_DIR, output_label_name)
|
||||||
|
|
||||||
|
with open(output_label_path, 'w') as f:
|
||||||
|
for bbox, label in zip(final_bboxes, final_labels):
|
||||||
|
x_c, y_c, w, h = bbox
|
||||||
|
# 边界检查:排除增强后可能超出0-1范围的bbox(避免训练报错)
|
||||||
|
if 0 <= x_c <= 1 and 0 <= y_c <= 1 and 0 <= w <= 1 and 0 <= h <= 1:
|
||||||
|
f.write(f"{label} {x_c:.6f} {y_c:.6f} {w:.6f} {h:.6f}\n")
|
||||||
|
|
||||||
|
# 5. 完成提示
|
||||||
|
total_generated = len(base_image_files) * AUGMENTATION_FACTOR
|
||||||
|
print(f"\n✅ 数据增强全部完成!")
|
||||||
|
print(f"📊 生成数据统计:")
|
||||||
|
print(f" - 底图数量:{len(base_image_files)} 张")
|
||||||
|
print(f" - 每张底图增强次数:{AUGMENTATION_FACTOR} 次")
|
||||||
|
print(f" - 总生成图片/标签:{total_generated} 组")
|
||||||
|
print(f" - 输出路径:")
|
||||||
|
print(f" 图片 → {OUTPUT_IMAGE_DIR}")
|
||||||
|
print(f" 标签 → {OUTPUT_LABEL_DIR}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 运行前务必确认:
|
||||||
|
# 1. SOURCE_FEATURE_IMAGE_DIR/SOURCE_FEATURE_LABEL_DIR 是「有目标的素材目录」
|
||||||
|
# 2. BASE_IMAGE_DIR 是你的「空白底图目录」
|
||||||
|
# 3. SMALL_OBJECT_CLASSES_TO_PASTE 是要粘贴的目标类别ID(如no_helmet=2)
|
||||||
|
main()
|
||||||
190
视频抽帧.py
Normal file
190
视频抽帧.py
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
from tqdm import tqdm
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# -------------------------- 1. 用户配置(根据实际情况修改)--------------------------
|
||||||
|
# 视频文件路径(绝对路径或相对路径均可,支持mp4/avi/mov等常见格式)
|
||||||
|
VIDEO_PATH = r"E:\geminicli\yolo\images\44.mp4" # 替换为你的视频路径
|
||||||
|
|
||||||
|
# 输出根目录(脚本会自动在该目录下创建train/valid/test子文件夹)
|
||||||
|
OUTPUT_ROOT_DIR = r"D:\DataPreHandler\images" # 替换为你的输出根路径
|
||||||
|
|
||||||
|
# 各文件夹需抽取的图片数量(总数量=2800+1400+700=4900)
|
||||||
|
FRAME_COUNTS = {
|
||||||
|
"train": 28,
|
||||||
|
"valid": 14,
|
||||||
|
"test": 7
|
||||||
|
}
|
||||||
|
|
||||||
|
# 图片保存格式(建议用jpg,兼容性更好;也可改为png)
|
||||||
|
SAVE_FORMAT = "jpg"
|
||||||
|
|
||||||
|
# 图片文件名编号位数(如0001.jpg,避免文件名乱序)
|
||||||
|
FILE_NUM_DIGITS = 4 # 对应最大编号9999,满足4900张需求
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def check_video_validity(cap):
|
||||||
|
"""检查视频是否能正常读取,并返回视频总帧数和帧率"""
|
||||||
|
if not cap.isOpened():
|
||||||
|
raise ValueError(f"无法打开视频文件!请检查路径:{VIDEO_PATH}")
|
||||||
|
|
||||||
|
# 获取视频总帧数(注意:部分视频可能返回-1,需特殊处理)
|
||||||
|
total_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
|
||||||
|
if total_frames == -1:
|
||||||
|
# 若无法直接获取总帧数,通过读取最后一帧间接计算
|
||||||
|
cap.set(cv2.CAP_PROP_POS_AVI_RATIO, 1) # 跳转到视频末尾
|
||||||
|
total_frames = cap.get(cv2.CAP_PROP_POS_FRAMES)
|
||||||
|
cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # 跳回视频开头
|
||||||
|
|
||||||
|
total_frames = int(total_frames)
|
||||||
|
required_total = sum(FRAME_COUNTS.values()) # 需抽取的总帧数
|
||||||
|
|
||||||
|
# 检查视频总帧数是否满足需求
|
||||||
|
if total_frames < required_total:
|
||||||
|
raise ValueError(
|
||||||
|
f"视频总帧数不足!\n"
|
||||||
|
f"视频实际帧数:{total_frames},需抽取帧数:{required_total}\n"
|
||||||
|
f"建议更换更长的视频,或减少各文件夹的抽取数量。"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取视频帧率(仅用于打印信息,不影响抽取逻辑)
|
||||||
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||||
|
video_duration = total_frames / fps # 视频总时长(秒)
|
||||||
|
|
||||||
|
print(f"✅ 视频信息读取成功:")
|
||||||
|
print(f" - 视频路径:{VIDEO_PATH}")
|
||||||
|
print(f" - 总帧数:{total_frames}")
|
||||||
|
print(f" - 帧率(FPS):{fps:.1f}")
|
||||||
|
print(f" - 总时长:{video_duration // 60:.0f}分{video_duration % 60:.1f}秒")
|
||||||
|
print(
|
||||||
|
f" - 需抽取总帧数:{required_total}(train:{FRAME_COUNTS['train']}, valid:{FRAME_COUNTS['valid']}, test:{FRAME_COUNTS['test']})")
|
||||||
|
|
||||||
|
return total_frames
|
||||||
|
|
||||||
|
|
||||||
|
def create_output_dirs():
|
||||||
|
"""创建输出根目录及train/valid/test子文件夹"""
|
||||||
|
# 转换为Path对象,适配Windows/Linux/macOS路径格式
|
||||||
|
output_root = Path(OUTPUT_ROOT_DIR)
|
||||||
|
subdirs = FRAME_COUNTS.keys()
|
||||||
|
|
||||||
|
for subdir in subdirs:
|
||||||
|
subdir_path = output_root / subdir
|
||||||
|
subdir_path.mkdir(parents=True, exist_ok=True) # parents=True创建父目录,exist_ok=True避免已存在时报错
|
||||||
|
print(f"📂 输出文件夹已创建/确认:{subdir_path}")
|
||||||
|
|
||||||
|
return {subdir: output_root / subdir for subdir in subdirs} # 返回各子文件夹路径字典
|
||||||
|
|
||||||
|
|
||||||
|
def generate_random_frame_indices(total_frames, required_total):
|
||||||
|
"""生成无重复的随机帧索引(范围:0 ~ total_frames-1)"""
|
||||||
|
print(f"\n🎲 正在生成{required_total}个无重复随机帧索引...")
|
||||||
|
# 使用numpy生成无重复随机整数(replace=False确保不重复)
|
||||||
|
random_indices = np.random.choice(
|
||||||
|
a=range(total_frames),
|
||||||
|
size=required_total,
|
||||||
|
replace=False
|
||||||
|
)
|
||||||
|
# 排序(可选,使抽取的帧按时间顺序保存,不排序则完全随机)
|
||||||
|
random_indices.sort()
|
||||||
|
print(f"✅ 随机帧索引生成完成(共{len(random_indices)}个)")
|
||||||
|
return random_indices
|
||||||
|
|
||||||
|
|
||||||
|
def split_indices_by_dataset(random_indices):
|
||||||
|
"""将总随机索引按train/valid/test的数量拆分"""
|
||||||
|
train_count = FRAME_COUNTS["train"]
|
||||||
|
valid_count = FRAME_COUNTS["valid"]
|
||||||
|
|
||||||
|
# 拆分逻辑:前N个给train,中间M个给valid,剩余给test
|
||||||
|
indices_split = {
|
||||||
|
"train": random_indices[:train_count],
|
||||||
|
"valid": random_indices[train_count:train_count + valid_count],
|
||||||
|
"test": random_indices[train_count + valid_count:]
|
||||||
|
}
|
||||||
|
|
||||||
|
# 验证拆分数量是否正确(避免配置错误)
|
||||||
|
for dataset, indices in indices_split.items():
|
||||||
|
assert len(indices) == FRAME_COUNTS[dataset], \
|
||||||
|
f"{dataset}索引拆分错误!预期{FRAME_COUNTS[dataset]}个,实际{len(indices)}个"
|
||||||
|
|
||||||
|
print(f"\n📊 索引拆分完成:")
|
||||||
|
for dataset, indices in indices_split.items():
|
||||||
|
print(f" - {dataset}:{len(indices)}个帧索引(范围:{indices[0]} ~ {indices[-1]})")
|
||||||
|
|
||||||
|
return indices_split
|
||||||
|
|
||||||
|
|
||||||
|
def extract_and_save_frames(cap, indices_split, output_dirs):
|
||||||
|
"""根据拆分后的索引,抽取视频帧并保存到对应文件夹"""
|
||||||
|
print(f"\n🚀 开始抽取并保存视频帧...")
|
||||||
|
|
||||||
|
for dataset, indices in indices_split.items():
|
||||||
|
dataset_dir = output_dirs[dataset]
|
||||||
|
print(f"\n--- 正在处理 {dataset} 集(共{len(indices)}张)---")
|
||||||
|
|
||||||
|
# 用tqdm显示进度条
|
||||||
|
for idx, frame_idx in tqdm(enumerate(indices, 1), total=len(indices), desc=f"{dataset}进度"):
|
||||||
|
# 跳转到指定帧(关键步骤:确保读取到正确的帧)
|
||||||
|
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
|
||||||
|
|
||||||
|
# 读取帧(ret为True表示读取成功,frame为帧数据)
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret:
|
||||||
|
print(f"⚠️ 警告:无法读取帧索引{frame_idx},已跳过该帧")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 生成文件名(如train_0001.jpg)
|
||||||
|
file_name = f"{dataset}_{idx:0{FILE_NUM_DIGITS}d}.{SAVE_FORMAT}" # 0{FILE_NUM_DIGITS}d表示补零到指定位数
|
||||||
|
save_path = dataset_dir / file_name
|
||||||
|
|
||||||
|
# 保存帧为图片(cv2.imwrite默认保存为BGR格式,符合图片存储标准)
|
||||||
|
cv2.imwrite(str(save_path), frame)
|
||||||
|
|
||||||
|
print(f"\n🎉 所有帧抽取与保存完成!")
|
||||||
|
# 打印最终结果汇总
|
||||||
|
print(f"\n📋 结果汇总:")
|
||||||
|
for dataset, dataset_dir in output_dirs.items():
|
||||||
|
# 统计实际保存的图片数量(避免因读取失败导致数量不足)
|
||||||
|
actual_count = len([f for f in dataset_dir.glob(f"*.{SAVE_FORMAT}")])
|
||||||
|
print(f" - {dataset}集:预期{FRAME_COUNTS[dataset]}张,实际保存{actual_count}张,路径:{dataset_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
try:
|
||||||
|
# 1. 初始化视频读取器
|
||||||
|
cap = cv2.VideoCapture(VIDEO_PATH)
|
||||||
|
|
||||||
|
# 2. 检查视频有效性并获取总帧数
|
||||||
|
total_frames = check_video_validity(cap)
|
||||||
|
|
||||||
|
# 3. 创建输出文件夹
|
||||||
|
output_dirs = create_output_dirs()
|
||||||
|
|
||||||
|
# 4. 生成无重复随机帧索引
|
||||||
|
required_total = sum(FRAME_COUNTS.values())
|
||||||
|
random_indices = generate_random_frame_indices(total_frames, required_total)
|
||||||
|
|
||||||
|
# 5. 按数据集拆分索引
|
||||||
|
indices_split = split_indices_by_dataset(random_indices)
|
||||||
|
|
||||||
|
# 6. 抽取并保存帧
|
||||||
|
extract_and_save_frames(cap, indices_split, output_dirs)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# 捕获所有异常并友好提示
|
||||||
|
print(f"\n❌ 脚本执行失败:{str(e)}")
|
||||||
|
finally:
|
||||||
|
# 无论是否报错,都关闭视频读取器(释放资源)
|
||||||
|
if 'cap' in locals() and cap.isOpened():
|
||||||
|
cap.release()
|
||||||
|
print(f"\n🔌 视频资源已释放")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user