190 lines
7.6 KiB
Python
190 lines
7.6 KiB
Python
import cv2
|
||
import numpy as np
|
||
import os
|
||
from tqdm import tqdm
|
||
from pathlib import Path
|
||
|
||
# -------------------------- 1. 用户配置(根据实际情况修改)--------------------------
|
||
# 视频文件路径(绝对路径或相对路径均可,支持mp4/avi/mov等常见格式)
|
||
VIDEO_PATH = r"E:\geminicli\yolo\images\44.mp4" # 替换为你的视频路径
|
||
|
||
# 输出根目录(脚本会自动在该目录下创建train/valid/test子文件夹)
|
||
OUTPUT_ROOT_DIR = r"D:\DataPreHandler\images" # 替换为你的输出根路径
|
||
|
||
# 各文件夹需抽取的图片数量(总数量=2800+1400+700=4900)
|
||
FRAME_COUNTS = {
|
||
"train": 28,
|
||
"valid": 14,
|
||
"test": 7
|
||
}
|
||
|
||
# 图片保存格式(建议用jpg,兼容性更好;也可改为png)
|
||
SAVE_FORMAT = "jpg"
|
||
|
||
# 图片文件名编号位数(如0001.jpg,避免文件名乱序)
|
||
FILE_NUM_DIGITS = 4 # 对应最大编号9999,满足4900张需求
|
||
|
||
|
||
# -----------------------------------------------------------------------------------
|
||
|
||
|
||
def check_video_validity(cap):
|
||
"""检查视频是否能正常读取,并返回视频总帧数和帧率"""
|
||
if not cap.isOpened():
|
||
raise ValueError(f"无法打开视频文件!请检查路径:{VIDEO_PATH}")
|
||
|
||
# 获取视频总帧数(注意:部分视频可能返回-1,需特殊处理)
|
||
total_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
|
||
if total_frames == -1:
|
||
# 若无法直接获取总帧数,通过读取最后一帧间接计算
|
||
cap.set(cv2.CAP_PROP_POS_AVI_RATIO, 1) # 跳转到视频末尾
|
||
total_frames = cap.get(cv2.CAP_PROP_POS_FRAMES)
|
||
cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # 跳回视频开头
|
||
|
||
total_frames = int(total_frames)
|
||
required_total = sum(FRAME_COUNTS.values()) # 需抽取的总帧数
|
||
|
||
# 检查视频总帧数是否满足需求
|
||
if total_frames < required_total:
|
||
raise ValueError(
|
||
f"视频总帧数不足!\n"
|
||
f"视频实际帧数:{total_frames},需抽取帧数:{required_total}\n"
|
||
f"建议更换更长的视频,或减少各文件夹的抽取数量。"
|
||
)
|
||
|
||
# 获取视频帧率(仅用于打印信息,不影响抽取逻辑)
|
||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||
video_duration = total_frames / fps # 视频总时长(秒)
|
||
|
||
print(f"✅ 视频信息读取成功:")
|
||
print(f" - 视频路径:{VIDEO_PATH}")
|
||
print(f" - 总帧数:{total_frames}")
|
||
print(f" - 帧率(FPS):{fps:.1f}")
|
||
print(f" - 总时长:{video_duration // 60:.0f}分{video_duration % 60:.1f}秒")
|
||
print(
|
||
f" - 需抽取总帧数:{required_total}(train:{FRAME_COUNTS['train']}, valid:{FRAME_COUNTS['valid']}, test:{FRAME_COUNTS['test']})")
|
||
|
||
return total_frames
|
||
|
||
|
||
def create_output_dirs():
|
||
"""创建输出根目录及train/valid/test子文件夹"""
|
||
# 转换为Path对象,适配Windows/Linux/macOS路径格式
|
||
output_root = Path(OUTPUT_ROOT_DIR)
|
||
subdirs = FRAME_COUNTS.keys()
|
||
|
||
for subdir in subdirs:
|
||
subdir_path = output_root / subdir
|
||
subdir_path.mkdir(parents=True, exist_ok=True) # parents=True创建父目录,exist_ok=True避免已存在时报错
|
||
print(f"📂 输出文件夹已创建/确认:{subdir_path}")
|
||
|
||
return {subdir: output_root / subdir for subdir in subdirs} # 返回各子文件夹路径字典
|
||
|
||
|
||
def generate_random_frame_indices(total_frames, required_total):
|
||
"""生成无重复的随机帧索引(范围:0 ~ total_frames-1)"""
|
||
print(f"\n🎲 正在生成{required_total}个无重复随机帧索引...")
|
||
# 使用numpy生成无重复随机整数(replace=False确保不重复)
|
||
random_indices = np.random.choice(
|
||
a=range(total_frames),
|
||
size=required_total,
|
||
replace=False
|
||
)
|
||
# 排序(可选,使抽取的帧按时间顺序保存,不排序则完全随机)
|
||
random_indices.sort()
|
||
print(f"✅ 随机帧索引生成完成(共{len(random_indices)}个)")
|
||
return random_indices
|
||
|
||
|
||
def split_indices_by_dataset(random_indices):
|
||
"""将总随机索引按train/valid/test的数量拆分"""
|
||
train_count = FRAME_COUNTS["train"]
|
||
valid_count = FRAME_COUNTS["valid"]
|
||
|
||
# 拆分逻辑:前N个给train,中间M个给valid,剩余给test
|
||
indices_split = {
|
||
"train": random_indices[:train_count],
|
||
"valid": random_indices[train_count:train_count + valid_count],
|
||
"test": random_indices[train_count + valid_count:]
|
||
}
|
||
|
||
# 验证拆分数量是否正确(避免配置错误)
|
||
for dataset, indices in indices_split.items():
|
||
assert len(indices) == FRAME_COUNTS[dataset], \
|
||
f"{dataset}索引拆分错误!预期{FRAME_COUNTS[dataset]}个,实际{len(indices)}个"
|
||
|
||
print(f"\n📊 索引拆分完成:")
|
||
for dataset, indices in indices_split.items():
|
||
print(f" - {dataset}:{len(indices)}个帧索引(范围:{indices[0]} ~ {indices[-1]})")
|
||
|
||
return indices_split
|
||
|
||
|
||
def extract_and_save_frames(cap, indices_split, output_dirs):
|
||
"""根据拆分后的索引,抽取视频帧并保存到对应文件夹"""
|
||
print(f"\n🚀 开始抽取并保存视频帧...")
|
||
|
||
for dataset, indices in indices_split.items():
|
||
dataset_dir = output_dirs[dataset]
|
||
print(f"\n--- 正在处理 {dataset} 集(共{len(indices)}张)---")
|
||
|
||
# 用tqdm显示进度条
|
||
for idx, frame_idx in tqdm(enumerate(indices, 1), total=len(indices), desc=f"{dataset}进度"):
|
||
# 跳转到指定帧(关键步骤:确保读取到正确的帧)
|
||
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
|
||
|
||
# 读取帧(ret为True表示读取成功,frame为帧数据)
|
||
ret, frame = cap.read()
|
||
if not ret:
|
||
print(f"⚠️ 警告:无法读取帧索引{frame_idx},已跳过该帧")
|
||
continue
|
||
|
||
# 生成文件名(如train_0001.jpg)
|
||
file_name = f"{dataset}_{idx:0{FILE_NUM_DIGITS}d}.{SAVE_FORMAT}" # 0{FILE_NUM_DIGITS}d表示补零到指定位数
|
||
save_path = dataset_dir / file_name
|
||
|
||
# 保存帧为图片(cv2.imwrite默认保存为BGR格式,符合图片存储标准)
|
||
cv2.imwrite(str(save_path), frame)
|
||
|
||
print(f"\n🎉 所有帧抽取与保存完成!")
|
||
# 打印最终结果汇总
|
||
print(f"\n📋 结果汇总:")
|
||
for dataset, dataset_dir in output_dirs.items():
|
||
# 统计实际保存的图片数量(避免因读取失败导致数量不足)
|
||
actual_count = len([f for f in dataset_dir.glob(f"*.{SAVE_FORMAT}")])
|
||
print(f" - {dataset}集:预期{FRAME_COUNTS[dataset]}张,实际保存{actual_count}张,路径:{dataset_dir}")
|
||
|
||
|
||
def main():
|
||
try:
|
||
# 1. 初始化视频读取器
|
||
cap = cv2.VideoCapture(VIDEO_PATH)
|
||
|
||
# 2. 检查视频有效性并获取总帧数
|
||
total_frames = check_video_validity(cap)
|
||
|
||
# 3. 创建输出文件夹
|
||
output_dirs = create_output_dirs()
|
||
|
||
# 4. 生成无重复随机帧索引
|
||
required_total = sum(FRAME_COUNTS.values())
|
||
random_indices = generate_random_frame_indices(total_frames, required_total)
|
||
|
||
# 5. 按数据集拆分索引
|
||
indices_split = split_indices_by_dataset(random_indices)
|
||
|
||
# 6. 抽取并保存帧
|
||
extract_and_save_frames(cap, indices_split, output_dirs)
|
||
|
||
except Exception as e:
|
||
# 捕获所有异常并友好提示
|
||
print(f"\n❌ 脚本执行失败:{str(e)}")
|
||
finally:
|
||
# 无论是否报错,都关闭视频读取器(释放资源)
|
||
if 'cap' in locals() and cap.isOpened():
|
||
cap.release()
|
||
print(f"\n🔌 视频资源已释放")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |