数据增强相关代码

This commit is contained in:
2025-10-10 11:39:23 +08:00
parent e7c35a017b
commit e5ef7ec066
5 changed files with 814 additions and 14 deletions

View File

@ -6,23 +6,23 @@ from pathlib import Path
# -------------------------- 1. 用户配置(根据实际情况修改)--------------------------
# 视频文件路径绝对路径或相对路径均可支持mp4/avi/mov等常见格式
VIDEO_PATH = r"E:\geminicli\yolo\images\44.mp4" # 替换为你的视频路径
# 输出根目录脚本会自动在该目录下创建train/valid/test子文件夹
OUTPUT_ROOT_DIR = r"D:\DataPreHandler\images" # 替换为你的输出根路径
# VIDEO_PATH = r"E:\geminicli\yolo\images\屏幕操作视频.mp4" # 替换为你的视频路径
VIDEO_PATH = r"D:\DataPreHandler\MP4\花生底图.mp4"
# 输出根目录脚本会自动在该目录下创建train/val/test子文件夹
OUTPUT_ROOT_DIR = r"D:\DataPreHandler\images\dituchoqu_huashen" # 替换为你的输出根路径
# 各文件夹需抽取的图片数量(总数量=2800+1400+700=4900
FRAME_COUNTS = {
"train": 28,
"valid": 14,
"test": 7
"train": 1200,
"val": 150,
"test": 50
}
# 图片保存格式建议用jpg兼容性更好也可改为png
SAVE_FORMAT = "jpg"
# 图片文件名编号位数如0001.jpg避免文件名乱序
FILE_NUM_DIGITS = 4 # 对应最大编号9999满足4900张需求
FILE_NUM_DIGITS = 5 # 对应最大编号9999满足4900张需求
# -----------------------------------------------------------------------------------
@ -62,13 +62,13 @@ def check_video_validity(cap):
print(f" - 帧率(FPS){fps:.1f}")
print(f" - 总时长:{video_duration // 60:.0f}{video_duration % 60:.1f}")
print(
f" - 需抽取总帧数:{required_total}train:{FRAME_COUNTS['train']}, valid:{FRAME_COUNTS['valid']}, test:{FRAME_COUNTS['test']}")
f" - 需抽取总帧数:{required_total}train:{FRAME_COUNTS['train']}, val:{FRAME_COUNTS['val']}, test:{FRAME_COUNTS['test']}")
return total_frames
def create_output_dirs():
"""创建输出根目录及train/valid/test子文件夹"""
"""创建输出根目录及train/val/test子文件夹"""
# 转换为Path对象适配Windows/Linux/macOS路径格式
output_root = Path(OUTPUT_ROOT_DIR)
subdirs = FRAME_COUNTS.keys()
@ -97,14 +97,14 @@ def generate_random_frame_indices(total_frames, required_total):
def split_indices_by_dataset(random_indices):
"""将总随机索引按train/valid/test的数量拆分"""
"""将总随机索引按train/val/test的数量拆分"""
train_count = FRAME_COUNTS["train"]
valid_count = FRAME_COUNTS["valid"]
valid_count = FRAME_COUNTS["val"]
# 拆分逻辑前N个给train中间M个给valid剩余给test
indices_split = {
"train": random_indices[:train_count],
"valid": random_indices[train_count:train_count + valid_count],
"val": random_indices[train_count:train_count + valid_count],
"test": random_indices[train_count + valid_count:]
}
@ -140,7 +140,7 @@ def extract_and_save_frames(cap, indices_split, output_dirs):
continue
# 生成文件名如train_0001.jpg
file_name = f"{dataset}_{idx:0{FILE_NUM_DIGITS}d}.{SAVE_FORMAT}" # 0{FILE_NUM_DIGITS}d表示补零到指定位数
file_name = f"dths_{dataset}_{idx:0{FILE_NUM_DIGITS}d}.{SAVE_FORMAT}" # 0{FILE_NUM_DIGITS}d表示补零到指定位数
save_path = dataset_dir / file_name
# 保存帧为图片cv2.imwrite默认保存为BGR格式符合图片存储标准