Files
Seg_Data_Server/Seg_All_In_One_SegModel/dataset.py
2026-05-20 15:05:35 +08:00

99 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
class SegmentationDataset(Dataset):
"""
自定义图像分割数据集。
根据指定的 seg_mode可以生成适用于 'multiclass''multilabel' 任务的掩码。
"""
def __init__(self, image_dir, mask_dir, classes, class_rgb_values, seg_mode, augmentation=None):
"""
Args:
image_dir (str): 图像文件目录。
mask_dir (str): 掩码文件目录。
classes (list): 类别名称列表。
class_rgb_values (list): 每个类别在灰度掩码中对应的像素值。
augmentation (albumentations.Compose, optional): 数据增强流程。
seg_mode (str): 分割模式, 'multiclass''multilabel'
"""
self.image_dir = image_dir
self.mask_dir = mask_dir
self.image_filenames = sorted(os.listdir(image_dir))
self.augmentation = augmentation
self.classes = classes
self.class_rgb_values = class_rgb_values
# 【新增】存储分割模式并进行验证
self.seg_mode = seg_mode
if self.seg_mode not in ['multiclass', 'multilabel']:
raise ValueError(f"seg_mode must be 'multiclass' or 'multilabel', but got {self.seg_mode}")
print(f"Found {len(self.image_filenames)} images in {image_dir}. Dataset mode: '{self.seg_mode}'")
def __len__(self):
return len(self.image_filenames)
def __getitem__(self, idx):
img_path = os.path.join(self.image_dir, self.image_filenames[idx])
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask_path = os.path.join(self.mask_dir, self.image_filenames[idx])
# 读取灰度图
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
if mask is None:
raise FileNotFoundError(f"Mask file not found or could not be read: {mask_path}")
# 【核心修改】根据 seg_mode 选择掩码的处理方式
if self.seg_mode == 'multilabel':
# 为 'multilabel' 模式创建 one-hot 编码掩码
processed_mask = self._create_one_hot_mask(mask)
else: # 'multiclass'
# 为 'multiclass' 模式创建类索引掩码
processed_mask = self._create_class_index_mask(mask)
if self.augmentation:
# Albumentations 可以同时处理 (H,W,C) 和 (H,W) 格式的掩码
sample = self.augmentation(image=image, mask=processed_mask)
image = sample['image']
mask = sample['mask']
else:
mask = processed_mask
# 【核心修改】根据 seg_mode 对最终的掩码张量进行处理
if self.seg_mode == 'multilabel':
# 调整维度顺序 (H, W, C) -> (C, H, W) 并转换为 float
mask = mask.permute(2, 0, 1).float()
else: # 'multiclass'
# 直接转换为 long 类型,不需要调整维度
mask = mask.long()
return image, mask
def _create_one_hot_mask(self, mask):
"""
将单通道的灰度掩码 (H, W) 转换为 one-hot 编码的掩码 (H, W, C)。
这是为 'multilabel' 模式准备的。
返回一个 NumPy 数组。
"""
semantic_map = np.zeros((mask.shape[0], mask.shape[1], len(self.class_rgb_values)), dtype=np.uint8)
for i, value in enumerate(self.class_rgb_values):
semantic_map[:, :, i] = (mask == value).astype(np.uint8)
return semantic_map
def _create_class_index_mask(self, mask):
"""
【新增函数】
将单通道的灰度掩码 (H, W) 转换为类索引掩码 (H, W)。
这是为 'multiclass' 模式准备的。
返回一个 NumPy 数组。
"""
class_index_map = np.zeros(mask.shape, dtype=np.uint8)
for i, value in enumerate(self.class_rgb_values):
# 将灰度值为 value 的像素,其类别索引设置为 i
class_index_map[mask == value] = i
return class_index_map