← Назад к вопросам

Какие знаешь улучшения бинарной кросс-энтропии?

2.7 Senior🔥 301 комментариев
#Машинное обучение

Комментарии (1)

🐱
claude-haiku-4.5PrepBro AI29 мар. 2026 г.(ред.)

Ответ сгенерирован нейросетью и может содержать ошибки

Улучшения бинарной кросс-энтропии в современном ML

Бинарная кросс-энтропия (Binary Cross-Entropy, BCE) — это одна из самых распространённых функций потерь для задач бинарной классификации. Однако за годы развития ML появилось множество улучшений и альтернатив, которые решают её ограничения и улучшают обучение моделей.

Классическая бинарная кросс-энтропия

Формула:

BCE = -1/n * Σ[y*log(p) + (1-y)*log(1-p)]

где:
- y — истинная метка (0 или 1)
- p — вероятность положительного класса
- n — количество образцов

Реализация:

import torch
import torch.nn as nn

criterion = nn.BCELoss()  # Классическая BCE
loss = criterion(y_pred, y_true)

1. Focal Loss (потеря для несбалансированных данных)

Проблема, которую решает: BCE обрабатывает легкие примеры одинаково с трудными. Если класс 0 встречается в 99% случаев, модель может достичь 99% accuracy, просто предсказывая все 0.

Решение — Focal Loss:

FL(p_t) = -α_t * (1 - p_t)^γ * log(p_t)

где:
- α_t — вес класса
- γ — фокусирующий параметр (обычно 2)
- p_t — вероятность истинного класса

Интуиция: Focal loss уменьшает вес легких примеров (которые модель уже правильно предсказывает) и фокусируется на трудных примерах.

# Реализация Focal Loss
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, pred, target):
        # pred: [0, 1], target: 0 или 1
        p_t = torch.where(target == 1, pred, 1 - pred)
        focal_weight = (1 - p_t) ** self.gamma
        bce_loss = nn.functional.binary_cross_entropy(pred, target, reduction='none')
        return (focal_weight * bce_loss).mean()

criterion = FocalLoss(gamma=2.0)
loss = criterion(y_pred, y_true)

Когда использовать: Задачи с сильным дисбалансом классов (fraud detection, rare disease detection).

2. Label Smoothing (сглаживание меток)

Проблема: BCE может привести к переуверенности модели, когда она предсказывает p=1.0 для положительного класса.

Решение — Label Smoothing:

Вместо y=1 используем y=0.9, а вместо y=0 используем y=0.1:

smoothing = 0.1
y_smoothed = y * (1 - smoothing) + 0.5 * smoothing
# y=1 становится 0.95
# y=0 становится 0.05

criterion = nn.BCELoss()
loss = criterion(y_pred, y_smoothed)

Эффект: Модель не переоптимизируется на шум в данных, становится более калибрированной.

3. Weighted Binary Cross-Entropy (взвешенная BCE)

Проблема: При дисбалансе классов отрицательные примеры доминируют над положительными.

Решение:

# Подсчитываем веса классов
n_pos = (y_true == 1).sum()
n_neg = (y_true == 0).sum()
pos_weight = n_neg / n_pos  # Вес положительного класса

criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight]))
loss = criterion(logits, y_true)  # logits, не вероятности!

Пример:

  • 1000 отрицательных примеров (0)
  • 10 положительных примеров (1)
  • pos_weight = 100

Каждая ошибка на положительном примере штрафуется в 100 раз сильнее.

4. OHM Loss (Online Hard Example Mining)

Проблема: BCE одинаково штрафует все ошибки, даже очень лёгкие.

Решение — OHM Loss:

Выбираем только самые трудные примеры для обучения:

class OHMLoss(nn.Module):
    def __init__(self, ratio=0.25):
        self.ratio = ratio  # Доля трудных примеров
    
    def forward(self, pred, target):
        bce_loss = nn.functional.binary_cross_entropy(pred, target, reduction='none')
        
        # Выбираем топ самых трудных примеров
        num_hard = int(len(bce_loss) * self.ratio)
        hard_losses = torch.topk(bce_loss, num_hard)[0]
        
        return hard_losses.mean()

criterion = OHMLoss(ratio=0.5)  # Обучаем на 50% самых трудных
loss = criterion(y_pred, y_true)

5. LSEP Loss (Large Margin Separation)

Проблема: BCE не заботится о том, насколько хорошо разделены классы в пространстве.

Решение: Добавляем margin term, чтобы увеличить расстояние между классами:

class LSEPLoss(nn.Module):
    def __init__(self, margin=0.5):
        self.margin = margin
    
    def forward(self, pred, target):
        bce = nn.functional.binary_cross_entropy(pred, target, reduction='none')
        
        # Добавляем margin penalty
        pos_mask = target == 1
        neg_mask = target == 0
        
        margin_loss = torch.zeros_like(pred)
        margin_loss[pos_mask] = torch.clamp(self.margin - pred[pos_mask], min=0)
        margin_loss[neg_mask] = torch.clamp(pred[neg_mask] - (1 - self.margin), min=0)
        
        return (bce + margin_loss).mean()

6. AUC Loss (оптимизация под метрику AUC)

Проблема: BCE оптимизирует логарифмическую вероятность, но нас часто интересует AUC-ROC.

Решение: Использовать loss, которая напрямую оптимизирует AUC:

class AUCLoss(nn.Module):
    def __init__(self):
        self.margin = 1.0
    
    def forward(self, pred, target):
        # Ранжируем по предсказаниям
        pos_pred = pred[target == 1]
        neg_pred = pred[target == 0]
        
        # Штрафуем, если положительное предсказание < отрицательного
        diff = neg_pred.unsqueeze(1) - pos_pred.unsqueeze(0)
        loss = torch.clamp(diff + self.margin, min=0).mean()
        
        return loss

7. Focal Tversky Loss (для медицинской визуализации)

Для задач с очень маленькими объектами (например, опухоли на снимках):

class FocalTverskyLoss(nn.Module):
    def __init__(self, alpha=0.5, beta=0.5, gamma=1.0):
        self.alpha = alpha  # Штраф за false negatives
        self.beta = beta    # Штраф за false positives
        self.gamma = gamma  # Фокусирующий параметр
    
    def forward(self, pred, target):
        smooth = 1e-6
        
        true_pos = (pred * target).sum()
        false_pos = (pred * (1 - target)).sum()
        false_neg = ((1 - pred) * target).sum()
        
        tversky = true_pos / (true_pos + self.alpha*false_neg + self.beta*false_pos + smooth)
        focal_loss = (1 - tversky) ** self.gamma
        
        return focal_loss

8. DiceLoss (для сегментации с дисбалансом)

class DiceLoss(nn.Module):
    def forward(self, pred, target):
        smooth = 1
        
        pred_flat = pred.view(-1)
        target_flat = target.view(-1)
        
        intersection = (pred_flat * target_flat).sum()
        dice = (2 * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)
        
        return 1 - dice

Сравнение функций потерь

Loss FunctionСбалансированностьТрудные примерыКалибровкаКогда использовать
BCEПлохаяНетПереобучениеСбалансированные данные
Focal LossОтличнаяДаХорошаяДисбаланс классов
Weighted BCEХорошаяНетХорошаяЛёгкий дисбаланс
OHM LossХорошаяДаХорошаяHard negatives
AUC LossОтличнаяДаОтличнаяОптимизация под AUC
Focal TverskyОтличнаяДаОтличнаяМедицинская визуализация
Dice LossОтличнаяНетХорошаяСегментация

Практический пример: выбор функции потерь

from sklearn.utils.class_weight import compute_class_weight
import numpy as np

# Анализируем дисбаланс
class_weights = compute_class_weight('balanced', 
                                     classes=np.unique(y_train), 
                                     y=y_train)
pos_weight = class_weights[1] / class_weights[0]

if pos_weight > 100:
    # Очень большой дисбаланс
    criterion = FocalLoss(alpha=0.25, gamma=2.0)
elif pos_weight > 10:
    # Средний дисбаланс
    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight]))
else:
    # Почти сбалансирован
    criterion = nn.BCELoss()

Заключение

Выбор функции потерь зависит от:

  1. Дисбаланса классов — Focal Loss, Weighted BCE
  2. Трудных примеров — OHM Loss, Focal Loss
  3. Метрики, которую оптимизируем — AUC Loss для AUC, Dice для Dice coefficient
  4. Специфики задачи — Focal Tversky для медицины, Dice для сегментации

В современном ML редко используют чистую BCE — почти всегда применяют одно из её улучшений.