← Назад к вопросам

Что такое focal loss и когда его применять?

2.0 Middle🔥 112 комментариев
#Глубокое обучение#Метрики и оценка моделей

Комментарии (2)

🐱
claude-haiku-4.5PrepBro AI2 апр. 2026 г.(ред.)

Ответ сгенерирован нейросетью и может содержать ошибки

Focal Loss: решение проблемы class imbalance

Focal Loss — это модифицированная функция потерь для классификации, разработанная для борьбы с дисбалансом классов. Она снижает вес легких примеров (которые модель уже хорошо предсказывает) и фокусирует обучение на сложных примерах (hard negatives).

Проблема class imbalance

import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

# Создаём дисбалансированный датасет
X, y = make_classification(
    n_samples=10000,
    n_features=20,
    n_classes=2,
    weights=[0.95, 0.05],  # 95% класс 0, 5% класс 1
    random_state=42
)

print(f"Класс 0: {(y == 0).sum()}")
print(f"Класс 1: {(y == 1).sum()}")
# Класс 0: 9500
# Класс 1: 500

# Проблема: обычная Cross Entropy Loss ленива
# Может добиться 95% accuracy, просто предсказывая всегда класс 0

Почему обычная Cross Entropy неэффективна

import torch
import torch.nn as nn

# Standard Cross Entropy Loss
ce_loss = nn.CrossEntropyLoss()

# Примеры:
p_positive = 0.9  # модель уверена в положительном классе
p_negative = 0.1  # модель уверена в отрицательном классе (неверно)

# Cross Entropy для лёгких примеров всё равно больше!
loss_easy_correct = ce_loss(
    torch.tensor([[0.1, 0.9]]),  # уверено в классе 1
    torch.tensor([1])
)  # маленькая loss, но всё ещё уменьшает градиент

loss_hard_wrong = ce_loss(
    torch.tensor([[0.6, 0.4]]),  # не уверено, но неверно
    torch.tensor([1])
)  # большая loss, больше градиент

print(f"Loss лёгкий пример (правильный): {loss_easy_correct:.4f}")
print(f"Loss сложный пример (неправильный): {loss_hard_wrong:.4f}")

Focal Loss формула

FL(p_t) = -α_t * (1 - p_t)^γ * log(p_t)

Где:
- p_t: вероятность правильного класса
- α_t: вес класса (для дисбаланса)
- γ (gamma): параметр фокусировки (usually 2.0)
- (1 - p_t)^γ: "focusing parameter" - даёт низкий вес лёгким примерам

Визуализация:

  • Когда p_t близко к 1 (лёгкий пример): (1 - p_t) близко к 0, loss маленькая
  • Когда p_t близко к 0 (сложный пример): (1 - p_t) близко к 1, loss большая
  • Параметр γ управляет крутизной: γ=0 это обычная CE, γ=2 даёт сильную фокусировку

Реализация Focal Loss

import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, logits, labels):
        """
        logits: (batch_size, num_classes) - выход модели
        labels: (batch_size,) - истинные метки
        """
        # Вероятности
        p = F.softmax(logits, dim=1)
        
        # Получить p_t для каждого примера
        # Если label=1, берём p[:, 1], если label=0, берём p[:, 0]
        p_t = p.gather(1, labels.view(-1, 1)).squeeze(1)
        
        # Focal weight: (1 - p_t)^gamma
        focal_weight = (1 - p_t) ** self.gamma
        
        # Cross entropy
        ce_loss = F.cross_entropy(logits, labels, reduction='none')
        
        # Focal loss = alpha * focal_weight * ce_loss
        focal_loss = self.alpha * focal_weight * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# Использование
focal_loss_fn = FocalLoss(alpha=0.25, gamma=2.0)
model = nn.Linear(20, 2)
optimizer = torch.optim.Adam(model.parameters())

for batch in data_loader:
    X, y = batch
    logits = model(X)
    loss = focal_loss_fn(logits, y)
    
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

Focal Loss в PyTorch (встроенная реализация в timm)

import torch
from torch import nn
from torchvision.ops import sigmoid_focal_loss

# PyTorch имеет встроенную реализацию для binary classification
loss = sigmoid_focal_loss(
    predictions,  # sigmoid выход
    targets,  # 0 или 1
    alpha=0.25,
    gamma=2.0,
    reduction='mean'
)

Параметр gamma (γ)

# γ управляет фокусировкой на сложных примерах

# γ = 0: обычная Cross Entropy
# γ = 0.5: легкая фокусировка
# γ = 1.0: средняя фокусировка
# γ = 2.0: сильная фокусировка (дефолт)
# γ = 5.0: очень сильная фокусировка

# Визуализация
import matplotlib.pyplot as plt
import numpy as np

p_t = np.linspace(0.01, 1, 100)

for gamma in [0, 0.5, 1.0, 2.0, 5.0]:
    focal_weight = (1 - p_t) ** gamma
    ce_loss = -np.log(p_t)
    focal_loss = focal_weight * ce_loss
    plt.plot(p_t, focal_loss, label=f'gamma={gamma}')

plt.xlabel('p_t (вероятность правильного класса)')
plt.ylabel('Focal Loss')
plt.legend()
plt.title('Focal Loss для разных значений gamma')
plt.show()

Практический пример: обнаружение объектов (RetinaNet)

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

class RetinaNet(nn.Module):
    def __init__(self, num_classes=80):
        super().__init__()
        # Backbone (ResNet)
        self.backbone = ResNet50()
        
        # Classification head
        self.cls_head = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, num_classes, kernel_size=3, padding=1)
        )
        
        # Bounding box regression head
        self.bbox_head = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 4, kernel_size=3, padding=1)  # 4 координаты
        )
        
        self.focal_loss = FocalLoss(alpha=0.25, gamma=2.0)
    
    def forward(self, x):
        features = self.backbone(x)
        cls_logits = self.cls_head(features)
        bbox_preds = self.bbox_head(features)
        return cls_logits, bbox_preds

model = RetinaNet(num_classes=80)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

for epoch in range(100):
    for images, cls_targets, bbox_targets in train_loader:
        cls_logits, bbox_preds = model(images)
        
        # Classification loss (Focal Loss)
        cls_loss = model.focal_loss(cls_logits, cls_targets)
        
        # Bounding box loss (Smooth L1)
        bbox_loss = F.smooth_l1_loss(bbox_preds, bbox_targets)
        
        total_loss = cls_loss + bbox_loss
        total_loss.backward()
        optimizer.step()
        optimizer.zero_grad()

Когда использовать Focal Loss

Используй Focal Loss если:

  1. Сильный дисбаланс классов (менее 10% минорного класса)

    # Пример: обнаружение редкого события
    minority_ratio = (y == 1).sum() / len(y)
    if minority_ratio < 0.1:
        use_focal_loss = True
    
  2. Object detection / Face detection

    • Миллионы легких негативных примеров (background)
    • Считанные сложные позитивные примеры
  3. Когда сбалансированные веса класса не помогают

    # Сбалансированные веса
    class_weights = torch.tensor([1.0, 19.0])  # 1:19 дисбаланс
    ce_loss = nn.CrossEntropyLoss(weight=class_weights)
    
    # Focal loss может быть более эффективен
    focal_loss = FocalLoss(alpha=0.25, gamma=2.0)
    
  4. Medical imaging (обнаружение опухолей, переломов)

НЕ используй Focal Loss если:

  1. Классы примерно сбалансированы (45-55% разделение)
  2. Уже используешь weighted sampling или oversampling
  3. Hard negative mining уже реализована

Параметры alpha и gamma

# alpha: вес класса (для очень сильного дисбаланса)
# Если дисбаланс 1:9
alpha = 0.9  # даём больше веса редкому классу

# gamma: фокусировка на сложных примерах
# Начни с gamma=2.0, затем экспериментируй
# gamma=1.0 для мягкой фокусировки
# gamma=3.0-5.0 для очень сложных задач

for gamma in [1.0, 2.0, 3.0]:
    for alpha in [0.25, 0.5, 0.75]:
        model = train_with_focal_loss(
            train_loader,
            FocalLoss(alpha=alpha, gamma=gamma)
        )
        val_loss = evaluate(model, val_loader)
        print(f"alpha={alpha}, gamma={gamma}: val_loss={val_loss:.4f}")

Выводы

  1. Focal Loss решает проблему class imbalance эффективнее, чем простое взвешивание
  2. Стандартные параметры (α=0.25, γ=2.0) работают хорошо в большинстве случаев
  3. Object detection (RetinaNet) показал, что Focal Loss снижает loss в 100+ раз по сравнению с обычной CE
  4. Экспериментируй с γ: начни с 2.0, попробуй 1.0-5.0 в зависимости от задачи
  5. Комбинируй с другими техниками (hard negative mining, class balancing)

Focal Loss — это мощная техника для работы с дисбалансированными данными, особенно в задачах обнаружения и детекции.

Что такое focal loss и когда его применять? | PrepBro