import os import cv2 import numpy as np import torch from torch.utils.data import Dataset class SegmentationDataset(Dataset): """ 自定义图像分割数据集。 根据指定的 seg_mode,可以生成适用于 'multiclass' 或 'multilabel' 任务的掩码。 """ def __init__(self, image_dir, mask_dir, classes, class_rgb_values, seg_mode, augmentation=None): """ Args: image_dir (str): 图像文件目录。 mask_dir (str): 掩码文件目录。 classes (list): 类别名称列表。 class_rgb_values (list): 每个类别在灰度掩码中对应的像素值。 augmentation (albumentations.Compose, optional): 数据增强流程。 seg_mode (str): 分割模式, 'multiclass' 或 'multilabel'。 """ self.image_dir = image_dir self.mask_dir = mask_dir self.image_filenames = sorted(os.listdir(image_dir)) self.augmentation = augmentation self.classes = classes self.class_rgb_values = class_rgb_values # 【新增】存储分割模式并进行验证 self.seg_mode = seg_mode if self.seg_mode not in ['multiclass', 'multilabel']: raise ValueError(f"seg_mode must be 'multiclass' or 'multilabel', but got {self.seg_mode}") print(f"Found {len(self.image_filenames)} images in {image_dir}. Dataset mode: '{self.seg_mode}'") def __len__(self): return len(self.image_filenames) def __getitem__(self, idx): img_path = os.path.join(self.image_dir, self.image_filenames[idx]) image = cv2.imread(img_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) mask_path = os.path.join(self.mask_dir, self.image_filenames[idx]) # 读取灰度图 mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) if mask is None: raise FileNotFoundError(f"Mask file not found or could not be read: {mask_path}") # 【核心修改】根据 seg_mode 选择掩码的处理方式 if self.seg_mode == 'multilabel': # 为 'multilabel' 模式创建 one-hot 编码掩码 processed_mask = self._create_one_hot_mask(mask) else: # 'multiclass' # 为 'multiclass' 模式创建类索引掩码 processed_mask = self._create_class_index_mask(mask) if self.augmentation: # Albumentations 可以同时处理 (H,W,C) 和 (H,W) 格式的掩码 sample = self.augmentation(image=image, mask=processed_mask) image = sample['image'] mask = sample['mask'] else: mask = processed_mask # 【核心修改】根据 seg_mode 对最终的掩码张量进行处理 if self.seg_mode == 'multilabel': # 调整维度顺序 (H, W, C) -> (C, H, W) 并转换为 float mask = mask.permute(2, 0, 1).float() else: # 'multiclass' # 直接转换为 long 类型,不需要调整维度 mask = mask.long() return image, mask def _create_one_hot_mask(self, mask): """ 将单通道的灰度掩码 (H, W) 转换为 one-hot 编码的掩码 (H, W, C)。 这是为 'multilabel' 模式准备的。 返回一个 NumPy 数组。 """ semantic_map = np.zeros((mask.shape[0], mask.shape[1], len(self.class_rgb_values)), dtype=np.uint8) for i, value in enumerate(self.class_rgb_values): semantic_map[:, :, i] = (mask == value).astype(np.uint8) return semantic_map def _create_class_index_mask(self, mask): """ 【新增函数】 将单通道的灰度掩码 (H, W) 转换为类索引掩码 (H, W)。 这是为 'multiclass' 模式准备的。 返回一个 NumPy 数组。 """ class_index_map = np.zeros(mask.shape, dtype=np.uint8) for i, value in enumerate(self.class_rgb_values): # 将灰度值为 value 的像素,其类别索引设置为 i class_index_map[mask == value] = i return class_index_map