first commit
This commit is contained in:
99
Seg_All_In_One_SegModel/dataset.py
Normal file
99
Seg_All_In_One_SegModel/dataset.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
class SegmentationDataset(Dataset):
|
||||
"""
|
||||
自定义图像分割数据集。
|
||||
根据指定的 seg_mode,可以生成适用于 'multiclass' 或 'multilabel' 任务的掩码。
|
||||
"""
|
||||
def __init__(self, image_dir, mask_dir, classes, class_rgb_values, seg_mode, augmentation=None):
|
||||
"""
|
||||
Args:
|
||||
image_dir (str): 图像文件目录。
|
||||
mask_dir (str): 掩码文件目录。
|
||||
classes (list): 类别名称列表。
|
||||
class_rgb_values (list): 每个类别在灰度掩码中对应的像素值。
|
||||
augmentation (albumentations.Compose, optional): 数据增强流程。
|
||||
seg_mode (str): 分割模式, 'multiclass' 或 'multilabel'。
|
||||
"""
|
||||
self.image_dir = image_dir
|
||||
self.mask_dir = mask_dir
|
||||
self.image_filenames = sorted(os.listdir(image_dir))
|
||||
self.augmentation = augmentation
|
||||
self.classes = classes
|
||||
self.class_rgb_values = class_rgb_values
|
||||
|
||||
# 【新增】存储分割模式并进行验证
|
||||
self.seg_mode = seg_mode
|
||||
if self.seg_mode not in ['multiclass', 'multilabel']:
|
||||
raise ValueError(f"seg_mode must be 'multiclass' or 'multilabel', but got {self.seg_mode}")
|
||||
|
||||
print(f"Found {len(self.image_filenames)} images in {image_dir}. Dataset mode: '{self.seg_mode}'")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_filenames)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img_path = os.path.join(self.image_dir, self.image_filenames[idx])
|
||||
image = cv2.imread(img_path)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
mask_path = os.path.join(self.mask_dir, self.image_filenames[idx])
|
||||
# 读取灰度图
|
||||
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
|
||||
|
||||
if mask is None:
|
||||
raise FileNotFoundError(f"Mask file not found or could not be read: {mask_path}")
|
||||
|
||||
# 【核心修改】根据 seg_mode 选择掩码的处理方式
|
||||
if self.seg_mode == 'multilabel':
|
||||
# 为 'multilabel' 模式创建 one-hot 编码掩码
|
||||
processed_mask = self._create_one_hot_mask(mask)
|
||||
else: # 'multiclass'
|
||||
# 为 'multiclass' 模式创建类索引掩码
|
||||
processed_mask = self._create_class_index_mask(mask)
|
||||
|
||||
if self.augmentation:
|
||||
# Albumentations 可以同时处理 (H,W,C) 和 (H,W) 格式的掩码
|
||||
sample = self.augmentation(image=image, mask=processed_mask)
|
||||
image = sample['image']
|
||||
mask = sample['mask']
|
||||
else:
|
||||
mask = processed_mask
|
||||
|
||||
# 【核心修改】根据 seg_mode 对最终的掩码张量进行处理
|
||||
if self.seg_mode == 'multilabel':
|
||||
# 调整维度顺序 (H, W, C) -> (C, H, W) 并转换为 float
|
||||
mask = mask.permute(2, 0, 1).float()
|
||||
else: # 'multiclass'
|
||||
# 直接转换为 long 类型,不需要调整维度
|
||||
mask = mask.long()
|
||||
|
||||
return image, mask
|
||||
|
||||
def _create_one_hot_mask(self, mask):
|
||||
"""
|
||||
将单通道的灰度掩码 (H, W) 转换为 one-hot 编码的掩码 (H, W, C)。
|
||||
这是为 'multilabel' 模式准备的。
|
||||
返回一个 NumPy 数组。
|
||||
"""
|
||||
semantic_map = np.zeros((mask.shape[0], mask.shape[1], len(self.class_rgb_values)), dtype=np.uint8)
|
||||
for i, value in enumerate(self.class_rgb_values):
|
||||
semantic_map[:, :, i] = (mask == value).astype(np.uint8)
|
||||
return semantic_map
|
||||
|
||||
def _create_class_index_mask(self, mask):
|
||||
"""
|
||||
【新增函数】
|
||||
将单通道的灰度掩码 (H, W) 转换为类索引掩码 (H, W)。
|
||||
这是为 'multiclass' 模式准备的。
|
||||
返回一个 NumPy 数组。
|
||||
"""
|
||||
class_index_map = np.zeros(mask.shape, dtype=np.uint8)
|
||||
for i, value in enumerate(self.class_rgb_values):
|
||||
# 将灰度值为 value 的像素,其类别索引设置为 i
|
||||
class_index_map[mask == value] = i
|
||||
return class_index_map
|
||||
Reference in New Issue
Block a user