From e5ef7ec066b0f0b186b043cb83b08c84578901f8 Mon Sep 17 00:00:00 2001 From: ninghongbin <2409766686@qq.com> Date: Fri, 10 Oct 2025 11:39:23 +0800 Subject: [PATCH] =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=A2=9E=E5=BC=BA=E7=9B=B8?= =?UTF-8?q?=E5=85=B3=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- copypaste整张特征图.py | 319 +++++++++++++++++++++++++++++++++++++++++ 扩大边缘像素.py | 65 +++++++++ 抽取样本数量.py | 217 ++++++++++++++++++++++++++++ 根据位置画框.py | 199 +++++++++++++++++++++++++ 视频抽帧.py | 28 ++-- 5 files changed, 814 insertions(+), 14 deletions(-) create mode 100644 copypaste整张特征图.py create mode 100644 扩大边缘像素.py create mode 100644 抽取样本数量.py create mode 100644 根据位置画框.py diff --git a/copypaste整张特征图.py b/copypaste整张特征图.py new file mode 100644 index 0000000..f342d1e --- /dev/null +++ b/copypaste整张特征图.py @@ -0,0 +1,319 @@ +import cv2 +import numpy as np +import os +import random +from collections import defaultdict +from tqdm import tqdm + +# --- 1. 核心配置 --- +# 特征素材来源目录(包含完整图片和对应标签) +# SOURCE_FEATURE_IMAGE_DIR = r"D:\DataPreHandler\yuanshi_data\images\train\images" +# SOURCE_FEATURE_LABEL_DIR = r"D:\DataPreHandler\yuanshi_data\images\train\labels" +# SOURCE_FEATURE_IMAGE_DIR = r"D:\DataPreHandler\yuanshi_data\images\val\images" +# SOURCE_FEATURE_LABEL_DIR = r"D:\DataPreHandler\yuanshi_data\images\val\labels" +SOURCE_FEATURE_IMAGE_DIR = r"D:\DataPreHandler\yuanshi_data\images\test\images" +SOURCE_FEATURE_LABEL_DIR = r"D:\DataPreHandler\yuanshi_data\images\test\labels" + +# 底图目录 +# BASE_IMAGE_DIR = r"D:\DataPreHandler\images\dituchoqu\train" +BASE_IMAGE_DIR = r"D:\DataPreHandler\images\dituchoqu_huashen\train" +# BASE_IMAGE_DIR = r"D:\DataPreHandler\images\test" + +# 输出目录 +OUTPUT_IMAGE_DIR = r"D:\DataPreHandler\data\train3\images" +OUTPUT_LABEL_DIR = r"D:\DataPreHandler\data\train3\labels" +# OUTPUT_IMAGE_DIR = r"D:\DataPreHandler\data\val3\images" +# OUTPUT_LABEL_DIR = r"D:\DataPreHandler\data\val3\labels" +# OUTPUT_IMAGE_DIR = r"D:\DataPreHandler\data\test\da\images" +# OUTPUT_LABEL_DIR = r"D:\DataPreHandler\data\test\da\labels" + +# 底图数量 +SELECT_BASE_IMAGE_COUNT = 0 + + +AUGMENTATION_FACTOR = 1 +TARGET_CLASSES = [0, 1, 2, 3, 4, 5, 6] +PASTE_COUNT_RANGE = (2, 3) +MIN_SCALED_SIZE = 50 +MAX_OVERLAP_IOU = 0.0 # 仅限制“跨特征图”无重叠 +MAX_RETRY_COUNT = 300 +SCALE_FACTOR_RANGE = (0.3, 0.9) + + +# --- 2. 工具函数 ---(核心修改:删除内部重叠检测) +def calculate_iou(box1, box2): + """计算两个边界框的交并比(像素坐标)——保持不变""" + box1 = [int(round(x)) for x in box1] + box2 = [int(round(x)) for x in box2] + + inter_x1 = max(box1[0], box2[0]) + inter_y1 = max(box1[1], box2[1]) + inter_x2 = min(box1[2], box2[2]) + inter_y2 = min(box1[3], box2[3]) + + inter_width = max(0, inter_x2 - inter_x1) + inter_height = max(0, inter_y2 - inter_y1) + inter_area = inter_width * inter_height + + box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) + box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) + union_area = box1_area + box2_area - inter_area + + return inter_area / union_area if union_area > 0 else 0.0 + + +def collect_feature_images(feature_image_dir, feature_label_dir, target_classes): + """收集特征图片及其标签——核心修改:删除内部重叠过滤""" + print(f"从 {feature_image_dir} 收集特征图片及标签,目标类别 {target_classes}...") + feature_list = [] # 元素格式:(图片路径, 标签列表, 原始宽, 原始高) + + feature_image_files = [f for f in os.listdir(feature_image_dir) + if f.lower().endswith(('.jpg', '.png', '.jpeg'))] + if not feature_image_files: + raise FileNotFoundError(f"未找到特征图片!") + + for filename in tqdm(feature_image_files, desc="收集特征图片"): + img_path = os.path.join(feature_image_dir, filename) + label_file = os.path.splitext(filename)[0] + ".txt" + label_path = os.path.join(feature_label_dir, label_file) + + if not os.path.exists(label_path): + continue + + img = cv2.imread(img_path) + if img is None: + tqdm.write(f"警告:无法读取图片 {filename},已跳过") + continue + orig_h, orig_w = img.shape[:2] + + # 仅过滤缩小后可能过小的特征图(保留内部重叠的图) + min_required_orig_size = MIN_SCALED_SIZE / max(SCALE_FACTOR_RANGE) + if orig_w < min_required_orig_size or orig_h < min_required_orig_size: + tqdm.write(f"警告:特征图片 {filename} 原始尺寸过小,已跳过") + continue + + # 读取并过滤目标类别(不再检查内部重叠) + labels = [] + with open(label_path, 'r') as f: + for line in f.readlines(): + line = line.strip() + if not line: + continue + parts = line.split() + cls_id = int(float(parts[0])) + if cls_id not in target_classes: + continue + xc, yc, w, h = [float(p) for p in parts[1:]] + if 0 <= xc <= 1 and 0 <= yc <= 1 and 0 < w <= 1 and 0 < h <= 1: + labels.append((cls_id, xc, yc, w, h)) + + # 只要有有效标签就保留(无论内部是否重叠) + if labels: + feature_list.append((img_path, labels, orig_w, orig_h)) + + if not feature_list: + raise ValueError(f"未收集到有效的特征图片及标签!") + + print(f"特征图片收集完成:共 {len(feature_list)} 张有效特征图片(允许内部目标重叠)") + return feature_list + + +def paste_feature_images_to_base(base_image, feature_list): + """将特征图粘贴到底图——保留“跨特征图无重叠”逻辑(核心)""" + base_h, base_w = base_image.shape[:2] + pasted_labels = [] # 最终输出的标签 + pasted_target_boxes = [] # 已粘贴的目标框(像素坐标) + pasted_feature_regions = [] # 已粘贴的特征图整体区域(防跨图重叠) + pasted_feature_count = 0 # 成功粘贴的特征图数量 + skipped_small = 0 # 因尺寸过小跳过的次数 + skipped_region_overlap = 0 # 因特征图整体区域重叠跳过的次数 + skipped_target_overlap = 0 # 因跨图目标重叠跳过的次数 + retry_count = 0 # 重试次数 + + target_paste_count = random.randint(*PASTE_COUNT_RANGE) + print(f" 计划粘贴 {target_paste_count} 张特征图...") + + while pasted_feature_count < target_paste_count and retry_count < MAX_RETRY_COUNT: + retry_count += 1 + + # 1. 随机选择一张特征图(允许内部重叠) + feature_img_path, feature_labels, orig_w, orig_h = random.choice(feature_list) + + # 2. 随机缩放特征图,确保尺寸符合要求 + scale_factor = random.uniform(*SCALE_FACTOR_RANGE) + scaled_w = int(round(orig_w * scale_factor)) + scaled_h = int(round(orig_h * scale_factor)) + + if scaled_w < MIN_SCALED_SIZE or scaled_h < MIN_SCALED_SIZE: + skipped_small += 1 + continue + if scaled_w >= base_w or scaled_h >= base_h: + skipped_small += 1 + continue + + # 3. 随机生成粘贴位置(确保特征图完全在底图内) + paste_x1 = random.randint(0, base_w - scaled_w) + paste_y1 = random.randint(0, base_h - scaled_h) + paste_x2 = paste_x1 + scaled_w + paste_y2 = paste_y1 + scaled_h + current_feature_region = (paste_x1, paste_y1, paste_x2, paste_y2) + + # 4. 关键:检测当前特征图整体区域与已粘贴区域是否重叠(防跨图重叠) + region_overlap = False + for existing_region in pasted_feature_regions: + if calculate_iou(current_feature_region, existing_region) > MAX_OVERLAP_IOU: + region_overlap = True + skipped_region_overlap += 1 + break + if region_overlap: + continue + + # 5. 计算当前特征图目标在底图上的像素坐标(保留原始内部重叠) + temp_target_boxes = [] + valid = True + for (cls_id, xc, yc, w, h) in feature_labels: + # 特征图内目标坐标 → 缩放后 → 底图坐标(内部重叠会保留) + orig_x1 = (xc - w/2) * orig_w + orig_y1 = (yc - h/2) * orig_h + orig_x2 = (xc + w/2) * orig_w + orig_y2 = (yc + h/2) * orig_h + + scaled_x1 = orig_x1 * scale_factor + scaled_y1 = orig_y1 * scale_factor + scaled_x2 = orig_x2 * scale_factor + scaled_y2 = orig_y2 * scale_factor + + base_x1 = paste_x1 + scaled_x1 + base_y1 = paste_y1 + scaled_y1 + base_x2 = paste_x1 + scaled_x2 + base_y2 = paste_y1 + scaled_y2 + + # 确保目标框完全在底图内(不考虑内部重叠) + if base_x1 < 0 or base_y1 < 0 or base_x2 > base_w or base_y2 > base_h: + valid = False + break + temp_target_boxes.append((base_x1, base_y1, base_x2, base_y2)) + + if not valid: + continue + + # 6. 关键:检测当前特征图目标与已粘贴目标是否重叠(防跨图目标重叠) + target_overlap = False + for temp_box in temp_target_boxes: + for existing_box in pasted_target_boxes: + if calculate_iou(temp_box, existing_box) > MAX_OVERLAP_IOU: + target_overlap = True + skipped_target_overlap += 1 + break + if target_overlap: + break + if target_overlap: + continue + + # 7. 粘贴特征图并记录信息(保留内部重叠) + feature_img = cv2.imread(feature_img_path) + if feature_img is None: + continue + scaled_feature_img = cv2.resize(feature_img, (scaled_w, scaled_h), interpolation=cv2.INTER_LINEAR) + base_image[paste_y1:paste_y2, paste_x1:paste_x2] = scaled_feature_img + + # 记录标签、目标框、特征图整体区域 + for (cls_id, xc, yc, w, h), temp_box in zip(feature_labels, temp_target_boxes): + base_xc = (temp_box[0] + temp_box[2]) / 2 / base_w + base_yc = (temp_box[1] + temp_box[3]) / 2 / base_h + base_w_box = (temp_box[2] - temp_box[0]) / base_w + base_h_box = (temp_box[3] - temp_box[1]) / base_h + + pasted_labels.append((cls_id, base_xc, base_yc, base_w_box, base_h_box)) + pasted_target_boxes.append(temp_box) + pasted_feature_regions.append(current_feature_region) + pasted_feature_count += 1 + print(f" 已成功粘贴 {pasted_feature_count}/{target_paste_count} 张特征图") + + # 输出统计(无内部重叠相关提示) + print( + f" 粘贴统计:成功{pasted_feature_count}张 → " + f"跳小尺寸{skipped_small} → 跳区域重叠{skipped_region_overlap} → " + f"跳目标重叠{skipped_target_overlap} → 总目标数{len(pasted_labels)}" + ) + return base_image, pasted_labels + + +# --- 3. 主函数 ---(无修改) +def main(): + os.makedirs(OUTPUT_IMAGE_DIR, exist_ok=True) + os.makedirs(OUTPUT_LABEL_DIR, exist_ok=True) + + try: + feature_list = collect_feature_images( + SOURCE_FEATURE_IMAGE_DIR, SOURCE_FEATURE_LABEL_DIR, TARGET_CLASSES + ) + except (FileNotFoundError, ValueError) as e: + print(f"错误:{e}") + return + + # 筛选指定数量的底图 + base_files = [f for f in os.listdir(BASE_IMAGE_DIR) + if f.lower().endswith(('.jpg', '.png', '.jpeg'))] + total_base_count = len(base_files) + if total_base_count == 0: + print(f"错误:底图目录 {BASE_IMAGE_DIR} 无有效图片!") + return + + if SELECT_BASE_IMAGE_COUNT < 0: + print(f"错误:底图数量不能为负数(当前设置:{SELECT_BASE_IMAGE_COUNT})!") + return + elif SELECT_BASE_IMAGE_COUNT == 0: + selected_base_files = base_files + print(f"\n找到 {total_base_count} 张底图,使用全部底图...\n") + else: + if SELECT_BASE_IMAGE_COUNT > total_base_count: + selected_base_files = base_files + print(f"\n警告:设置底图数({SELECT_BASE_IMAGE_COUNT})超过可用数({total_base_count}),使用全部底图...\n") + else: + selected_base_files = random.sample(base_files, SELECT_BASE_IMAGE_COUNT) + print(f"\n找到 {total_base_count} 张底图,随机选择 {SELECT_BASE_IMAGE_COUNT} 张...\n") + + # 处理底图 + for base_filename in tqdm(selected_base_files, desc="处理底图"): + base_name, ext = os.path.splitext(base_filename) + base_path = os.path.join(BASE_IMAGE_DIR, base_filename) + base_img = cv2.imread(base_path) + if base_img is None: + tqdm.write(f"\n警告:无法读取底图 {base_filename},已跳过") + continue + + for aug_idx in range(AUGMENTATION_FACTOR): + print(f"\n底图 {base_filename}(增强序号 {aug_idx}):", end="") + base_copy = base_img.copy() + pasted_img, labels = paste_feature_images_to_base(base_copy, feature_list) + + if not labels: + tqdm.write(f"\n警告:未粘贴任何目标,已跳过该结果") + continue + + # 保存图片和标签 + output_img_name = f"{base_name}_feat_paste_{aug_idx}.jpg" + output_img_path = os.path.join(OUTPUT_IMAGE_DIR, output_img_name) + cv2.imwrite(output_img_path, pasted_img) + + output_label_name = f"{base_name}_feat_paste_{aug_idx}.txt" + output_label_path = os.path.join(OUTPUT_LABEL_DIR, output_label_name) + with open(output_label_path, 'w') as f: + for (cls_id, x, y, w, h) in labels: + f.write(f"{cls_id} {x:.6f} {y:.6f} {w:.6f} {h:.6f}\n") + + print(f" → 已保存(目标数:{len(labels)})") + + # 统计结果 + output_img_count = len([f for f in os.listdir(OUTPUT_IMAGE_DIR) if f.endswith(('.jpg', '.png'))]) + output_label_count = len([f for f in os.listdir(OUTPUT_LABEL_DIR) if f.endswith('.txt')]) + print(f"\n✅ 全部处理完成!") + print(f" - 实际处理底图数量:{len(selected_base_files)} 张") + print(f" - 生成图片:{output_img_count} 张") + print(f" - 生成标签:{output_label_count} 个") + print(f" - 输出路径:\n 图片 → {OUTPUT_IMAGE_DIR}\n 标签 → {OUTPUT_LABEL_DIR}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/扩大边缘像素.py b/扩大边缘像素.py new file mode 100644 index 0000000..71119d4 --- /dev/null +++ b/扩大边缘像素.py @@ -0,0 +1,65 @@ +import os + + +def expand_bbox(label_path, k=1.1): + """ + 扩大YOLO标签的边界框 + :param label_path: 标签文件路径(.txt) + :param k: 扩大系数(k>1) + """ + with open(label_path, 'r') as f: + lines = f.readlines() + + new_lines = [] + for line in lines: + line = line.strip() + if not line: + continue + # 解析标签 + class_id, xc, yc, w, h = line.split() + xc = float(xc) + yc = float(yc) + w = float(w) + h = float(h) + + # 计算新宽高 + new_w = w * k + new_h = h * k + + # 计算边界 + x1 = xc - new_w / 2 + y1 = yc - new_h / 2 + x2 = xc + new_w / 2 + y2 = yc + new_h / 2 + + # 截断超出图像的部分(0~1范围) + x1 = max(0.0, x1) + y1 = max(0.0, y1) + x2 = min(1.0, x2) + y2 = min(1.0, y2) + + # 重新计算中心和宽高 + new_xc = (x1 + x2) / 2 + new_yc = (y1 + y2) / 2 + new_w = x2 - x1 + new_h = y2 - y1 + + # 保留6位小数,拼接新标签 + new_line = f"{class_id} {new_xc:.6f} {new_yc:.6f} {new_w:.6f} {new_h:.6f}\n" + new_lines.append(new_line) + + # 写入新标签(覆盖原文件,或改为新路径) + with open(label_path, 'w') as f: + f.writelines(new_lines) + + +# 批量处理文件夹中的所有标签 +label_dir = r"D:\DataPreHandler\yuanshi_data\images\val\labels" # 替换为你的标签文件夹路径 +k = 1.25 # 扩大系数,根据需求调整 + +for filename in os.listdir(label_dir): + if filename.endswith('.txt'): + label_path = os.path.join(label_dir, filename) + expand_bbox(label_path, k) + +print("处理完成!") \ No newline at end of file diff --git a/抽取样本数量.py b/抽取样本数量.py new file mode 100644 index 0000000..e95f440 --- /dev/null +++ b/抽取样本数量.py @@ -0,0 +1,217 @@ +import os +import random +import shutil +from collections import defaultdict + + +def add_balanced_samples(source_dir, target_dir, add_samples): + """ + 在已有目标数据集基础上添加新样本(不重复),并保持类别平衡 + 参数: + source_dir: 源数据集目录(包含images和labels) + target_dir: 目标数据集目录(已有样本,包含images和labels) + add_samples: 需要新增的样本数量 + """ + # 定义类别名称映射 + class_names = { + 0: "Abdomen", + 1: "Hips", + 2: "Chest", + 3: "vulva", + 4: "back", + 5: "penis", + 6: "Horror" + } + num_classes = len(class_names) + + # 确保目标目录存在 + os.makedirs(os.path.join(target_dir, "images"), exist_ok=True) + os.makedirs(os.path.join(target_dir, "labels"), exist_ok=True) + + # -------------------------- + # 关键步骤1:获取目标目录中已有的文件(避免重复) + # -------------------------- + existing_images = set() + target_image_dir = os.path.join(target_dir, "images") + for f in os.listdir(target_image_dir): + if os.path.isfile(os.path.join(target_image_dir, f)): + existing_images.add(f) # 记录已存在的图片文件名 + print(f"目标目录中已存在 {len(existing_images)} 个样本,将避免重复添加") + + # -------------------------- + # 关键步骤2:获取源目录中未被目标目录包含的可用文件 + # -------------------------- + source_image_dir = os.path.join(source_dir, "images") + source_label_dir = os.path.join(source_dir, "labels") + + if not os.path.exists(source_image_dir) or not os.path.exists(source_label_dir): + print(f"错误: 源目录 {source_dir} 中缺少images或labels子目录") + return + + # 源目录所有图片中,排除已存在于目标目录的文件 + candidate_images = [ + f for f in os.listdir(source_image_dir) + if os.path.isfile(os.path.join(source_image_dir, f)) and f not in existing_images + ] + + if not candidate_images: + print(f"错误: 源目录中没有可添加的新样本(所有样本已存在于目标目录)") + return + + # -------------------------- + # 分析候选文件的类别(兼容浮点数ID) + # -------------------------- + file_classes = {} # 候选文件名 -> 包含的类别集合 + class_files = defaultdict(list) # 类别 -> 包含该类别的候选文件列表 + + for img_file in candidate_images: + base_name = os.path.splitext(img_file)[0] + label_file = f"{base_name}.txt" + label_path = os.path.join(source_label_dir, label_file) + + if not os.path.exists(label_path): + print(f"警告: 图片 {img_file} 对应的标签文件 {label_file} 不存在,已跳过") + continue + + classes_in_file = set() + with open(label_path, 'r') as f: + for line in f: + parts = line.strip().split() + if not parts: + continue + try: + class_id = int(float(parts[0])) # 兼容浮点数类别ID + if class_id in class_names: + classes_in_file.add(class_id) + else: + print(f"警告: 标签文件 {label_file} 包含未知类别 {class_id}") + except (ValueError, IndexError): + print(f"警告: 标签文件 {label_file} 格式不正确,行内容: {line.strip()}") + + if not classes_in_file: + print(f"警告: 标签文件 {label_file} 无有效类别,已跳过") + continue + + file_classes[img_file] = classes_in_file + for class_id in classes_in_file: + class_files[class_id].append(img_file) + + # 检查候选文件中是否有类别缺失 + for class_id in class_names: + if class_id not in class_files or len(class_files[class_id]) == 0: + print(f"警告: 候选文件中没有类别 {class_id} ({class_names[class_id]}) 的样本") + + # -------------------------- + # 处理新增样本数量(不超过候选文件数量) + # -------------------------- + available_candidates = len(file_classes) + if add_samples > available_candidates: + print(f"警告: 请求新增 {add_samples} 个样本,但候选文件仅 {available_candidates} 个,将全部添加") + add_samples = available_candidates + elif add_samples <= 0: + print("错误: 新增样本数量必须大于0") + return + + # -------------------------- + # 平衡抽取新增样本 + # -------------------------- + selected_files = set() + remaining = add_samples + + # 每个类别理想的新增数量(基于总新增数量平均) + ideal_per_class = max(1, add_samples // num_classes) + + # 先为每个类别抽取理想数量的样本 + for class_id in class_names: + if class_id not in class_files: + continue + + # 候选文件中未被选中的 + available = [f for f in class_files[class_id] if f not in selected_files] + num_to_add = min(ideal_per_class, len(available), remaining) + + if num_to_add > 0: + selected = random.sample(available, num_to_add) + selected_files.update(selected) + remaining -= num_to_add + + if remaining <= 0: + break + + # 补充剩余需要新增的样本 + if remaining > 0: + all_available = [f for f in file_classes if f not in selected_files] + num_to_add = min(remaining, len(all_available)) + if num_to_add > 0: + selected_additional = random.sample(all_available, num_to_add) + selected_files.update(selected_additional) + + # -------------------------- + # 复制新增样本到目标目录(确保不重复) + # -------------------------- + for img_file in selected_files: + # 复制图片 + src_img = os.path.join(source_image_dir, img_file) + dst_img = os.path.join(target_image_dir, img_file) + shutil.copy2(src_img, dst_img) + + # 复制标签 + base_name = os.path.splitext(img_file)[0] + label_file = f"{base_name}.txt" + src_label = os.path.join(source_label_dir, label_file) + dst_label = os.path.join(target_dir, "labels", label_file) + if os.path.exists(src_label): + shutil.copy2(src_label, dst_label) + + # -------------------------- + # 统计结果(包含目标目录原有+新增的总分布) + # -------------------------- + # 1. 统计目标目录原有样本的类别 + existing_class_counts = defaultdict(int) + for img_file in existing_images: + # 从源目录找原有样本的标签(因为目标目录标签可能已存在,但源目录更可靠) + base_name = os.path.splitext(img_file)[0] + label_file = f"{base_name}.txt" + label_path = os.path.join(source_label_dir, label_file) + if os.path.exists(label_path): + with open(label_path, 'r') as f: + for line in f: + parts = line.strip().split() + if parts: + try: + class_id = int(float(parts[0])) + if class_id in class_names: + existing_class_counts[class_id] += 1 + except (ValueError, IndexError): + pass # 忽略格式错误的行 + + # 2. 统计新增样本的类别 + new_class_counts = defaultdict(int) + for img_file in selected_files: + for class_id in file_classes[img_file]: + new_class_counts[class_id] += 1 + + # 3. 总分布(原有+新增) + total_class_counts = { + class_id: existing_class_counts.get(class_id, 0) + new_class_counts.get(class_id, 0) + for class_id in class_names + } + + # 输出结果 + print(f"\n已成功新增 {len(selected_files)} 个样本,目标目录总样本数: {len(existing_images) + len(selected_files)}") + print("目标目录总类别分布(原有+新增):") + for class_id in class_names: + print(f"类别 {class_id} ({class_names[class_id]}): {total_class_counts[class_id]} 次出现") + + +if __name__ == "__main__": + # 源数据集路径(从中抽取新样本) + source_directory = r"D:\DataPreHandler\data\train3" + # 目标数据集路径(已有样本,将新增到这里) + target_directory = r"D:\DataPreHandler\data\train2" + + # 需要新增的样本数量(根据需求修改) + add_samples = 1200 # 例如:再添加100个新样本 + + # 执行新增操作 + add_balanced_samples(source_directory, target_directory, add_samples) \ No newline at end of file diff --git a/根据位置画框.py b/根据位置画框.py new file mode 100644 index 0000000..d4a139c --- /dev/null +++ b/根据位置画框.py @@ -0,0 +1,199 @@ +import cv2 +import numpy as np +import os +from tqdm import tqdm + +# -------------------------- 1. 核心配置(请根据需求修改) -------------------------- +# 输入目录(图片和标签需一一对应,文件名相同,仅后缀不同) +INPUT_IMAGE_DIR = r"D:\DataPreHandler\data\train\images" # 原始图片目录 +INPUT_LABEL_DIR = r"D:\DataPreHandler\data\train\labels" # 原始YOLO标签目录 +# 输出目录(标注后的图片会保存在这里) +OUTPUT_IMAGE_DIR = r"D:\DataPreHandler\data\test\da\output2" + +# 关键配置:类别ID与类别名称的映射(必须与你的YOLO训练类别顺序一致!) +CLASS_CONFIG = [ + (0, "Abdomen", (0, 255, 0)), + (1, "Hips", (0, 255, 0)), + (2, "Chest", (0, 255, 0)), + (3, "vulva", (0, 255, 0)), + (4, "back", (0, 255, 0)), + (5, "penis", (0, 255, 0)), + (6, "Horror", (0, 255, 0)) +] + +# 绘制参数(可按需调整) +BOX_THICKNESS = 2 # 边界框线条厚度(像素) +FONT_FACE = cv2.FONT_HERSHEY_SIMPLEX # 字体类型 +FONT_SCALE = 0.6 # 字体大小(根据图片尺寸调整) +FONT_THICKNESS = 1 # 字体线条厚度 +TEXT_PADDING = 5 # 文字与边界框的间距(像素) +TEXT_BG_OPACITY = 0.7 # 文字背景的透明度(0-1,0为完全透明) + +# 支持的图片格式(无需修改) +SUPPORTED_IMAGE_FORMATS = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff') + + +# -------------------------- 2. 工具函数 -------------------------- +def yolo2pixel(yolo_coords, img_w, img_h): + """ + 将YOLO相对坐标转换为图片像素坐标(边界框:x1, y1, x2, y2) + :param yolo_coords: YOLO坐标列表 [xc, yc, w, h](相对值,0-1) + :param img_w: 图片宽度(像素) + :param img_h: 图片高度(像素) + :return: 像素坐标元组 (x1, y1, x2, y2) + """ + xc, yc, w, h = yolo_coords + # 计算边界框左上角和右下角坐标 + x1 = int((xc - w / 2) * img_w) + y1 = int((yc - h / 2) * img_h) + x2 = int((xc + w / 2) * img_w) + y2 = int((yc + h / 2) * img_h) + # 确保坐标不超出图片范围 + x1 = max(0, x1) + y1 = max(0, y1) + x2 = min(img_w, x2) + y2 = min(img_h, y2) + return x1, y1, x2, y2 + + +def draw_annotation(img, bbox, class_name, color): + """ + 在图片上绘制边界框和类别名称 + :param img: 原始图片(OpenCV格式,BGR通道) + :param bbox: 像素坐标边界框 (x1, y1, x2, y2) + :param class_name: 类别名称(字符串) + :param color: 边界框和文字颜色(BGR元组,如 (0,255,0) 代表绿色) + :return: 标注后的图片 + """ + img_h, img_w = img.shape[:2] + x1, y1, x2, y2 = bbox + + # 1. 绘制边界框 + cv2.rectangle(img, (x1, y1), (x2, y2), color, BOX_THICKNESS) + + # 2. 计算文字尺寸(用于创建文字背景) + text_size, _ = cv2.getTextSize(class_name, FONT_FACE, FONT_SCALE, FONT_THICKNESS) + text_w, text_h = text_size + + # 3. 确定文字位置(避免超出图片范围) + # 文字默认放在边界框左上角,若左上角空间不足则放在右上角 + text_x = x1 + TEXT_PADDING + text_y = y1 - TEXT_PADDING - text_h # 文字基线在y轴上方 + if text_y < 0: # 左上角超出图片顶部,调整到右上角 + text_x = x2 - TEXT_PADDING - text_w + text_y = y1 + TEXT_PADDING + text_h + + # 4. 绘制文字背景(半透明矩形,避免遮挡图片内容) + bg_x1 = text_x - TEXT_PADDING + bg_y1 = text_y - text_h - TEXT_PADDING + bg_x2 = text_x + text_w + TEXT_PADDING + bg_y2 = text_y + TEXT_PADDING + # 确保背景不超出图片范围 + bg_x1 = max(0, bg_x1) + bg_y1 = max(0, bg_y1) + bg_x2 = min(img_w, bg_x2) + bg_y2 = min(img_h, bg_y2) + + # 半透明背景:先创建背景层,再与原图混合 + bg = img[bg_y1:bg_y2, bg_x1:bg_x2].copy() + bg = cv2.rectangle(bg, (0, 0), (bg_x2 - bg_x1, bg_y2 - bg_y1), color, -1) # 实心矩形 + img[bg_y1:bg_y2, bg_x1:bg_x2] = cv2.addWeighted(bg, TEXT_BG_OPACITY, img[bg_y1:bg_y2, bg_x1:bg_x2], 1 - TEXT_BG_OPACITY, 0) + + # 5. 绘制类别名称 + cv2.putText(img, class_name, (text_x, text_y), FONT_FACE, FONT_SCALE, (0, 0, 0), FONT_THICKNESS) # 白色文字 + + return img + + +# -------------------------- 3. 主函数 -------------------------- +def main(): + # 1. 创建输出目录(若不存在) + os.makedirs(OUTPUT_IMAGE_DIR, exist_ok=True) + print(f"标注后的图片将保存到:{OUTPUT_IMAGE_DIR}\n") + + # 2. 构建类别ID到(名称+颜色)的映射字典 + class_map = {cls_id: (cls_name, cls_color) for cls_id, cls_name, cls_color in CLASS_CONFIG} + print("类别配置:") + for cls_id, cls_name, cls_color in CLASS_CONFIG: + print(f" ID {cls_id} → 名称:{cls_name},颜色:{cls_color}") + print() + + # 3. 获取所有图片文件(仅处理支持的格式) + image_files = [f for f in os.listdir(INPUT_IMAGE_DIR) if f.lower().endswith(SUPPORTED_IMAGE_FORMATS)] + if not image_files: + raise FileNotFoundError(f"在 {INPUT_IMAGE_DIR} 中未找到任何支持的图片文件({SUPPORTED_IMAGE_FORMATS})") + print(f"找到 {len(image_files)} 张图片,开始标注...\n") + + # 4. 遍历图片并标注 + for img_filename in tqdm(image_files, desc="处理进度"): + # 4.1 构建图片和标签的路径 + img_name, img_ext = os.path.splitext(img_filename) + img_path = os.path.join(INPUT_IMAGE_DIR, img_filename) + label_path = os.path.join(INPUT_LABEL_DIR, f"{img_name}.txt") # 标签文件与图片同名,后缀为txt + + # 4.2 读取图片(OpenCV默认读取为BGR通道) + img = cv2.imread(img_path) + if img is None: + tqdm.write(f"⚠️ 跳过:无法读取图片 {img_filename}(可能损坏或格式不支持)") + continue + img_h, img_w = img.shape[:2] + + # 4.3 读取标签文件(若不存在则跳过标注,直接保存原图) + if not os.path.exists(label_path): + tqdm.write(f"⚠️ 警告:图片 {img_filename} 无对应标签文件 {os.path.basename(label_path)},直接保存原图") + annotated_img = img.copy() + else: + # 复制原图用于标注(避免修改原始图片) + annotated_img = img.copy() + # 读取标签内容 + with open(label_path, 'r', encoding='utf-8') as f: + label_lines = [line.strip() for line in f.readlines() if line.strip()] # 过滤空行 + + # 4.4 解析每个标签并绘制 + for line_idx, line in enumerate(label_lines): + try: + # YOLO标签格式:class_id xc yc w h(空格分隔) + parts = line.split() + if len(parts) != 5: + raise ValueError(f"格式错误(需5个字段,实际{len(parts)}个)") + + # 解析类别ID和坐标 + cls_id = int(float(parts[0])) + yolo_coords = [float(p) for p in parts[1:]] + # 检查YOLO坐标有效性(必须在0-1范围内) + if not all(0 <= coord <= 1 for coord in yolo_coords): + raise ValueError(f"YOLO坐标超出0-1范围:{yolo_coords}") + + # 4.5 转换坐标并绘制 + # 检查类别ID是否在配置中 + if cls_id not in class_map: + tqdm.write(f"⚠️ 跳过:图片 {img_filename} 标签第{line_idx+1}行,未知类别ID {cls_id}(未在CLASS_CONFIG中配置)") + continue + + # 获取类别名称和颜色 + cls_name, cls_color = class_map[cls_id] + # 转换YOLO坐标为像素坐标 + bbox = yolo2pixel(yolo_coords, img_w, img_h) + # 绘制标注 + annotated_img = draw_annotation(annotated_img, bbox, cls_name, cls_color) + + except Exception as e: + tqdm.write(f"⚠️ 跳过:图片 {img_filename} 标签第{line_idx+1}行解析失败 → {str(e)}") + continue + + # 4.6 保存标注后的图片 + output_img_path = os.path.join(OUTPUT_IMAGE_DIR, f"{img_name}_annotated{img_ext}") + # 保存为JPG格式(若原始是PNG,也可改为img_ext保持原格式) + # 注:JPG不支持透明通道,若原始是PNG且有透明,建议保留img_ext + cv2.imwrite(output_img_path, annotated_img) + + # 5. 完成提示 + print(f"\n✅ 标注完成!共处理 {len(image_files)} 张图片,标注后的图片已保存到:") + print(f" {OUTPUT_IMAGE_DIR}") + + +if __name__ == "__main__": + try: + main() + except Exception as e: + print(f"\n❌ 程序异常终止:{str(e)}") \ No newline at end of file diff --git a/视频抽帧.py b/视频抽帧.py index 58f15ad..adb54b0 100644 --- a/视频抽帧.py +++ b/视频抽帧.py @@ -6,23 +6,23 @@ from pathlib import Path # -------------------------- 1. 用户配置(根据实际情况修改)-------------------------- # 视频文件路径(绝对路径或相对路径均可,支持mp4/avi/mov等常见格式) -VIDEO_PATH = r"E:\geminicli\yolo\images\44.mp4" # 替换为你的视频路径 - -# 输出根目录(脚本会自动在该目录下创建train/valid/test子文件夹) -OUTPUT_ROOT_DIR = r"D:\DataPreHandler\images" # 替换为你的输出根路径 +# VIDEO_PATH = r"E:\geminicli\yolo\images\屏幕操作视频.mp4" # 替换为你的视频路径 +VIDEO_PATH = r"D:\DataPreHandler\MP4\花生底图.mp4" +# 输出根目录(脚本会自动在该目录下创建train/val/test子文件夹) +OUTPUT_ROOT_DIR = r"D:\DataPreHandler\images\dituchoqu_huashen" # 替换为你的输出根路径 # 各文件夹需抽取的图片数量(总数量=2800+1400+700=4900) FRAME_COUNTS = { - "train": 28, - "valid": 14, - "test": 7 + "train": 1200, + "val": 150, + "test": 50 } # 图片保存格式(建议用jpg,兼容性更好;也可改为png) SAVE_FORMAT = "jpg" # 图片文件名编号位数(如0001.jpg,避免文件名乱序) -FILE_NUM_DIGITS = 4 # 对应最大编号9999,满足4900张需求 +FILE_NUM_DIGITS = 5 # 对应最大编号9999,满足4900张需求 # ----------------------------------------------------------------------------------- @@ -62,13 +62,13 @@ def check_video_validity(cap): print(f" - 帧率(FPS):{fps:.1f}") print(f" - 总时长:{video_duration // 60:.0f}分{video_duration % 60:.1f}秒") print( - f" - 需抽取总帧数:{required_total}(train:{FRAME_COUNTS['train']}, valid:{FRAME_COUNTS['valid']}, test:{FRAME_COUNTS['test']})") + f" - 需抽取总帧数:{required_total}(train:{FRAME_COUNTS['train']}, val:{FRAME_COUNTS['val']}, test:{FRAME_COUNTS['test']})") return total_frames def create_output_dirs(): - """创建输出根目录及train/valid/test子文件夹""" + """创建输出根目录及train/val/test子文件夹""" # 转换为Path对象,适配Windows/Linux/macOS路径格式 output_root = Path(OUTPUT_ROOT_DIR) subdirs = FRAME_COUNTS.keys() @@ -97,14 +97,14 @@ def generate_random_frame_indices(total_frames, required_total): def split_indices_by_dataset(random_indices): - """将总随机索引按train/valid/test的数量拆分""" + """将总随机索引按train/val/test的数量拆分""" train_count = FRAME_COUNTS["train"] - valid_count = FRAME_COUNTS["valid"] + valid_count = FRAME_COUNTS["val"] # 拆分逻辑:前N个给train,中间M个给valid,剩余给test indices_split = { "train": random_indices[:train_count], - "valid": random_indices[train_count:train_count + valid_count], + "val": random_indices[train_count:train_count + valid_count], "test": random_indices[train_count + valid_count:] } @@ -140,7 +140,7 @@ def extract_and_save_frames(cap, indices_split, output_dirs): continue # 生成文件名(如train_0001.jpg) - file_name = f"{dataset}_{idx:0{FILE_NUM_DIGITS}d}.{SAVE_FORMAT}" # 0{FILE_NUM_DIGITS}d表示补零到指定位数 + file_name = f"dths_{dataset}_{idx:0{FILE_NUM_DIGITS}d}.{SAVE_FORMAT}" # 0{FILE_NUM_DIGITS}d表示补零到指定位数 save_path = dataset_dir / file_name # 保存帧为图片(cv2.imwrite默认保存为BGR格式,符合图片存储标准)