Files
prehadler/抽取样本数量.py
2025-10-10 11:39:23 +08:00

217 lines
8.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import random
import shutil
from collections import defaultdict
def add_balanced_samples(source_dir, target_dir, add_samples):
"""
在已有目标数据集基础上添加新样本(不重复),并保持类别平衡
参数:
source_dir: 源数据集目录包含images和labels
target_dir: 目标数据集目录已有样本包含images和labels
add_samples: 需要新增的样本数量
"""
# 定义类别名称映射
class_names = {
0: "Abdomen",
1: "Hips",
2: "Chest",
3: "vulva",
4: "back",
5: "penis",
6: "Horror"
}
num_classes = len(class_names)
# 确保目标目录存在
os.makedirs(os.path.join(target_dir, "images"), exist_ok=True)
os.makedirs(os.path.join(target_dir, "labels"), exist_ok=True)
# --------------------------
# 关键步骤1获取目标目录中已有的文件避免重复
# --------------------------
existing_images = set()
target_image_dir = os.path.join(target_dir, "images")
for f in os.listdir(target_image_dir):
if os.path.isfile(os.path.join(target_image_dir, f)):
existing_images.add(f) # 记录已存在的图片文件名
print(f"目标目录中已存在 {len(existing_images)} 个样本,将避免重复添加")
# --------------------------
# 关键步骤2获取源目录中未被目标目录包含的可用文件
# --------------------------
source_image_dir = os.path.join(source_dir, "images")
source_label_dir = os.path.join(source_dir, "labels")
if not os.path.exists(source_image_dir) or not os.path.exists(source_label_dir):
print(f"错误: 源目录 {source_dir} 中缺少images或labels子目录")
return
# 源目录所有图片中,排除已存在于目标目录的文件
candidate_images = [
f for f in os.listdir(source_image_dir)
if os.path.isfile(os.path.join(source_image_dir, f)) and f not in existing_images
]
if not candidate_images:
print(f"错误: 源目录中没有可添加的新样本(所有样本已存在于目标目录)")
return
# --------------------------
# 分析候选文件的类别兼容浮点数ID
# --------------------------
file_classes = {} # 候选文件名 -> 包含的类别集合
class_files = defaultdict(list) # 类别 -> 包含该类别的候选文件列表
for img_file in candidate_images:
base_name = os.path.splitext(img_file)[0]
label_file = f"{base_name}.txt"
label_path = os.path.join(source_label_dir, label_file)
if not os.path.exists(label_path):
print(f"警告: 图片 {img_file} 对应的标签文件 {label_file} 不存在,已跳过")
continue
classes_in_file = set()
with open(label_path, 'r') as f:
for line in f:
parts = line.strip().split()
if not parts:
continue
try:
class_id = int(float(parts[0])) # 兼容浮点数类别ID
if class_id in class_names:
classes_in_file.add(class_id)
else:
print(f"警告: 标签文件 {label_file} 包含未知类别 {class_id}")
except (ValueError, IndexError):
print(f"警告: 标签文件 {label_file} 格式不正确,行内容: {line.strip()}")
if not classes_in_file:
print(f"警告: 标签文件 {label_file} 无有效类别,已跳过")
continue
file_classes[img_file] = classes_in_file
for class_id in classes_in_file:
class_files[class_id].append(img_file)
# 检查候选文件中是否有类别缺失
for class_id in class_names:
if class_id not in class_files or len(class_files[class_id]) == 0:
print(f"警告: 候选文件中没有类别 {class_id} ({class_names[class_id]}) 的样本")
# --------------------------
# 处理新增样本数量(不超过候选文件数量)
# --------------------------
available_candidates = len(file_classes)
if add_samples > available_candidates:
print(f"警告: 请求新增 {add_samples} 个样本,但候选文件仅 {available_candidates} 个,将全部添加")
add_samples = available_candidates
elif add_samples <= 0:
print("错误: 新增样本数量必须大于0")
return
# --------------------------
# 平衡抽取新增样本
# --------------------------
selected_files = set()
remaining = add_samples
# 每个类别理想的新增数量(基于总新增数量平均)
ideal_per_class = max(1, add_samples // num_classes)
# 先为每个类别抽取理想数量的样本
for class_id in class_names:
if class_id not in class_files:
continue
# 候选文件中未被选中的
available = [f for f in class_files[class_id] if f not in selected_files]
num_to_add = min(ideal_per_class, len(available), remaining)
if num_to_add > 0:
selected = random.sample(available, num_to_add)
selected_files.update(selected)
remaining -= num_to_add
if remaining <= 0:
break
# 补充剩余需要新增的样本
if remaining > 0:
all_available = [f for f in file_classes if f not in selected_files]
num_to_add = min(remaining, len(all_available))
if num_to_add > 0:
selected_additional = random.sample(all_available, num_to_add)
selected_files.update(selected_additional)
# --------------------------
# 复制新增样本到目标目录(确保不重复)
# --------------------------
for img_file in selected_files:
# 复制图片
src_img = os.path.join(source_image_dir, img_file)
dst_img = os.path.join(target_image_dir, img_file)
shutil.copy2(src_img, dst_img)
# 复制标签
base_name = os.path.splitext(img_file)[0]
label_file = f"{base_name}.txt"
src_label = os.path.join(source_label_dir, label_file)
dst_label = os.path.join(target_dir, "labels", label_file)
if os.path.exists(src_label):
shutil.copy2(src_label, dst_label)
# --------------------------
# 统计结果(包含目标目录原有+新增的总分布)
# --------------------------
# 1. 统计目标目录原有样本的类别
existing_class_counts = defaultdict(int)
for img_file in existing_images:
# 从源目录找原有样本的标签(因为目标目录标签可能已存在,但源目录更可靠)
base_name = os.path.splitext(img_file)[0]
label_file = f"{base_name}.txt"
label_path = os.path.join(source_label_dir, label_file)
if os.path.exists(label_path):
with open(label_path, 'r') as f:
for line in f:
parts = line.strip().split()
if parts:
try:
class_id = int(float(parts[0]))
if class_id in class_names:
existing_class_counts[class_id] += 1
except (ValueError, IndexError):
pass # 忽略格式错误的行
# 2. 统计新增样本的类别
new_class_counts = defaultdict(int)
for img_file in selected_files:
for class_id in file_classes[img_file]:
new_class_counts[class_id] += 1
# 3. 总分布(原有+新增)
total_class_counts = {
class_id: existing_class_counts.get(class_id, 0) + new_class_counts.get(class_id, 0)
for class_id in class_names
}
# 输出结果
print(f"\n已成功新增 {len(selected_files)} 个样本,目标目录总样本数: {len(existing_images) + len(selected_files)}")
print("目标目录总类别分布(原有+新增):")
for class_id in class_names:
print(f"类别 {class_id} ({class_names[class_id]}): {total_class_counts[class_id]} 次出现")
if __name__ == "__main__":
# 源数据集路径(从中抽取新样本)
source_directory = r"D:\DataPreHandler\data\train3"
# 目标数据集路径(已有样本,将新增到这里)
target_directory = r"D:\DataPreHandler\data\train2"
# 需要新增的样本数量(根据需求修改)
add_samples = 1200 # 例如再添加100个新样本
# 执行新增操作
add_balanced_samples(source_directory, target_directory, add_samples)