import os import random import shutil from collections import defaultdict def add_balanced_samples(source_dir, target_dir, add_samples): """ 在已有目标数据集基础上添加新样本(不重复),并保持类别平衡 参数: source_dir: 源数据集目录(包含images和labels) target_dir: 目标数据集目录(已有样本,包含images和labels) add_samples: 需要新增的样本数量 """ # 定义类别名称映射 class_names = { 0: "Abdomen", 1: "Hips", 2: "Chest", 3: "vulva", 4: "back", 5: "penis", 6: "Horror" } num_classes = len(class_names) # 确保目标目录存在 os.makedirs(os.path.join(target_dir, "images"), exist_ok=True) os.makedirs(os.path.join(target_dir, "labels"), exist_ok=True) # -------------------------- # 关键步骤1:获取目标目录中已有的文件(避免重复) # -------------------------- existing_images = set() target_image_dir = os.path.join(target_dir, "images") for f in os.listdir(target_image_dir): if os.path.isfile(os.path.join(target_image_dir, f)): existing_images.add(f) # 记录已存在的图片文件名 print(f"目标目录中已存在 {len(existing_images)} 个样本,将避免重复添加") # -------------------------- # 关键步骤2:获取源目录中未被目标目录包含的可用文件 # -------------------------- source_image_dir = os.path.join(source_dir, "images") source_label_dir = os.path.join(source_dir, "labels") if not os.path.exists(source_image_dir) or not os.path.exists(source_label_dir): print(f"错误: 源目录 {source_dir} 中缺少images或labels子目录") return # 源目录所有图片中,排除已存在于目标目录的文件 candidate_images = [ f for f in os.listdir(source_image_dir) if os.path.isfile(os.path.join(source_image_dir, f)) and f not in existing_images ] if not candidate_images: print(f"错误: 源目录中没有可添加的新样本(所有样本已存在于目标目录)") return # -------------------------- # 分析候选文件的类别(兼容浮点数ID) # -------------------------- file_classes = {} # 候选文件名 -> 包含的类别集合 class_files = defaultdict(list) # 类别 -> 包含该类别的候选文件列表 for img_file in candidate_images: base_name = os.path.splitext(img_file)[0] label_file = f"{base_name}.txt" label_path = os.path.join(source_label_dir, label_file) if not os.path.exists(label_path): print(f"警告: 图片 {img_file} 对应的标签文件 {label_file} 不存在,已跳过") continue classes_in_file = set() with open(label_path, 'r') as f: for line in f: parts = line.strip().split() if not parts: continue try: class_id = int(float(parts[0])) # 兼容浮点数类别ID if class_id in class_names: classes_in_file.add(class_id) else: print(f"警告: 标签文件 {label_file} 包含未知类别 {class_id}") except (ValueError, IndexError): print(f"警告: 标签文件 {label_file} 格式不正确,行内容: {line.strip()}") if not classes_in_file: print(f"警告: 标签文件 {label_file} 无有效类别,已跳过") continue file_classes[img_file] = classes_in_file for class_id in classes_in_file: class_files[class_id].append(img_file) # 检查候选文件中是否有类别缺失 for class_id in class_names: if class_id not in class_files or len(class_files[class_id]) == 0: print(f"警告: 候选文件中没有类别 {class_id} ({class_names[class_id]}) 的样本") # -------------------------- # 处理新增样本数量(不超过候选文件数量) # -------------------------- available_candidates = len(file_classes) if add_samples > available_candidates: print(f"警告: 请求新增 {add_samples} 个样本,但候选文件仅 {available_candidates} 个,将全部添加") add_samples = available_candidates elif add_samples <= 0: print("错误: 新增样本数量必须大于0") return # -------------------------- # 平衡抽取新增样本 # -------------------------- selected_files = set() remaining = add_samples # 每个类别理想的新增数量(基于总新增数量平均) ideal_per_class = max(1, add_samples // num_classes) # 先为每个类别抽取理想数量的样本 for class_id in class_names: if class_id not in class_files: continue # 候选文件中未被选中的 available = [f for f in class_files[class_id] if f not in selected_files] num_to_add = min(ideal_per_class, len(available), remaining) if num_to_add > 0: selected = random.sample(available, num_to_add) selected_files.update(selected) remaining -= num_to_add if remaining <= 0: break # 补充剩余需要新增的样本 if remaining > 0: all_available = [f for f in file_classes if f not in selected_files] num_to_add = min(remaining, len(all_available)) if num_to_add > 0: selected_additional = random.sample(all_available, num_to_add) selected_files.update(selected_additional) # -------------------------- # 复制新增样本到目标目录(确保不重复) # -------------------------- for img_file in selected_files: # 复制图片 src_img = os.path.join(source_image_dir, img_file) dst_img = os.path.join(target_image_dir, img_file) shutil.copy2(src_img, dst_img) # 复制标签 base_name = os.path.splitext(img_file)[0] label_file = f"{base_name}.txt" src_label = os.path.join(source_label_dir, label_file) dst_label = os.path.join(target_dir, "labels", label_file) if os.path.exists(src_label): shutil.copy2(src_label, dst_label) # -------------------------- # 统计结果(包含目标目录原有+新增的总分布) # -------------------------- # 1. 统计目标目录原有样本的类别 existing_class_counts = defaultdict(int) for img_file in existing_images: # 从源目录找原有样本的标签(因为目标目录标签可能已存在,但源目录更可靠) base_name = os.path.splitext(img_file)[0] label_file = f"{base_name}.txt" label_path = os.path.join(source_label_dir, label_file) if os.path.exists(label_path): with open(label_path, 'r') as f: for line in f: parts = line.strip().split() if parts: try: class_id = int(float(parts[0])) if class_id in class_names: existing_class_counts[class_id] += 1 except (ValueError, IndexError): pass # 忽略格式错误的行 # 2. 统计新增样本的类别 new_class_counts = defaultdict(int) for img_file in selected_files: for class_id in file_classes[img_file]: new_class_counts[class_id] += 1 # 3. 总分布(原有+新增) total_class_counts = { class_id: existing_class_counts.get(class_id, 0) + new_class_counts.get(class_id, 0) for class_id in class_names } # 输出结果 print(f"\n已成功新增 {len(selected_files)} 个样本,目标目录总样本数: {len(existing_images) + len(selected_files)}") print("目标目录总类别分布(原有+新增):") for class_id in class_names: print(f"类别 {class_id} ({class_names[class_id]}): {total_class_counts[class_id]} 次出现") if __name__ == "__main__": # 源数据集路径(从中抽取新样本) source_directory = r"D:\DataPreHandler\data\train3" # 目标数据集路径(已有样本,将新增到这里) target_directory = r"D:\DataPreHandler\data\train2" # 需要新增的样本数量(根据需求修改) add_samples = 1200 # 例如:再添加100个新样本 # 执行新增操作 add_balanced_samples(source_directory, target_directory, add_samples)