← Назад к вопросам
Что такое focal loss и когда его применять?
2.0 Middle🔥 112 комментариев
#Глубокое обучение#Метрики и оценка моделей
Комментарии (2)
🐱
claude-haiku-4.5PrepBro AI2 апр. 2026 г.(ред.)
Ответ сгенерирован нейросетью и может содержать ошибки
Focal Loss: решение проблемы class imbalance
Focal Loss — это модифицированная функция потерь для классификации, разработанная для борьбы с дисбалансом классов. Она снижает вес легких примеров (которые модель уже хорошо предсказывает) и фокусирует обучение на сложных примерах (hard negatives).
Проблема class imbalance
import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
# Создаём дисбалансированный датасет
X, y = make_classification(
n_samples=10000,
n_features=20,
n_classes=2,
weights=[0.95, 0.05], # 95% класс 0, 5% класс 1
random_state=42
)
print(f"Класс 0: {(y == 0).sum()}")
print(f"Класс 1: {(y == 1).sum()}")
# Класс 0: 9500
# Класс 1: 500
# Проблема: обычная Cross Entropy Loss ленива
# Может добиться 95% accuracy, просто предсказывая всегда класс 0
Почему обычная Cross Entropy неэффективна
import torch
import torch.nn as nn
# Standard Cross Entropy Loss
ce_loss = nn.CrossEntropyLoss()
# Примеры:
p_positive = 0.9 # модель уверена в положительном классе
p_negative = 0.1 # модель уверена в отрицательном классе (неверно)
# Cross Entropy для лёгких примеров всё равно больше!
loss_easy_correct = ce_loss(
torch.tensor([[0.1, 0.9]]), # уверено в классе 1
torch.tensor([1])
) # маленькая loss, но всё ещё уменьшает градиент
loss_hard_wrong = ce_loss(
torch.tensor([[0.6, 0.4]]), # не уверено, но неверно
torch.tensor([1])
) # большая loss, больше градиент
print(f"Loss лёгкий пример (правильный): {loss_easy_correct:.4f}")
print(f"Loss сложный пример (неправильный): {loss_hard_wrong:.4f}")
Focal Loss формула
FL(p_t) = -α_t * (1 - p_t)^γ * log(p_t)
Где:
- p_t: вероятность правильного класса
- α_t: вес класса (для дисбаланса)
- γ (gamma): параметр фокусировки (usually 2.0)
- (1 - p_t)^γ: "focusing parameter" - даёт низкий вес лёгким примерам
Визуализация:
- Когда p_t близко к 1 (лёгкий пример): (1 - p_t) близко к 0, loss маленькая
- Когда p_t близко к 0 (сложный пример): (1 - p_t) близко к 1, loss большая
- Параметр γ управляет крутизной: γ=0 это обычная CE, γ=2 даёт сильную фокусировку
Реализация Focal Loss
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, logits, labels):
"""
logits: (batch_size, num_classes) - выход модели
labels: (batch_size,) - истинные метки
"""
# Вероятности
p = F.softmax(logits, dim=1)
# Получить p_t для каждого примера
# Если label=1, берём p[:, 1], если label=0, берём p[:, 0]
p_t = p.gather(1, labels.view(-1, 1)).squeeze(1)
# Focal weight: (1 - p_t)^gamma
focal_weight = (1 - p_t) ** self.gamma
# Cross entropy
ce_loss = F.cross_entropy(logits, labels, reduction='none')
# Focal loss = alpha * focal_weight * ce_loss
focal_loss = self.alpha * focal_weight * ce_loss
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss
# Использование
focal_loss_fn = FocalLoss(alpha=0.25, gamma=2.0)
model = nn.Linear(20, 2)
optimizer = torch.optim.Adam(model.parameters())
for batch in data_loader:
X, y = batch
logits = model(X)
loss = focal_loss_fn(logits, y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
Focal Loss в PyTorch (встроенная реализация в timm)
import torch
from torch import nn
from torchvision.ops import sigmoid_focal_loss
# PyTorch имеет встроенную реализацию для binary classification
loss = sigmoid_focal_loss(
predictions, # sigmoid выход
targets, # 0 или 1
alpha=0.25,
gamma=2.0,
reduction='mean'
)
Параметр gamma (γ)
# γ управляет фокусировкой на сложных примерах
# γ = 0: обычная Cross Entropy
# γ = 0.5: легкая фокусировка
# γ = 1.0: средняя фокусировка
# γ = 2.0: сильная фокусировка (дефолт)
# γ = 5.0: очень сильная фокусировка
# Визуализация
import matplotlib.pyplot as plt
import numpy as np
p_t = np.linspace(0.01, 1, 100)
for gamma in [0, 0.5, 1.0, 2.0, 5.0]:
focal_weight = (1 - p_t) ** gamma
ce_loss = -np.log(p_t)
focal_loss = focal_weight * ce_loss
plt.plot(p_t, focal_loss, label=f'gamma={gamma}')
plt.xlabel('p_t (вероятность правильного класса)')
plt.ylabel('Focal Loss')
plt.legend()
plt.title('Focal Loss для разных значений gamma')
plt.show()
Практический пример: обнаружение объектов (RetinaNet)
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
class RetinaNet(nn.Module):
def __init__(self, num_classes=80):
super().__init__()
# Backbone (ResNet)
self.backbone = ResNet50()
# Classification head
self.cls_head = nn.Sequential(
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(256, num_classes, kernel_size=3, padding=1)
)
# Bounding box regression head
self.bbox_head = nn.Sequential(
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(256, 4, kernel_size=3, padding=1) # 4 координаты
)
self.focal_loss = FocalLoss(alpha=0.25, gamma=2.0)
def forward(self, x):
features = self.backbone(x)
cls_logits = self.cls_head(features)
bbox_preds = self.bbox_head(features)
return cls_logits, bbox_preds
model = RetinaNet(num_classes=80)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for epoch in range(100):
for images, cls_targets, bbox_targets in train_loader:
cls_logits, bbox_preds = model(images)
# Classification loss (Focal Loss)
cls_loss = model.focal_loss(cls_logits, cls_targets)
# Bounding box loss (Smooth L1)
bbox_loss = F.smooth_l1_loss(bbox_preds, bbox_targets)
total_loss = cls_loss + bbox_loss
total_loss.backward()
optimizer.step()
optimizer.zero_grad()
Когда использовать Focal Loss
Используй Focal Loss если:
-
Сильный дисбаланс классов (менее 10% минорного класса)
# Пример: обнаружение редкого события minority_ratio = (y == 1).sum() / len(y) if minority_ratio < 0.1: use_focal_loss = True -
Object detection / Face detection
- Миллионы легких негативных примеров (background)
- Считанные сложные позитивные примеры
-
Когда сбалансированные веса класса не помогают
# Сбалансированные веса class_weights = torch.tensor([1.0, 19.0]) # 1:19 дисбаланс ce_loss = nn.CrossEntropyLoss(weight=class_weights) # Focal loss может быть более эффективен focal_loss = FocalLoss(alpha=0.25, gamma=2.0) -
Medical imaging (обнаружение опухолей, переломов)
НЕ используй Focal Loss если:
- Классы примерно сбалансированы (45-55% разделение)
- Уже используешь weighted sampling или oversampling
- Hard negative mining уже реализована
Параметры alpha и gamma
# alpha: вес класса (для очень сильного дисбаланса)
# Если дисбаланс 1:9
alpha = 0.9 # даём больше веса редкому классу
# gamma: фокусировка на сложных примерах
# Начни с gamma=2.0, затем экспериментируй
# gamma=1.0 для мягкой фокусировки
# gamma=3.0-5.0 для очень сложных задач
for gamma in [1.0, 2.0, 3.0]:
for alpha in [0.25, 0.5, 0.75]:
model = train_with_focal_loss(
train_loader,
FocalLoss(alpha=alpha, gamma=gamma)
)
val_loss = evaluate(model, val_loader)
print(f"alpha={alpha}, gamma={gamma}: val_loss={val_loss:.4f}")
Выводы
- Focal Loss решает проблему class imbalance эффективнее, чем простое взвешивание
- Стандартные параметры (α=0.25, γ=2.0) работают хорошо в большинстве случаев
- Object detection (RetinaNet) показал, что Focal Loss снижает loss в 100+ раз по сравнению с обычной CE
- Экспериментируй с γ: начни с 2.0, попробуй 1.0-5.0 в зависимости от задачи
- Комбинируй с другими техниками (hard negative mining, class balancing)
Focal Loss — это мощная техника для работы с дисбалансированными данными, особенно в задачах обнаружения и детекции.