Files
prehadler/数据增强底图版.py
2025-09-26 10:23:45 +08:00

272 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
import os
import random
import albumentations as A
from tqdm import tqdm
# --- 1. 用户配置(重点修改!!!)---
# 请根据你的实际路径修改,三个核心目录需区分清楚:
# 1. 特征素材来源目录存放有「待粘贴目标如no_helmet」的图片和标签用于提取可粘贴的目标
SOURCE_FEATURE_IMAGE_DIR = r"E:\NSFW-Detection-YOLO\data\images\val\images" # 有目标的原图
SOURCE_FEATURE_LABEL_DIR = r"E:\NSFW-Detection-YOLO\data\images\val\labels" # 对应原图的标签
# 2. 独立底图目录:存放你要粘贴目标的「空白/背景底图」(底图无需标签)
BASE_IMAGE_DIR = r"D:\DataPreHandler\images\valid" # 你的底图文件夹
# 3. 输出目录:保存最终增强后的图片和标签
OUTPUT_IMAGE_DIR = r"D:\DataPreHandler\data\valid\images"
OUTPUT_LABEL_DIR = r"D:\DataPreHandler\data\valid\labels"
# 数据增强参数
AUGMENTATION_FACTOR = 1 # 每张底图生成的增强图数量如40张
# --- Copy-Paste 核心配置 ---
SMALL_OBJECT_CLASSES_TO_PASTE = [0,1,2,3,4,5,6] # 要粘贴的目标类别ID如no_helmet是2
PASTE_COUNT_RANGE = (5, 10) # 每张增强图上粘贴的目标数量随机5-10个
# --- 2. 常规增强流水线修复Albumentations参数---
transform_geometric = A.Compose([
A.HorizontalFlip(p=0.5),
# 修改1A.Affine参数rotate_limit→rotatecval→pad_val新增border_mode
A.Affine(scale=(0.8, 1.2), shear=(-10, 10), translate_percent=0.1,
rotate=30, border_mode=cv2.BORDER_CONSTANT, pad_val=0, p=0.8),
A.Perspective(scale=(0.02, 0.05), p=0.4),
], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels'], min_visibility=0.25))
transform_quality = A.Compose([
A.RandomBrightnessContrast(brightness_limit=0.25, contrast_limit=0.25, p=0.8),
A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.7),
# 修改2A.GaussNoise参数var_limit→std_limit方差转标准差数值取平方根近似
A.OneOf([A.GaussNoise(std_limit=(3.0, 8.0), p=1.0), A.ISONoise(p=1.0)], p=0.6),
A.OneOf([A.Blur(blur_limit=(3, 7), p=1.0), A.MotionBlur(blur_limit=(3, 7), p=1.0)], p=0.5),
# 修改3A.ImageCompression参数quality_lower/upper→quality_range合并为元组
A.ImageCompression(quality_range=(70, 95), p=0.3),
], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels'], min_visibility=0.25))
transform_mixed = A.Compose([
A.HorizontalFlip(p=0.5),
# 修改4A.Rotate参数value→pad_val
A.Rotate(limit=15, p=0.5, border_mode=cv2.BORDER_CONSTANT, pad_val=0),
A.RandomBrightnessContrast(p=0.6),
A.GaussNoise(std_limit=(2.0, 6.0), p=0.4), # 同步修改GaussNoise参数
A.Blur(blur_limit=3, p=0.3),
], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels'], min_visibility=0.25))
base_transforms = [transform_geometric, transform_quality, transform_mixed] # 随机选择增强策略
# --- 3. 核心工具函数 ---
def harvest_objects_for_pasting(feature_image_dir, feature_label_dir, target_classes):
"""
从「特征素材来源目录」提取目标,创建可粘贴的素材库
:param feature_image_dir: 有目标的图片目录如含no_helmet的原图
:param feature_label_dir: 对应图片的标签目录
:param target_classes: 要提取的目标类别(如[2]
:return: 素材库 {类别ID: [目标图像1, 目标图像2, ...]}
"""
print(f"正在从 {feature_image_dir} 提取目标类别 {target_classes}...")
asset_library = {cls_id: [] for cls_id in target_classes}
# 只读取特征素材目录中的图片文件
feature_image_files = [f for f in os.listdir(feature_image_dir) if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
if not feature_image_files:
raise FileNotFoundError(f"特征素材目录 {feature_image_dir} 中未找到图片!")
for filename in tqdm(feature_image_files, desc="提取目标素材"):
label_file = os.path.splitext(filename)[0] + ".txt"
label_path = os.path.join(feature_label_dir, label_file)
if not os.path.exists(label_path):
continue # 跳过无标签的图片
# 读取图片并获取尺寸
img = cv2.imread(os.path.join(feature_image_dir, filename))
if img is None:
tqdm.write(f"警告:无法读取图片 {filename},已跳过")
continue
img_h, img_w, _ = img.shape
# 解析标签,裁剪目标
with open(label_path, 'r') as f:
for line in f.readlines():
line = line.strip()
if not line:
continue
parts = line.split()
# 修改5处理标签类别ID为浮点数的情况如6.0→6先转float再转int
cls_id = int(float(parts[0]))
if cls_id not in target_classes:
continue # 只保留目标类别
# YOLO归一化坐标转像素坐标x1,y1:左上角x2,y2:右下角)
x_center, y_center, box_w, box_h = [float(p) for p in parts[1:]]
x1 = int((x_center - box_w / 2) * img_w)
y1 = int((y_center - box_h / 2) * img_h)
x2 = int((x_center + box_w / 2) * img_w)
y2 = int((y_center + box_h / 2) * img_h)
# 确保坐标在图片范围内,避免裁剪出错
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(img_w, x2), min(img_h, y2)
# 裁剪目标并加入素材库(排除空图像)
if x1 < x2 and y1 < y2:
cropped_obj = img[y1:y2, x1:x2]
if cropped_obj.size > 0:
asset_library[cls_id].append(cropped_obj)
# 检查素材库是否为空
total_assets = sum(len(v) for v in asset_library.values())
if total_assets == 0:
raise ValueError(f"未从特征素材目录提取到任何目标请检查类别ID {target_classes} 是否正确")
print(f"素材库创建完成!共提取 {total_assets} 个目标(类别:{target_classes}")
return asset_library
def paste_objects_to_base(base_image, asset_library):
"""
将素材库中的目标粘贴到单张底图上
:param base_image: 输入的底图cv2读取的BGR图像
:param asset_library: 目标素材库
:return: 粘贴后的图像、对应的YOLO格式标签bboxes + labels
"""
base_h, base_w, _ = base_image.shape
pasted_bboxes = [] # 存储粘贴目标的YOLO bbox
pasted_labels = [] # 存储粘贴目标的类别ID
# 随机确定本次要粘贴的目标数量
num_to_paste = random.randint(*PASTE_COUNT_RANGE)
for _ in range(num_to_paste):
# 选择要粘贴的目标类别(只从有素材的类别中选)
valid_classes = [cls for cls, assets in asset_library.items() if len(assets) > 0]
if not valid_classes:
break # 极端情况:素材库临时为空(几乎不会发生)
# 随机选择一个目标类别和该类别下的一个素材
target_cls = random.choice(valid_classes)
target_obj = random.choice(asset_library[target_cls])
obj_h, obj_w, _ = target_obj.shape
# 跳过比底图大的目标(避免粘贴后超出边界)
if obj_h >= base_h or obj_w >= base_w:
continue
# 随机选择粘贴位置(左上角坐标,确保目标完全在底图内)
paste_x1 = random.randint(0, base_w - obj_w)
paste_y1 = random.randint(0, base_h - obj_h)
paste_x2 = paste_x1 + obj_w
paste_y2 = paste_y1 + obj_h
# 直接用Numpy切片粘贴目标覆盖底图对应区域
base_image[paste_y1:paste_y2, paste_x1:paste_x2] = target_obj
# 计算粘贴目标的YOLO归一化坐标x_center, y_center, w, h
yolo_x_center = (paste_x1 + obj_w / 2) / base_w
yolo_y_center = (paste_y1 + obj_h / 2) / base_h
yolo_w = obj_w / base_w
yolo_h = obj_h / base_h
# 将标签加入列表
pasted_bboxes.append([yolo_x_center, yolo_y_center, yolo_w, yolo_h])
pasted_labels.append(target_cls)
return base_image, pasted_bboxes, pasted_labels
def main():
# 1. 初始化:创建输出目录
os.makedirs(OUTPUT_IMAGE_DIR, exist_ok=True)
os.makedirs(OUTPUT_LABEL_DIR, exist_ok=True)
# 2. 第一步:创建目标素材库(从特征素材目录提取可粘贴的目标)
try:
asset_library = harvest_objects_for_pasting(
feature_image_dir=SOURCE_FEATURE_IMAGE_DIR,
feature_label_dir=SOURCE_FEATURE_LABEL_DIR,
target_classes=SMALL_OBJECT_CLASSES_TO_PASTE
)
except (FileNotFoundError, ValueError) as e:
print(f"错误:{e}")
return
# 3. 第二步:获取所有底图(只读取图片文件)
base_image_files = [f for f in os.listdir(BASE_IMAGE_DIR) if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
if not base_image_files:
print(f"错误:底图目录 {BASE_IMAGE_DIR} 中未找到任何图片!")
return
print(f"\n找到 {len(base_image_files)} 张底图,开始生成增强数据(每张底图生成 {AUGMENTATION_FACTOR} 张)")
# 4. 主循环:遍历每张底图,生成增强数据
for base_filename in tqdm(base_image_files, desc="处理底图"):
base_name, base_ext = os.path.splitext(base_filename)
base_image_path = os.path.join(BASE_IMAGE_DIR, base_filename)
# 读取底图(若读取失败则跳过)
base_image = cv2.imread(base_image_path)
if base_image is None:
tqdm.write(f"\n警告:无法读取底图 {base_filename},已跳过")
continue
# 为当前底图生成 AUGMENTATION_FACTOR 张增强图
for aug_idx in range(AUGMENTATION_FACTOR):
# 步骤1复制底图避免修改原始底图并粘贴目标
base_image_copy = base_image.copy()
pasted_image, pasted_bboxes, pasted_labels = paste_objects_to_base(
base_image=base_image_copy,
asset_library=asset_library
)
# 步骤2对粘贴后的图像应用常规增强Albumentations需要RGB格式
pasted_image_rgb = cv2.cvtColor(pasted_image, cv2.COLOR_BGR2RGB)
chosen_transform = random.choice(base_transforms) # 随机选择增强策略
try:
# 应用增强同时处理bbox和label
augmented_result = chosen_transform(
image=pasted_image_rgb,
bboxes=pasted_bboxes,
class_labels=pasted_labels
)
final_image_rgb = augmented_result['image']
final_bboxes = augmented_result['bboxes']
final_labels = augmented_result['class_labels']
except Exception as e:
tqdm.write(f"\n警告:底图 {base_filename} 增强失败(序号 {aug_idx}{str(e)}")
continue
# 步骤3保存增强后的图片和标签
# 图片命名格式底图名_aug_序号.jpg统一转为jpg格式避免格式混乱
output_img_name = f"{base_name}_aug_{aug_idx}.jpg"
output_img_path = os.path.join(OUTPUT_IMAGE_DIR, output_img_name)
# RGB转BGRcv2保存需要BGR格式
cv2.imwrite(output_img_path, cv2.cvtColor(final_image_rgb, cv2.COLOR_RGB2BGR))
# 标签命名格式:与图片同名.txtYOLO格式
output_label_name = f"{base_name}_aug_{aug_idx}.txt"
output_label_path = os.path.join(OUTPUT_LABEL_DIR, output_label_name)
with open(output_label_path, 'w') as f:
for bbox, label in zip(final_bboxes, final_labels):
x_c, y_c, w, h = bbox
# 边界检查排除增强后可能超出0-1范围的bbox避免训练报错
if 0 <= x_c <= 1 and 0 <= y_c <= 1 and 0 <= w <= 1 and 0 <= h <= 1:
f.write(f"{label} {x_c:.6f} {y_c:.6f} {w:.6f} {h:.6f}\n")
# 5. 完成提示
total_generated = len(base_image_files) * AUGMENTATION_FACTOR
print(f"\n✅ 数据增强全部完成!")
print(f"📊 生成数据统计:")
print(f" - 底图数量:{len(base_image_files)}")
print(f" - 每张底图增强次数:{AUGMENTATION_FACTOR}")
print(f" - 总生成图片/标签:{total_generated}")
print(f" - 输出路径:")
print(f" 图片 → {OUTPUT_IMAGE_DIR}")
print(f" 标签 → {OUTPUT_LABEL_DIR}")
if __name__ == "__main__":
# 运行前务必确认:
# 1. SOURCE_FEATURE_IMAGE_DIR/SOURCE_FEATURE_LABEL_DIR 是「有目标的素材目录」
# 2. BASE_IMAGE_DIR 是你的「空白底图目录」
# 3. SMALL_OBJECT_CLASSES_TO_PASTE 是要粘贴的目标类别ID如no_helmet=2
main()