← Назад к вопросам

Что такое дистилляция модели?

2.7 Senior🔥 141 комментариев
#MLOps и инфраструктура#Глубокое обучение

Комментарии (1)

🐱
claude-haiku-4.5PrepBro AI30 мар. 2026 г.(ред.)

Ответ сгенерирован нейросетью и может содержать ошибки

Что такое дистилляция модели?

Дистилляция модели (Knowledge Distillation) — это метод трансфера знаний, при котором более простая модель-ученик (Student Model) обучается воспроизводить поведение более сложной модели-учителя (Teacher Model). Цель — достичь близкой точности при значительном снижении размера и вычислительных затрат.

Основной принцип

Teacher Model (сложная, точная):

  • Большая нейронная сеть или ансамбль моделей
  • Обучена на исходной задаче
  • Высокая точность, но требует много ресурсов

Student Model (простая, быстрая):

  • Меньше параметров и слоёв
  • Обучается предсказываниям учителя
  • Удобна для production (низкая задержка)

Как это работает

Temperature Scaling (важная концепция):

import torch
import torch.nn.functional as F

def soft_target_loss(student_logits, teacher_logits, temperature=4.0):
    # Мягкие вероятности от учителя
    teacher_probs = F.softmax(teacher_logits / temperature, dim=1)
    
    # Мягкие вероятности от студента
    student_log_probs = F.log_softmax(student_logits / temperature, dim=1)
    
    # KL divergence между распределениями
    loss = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean')
    return loss * (temperature ** 2)

Полный процесс обучения:

import torch.nn as nn

class DistillationLoss(nn.Module):
    def __init__(self, temperature=4.0, alpha=0.7):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha  # вес для loss дистилляции
        self.ce_loss = nn.CrossEntropyLoss()
    
    def forward(self, student_logits, teacher_logits, targets):
        # Loss от дистилляции (KL divergence)
        teacher_probs = F.softmax(teacher_logits / self.temperature, dim=1)
        student_log_probs = F.log_softmax(student_logits / self.temperature, dim=1)
        distill_loss = F.kl_div(student_log_probs, teacher_probs)
        
        # Обычный supervised loss
        ce_loss = self.ce_loss(student_logits, targets)
        
        # Итоговый loss
        total_loss = self.alpha * ce_loss + (1 - self.alpha) * distill_loss
        return total_loss

Практический пример: Классификация изображений

Шаг 1: Обучение Teacher Model

teacher = EfficientNet(depth='b7')  # большая модель
teacher.train(train_loader, epochs=100)
# Получаем accuracy ~95%

Шаг 2: Инициализация Student Model

student = EfficientNet(depth='b0')  # маленькая модель
# Параметров в 100 раз меньше

Шаг 3: Обучение с дистилляцией

distill_loss_fn = DistillationLoss(temperature=4.0, alpha=0.7)
optimizer = torch.optim.Adam(student.parameters(), lr=0.001)

for epoch in range(50):
    for batch_x, batch_y in train_loader:
        student_logits = student(batch_x)
        teacher_logits = teacher(batch_x).detach()  # no gradients
        
        loss = distill_loss_fn(student_logits, teacher_logits, batch_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
# Student достигает ~93% точности

Преимущества дистилляции

  • Снижение размера модели — уменьшение памяти в 10-100 раз
  • Ускорение инференса — быстрее на 5-20 раз
  • Сохранение точности — потеря обычно 1-3%
  • Оптимизация для мобильных — развёртывание на телефонах
  • Энергоэффективность — снижение потребления батареи

Когда использовать

  • Развёртывание на edge-устройствах (мобильные, IoT)
  • Снижение latency в production
  • Оптимизация подConstraints по памяти
  • Real-time приложения (рекомендации, поиск)

Вариации метода

Attention Transfer — трансфер attention maps Feature-based Distillation — обучение на скрытых слоях Relational Knowledge Distillation — использование relational информации

Дистилляция — мощный инструмент в арсенале ML-инженера для оптимизации моделей без потери качества.

Что такое дистилляция модели? | PrepBro