Files
Seg_Data_Server/Seg_All_In_One_SegModel/loss.py
2026-05-20 15:05:35 +08:00

86 lines
3.2 KiB
Python

# In utils.py or a new losses.py
import config
import torch.nn as nn
import segmentation_models_pytorch as smp
class UNetPlusPlusLoss(nn.Module):
"""
仿照 UNet++ 官方实现,结合了 BCE 和 Dice Loss。
这个损失函数适用于 segmentation-models-pytorch 库的输出。
Args:
mode (str): DiceLoss 的模式, e.g., 'multilabel', 'multiclass'.
bce_weight (float): BCE Loss 在总损失中所占的权重。
dice_weight (float): Dice Loss 在总损失中所占的权重。
"""
def __init__(self, mode=config.SEG_MODE, bce_weight=0.5, dice_weight=0.5):
super(UNetPlusPlusLoss, self).__init__()
# 使用数值更稳定的 BCEWithLogitsLoss
self.bce_loss = nn.BCEWithLogitsLoss()
self.dice_loss = smp.losses.DiceLoss(mode=mode)
self.bce_weight = bce_weight
self.dice_weight = dice_weight
def forward(self, y_pred, y_true):
"""
计算组合损失。
Args:
y_pred: 模型的预测输出 (logits)。
y_true: 真实的标签 (mask)。
Returns:
组合后的损失值。
"""
bce = self.bce_loss(y_pred, y_true)
dice = self.dice_loss(y_pred, y_true)
# 按照权重将两种损失相加
total_loss = self.bce_weight * bce + self.dice_weight * dice
return total_loss
class MultiClassLoss(nn.Module):
"""
A combined loss function for multi-class segmentation.
This loss combines CrossEntropyLoss and DiceLoss, which is a common and effective
practice for semantic segmentation tasks.
Args:
mode (str): DiceLoss mode, should be 'multiclass'.
ce_weight (float): Weight for the CrossEntropyLoss component.
dice_weight (float): Weight for the DiceLoss component.
ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
"""
def __init__(self, mode='multiclass', ce_weight=0.5, dice_weight=0.5, ignore_index=-100):
super(MultiClassLoss, self).__init__()
# CrossEntropyLoss is the standard for multi-class classification.
# It combines LogSoftmax and NLLLoss in one single class.
self.ce_loss = nn.CrossEntropyLoss(ignore_index=ignore_index)
# DiceLoss in 'multiclass' mode works correctly with class indices.
self.dice_loss = smp.losses.DiceLoss(mode=mode)
self.ce_weight = ce_weight
self.dice_weight = dice_weight
def forward(self, y_pred, y_true):
"""
Calculates the combined loss.
Args:
y_pred: Model predictions (logits), shape (N, C, H, W).
y_true: Ground truth labels (class indices), shape (N, H, W) and dtype=torch.long.
Returns:
The combined loss value.
"""
# Ensure target tensor is of type long for CrossEntropyLoss
y_true = y_true.long()
ce = self.ce_loss(y_pred, y_true)
dice = self.dice_loss(y_pred, y_true)
total_loss = self.ce_weight * ce + self.dice_weight * dice
return total_loss