数据增强相关代码
This commit is contained in:
319
copypaste整张特征图.py
Normal file
319
copypaste整张特征图.py
Normal file
@ -0,0 +1,319 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from collections import defaultdict
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# --- 1. 核心配置 ---
|
||||||
|
# 特征素材来源目录(包含完整图片和对应标签)
|
||||||
|
# SOURCE_FEATURE_IMAGE_DIR = r"D:\DataPreHandler\yuanshi_data\images\train\images"
|
||||||
|
# SOURCE_FEATURE_LABEL_DIR = r"D:\DataPreHandler\yuanshi_data\images\train\labels"
|
||||||
|
# SOURCE_FEATURE_IMAGE_DIR = r"D:\DataPreHandler\yuanshi_data\images\val\images"
|
||||||
|
# SOURCE_FEATURE_LABEL_DIR = r"D:\DataPreHandler\yuanshi_data\images\val\labels"
|
||||||
|
SOURCE_FEATURE_IMAGE_DIR = r"D:\DataPreHandler\yuanshi_data\images\test\images"
|
||||||
|
SOURCE_FEATURE_LABEL_DIR = r"D:\DataPreHandler\yuanshi_data\images\test\labels"
|
||||||
|
|
||||||
|
# 底图目录
|
||||||
|
# BASE_IMAGE_DIR = r"D:\DataPreHandler\images\dituchoqu\train"
|
||||||
|
BASE_IMAGE_DIR = r"D:\DataPreHandler\images\dituchoqu_huashen\train"
|
||||||
|
# BASE_IMAGE_DIR = r"D:\DataPreHandler\images\test"
|
||||||
|
|
||||||
|
# 输出目录
|
||||||
|
OUTPUT_IMAGE_DIR = r"D:\DataPreHandler\data\train3\images"
|
||||||
|
OUTPUT_LABEL_DIR = r"D:\DataPreHandler\data\train3\labels"
|
||||||
|
# OUTPUT_IMAGE_DIR = r"D:\DataPreHandler\data\val3\images"
|
||||||
|
# OUTPUT_LABEL_DIR = r"D:\DataPreHandler\data\val3\labels"
|
||||||
|
# OUTPUT_IMAGE_DIR = r"D:\DataPreHandler\data\test\da\images"
|
||||||
|
# OUTPUT_LABEL_DIR = r"D:\DataPreHandler\data\test\da\labels"
|
||||||
|
|
||||||
|
# 底图数量
|
||||||
|
SELECT_BASE_IMAGE_COUNT = 0
|
||||||
|
|
||||||
|
|
||||||
|
AUGMENTATION_FACTOR = 1
|
||||||
|
TARGET_CLASSES = [0, 1, 2, 3, 4, 5, 6]
|
||||||
|
PASTE_COUNT_RANGE = (2, 3)
|
||||||
|
MIN_SCALED_SIZE = 50
|
||||||
|
MAX_OVERLAP_IOU = 0.0 # 仅限制“跨特征图”无重叠
|
||||||
|
MAX_RETRY_COUNT = 300
|
||||||
|
SCALE_FACTOR_RANGE = (0.3, 0.9)
|
||||||
|
|
||||||
|
|
||||||
|
# --- 2. 工具函数 ---(核心修改:删除内部重叠检测)
|
||||||
|
def calculate_iou(box1, box2):
|
||||||
|
"""计算两个边界框的交并比(像素坐标)——保持不变"""
|
||||||
|
box1 = [int(round(x)) for x in box1]
|
||||||
|
box2 = [int(round(x)) for x in box2]
|
||||||
|
|
||||||
|
inter_x1 = max(box1[0], box2[0])
|
||||||
|
inter_y1 = max(box1[1], box2[1])
|
||||||
|
inter_x2 = min(box1[2], box2[2])
|
||||||
|
inter_y2 = min(box1[3], box2[3])
|
||||||
|
|
||||||
|
inter_width = max(0, inter_x2 - inter_x1)
|
||||||
|
inter_height = max(0, inter_y2 - inter_y1)
|
||||||
|
inter_area = inter_width * inter_height
|
||||||
|
|
||||||
|
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
||||||
|
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
||||||
|
union_area = box1_area + box2_area - inter_area
|
||||||
|
|
||||||
|
return inter_area / union_area if union_area > 0 else 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def collect_feature_images(feature_image_dir, feature_label_dir, target_classes):
|
||||||
|
"""收集特征图片及其标签——核心修改:删除内部重叠过滤"""
|
||||||
|
print(f"从 {feature_image_dir} 收集特征图片及标签,目标类别 {target_classes}...")
|
||||||
|
feature_list = [] # 元素格式:(图片路径, 标签列表, 原始宽, 原始高)
|
||||||
|
|
||||||
|
feature_image_files = [f for f in os.listdir(feature_image_dir)
|
||||||
|
if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
|
||||||
|
if not feature_image_files:
|
||||||
|
raise FileNotFoundError(f"未找到特征图片!")
|
||||||
|
|
||||||
|
for filename in tqdm(feature_image_files, desc="收集特征图片"):
|
||||||
|
img_path = os.path.join(feature_image_dir, filename)
|
||||||
|
label_file = os.path.splitext(filename)[0] + ".txt"
|
||||||
|
label_path = os.path.join(feature_label_dir, label_file)
|
||||||
|
|
||||||
|
if not os.path.exists(label_path):
|
||||||
|
continue
|
||||||
|
|
||||||
|
img = cv2.imread(img_path)
|
||||||
|
if img is None:
|
||||||
|
tqdm.write(f"警告:无法读取图片 {filename},已跳过")
|
||||||
|
continue
|
||||||
|
orig_h, orig_w = img.shape[:2]
|
||||||
|
|
||||||
|
# 仅过滤缩小后可能过小的特征图(保留内部重叠的图)
|
||||||
|
min_required_orig_size = MIN_SCALED_SIZE / max(SCALE_FACTOR_RANGE)
|
||||||
|
if orig_w < min_required_orig_size or orig_h < min_required_orig_size:
|
||||||
|
tqdm.write(f"警告:特征图片 {filename} 原始尺寸过小,已跳过")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 读取并过滤目标类别(不再检查内部重叠)
|
||||||
|
labels = []
|
||||||
|
with open(label_path, 'r') as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
parts = line.split()
|
||||||
|
cls_id = int(float(parts[0]))
|
||||||
|
if cls_id not in target_classes:
|
||||||
|
continue
|
||||||
|
xc, yc, w, h = [float(p) for p in parts[1:]]
|
||||||
|
if 0 <= xc <= 1 and 0 <= yc <= 1 and 0 < w <= 1 and 0 < h <= 1:
|
||||||
|
labels.append((cls_id, xc, yc, w, h))
|
||||||
|
|
||||||
|
# 只要有有效标签就保留(无论内部是否重叠)
|
||||||
|
if labels:
|
||||||
|
feature_list.append((img_path, labels, orig_w, orig_h))
|
||||||
|
|
||||||
|
if not feature_list:
|
||||||
|
raise ValueError(f"未收集到有效的特征图片及标签!")
|
||||||
|
|
||||||
|
print(f"特征图片收集完成:共 {len(feature_list)} 张有效特征图片(允许内部目标重叠)")
|
||||||
|
return feature_list
|
||||||
|
|
||||||
|
|
||||||
|
def paste_feature_images_to_base(base_image, feature_list):
|
||||||
|
"""将特征图粘贴到底图——保留“跨特征图无重叠”逻辑(核心)"""
|
||||||
|
base_h, base_w = base_image.shape[:2]
|
||||||
|
pasted_labels = [] # 最终输出的标签
|
||||||
|
pasted_target_boxes = [] # 已粘贴的目标框(像素坐标)
|
||||||
|
pasted_feature_regions = [] # 已粘贴的特征图整体区域(防跨图重叠)
|
||||||
|
pasted_feature_count = 0 # 成功粘贴的特征图数量
|
||||||
|
skipped_small = 0 # 因尺寸过小跳过的次数
|
||||||
|
skipped_region_overlap = 0 # 因特征图整体区域重叠跳过的次数
|
||||||
|
skipped_target_overlap = 0 # 因跨图目标重叠跳过的次数
|
||||||
|
retry_count = 0 # 重试次数
|
||||||
|
|
||||||
|
target_paste_count = random.randint(*PASTE_COUNT_RANGE)
|
||||||
|
print(f" 计划粘贴 {target_paste_count} 张特征图...")
|
||||||
|
|
||||||
|
while pasted_feature_count < target_paste_count and retry_count < MAX_RETRY_COUNT:
|
||||||
|
retry_count += 1
|
||||||
|
|
||||||
|
# 1. 随机选择一张特征图(允许内部重叠)
|
||||||
|
feature_img_path, feature_labels, orig_w, orig_h = random.choice(feature_list)
|
||||||
|
|
||||||
|
# 2. 随机缩放特征图,确保尺寸符合要求
|
||||||
|
scale_factor = random.uniform(*SCALE_FACTOR_RANGE)
|
||||||
|
scaled_w = int(round(orig_w * scale_factor))
|
||||||
|
scaled_h = int(round(orig_h * scale_factor))
|
||||||
|
|
||||||
|
if scaled_w < MIN_SCALED_SIZE or scaled_h < MIN_SCALED_SIZE:
|
||||||
|
skipped_small += 1
|
||||||
|
continue
|
||||||
|
if scaled_w >= base_w or scaled_h >= base_h:
|
||||||
|
skipped_small += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 3. 随机生成粘贴位置(确保特征图完全在底图内)
|
||||||
|
paste_x1 = random.randint(0, base_w - scaled_w)
|
||||||
|
paste_y1 = random.randint(0, base_h - scaled_h)
|
||||||
|
paste_x2 = paste_x1 + scaled_w
|
||||||
|
paste_y2 = paste_y1 + scaled_h
|
||||||
|
current_feature_region = (paste_x1, paste_y1, paste_x2, paste_y2)
|
||||||
|
|
||||||
|
# 4. 关键:检测当前特征图整体区域与已粘贴区域是否重叠(防跨图重叠)
|
||||||
|
region_overlap = False
|
||||||
|
for existing_region in pasted_feature_regions:
|
||||||
|
if calculate_iou(current_feature_region, existing_region) > MAX_OVERLAP_IOU:
|
||||||
|
region_overlap = True
|
||||||
|
skipped_region_overlap += 1
|
||||||
|
break
|
||||||
|
if region_overlap:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 5. 计算当前特征图目标在底图上的像素坐标(保留原始内部重叠)
|
||||||
|
temp_target_boxes = []
|
||||||
|
valid = True
|
||||||
|
for (cls_id, xc, yc, w, h) in feature_labels:
|
||||||
|
# 特征图内目标坐标 → 缩放后 → 底图坐标(内部重叠会保留)
|
||||||
|
orig_x1 = (xc - w/2) * orig_w
|
||||||
|
orig_y1 = (yc - h/2) * orig_h
|
||||||
|
orig_x2 = (xc + w/2) * orig_w
|
||||||
|
orig_y2 = (yc + h/2) * orig_h
|
||||||
|
|
||||||
|
scaled_x1 = orig_x1 * scale_factor
|
||||||
|
scaled_y1 = orig_y1 * scale_factor
|
||||||
|
scaled_x2 = orig_x2 * scale_factor
|
||||||
|
scaled_y2 = orig_y2 * scale_factor
|
||||||
|
|
||||||
|
base_x1 = paste_x1 + scaled_x1
|
||||||
|
base_y1 = paste_y1 + scaled_y1
|
||||||
|
base_x2 = paste_x1 + scaled_x2
|
||||||
|
base_y2 = paste_y1 + scaled_y2
|
||||||
|
|
||||||
|
# 确保目标框完全在底图内(不考虑内部重叠)
|
||||||
|
if base_x1 < 0 or base_y1 < 0 or base_x2 > base_w or base_y2 > base_h:
|
||||||
|
valid = False
|
||||||
|
break
|
||||||
|
temp_target_boxes.append((base_x1, base_y1, base_x2, base_y2))
|
||||||
|
|
||||||
|
if not valid:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 6. 关键:检测当前特征图目标与已粘贴目标是否重叠(防跨图目标重叠)
|
||||||
|
target_overlap = False
|
||||||
|
for temp_box in temp_target_boxes:
|
||||||
|
for existing_box in pasted_target_boxes:
|
||||||
|
if calculate_iou(temp_box, existing_box) > MAX_OVERLAP_IOU:
|
||||||
|
target_overlap = True
|
||||||
|
skipped_target_overlap += 1
|
||||||
|
break
|
||||||
|
if target_overlap:
|
||||||
|
break
|
||||||
|
if target_overlap:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 7. 粘贴特征图并记录信息(保留内部重叠)
|
||||||
|
feature_img = cv2.imread(feature_img_path)
|
||||||
|
if feature_img is None:
|
||||||
|
continue
|
||||||
|
scaled_feature_img = cv2.resize(feature_img, (scaled_w, scaled_h), interpolation=cv2.INTER_LINEAR)
|
||||||
|
base_image[paste_y1:paste_y2, paste_x1:paste_x2] = scaled_feature_img
|
||||||
|
|
||||||
|
# 记录标签、目标框、特征图整体区域
|
||||||
|
for (cls_id, xc, yc, w, h), temp_box in zip(feature_labels, temp_target_boxes):
|
||||||
|
base_xc = (temp_box[0] + temp_box[2]) / 2 / base_w
|
||||||
|
base_yc = (temp_box[1] + temp_box[3]) / 2 / base_h
|
||||||
|
base_w_box = (temp_box[2] - temp_box[0]) / base_w
|
||||||
|
base_h_box = (temp_box[3] - temp_box[1]) / base_h
|
||||||
|
|
||||||
|
pasted_labels.append((cls_id, base_xc, base_yc, base_w_box, base_h_box))
|
||||||
|
pasted_target_boxes.append(temp_box)
|
||||||
|
pasted_feature_regions.append(current_feature_region)
|
||||||
|
pasted_feature_count += 1
|
||||||
|
print(f" 已成功粘贴 {pasted_feature_count}/{target_paste_count} 张特征图")
|
||||||
|
|
||||||
|
# 输出统计(无内部重叠相关提示)
|
||||||
|
print(
|
||||||
|
f" 粘贴统计:成功{pasted_feature_count}张 → "
|
||||||
|
f"跳小尺寸{skipped_small} → 跳区域重叠{skipped_region_overlap} → "
|
||||||
|
f"跳目标重叠{skipped_target_overlap} → 总目标数{len(pasted_labels)}"
|
||||||
|
)
|
||||||
|
return base_image, pasted_labels
|
||||||
|
|
||||||
|
|
||||||
|
# --- 3. 主函数 ---(无修改)
|
||||||
|
def main():
|
||||||
|
os.makedirs(OUTPUT_IMAGE_DIR, exist_ok=True)
|
||||||
|
os.makedirs(OUTPUT_LABEL_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
feature_list = collect_feature_images(
|
||||||
|
SOURCE_FEATURE_IMAGE_DIR, SOURCE_FEATURE_LABEL_DIR, TARGET_CLASSES
|
||||||
|
)
|
||||||
|
except (FileNotFoundError, ValueError) as e:
|
||||||
|
print(f"错误:{e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 筛选指定数量的底图
|
||||||
|
base_files = [f for f in os.listdir(BASE_IMAGE_DIR)
|
||||||
|
if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
|
||||||
|
total_base_count = len(base_files)
|
||||||
|
if total_base_count == 0:
|
||||||
|
print(f"错误:底图目录 {BASE_IMAGE_DIR} 无有效图片!")
|
||||||
|
return
|
||||||
|
|
||||||
|
if SELECT_BASE_IMAGE_COUNT < 0:
|
||||||
|
print(f"错误:底图数量不能为负数(当前设置:{SELECT_BASE_IMAGE_COUNT})!")
|
||||||
|
return
|
||||||
|
elif SELECT_BASE_IMAGE_COUNT == 0:
|
||||||
|
selected_base_files = base_files
|
||||||
|
print(f"\n找到 {total_base_count} 张底图,使用全部底图...\n")
|
||||||
|
else:
|
||||||
|
if SELECT_BASE_IMAGE_COUNT > total_base_count:
|
||||||
|
selected_base_files = base_files
|
||||||
|
print(f"\n警告:设置底图数({SELECT_BASE_IMAGE_COUNT})超过可用数({total_base_count}),使用全部底图...\n")
|
||||||
|
else:
|
||||||
|
selected_base_files = random.sample(base_files, SELECT_BASE_IMAGE_COUNT)
|
||||||
|
print(f"\n找到 {total_base_count} 张底图,随机选择 {SELECT_BASE_IMAGE_COUNT} 张...\n")
|
||||||
|
|
||||||
|
# 处理底图
|
||||||
|
for base_filename in tqdm(selected_base_files, desc="处理底图"):
|
||||||
|
base_name, ext = os.path.splitext(base_filename)
|
||||||
|
base_path = os.path.join(BASE_IMAGE_DIR, base_filename)
|
||||||
|
base_img = cv2.imread(base_path)
|
||||||
|
if base_img is None:
|
||||||
|
tqdm.write(f"\n警告:无法读取底图 {base_filename},已跳过")
|
||||||
|
continue
|
||||||
|
|
||||||
|
for aug_idx in range(AUGMENTATION_FACTOR):
|
||||||
|
print(f"\n底图 {base_filename}(增强序号 {aug_idx}):", end="")
|
||||||
|
base_copy = base_img.copy()
|
||||||
|
pasted_img, labels = paste_feature_images_to_base(base_copy, feature_list)
|
||||||
|
|
||||||
|
if not labels:
|
||||||
|
tqdm.write(f"\n警告:未粘贴任何目标,已跳过该结果")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 保存图片和标签
|
||||||
|
output_img_name = f"{base_name}_feat_paste_{aug_idx}.jpg"
|
||||||
|
output_img_path = os.path.join(OUTPUT_IMAGE_DIR, output_img_name)
|
||||||
|
cv2.imwrite(output_img_path, pasted_img)
|
||||||
|
|
||||||
|
output_label_name = f"{base_name}_feat_paste_{aug_idx}.txt"
|
||||||
|
output_label_path = os.path.join(OUTPUT_LABEL_DIR, output_label_name)
|
||||||
|
with open(output_label_path, 'w') as f:
|
||||||
|
for (cls_id, x, y, w, h) in labels:
|
||||||
|
f.write(f"{cls_id} {x:.6f} {y:.6f} {w:.6f} {h:.6f}\n")
|
||||||
|
|
||||||
|
print(f" → 已保存(目标数:{len(labels)})")
|
||||||
|
|
||||||
|
# 统计结果
|
||||||
|
output_img_count = len([f for f in os.listdir(OUTPUT_IMAGE_DIR) if f.endswith(('.jpg', '.png'))])
|
||||||
|
output_label_count = len([f for f in os.listdir(OUTPUT_LABEL_DIR) if f.endswith('.txt')])
|
||||||
|
print(f"\n✅ 全部处理完成!")
|
||||||
|
print(f" - 实际处理底图数量:{len(selected_base_files)} 张")
|
||||||
|
print(f" - 生成图片:{output_img_count} 张")
|
||||||
|
print(f" - 生成标签:{output_label_count} 个")
|
||||||
|
print(f" - 输出路径:\n 图片 → {OUTPUT_IMAGE_DIR}\n 标签 → {OUTPUT_LABEL_DIR}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
65
扩大边缘像素.py
Normal file
65
扩大边缘像素.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def expand_bbox(label_path, k=1.1):
|
||||||
|
"""
|
||||||
|
扩大YOLO标签的边界框
|
||||||
|
:param label_path: 标签文件路径(.txt)
|
||||||
|
:param k: 扩大系数(k>1)
|
||||||
|
"""
|
||||||
|
with open(label_path, 'r') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
new_lines = []
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
# 解析标签
|
||||||
|
class_id, xc, yc, w, h = line.split()
|
||||||
|
xc = float(xc)
|
||||||
|
yc = float(yc)
|
||||||
|
w = float(w)
|
||||||
|
h = float(h)
|
||||||
|
|
||||||
|
# 计算新宽高
|
||||||
|
new_w = w * k
|
||||||
|
new_h = h * k
|
||||||
|
|
||||||
|
# 计算边界
|
||||||
|
x1 = xc - new_w / 2
|
||||||
|
y1 = yc - new_h / 2
|
||||||
|
x2 = xc + new_w / 2
|
||||||
|
y2 = yc + new_h / 2
|
||||||
|
|
||||||
|
# 截断超出图像的部分(0~1范围)
|
||||||
|
x1 = max(0.0, x1)
|
||||||
|
y1 = max(0.0, y1)
|
||||||
|
x2 = min(1.0, x2)
|
||||||
|
y2 = min(1.0, y2)
|
||||||
|
|
||||||
|
# 重新计算中心和宽高
|
||||||
|
new_xc = (x1 + x2) / 2
|
||||||
|
new_yc = (y1 + y2) / 2
|
||||||
|
new_w = x2 - x1
|
||||||
|
new_h = y2 - y1
|
||||||
|
|
||||||
|
# 保留6位小数,拼接新标签
|
||||||
|
new_line = f"{class_id} {new_xc:.6f} {new_yc:.6f} {new_w:.6f} {new_h:.6f}\n"
|
||||||
|
new_lines.append(new_line)
|
||||||
|
|
||||||
|
# 写入新标签(覆盖原文件,或改为新路径)
|
||||||
|
with open(label_path, 'w') as f:
|
||||||
|
f.writelines(new_lines)
|
||||||
|
|
||||||
|
|
||||||
|
# 批量处理文件夹中的所有标签
|
||||||
|
label_dir = r"D:\DataPreHandler\yuanshi_data\images\val\labels" # 替换为你的标签文件夹路径
|
||||||
|
k = 1.25 # 扩大系数,根据需求调整
|
||||||
|
|
||||||
|
for filename in os.listdir(label_dir):
|
||||||
|
if filename.endswith('.txt'):
|
||||||
|
label_path = os.path.join(label_dir, filename)
|
||||||
|
expand_bbox(label_path, k)
|
||||||
|
|
||||||
|
print("处理完成!")
|
||||||
217
抽取样本数量.py
Normal file
217
抽取样本数量.py
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
import shutil
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|
||||||
|
def add_balanced_samples(source_dir, target_dir, add_samples):
|
||||||
|
"""
|
||||||
|
在已有目标数据集基础上添加新样本(不重复),并保持类别平衡
|
||||||
|
参数:
|
||||||
|
source_dir: 源数据集目录(包含images和labels)
|
||||||
|
target_dir: 目标数据集目录(已有样本,包含images和labels)
|
||||||
|
add_samples: 需要新增的样本数量
|
||||||
|
"""
|
||||||
|
# 定义类别名称映射
|
||||||
|
class_names = {
|
||||||
|
0: "Abdomen",
|
||||||
|
1: "Hips",
|
||||||
|
2: "Chest",
|
||||||
|
3: "vulva",
|
||||||
|
4: "back",
|
||||||
|
5: "penis",
|
||||||
|
6: "Horror"
|
||||||
|
}
|
||||||
|
num_classes = len(class_names)
|
||||||
|
|
||||||
|
# 确保目标目录存在
|
||||||
|
os.makedirs(os.path.join(target_dir, "images"), exist_ok=True)
|
||||||
|
os.makedirs(os.path.join(target_dir, "labels"), exist_ok=True)
|
||||||
|
|
||||||
|
# --------------------------
|
||||||
|
# 关键步骤1:获取目标目录中已有的文件(避免重复)
|
||||||
|
# --------------------------
|
||||||
|
existing_images = set()
|
||||||
|
target_image_dir = os.path.join(target_dir, "images")
|
||||||
|
for f in os.listdir(target_image_dir):
|
||||||
|
if os.path.isfile(os.path.join(target_image_dir, f)):
|
||||||
|
existing_images.add(f) # 记录已存在的图片文件名
|
||||||
|
print(f"目标目录中已存在 {len(existing_images)} 个样本,将避免重复添加")
|
||||||
|
|
||||||
|
# --------------------------
|
||||||
|
# 关键步骤2:获取源目录中未被目标目录包含的可用文件
|
||||||
|
# --------------------------
|
||||||
|
source_image_dir = os.path.join(source_dir, "images")
|
||||||
|
source_label_dir = os.path.join(source_dir, "labels")
|
||||||
|
|
||||||
|
if not os.path.exists(source_image_dir) or not os.path.exists(source_label_dir):
|
||||||
|
print(f"错误: 源目录 {source_dir} 中缺少images或labels子目录")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 源目录所有图片中,排除已存在于目标目录的文件
|
||||||
|
candidate_images = [
|
||||||
|
f for f in os.listdir(source_image_dir)
|
||||||
|
if os.path.isfile(os.path.join(source_image_dir, f)) and f not in existing_images
|
||||||
|
]
|
||||||
|
|
||||||
|
if not candidate_images:
|
||||||
|
print(f"错误: 源目录中没有可添加的新样本(所有样本已存在于目标目录)")
|
||||||
|
return
|
||||||
|
|
||||||
|
# --------------------------
|
||||||
|
# 分析候选文件的类别(兼容浮点数ID)
|
||||||
|
# --------------------------
|
||||||
|
file_classes = {} # 候选文件名 -> 包含的类别集合
|
||||||
|
class_files = defaultdict(list) # 类别 -> 包含该类别的候选文件列表
|
||||||
|
|
||||||
|
for img_file in candidate_images:
|
||||||
|
base_name = os.path.splitext(img_file)[0]
|
||||||
|
label_file = f"{base_name}.txt"
|
||||||
|
label_path = os.path.join(source_label_dir, label_file)
|
||||||
|
|
||||||
|
if not os.path.exists(label_path):
|
||||||
|
print(f"警告: 图片 {img_file} 对应的标签文件 {label_file} 不存在,已跳过")
|
||||||
|
continue
|
||||||
|
|
||||||
|
classes_in_file = set()
|
||||||
|
with open(label_path, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
parts = line.strip().split()
|
||||||
|
if not parts:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
class_id = int(float(parts[0])) # 兼容浮点数类别ID
|
||||||
|
if class_id in class_names:
|
||||||
|
classes_in_file.add(class_id)
|
||||||
|
else:
|
||||||
|
print(f"警告: 标签文件 {label_file} 包含未知类别 {class_id}")
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
print(f"警告: 标签文件 {label_file} 格式不正确,行内容: {line.strip()}")
|
||||||
|
|
||||||
|
if not classes_in_file:
|
||||||
|
print(f"警告: 标签文件 {label_file} 无有效类别,已跳过")
|
||||||
|
continue
|
||||||
|
|
||||||
|
file_classes[img_file] = classes_in_file
|
||||||
|
for class_id in classes_in_file:
|
||||||
|
class_files[class_id].append(img_file)
|
||||||
|
|
||||||
|
# 检查候选文件中是否有类别缺失
|
||||||
|
for class_id in class_names:
|
||||||
|
if class_id not in class_files or len(class_files[class_id]) == 0:
|
||||||
|
print(f"警告: 候选文件中没有类别 {class_id} ({class_names[class_id]}) 的样本")
|
||||||
|
|
||||||
|
# --------------------------
|
||||||
|
# 处理新增样本数量(不超过候选文件数量)
|
||||||
|
# --------------------------
|
||||||
|
available_candidates = len(file_classes)
|
||||||
|
if add_samples > available_candidates:
|
||||||
|
print(f"警告: 请求新增 {add_samples} 个样本,但候选文件仅 {available_candidates} 个,将全部添加")
|
||||||
|
add_samples = available_candidates
|
||||||
|
elif add_samples <= 0:
|
||||||
|
print("错误: 新增样本数量必须大于0")
|
||||||
|
return
|
||||||
|
|
||||||
|
# --------------------------
|
||||||
|
# 平衡抽取新增样本
|
||||||
|
# --------------------------
|
||||||
|
selected_files = set()
|
||||||
|
remaining = add_samples
|
||||||
|
|
||||||
|
# 每个类别理想的新增数量(基于总新增数量平均)
|
||||||
|
ideal_per_class = max(1, add_samples // num_classes)
|
||||||
|
|
||||||
|
# 先为每个类别抽取理想数量的样本
|
||||||
|
for class_id in class_names:
|
||||||
|
if class_id not in class_files:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 候选文件中未被选中的
|
||||||
|
available = [f for f in class_files[class_id] if f not in selected_files]
|
||||||
|
num_to_add = min(ideal_per_class, len(available), remaining)
|
||||||
|
|
||||||
|
if num_to_add > 0:
|
||||||
|
selected = random.sample(available, num_to_add)
|
||||||
|
selected_files.update(selected)
|
||||||
|
remaining -= num_to_add
|
||||||
|
|
||||||
|
if remaining <= 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 补充剩余需要新增的样本
|
||||||
|
if remaining > 0:
|
||||||
|
all_available = [f for f in file_classes if f not in selected_files]
|
||||||
|
num_to_add = min(remaining, len(all_available))
|
||||||
|
if num_to_add > 0:
|
||||||
|
selected_additional = random.sample(all_available, num_to_add)
|
||||||
|
selected_files.update(selected_additional)
|
||||||
|
|
||||||
|
# --------------------------
|
||||||
|
# 复制新增样本到目标目录(确保不重复)
|
||||||
|
# --------------------------
|
||||||
|
for img_file in selected_files:
|
||||||
|
# 复制图片
|
||||||
|
src_img = os.path.join(source_image_dir, img_file)
|
||||||
|
dst_img = os.path.join(target_image_dir, img_file)
|
||||||
|
shutil.copy2(src_img, dst_img)
|
||||||
|
|
||||||
|
# 复制标签
|
||||||
|
base_name = os.path.splitext(img_file)[0]
|
||||||
|
label_file = f"{base_name}.txt"
|
||||||
|
src_label = os.path.join(source_label_dir, label_file)
|
||||||
|
dst_label = os.path.join(target_dir, "labels", label_file)
|
||||||
|
if os.path.exists(src_label):
|
||||||
|
shutil.copy2(src_label, dst_label)
|
||||||
|
|
||||||
|
# --------------------------
|
||||||
|
# 统计结果(包含目标目录原有+新增的总分布)
|
||||||
|
# --------------------------
|
||||||
|
# 1. 统计目标目录原有样本的类别
|
||||||
|
existing_class_counts = defaultdict(int)
|
||||||
|
for img_file in existing_images:
|
||||||
|
# 从源目录找原有样本的标签(因为目标目录标签可能已存在,但源目录更可靠)
|
||||||
|
base_name = os.path.splitext(img_file)[0]
|
||||||
|
label_file = f"{base_name}.txt"
|
||||||
|
label_path = os.path.join(source_label_dir, label_file)
|
||||||
|
if os.path.exists(label_path):
|
||||||
|
with open(label_path, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
parts = line.strip().split()
|
||||||
|
if parts:
|
||||||
|
try:
|
||||||
|
class_id = int(float(parts[0]))
|
||||||
|
if class_id in class_names:
|
||||||
|
existing_class_counts[class_id] += 1
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
pass # 忽略格式错误的行
|
||||||
|
|
||||||
|
# 2. 统计新增样本的类别
|
||||||
|
new_class_counts = defaultdict(int)
|
||||||
|
for img_file in selected_files:
|
||||||
|
for class_id in file_classes[img_file]:
|
||||||
|
new_class_counts[class_id] += 1
|
||||||
|
|
||||||
|
# 3. 总分布(原有+新增)
|
||||||
|
total_class_counts = {
|
||||||
|
class_id: existing_class_counts.get(class_id, 0) + new_class_counts.get(class_id, 0)
|
||||||
|
for class_id in class_names
|
||||||
|
}
|
||||||
|
|
||||||
|
# 输出结果
|
||||||
|
print(f"\n已成功新增 {len(selected_files)} 个样本,目标目录总样本数: {len(existing_images) + len(selected_files)}")
|
||||||
|
print("目标目录总类别分布(原有+新增):")
|
||||||
|
for class_id in class_names:
|
||||||
|
print(f"类别 {class_id} ({class_names[class_id]}): {total_class_counts[class_id]} 次出现")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 源数据集路径(从中抽取新样本)
|
||||||
|
source_directory = r"D:\DataPreHandler\data\train3"
|
||||||
|
# 目标数据集路径(已有样本,将新增到这里)
|
||||||
|
target_directory = r"D:\DataPreHandler\data\train2"
|
||||||
|
|
||||||
|
# 需要新增的样本数量(根据需求修改)
|
||||||
|
add_samples = 1200 # 例如:再添加100个新样本
|
||||||
|
|
||||||
|
# 执行新增操作
|
||||||
|
add_balanced_samples(source_directory, target_directory, add_samples)
|
||||||
199
根据位置画框.py
Normal file
199
根据位置画框.py
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# -------------------------- 1. 核心配置(请根据需求修改) --------------------------
|
||||||
|
# 输入目录(图片和标签需一一对应,文件名相同,仅后缀不同)
|
||||||
|
INPUT_IMAGE_DIR = r"D:\DataPreHandler\data\train\images" # 原始图片目录
|
||||||
|
INPUT_LABEL_DIR = r"D:\DataPreHandler\data\train\labels" # 原始YOLO标签目录
|
||||||
|
# 输出目录(标注后的图片会保存在这里)
|
||||||
|
OUTPUT_IMAGE_DIR = r"D:\DataPreHandler\data\test\da\output2"
|
||||||
|
|
||||||
|
# 关键配置:类别ID与类别名称的映射(必须与你的YOLO训练类别顺序一致!)
|
||||||
|
CLASS_CONFIG = [
|
||||||
|
(0, "Abdomen", (0, 255, 0)),
|
||||||
|
(1, "Hips", (0, 255, 0)),
|
||||||
|
(2, "Chest", (0, 255, 0)),
|
||||||
|
(3, "vulva", (0, 255, 0)),
|
||||||
|
(4, "back", (0, 255, 0)),
|
||||||
|
(5, "penis", (0, 255, 0)),
|
||||||
|
(6, "Horror", (0, 255, 0))
|
||||||
|
]
|
||||||
|
|
||||||
|
# 绘制参数(可按需调整)
|
||||||
|
BOX_THICKNESS = 2 # 边界框线条厚度(像素)
|
||||||
|
FONT_FACE = cv2.FONT_HERSHEY_SIMPLEX # 字体类型
|
||||||
|
FONT_SCALE = 0.6 # 字体大小(根据图片尺寸调整)
|
||||||
|
FONT_THICKNESS = 1 # 字体线条厚度
|
||||||
|
TEXT_PADDING = 5 # 文字与边界框的间距(像素)
|
||||||
|
TEXT_BG_OPACITY = 0.7 # 文字背景的透明度(0-1,0为完全透明)
|
||||||
|
|
||||||
|
# 支持的图片格式(无需修改)
|
||||||
|
SUPPORTED_IMAGE_FORMATS = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff')
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------- 2. 工具函数 --------------------------
|
||||||
|
def yolo2pixel(yolo_coords, img_w, img_h):
|
||||||
|
"""
|
||||||
|
将YOLO相对坐标转换为图片像素坐标(边界框:x1, y1, x2, y2)
|
||||||
|
:param yolo_coords: YOLO坐标列表 [xc, yc, w, h](相对值,0-1)
|
||||||
|
:param img_w: 图片宽度(像素)
|
||||||
|
:param img_h: 图片高度(像素)
|
||||||
|
:return: 像素坐标元组 (x1, y1, x2, y2)
|
||||||
|
"""
|
||||||
|
xc, yc, w, h = yolo_coords
|
||||||
|
# 计算边界框左上角和右下角坐标
|
||||||
|
x1 = int((xc - w / 2) * img_w)
|
||||||
|
y1 = int((yc - h / 2) * img_h)
|
||||||
|
x2 = int((xc + w / 2) * img_w)
|
||||||
|
y2 = int((yc + h / 2) * img_h)
|
||||||
|
# 确保坐标不超出图片范围
|
||||||
|
x1 = max(0, x1)
|
||||||
|
y1 = max(0, y1)
|
||||||
|
x2 = min(img_w, x2)
|
||||||
|
y2 = min(img_h, y2)
|
||||||
|
return x1, y1, x2, y2
|
||||||
|
|
||||||
|
|
||||||
|
def draw_annotation(img, bbox, class_name, color):
|
||||||
|
"""
|
||||||
|
在图片上绘制边界框和类别名称
|
||||||
|
:param img: 原始图片(OpenCV格式,BGR通道)
|
||||||
|
:param bbox: 像素坐标边界框 (x1, y1, x2, y2)
|
||||||
|
:param class_name: 类别名称(字符串)
|
||||||
|
:param color: 边界框和文字颜色(BGR元组,如 (0,255,0) 代表绿色)
|
||||||
|
:return: 标注后的图片
|
||||||
|
"""
|
||||||
|
img_h, img_w = img.shape[:2]
|
||||||
|
x1, y1, x2, y2 = bbox
|
||||||
|
|
||||||
|
# 1. 绘制边界框
|
||||||
|
cv2.rectangle(img, (x1, y1), (x2, y2), color, BOX_THICKNESS)
|
||||||
|
|
||||||
|
# 2. 计算文字尺寸(用于创建文字背景)
|
||||||
|
text_size, _ = cv2.getTextSize(class_name, FONT_FACE, FONT_SCALE, FONT_THICKNESS)
|
||||||
|
text_w, text_h = text_size
|
||||||
|
|
||||||
|
# 3. 确定文字位置(避免超出图片范围)
|
||||||
|
# 文字默认放在边界框左上角,若左上角空间不足则放在右上角
|
||||||
|
text_x = x1 + TEXT_PADDING
|
||||||
|
text_y = y1 - TEXT_PADDING - text_h # 文字基线在y轴上方
|
||||||
|
if text_y < 0: # 左上角超出图片顶部,调整到右上角
|
||||||
|
text_x = x2 - TEXT_PADDING - text_w
|
||||||
|
text_y = y1 + TEXT_PADDING + text_h
|
||||||
|
|
||||||
|
# 4. 绘制文字背景(半透明矩形,避免遮挡图片内容)
|
||||||
|
bg_x1 = text_x - TEXT_PADDING
|
||||||
|
bg_y1 = text_y - text_h - TEXT_PADDING
|
||||||
|
bg_x2 = text_x + text_w + TEXT_PADDING
|
||||||
|
bg_y2 = text_y + TEXT_PADDING
|
||||||
|
# 确保背景不超出图片范围
|
||||||
|
bg_x1 = max(0, bg_x1)
|
||||||
|
bg_y1 = max(0, bg_y1)
|
||||||
|
bg_x2 = min(img_w, bg_x2)
|
||||||
|
bg_y2 = min(img_h, bg_y2)
|
||||||
|
|
||||||
|
# 半透明背景:先创建背景层,再与原图混合
|
||||||
|
bg = img[bg_y1:bg_y2, bg_x1:bg_x2].copy()
|
||||||
|
bg = cv2.rectangle(bg, (0, 0), (bg_x2 - bg_x1, bg_y2 - bg_y1), color, -1) # 实心矩形
|
||||||
|
img[bg_y1:bg_y2, bg_x1:bg_x2] = cv2.addWeighted(bg, TEXT_BG_OPACITY, img[bg_y1:bg_y2, bg_x1:bg_x2], 1 - TEXT_BG_OPACITY, 0)
|
||||||
|
|
||||||
|
# 5. 绘制类别名称
|
||||||
|
cv2.putText(img, class_name, (text_x, text_y), FONT_FACE, FONT_SCALE, (0, 0, 0), FONT_THICKNESS) # 白色文字
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------- 3. 主函数 --------------------------
|
||||||
|
def main():
|
||||||
|
# 1. 创建输出目录(若不存在)
|
||||||
|
os.makedirs(OUTPUT_IMAGE_DIR, exist_ok=True)
|
||||||
|
print(f"标注后的图片将保存到:{OUTPUT_IMAGE_DIR}\n")
|
||||||
|
|
||||||
|
# 2. 构建类别ID到(名称+颜色)的映射字典
|
||||||
|
class_map = {cls_id: (cls_name, cls_color) for cls_id, cls_name, cls_color in CLASS_CONFIG}
|
||||||
|
print("类别配置:")
|
||||||
|
for cls_id, cls_name, cls_color in CLASS_CONFIG:
|
||||||
|
print(f" ID {cls_id} → 名称:{cls_name},颜色:{cls_color}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 3. 获取所有图片文件(仅处理支持的格式)
|
||||||
|
image_files = [f for f in os.listdir(INPUT_IMAGE_DIR) if f.lower().endswith(SUPPORTED_IMAGE_FORMATS)]
|
||||||
|
if not image_files:
|
||||||
|
raise FileNotFoundError(f"在 {INPUT_IMAGE_DIR} 中未找到任何支持的图片文件({SUPPORTED_IMAGE_FORMATS})")
|
||||||
|
print(f"找到 {len(image_files)} 张图片,开始标注...\n")
|
||||||
|
|
||||||
|
# 4. 遍历图片并标注
|
||||||
|
for img_filename in tqdm(image_files, desc="处理进度"):
|
||||||
|
# 4.1 构建图片和标签的路径
|
||||||
|
img_name, img_ext = os.path.splitext(img_filename)
|
||||||
|
img_path = os.path.join(INPUT_IMAGE_DIR, img_filename)
|
||||||
|
label_path = os.path.join(INPUT_LABEL_DIR, f"{img_name}.txt") # 标签文件与图片同名,后缀为txt
|
||||||
|
|
||||||
|
# 4.2 读取图片(OpenCV默认读取为BGR通道)
|
||||||
|
img = cv2.imread(img_path)
|
||||||
|
if img is None:
|
||||||
|
tqdm.write(f"⚠️ 跳过:无法读取图片 {img_filename}(可能损坏或格式不支持)")
|
||||||
|
continue
|
||||||
|
img_h, img_w = img.shape[:2]
|
||||||
|
|
||||||
|
# 4.3 读取标签文件(若不存在则跳过标注,直接保存原图)
|
||||||
|
if not os.path.exists(label_path):
|
||||||
|
tqdm.write(f"⚠️ 警告:图片 {img_filename} 无对应标签文件 {os.path.basename(label_path)},直接保存原图")
|
||||||
|
annotated_img = img.copy()
|
||||||
|
else:
|
||||||
|
# 复制原图用于标注(避免修改原始图片)
|
||||||
|
annotated_img = img.copy()
|
||||||
|
# 读取标签内容
|
||||||
|
with open(label_path, 'r', encoding='utf-8') as f:
|
||||||
|
label_lines = [line.strip() for line in f.readlines() if line.strip()] # 过滤空行
|
||||||
|
|
||||||
|
# 4.4 解析每个标签并绘制
|
||||||
|
for line_idx, line in enumerate(label_lines):
|
||||||
|
try:
|
||||||
|
# YOLO标签格式:class_id xc yc w h(空格分隔)
|
||||||
|
parts = line.split()
|
||||||
|
if len(parts) != 5:
|
||||||
|
raise ValueError(f"格式错误(需5个字段,实际{len(parts)}个)")
|
||||||
|
|
||||||
|
# 解析类别ID和坐标
|
||||||
|
cls_id = int(float(parts[0]))
|
||||||
|
yolo_coords = [float(p) for p in parts[1:]]
|
||||||
|
# 检查YOLO坐标有效性(必须在0-1范围内)
|
||||||
|
if not all(0 <= coord <= 1 for coord in yolo_coords):
|
||||||
|
raise ValueError(f"YOLO坐标超出0-1范围:{yolo_coords}")
|
||||||
|
|
||||||
|
# 4.5 转换坐标并绘制
|
||||||
|
# 检查类别ID是否在配置中
|
||||||
|
if cls_id not in class_map:
|
||||||
|
tqdm.write(f"⚠️ 跳过:图片 {img_filename} 标签第{line_idx+1}行,未知类别ID {cls_id}(未在CLASS_CONFIG中配置)")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 获取类别名称和颜色
|
||||||
|
cls_name, cls_color = class_map[cls_id]
|
||||||
|
# 转换YOLO坐标为像素坐标
|
||||||
|
bbox = yolo2pixel(yolo_coords, img_w, img_h)
|
||||||
|
# 绘制标注
|
||||||
|
annotated_img = draw_annotation(annotated_img, bbox, cls_name, cls_color)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
tqdm.write(f"⚠️ 跳过:图片 {img_filename} 标签第{line_idx+1}行解析失败 → {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 4.6 保存标注后的图片
|
||||||
|
output_img_path = os.path.join(OUTPUT_IMAGE_DIR, f"{img_name}_annotated{img_ext}")
|
||||||
|
# 保存为JPG格式(若原始是PNG,也可改为img_ext保持原格式)
|
||||||
|
# 注:JPG不支持透明通道,若原始是PNG且有透明,建议保留img_ext
|
||||||
|
cv2.imwrite(output_img_path, annotated_img)
|
||||||
|
|
||||||
|
# 5. 完成提示
|
||||||
|
print(f"\n✅ 标注完成!共处理 {len(image_files)} 张图片,标注后的图片已保存到:")
|
||||||
|
print(f" {OUTPUT_IMAGE_DIR}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
main()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ 程序异常终止:{str(e)}")
|
||||||
28
视频抽帧.py
28
视频抽帧.py
@ -6,23 +6,23 @@ from pathlib import Path
|
|||||||
|
|
||||||
# -------------------------- 1. 用户配置(根据实际情况修改)--------------------------
|
# -------------------------- 1. 用户配置(根据实际情况修改)--------------------------
|
||||||
# 视频文件路径(绝对路径或相对路径均可,支持mp4/avi/mov等常见格式)
|
# 视频文件路径(绝对路径或相对路径均可,支持mp4/avi/mov等常见格式)
|
||||||
VIDEO_PATH = r"E:\geminicli\yolo\images\44.mp4" # 替换为你的视频路径
|
# VIDEO_PATH = r"E:\geminicli\yolo\images\屏幕操作视频.mp4" # 替换为你的视频路径
|
||||||
|
VIDEO_PATH = r"D:\DataPreHandler\MP4\花生底图.mp4"
|
||||||
# 输出根目录(脚本会自动在该目录下创建train/valid/test子文件夹)
|
# 输出根目录(脚本会自动在该目录下创建train/val/test子文件夹)
|
||||||
OUTPUT_ROOT_DIR = r"D:\DataPreHandler\images" # 替换为你的输出根路径
|
OUTPUT_ROOT_DIR = r"D:\DataPreHandler\images\dituchoqu_huashen" # 替换为你的输出根路径
|
||||||
|
|
||||||
# 各文件夹需抽取的图片数量(总数量=2800+1400+700=4900)
|
# 各文件夹需抽取的图片数量(总数量=2800+1400+700=4900)
|
||||||
FRAME_COUNTS = {
|
FRAME_COUNTS = {
|
||||||
"train": 28,
|
"train": 1200,
|
||||||
"valid": 14,
|
"val": 150,
|
||||||
"test": 7
|
"test": 50
|
||||||
}
|
}
|
||||||
|
|
||||||
# 图片保存格式(建议用jpg,兼容性更好;也可改为png)
|
# 图片保存格式(建议用jpg,兼容性更好;也可改为png)
|
||||||
SAVE_FORMAT = "jpg"
|
SAVE_FORMAT = "jpg"
|
||||||
|
|
||||||
# 图片文件名编号位数(如0001.jpg,避免文件名乱序)
|
# 图片文件名编号位数(如0001.jpg,避免文件名乱序)
|
||||||
FILE_NUM_DIGITS = 4 # 对应最大编号9999,满足4900张需求
|
FILE_NUM_DIGITS = 5 # 对应最大编号9999,满足4900张需求
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------------
|
||||||
@ -62,13 +62,13 @@ def check_video_validity(cap):
|
|||||||
print(f" - 帧率(FPS):{fps:.1f}")
|
print(f" - 帧率(FPS):{fps:.1f}")
|
||||||
print(f" - 总时长:{video_duration // 60:.0f}分{video_duration % 60:.1f}秒")
|
print(f" - 总时长:{video_duration // 60:.0f}分{video_duration % 60:.1f}秒")
|
||||||
print(
|
print(
|
||||||
f" - 需抽取总帧数:{required_total}(train:{FRAME_COUNTS['train']}, valid:{FRAME_COUNTS['valid']}, test:{FRAME_COUNTS['test']})")
|
f" - 需抽取总帧数:{required_total}(train:{FRAME_COUNTS['train']}, val:{FRAME_COUNTS['val']}, test:{FRAME_COUNTS['test']})")
|
||||||
|
|
||||||
return total_frames
|
return total_frames
|
||||||
|
|
||||||
|
|
||||||
def create_output_dirs():
|
def create_output_dirs():
|
||||||
"""创建输出根目录及train/valid/test子文件夹"""
|
"""创建输出根目录及train/val/test子文件夹"""
|
||||||
# 转换为Path对象,适配Windows/Linux/macOS路径格式
|
# 转换为Path对象,适配Windows/Linux/macOS路径格式
|
||||||
output_root = Path(OUTPUT_ROOT_DIR)
|
output_root = Path(OUTPUT_ROOT_DIR)
|
||||||
subdirs = FRAME_COUNTS.keys()
|
subdirs = FRAME_COUNTS.keys()
|
||||||
@ -97,14 +97,14 @@ def generate_random_frame_indices(total_frames, required_total):
|
|||||||
|
|
||||||
|
|
||||||
def split_indices_by_dataset(random_indices):
|
def split_indices_by_dataset(random_indices):
|
||||||
"""将总随机索引按train/valid/test的数量拆分"""
|
"""将总随机索引按train/val/test的数量拆分"""
|
||||||
train_count = FRAME_COUNTS["train"]
|
train_count = FRAME_COUNTS["train"]
|
||||||
valid_count = FRAME_COUNTS["valid"]
|
valid_count = FRAME_COUNTS["val"]
|
||||||
|
|
||||||
# 拆分逻辑:前N个给train,中间M个给valid,剩余给test
|
# 拆分逻辑:前N个给train,中间M个给valid,剩余给test
|
||||||
indices_split = {
|
indices_split = {
|
||||||
"train": random_indices[:train_count],
|
"train": random_indices[:train_count],
|
||||||
"valid": random_indices[train_count:train_count + valid_count],
|
"val": random_indices[train_count:train_count + valid_count],
|
||||||
"test": random_indices[train_count + valid_count:]
|
"test": random_indices[train_count + valid_count:]
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -140,7 +140,7 @@ def extract_and_save_frames(cap, indices_split, output_dirs):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# 生成文件名(如train_0001.jpg)
|
# 生成文件名(如train_0001.jpg)
|
||||||
file_name = f"{dataset}_{idx:0{FILE_NUM_DIGITS}d}.{SAVE_FORMAT}" # 0{FILE_NUM_DIGITS}d表示补零到指定位数
|
file_name = f"dths_{dataset}_{idx:0{FILE_NUM_DIGITS}d}.{SAVE_FORMAT}" # 0{FILE_NUM_DIGITS}d表示补零到指定位数
|
||||||
save_path = dataset_dir / file_name
|
save_path = dataset_dir / file_name
|
||||||
|
|
||||||
# 保存帧为图片(cv2.imwrite默认保存为BGR格式,符合图片存储标准)
|
# 保存帧为图片(cv2.imwrite默认保存为BGR格式,符合图片存储标准)
|
||||||
|
|||||||
Reference in New Issue
Block a user