86 lines
3.2 KiB
Python
86 lines
3.2 KiB
Python
# In utils.py or a new losses.py
|
|
import config
|
|
import torch.nn as nn
|
|
import segmentation_models_pytorch as smp
|
|
|
|
class UNetPlusPlusLoss(nn.Module):
|
|
"""
|
|
仿照 UNet++ 官方实现,结合了 BCE 和 Dice Loss。
|
|
这个损失函数适用于 segmentation-models-pytorch 库的输出。
|
|
|
|
Args:
|
|
mode (str): DiceLoss 的模式, e.g., 'multilabel', 'multiclass'.
|
|
bce_weight (float): BCE Loss 在总损失中所占的权重。
|
|
dice_weight (float): Dice Loss 在总损失中所占的权重。
|
|
"""
|
|
def __init__(self, mode=config.SEG_MODE, bce_weight=0.5, dice_weight=0.5):
|
|
super(UNetPlusPlusLoss, self).__init__()
|
|
|
|
# 使用数值更稳定的 BCEWithLogitsLoss
|
|
self.bce_loss = nn.BCEWithLogitsLoss()
|
|
self.dice_loss = smp.losses.DiceLoss(mode=mode)
|
|
self.bce_weight = bce_weight
|
|
self.dice_weight = dice_weight
|
|
|
|
def forward(self, y_pred, y_true):
|
|
"""
|
|
计算组合损失。
|
|
|
|
Args:
|
|
y_pred: 模型的预测输出 (logits)。
|
|
y_true: 真实的标签 (mask)。
|
|
|
|
Returns:
|
|
组合后的损失值。
|
|
"""
|
|
bce = self.bce_loss(y_pred, y_true)
|
|
dice = self.dice_loss(y_pred, y_true)
|
|
|
|
# 按照权重将两种损失相加
|
|
total_loss = self.bce_weight * bce + self.dice_weight * dice
|
|
return total_loss
|
|
|
|
class MultiClassLoss(nn.Module):
|
|
"""
|
|
A combined loss function for multi-class segmentation.
|
|
This loss combines CrossEntropyLoss and DiceLoss, which is a common and effective
|
|
practice for semantic segmentation tasks.
|
|
|
|
Args:
|
|
mode (str): DiceLoss mode, should be 'multiclass'.
|
|
ce_weight (float): Weight for the CrossEntropyLoss component.
|
|
dice_weight (float): Weight for the DiceLoss component.
|
|
ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
|
|
"""
|
|
def __init__(self, mode='multiclass', ce_weight=0.5, dice_weight=0.5, ignore_index=-100):
|
|
super(MultiClassLoss, self).__init__()
|
|
|
|
# CrossEntropyLoss is the standard for multi-class classification.
|
|
# It combines LogSoftmax and NLLLoss in one single class.
|
|
self.ce_loss = nn.CrossEntropyLoss(ignore_index=ignore_index)
|
|
|
|
# DiceLoss in 'multiclass' mode works correctly with class indices.
|
|
self.dice_loss = smp.losses.DiceLoss(mode=mode)
|
|
|
|
self.ce_weight = ce_weight
|
|
self.dice_weight = dice_weight
|
|
|
|
def forward(self, y_pred, y_true):
|
|
"""
|
|
Calculates the combined loss.
|
|
|
|
Args:
|
|
y_pred: Model predictions (logits), shape (N, C, H, W).
|
|
y_true: Ground truth labels (class indices), shape (N, H, W) and dtype=torch.long.
|
|
|
|
Returns:
|
|
The combined loss value.
|
|
"""
|
|
# Ensure target tensor is of type long for CrossEntropyLoss
|
|
y_true = y_true.long()
|
|
|
|
ce = self.ce_loss(y_pred, y_true)
|
|
dice = self.dice_loss(y_pred, y_true)
|
|
|
|
total_loss = self.ce_weight * ce + self.dice_weight * dice
|
|
return total_loss |