Что такое дистилляция нейронных сетей?
Комментарии (1)
Ответ сгенерирован нейросетью и может содержать ошибки
Дистилляция нейронных сетей (Knowledge Distillation)
Дистилляция нейронных сетей — это техника машинного обучения, которая позволяет передать знания от большой, сложной модели (учителя) к меньшей, более эффективной модели (ученика). Цель — достичь близкой к оригинальной производительности при значительном снижении размера модели, времени инференса и требований к вычислительным ресурсам.
Принцип работы
Дистилляция основана на простой идее: вместо обучения ученика напрямую на исходных данных и жёстких метках (hard targets), мы обучаем его на мягких вероятностных предсказаниях учителя (soft targets).
Процесс дистилляции:
- Обучение модели-учителя — большая сеть обучается на полном датасете и достигает высокой точности
- Получение soft targets — модель-учитель генерирует вероятностные предсказания для всех примеров
- Обучение модели-ученика — маленькая сеть обучается минимизировать расстояние между своими предсказаниями и предсказаниями учителя
- Тонкая настройка — опционально, ученик дообучается на исходных данных с жёсткими метками
Роль температуры (Temperature)
Ключевой параметр дистилляции — температура (T). Она контролирует мягкость вероятностных распределений:
# Без температуры (T=1) — жёсткие распределения
soft_targets = softmax(logits_teacher)
# С температурой (T>1) — мягкие распределения
soft_targets = softmax(logits_teacher / T)
# Пример: если учитель уверен на 99%, с T=4 это становится примерно 85%
Высокая температура (T=4-10) создаёт более "размытые" распределения, что раскрывает информацию о том, какие неправильные классы почти верны. Это очень информативно для обучения ученика.
Функция потерь при дистилляции
import torch
import torch.nn.functional as F
def distillation_loss(student_logits, teacher_logits, hard_targets, T=4.0, alpha=0.7):
# KL-divergence loss между soft targets
soft_loss = F.kl_div(
F.log_softmax(student_logits / T, dim=1),
F.softmax(teacher_logits / T, dim=1),
reduction=batchmean
) * (T ** 2) # Масштабирование для компенсации деления на T
# Cross-entropy loss на жёстких метках
hard_loss = F.cross_entropy(student_logits, hard_targets)
# Комбинированная потеря
total_loss = alpha * hard_loss + (1 - alpha) * soft_loss
return total_loss
Здесь:
- alpha — вес жёсткой потери (обычно 0.5-0.9)
- T — температура (обычно 3-20)
- Множитель T² — компенсирует меньшие градиенты при делении на T
Преимущества дистилляции
1. Сжатие модели
- Ученик может быть в 10-100 раз меньше учителя
- Снижение памяти, требуемой для инференса
- Возможность развёртывания на мобильных устройствах и edge devices
2. Ускорение инференса
Время инференса: Teacher=200ms → Student=20ms (10x ускорение)
Точность: Teacher=90% → Student=88% (небольшое падение)
3. Лучшая генерализация Ученик часто обобщается лучше, так как учится на гладких распределениях вероятностей, а не на жёстких метках
4. Улучшение через ансамбли Множество учителей могут обучить одного ученика лучше, чем один учитель
Практический пример
import torch
import torch.nn as nn
class StudentModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 128) # Маленькая сеть
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = torch.relu(self.fc1(x))
return self.fc2(x)
class TeacherModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 512) # Большая сеть
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 10)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
return self.fc3(x)
# Обучение ученика
teacher = TeacherModel().eval() # Учитель в режиме оценки
student = StudentModel().train()
optimizer = torch.optim.Adam(student.parameters())
for images, labels in train_loader:
with torch.no_grad():
teacher_logits = teacher(images)
student_logits = student(images)
loss = distillation_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.7)
optimizer.zero_grad()
loss.backward()
optimizer.step()
Варианты и расширения
1. Attention Transfer — передача информации о том, на какие части входа обращает внимание учитель
2. Feature-based distillation — обучение ученика воспроизводить промежуточные представления учителя
3. Multi-teacher distillation — один ученик обучается у нескольких учителей
4. Self-distillation — сеть обучается собственным предсказаниям с разными температурами
Применение в реальных проектах
- Мобильные приложения — компактные модели для распознавания лиц, речи
- Облачные сервисы — быстрый инференс при сохранении качества
- Edge AI — развёртывание на IoT устройствах
- Real-time системы — низкая латентность критична
Выводы
Дистилляция нейронных сетей — это мощная техника для оптимизации моделей глубокого обучения. Она позволяет достичь практического компромисса между точностью и эффективностью, что критично для реальных приложений. Знание этого метода необходимо для работы с ресурсоограниченными окружениями и оптимизации production-систем.