← Назад к вопросам

Что такое Dice loss?

2.0 Middle🔥 111 комментариев
#Глубокое обучение#Метрики и оценка моделей

Комментарии (1)

🐱
claude-haiku-4.5PrepBro AI2 апр. 2026 г.(ред.)

Ответ сгенерирован нейросетью и может содержать ошибки

Dice Loss (F1 Loss)

Определение

Dice Loss — это функция потерь, часто используемая в задачах сегментации изображений. Она основана на коэффициенте Dice (F1-мере) и особенно эффективна для работы с несбалансированными классами, где один класс занимает значительно меньше пикселей.

История и мотивация

Dice Loss названа в честь коэффициента Dice (Sorensen-Dice coefficient), метрики из компьютерной томографии и медицинской визуализации. Она решает проблемы:

  1. Несбалансированные классы: если фон занимает 95% пикселей, Cross-Entropy Loss может просто предсказывать "фон" везде
  2. Геометрическое соответствие: Dice напрямую оптимизирует пересечение, а не посиксельную точность
  3. IoU-подобная метрика: близка к Intersection over Union (IoU), метрике оценки качества

Математика Dice Loss

Для бинарной классификации:

Dice = 2 * TP / (2 * TP + FP + FN)

где:
TP = true positives (правильно предсказанные позитивные)
FP = false positives (неправильно предсказанные позитивные)
FN = false negatives (пропущенные позитивные)

Dice Loss = 1 - Dice

Альтернативная формулировка через пересечение и объединение:

Dice = 2 * |X intersection Y| / (|X| + |Y|)

где X — предсказание, Y — true label

Для мультиклассовой сегментации (взвешенная):

Dice Loss = 1 - (2 * sum(TP_c) / (2 * sum(TP_c) + sum(FP_c) + sum(FN_c)))

Реализация в PyTorch

import torch
import torch.nn as nn
import torch.nn.functional as F

class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6, weight=None):
        """
        Dice Loss для бинарной и мультиклассовой сегментации.
        
        smooth: малое число для численной стабильности (по умолчанию 1e-6)
        weight: веса для классов (для несбалансированных данных)
        """
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        self.weight = weight
    
    def forward(self, pred, target):
        """
        pred: предсказания модели (logits или softmax), shape: (B, C, H, W)
        target: ground truth labels, shape: (B, H, W) или (B, 1, H, W)
        """
        # Применяем softmax для получения вероятностей
        pred = F.softmax(pred, dim=1)  # shape: (B, C, H, W)
        
        # One-hot encode target
        if target.dim() == 3:
            target = target.unsqueeze(1)  # (B, 1, H, W)
        
        target_one_hot = torch.zeros_like(pred)
        target_one_hot.scatter_(1, target, 1)  # (B, C, H, W)
        
        # Вычисляем Dice для каждого класса
        num_classes = pred.shape[1]
        dice_loss = 0
        
        for c in range(num_classes):
            pred_c = pred[:, c, :, :]  # (B, H, W)
            target_c = target_one_hot[:, c, :, :]  # (B, H, W)
            
            intersection = (pred_c * target_c).sum()
            union = pred_c.sum() + target_c.sum()
            
            dice = (2 * intersection + self.smooth) / (union + self.smooth)
            
            # Применяем веса если они заданы
            if self.weight is not None:
                dice *= self.weight[c]
            
            dice_loss += (1 - dice)
        
        return dice_loss / num_classes

# Более простая версия (часто встречается в practice)
class SimpleDiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(SimpleDiceLoss, self).__init__()
        self.smooth = smooth
    
    def forward(self, pred, target):
        """
        Простая реализация для бинарной сегментации.
        """
        # pred: (B, 1, H, W) или (B, 2, H, W) после softmax
        # target: (B, 1, H, W) или (B, H, W)
        
        pred_flat = pred.view(-1)
        target_flat = target.view(-1)
        
        intersection = (pred_flat * target_flat).sum()
        union = pred_flat.sum() + target_flat.sum()
        
        dice = (2 * intersection + self.smooth) / (union + self.smooth)
        return 1 - dice

Использование в PyTorch

import torch
from torch.utils.data import DataLoader
import torch.optim as optim

# Инициализация
model = ... # ваша U-Net, DeepLab и т.д.
criterion = DiceLoss(smooth=1e-6, weight=torch.tensor([0.1, 0.9]))  # веса для классов
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Обучение
for epoch in range(num_epochs):
    for images, masks in train_loader:
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        print(f"Epoch {epoch}: Loss = {loss.item():.4f}")

Dice Loss vs Cross-Entropy Loss

Пример на несбалансированных данных:

import torch
import torch.nn.functional as F

# Симуляция: маска объекта занимает только 5% пикселей
batch_size, h, w = 4, 128, 128

# Ground truth: 95% фона, 5% объекта
target = torch.zeros(batch_size, 1, h, w)
target[:, :, :int(h*0.2), :int(w*0.2)] = 1  # маленький объект

# Плохое предсказание (предсказываем везде фон)
pred_bad = torch.zeros_like(target)  # все нули (фон)

# Хорошее предсказание (правильно предсказываем объект)
pred_good = target.float() + 0.1 * torch.randn_like(target)

# Cross-Entropy Loss
ce_loss_bad = F.binary_cross_entropy_with_logits(
    pred_bad.float() * 10,
    target.float()
)
ce_loss_good = F.binary_cross_entropy_with_logits(
    pred_good,
    target.float()
)

print(f"CE Loss (плохое предсказание): {ce_loss_bad:.4f}")
print(f"CE Loss (хорошее предсказание): {ce_loss_good:.4f}")
print(f"Разница: {abs(ce_loss_bad - ce_loss_good):.4f}")
# Видно, что разница МАЛАЯ, CE Loss не штрафует сильно за плохое предсказание!

# Dice Loss
dice_loss_bad = SimpleDiceLoss()(pred_bad.float(), target.float())
dice_loss_good = SimpleDiceLoss()(pred_good, target.float())

print(f"\nDice Loss (плохое предсказание): {dice_loss_bad:.4f}")
print(f"Dice Loss (хорошее предсказание): {dice_loss_good:.4f}")
print(f"Разница: {abs(dice_loss_bad - dice_loss_good):.4f}")
# Dice Loss дает БОЛЬШУЮ разницу, сильно штрафует за плохое предсказание!

Варианты Dice Loss

1. Weighted Dice Loss

class WeightedDiceLoss(nn.Module):
    def __init__(self, weights, smooth=1e-6):
        super().__init__()
        self.weights = weights  # веса для каждого класса
        self.smooth = smooth
    
    def forward(self, pred, target):
        """
        Взвешенная Dice Loss для несбалансированных классов.
        """
        pred = F.softmax(pred, dim=1)
        total_loss = 0
        
        for c in range(pred.shape[1]):
            pred_c = pred[:, c].contiguous().view(-1)
            target_c = (target == c).float().view(-1)
            
            intersection = (pred_c * target_c).sum()
            union = pred_c.sum() + target_c.sum()
            
            dice = (2 * intersection + self.smooth) / (union + self.smooth)
            total_loss += self.weights[c] * (1 - dice)
        
        return total_loss

2. Focal Dice Loss (комбинация с Focal Loss)

class FocalDiceLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, smooth=1e-6):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.smooth = smooth
    
    def forward(self, pred, target):
        """
        Focal Dice Loss: комбинирует Dice и Focal для ещё лучшего
        фокуса на сложные примеры.
        """
        pred_flat = pred.view(-1)
        target_flat = target.view(-1)
        
        # Dice
        intersection = (pred_flat * target_flat).sum()
        union = pred_flat.sum() + target_flat.sum()
        dice = (2 * intersection + self.smooth) / (union + self.smooth)
        
        # Focal
        p_t = torch.where(target_flat == 1, pred_flat, 1 - pred_flat)
        focal = (1 - p_t) ** self.gamma
        
        return (1 - dice) + self.alpha * focal.mean()

3. Tversky Loss (обобщение Dice)

class TverskyLoss(nn.Module):
    def __init__(self, alpha=0.5, beta=0.5, smooth=1e-6):
        """
        Tversky Loss: обобщение Dice Loss с параметрами alpha и beta.
        alpha: вес False Positives
        beta: вес False Negatives
        """
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.smooth = smooth
    
    def forward(self, pred, target):
        pred_flat = pred.view(-1)
        target_flat = target.view(-1)
        
        tp = (pred_flat * target_flat).sum()
        fp = (pred_flat * (1 - target_flat)).sum()
        fn = ((1 - pred_flat) * target_flat).sum()
        
        tversky = tp / (tp + self.alpha * fp + self.beta * fn + self.smooth)
        return 1 - tversky

Когда использовать Dice Loss

Используй Dice Loss если:

  • Работаешь с сегментацией (медицинская визуализация, семантическая сегментация)
  • Классы несбалансированы (один класс занимает мало пикселей)
  • Нужна оптимизация IoU напрямую
  • Объекты интереса занимают малую часть изображения

Используй Cross-Entropy если:

  • Классы примерно сбалансированы
  • Работаешь с классификацией (не сегментацией)
  • Нужна стабильность обучения (CE более стабильна)

Best practice — комбинируй обе:

class CombinedLoss(nn.Module):
    def __init__(self, alpha=0.5):
        super().__init__()
        self.alpha = alpha
        self.dice_loss = DiceLoss()
        self.ce_loss = nn.CrossEntropyLoss()
    
    def forward(self, pred, target):
        dice = self.dice_loss(pred, target)
        ce = self.ce_loss(pred, target)
        return self.alpha * dice + (1 - self.alpha) * ce

Практические советы

  1. Smooth параметр: start с 1e-6, увеличивай если теряешь стабильность
  2. Нормализуй predictions: убедись что выход из модели в диапазоне [0, 1]
  3. Взвешивай классы: для несбалансированных данных используй weights
  4. Комбинируй с метриками: во время обучения смотри на IoU или мAP
  5. Экспериментируй: разные варианты работают лучше для разных задач

Dice Loss стала стандартом в медицинской визуализации и сегментации и рекомендуется для всех задач с несбалансированными данными.

Что такое Dice loss? | PrepBro