319 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			319 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import cv2
 | ||
| import numpy as np
 | ||
| import os
 | ||
| import random
 | ||
| from collections import defaultdict
 | ||
| from tqdm import tqdm
 | ||
| 
 | ||
| # --- 1. 核心配置 ---
 | ||
| # 特征素材来源目录(包含完整图片和对应标签)
 | ||
| # SOURCE_FEATURE_IMAGE_DIR = r"D:\DataPreHandler\yuanshi_data\images\train\images"
 | ||
| # SOURCE_FEATURE_LABEL_DIR = r"D:\DataPreHandler\yuanshi_data\images\train\labels"
 | ||
| # SOURCE_FEATURE_IMAGE_DIR = r"D:\DataPreHandler\yuanshi_data\images\val\images"
 | ||
| # SOURCE_FEATURE_LABEL_DIR = r"D:\DataPreHandler\yuanshi_data\images\val\labels"
 | ||
| SOURCE_FEATURE_IMAGE_DIR = r"D:\DataPreHandler\yuanshi_data\images\test\images"
 | ||
| SOURCE_FEATURE_LABEL_DIR = r"D:\DataPreHandler\yuanshi_data\images\test\labels"
 | ||
| 
 | ||
| # 底图目录
 | ||
| # BASE_IMAGE_DIR = r"D:\DataPreHandler\images\dituchoqu\train"
 | ||
| BASE_IMAGE_DIR = r"D:\DataPreHandler\images\dituchoqu_huashen\train"
 | ||
| # BASE_IMAGE_DIR = r"D:\DataPreHandler\images\test"
 | ||
| 
 | ||
| # 输出目录
 | ||
| OUTPUT_IMAGE_DIR = r"D:\DataPreHandler\data\train3\images"
 | ||
| OUTPUT_LABEL_DIR = r"D:\DataPreHandler\data\train3\labels"
 | ||
| # OUTPUT_IMAGE_DIR = r"D:\DataPreHandler\data\val3\images"
 | ||
| # OUTPUT_LABEL_DIR = r"D:\DataPreHandler\data\val3\labels"
 | ||
| # OUTPUT_IMAGE_DIR = r"D:\DataPreHandler\data\test\da\images"
 | ||
| # OUTPUT_LABEL_DIR = r"D:\DataPreHandler\data\test\da\labels"
 | ||
| 
 | ||
| # 底图数量
 | ||
| SELECT_BASE_IMAGE_COUNT = 0
 | ||
| 
 | ||
| 
 | ||
| AUGMENTATION_FACTOR = 1
 | ||
| TARGET_CLASSES = [0, 1, 2, 3, 4, 5, 6]
 | ||
| PASTE_COUNT_RANGE = (2, 3)
 | ||
| MIN_SCALED_SIZE = 50
 | ||
| MAX_OVERLAP_IOU = 0.0  # 仅限制“跨特征图”无重叠
 | ||
| MAX_RETRY_COUNT = 300
 | ||
| SCALE_FACTOR_RANGE = (0.3, 0.9)
 | ||
| 
 | ||
| 
 | ||
| # --- 2. 工具函数 ---(核心修改:删除内部重叠检测)
 | ||
| def calculate_iou(box1, box2):
 | ||
|     """计算两个边界框的交并比(像素坐标)——保持不变"""
 | ||
|     box1 = [int(round(x)) for x in box1]
 | ||
|     box2 = [int(round(x)) for x in box2]
 | ||
| 
 | ||
|     inter_x1 = max(box1[0], box2[0])
 | ||
|     inter_y1 = max(box1[1], box2[1])
 | ||
|     inter_x2 = min(box1[2], box2[2])
 | ||
|     inter_y2 = min(box1[3], box2[3])
 | ||
| 
 | ||
|     inter_width = max(0, inter_x2 - inter_x1)
 | ||
|     inter_height = max(0, inter_y2 - inter_y1)
 | ||
|     inter_area = inter_width * inter_height
 | ||
| 
 | ||
|     box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
 | ||
|     box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
 | ||
|     union_area = box1_area + box2_area - inter_area
 | ||
| 
 | ||
|     return inter_area / union_area if union_area > 0 else 0.0
 | ||
| 
 | ||
| 
 | ||
| def collect_feature_images(feature_image_dir, feature_label_dir, target_classes):
 | ||
|     """收集特征图片及其标签——核心修改:删除内部重叠过滤"""
 | ||
|     print(f"从 {feature_image_dir} 收集特征图片及标签,目标类别 {target_classes}...")
 | ||
|     feature_list = []  # 元素格式:(图片路径, 标签列表, 原始宽, 原始高)
 | ||
| 
 | ||
|     feature_image_files = [f for f in os.listdir(feature_image_dir)
 | ||
|                            if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
 | ||
|     if not feature_image_files:
 | ||
|         raise FileNotFoundError(f"未找到特征图片!")
 | ||
| 
 | ||
|     for filename in tqdm(feature_image_files, desc="收集特征图片"):
 | ||
|         img_path = os.path.join(feature_image_dir, filename)
 | ||
|         label_file = os.path.splitext(filename)[0] + ".txt"
 | ||
|         label_path = os.path.join(feature_label_dir, label_file)
 | ||
| 
 | ||
|         if not os.path.exists(label_path):
 | ||
|             continue
 | ||
| 
 | ||
|         img = cv2.imread(img_path)
 | ||
|         if img is None:
 | ||
|             tqdm.write(f"警告:无法读取图片 {filename},已跳过")
 | ||
|             continue
 | ||
|         orig_h, orig_w = img.shape[:2]
 | ||
| 
 | ||
|         # 仅过滤缩小后可能过小的特征图(保留内部重叠的图)
 | ||
|         min_required_orig_size = MIN_SCALED_SIZE / max(SCALE_FACTOR_RANGE)
 | ||
|         if orig_w < min_required_orig_size or orig_h < min_required_orig_size:
 | ||
|             tqdm.write(f"警告:特征图片 {filename} 原始尺寸过小,已跳过")
 | ||
|             continue
 | ||
| 
 | ||
|         # 读取并过滤目标类别(不再检查内部重叠)
 | ||
|         labels = []
 | ||
|         with open(label_path, 'r') as f:
 | ||
|             for line in f.readlines():
 | ||
|                 line = line.strip()
 | ||
|                 if not line:
 | ||
|                     continue
 | ||
|                 parts = line.split()
 | ||
|                 cls_id = int(float(parts[0]))
 | ||
|                 if cls_id not in target_classes:
 | ||
|                     continue
 | ||
|                 xc, yc, w, h = [float(p) for p in parts[1:]]
 | ||
|                 if 0 <= xc <= 1 and 0 <= yc <= 1 and 0 < w <= 1 and 0 < h <= 1:
 | ||
|                     labels.append((cls_id, xc, yc, w, h))
 | ||
| 
 | ||
|         # 只要有有效标签就保留(无论内部是否重叠)
 | ||
|         if labels:
 | ||
|             feature_list.append((img_path, labels, orig_w, orig_h))
 | ||
| 
 | ||
|     if not feature_list:
 | ||
|         raise ValueError(f"未收集到有效的特征图片及标签!")
 | ||
| 
 | ||
|     print(f"特征图片收集完成:共 {len(feature_list)} 张有效特征图片(允许内部目标重叠)")
 | ||
|     return feature_list
 | ||
| 
 | ||
| 
 | ||
| def paste_feature_images_to_base(base_image, feature_list):
 | ||
|     """将特征图粘贴到底图——保留“跨特征图无重叠”逻辑(核心)"""
 | ||
|     base_h, base_w = base_image.shape[:2]
 | ||
|     pasted_labels = []               # 最终输出的标签
 | ||
|     pasted_target_boxes = []         # 已粘贴的目标框(像素坐标)
 | ||
|     pasted_feature_regions = []      # 已粘贴的特征图整体区域(防跨图重叠)
 | ||
|     pasted_feature_count = 0         # 成功粘贴的特征图数量
 | ||
|     skipped_small = 0                # 因尺寸过小跳过的次数
 | ||
|     skipped_region_overlap = 0       # 因特征图整体区域重叠跳过的次数
 | ||
|     skipped_target_overlap = 0       # 因跨图目标重叠跳过的次数
 | ||
|     retry_count = 0                  # 重试次数
 | ||
| 
 | ||
|     target_paste_count = random.randint(*PASTE_COUNT_RANGE)
 | ||
|     print(f"  计划粘贴 {target_paste_count} 张特征图...")
 | ||
| 
 | ||
|     while pasted_feature_count < target_paste_count and retry_count < MAX_RETRY_COUNT:
 | ||
|         retry_count += 1
 | ||
| 
 | ||
|         # 1. 随机选择一张特征图(允许内部重叠)
 | ||
|         feature_img_path, feature_labels, orig_w, orig_h = random.choice(feature_list)
 | ||
| 
 | ||
|         # 2. 随机缩放特征图,确保尺寸符合要求
 | ||
|         scale_factor = random.uniform(*SCALE_FACTOR_RANGE)
 | ||
|         scaled_w = int(round(orig_w * scale_factor))
 | ||
|         scaled_h = int(round(orig_h * scale_factor))
 | ||
| 
 | ||
|         if scaled_w < MIN_SCALED_SIZE or scaled_h < MIN_SCALED_SIZE:
 | ||
|             skipped_small += 1
 | ||
|             continue
 | ||
|         if scaled_w >= base_w or scaled_h >= base_h:
 | ||
|             skipped_small += 1
 | ||
|             continue
 | ||
| 
 | ||
|         # 3. 随机生成粘贴位置(确保特征图完全在底图内)
 | ||
|         paste_x1 = random.randint(0, base_w - scaled_w)
 | ||
|         paste_y1 = random.randint(0, base_h - scaled_h)
 | ||
|         paste_x2 = paste_x1 + scaled_w
 | ||
|         paste_y2 = paste_y1 + scaled_h
 | ||
|         current_feature_region = (paste_x1, paste_y1, paste_x2, paste_y2)
 | ||
| 
 | ||
|         # 4. 关键:检测当前特征图整体区域与已粘贴区域是否重叠(防跨图重叠)
 | ||
|         region_overlap = False
 | ||
|         for existing_region in pasted_feature_regions:
 | ||
|             if calculate_iou(current_feature_region, existing_region) > MAX_OVERLAP_IOU:
 | ||
|                 region_overlap = True
 | ||
|                 skipped_region_overlap += 1
 | ||
|                 break
 | ||
|         if region_overlap:
 | ||
|             continue
 | ||
| 
 | ||
|         # 5. 计算当前特征图目标在底图上的像素坐标(保留原始内部重叠)
 | ||
|         temp_target_boxes = []
 | ||
|         valid = True
 | ||
|         for (cls_id, xc, yc, w, h) in feature_labels:
 | ||
|             # 特征图内目标坐标 → 缩放后 → 底图坐标(内部重叠会保留)
 | ||
|             orig_x1 = (xc - w/2) * orig_w
 | ||
|             orig_y1 = (yc - h/2) * orig_h
 | ||
|             orig_x2 = (xc + w/2) * orig_w
 | ||
|             orig_y2 = (yc + h/2) * orig_h
 | ||
| 
 | ||
|             scaled_x1 = orig_x1 * scale_factor
 | ||
|             scaled_y1 = orig_y1 * scale_factor
 | ||
|             scaled_x2 = orig_x2 * scale_factor
 | ||
|             scaled_y2 = orig_y2 * scale_factor
 | ||
| 
 | ||
|             base_x1 = paste_x1 + scaled_x1
 | ||
|             base_y1 = paste_y1 + scaled_y1
 | ||
|             base_x2 = paste_x1 + scaled_x2
 | ||
|             base_y2 = paste_y1 + scaled_y2
 | ||
| 
 | ||
|             # 确保目标框完全在底图内(不考虑内部重叠)
 | ||
|             if base_x1 < 0 or base_y1 < 0 or base_x2 > base_w or base_y2 > base_h:
 | ||
|                 valid = False
 | ||
|                 break
 | ||
|             temp_target_boxes.append((base_x1, base_y1, base_x2, base_y2))
 | ||
| 
 | ||
|         if not valid:
 | ||
|             continue
 | ||
| 
 | ||
|         # 6. 关键:检测当前特征图目标与已粘贴目标是否重叠(防跨图目标重叠)
 | ||
|         target_overlap = False
 | ||
|         for temp_box in temp_target_boxes:
 | ||
|             for existing_box in pasted_target_boxes:
 | ||
|                 if calculate_iou(temp_box, existing_box) > MAX_OVERLAP_IOU:
 | ||
|                     target_overlap = True
 | ||
|                     skipped_target_overlap += 1
 | ||
|                     break
 | ||
|             if target_overlap:
 | ||
|                 break
 | ||
|         if target_overlap:
 | ||
|             continue
 | ||
| 
 | ||
|         # 7. 粘贴特征图并记录信息(保留内部重叠)
 | ||
|         feature_img = cv2.imread(feature_img_path)
 | ||
|         if feature_img is None:
 | ||
|             continue
 | ||
|         scaled_feature_img = cv2.resize(feature_img, (scaled_w, scaled_h), interpolation=cv2.INTER_LINEAR)
 | ||
|         base_image[paste_y1:paste_y2, paste_x1:paste_x2] = scaled_feature_img
 | ||
| 
 | ||
|         # 记录标签、目标框、特征图整体区域
 | ||
|         for (cls_id, xc, yc, w, h), temp_box in zip(feature_labels, temp_target_boxes):
 | ||
|             base_xc = (temp_box[0] + temp_box[2]) / 2 / base_w
 | ||
|             base_yc = (temp_box[1] + temp_box[3]) / 2 / base_h
 | ||
|             base_w_box = (temp_box[2] - temp_box[0]) / base_w
 | ||
|             base_h_box = (temp_box[3] - temp_box[1]) / base_h
 | ||
| 
 | ||
|             pasted_labels.append((cls_id, base_xc, base_yc, base_w_box, base_h_box))
 | ||
|             pasted_target_boxes.append(temp_box)
 | ||
|         pasted_feature_regions.append(current_feature_region)
 | ||
|         pasted_feature_count += 1
 | ||
|         print(f"  已成功粘贴 {pasted_feature_count}/{target_paste_count} 张特征图")
 | ||
| 
 | ||
|     # 输出统计(无内部重叠相关提示)
 | ||
|     print(
 | ||
|         f"  粘贴统计:成功{pasted_feature_count}张 → "
 | ||
|         f"跳小尺寸{skipped_small} → 跳区域重叠{skipped_region_overlap} → "
 | ||
|         f"跳目标重叠{skipped_target_overlap} → 总目标数{len(pasted_labels)}"
 | ||
|     )
 | ||
|     return base_image, pasted_labels
 | ||
| 
 | ||
| 
 | ||
| # --- 3. 主函数 ---(无修改)
 | ||
| def main():
 | ||
|     os.makedirs(OUTPUT_IMAGE_DIR, exist_ok=True)
 | ||
|     os.makedirs(OUTPUT_LABEL_DIR, exist_ok=True)
 | ||
| 
 | ||
|     try:
 | ||
|         feature_list = collect_feature_images(
 | ||
|             SOURCE_FEATURE_IMAGE_DIR, SOURCE_FEATURE_LABEL_DIR, TARGET_CLASSES
 | ||
|         )
 | ||
|     except (FileNotFoundError, ValueError) as e:
 | ||
|         print(f"错误:{e}")
 | ||
|         return
 | ||
| 
 | ||
|     # 筛选指定数量的底图
 | ||
|     base_files = [f for f in os.listdir(BASE_IMAGE_DIR)
 | ||
|                   if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
 | ||
|     total_base_count = len(base_files)
 | ||
|     if total_base_count == 0:
 | ||
|         print(f"错误:底图目录 {BASE_IMAGE_DIR} 无有效图片!")
 | ||
|         return
 | ||
| 
 | ||
|     if SELECT_BASE_IMAGE_COUNT < 0:
 | ||
|         print(f"错误:底图数量不能为负数(当前设置:{SELECT_BASE_IMAGE_COUNT})!")
 | ||
|         return
 | ||
|     elif SELECT_BASE_IMAGE_COUNT == 0:
 | ||
|         selected_base_files = base_files
 | ||
|         print(f"\n找到 {total_base_count} 张底图,使用全部底图...\n")
 | ||
|     else:
 | ||
|         if SELECT_BASE_IMAGE_COUNT > total_base_count:
 | ||
|             selected_base_files = base_files
 | ||
|             print(f"\n警告:设置底图数({SELECT_BASE_IMAGE_COUNT})超过可用数({total_base_count}),使用全部底图...\n")
 | ||
|         else:
 | ||
|             selected_base_files = random.sample(base_files, SELECT_BASE_IMAGE_COUNT)
 | ||
|             print(f"\n找到 {total_base_count} 张底图,随机选择 {SELECT_BASE_IMAGE_COUNT} 张...\n")
 | ||
| 
 | ||
|     # 处理底图
 | ||
|     for base_filename in tqdm(selected_base_files, desc="处理底图"):
 | ||
|         base_name, ext = os.path.splitext(base_filename)
 | ||
|         base_path = os.path.join(BASE_IMAGE_DIR, base_filename)
 | ||
|         base_img = cv2.imread(base_path)
 | ||
|         if base_img is None:
 | ||
|             tqdm.write(f"\n警告:无法读取底图 {base_filename},已跳过")
 | ||
|             continue
 | ||
| 
 | ||
|         for aug_idx in range(AUGMENTATION_FACTOR):
 | ||
|             print(f"\n底图 {base_filename}(增强序号 {aug_idx}):", end="")
 | ||
|             base_copy = base_img.copy()
 | ||
|             pasted_img, labels = paste_feature_images_to_base(base_copy, feature_list)
 | ||
| 
 | ||
|             if not labels:
 | ||
|                 tqdm.write(f"\n警告:未粘贴任何目标,已跳过该结果")
 | ||
|                 continue
 | ||
| 
 | ||
|             # 保存图片和标签
 | ||
|             output_img_name = f"{base_name}_feat_paste_{aug_idx}.jpg"
 | ||
|             output_img_path = os.path.join(OUTPUT_IMAGE_DIR, output_img_name)
 | ||
|             cv2.imwrite(output_img_path, pasted_img)
 | ||
| 
 | ||
|             output_label_name = f"{base_name}_feat_paste_{aug_idx}.txt"
 | ||
|             output_label_path = os.path.join(OUTPUT_LABEL_DIR, output_label_name)
 | ||
|             with open(output_label_path, 'w') as f:
 | ||
|                 for (cls_id, x, y, w, h) in labels:
 | ||
|                     f.write(f"{cls_id} {x:.6f} {y:.6f} {w:.6f} {h:.6f}\n")
 | ||
| 
 | ||
|             print(f" → 已保存(目标数:{len(labels)})")
 | ||
| 
 | ||
|     # 统计结果
 | ||
|     output_img_count = len([f for f in os.listdir(OUTPUT_IMAGE_DIR) if f.endswith(('.jpg', '.png'))])
 | ||
|     output_label_count = len([f for f in os.listdir(OUTPUT_LABEL_DIR) if f.endswith('.txt')])
 | ||
|     print(f"\n✅ 全部处理完成!")
 | ||
|     print(f"  - 实际处理底图数量:{len(selected_base_files)} 张")
 | ||
|     print(f"  - 生成图片:{output_img_count} 张")
 | ||
|     print(f"  - 生成标签:{output_label_count} 个")
 | ||
|     print(f"  - 输出路径:\n    图片 → {OUTPUT_IMAGE_DIR}\n    标签 → {OUTPUT_LABEL_DIR}")
 | ||
| 
 | ||
| 
 | ||
| if __name__ == "__main__":
 | ||
|     main() |