数据增强相关代码
This commit is contained in:
199
根据位置画框.py
Normal file
199
根据位置画框.py
Normal file
@ -0,0 +1,199 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
|
||||
# -------------------------- 1. 核心配置(请根据需求修改) --------------------------
|
||||
# 输入目录(图片和标签需一一对应,文件名相同,仅后缀不同)
|
||||
INPUT_IMAGE_DIR = r"D:\DataPreHandler\data\train\images" # 原始图片目录
|
||||
INPUT_LABEL_DIR = r"D:\DataPreHandler\data\train\labels" # 原始YOLO标签目录
|
||||
# 输出目录(标注后的图片会保存在这里)
|
||||
OUTPUT_IMAGE_DIR = r"D:\DataPreHandler\data\test\da\output2"
|
||||
|
||||
# 关键配置:类别ID与类别名称的映射(必须与你的YOLO训练类别顺序一致!)
|
||||
CLASS_CONFIG = [
|
||||
(0, "Abdomen", (0, 255, 0)),
|
||||
(1, "Hips", (0, 255, 0)),
|
||||
(2, "Chest", (0, 255, 0)),
|
||||
(3, "vulva", (0, 255, 0)),
|
||||
(4, "back", (0, 255, 0)),
|
||||
(5, "penis", (0, 255, 0)),
|
||||
(6, "Horror", (0, 255, 0))
|
||||
]
|
||||
|
||||
# 绘制参数(可按需调整)
|
||||
BOX_THICKNESS = 2 # 边界框线条厚度(像素)
|
||||
FONT_FACE = cv2.FONT_HERSHEY_SIMPLEX # 字体类型
|
||||
FONT_SCALE = 0.6 # 字体大小(根据图片尺寸调整)
|
||||
FONT_THICKNESS = 1 # 字体线条厚度
|
||||
TEXT_PADDING = 5 # 文字与边界框的间距(像素)
|
||||
TEXT_BG_OPACITY = 0.7 # 文字背景的透明度(0-1,0为完全透明)
|
||||
|
||||
# 支持的图片格式(无需修改)
|
||||
SUPPORTED_IMAGE_FORMATS = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff')
|
||||
|
||||
|
||||
# -------------------------- 2. 工具函数 --------------------------
|
||||
def yolo2pixel(yolo_coords, img_w, img_h):
|
||||
"""
|
||||
将YOLO相对坐标转换为图片像素坐标(边界框:x1, y1, x2, y2)
|
||||
:param yolo_coords: YOLO坐标列表 [xc, yc, w, h](相对值,0-1)
|
||||
:param img_w: 图片宽度(像素)
|
||||
:param img_h: 图片高度(像素)
|
||||
:return: 像素坐标元组 (x1, y1, x2, y2)
|
||||
"""
|
||||
xc, yc, w, h = yolo_coords
|
||||
# 计算边界框左上角和右下角坐标
|
||||
x1 = int((xc - w / 2) * img_w)
|
||||
y1 = int((yc - h / 2) * img_h)
|
||||
x2 = int((xc + w / 2) * img_w)
|
||||
y2 = int((yc + h / 2) * img_h)
|
||||
# 确保坐标不超出图片范围
|
||||
x1 = max(0, x1)
|
||||
y1 = max(0, y1)
|
||||
x2 = min(img_w, x2)
|
||||
y2 = min(img_h, y2)
|
||||
return x1, y1, x2, y2
|
||||
|
||||
|
||||
def draw_annotation(img, bbox, class_name, color):
|
||||
"""
|
||||
在图片上绘制边界框和类别名称
|
||||
:param img: 原始图片(OpenCV格式,BGR通道)
|
||||
:param bbox: 像素坐标边界框 (x1, y1, x2, y2)
|
||||
:param class_name: 类别名称(字符串)
|
||||
:param color: 边界框和文字颜色(BGR元组,如 (0,255,0) 代表绿色)
|
||||
:return: 标注后的图片
|
||||
"""
|
||||
img_h, img_w = img.shape[:2]
|
||||
x1, y1, x2, y2 = bbox
|
||||
|
||||
# 1. 绘制边界框
|
||||
cv2.rectangle(img, (x1, y1), (x2, y2), color, BOX_THICKNESS)
|
||||
|
||||
# 2. 计算文字尺寸(用于创建文字背景)
|
||||
text_size, _ = cv2.getTextSize(class_name, FONT_FACE, FONT_SCALE, FONT_THICKNESS)
|
||||
text_w, text_h = text_size
|
||||
|
||||
# 3. 确定文字位置(避免超出图片范围)
|
||||
# 文字默认放在边界框左上角,若左上角空间不足则放在右上角
|
||||
text_x = x1 + TEXT_PADDING
|
||||
text_y = y1 - TEXT_PADDING - text_h # 文字基线在y轴上方
|
||||
if text_y < 0: # 左上角超出图片顶部,调整到右上角
|
||||
text_x = x2 - TEXT_PADDING - text_w
|
||||
text_y = y1 + TEXT_PADDING + text_h
|
||||
|
||||
# 4. 绘制文字背景(半透明矩形,避免遮挡图片内容)
|
||||
bg_x1 = text_x - TEXT_PADDING
|
||||
bg_y1 = text_y - text_h - TEXT_PADDING
|
||||
bg_x2 = text_x + text_w + TEXT_PADDING
|
||||
bg_y2 = text_y + TEXT_PADDING
|
||||
# 确保背景不超出图片范围
|
||||
bg_x1 = max(0, bg_x1)
|
||||
bg_y1 = max(0, bg_y1)
|
||||
bg_x2 = min(img_w, bg_x2)
|
||||
bg_y2 = min(img_h, bg_y2)
|
||||
|
||||
# 半透明背景:先创建背景层,再与原图混合
|
||||
bg = img[bg_y1:bg_y2, bg_x1:bg_x2].copy()
|
||||
bg = cv2.rectangle(bg, (0, 0), (bg_x2 - bg_x1, bg_y2 - bg_y1), color, -1) # 实心矩形
|
||||
img[bg_y1:bg_y2, bg_x1:bg_x2] = cv2.addWeighted(bg, TEXT_BG_OPACITY, img[bg_y1:bg_y2, bg_x1:bg_x2], 1 - TEXT_BG_OPACITY, 0)
|
||||
|
||||
# 5. 绘制类别名称
|
||||
cv2.putText(img, class_name, (text_x, text_y), FONT_FACE, FONT_SCALE, (0, 0, 0), FONT_THICKNESS) # 白色文字
|
||||
|
||||
return img
|
||||
|
||||
|
||||
# -------------------------- 3. 主函数 --------------------------
|
||||
def main():
|
||||
# 1. 创建输出目录(若不存在)
|
||||
os.makedirs(OUTPUT_IMAGE_DIR, exist_ok=True)
|
||||
print(f"标注后的图片将保存到:{OUTPUT_IMAGE_DIR}\n")
|
||||
|
||||
# 2. 构建类别ID到(名称+颜色)的映射字典
|
||||
class_map = {cls_id: (cls_name, cls_color) for cls_id, cls_name, cls_color in CLASS_CONFIG}
|
||||
print("类别配置:")
|
||||
for cls_id, cls_name, cls_color in CLASS_CONFIG:
|
||||
print(f" ID {cls_id} → 名称:{cls_name},颜色:{cls_color}")
|
||||
print()
|
||||
|
||||
# 3. 获取所有图片文件(仅处理支持的格式)
|
||||
image_files = [f for f in os.listdir(INPUT_IMAGE_DIR) if f.lower().endswith(SUPPORTED_IMAGE_FORMATS)]
|
||||
if not image_files:
|
||||
raise FileNotFoundError(f"在 {INPUT_IMAGE_DIR} 中未找到任何支持的图片文件({SUPPORTED_IMAGE_FORMATS})")
|
||||
print(f"找到 {len(image_files)} 张图片,开始标注...\n")
|
||||
|
||||
# 4. 遍历图片并标注
|
||||
for img_filename in tqdm(image_files, desc="处理进度"):
|
||||
# 4.1 构建图片和标签的路径
|
||||
img_name, img_ext = os.path.splitext(img_filename)
|
||||
img_path = os.path.join(INPUT_IMAGE_DIR, img_filename)
|
||||
label_path = os.path.join(INPUT_LABEL_DIR, f"{img_name}.txt") # 标签文件与图片同名,后缀为txt
|
||||
|
||||
# 4.2 读取图片(OpenCV默认读取为BGR通道)
|
||||
img = cv2.imread(img_path)
|
||||
if img is None:
|
||||
tqdm.write(f"⚠️ 跳过:无法读取图片 {img_filename}(可能损坏或格式不支持)")
|
||||
continue
|
||||
img_h, img_w = img.shape[:2]
|
||||
|
||||
# 4.3 读取标签文件(若不存在则跳过标注,直接保存原图)
|
||||
if not os.path.exists(label_path):
|
||||
tqdm.write(f"⚠️ 警告:图片 {img_filename} 无对应标签文件 {os.path.basename(label_path)},直接保存原图")
|
||||
annotated_img = img.copy()
|
||||
else:
|
||||
# 复制原图用于标注(避免修改原始图片)
|
||||
annotated_img = img.copy()
|
||||
# 读取标签内容
|
||||
with open(label_path, 'r', encoding='utf-8') as f:
|
||||
label_lines = [line.strip() for line in f.readlines() if line.strip()] # 过滤空行
|
||||
|
||||
# 4.4 解析每个标签并绘制
|
||||
for line_idx, line in enumerate(label_lines):
|
||||
try:
|
||||
# YOLO标签格式:class_id xc yc w h(空格分隔)
|
||||
parts = line.split()
|
||||
if len(parts) != 5:
|
||||
raise ValueError(f"格式错误(需5个字段,实际{len(parts)}个)")
|
||||
|
||||
# 解析类别ID和坐标
|
||||
cls_id = int(float(parts[0]))
|
||||
yolo_coords = [float(p) for p in parts[1:]]
|
||||
# 检查YOLO坐标有效性(必须在0-1范围内)
|
||||
if not all(0 <= coord <= 1 for coord in yolo_coords):
|
||||
raise ValueError(f"YOLO坐标超出0-1范围:{yolo_coords}")
|
||||
|
||||
# 4.5 转换坐标并绘制
|
||||
# 检查类别ID是否在配置中
|
||||
if cls_id not in class_map:
|
||||
tqdm.write(f"⚠️ 跳过:图片 {img_filename} 标签第{line_idx+1}行,未知类别ID {cls_id}(未在CLASS_CONFIG中配置)")
|
||||
continue
|
||||
|
||||
# 获取类别名称和颜色
|
||||
cls_name, cls_color = class_map[cls_id]
|
||||
# 转换YOLO坐标为像素坐标
|
||||
bbox = yolo2pixel(yolo_coords, img_w, img_h)
|
||||
# 绘制标注
|
||||
annotated_img = draw_annotation(annotated_img, bbox, cls_name, cls_color)
|
||||
|
||||
except Exception as e:
|
||||
tqdm.write(f"⚠️ 跳过:图片 {img_filename} 标签第{line_idx+1}行解析失败 → {str(e)}")
|
||||
continue
|
||||
|
||||
# 4.6 保存标注后的图片
|
||||
output_img_path = os.path.join(OUTPUT_IMAGE_DIR, f"{img_name}_annotated{img_ext}")
|
||||
# 保存为JPG格式(若原始是PNG,也可改为img_ext保持原格式)
|
||||
# 注:JPG不支持透明通道,若原始是PNG且有透明,建议保留img_ext
|
||||
cv2.imwrite(output_img_path, annotated_img)
|
||||
|
||||
# 5. 完成提示
|
||||
print(f"\n✅ 标注完成!共处理 {len(image_files)} 张图片,标注后的图片已保存到:")
|
||||
print(f" {OUTPUT_IMAGE_DIR}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except Exception as e:
|
||||
print(f"\n❌ 程序异常终止:{str(e)}")
|
||||
Reference in New Issue
Block a user