import cv2 import numpy as np import os from tqdm import tqdm from pathlib import Path # -------------------------- 1. 用户配置(根据实际情况修改)-------------------------- # 视频文件路径(绝对路径或相对路径均可,支持mp4/avi/mov等常见格式) # VIDEO_PATH = r"E:\geminicli\yolo\images\屏幕操作视频.mp4" # 替换为你的视频路径 VIDEO_PATH = r"D:\DataPreHandler\MP4\花生底图.mp4" # 输出根目录(脚本会自动在该目录下创建train/val/test子文件夹) OUTPUT_ROOT_DIR = r"D:\DataPreHandler\images\dituchoqu_huashen" # 替换为你的输出根路径 # 各文件夹需抽取的图片数量(总数量=2800+1400+700=4900) FRAME_COUNTS = { "train": 1200, "val": 150, "test": 50 } # 图片保存格式(建议用jpg,兼容性更好;也可改为png) SAVE_FORMAT = "jpg" # 图片文件名编号位数(如0001.jpg,避免文件名乱序) FILE_NUM_DIGITS = 5 # 对应最大编号9999,满足4900张需求 # ----------------------------------------------------------------------------------- def check_video_validity(cap): """检查视频是否能正常读取,并返回视频总帧数和帧率""" if not cap.isOpened(): raise ValueError(f"无法打开视频文件!请检查路径:{VIDEO_PATH}") # 获取视频总帧数(注意:部分视频可能返回-1,需特殊处理) total_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT) if total_frames == -1: # 若无法直接获取总帧数,通过读取最后一帧间接计算 cap.set(cv2.CAP_PROP_POS_AVI_RATIO, 1) # 跳转到视频末尾 total_frames = cap.get(cv2.CAP_PROP_POS_FRAMES) cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # 跳回视频开头 total_frames = int(total_frames) required_total = sum(FRAME_COUNTS.values()) # 需抽取的总帧数 # 检查视频总帧数是否满足需求 if total_frames < required_total: raise ValueError( f"视频总帧数不足!\n" f"视频实际帧数:{total_frames},需抽取帧数:{required_total}\n" f"建议更换更长的视频,或减少各文件夹的抽取数量。" ) # 获取视频帧率(仅用于打印信息,不影响抽取逻辑) fps = cap.get(cv2.CAP_PROP_FPS) video_duration = total_frames / fps # 视频总时长(秒) print(f"✅ 视频信息读取成功:") print(f" - 视频路径:{VIDEO_PATH}") print(f" - 总帧数:{total_frames}") print(f" - 帧率(FPS):{fps:.1f}") print(f" - 总时长:{video_duration // 60:.0f}分{video_duration % 60:.1f}秒") print( f" - 需抽取总帧数:{required_total}(train:{FRAME_COUNTS['train']}, val:{FRAME_COUNTS['val']}, test:{FRAME_COUNTS['test']})") return total_frames def create_output_dirs(): """创建输出根目录及train/val/test子文件夹""" # 转换为Path对象,适配Windows/Linux/macOS路径格式 output_root = Path(OUTPUT_ROOT_DIR) subdirs = FRAME_COUNTS.keys() for subdir in subdirs: subdir_path = output_root / subdir subdir_path.mkdir(parents=True, exist_ok=True) # parents=True创建父目录,exist_ok=True避免已存在时报错 print(f"📂 输出文件夹已创建/确认:{subdir_path}") return {subdir: output_root / subdir for subdir in subdirs} # 返回各子文件夹路径字典 def generate_random_frame_indices(total_frames, required_total): """生成无重复的随机帧索引(范围:0 ~ total_frames-1)""" print(f"\n🎲 正在生成{required_total}个无重复随机帧索引...") # 使用numpy生成无重复随机整数(replace=False确保不重复) random_indices = np.random.choice( a=range(total_frames), size=required_total, replace=False ) # 排序(可选,使抽取的帧按时间顺序保存,不排序则完全随机) random_indices.sort() print(f"✅ 随机帧索引生成完成(共{len(random_indices)}个)") return random_indices def split_indices_by_dataset(random_indices): """将总随机索引按train/val/test的数量拆分""" train_count = FRAME_COUNTS["train"] valid_count = FRAME_COUNTS["val"] # 拆分逻辑:前N个给train,中间M个给valid,剩余给test indices_split = { "train": random_indices[:train_count], "val": random_indices[train_count:train_count + valid_count], "test": random_indices[train_count + valid_count:] } # 验证拆分数量是否正确(避免配置错误) for dataset, indices in indices_split.items(): assert len(indices) == FRAME_COUNTS[dataset], \ f"{dataset}索引拆分错误!预期{FRAME_COUNTS[dataset]}个,实际{len(indices)}个" print(f"\n📊 索引拆分完成:") for dataset, indices in indices_split.items(): print(f" - {dataset}:{len(indices)}个帧索引(范围:{indices[0]} ~ {indices[-1]})") return indices_split def extract_and_save_frames(cap, indices_split, output_dirs): """根据拆分后的索引,抽取视频帧并保存到对应文件夹""" print(f"\n🚀 开始抽取并保存视频帧...") for dataset, indices in indices_split.items(): dataset_dir = output_dirs[dataset] print(f"\n--- 正在处理 {dataset} 集(共{len(indices)}张)---") # 用tqdm显示进度条 for idx, frame_idx in tqdm(enumerate(indices, 1), total=len(indices), desc=f"{dataset}进度"): # 跳转到指定帧(关键步骤:确保读取到正确的帧) cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) # 读取帧(ret为True表示读取成功,frame为帧数据) ret, frame = cap.read() if not ret: print(f"⚠️ 警告:无法读取帧索引{frame_idx},已跳过该帧") continue # 生成文件名(如train_0001.jpg) file_name = f"dths_{dataset}_{idx:0{FILE_NUM_DIGITS}d}.{SAVE_FORMAT}" # 0{FILE_NUM_DIGITS}d表示补零到指定位数 save_path = dataset_dir / file_name # 保存帧为图片(cv2.imwrite默认保存为BGR格式,符合图片存储标准) cv2.imwrite(str(save_path), frame) print(f"\n🎉 所有帧抽取与保存完成!") # 打印最终结果汇总 print(f"\n📋 结果汇总:") for dataset, dataset_dir in output_dirs.items(): # 统计实际保存的图片数量(避免因读取失败导致数量不足) actual_count = len([f for f in dataset_dir.glob(f"*.{SAVE_FORMAT}")]) print(f" - {dataset}集:预期{FRAME_COUNTS[dataset]}张,实际保存{actual_count}张,路径:{dataset_dir}") def main(): try: # 1. 初始化视频读取器 cap = cv2.VideoCapture(VIDEO_PATH) # 2. 检查视频有效性并获取总帧数 total_frames = check_video_validity(cap) # 3. 创建输出文件夹 output_dirs = create_output_dirs() # 4. 生成无重复随机帧索引 required_total = sum(FRAME_COUNTS.values()) random_indices = generate_random_frame_indices(total_frames, required_total) # 5. 按数据集拆分索引 indices_split = split_indices_by_dataset(random_indices) # 6. 抽取并保存帧 extract_and_save_frames(cap, indices_split, output_dirs) except Exception as e: # 捕获所有异常并友好提示 print(f"\n❌ 脚本执行失败:{str(e)}") finally: # 无论是否报错,都关闭视频读取器(释放资源) if 'cap' in locals() and cap.isOpened(): cap.release() print(f"\n🔌 视频资源已释放") if __name__ == "__main__": main()