Files
prehadler/抽取样本数量.py

217 lines
8.5 KiB
Python
Raw Permalink Normal View History

2025-10-10 11:39:23 +08:00
import os
import random
import shutil
from collections import defaultdict
def add_balanced_samples(source_dir, target_dir, add_samples):
"""
在已有目标数据集基础上添加新样本不重复并保持类别平衡
参数:
source_dir: 源数据集目录包含images和labels
target_dir: 目标数据集目录已有样本包含images和labels
add_samples: 需要新增的样本数量
"""
# 定义类别名称映射
class_names = {
0: "Abdomen",
1: "Hips",
2: "Chest",
3: "vulva",
4: "back",
5: "penis",
6: "Horror"
}
num_classes = len(class_names)
# 确保目标目录存在
os.makedirs(os.path.join(target_dir, "images"), exist_ok=True)
os.makedirs(os.path.join(target_dir, "labels"), exist_ok=True)
# --------------------------
# 关键步骤1获取目标目录中已有的文件避免重复
# --------------------------
existing_images = set()
target_image_dir = os.path.join(target_dir, "images")
for f in os.listdir(target_image_dir):
if os.path.isfile(os.path.join(target_image_dir, f)):
existing_images.add(f) # 记录已存在的图片文件名
print(f"目标目录中已存在 {len(existing_images)} 个样本,将避免重复添加")
# --------------------------
# 关键步骤2获取源目录中未被目标目录包含的可用文件
# --------------------------
source_image_dir = os.path.join(source_dir, "images")
source_label_dir = os.path.join(source_dir, "labels")
if not os.path.exists(source_image_dir) or not os.path.exists(source_label_dir):
print(f"错误: 源目录 {source_dir} 中缺少images或labels子目录")
return
# 源目录所有图片中,排除已存在于目标目录的文件
candidate_images = [
f for f in os.listdir(source_image_dir)
if os.path.isfile(os.path.join(source_image_dir, f)) and f not in existing_images
]
if not candidate_images:
print(f"错误: 源目录中没有可添加的新样本(所有样本已存在于目标目录)")
return
# --------------------------
# 分析候选文件的类别兼容浮点数ID
# --------------------------
file_classes = {} # 候选文件名 -> 包含的类别集合
class_files = defaultdict(list) # 类别 -> 包含该类别的候选文件列表
for img_file in candidate_images:
base_name = os.path.splitext(img_file)[0]
label_file = f"{base_name}.txt"
label_path = os.path.join(source_label_dir, label_file)
if not os.path.exists(label_path):
print(f"警告: 图片 {img_file} 对应的标签文件 {label_file} 不存在,已跳过")
continue
classes_in_file = set()
with open(label_path, 'r') as f:
for line in f:
parts = line.strip().split()
if not parts:
continue
try:
class_id = int(float(parts[0])) # 兼容浮点数类别ID
if class_id in class_names:
classes_in_file.add(class_id)
else:
print(f"警告: 标签文件 {label_file} 包含未知类别 {class_id}")
except (ValueError, IndexError):
print(f"警告: 标签文件 {label_file} 格式不正确,行内容: {line.strip()}")
if not classes_in_file:
print(f"警告: 标签文件 {label_file} 无有效类别,已跳过")
continue
file_classes[img_file] = classes_in_file
for class_id in classes_in_file:
class_files[class_id].append(img_file)
# 检查候选文件中是否有类别缺失
for class_id in class_names:
if class_id not in class_files or len(class_files[class_id]) == 0:
print(f"警告: 候选文件中没有类别 {class_id} ({class_names[class_id]}) 的样本")
# --------------------------
# 处理新增样本数量(不超过候选文件数量)
# --------------------------
available_candidates = len(file_classes)
if add_samples > available_candidates:
print(f"警告: 请求新增 {add_samples} 个样本,但候选文件仅 {available_candidates} 个,将全部添加")
add_samples = available_candidates
elif add_samples <= 0:
print("错误: 新增样本数量必须大于0")
return
# --------------------------
# 平衡抽取新增样本
# --------------------------
selected_files = set()
remaining = add_samples
# 每个类别理想的新增数量(基于总新增数量平均)
ideal_per_class = max(1, add_samples // num_classes)
# 先为每个类别抽取理想数量的样本
for class_id in class_names:
if class_id not in class_files:
continue
# 候选文件中未被选中的
available = [f for f in class_files[class_id] if f not in selected_files]
num_to_add = min(ideal_per_class, len(available), remaining)
if num_to_add > 0:
selected = random.sample(available, num_to_add)
selected_files.update(selected)
remaining -= num_to_add
if remaining <= 0:
break
# 补充剩余需要新增的样本
if remaining > 0:
all_available = [f for f in file_classes if f not in selected_files]
num_to_add = min(remaining, len(all_available))
if num_to_add > 0:
selected_additional = random.sample(all_available, num_to_add)
selected_files.update(selected_additional)
# --------------------------
# 复制新增样本到目标目录(确保不重复)
# --------------------------
for img_file in selected_files:
# 复制图片
src_img = os.path.join(source_image_dir, img_file)
dst_img = os.path.join(target_image_dir, img_file)
shutil.copy2(src_img, dst_img)
# 复制标签
base_name = os.path.splitext(img_file)[0]
label_file = f"{base_name}.txt"
src_label = os.path.join(source_label_dir, label_file)
dst_label = os.path.join(target_dir, "labels", label_file)
if os.path.exists(src_label):
shutil.copy2(src_label, dst_label)
# --------------------------
# 统计结果(包含目标目录原有+新增的总分布)
# --------------------------
# 1. 统计目标目录原有样本的类别
existing_class_counts = defaultdict(int)
for img_file in existing_images:
# 从源目录找原有样本的标签(因为目标目录标签可能已存在,但源目录更可靠)
base_name = os.path.splitext(img_file)[0]
label_file = f"{base_name}.txt"
label_path = os.path.join(source_label_dir, label_file)
if os.path.exists(label_path):
with open(label_path, 'r') as f:
for line in f:
parts = line.strip().split()
if parts:
try:
class_id = int(float(parts[0]))
if class_id in class_names:
existing_class_counts[class_id] += 1
except (ValueError, IndexError):
pass # 忽略格式错误的行
# 2. 统计新增样本的类别
new_class_counts = defaultdict(int)
for img_file in selected_files:
for class_id in file_classes[img_file]:
new_class_counts[class_id] += 1
# 3. 总分布(原有+新增)
total_class_counts = {
class_id: existing_class_counts.get(class_id, 0) + new_class_counts.get(class_id, 0)
for class_id in class_names
}
# 输出结果
print(f"\n已成功新增 {len(selected_files)} 个样本,目标目录总样本数: {len(existing_images) + len(selected_files)}")
print("目标目录总类别分布(原有+新增):")
for class_id in class_names:
print(f"类别 {class_id} ({class_names[class_id]}): {total_class_counts[class_id]} 次出现")
if __name__ == "__main__":
# 源数据集路径(从中抽取新样本)
source_directory = r"D:\DataPreHandler\data\train3"
# 目标数据集路径(已有样本,将新增到这里)
target_directory = r"D:\DataPreHandler\data\train2"
# 需要新增的样本数量(根据需求修改)
add_samples = 1200 # 例如再添加100个新样本
# 执行新增操作
add_balanced_samples(source_directory, target_directory, add_samples)