first commit

This commit is contained in:
admin
2026-05-20 15:05:35 +08:00
commit ac09b26253
2048 changed files with 189478 additions and 0 deletions

View File

@@ -0,0 +1,99 @@
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
class SegmentationDataset(Dataset):
"""
自定义图像分割数据集。
根据指定的 seg_mode可以生成适用于 'multiclass''multilabel' 任务的掩码。
"""
def __init__(self, image_dir, mask_dir, classes, class_rgb_values, seg_mode, augmentation=None):
"""
Args:
image_dir (str): 图像文件目录。
mask_dir (str): 掩码文件目录。
classes (list): 类别名称列表。
class_rgb_values (list): 每个类别在灰度掩码中对应的像素值。
augmentation (albumentations.Compose, optional): 数据增强流程。
seg_mode (str): 分割模式, 'multiclass''multilabel'
"""
self.image_dir = image_dir
self.mask_dir = mask_dir
self.image_filenames = sorted(os.listdir(image_dir))
self.augmentation = augmentation
self.classes = classes
self.class_rgb_values = class_rgb_values
# 【新增】存储分割模式并进行验证
self.seg_mode = seg_mode
if self.seg_mode not in ['multiclass', 'multilabel']:
raise ValueError(f"seg_mode must be 'multiclass' or 'multilabel', but got {self.seg_mode}")
print(f"Found {len(self.image_filenames)} images in {image_dir}. Dataset mode: '{self.seg_mode}'")
def __len__(self):
return len(self.image_filenames)
def __getitem__(self, idx):
img_path = os.path.join(self.image_dir, self.image_filenames[idx])
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask_path = os.path.join(self.mask_dir, self.image_filenames[idx])
# 读取灰度图
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
if mask is None:
raise FileNotFoundError(f"Mask file not found or could not be read: {mask_path}")
# 【核心修改】根据 seg_mode 选择掩码的处理方式
if self.seg_mode == 'multilabel':
# 为 'multilabel' 模式创建 one-hot 编码掩码
processed_mask = self._create_one_hot_mask(mask)
else: # 'multiclass'
# 为 'multiclass' 模式创建类索引掩码
processed_mask = self._create_class_index_mask(mask)
if self.augmentation:
# Albumentations 可以同时处理 (H,W,C) 和 (H,W) 格式的掩码
sample = self.augmentation(image=image, mask=processed_mask)
image = sample['image']
mask = sample['mask']
else:
mask = processed_mask
# 【核心修改】根据 seg_mode 对最终的掩码张量进行处理
if self.seg_mode == 'multilabel':
# 调整维度顺序 (H, W, C) -> (C, H, W) 并转换为 float
mask = mask.permute(2, 0, 1).float()
else: # 'multiclass'
# 直接转换为 long 类型,不需要调整维度
mask = mask.long()
return image, mask
def _create_one_hot_mask(self, mask):
"""
将单通道的灰度掩码 (H, W) 转换为 one-hot 编码的掩码 (H, W, C)。
这是为 'multilabel' 模式准备的。
返回一个 NumPy 数组。
"""
semantic_map = np.zeros((mask.shape[0], mask.shape[1], len(self.class_rgb_values)), dtype=np.uint8)
for i, value in enumerate(self.class_rgb_values):
semantic_map[:, :, i] = (mask == value).astype(np.uint8)
return semantic_map
def _create_class_index_mask(self, mask):
"""
【新增函数】
将单通道的灰度掩码 (H, W) 转换为类索引掩码 (H, W)。
这是为 'multiclass' 模式准备的。
返回一个 NumPy 数组。
"""
class_index_map = np.zeros(mask.shape, dtype=np.uint8)
for i, value in enumerate(self.class_rgb_values):
# 将灰度值为 value 的像素,其类别索引设置为 i
class_index_map[mask == value] = i
return class_index_map