← Назад к вопросам

Что такое дистилляция нейронных сетей?

3.0 Senior🔥 101 комментариев
#MLOps и инфраструктура#Глубокое обучение

Комментарии (1)

🐱
claude-haiku-4.5PrepBro AI30 мар. 2026 г.(ред.)

Ответ сгенерирован нейросетью и может содержать ошибки

Дистилляция нейронных сетей (Knowledge Distillation)

Дистилляция нейронных сетей — это техника машинного обучения, которая позволяет передать знания от большой, сложной модели (учителя) к меньшей, более эффективной модели (ученика). Цель — достичь близкой к оригинальной производительности при значительном снижении размера модели, времени инференса и требований к вычислительным ресурсам.

Принцип работы

Дистилляция основана на простой идее: вместо обучения ученика напрямую на исходных данных и жёстких метках (hard targets), мы обучаем его на мягких вероятностных предсказаниях учителя (soft targets).

Процесс дистилляции:

  1. Обучение модели-учителя — большая сеть обучается на полном датасете и достигает высокой точности
  2. Получение soft targets — модель-учитель генерирует вероятностные предсказания для всех примеров
  3. Обучение модели-ученика — маленькая сеть обучается минимизировать расстояние между своими предсказаниями и предсказаниями учителя
  4. Тонкая настройка — опционально, ученик дообучается на исходных данных с жёсткими метками

Роль температуры (Temperature)

Ключевой параметр дистилляции — температура (T). Она контролирует мягкость вероятностных распределений:

# Без температуры (T=1) — жёсткие распределения
soft_targets = softmax(logits_teacher)

# С температурой (T>1) — мягкие распределения
soft_targets = softmax(logits_teacher / T)

# Пример: если учитель уверен на 99%, с T=4 это становится примерно 85%

Высокая температура (T=4-10) создаёт более "размытые" распределения, что раскрывает информацию о том, какие неправильные классы почти верны. Это очень информативно для обучения ученика.

Функция потерь при дистилляции

import torch
import torch.nn.functional as F

def distillation_loss(student_logits, teacher_logits, hard_targets, T=4.0, alpha=0.7):
    # KL-divergence loss между soft targets
    soft_loss = F.kl_div(
        F.log_softmax(student_logits / T, dim=1),
        F.softmax(teacher_logits / T, dim=1),
        reduction=batchmean
    ) * (T ** 2)  # Масштабирование для компенсации деления на T
    
    # Cross-entropy loss на жёстких метках
    hard_loss = F.cross_entropy(student_logits, hard_targets)
    
    # Комбинированная потеря
    total_loss = alpha * hard_loss + (1 - alpha) * soft_loss
    return total_loss

Здесь:

  • alpha — вес жёсткой потери (обычно 0.5-0.9)
  • T — температура (обычно 3-20)
  • Множитель T² — компенсирует меньшие градиенты при делении на T

Преимущества дистилляции

1. Сжатие модели

  • Ученик может быть в 10-100 раз меньше учителя
  • Снижение памяти, требуемой для инференса
  • Возможность развёртывания на мобильных устройствах и edge devices

2. Ускорение инференса

Время инференса: Teacher=200ms → Student=20ms (10x ускорение)
Точность: Teacher=90% → Student=88% (небольшое падение)

3. Лучшая генерализация Ученик часто обобщается лучше, так как учится на гладких распределениях вероятностей, а не на жёстких метках

4. Улучшение через ансамбли Множество учителей могут обучить одного ученика лучше, чем один учитель

Практический пример

import torch
import torch.nn as nn

class StudentModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)  # Маленькая сеть
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

class TeacherModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 512)  # Большая сеть
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 10)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

# Обучение ученика
teacher = TeacherModel().eval()  # Учитель в режиме оценки
student = StudentModel().train()
optimizer = torch.optim.Adam(student.parameters())

for images, labels in train_loader:
    with torch.no_grad():
        teacher_logits = teacher(images)
    
    student_logits = student(images)
    loss = distillation_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.7)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Варианты и расширения

1. Attention Transfer — передача информации о том, на какие части входа обращает внимание учитель

2. Feature-based distillation — обучение ученика воспроизводить промежуточные представления учителя

3. Multi-teacher distillation — один ученик обучается у нескольких учителей

4. Self-distillation — сеть обучается собственным предсказаниям с разными температурами

Применение в реальных проектах

  • Мобильные приложения — компактные модели для распознавания лиц, речи
  • Облачные сервисы — быстрый инференс при сохранении качества
  • Edge AI — развёртывание на IoT устройствах
  • Real-time системы — низкая латентность критична

Выводы

Дистилляция нейронных сетей — это мощная техника для оптимизации моделей глубокого обучения. Она позволяет достичь практического компромисса между точностью и эффективностью, что критично для реальных приложений. Знание этого метода необходимо для работы с ресурсоограниченными окружениями и оптимизации production-систем.

Что такое дистилляция нейронных сетей? | PrepBro