Files
prehadler/copypaste整张特征图.py
2025-10-10 11:39:23 +08:00

319 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
import os
import random
from collections import defaultdict
from tqdm import tqdm
# --- 1. 核心配置 ---
# 特征素材来源目录(包含完整图片和对应标签)
# SOURCE_FEATURE_IMAGE_DIR = r"D:\DataPreHandler\yuanshi_data\images\train\images"
# SOURCE_FEATURE_LABEL_DIR = r"D:\DataPreHandler\yuanshi_data\images\train\labels"
# SOURCE_FEATURE_IMAGE_DIR = r"D:\DataPreHandler\yuanshi_data\images\val\images"
# SOURCE_FEATURE_LABEL_DIR = r"D:\DataPreHandler\yuanshi_data\images\val\labels"
SOURCE_FEATURE_IMAGE_DIR = r"D:\DataPreHandler\yuanshi_data\images\test\images"
SOURCE_FEATURE_LABEL_DIR = r"D:\DataPreHandler\yuanshi_data\images\test\labels"
# 底图目录
# BASE_IMAGE_DIR = r"D:\DataPreHandler\images\dituchoqu\train"
BASE_IMAGE_DIR = r"D:\DataPreHandler\images\dituchoqu_huashen\train"
# BASE_IMAGE_DIR = r"D:\DataPreHandler\images\test"
# 输出目录
OUTPUT_IMAGE_DIR = r"D:\DataPreHandler\data\train3\images"
OUTPUT_LABEL_DIR = r"D:\DataPreHandler\data\train3\labels"
# OUTPUT_IMAGE_DIR = r"D:\DataPreHandler\data\val3\images"
# OUTPUT_LABEL_DIR = r"D:\DataPreHandler\data\val3\labels"
# OUTPUT_IMAGE_DIR = r"D:\DataPreHandler\data\test\da\images"
# OUTPUT_LABEL_DIR = r"D:\DataPreHandler\data\test\da\labels"
# 底图数量
SELECT_BASE_IMAGE_COUNT = 0
AUGMENTATION_FACTOR = 1
TARGET_CLASSES = [0, 1, 2, 3, 4, 5, 6]
PASTE_COUNT_RANGE = (2, 3)
MIN_SCALED_SIZE = 50
MAX_OVERLAP_IOU = 0.0 # 仅限制“跨特征图”无重叠
MAX_RETRY_COUNT = 300
SCALE_FACTOR_RANGE = (0.3, 0.9)
# --- 2. 工具函数 ---(核心修改:删除内部重叠检测)
def calculate_iou(box1, box2):
"""计算两个边界框的交并比(像素坐标)——保持不变"""
box1 = [int(round(x)) for x in box1]
box2 = [int(round(x)) for x in box2]
inter_x1 = max(box1[0], box2[0])
inter_y1 = max(box1[1], box2[1])
inter_x2 = min(box1[2], box2[2])
inter_y2 = min(box1[3], box2[3])
inter_width = max(0, inter_x2 - inter_x1)
inter_height = max(0, inter_y2 - inter_y1)
inter_area = inter_width * inter_height
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
union_area = box1_area + box2_area - inter_area
return inter_area / union_area if union_area > 0 else 0.0
def collect_feature_images(feature_image_dir, feature_label_dir, target_classes):
"""收集特征图片及其标签——核心修改:删除内部重叠过滤"""
print(f"{feature_image_dir} 收集特征图片及标签,目标类别 {target_classes}...")
feature_list = [] # 元素格式:(图片路径, 标签列表, 原始宽, 原始高)
feature_image_files = [f for f in os.listdir(feature_image_dir)
if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
if not feature_image_files:
raise FileNotFoundError(f"未找到特征图片!")
for filename in tqdm(feature_image_files, desc="收集特征图片"):
img_path = os.path.join(feature_image_dir, filename)
label_file = os.path.splitext(filename)[0] + ".txt"
label_path = os.path.join(feature_label_dir, label_file)
if not os.path.exists(label_path):
continue
img = cv2.imread(img_path)
if img is None:
tqdm.write(f"警告:无法读取图片 {filename},已跳过")
continue
orig_h, orig_w = img.shape[:2]
# 仅过滤缩小后可能过小的特征图(保留内部重叠的图)
min_required_orig_size = MIN_SCALED_SIZE / max(SCALE_FACTOR_RANGE)
if orig_w < min_required_orig_size or orig_h < min_required_orig_size:
tqdm.write(f"警告:特征图片 {filename} 原始尺寸过小,已跳过")
continue
# 读取并过滤目标类别(不再检查内部重叠)
labels = []
with open(label_path, 'r') as f:
for line in f.readlines():
line = line.strip()
if not line:
continue
parts = line.split()
cls_id = int(float(parts[0]))
if cls_id not in target_classes:
continue
xc, yc, w, h = [float(p) for p in parts[1:]]
if 0 <= xc <= 1 and 0 <= yc <= 1 and 0 < w <= 1 and 0 < h <= 1:
labels.append((cls_id, xc, yc, w, h))
# 只要有有效标签就保留(无论内部是否重叠)
if labels:
feature_list.append((img_path, labels, orig_w, orig_h))
if not feature_list:
raise ValueError(f"未收集到有效的特征图片及标签!")
print(f"特征图片收集完成:共 {len(feature_list)} 张有效特征图片(允许内部目标重叠)")
return feature_list
def paste_feature_images_to_base(base_image, feature_list):
"""将特征图粘贴到底图——保留“跨特征图无重叠”逻辑(核心)"""
base_h, base_w = base_image.shape[:2]
pasted_labels = [] # 最终输出的标签
pasted_target_boxes = [] # 已粘贴的目标框(像素坐标)
pasted_feature_regions = [] # 已粘贴的特征图整体区域(防跨图重叠)
pasted_feature_count = 0 # 成功粘贴的特征图数量
skipped_small = 0 # 因尺寸过小跳过的次数
skipped_region_overlap = 0 # 因特征图整体区域重叠跳过的次数
skipped_target_overlap = 0 # 因跨图目标重叠跳过的次数
retry_count = 0 # 重试次数
target_paste_count = random.randint(*PASTE_COUNT_RANGE)
print(f" 计划粘贴 {target_paste_count} 张特征图...")
while pasted_feature_count < target_paste_count and retry_count < MAX_RETRY_COUNT:
retry_count += 1
# 1. 随机选择一张特征图(允许内部重叠)
feature_img_path, feature_labels, orig_w, orig_h = random.choice(feature_list)
# 2. 随机缩放特征图,确保尺寸符合要求
scale_factor = random.uniform(*SCALE_FACTOR_RANGE)
scaled_w = int(round(orig_w * scale_factor))
scaled_h = int(round(orig_h * scale_factor))
if scaled_w < MIN_SCALED_SIZE or scaled_h < MIN_SCALED_SIZE:
skipped_small += 1
continue
if scaled_w >= base_w or scaled_h >= base_h:
skipped_small += 1
continue
# 3. 随机生成粘贴位置(确保特征图完全在底图内)
paste_x1 = random.randint(0, base_w - scaled_w)
paste_y1 = random.randint(0, base_h - scaled_h)
paste_x2 = paste_x1 + scaled_w
paste_y2 = paste_y1 + scaled_h
current_feature_region = (paste_x1, paste_y1, paste_x2, paste_y2)
# 4. 关键:检测当前特征图整体区域与已粘贴区域是否重叠(防跨图重叠)
region_overlap = False
for existing_region in pasted_feature_regions:
if calculate_iou(current_feature_region, existing_region) > MAX_OVERLAP_IOU:
region_overlap = True
skipped_region_overlap += 1
break
if region_overlap:
continue
# 5. 计算当前特征图目标在底图上的像素坐标(保留原始内部重叠)
temp_target_boxes = []
valid = True
for (cls_id, xc, yc, w, h) in feature_labels:
# 特征图内目标坐标 → 缩放后 → 底图坐标(内部重叠会保留)
orig_x1 = (xc - w/2) * orig_w
orig_y1 = (yc - h/2) * orig_h
orig_x2 = (xc + w/2) * orig_w
orig_y2 = (yc + h/2) * orig_h
scaled_x1 = orig_x1 * scale_factor
scaled_y1 = orig_y1 * scale_factor
scaled_x2 = orig_x2 * scale_factor
scaled_y2 = orig_y2 * scale_factor
base_x1 = paste_x1 + scaled_x1
base_y1 = paste_y1 + scaled_y1
base_x2 = paste_x1 + scaled_x2
base_y2 = paste_y1 + scaled_y2
# 确保目标框完全在底图内(不考虑内部重叠)
if base_x1 < 0 or base_y1 < 0 or base_x2 > base_w or base_y2 > base_h:
valid = False
break
temp_target_boxes.append((base_x1, base_y1, base_x2, base_y2))
if not valid:
continue
# 6. 关键:检测当前特征图目标与已粘贴目标是否重叠(防跨图目标重叠)
target_overlap = False
for temp_box in temp_target_boxes:
for existing_box in pasted_target_boxes:
if calculate_iou(temp_box, existing_box) > MAX_OVERLAP_IOU:
target_overlap = True
skipped_target_overlap += 1
break
if target_overlap:
break
if target_overlap:
continue
# 7. 粘贴特征图并记录信息(保留内部重叠)
feature_img = cv2.imread(feature_img_path)
if feature_img is None:
continue
scaled_feature_img = cv2.resize(feature_img, (scaled_w, scaled_h), interpolation=cv2.INTER_LINEAR)
base_image[paste_y1:paste_y2, paste_x1:paste_x2] = scaled_feature_img
# 记录标签、目标框、特征图整体区域
for (cls_id, xc, yc, w, h), temp_box in zip(feature_labels, temp_target_boxes):
base_xc = (temp_box[0] + temp_box[2]) / 2 / base_w
base_yc = (temp_box[1] + temp_box[3]) / 2 / base_h
base_w_box = (temp_box[2] - temp_box[0]) / base_w
base_h_box = (temp_box[3] - temp_box[1]) / base_h
pasted_labels.append((cls_id, base_xc, base_yc, base_w_box, base_h_box))
pasted_target_boxes.append(temp_box)
pasted_feature_regions.append(current_feature_region)
pasted_feature_count += 1
print(f" 已成功粘贴 {pasted_feature_count}/{target_paste_count} 张特征图")
# 输出统计(无内部重叠相关提示)
print(
f" 粘贴统计:成功{pasted_feature_count}张 → "
f"跳小尺寸{skipped_small} → 跳区域重叠{skipped_region_overlap}"
f"跳目标重叠{skipped_target_overlap} → 总目标数{len(pasted_labels)}"
)
return base_image, pasted_labels
# --- 3. 主函数 ---(无修改)
def main():
os.makedirs(OUTPUT_IMAGE_DIR, exist_ok=True)
os.makedirs(OUTPUT_LABEL_DIR, exist_ok=True)
try:
feature_list = collect_feature_images(
SOURCE_FEATURE_IMAGE_DIR, SOURCE_FEATURE_LABEL_DIR, TARGET_CLASSES
)
except (FileNotFoundError, ValueError) as e:
print(f"错误:{e}")
return
# 筛选指定数量的底图
base_files = [f for f in os.listdir(BASE_IMAGE_DIR)
if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
total_base_count = len(base_files)
if total_base_count == 0:
print(f"错误:底图目录 {BASE_IMAGE_DIR} 无有效图片!")
return
if SELECT_BASE_IMAGE_COUNT < 0:
print(f"错误:底图数量不能为负数(当前设置:{SELECT_BASE_IMAGE_COUNT}")
return
elif SELECT_BASE_IMAGE_COUNT == 0:
selected_base_files = base_files
print(f"\n找到 {total_base_count} 张底图,使用全部底图...\n")
else:
if SELECT_BASE_IMAGE_COUNT > total_base_count:
selected_base_files = base_files
print(f"\n警告:设置底图数({SELECT_BASE_IMAGE_COUNT})超过可用数({total_base_count}),使用全部底图...\n")
else:
selected_base_files = random.sample(base_files, SELECT_BASE_IMAGE_COUNT)
print(f"\n找到 {total_base_count} 张底图,随机选择 {SELECT_BASE_IMAGE_COUNT} 张...\n")
# 处理底图
for base_filename in tqdm(selected_base_files, desc="处理底图"):
base_name, ext = os.path.splitext(base_filename)
base_path = os.path.join(BASE_IMAGE_DIR, base_filename)
base_img = cv2.imread(base_path)
if base_img is None:
tqdm.write(f"\n警告:无法读取底图 {base_filename},已跳过")
continue
for aug_idx in range(AUGMENTATION_FACTOR):
print(f"\n底图 {base_filename}(增强序号 {aug_idx}", end="")
base_copy = base_img.copy()
pasted_img, labels = paste_feature_images_to_base(base_copy, feature_list)
if not labels:
tqdm.write(f"\n警告:未粘贴任何目标,已跳过该结果")
continue
# 保存图片和标签
output_img_name = f"{base_name}_feat_paste_{aug_idx}.jpg"
output_img_path = os.path.join(OUTPUT_IMAGE_DIR, output_img_name)
cv2.imwrite(output_img_path, pasted_img)
output_label_name = f"{base_name}_feat_paste_{aug_idx}.txt"
output_label_path = os.path.join(OUTPUT_LABEL_DIR, output_label_name)
with open(output_label_path, 'w') as f:
for (cls_id, x, y, w, h) in labels:
f.write(f"{cls_id} {x:.6f} {y:.6f} {w:.6f} {h:.6f}\n")
print(f" → 已保存(目标数:{len(labels)}")
# 统计结果
output_img_count = len([f for f in os.listdir(OUTPUT_IMAGE_DIR) if f.endswith(('.jpg', '.png'))])
output_label_count = len([f for f in os.listdir(OUTPUT_LABEL_DIR) if f.endswith('.txt')])
print(f"\n✅ 全部处理完成!")
print(f" - 实际处理底图数量:{len(selected_base_files)}")
print(f" - 生成图片:{output_img_count}")
print(f" - 生成标签:{output_label_count}")
print(f" - 输出路径:\n 图片 → {OUTPUT_IMAGE_DIR}\n 标签 → {OUTPUT_LABEL_DIR}")
if __name__ == "__main__":
main()