217 lines
8.5 KiB
Python
217 lines
8.5 KiB
Python
|
|
import os
|
|||
|
|
import random
|
|||
|
|
import shutil
|
|||
|
|
from collections import defaultdict
|
|||
|
|
|
|||
|
|
|
|||
|
|
def add_balanced_samples(source_dir, target_dir, add_samples):
|
|||
|
|
"""
|
|||
|
|
在已有目标数据集基础上添加新样本(不重复),并保持类别平衡
|
|||
|
|
参数:
|
|||
|
|
source_dir: 源数据集目录(包含images和labels)
|
|||
|
|
target_dir: 目标数据集目录(已有样本,包含images和labels)
|
|||
|
|
add_samples: 需要新增的样本数量
|
|||
|
|
"""
|
|||
|
|
# 定义类别名称映射
|
|||
|
|
class_names = {
|
|||
|
|
0: "Abdomen",
|
|||
|
|
1: "Hips",
|
|||
|
|
2: "Chest",
|
|||
|
|
3: "vulva",
|
|||
|
|
4: "back",
|
|||
|
|
5: "penis",
|
|||
|
|
6: "Horror"
|
|||
|
|
}
|
|||
|
|
num_classes = len(class_names)
|
|||
|
|
|
|||
|
|
# 确保目标目录存在
|
|||
|
|
os.makedirs(os.path.join(target_dir, "images"), exist_ok=True)
|
|||
|
|
os.makedirs(os.path.join(target_dir, "labels"), exist_ok=True)
|
|||
|
|
|
|||
|
|
# --------------------------
|
|||
|
|
# 关键步骤1:获取目标目录中已有的文件(避免重复)
|
|||
|
|
# --------------------------
|
|||
|
|
existing_images = set()
|
|||
|
|
target_image_dir = os.path.join(target_dir, "images")
|
|||
|
|
for f in os.listdir(target_image_dir):
|
|||
|
|
if os.path.isfile(os.path.join(target_image_dir, f)):
|
|||
|
|
existing_images.add(f) # 记录已存在的图片文件名
|
|||
|
|
print(f"目标目录中已存在 {len(existing_images)} 个样本,将避免重复添加")
|
|||
|
|
|
|||
|
|
# --------------------------
|
|||
|
|
# 关键步骤2:获取源目录中未被目标目录包含的可用文件
|
|||
|
|
# --------------------------
|
|||
|
|
source_image_dir = os.path.join(source_dir, "images")
|
|||
|
|
source_label_dir = os.path.join(source_dir, "labels")
|
|||
|
|
|
|||
|
|
if not os.path.exists(source_image_dir) or not os.path.exists(source_label_dir):
|
|||
|
|
print(f"错误: 源目录 {source_dir} 中缺少images或labels子目录")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 源目录所有图片中,排除已存在于目标目录的文件
|
|||
|
|
candidate_images = [
|
|||
|
|
f for f in os.listdir(source_image_dir)
|
|||
|
|
if os.path.isfile(os.path.join(source_image_dir, f)) and f not in existing_images
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
if not candidate_images:
|
|||
|
|
print(f"错误: 源目录中没有可添加的新样本(所有样本已存在于目标目录)")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# --------------------------
|
|||
|
|
# 分析候选文件的类别(兼容浮点数ID)
|
|||
|
|
# --------------------------
|
|||
|
|
file_classes = {} # 候选文件名 -> 包含的类别集合
|
|||
|
|
class_files = defaultdict(list) # 类别 -> 包含该类别的候选文件列表
|
|||
|
|
|
|||
|
|
for img_file in candidate_images:
|
|||
|
|
base_name = os.path.splitext(img_file)[0]
|
|||
|
|
label_file = f"{base_name}.txt"
|
|||
|
|
label_path = os.path.join(source_label_dir, label_file)
|
|||
|
|
|
|||
|
|
if not os.path.exists(label_path):
|
|||
|
|
print(f"警告: 图片 {img_file} 对应的标签文件 {label_file} 不存在,已跳过")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
classes_in_file = set()
|
|||
|
|
with open(label_path, 'r') as f:
|
|||
|
|
for line in f:
|
|||
|
|
parts = line.strip().split()
|
|||
|
|
if not parts:
|
|||
|
|
continue
|
|||
|
|
try:
|
|||
|
|
class_id = int(float(parts[0])) # 兼容浮点数类别ID
|
|||
|
|
if class_id in class_names:
|
|||
|
|
classes_in_file.add(class_id)
|
|||
|
|
else:
|
|||
|
|
print(f"警告: 标签文件 {label_file} 包含未知类别 {class_id}")
|
|||
|
|
except (ValueError, IndexError):
|
|||
|
|
print(f"警告: 标签文件 {label_file} 格式不正确,行内容: {line.strip()}")
|
|||
|
|
|
|||
|
|
if not classes_in_file:
|
|||
|
|
print(f"警告: 标签文件 {label_file} 无有效类别,已跳过")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
file_classes[img_file] = classes_in_file
|
|||
|
|
for class_id in classes_in_file:
|
|||
|
|
class_files[class_id].append(img_file)
|
|||
|
|
|
|||
|
|
# 检查候选文件中是否有类别缺失
|
|||
|
|
for class_id in class_names:
|
|||
|
|
if class_id not in class_files or len(class_files[class_id]) == 0:
|
|||
|
|
print(f"警告: 候选文件中没有类别 {class_id} ({class_names[class_id]}) 的样本")
|
|||
|
|
|
|||
|
|
# --------------------------
|
|||
|
|
# 处理新增样本数量(不超过候选文件数量)
|
|||
|
|
# --------------------------
|
|||
|
|
available_candidates = len(file_classes)
|
|||
|
|
if add_samples > available_candidates:
|
|||
|
|
print(f"警告: 请求新增 {add_samples} 个样本,但候选文件仅 {available_candidates} 个,将全部添加")
|
|||
|
|
add_samples = available_candidates
|
|||
|
|
elif add_samples <= 0:
|
|||
|
|
print("错误: 新增样本数量必须大于0")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# --------------------------
|
|||
|
|
# 平衡抽取新增样本
|
|||
|
|
# --------------------------
|
|||
|
|
selected_files = set()
|
|||
|
|
remaining = add_samples
|
|||
|
|
|
|||
|
|
# 每个类别理想的新增数量(基于总新增数量平均)
|
|||
|
|
ideal_per_class = max(1, add_samples // num_classes)
|
|||
|
|
|
|||
|
|
# 先为每个类别抽取理想数量的样本
|
|||
|
|
for class_id in class_names:
|
|||
|
|
if class_id not in class_files:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 候选文件中未被选中的
|
|||
|
|
available = [f for f in class_files[class_id] if f not in selected_files]
|
|||
|
|
num_to_add = min(ideal_per_class, len(available), remaining)
|
|||
|
|
|
|||
|
|
if num_to_add > 0:
|
|||
|
|
selected = random.sample(available, num_to_add)
|
|||
|
|
selected_files.update(selected)
|
|||
|
|
remaining -= num_to_add
|
|||
|
|
|
|||
|
|
if remaining <= 0:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
# 补充剩余需要新增的样本
|
|||
|
|
if remaining > 0:
|
|||
|
|
all_available = [f for f in file_classes if f not in selected_files]
|
|||
|
|
num_to_add = min(remaining, len(all_available))
|
|||
|
|
if num_to_add > 0:
|
|||
|
|
selected_additional = random.sample(all_available, num_to_add)
|
|||
|
|
selected_files.update(selected_additional)
|
|||
|
|
|
|||
|
|
# --------------------------
|
|||
|
|
# 复制新增样本到目标目录(确保不重复)
|
|||
|
|
# --------------------------
|
|||
|
|
for img_file in selected_files:
|
|||
|
|
# 复制图片
|
|||
|
|
src_img = os.path.join(source_image_dir, img_file)
|
|||
|
|
dst_img = os.path.join(target_image_dir, img_file)
|
|||
|
|
shutil.copy2(src_img, dst_img)
|
|||
|
|
|
|||
|
|
# 复制标签
|
|||
|
|
base_name = os.path.splitext(img_file)[0]
|
|||
|
|
label_file = f"{base_name}.txt"
|
|||
|
|
src_label = os.path.join(source_label_dir, label_file)
|
|||
|
|
dst_label = os.path.join(target_dir, "labels", label_file)
|
|||
|
|
if os.path.exists(src_label):
|
|||
|
|
shutil.copy2(src_label, dst_label)
|
|||
|
|
|
|||
|
|
# --------------------------
|
|||
|
|
# 统计结果(包含目标目录原有+新增的总分布)
|
|||
|
|
# --------------------------
|
|||
|
|
# 1. 统计目标目录原有样本的类别
|
|||
|
|
existing_class_counts = defaultdict(int)
|
|||
|
|
for img_file in existing_images:
|
|||
|
|
# 从源目录找原有样本的标签(因为目标目录标签可能已存在,但源目录更可靠)
|
|||
|
|
base_name = os.path.splitext(img_file)[0]
|
|||
|
|
label_file = f"{base_name}.txt"
|
|||
|
|
label_path = os.path.join(source_label_dir, label_file)
|
|||
|
|
if os.path.exists(label_path):
|
|||
|
|
with open(label_path, 'r') as f:
|
|||
|
|
for line in f:
|
|||
|
|
parts = line.strip().split()
|
|||
|
|
if parts:
|
|||
|
|
try:
|
|||
|
|
class_id = int(float(parts[0]))
|
|||
|
|
if class_id in class_names:
|
|||
|
|
existing_class_counts[class_id] += 1
|
|||
|
|
except (ValueError, IndexError):
|
|||
|
|
pass # 忽略格式错误的行
|
|||
|
|
|
|||
|
|
# 2. 统计新增样本的类别
|
|||
|
|
new_class_counts = defaultdict(int)
|
|||
|
|
for img_file in selected_files:
|
|||
|
|
for class_id in file_classes[img_file]:
|
|||
|
|
new_class_counts[class_id] += 1
|
|||
|
|
|
|||
|
|
# 3. 总分布(原有+新增)
|
|||
|
|
total_class_counts = {
|
|||
|
|
class_id: existing_class_counts.get(class_id, 0) + new_class_counts.get(class_id, 0)
|
|||
|
|
for class_id in class_names
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 输出结果
|
|||
|
|
print(f"\n已成功新增 {len(selected_files)} 个样本,目标目录总样本数: {len(existing_images) + len(selected_files)}")
|
|||
|
|
print("目标目录总类别分布(原有+新增):")
|
|||
|
|
for class_id in class_names:
|
|||
|
|
print(f"类别 {class_id} ({class_names[class_id]}): {total_class_counts[class_id]} 次出现")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
# 源数据集路径(从中抽取新样本)
|
|||
|
|
source_directory = r"D:\DataPreHandler\data\train3"
|
|||
|
|
# 目标数据集路径(已有样本,将新增到这里)
|
|||
|
|
target_directory = r"D:\DataPreHandler\data\train2"
|
|||
|
|
|
|||
|
|
# 需要新增的样本数量(根据需求修改)
|
|||
|
|
add_samples = 1200 # 例如:再添加100个新样本
|
|||
|
|
|
|||
|
|
# 执行新增操作
|
|||
|
|
add_balanced_samples(source_directory, target_directory, add_samples)
|