# In utils.py or a new losses.py import config import torch.nn as nn import segmentation_models_pytorch as smp class UNetPlusPlusLoss(nn.Module): """ 仿照 UNet++ 官方实现,结合了 BCE 和 Dice Loss。 这个损失函数适用于 segmentation-models-pytorch 库的输出。 Args: mode (str): DiceLoss 的模式, e.g., 'multilabel', 'multiclass'. bce_weight (float): BCE Loss 在总损失中所占的权重。 dice_weight (float): Dice Loss 在总损失中所占的权重。 """ def __init__(self, mode=config.SEG_MODE, bce_weight=0.5, dice_weight=0.5): super(UNetPlusPlusLoss, self).__init__() # 使用数值更稳定的 BCEWithLogitsLoss self.bce_loss = nn.BCEWithLogitsLoss() self.dice_loss = smp.losses.DiceLoss(mode=mode) self.bce_weight = bce_weight self.dice_weight = dice_weight def forward(self, y_pred, y_true): """ 计算组合损失。 Args: y_pred: 模型的预测输出 (logits)。 y_true: 真实的标签 (mask)。 Returns: 组合后的损失值。 """ bce = self.bce_loss(y_pred, y_true) dice = self.dice_loss(y_pred, y_true) # 按照权重将两种损失相加 total_loss = self.bce_weight * bce + self.dice_weight * dice return total_loss class MultiClassLoss(nn.Module): """ A combined loss function for multi-class segmentation. This loss combines CrossEntropyLoss and DiceLoss, which is a common and effective practice for semantic segmentation tasks. Args: mode (str): DiceLoss mode, should be 'multiclass'. ce_weight (float): Weight for the CrossEntropyLoss component. dice_weight (float): Weight for the DiceLoss component. ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. """ def __init__(self, mode='multiclass', ce_weight=0.5, dice_weight=0.5, ignore_index=-100): super(MultiClassLoss, self).__init__() # CrossEntropyLoss is the standard for multi-class classification. # It combines LogSoftmax and NLLLoss in one single class. self.ce_loss = nn.CrossEntropyLoss(ignore_index=ignore_index) # DiceLoss in 'multiclass' mode works correctly with class indices. self.dice_loss = smp.losses.DiceLoss(mode=mode) self.ce_weight = ce_weight self.dice_weight = dice_weight def forward(self, y_pred, y_true): """ Calculates the combined loss. Args: y_pred: Model predictions (logits), shape (N, C, H, W). y_true: Ground truth labels (class indices), shape (N, H, W) and dtype=torch.long. Returns: The combined loss value. """ # Ensure target tensor is of type long for CrossEntropyLoss y_true = y_true.long() ce = self.ce_loss(y_pred, y_true) dice = self.dice_loss(y_pred, y_true) total_loss = self.ce_weight * ce + self.dice_weight * dice return total_loss