import cv2 import numpy as np import os import random import albumentations as A from tqdm import tqdm # --- 1. 用户配置(重点修改!!!)--- # 请根据你的实际路径修改,三个核心目录需区分清楚: # 1. 特征素材来源目录:存放有「待粘贴目标(如no_helmet)」的图片和标签(用于提取可粘贴的目标) SOURCE_FEATURE_IMAGE_DIR = r"E:\NSFW-Detection-YOLO\data\images\val\images" # 有目标的原图 SOURCE_FEATURE_LABEL_DIR = r"E:\NSFW-Detection-YOLO\data\images\val\labels" # 对应原图的标签 # 2. 独立底图目录:存放你要粘贴目标的「空白/背景底图」(底图无需标签) BASE_IMAGE_DIR = r"D:\DataPreHandler\images\valid" # 你的底图文件夹 # 3. 输出目录:保存最终增强后的图片和标签 OUTPUT_IMAGE_DIR = r"D:\DataPreHandler\data\valid\images" OUTPUT_LABEL_DIR = r"D:\DataPreHandler\data\valid\labels" # 数据增强参数 AUGMENTATION_FACTOR = 1 # 每张底图生成的增强图数量(如40张) # --- Copy-Paste 核心配置 --- SMALL_OBJECT_CLASSES_TO_PASTE = [0,1,2,3,4,5,6] # 要粘贴的目标类别ID(如no_helmet是2) PASTE_COUNT_RANGE = (5, 10) # 每张增强图上粘贴的目标数量(随机5-10个) # --- 2. 常规增强流水线(修复Albumentations参数)--- transform_geometric = A.Compose([ A.HorizontalFlip(p=0.5), # 修改1:A.Affine参数:rotate_limit→rotate,cval→pad_val,新增border_mode A.Affine(scale=(0.8, 1.2), shear=(-10, 10), translate_percent=0.1, rotate=30, border_mode=cv2.BORDER_CONSTANT, pad_val=0, p=0.8), A.Perspective(scale=(0.02, 0.05), p=0.4), ], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels'], min_visibility=0.25)) transform_quality = A.Compose([ A.RandomBrightnessContrast(brightness_limit=0.25, contrast_limit=0.25, p=0.8), A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.7), # 修改2:A.GaussNoise参数:var_limit→std_limit(方差转标准差,数值取平方根近似) A.OneOf([A.GaussNoise(std_limit=(3.0, 8.0), p=1.0), A.ISONoise(p=1.0)], p=0.6), A.OneOf([A.Blur(blur_limit=(3, 7), p=1.0), A.MotionBlur(blur_limit=(3, 7), p=1.0)], p=0.5), # 修改3:A.ImageCompression参数:quality_lower/upper→quality_range(合并为元组) A.ImageCompression(quality_range=(70, 95), p=0.3), ], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels'], min_visibility=0.25)) transform_mixed = A.Compose([ A.HorizontalFlip(p=0.5), # 修改4:A.Rotate参数:value→pad_val A.Rotate(limit=15, p=0.5, border_mode=cv2.BORDER_CONSTANT, pad_val=0), A.RandomBrightnessContrast(p=0.6), A.GaussNoise(std_limit=(2.0, 6.0), p=0.4), # 同步修改GaussNoise参数 A.Blur(blur_limit=3, p=0.3), ], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels'], min_visibility=0.25)) base_transforms = [transform_geometric, transform_quality, transform_mixed] # 随机选择增强策略 # --- 3. 核心工具函数 --- def harvest_objects_for_pasting(feature_image_dir, feature_label_dir, target_classes): """ 从「特征素材来源目录」提取目标,创建可粘贴的素材库 :param feature_image_dir: 有目标的图片目录(如含no_helmet的原图) :param feature_label_dir: 对应图片的标签目录 :param target_classes: 要提取的目标类别(如[2]) :return: 素材库 {类别ID: [目标图像1, 目标图像2, ...]} """ print(f"正在从 {feature_image_dir} 提取目标类别 {target_classes}...") asset_library = {cls_id: [] for cls_id in target_classes} # 只读取特征素材目录中的图片文件 feature_image_files = [f for f in os.listdir(feature_image_dir) if f.lower().endswith(('.jpg', '.png', '.jpeg'))] if not feature_image_files: raise FileNotFoundError(f"特征素材目录 {feature_image_dir} 中未找到图片!") for filename in tqdm(feature_image_files, desc="提取目标素材"): label_file = os.path.splitext(filename)[0] + ".txt" label_path = os.path.join(feature_label_dir, label_file) if not os.path.exists(label_path): continue # 跳过无标签的图片 # 读取图片并获取尺寸 img = cv2.imread(os.path.join(feature_image_dir, filename)) if img is None: tqdm.write(f"警告:无法读取图片 {filename},已跳过") continue img_h, img_w, _ = img.shape # 解析标签,裁剪目标 with open(label_path, 'r') as f: for line in f.readlines(): line = line.strip() if not line: continue parts = line.split() # 修改5:处理标签类别ID为浮点数的情况(如6.0→6):先转float再转int cls_id = int(float(parts[0])) if cls_id not in target_classes: continue # 只保留目标类别 # YOLO归一化坐标转像素坐标(x1,y1:左上角;x2,y2:右下角) x_center, y_center, box_w, box_h = [float(p) for p in parts[1:]] x1 = int((x_center - box_w / 2) * img_w) y1 = int((y_center - box_h / 2) * img_h) x2 = int((x_center + box_w / 2) * img_w) y2 = int((y_center + box_h / 2) * img_h) # 确保坐标在图片范围内,避免裁剪出错 x1, y1 = max(0, x1), max(0, y1) x2, y2 = min(img_w, x2), min(img_h, y2) # 裁剪目标并加入素材库(排除空图像) if x1 < x2 and y1 < y2: cropped_obj = img[y1:y2, x1:x2] if cropped_obj.size > 0: asset_library[cls_id].append(cropped_obj) # 检查素材库是否为空 total_assets = sum(len(v) for v in asset_library.values()) if total_assets == 0: raise ValueError(f"未从特征素材目录提取到任何目标!请检查类别ID {target_classes} 是否正确") print(f"素材库创建完成!共提取 {total_assets} 个目标(类别:{target_classes})") return asset_library def paste_objects_to_base(base_image, asset_library): """ 将素材库中的目标粘贴到单张底图上 :param base_image: 输入的底图(cv2读取的BGR图像) :param asset_library: 目标素材库 :return: 粘贴后的图像、对应的YOLO格式标签(bboxes + labels) """ base_h, base_w, _ = base_image.shape pasted_bboxes = [] # 存储粘贴目标的YOLO bbox pasted_labels = [] # 存储粘贴目标的类别ID # 随机确定本次要粘贴的目标数量 num_to_paste = random.randint(*PASTE_COUNT_RANGE) for _ in range(num_to_paste): # 选择要粘贴的目标类别(只从有素材的类别中选) valid_classes = [cls for cls, assets in asset_library.items() if len(assets) > 0] if not valid_classes: break # 极端情况:素材库临时为空(几乎不会发生) # 随机选择一个目标类别和该类别下的一个素材 target_cls = random.choice(valid_classes) target_obj = random.choice(asset_library[target_cls]) obj_h, obj_w, _ = target_obj.shape # 跳过比底图大的目标(避免粘贴后超出边界) if obj_h >= base_h or obj_w >= base_w: continue # 随机选择粘贴位置(左上角坐标,确保目标完全在底图内) paste_x1 = random.randint(0, base_w - obj_w) paste_y1 = random.randint(0, base_h - obj_h) paste_x2 = paste_x1 + obj_w paste_y2 = paste_y1 + obj_h # 直接用Numpy切片粘贴目标(覆盖底图对应区域) base_image[paste_y1:paste_y2, paste_x1:paste_x2] = target_obj # 计算粘贴目标的YOLO归一化坐标(x_center, y_center, w, h) yolo_x_center = (paste_x1 + obj_w / 2) / base_w yolo_y_center = (paste_y1 + obj_h / 2) / base_h yolo_w = obj_w / base_w yolo_h = obj_h / base_h # 将标签加入列表 pasted_bboxes.append([yolo_x_center, yolo_y_center, yolo_w, yolo_h]) pasted_labels.append(target_cls) return base_image, pasted_bboxes, pasted_labels def main(): # 1. 初始化:创建输出目录 os.makedirs(OUTPUT_IMAGE_DIR, exist_ok=True) os.makedirs(OUTPUT_LABEL_DIR, exist_ok=True) # 2. 第一步:创建目标素材库(从特征素材目录提取可粘贴的目标) try: asset_library = harvest_objects_for_pasting( feature_image_dir=SOURCE_FEATURE_IMAGE_DIR, feature_label_dir=SOURCE_FEATURE_LABEL_DIR, target_classes=SMALL_OBJECT_CLASSES_TO_PASTE ) except (FileNotFoundError, ValueError) as e: print(f"错误:{e}") return # 3. 第二步:获取所有底图(只读取图片文件) base_image_files = [f for f in os.listdir(BASE_IMAGE_DIR) if f.lower().endswith(('.jpg', '.png', '.jpeg'))] if not base_image_files: print(f"错误:底图目录 {BASE_IMAGE_DIR} 中未找到任何图片!") return print(f"\n找到 {len(base_image_files)} 张底图,开始生成增强数据(每张底图生成 {AUGMENTATION_FACTOR} 张)") # 4. 主循环:遍历每张底图,生成增强数据 for base_filename in tqdm(base_image_files, desc="处理底图"): base_name, base_ext = os.path.splitext(base_filename) base_image_path = os.path.join(BASE_IMAGE_DIR, base_filename) # 读取底图(若读取失败则跳过) base_image = cv2.imread(base_image_path) if base_image is None: tqdm.write(f"\n警告:无法读取底图 {base_filename},已跳过") continue # 为当前底图生成 AUGMENTATION_FACTOR 张增强图 for aug_idx in range(AUGMENTATION_FACTOR): # 步骤1:复制底图(避免修改原始底图),并粘贴目标 base_image_copy = base_image.copy() pasted_image, pasted_bboxes, pasted_labels = paste_objects_to_base( base_image=base_image_copy, asset_library=asset_library ) # 步骤2:对粘贴后的图像应用常规增强(Albumentations需要RGB格式) pasted_image_rgb = cv2.cvtColor(pasted_image, cv2.COLOR_BGR2RGB) chosen_transform = random.choice(base_transforms) # 随机选择增强策略 try: # 应用增强(同时处理bbox和label) augmented_result = chosen_transform( image=pasted_image_rgb, bboxes=pasted_bboxes, class_labels=pasted_labels ) final_image_rgb = augmented_result['image'] final_bboxes = augmented_result['bboxes'] final_labels = augmented_result['class_labels'] except Exception as e: tqdm.write(f"\n警告:底图 {base_filename} 增强失败(序号 {aug_idx}):{str(e)}") continue # 步骤3:保存增强后的图片和标签 # 图片命名格式:底图名_aug_序号.jpg(统一转为jpg格式,避免格式混乱) output_img_name = f"{base_name}_aug_{aug_idx}.jpg" output_img_path = os.path.join(OUTPUT_IMAGE_DIR, output_img_name) # RGB转BGR(cv2保存需要BGR格式) cv2.imwrite(output_img_path, cv2.cvtColor(final_image_rgb, cv2.COLOR_RGB2BGR)) # 标签命名格式:与图片同名.txt(YOLO格式) output_label_name = f"{base_name}_aug_{aug_idx}.txt" output_label_path = os.path.join(OUTPUT_LABEL_DIR, output_label_name) with open(output_label_path, 'w') as f: for bbox, label in zip(final_bboxes, final_labels): x_c, y_c, w, h = bbox # 边界检查:排除增强后可能超出0-1范围的bbox(避免训练报错) if 0 <= x_c <= 1 and 0 <= y_c <= 1 and 0 <= w <= 1 and 0 <= h <= 1: f.write(f"{label} {x_c:.6f} {y_c:.6f} {w:.6f} {h:.6f}\n") # 5. 完成提示 total_generated = len(base_image_files) * AUGMENTATION_FACTOR print(f"\n✅ 数据增强全部完成!") print(f"📊 生成数据统计:") print(f" - 底图数量:{len(base_image_files)} 张") print(f" - 每张底图增强次数:{AUGMENTATION_FACTOR} 次") print(f" - 总生成图片/标签:{total_generated} 组") print(f" - 输出路径:") print(f" 图片 → {OUTPUT_IMAGE_DIR}") print(f" 标签 → {OUTPUT_LABEL_DIR}") if __name__ == "__main__": # 运行前务必确认: # 1. SOURCE_FEATURE_IMAGE_DIR/SOURCE_FEATURE_LABEL_DIR 是「有目标的素材目录」 # 2. BASE_IMAGE_DIR 是你的「空白底图目录」 # 3. SMALL_OBJECT_CLASSES_TO_PASTE 是要粘贴的目标类别ID(如no_helmet=2) main()