数据增强相关代码
This commit is contained in:
		
							
								
								
									
										319
									
								
								copypaste整张特征图.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										319
									
								
								copypaste整张特征图.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,319 @@ | ||||
| import cv2 | ||||
| import numpy as np | ||||
| import os | ||||
| import random | ||||
| from collections import defaultdict | ||||
| from tqdm import tqdm | ||||
|  | ||||
| # --- 1. 核心配置 --- | ||||
| # 特征素材来源目录(包含完整图片和对应标签) | ||||
| # SOURCE_FEATURE_IMAGE_DIR = r"D:\DataPreHandler\yuanshi_data\images\train\images" | ||||
| # SOURCE_FEATURE_LABEL_DIR = r"D:\DataPreHandler\yuanshi_data\images\train\labels" | ||||
| # SOURCE_FEATURE_IMAGE_DIR = r"D:\DataPreHandler\yuanshi_data\images\val\images" | ||||
| # SOURCE_FEATURE_LABEL_DIR = r"D:\DataPreHandler\yuanshi_data\images\val\labels" | ||||
| SOURCE_FEATURE_IMAGE_DIR = r"D:\DataPreHandler\yuanshi_data\images\test\images" | ||||
| SOURCE_FEATURE_LABEL_DIR = r"D:\DataPreHandler\yuanshi_data\images\test\labels" | ||||
|  | ||||
| # 底图目录 | ||||
| # BASE_IMAGE_DIR = r"D:\DataPreHandler\images\dituchoqu\train" | ||||
| BASE_IMAGE_DIR = r"D:\DataPreHandler\images\dituchoqu_huashen\train" | ||||
| # BASE_IMAGE_DIR = r"D:\DataPreHandler\images\test" | ||||
|  | ||||
| # 输出目录 | ||||
| OUTPUT_IMAGE_DIR = r"D:\DataPreHandler\data\train3\images" | ||||
| OUTPUT_LABEL_DIR = r"D:\DataPreHandler\data\train3\labels" | ||||
| # OUTPUT_IMAGE_DIR = r"D:\DataPreHandler\data\val3\images" | ||||
| # OUTPUT_LABEL_DIR = r"D:\DataPreHandler\data\val3\labels" | ||||
| # OUTPUT_IMAGE_DIR = r"D:\DataPreHandler\data\test\da\images" | ||||
| # OUTPUT_LABEL_DIR = r"D:\DataPreHandler\data\test\da\labels" | ||||
|  | ||||
| # 底图数量 | ||||
| SELECT_BASE_IMAGE_COUNT = 0 | ||||
|  | ||||
|  | ||||
| AUGMENTATION_FACTOR = 1 | ||||
| TARGET_CLASSES = [0, 1, 2, 3, 4, 5, 6] | ||||
| PASTE_COUNT_RANGE = (2, 3) | ||||
| MIN_SCALED_SIZE = 50 | ||||
| MAX_OVERLAP_IOU = 0.0  # 仅限制“跨特征图”无重叠 | ||||
| MAX_RETRY_COUNT = 300 | ||||
| SCALE_FACTOR_RANGE = (0.3, 0.9) | ||||
|  | ||||
|  | ||||
| # --- 2. 工具函数 ---(核心修改:删除内部重叠检测) | ||||
| def calculate_iou(box1, box2): | ||||
|     """计算两个边界框的交并比(像素坐标)——保持不变""" | ||||
|     box1 = [int(round(x)) for x in box1] | ||||
|     box2 = [int(round(x)) for x in box2] | ||||
|  | ||||
|     inter_x1 = max(box1[0], box2[0]) | ||||
|     inter_y1 = max(box1[1], box2[1]) | ||||
|     inter_x2 = min(box1[2], box2[2]) | ||||
|     inter_y2 = min(box1[3], box2[3]) | ||||
|  | ||||
|     inter_width = max(0, inter_x2 - inter_x1) | ||||
|     inter_height = max(0, inter_y2 - inter_y1) | ||||
|     inter_area = inter_width * inter_height | ||||
|  | ||||
|     box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) | ||||
|     box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) | ||||
|     union_area = box1_area + box2_area - inter_area | ||||
|  | ||||
|     return inter_area / union_area if union_area > 0 else 0.0 | ||||
|  | ||||
|  | ||||
| def collect_feature_images(feature_image_dir, feature_label_dir, target_classes): | ||||
|     """收集特征图片及其标签——核心修改:删除内部重叠过滤""" | ||||
|     print(f"从 {feature_image_dir} 收集特征图片及标签,目标类别 {target_classes}...") | ||||
|     feature_list = []  # 元素格式:(图片路径, 标签列表, 原始宽, 原始高) | ||||
|  | ||||
|     feature_image_files = [f for f in os.listdir(feature_image_dir) | ||||
|                            if f.lower().endswith(('.jpg', '.png', '.jpeg'))] | ||||
|     if not feature_image_files: | ||||
|         raise FileNotFoundError(f"未找到特征图片!") | ||||
|  | ||||
|     for filename in tqdm(feature_image_files, desc="收集特征图片"): | ||||
|         img_path = os.path.join(feature_image_dir, filename) | ||||
|         label_file = os.path.splitext(filename)[0] + ".txt" | ||||
|         label_path = os.path.join(feature_label_dir, label_file) | ||||
|  | ||||
|         if not os.path.exists(label_path): | ||||
|             continue | ||||
|  | ||||
|         img = cv2.imread(img_path) | ||||
|         if img is None: | ||||
|             tqdm.write(f"警告:无法读取图片 {filename},已跳过") | ||||
|             continue | ||||
|         orig_h, orig_w = img.shape[:2] | ||||
|  | ||||
|         # 仅过滤缩小后可能过小的特征图(保留内部重叠的图) | ||||
|         min_required_orig_size = MIN_SCALED_SIZE / max(SCALE_FACTOR_RANGE) | ||||
|         if orig_w < min_required_orig_size or orig_h < min_required_orig_size: | ||||
|             tqdm.write(f"警告:特征图片 {filename} 原始尺寸过小,已跳过") | ||||
|             continue | ||||
|  | ||||
|         # 读取并过滤目标类别(不再检查内部重叠) | ||||
|         labels = [] | ||||
|         with open(label_path, 'r') as f: | ||||
|             for line in f.readlines(): | ||||
|                 line = line.strip() | ||||
|                 if not line: | ||||
|                     continue | ||||
|                 parts = line.split() | ||||
|                 cls_id = int(float(parts[0])) | ||||
|                 if cls_id not in target_classes: | ||||
|                     continue | ||||
|                 xc, yc, w, h = [float(p) for p in parts[1:]] | ||||
|                 if 0 <= xc <= 1 and 0 <= yc <= 1 and 0 < w <= 1 and 0 < h <= 1: | ||||
|                     labels.append((cls_id, xc, yc, w, h)) | ||||
|  | ||||
|         # 只要有有效标签就保留(无论内部是否重叠) | ||||
|         if labels: | ||||
|             feature_list.append((img_path, labels, orig_w, orig_h)) | ||||
|  | ||||
|     if not feature_list: | ||||
|         raise ValueError(f"未收集到有效的特征图片及标签!") | ||||
|  | ||||
|     print(f"特征图片收集完成:共 {len(feature_list)} 张有效特征图片(允许内部目标重叠)") | ||||
|     return feature_list | ||||
|  | ||||
|  | ||||
| def paste_feature_images_to_base(base_image, feature_list): | ||||
|     """将特征图粘贴到底图——保留“跨特征图无重叠”逻辑(核心)""" | ||||
|     base_h, base_w = base_image.shape[:2] | ||||
|     pasted_labels = []               # 最终输出的标签 | ||||
|     pasted_target_boxes = []         # 已粘贴的目标框(像素坐标) | ||||
|     pasted_feature_regions = []      # 已粘贴的特征图整体区域(防跨图重叠) | ||||
|     pasted_feature_count = 0         # 成功粘贴的特征图数量 | ||||
|     skipped_small = 0                # 因尺寸过小跳过的次数 | ||||
|     skipped_region_overlap = 0       # 因特征图整体区域重叠跳过的次数 | ||||
|     skipped_target_overlap = 0       # 因跨图目标重叠跳过的次数 | ||||
|     retry_count = 0                  # 重试次数 | ||||
|  | ||||
|     target_paste_count = random.randint(*PASTE_COUNT_RANGE) | ||||
|     print(f"  计划粘贴 {target_paste_count} 张特征图...") | ||||
|  | ||||
|     while pasted_feature_count < target_paste_count and retry_count < MAX_RETRY_COUNT: | ||||
|         retry_count += 1 | ||||
|  | ||||
|         # 1. 随机选择一张特征图(允许内部重叠) | ||||
|         feature_img_path, feature_labels, orig_w, orig_h = random.choice(feature_list) | ||||
|  | ||||
|         # 2. 随机缩放特征图,确保尺寸符合要求 | ||||
|         scale_factor = random.uniform(*SCALE_FACTOR_RANGE) | ||||
|         scaled_w = int(round(orig_w * scale_factor)) | ||||
|         scaled_h = int(round(orig_h * scale_factor)) | ||||
|  | ||||
|         if scaled_w < MIN_SCALED_SIZE or scaled_h < MIN_SCALED_SIZE: | ||||
|             skipped_small += 1 | ||||
|             continue | ||||
|         if scaled_w >= base_w or scaled_h >= base_h: | ||||
|             skipped_small += 1 | ||||
|             continue | ||||
|  | ||||
|         # 3. 随机生成粘贴位置(确保特征图完全在底图内) | ||||
|         paste_x1 = random.randint(0, base_w - scaled_w) | ||||
|         paste_y1 = random.randint(0, base_h - scaled_h) | ||||
|         paste_x2 = paste_x1 + scaled_w | ||||
|         paste_y2 = paste_y1 + scaled_h | ||||
|         current_feature_region = (paste_x1, paste_y1, paste_x2, paste_y2) | ||||
|  | ||||
|         # 4. 关键:检测当前特征图整体区域与已粘贴区域是否重叠(防跨图重叠) | ||||
|         region_overlap = False | ||||
|         for existing_region in pasted_feature_regions: | ||||
|             if calculate_iou(current_feature_region, existing_region) > MAX_OVERLAP_IOU: | ||||
|                 region_overlap = True | ||||
|                 skipped_region_overlap += 1 | ||||
|                 break | ||||
|         if region_overlap: | ||||
|             continue | ||||
|  | ||||
|         # 5. 计算当前特征图目标在底图上的像素坐标(保留原始内部重叠) | ||||
|         temp_target_boxes = [] | ||||
|         valid = True | ||||
|         for (cls_id, xc, yc, w, h) in feature_labels: | ||||
|             # 特征图内目标坐标 → 缩放后 → 底图坐标(内部重叠会保留) | ||||
|             orig_x1 = (xc - w/2) * orig_w | ||||
|             orig_y1 = (yc - h/2) * orig_h | ||||
|             orig_x2 = (xc + w/2) * orig_w | ||||
|             orig_y2 = (yc + h/2) * orig_h | ||||
|  | ||||
|             scaled_x1 = orig_x1 * scale_factor | ||||
|             scaled_y1 = orig_y1 * scale_factor | ||||
|             scaled_x2 = orig_x2 * scale_factor | ||||
|             scaled_y2 = orig_y2 * scale_factor | ||||
|  | ||||
|             base_x1 = paste_x1 + scaled_x1 | ||||
|             base_y1 = paste_y1 + scaled_y1 | ||||
|             base_x2 = paste_x1 + scaled_x2 | ||||
|             base_y2 = paste_y1 + scaled_y2 | ||||
|  | ||||
|             # 确保目标框完全在底图内(不考虑内部重叠) | ||||
|             if base_x1 < 0 or base_y1 < 0 or base_x2 > base_w or base_y2 > base_h: | ||||
|                 valid = False | ||||
|                 break | ||||
|             temp_target_boxes.append((base_x1, base_y1, base_x2, base_y2)) | ||||
|  | ||||
|         if not valid: | ||||
|             continue | ||||
|  | ||||
|         # 6. 关键:检测当前特征图目标与已粘贴目标是否重叠(防跨图目标重叠) | ||||
|         target_overlap = False | ||||
|         for temp_box in temp_target_boxes: | ||||
|             for existing_box in pasted_target_boxes: | ||||
|                 if calculate_iou(temp_box, existing_box) > MAX_OVERLAP_IOU: | ||||
|                     target_overlap = True | ||||
|                     skipped_target_overlap += 1 | ||||
|                     break | ||||
|             if target_overlap: | ||||
|                 break | ||||
|         if target_overlap: | ||||
|             continue | ||||
|  | ||||
|         # 7. 粘贴特征图并记录信息(保留内部重叠) | ||||
|         feature_img = cv2.imread(feature_img_path) | ||||
|         if feature_img is None: | ||||
|             continue | ||||
|         scaled_feature_img = cv2.resize(feature_img, (scaled_w, scaled_h), interpolation=cv2.INTER_LINEAR) | ||||
|         base_image[paste_y1:paste_y2, paste_x1:paste_x2] = scaled_feature_img | ||||
|  | ||||
|         # 记录标签、目标框、特征图整体区域 | ||||
|         for (cls_id, xc, yc, w, h), temp_box in zip(feature_labels, temp_target_boxes): | ||||
|             base_xc = (temp_box[0] + temp_box[2]) / 2 / base_w | ||||
|             base_yc = (temp_box[1] + temp_box[3]) / 2 / base_h | ||||
|             base_w_box = (temp_box[2] - temp_box[0]) / base_w | ||||
|             base_h_box = (temp_box[3] - temp_box[1]) / base_h | ||||
|  | ||||
|             pasted_labels.append((cls_id, base_xc, base_yc, base_w_box, base_h_box)) | ||||
|             pasted_target_boxes.append(temp_box) | ||||
|         pasted_feature_regions.append(current_feature_region) | ||||
|         pasted_feature_count += 1 | ||||
|         print(f"  已成功粘贴 {pasted_feature_count}/{target_paste_count} 张特征图") | ||||
|  | ||||
|     # 输出统计(无内部重叠相关提示) | ||||
|     print( | ||||
|         f"  粘贴统计:成功{pasted_feature_count}张 → " | ||||
|         f"跳小尺寸{skipped_small} → 跳区域重叠{skipped_region_overlap} → " | ||||
|         f"跳目标重叠{skipped_target_overlap} → 总目标数{len(pasted_labels)}" | ||||
|     ) | ||||
|     return base_image, pasted_labels | ||||
|  | ||||
|  | ||||
| # --- 3. 主函数 ---(无修改) | ||||
| def main(): | ||||
|     os.makedirs(OUTPUT_IMAGE_DIR, exist_ok=True) | ||||
|     os.makedirs(OUTPUT_LABEL_DIR, exist_ok=True) | ||||
|  | ||||
|     try: | ||||
|         feature_list = collect_feature_images( | ||||
|             SOURCE_FEATURE_IMAGE_DIR, SOURCE_FEATURE_LABEL_DIR, TARGET_CLASSES | ||||
|         ) | ||||
|     except (FileNotFoundError, ValueError) as e: | ||||
|         print(f"错误:{e}") | ||||
|         return | ||||
|  | ||||
|     # 筛选指定数量的底图 | ||||
|     base_files = [f for f in os.listdir(BASE_IMAGE_DIR) | ||||
|                   if f.lower().endswith(('.jpg', '.png', '.jpeg'))] | ||||
|     total_base_count = len(base_files) | ||||
|     if total_base_count == 0: | ||||
|         print(f"错误:底图目录 {BASE_IMAGE_DIR} 无有效图片!") | ||||
|         return | ||||
|  | ||||
|     if SELECT_BASE_IMAGE_COUNT < 0: | ||||
|         print(f"错误:底图数量不能为负数(当前设置:{SELECT_BASE_IMAGE_COUNT})!") | ||||
|         return | ||||
|     elif SELECT_BASE_IMAGE_COUNT == 0: | ||||
|         selected_base_files = base_files | ||||
|         print(f"\n找到 {total_base_count} 张底图,使用全部底图...\n") | ||||
|     else: | ||||
|         if SELECT_BASE_IMAGE_COUNT > total_base_count: | ||||
|             selected_base_files = base_files | ||||
|             print(f"\n警告:设置底图数({SELECT_BASE_IMAGE_COUNT})超过可用数({total_base_count}),使用全部底图...\n") | ||||
|         else: | ||||
|             selected_base_files = random.sample(base_files, SELECT_BASE_IMAGE_COUNT) | ||||
|             print(f"\n找到 {total_base_count} 张底图,随机选择 {SELECT_BASE_IMAGE_COUNT} 张...\n") | ||||
|  | ||||
|     # 处理底图 | ||||
|     for base_filename in tqdm(selected_base_files, desc="处理底图"): | ||||
|         base_name, ext = os.path.splitext(base_filename) | ||||
|         base_path = os.path.join(BASE_IMAGE_DIR, base_filename) | ||||
|         base_img = cv2.imread(base_path) | ||||
|         if base_img is None: | ||||
|             tqdm.write(f"\n警告:无法读取底图 {base_filename},已跳过") | ||||
|             continue | ||||
|  | ||||
|         for aug_idx in range(AUGMENTATION_FACTOR): | ||||
|             print(f"\n底图 {base_filename}(增强序号 {aug_idx}):", end="") | ||||
|             base_copy = base_img.copy() | ||||
|             pasted_img, labels = paste_feature_images_to_base(base_copy, feature_list) | ||||
|  | ||||
|             if not labels: | ||||
|                 tqdm.write(f"\n警告:未粘贴任何目标,已跳过该结果") | ||||
|                 continue | ||||
|  | ||||
|             # 保存图片和标签 | ||||
|             output_img_name = f"{base_name}_feat_paste_{aug_idx}.jpg" | ||||
|             output_img_path = os.path.join(OUTPUT_IMAGE_DIR, output_img_name) | ||||
|             cv2.imwrite(output_img_path, pasted_img) | ||||
|  | ||||
|             output_label_name = f"{base_name}_feat_paste_{aug_idx}.txt" | ||||
|             output_label_path = os.path.join(OUTPUT_LABEL_DIR, output_label_name) | ||||
|             with open(output_label_path, 'w') as f: | ||||
|                 for (cls_id, x, y, w, h) in labels: | ||||
|                     f.write(f"{cls_id} {x:.6f} {y:.6f} {w:.6f} {h:.6f}\n") | ||||
|  | ||||
|             print(f" → 已保存(目标数:{len(labels)})") | ||||
|  | ||||
|     # 统计结果 | ||||
|     output_img_count = len([f for f in os.listdir(OUTPUT_IMAGE_DIR) if f.endswith(('.jpg', '.png'))]) | ||||
|     output_label_count = len([f for f in os.listdir(OUTPUT_LABEL_DIR) if f.endswith('.txt')]) | ||||
|     print(f"\n✅ 全部处理完成!") | ||||
|     print(f"  - 实际处理底图数量:{len(selected_base_files)} 张") | ||||
|     print(f"  - 生成图片:{output_img_count} 张") | ||||
|     print(f"  - 生成标签:{output_label_count} 个") | ||||
|     print(f"  - 输出路径:\n    图片 → {OUTPUT_IMAGE_DIR}\n    标签 → {OUTPUT_LABEL_DIR}") | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
							
								
								
									
										65
									
								
								扩大边缘像素.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										65
									
								
								扩大边缘像素.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,65 @@ | ||||
| import os | ||||
|  | ||||
|  | ||||
| def expand_bbox(label_path, k=1.1): | ||||
|     """ | ||||
|     扩大YOLO标签的边界框 | ||||
|     :param label_path: 标签文件路径(.txt) | ||||
|     :param k: 扩大系数(k>1) | ||||
|     """ | ||||
|     with open(label_path, 'r') as f: | ||||
|         lines = f.readlines() | ||||
|  | ||||
|     new_lines = [] | ||||
|     for line in lines: | ||||
|         line = line.strip() | ||||
|         if not line: | ||||
|             continue | ||||
|         # 解析标签 | ||||
|         class_id, xc, yc, w, h = line.split() | ||||
|         xc = float(xc) | ||||
|         yc = float(yc) | ||||
|         w = float(w) | ||||
|         h = float(h) | ||||
|  | ||||
|         # 计算新宽高 | ||||
|         new_w = w * k | ||||
|         new_h = h * k | ||||
|  | ||||
|         # 计算边界 | ||||
|         x1 = xc - new_w / 2 | ||||
|         y1 = yc - new_h / 2 | ||||
|         x2 = xc + new_w / 2 | ||||
|         y2 = yc + new_h / 2 | ||||
|  | ||||
|         # 截断超出图像的部分(0~1范围) | ||||
|         x1 = max(0.0, x1) | ||||
|         y1 = max(0.0, y1) | ||||
|         x2 = min(1.0, x2) | ||||
|         y2 = min(1.0, y2) | ||||
|  | ||||
|         # 重新计算中心和宽高 | ||||
|         new_xc = (x1 + x2) / 2 | ||||
|         new_yc = (y1 + y2) / 2 | ||||
|         new_w = x2 - x1 | ||||
|         new_h = y2 - y1 | ||||
|  | ||||
|         # 保留6位小数,拼接新标签 | ||||
|         new_line = f"{class_id} {new_xc:.6f} {new_yc:.6f} {new_w:.6f} {new_h:.6f}\n" | ||||
|         new_lines.append(new_line) | ||||
|  | ||||
|     # 写入新标签(覆盖原文件,或改为新路径) | ||||
|     with open(label_path, 'w') as f: | ||||
|         f.writelines(new_lines) | ||||
|  | ||||
|  | ||||
| # 批量处理文件夹中的所有标签 | ||||
| label_dir = r"D:\DataPreHandler\yuanshi_data\images\val\labels"  # 替换为你的标签文件夹路径 | ||||
| k = 1.25  # 扩大系数,根据需求调整 | ||||
|  | ||||
| for filename in os.listdir(label_dir): | ||||
|     if filename.endswith('.txt'): | ||||
|         label_path = os.path.join(label_dir, filename) | ||||
|         expand_bbox(label_path, k) | ||||
|  | ||||
| print("处理完成!") | ||||
							
								
								
									
										217
									
								
								抽取样本数量.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										217
									
								
								抽取样本数量.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,217 @@ | ||||
| import os | ||||
| import random | ||||
| import shutil | ||||
| from collections import defaultdict | ||||
|  | ||||
|  | ||||
| def add_balanced_samples(source_dir, target_dir, add_samples): | ||||
|     """ | ||||
|     在已有目标数据集基础上添加新样本(不重复),并保持类别平衡 | ||||
|     参数: | ||||
|     source_dir: 源数据集目录(包含images和labels) | ||||
|     target_dir: 目标数据集目录(已有样本,包含images和labels) | ||||
|     add_samples: 需要新增的样本数量 | ||||
|     """ | ||||
|     # 定义类别名称映射 | ||||
|     class_names = { | ||||
|         0: "Abdomen", | ||||
|         1: "Hips", | ||||
|         2: "Chest", | ||||
|         3: "vulva", | ||||
|         4: "back", | ||||
|         5: "penis", | ||||
|         6: "Horror" | ||||
|     } | ||||
|     num_classes = len(class_names) | ||||
|  | ||||
|     # 确保目标目录存在 | ||||
|     os.makedirs(os.path.join(target_dir, "images"), exist_ok=True) | ||||
|     os.makedirs(os.path.join(target_dir, "labels"), exist_ok=True) | ||||
|  | ||||
|     # -------------------------- | ||||
|     # 关键步骤1:获取目标目录中已有的文件(避免重复) | ||||
|     # -------------------------- | ||||
|     existing_images = set() | ||||
|     target_image_dir = os.path.join(target_dir, "images") | ||||
|     for f in os.listdir(target_image_dir): | ||||
|         if os.path.isfile(os.path.join(target_image_dir, f)): | ||||
|             existing_images.add(f)  # 记录已存在的图片文件名 | ||||
|     print(f"目标目录中已存在 {len(existing_images)} 个样本,将避免重复添加") | ||||
|  | ||||
|     # -------------------------- | ||||
|     # 关键步骤2:获取源目录中未被目标目录包含的可用文件 | ||||
|     # -------------------------- | ||||
|     source_image_dir = os.path.join(source_dir, "images") | ||||
|     source_label_dir = os.path.join(source_dir, "labels") | ||||
|  | ||||
|     if not os.path.exists(source_image_dir) or not os.path.exists(source_label_dir): | ||||
|         print(f"错误: 源目录 {source_dir} 中缺少images或labels子目录") | ||||
|         return | ||||
|  | ||||
|     # 源目录所有图片中,排除已存在于目标目录的文件 | ||||
|     candidate_images = [ | ||||
|         f for f in os.listdir(source_image_dir) | ||||
|         if os.path.isfile(os.path.join(source_image_dir, f)) and f not in existing_images | ||||
|     ] | ||||
|  | ||||
|     if not candidate_images: | ||||
|         print(f"错误: 源目录中没有可添加的新样本(所有样本已存在于目标目录)") | ||||
|         return | ||||
|  | ||||
|     # -------------------------- | ||||
|     # 分析候选文件的类别(兼容浮点数ID) | ||||
|     # -------------------------- | ||||
|     file_classes = {}  # 候选文件名 -> 包含的类别集合 | ||||
|     class_files = defaultdict(list)  # 类别 -> 包含该类别的候选文件列表 | ||||
|  | ||||
|     for img_file in candidate_images: | ||||
|         base_name = os.path.splitext(img_file)[0] | ||||
|         label_file = f"{base_name}.txt" | ||||
|         label_path = os.path.join(source_label_dir, label_file) | ||||
|  | ||||
|         if not os.path.exists(label_path): | ||||
|             print(f"警告: 图片 {img_file} 对应的标签文件 {label_file} 不存在,已跳过") | ||||
|             continue | ||||
|  | ||||
|         classes_in_file = set() | ||||
|         with open(label_path, 'r') as f: | ||||
|             for line in f: | ||||
|                 parts = line.strip().split() | ||||
|                 if not parts: | ||||
|                     continue | ||||
|                 try: | ||||
|                     class_id = int(float(parts[0]))  # 兼容浮点数类别ID | ||||
|                     if class_id in class_names: | ||||
|                         classes_in_file.add(class_id) | ||||
|                     else: | ||||
|                         print(f"警告: 标签文件 {label_file} 包含未知类别 {class_id}") | ||||
|                 except (ValueError, IndexError): | ||||
|                     print(f"警告: 标签文件 {label_file} 格式不正确,行内容: {line.strip()}") | ||||
|  | ||||
|         if not classes_in_file: | ||||
|             print(f"警告: 标签文件 {label_file} 无有效类别,已跳过") | ||||
|             continue | ||||
|  | ||||
|         file_classes[img_file] = classes_in_file | ||||
|         for class_id in classes_in_file: | ||||
|             class_files[class_id].append(img_file) | ||||
|  | ||||
|     # 检查候选文件中是否有类别缺失 | ||||
|     for class_id in class_names: | ||||
|         if class_id not in class_files or len(class_files[class_id]) == 0: | ||||
|             print(f"警告: 候选文件中没有类别 {class_id} ({class_names[class_id]}) 的样本") | ||||
|  | ||||
|     # -------------------------- | ||||
|     # 处理新增样本数量(不超过候选文件数量) | ||||
|     # -------------------------- | ||||
|     available_candidates = len(file_classes) | ||||
|     if add_samples > available_candidates: | ||||
|         print(f"警告: 请求新增 {add_samples} 个样本,但候选文件仅 {available_candidates} 个,将全部添加") | ||||
|         add_samples = available_candidates | ||||
|     elif add_samples <= 0: | ||||
|         print("错误: 新增样本数量必须大于0") | ||||
|         return | ||||
|  | ||||
|     # -------------------------- | ||||
|     # 平衡抽取新增样本 | ||||
|     # -------------------------- | ||||
|     selected_files = set() | ||||
|     remaining = add_samples | ||||
|  | ||||
|     # 每个类别理想的新增数量(基于总新增数量平均) | ||||
|     ideal_per_class = max(1, add_samples // num_classes) | ||||
|  | ||||
|     # 先为每个类别抽取理想数量的样本 | ||||
|     for class_id in class_names: | ||||
|         if class_id not in class_files: | ||||
|             continue | ||||
|  | ||||
|         # 候选文件中未被选中的 | ||||
|         available = [f for f in class_files[class_id] if f not in selected_files] | ||||
|         num_to_add = min(ideal_per_class, len(available), remaining) | ||||
|  | ||||
|         if num_to_add > 0: | ||||
|             selected = random.sample(available, num_to_add) | ||||
|             selected_files.update(selected) | ||||
|             remaining -= num_to_add | ||||
|  | ||||
|         if remaining <= 0: | ||||
|             break | ||||
|  | ||||
|     # 补充剩余需要新增的样本 | ||||
|     if remaining > 0: | ||||
|         all_available = [f for f in file_classes if f not in selected_files] | ||||
|         num_to_add = min(remaining, len(all_available)) | ||||
|         if num_to_add > 0: | ||||
|             selected_additional = random.sample(all_available, num_to_add) | ||||
|             selected_files.update(selected_additional) | ||||
|  | ||||
|     # -------------------------- | ||||
|     # 复制新增样本到目标目录(确保不重复) | ||||
|     # -------------------------- | ||||
|     for img_file in selected_files: | ||||
|         # 复制图片 | ||||
|         src_img = os.path.join(source_image_dir, img_file) | ||||
|         dst_img = os.path.join(target_image_dir, img_file) | ||||
|         shutil.copy2(src_img, dst_img) | ||||
|  | ||||
|         # 复制标签 | ||||
|         base_name = os.path.splitext(img_file)[0] | ||||
|         label_file = f"{base_name}.txt" | ||||
|         src_label = os.path.join(source_label_dir, label_file) | ||||
|         dst_label = os.path.join(target_dir, "labels", label_file) | ||||
|         if os.path.exists(src_label): | ||||
|             shutil.copy2(src_label, dst_label) | ||||
|  | ||||
|     # -------------------------- | ||||
|     # 统计结果(包含目标目录原有+新增的总分布) | ||||
|     # -------------------------- | ||||
|     # 1. 统计目标目录原有样本的类别 | ||||
|     existing_class_counts = defaultdict(int) | ||||
|     for img_file in existing_images: | ||||
|         # 从源目录找原有样本的标签(因为目标目录标签可能已存在,但源目录更可靠) | ||||
|         base_name = os.path.splitext(img_file)[0] | ||||
|         label_file = f"{base_name}.txt" | ||||
|         label_path = os.path.join(source_label_dir, label_file) | ||||
|         if os.path.exists(label_path): | ||||
|             with open(label_path, 'r') as f: | ||||
|                 for line in f: | ||||
|                     parts = line.strip().split() | ||||
|                     if parts: | ||||
|                         try: | ||||
|                             class_id = int(float(parts[0])) | ||||
|                             if class_id in class_names: | ||||
|                                 existing_class_counts[class_id] += 1 | ||||
|                         except (ValueError, IndexError): | ||||
|                             pass  # 忽略格式错误的行 | ||||
|  | ||||
|     # 2. 统计新增样本的类别 | ||||
|     new_class_counts = defaultdict(int) | ||||
|     for img_file in selected_files: | ||||
|         for class_id in file_classes[img_file]: | ||||
|             new_class_counts[class_id] += 1 | ||||
|  | ||||
|     # 3. 总分布(原有+新增) | ||||
|     total_class_counts = { | ||||
|         class_id: existing_class_counts.get(class_id, 0) + new_class_counts.get(class_id, 0) | ||||
|         for class_id in class_names | ||||
|     } | ||||
|  | ||||
|     # 输出结果 | ||||
|     print(f"\n已成功新增 {len(selected_files)} 个样本,目标目录总样本数: {len(existing_images) + len(selected_files)}") | ||||
|     print("目标目录总类别分布(原有+新增):") | ||||
|     for class_id in class_names: | ||||
|         print(f"类别 {class_id} ({class_names[class_id]}): {total_class_counts[class_id]} 次出现") | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     # 源数据集路径(从中抽取新样本) | ||||
|     source_directory = r"D:\DataPreHandler\data\train3" | ||||
|     # 目标数据集路径(已有样本,将新增到这里) | ||||
|     target_directory = r"D:\DataPreHandler\data\train2" | ||||
|  | ||||
|     # 需要新增的样本数量(根据需求修改) | ||||
|     add_samples = 1200  # 例如:再添加100个新样本 | ||||
|  | ||||
|     # 执行新增操作 | ||||
|     add_balanced_samples(source_directory, target_directory, add_samples) | ||||
							
								
								
									
										199
									
								
								根据位置画框.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										199
									
								
								根据位置画框.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,199 @@ | ||||
| import cv2 | ||||
| import numpy as np | ||||
| import os | ||||
| from tqdm import tqdm | ||||
|  | ||||
| # -------------------------- 1. 核心配置(请根据需求修改) -------------------------- | ||||
| # 输入目录(图片和标签需一一对应,文件名相同,仅后缀不同) | ||||
| INPUT_IMAGE_DIR = r"D:\DataPreHandler\data\train\images"  # 原始图片目录 | ||||
| INPUT_LABEL_DIR = r"D:\DataPreHandler\data\train\labels"  # 原始YOLO标签目录 | ||||
| # 输出目录(标注后的图片会保存在这里) | ||||
| OUTPUT_IMAGE_DIR = r"D:\DataPreHandler\data\test\da\output2" | ||||
|  | ||||
| # 关键配置:类别ID与类别名称的映射(必须与你的YOLO训练类别顺序一致!) | ||||
| CLASS_CONFIG = [ | ||||
|     (0, "Abdomen", (0, 255, 0)), | ||||
|     (1, "Hips", (0, 255, 0)), | ||||
|     (2, "Chest", (0, 255, 0)), | ||||
|     (3, "vulva", (0, 255, 0)), | ||||
|     (4, "back", (0, 255, 0)), | ||||
|     (5, "penis", (0, 255, 0)), | ||||
|     (6, "Horror", (0, 255, 0)) | ||||
| ] | ||||
|  | ||||
| # 绘制参数(可按需调整) | ||||
| BOX_THICKNESS = 2          # 边界框线条厚度(像素) | ||||
| FONT_FACE = cv2.FONT_HERSHEY_SIMPLEX  # 字体类型 | ||||
| FONT_SCALE = 0.6           # 字体大小(根据图片尺寸调整) | ||||
| FONT_THICKNESS = 1         # 字体线条厚度 | ||||
| TEXT_PADDING = 5           # 文字与边界框的间距(像素) | ||||
| TEXT_BG_OPACITY = 0.7      # 文字背景的透明度(0-1,0为完全透明) | ||||
|  | ||||
| # 支持的图片格式(无需修改) | ||||
| SUPPORTED_IMAGE_FORMATS = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff') | ||||
|  | ||||
|  | ||||
| # -------------------------- 2. 工具函数 -------------------------- | ||||
| def yolo2pixel(yolo_coords, img_w, img_h): | ||||
|     """ | ||||
|     将YOLO相对坐标转换为图片像素坐标(边界框:x1, y1, x2, y2) | ||||
|     :param yolo_coords: YOLO坐标列表 [xc, yc, w, h](相对值,0-1) | ||||
|     :param img_w: 图片宽度(像素) | ||||
|     :param img_h: 图片高度(像素) | ||||
|     :return: 像素坐标元组 (x1, y1, x2, y2) | ||||
|     """ | ||||
|     xc, yc, w, h = yolo_coords | ||||
|     # 计算边界框左上角和右下角坐标 | ||||
|     x1 = int((xc - w / 2) * img_w) | ||||
|     y1 = int((yc - h / 2) * img_h) | ||||
|     x2 = int((xc + w / 2) * img_w) | ||||
|     y2 = int((yc + h / 2) * img_h) | ||||
|     # 确保坐标不超出图片范围 | ||||
|     x1 = max(0, x1) | ||||
|     y1 = max(0, y1) | ||||
|     x2 = min(img_w, x2) | ||||
|     y2 = min(img_h, y2) | ||||
|     return x1, y1, x2, y2 | ||||
|  | ||||
|  | ||||
| def draw_annotation(img, bbox, class_name, color): | ||||
|     """ | ||||
|     在图片上绘制边界框和类别名称 | ||||
|     :param img: 原始图片(OpenCV格式,BGR通道) | ||||
|     :param bbox: 像素坐标边界框 (x1, y1, x2, y2) | ||||
|     :param class_name: 类别名称(字符串) | ||||
|     :param color: 边界框和文字颜色(BGR元组,如 (0,255,0) 代表绿色) | ||||
|     :return: 标注后的图片 | ||||
|     """ | ||||
|     img_h, img_w = img.shape[:2] | ||||
|     x1, y1, x2, y2 = bbox | ||||
|  | ||||
|     # 1. 绘制边界框 | ||||
|     cv2.rectangle(img, (x1, y1), (x2, y2), color, BOX_THICKNESS) | ||||
|  | ||||
|     # 2. 计算文字尺寸(用于创建文字背景) | ||||
|     text_size, _ = cv2.getTextSize(class_name, FONT_FACE, FONT_SCALE, FONT_THICKNESS) | ||||
|     text_w, text_h = text_size | ||||
|  | ||||
|     # 3. 确定文字位置(避免超出图片范围) | ||||
|     # 文字默认放在边界框左上角,若左上角空间不足则放在右上角 | ||||
|     text_x = x1 + TEXT_PADDING | ||||
|     text_y = y1 - TEXT_PADDING - text_h  # 文字基线在y轴上方 | ||||
|     if text_y < 0:  # 左上角超出图片顶部,调整到右上角 | ||||
|         text_x = x2 - TEXT_PADDING - text_w | ||||
|         text_y = y1 + TEXT_PADDING + text_h | ||||
|  | ||||
|     # 4. 绘制文字背景(半透明矩形,避免遮挡图片内容) | ||||
|     bg_x1 = text_x - TEXT_PADDING | ||||
|     bg_y1 = text_y - text_h - TEXT_PADDING | ||||
|     bg_x2 = text_x + text_w + TEXT_PADDING | ||||
|     bg_y2 = text_y + TEXT_PADDING | ||||
|     # 确保背景不超出图片范围 | ||||
|     bg_x1 = max(0, bg_x1) | ||||
|     bg_y1 = max(0, bg_y1) | ||||
|     bg_x2 = min(img_w, bg_x2) | ||||
|     bg_y2 = min(img_h, bg_y2) | ||||
|  | ||||
|     # 半透明背景:先创建背景层,再与原图混合 | ||||
|     bg = img[bg_y1:bg_y2, bg_x1:bg_x2].copy() | ||||
|     bg = cv2.rectangle(bg, (0, 0), (bg_x2 - bg_x1, bg_y2 - bg_y1), color, -1)  # 实心矩形 | ||||
|     img[bg_y1:bg_y2, bg_x1:bg_x2] = cv2.addWeighted(bg, TEXT_BG_OPACITY, img[bg_y1:bg_y2, bg_x1:bg_x2], 1 - TEXT_BG_OPACITY, 0) | ||||
|  | ||||
|     # 5. 绘制类别名称 | ||||
|     cv2.putText(img, class_name, (text_x, text_y), FONT_FACE, FONT_SCALE, (0, 0, 0), FONT_THICKNESS)  # 白色文字 | ||||
|  | ||||
|     return img | ||||
|  | ||||
|  | ||||
| # -------------------------- 3. 主函数 -------------------------- | ||||
| def main(): | ||||
|     # 1. 创建输出目录(若不存在) | ||||
|     os.makedirs(OUTPUT_IMAGE_DIR, exist_ok=True) | ||||
|     print(f"标注后的图片将保存到:{OUTPUT_IMAGE_DIR}\n") | ||||
|  | ||||
|     # 2. 构建类别ID到(名称+颜色)的映射字典 | ||||
|     class_map = {cls_id: (cls_name, cls_color) for cls_id, cls_name, cls_color in CLASS_CONFIG} | ||||
|     print("类别配置:") | ||||
|     for cls_id, cls_name, cls_color in CLASS_CONFIG: | ||||
|         print(f"  ID {cls_id} → 名称:{cls_name},颜色:{cls_color}") | ||||
|     print() | ||||
|  | ||||
|     # 3. 获取所有图片文件(仅处理支持的格式) | ||||
|     image_files = [f for f in os.listdir(INPUT_IMAGE_DIR) if f.lower().endswith(SUPPORTED_IMAGE_FORMATS)] | ||||
|     if not image_files: | ||||
|         raise FileNotFoundError(f"在 {INPUT_IMAGE_DIR} 中未找到任何支持的图片文件({SUPPORTED_IMAGE_FORMATS})") | ||||
|     print(f"找到 {len(image_files)} 张图片,开始标注...\n") | ||||
|  | ||||
|     # 4. 遍历图片并标注 | ||||
|     for img_filename in tqdm(image_files, desc="处理进度"): | ||||
|         # 4.1 构建图片和标签的路径 | ||||
|         img_name, img_ext = os.path.splitext(img_filename) | ||||
|         img_path = os.path.join(INPUT_IMAGE_DIR, img_filename) | ||||
|         label_path = os.path.join(INPUT_LABEL_DIR, f"{img_name}.txt")  # 标签文件与图片同名,后缀为txt | ||||
|  | ||||
|         # 4.2 读取图片(OpenCV默认读取为BGR通道) | ||||
|         img = cv2.imread(img_path) | ||||
|         if img is None: | ||||
|             tqdm.write(f"⚠️  跳过:无法读取图片 {img_filename}(可能损坏或格式不支持)") | ||||
|             continue | ||||
|         img_h, img_w = img.shape[:2] | ||||
|  | ||||
|         # 4.3 读取标签文件(若不存在则跳过标注,直接保存原图) | ||||
|         if not os.path.exists(label_path): | ||||
|             tqdm.write(f"⚠️  警告:图片 {img_filename} 无对应标签文件 {os.path.basename(label_path)},直接保存原图") | ||||
|             annotated_img = img.copy() | ||||
|         else: | ||||
|             # 复制原图用于标注(避免修改原始图片) | ||||
|             annotated_img = img.copy() | ||||
|             # 读取标签内容 | ||||
|             with open(label_path, 'r', encoding='utf-8') as f: | ||||
|                 label_lines = [line.strip() for line in f.readlines() if line.strip()]  # 过滤空行 | ||||
|  | ||||
|             # 4.4 解析每个标签并绘制 | ||||
|             for line_idx, line in enumerate(label_lines): | ||||
|                 try: | ||||
|                     # YOLO标签格式:class_id xc yc w h(空格分隔) | ||||
|                     parts = line.split() | ||||
|                     if len(parts) != 5: | ||||
|                         raise ValueError(f"格式错误(需5个字段,实际{len(parts)}个)") | ||||
|  | ||||
|                     # 解析类别ID和坐标 | ||||
|                     cls_id = int(float(parts[0])) | ||||
|                     yolo_coords = [float(p) for p in parts[1:]] | ||||
|                     # 检查YOLO坐标有效性(必须在0-1范围内) | ||||
|                     if not all(0 <= coord <= 1 for coord in yolo_coords): | ||||
|                         raise ValueError(f"YOLO坐标超出0-1范围:{yolo_coords}") | ||||
|  | ||||
|                     # 4.5 转换坐标并绘制 | ||||
|                     # 检查类别ID是否在配置中 | ||||
|                     if cls_id not in class_map: | ||||
|                         tqdm.write(f"⚠️  跳过:图片 {img_filename} 标签第{line_idx+1}行,未知类别ID {cls_id}(未在CLASS_CONFIG中配置)") | ||||
|                         continue | ||||
|  | ||||
|                     # 获取类别名称和颜色 | ||||
|                     cls_name, cls_color = class_map[cls_id] | ||||
|                     # 转换YOLO坐标为像素坐标 | ||||
|                     bbox = yolo2pixel(yolo_coords, img_w, img_h) | ||||
|                     # 绘制标注 | ||||
|                     annotated_img = draw_annotation(annotated_img, bbox, cls_name, cls_color) | ||||
|  | ||||
|                 except Exception as e: | ||||
|                     tqdm.write(f"⚠️  跳过:图片 {img_filename} 标签第{line_idx+1}行解析失败 → {str(e)}") | ||||
|                     continue | ||||
|  | ||||
|         # 4.6 保存标注后的图片 | ||||
|         output_img_path = os.path.join(OUTPUT_IMAGE_DIR, f"{img_name}_annotated{img_ext}") | ||||
|         # 保存为JPG格式(若原始是PNG,也可改为img_ext保持原格式) | ||||
|         # 注:JPG不支持透明通道,若原始是PNG且有透明,建议保留img_ext | ||||
|         cv2.imwrite(output_img_path, annotated_img) | ||||
|  | ||||
|     # 5. 完成提示 | ||||
|     print(f"\n✅ 标注完成!共处理 {len(image_files)} 张图片,标注后的图片已保存到:") | ||||
|     print(f"   {OUTPUT_IMAGE_DIR}") | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     try: | ||||
|         main() | ||||
|     except Exception as e: | ||||
|         print(f"\n❌ 程序异常终止:{str(e)}") | ||||
							
								
								
									
										28
									
								
								视频抽帧.py
									
									
									
									
									
								
							
							
						
						
									
										28
									
								
								视频抽帧.py
									
									
									
									
									
								
							| @ -6,23 +6,23 @@ from pathlib import Path | ||||
|  | ||||
| # -------------------------- 1. 用户配置(根据实际情况修改)-------------------------- | ||||
| # 视频文件路径(绝对路径或相对路径均可,支持mp4/avi/mov等常见格式) | ||||
| VIDEO_PATH = r"E:\geminicli\yolo\images\44.mp4"  # 替换为你的视频路径 | ||||
|  | ||||
| # 输出根目录(脚本会自动在该目录下创建train/valid/test子文件夹) | ||||
| OUTPUT_ROOT_DIR = r"D:\DataPreHandler\images"  # 替换为你的输出根路径 | ||||
| # VIDEO_PATH = r"E:\geminicli\yolo\images\屏幕操作视频.mp4"  # 替换为你的视频路径 | ||||
| VIDEO_PATH = r"D:\DataPreHandler\MP4\花生底图.mp4" | ||||
| # 输出根目录(脚本会自动在该目录下创建train/val/test子文件夹) | ||||
| OUTPUT_ROOT_DIR = r"D:\DataPreHandler\images\dituchoqu_huashen"  # 替换为你的输出根路径 | ||||
|  | ||||
| # 各文件夹需抽取的图片数量(总数量=2800+1400+700=4900) | ||||
| FRAME_COUNTS = { | ||||
|     "train": 28, | ||||
|     "valid": 14, | ||||
|     "test": 7 | ||||
|     "train": 1200, | ||||
|     "val": 150, | ||||
|     "test": 50 | ||||
| } | ||||
|  | ||||
| # 图片保存格式(建议用jpg,兼容性更好;也可改为png) | ||||
| SAVE_FORMAT = "jpg" | ||||
|  | ||||
| # 图片文件名编号位数(如0001.jpg,避免文件名乱序) | ||||
| FILE_NUM_DIGITS = 4  # 对应最大编号9999,满足4900张需求 | ||||
| FILE_NUM_DIGITS = 5  # 对应最大编号9999,满足4900张需求 | ||||
|  | ||||
|  | ||||
| # ----------------------------------------------------------------------------------- | ||||
| @ -62,13 +62,13 @@ def check_video_validity(cap): | ||||
|     print(f"   - 帧率(FPS):{fps:.1f}") | ||||
|     print(f"   - 总时长:{video_duration // 60:.0f}分{video_duration % 60:.1f}秒") | ||||
|     print( | ||||
|         f"   - 需抽取总帧数:{required_total}(train:{FRAME_COUNTS['train']}, valid:{FRAME_COUNTS['valid']}, test:{FRAME_COUNTS['test']})") | ||||
|         f"   - 需抽取总帧数:{required_total}(train:{FRAME_COUNTS['train']}, val:{FRAME_COUNTS['val']}, test:{FRAME_COUNTS['test']})") | ||||
|  | ||||
|     return total_frames | ||||
|  | ||||
|  | ||||
| def create_output_dirs(): | ||||
|     """创建输出根目录及train/valid/test子文件夹""" | ||||
|     """创建输出根目录及train/val/test子文件夹""" | ||||
|     # 转换为Path对象,适配Windows/Linux/macOS路径格式 | ||||
|     output_root = Path(OUTPUT_ROOT_DIR) | ||||
|     subdirs = FRAME_COUNTS.keys() | ||||
| @ -97,14 +97,14 @@ def generate_random_frame_indices(total_frames, required_total): | ||||
|  | ||||
|  | ||||
| def split_indices_by_dataset(random_indices): | ||||
|     """将总随机索引按train/valid/test的数量拆分""" | ||||
|     """将总随机索引按train/val/test的数量拆分""" | ||||
|     train_count = FRAME_COUNTS["train"] | ||||
|     valid_count = FRAME_COUNTS["valid"] | ||||
|     valid_count = FRAME_COUNTS["val"] | ||||
|  | ||||
|     # 拆分逻辑:前N个给train,中间M个给valid,剩余给test | ||||
|     indices_split = { | ||||
|         "train": random_indices[:train_count], | ||||
|         "valid": random_indices[train_count:train_count + valid_count], | ||||
|         "val": random_indices[train_count:train_count + valid_count], | ||||
|         "test": random_indices[train_count + valid_count:] | ||||
|     } | ||||
|  | ||||
| @ -140,7 +140,7 @@ def extract_and_save_frames(cap, indices_split, output_dirs): | ||||
|                 continue | ||||
|  | ||||
|             # 生成文件名(如train_0001.jpg) | ||||
|             file_name = f"{dataset}_{idx:0{FILE_NUM_DIGITS}d}.{SAVE_FORMAT}"  # 0{FILE_NUM_DIGITS}d表示补零到指定位数 | ||||
|             file_name = f"dths_{dataset}_{idx:0{FILE_NUM_DIGITS}d}.{SAVE_FORMAT}"  # 0{FILE_NUM_DIGITS}d表示补零到指定位数 | ||||
|             save_path = dataset_dir / file_name | ||||
|  | ||||
|             # 保存帧为图片(cv2.imwrite默认保存为BGR格式,符合图片存储标准) | ||||
|  | ||||
		Reference in New Issue
	
	Block a user