Files
prehadler/视频抽帧.py
2025-09-26 10:23:45 +08:00

190 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
import os
from tqdm import tqdm
from pathlib import Path
# -------------------------- 1. 用户配置(根据实际情况修改)--------------------------
# 视频文件路径绝对路径或相对路径均可支持mp4/avi/mov等常见格式
VIDEO_PATH = r"E:\geminicli\yolo\images\44.mp4" # 替换为你的视频路径
# 输出根目录脚本会自动在该目录下创建train/valid/test子文件夹
OUTPUT_ROOT_DIR = r"D:\DataPreHandler\images" # 替换为你的输出根路径
# 各文件夹需抽取的图片数量(总数量=2800+1400+700=4900
FRAME_COUNTS = {
"train": 28,
"valid": 14,
"test": 7
}
# 图片保存格式建议用jpg兼容性更好也可改为png
SAVE_FORMAT = "jpg"
# 图片文件名编号位数如0001.jpg避免文件名乱序
FILE_NUM_DIGITS = 4 # 对应最大编号9999满足4900张需求
# -----------------------------------------------------------------------------------
def check_video_validity(cap):
"""检查视频是否能正常读取,并返回视频总帧数和帧率"""
if not cap.isOpened():
raise ValueError(f"无法打开视频文件!请检查路径:{VIDEO_PATH}")
# 获取视频总帧数(注意:部分视频可能返回-1需特殊处理
total_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
if total_frames == -1:
# 若无法直接获取总帧数,通过读取最后一帧间接计算
cap.set(cv2.CAP_PROP_POS_AVI_RATIO, 1) # 跳转到视频末尾
total_frames = cap.get(cv2.CAP_PROP_POS_FRAMES)
cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # 跳回视频开头
total_frames = int(total_frames)
required_total = sum(FRAME_COUNTS.values()) # 需抽取的总帧数
# 检查视频总帧数是否满足需求
if total_frames < required_total:
raise ValueError(
f"视频总帧数不足!\n"
f"视频实际帧数:{total_frames},需抽取帧数:{required_total}\n"
f"建议更换更长的视频,或减少各文件夹的抽取数量。"
)
# 获取视频帧率(仅用于打印信息,不影响抽取逻辑)
fps = cap.get(cv2.CAP_PROP_FPS)
video_duration = total_frames / fps # 视频总时长(秒)
print(f"✅ 视频信息读取成功:")
print(f" - 视频路径:{VIDEO_PATH}")
print(f" - 总帧数:{total_frames}")
print(f" - 帧率(FPS){fps:.1f}")
print(f" - 总时长:{video_duration // 60:.0f}{video_duration % 60:.1f}")
print(
f" - 需抽取总帧数:{required_total}train:{FRAME_COUNTS['train']}, valid:{FRAME_COUNTS['valid']}, test:{FRAME_COUNTS['test']}")
return total_frames
def create_output_dirs():
"""创建输出根目录及train/valid/test子文件夹"""
# 转换为Path对象适配Windows/Linux/macOS路径格式
output_root = Path(OUTPUT_ROOT_DIR)
subdirs = FRAME_COUNTS.keys()
for subdir in subdirs:
subdir_path = output_root / subdir
subdir_path.mkdir(parents=True, exist_ok=True) # parents=True创建父目录exist_ok=True避免已存在时报错
print(f"📂 输出文件夹已创建/确认:{subdir_path}")
return {subdir: output_root / subdir for subdir in subdirs} # 返回各子文件夹路径字典
def generate_random_frame_indices(total_frames, required_total):
"""生成无重复的随机帧索引范围0 ~ total_frames-1"""
print(f"\n🎲 正在生成{required_total}个无重复随机帧索引...")
# 使用numpy生成无重复随机整数replace=False确保不重复
random_indices = np.random.choice(
a=range(total_frames),
size=required_total,
replace=False
)
# 排序(可选,使抽取的帧按时间顺序保存,不排序则完全随机)
random_indices.sort()
print(f"✅ 随机帧索引生成完成(共{len(random_indices)}个)")
return random_indices
def split_indices_by_dataset(random_indices):
"""将总随机索引按train/valid/test的数量拆分"""
train_count = FRAME_COUNTS["train"]
valid_count = FRAME_COUNTS["valid"]
# 拆分逻辑前N个给train中间M个给valid剩余给test
indices_split = {
"train": random_indices[:train_count],
"valid": random_indices[train_count:train_count + valid_count],
"test": random_indices[train_count + valid_count:]
}
# 验证拆分数量是否正确(避免配置错误)
for dataset, indices in indices_split.items():
assert len(indices) == FRAME_COUNTS[dataset], \
f"{dataset}索引拆分错误!预期{FRAME_COUNTS[dataset]}个,实际{len(indices)}"
print(f"\n📊 索引拆分完成:")
for dataset, indices in indices_split.items():
print(f" - {dataset}{len(indices)}个帧索引(范围:{indices[0]} ~ {indices[-1]}")
return indices_split
def extract_and_save_frames(cap, indices_split, output_dirs):
"""根据拆分后的索引,抽取视频帧并保存到对应文件夹"""
print(f"\n🚀 开始抽取并保存视频帧...")
for dataset, indices in indices_split.items():
dataset_dir = output_dirs[dataset]
print(f"\n--- 正在处理 {dataset} 集(共{len(indices)}张)---")
# 用tqdm显示进度条
for idx, frame_idx in tqdm(enumerate(indices, 1), total=len(indices), desc=f"{dataset}进度"):
# 跳转到指定帧(关键步骤:确保读取到正确的帧)
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
# 读取帧ret为True表示读取成功frame为帧数据
ret, frame = cap.read()
if not ret:
print(f"⚠️ 警告:无法读取帧索引{frame_idx},已跳过该帧")
continue
# 生成文件名如train_0001.jpg
file_name = f"{dataset}_{idx:0{FILE_NUM_DIGITS}d}.{SAVE_FORMAT}" # 0{FILE_NUM_DIGITS}d表示补零到指定位数
save_path = dataset_dir / file_name
# 保存帧为图片cv2.imwrite默认保存为BGR格式符合图片存储标准
cv2.imwrite(str(save_path), frame)
print(f"\n🎉 所有帧抽取与保存完成!")
# 打印最终结果汇总
print(f"\n📋 结果汇总:")
for dataset, dataset_dir in output_dirs.items():
# 统计实际保存的图片数量(避免因读取失败导致数量不足)
actual_count = len([f for f in dataset_dir.glob(f"*.{SAVE_FORMAT}")])
print(f" - {dataset}集:预期{FRAME_COUNTS[dataset]}张,实际保存{actual_count}张,路径:{dataset_dir}")
def main():
try:
# 1. 初始化视频读取器
cap = cv2.VideoCapture(VIDEO_PATH)
# 2. 检查视频有效性并获取总帧数
total_frames = check_video_validity(cap)
# 3. 创建输出文件夹
output_dirs = create_output_dirs()
# 4. 生成无重复随机帧索引
required_total = sum(FRAME_COUNTS.values())
random_indices = generate_random_frame_indices(total_frames, required_total)
# 5. 按数据集拆分索引
indices_split = split_indices_by_dataset(random_indices)
# 6. 抽取并保存帧
extract_and_save_frames(cap, indices_split, output_dirs)
except Exception as e:
# 捕获所有异常并友好提示
print(f"\n❌ 脚本执行失败:{str(e)}")
finally:
# 无论是否报错,都关闭视频读取器(释放资源)
if 'cap' in locals() and cap.isOpened():
cap.release()
print(f"\n🔌 视频资源已释放")
if __name__ == "__main__":
main()