Что такое дистилляция модели?
Комментарии (1)
Ответ сгенерирован нейросетью и может содержать ошибки
Что такое дистилляция модели?
Дистилляция модели (Knowledge Distillation) — это метод трансфера знаний, при котором более простая модель-ученик (Student Model) обучается воспроизводить поведение более сложной модели-учителя (Teacher Model). Цель — достичь близкой точности при значительном снижении размера и вычислительных затрат.
Основной принцип
Teacher Model (сложная, точная):
- Большая нейронная сеть или ансамбль моделей
- Обучена на исходной задаче
- Высокая точность, но требует много ресурсов
Student Model (простая, быстрая):
- Меньше параметров и слоёв
- Обучается предсказываниям учителя
- Удобна для production (низкая задержка)
Как это работает
Temperature Scaling (важная концепция):
import torch
import torch.nn.functional as F
def soft_target_loss(student_logits, teacher_logits, temperature=4.0):
# Мягкие вероятности от учителя
teacher_probs = F.softmax(teacher_logits / temperature, dim=1)
# Мягкие вероятности от студента
student_log_probs = F.log_softmax(student_logits / temperature, dim=1)
# KL divergence между распределениями
loss = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean')
return loss * (temperature ** 2)
Полный процесс обучения:
import torch.nn as nn
class DistillationLoss(nn.Module):
def __init__(self, temperature=4.0, alpha=0.7):
super().__init__()
self.temperature = temperature
self.alpha = alpha # вес для loss дистилляции
self.ce_loss = nn.CrossEntropyLoss()
def forward(self, student_logits, teacher_logits, targets):
# Loss от дистилляции (KL divergence)
teacher_probs = F.softmax(teacher_logits / self.temperature, dim=1)
student_log_probs = F.log_softmax(student_logits / self.temperature, dim=1)
distill_loss = F.kl_div(student_log_probs, teacher_probs)
# Обычный supervised loss
ce_loss = self.ce_loss(student_logits, targets)
# Итоговый loss
total_loss = self.alpha * ce_loss + (1 - self.alpha) * distill_loss
return total_loss
Практический пример: Классификация изображений
Шаг 1: Обучение Teacher Model
teacher = EfficientNet(depth='b7') # большая модель
teacher.train(train_loader, epochs=100)
# Получаем accuracy ~95%
Шаг 2: Инициализация Student Model
student = EfficientNet(depth='b0') # маленькая модель
# Параметров в 100 раз меньше
Шаг 3: Обучение с дистилляцией
distill_loss_fn = DistillationLoss(temperature=4.0, alpha=0.7)
optimizer = torch.optim.Adam(student.parameters(), lr=0.001)
for epoch in range(50):
for batch_x, batch_y in train_loader:
student_logits = student(batch_x)
teacher_logits = teacher(batch_x).detach() # no gradients
loss = distill_loss_fn(student_logits, teacher_logits, batch_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Student достигает ~93% точности
Преимущества дистилляции
- Снижение размера модели — уменьшение памяти в 10-100 раз
- Ускорение инференса — быстрее на 5-20 раз
- Сохранение точности — потеря обычно 1-3%
- Оптимизация для мобильных — развёртывание на телефонах
- Энергоэффективность — снижение потребления батареи
Когда использовать
- Развёртывание на edge-устройствах (мобильные, IoT)
- Снижение latency в production
- Оптимизация подConstraints по памяти
- Real-time приложения (рекомендации, поиск)
Вариации метода
Attention Transfer — трансфер attention maps Feature-based Distillation — обучение на скрытых слоях Relational Knowledge Distillation — использование relational информации
Дистилляция — мощный инструмент в арсенале ML-инженера для оптимизации моделей без потери качества.