Что такое cross-entropy loss?
Комментарии (1)
Ответ сгенерирован нейросетью и может содержать ошибки
Cross-Entropy Loss
Cross-entropy loss — это одна из самых фундаментальных функций потерь в машинном обучении и глубоком обучении. Она измеряет расхождение между предсказанным распределением вероятностей и истинным распределением.
Информационная теория: основы
Для понимания cross-entropy нужны два понятия из теории информации:
Entropy (энтропия) H(p) — среднее количество информации в распределении вероятностей:
H(p) = -sum(p_i * log(p_i))
Высокая энтропия = распределение "размазанное" (много неопределённости) Низкая энтропия = распределение сосредоточено на одном значении
KL Divergence (расхождение Кульбака-Лейблера) — мера различия двух распределений:
KL(p || q) = sum(p_i * log(p_i / q_i)) = sum(p_i * (log(p_i) - log(q_i)))
Cross-entropy связана с KL divergence:
Cross-Entropy(p, q) = H(p) + KL(p || q)
= -sum(p_i * log(q_i))
Так что minimizing cross-entropy равносильно minimizing KL divergence (H(p) не зависит от q).
Определение
Cross-entropy loss для классификации:
L = -sum(y_i * log(y_pred_i))
Где:
- y_i — истинная метка (обычно one-hot vector: [0, 1, 0])
- y_pred_i — предсказанная вероятность (выход softmax)
Для бинарной классификации:
L = -[y * log(p) + (1 - y) * log(1 - p)]
Где:
- y = 0 или 1 (истинный класс)
- p = вероятность класса 1 (выход sigmoid)
Пример: бинарная классификация
Представим задачу: классифицировать email как спам или не спам.
Истинная метка: y = 1 (это спам) Предсказанная вероятность: p = 0.8 (модель уверена, что спам)
loss = -[1 * log(0.8) + 0 * log(0.2)]
= -log(0.8)
= 0.223 # низкий loss (хорошо)
Если модель ошиблась: Предсказанная вероятность: p = 0.2 (модель думает, что не спам)
loss = -[1 * log(0.2) + 0 * log(0.8)]
= -log(0.2)
= 1.609 # высокий loss (плохо)
Мультиклассовая классификация
У нас есть 3 класса: {кот, собака, птица} Истинная метка: y = [0, 1, 0] (это собака) Предсказания модели:
- P(кот) = 0.1
- P(собака) = 0.7
- P(птица) = 0.2
loss = -(0 * log(0.1) + 1 * log(0.7) + 0 * log(0.2))
= -log(0.7)
= 0.357 # низкий loss
Реализация в PyTorch
import torch
import torch.nn as nn
# Бинарная классификация с BCELoss
loss_fn = nn.BCELoss() # Требует sigmoid внутри модели
predictions = torch.tensor([0.8, 0.3, 0.6])
targets = torch.tensor([1.0, 0.0, 1.0])
loss = loss_fn(predictions, targets)
print(f"BCE Loss: {loss.item()}")
# Мультиклассовая классификация с CrossEntropyLoss
loss_fn = nn.CrossEntropyLoss() # Внутри: log_softmax + NLLLoss
predictions = torch.tensor([[2.0, 1.0, 0.1],
[0.1, 3.0, 0.5],
[1.0, 0.5, 2.0]]) # logits (НЕ вероятности)
targets = torch.tensor([0, 1, 2]) # индексы классов
loss = loss_fn(predictions, targets)
print(f"CrossEntropy Loss: {loss.item()}")
# С softmax + log + NLLLoss явно
softmax = nn.Softmax(dim=1)
log_softmax = nn.LogSoftmax(dim=1)
nll_loss = nn.NLLLoss()
probs = softmax(predictions) # Применить softmax
log_probs = log_softmax(predictions) # Применить log_softmax
loss = nll_loss(log_probs, targets) # Вычислить NLL
print(f"Manual CrossEntropy Loss: {loss.item()}")
Почему логарифм?
Логарифм имеет важные свойства для машинного обучения:
- Штрафует ошибки сильно: log(0.1) = -2.3, а log(0.5) = -0.69
- Делает всё выпуклым — есть один глобальный минимум
- Численная стабильность — log_softmax избегает overflow
- Интерпретация: информационное содержание ошибки
Cross-Entropy vs MSE
Почему не использовать Mean Squared Error для классификации?
# MSE: (y - p)^2
# Для y=1, p=0.01: MSE = (1 - 0.01)^2 = 0.98
# Для y=1, p=0.5: MSE = (1 - 0.5)^2 = 0.25
# Cross-Entropy: -y*log(p) - (1-y)*log(1-p)
# Для y=1, p=0.01: CE = -log(0.01) = 4.6
# Для y=1, p=0.5: CE = -log(0.5) = 0.69
Cross-entropy штрафует неправильные предсказания НАМНОГО сильнее, чем MSE. Это лучше для классификации.
Практические замечания
-
Никогда не используй softmax + CrossEntropyLoss отдельно
- CrossEntropyLoss уже включает log_softmax
- Правильно:
nn.CrossEntropyLoss()с raw logits - Неправильно:
nn.CrossEntropyLoss()с softmax выходом
-
Class imbalance: используй
weightпараметрweights = torch.tensor([0.2, 0.8]) # Дать большую важность редкому классу loss_fn = nn.CrossEntropyLoss(weight=weights) -
Label smoothing: помогает от overfitting
# Вместо y=[0, 1, 0], используй y=[0.05, 0.9, 0.05] loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1) -
Focal Loss: для очень несбалансированных данных
# Упор на hard examples (неправильно классифицированные примеры) L = -alpha * (1 - p_t)^gamma * log(p_t)
Вывод
Cross-entropy loss — это стандарт для классификации потому что:
- Основана на теории информации (имеет смысл)
- Штрафует ошибки экспоненциально (эффективна)
- Выпукла (гарантирует схождение)
- Эмпирически работает лучше других
Если работаешь с классификацией и не используешь cross-entropy — это красный флаг на интервью.