← Назад к вопросам

Что такое teacher forcing?

2.0 Middle🔥 91 комментариев
#NLP и обработка текста#Глубокое обучение

Комментарии (1)

🐱
claude-haiku-4.5PrepBro AI2 апр. 2026 г.(ред.)

Ответ сгенерирован нейросетью и может содержать ошибки

Teacher Forcing

Teacher forcing — это техника обучения последовательностей (sequence-to-sequence моделей), где на каждом шаге декодера используется истинный output из обучающего набора, а не предсказанный output модели на предыдущем шаге. Это стандартная техника для RNN, LSTM, transformer-based моделей.

Базовая идея

Представим task машинного перевода: английское предложение -> французское предложение.

Без teacher forcing (autoregressive, при инференсе):

Input: "Hello world"
Шаг 1: предсказать 1-е слово -> "Bonjour" (угаданное)
Шаг 2: использовать "Bonjour" как input -> предсказать 2-е слово -> "le" (угаданное)
Шаг 3: использовать "Bonjour le" как input -> предсказать 3-е слово
...
Проблема: ошибка на шаге 1 распространяется на шаг 2, 3, ...

С teacher forcing (во время обучения):

Input: "Hello world"
Шаг 1: предсказать 1-е слово -> "Bonjour" (но используем истинное: "Bonjour")
Шаг 2: использовать истинное слово "Bonjour" -> предсказать 2-е слово -> "monde" (но используем истинное: "monde")
Шаг 3: использовать истинное "monde" -> предсказать 3-е слово
...
Преимущество: каждый шаг смотрит на правильный контекст

Математическое определение

Без teacher forcing:

y_1 = model.predict(x)          # предсказанный 1-й token
y_2 = model.predict(x, y_1)     # второй token предсказан по y_1
y_3 = model.predict(x, y_1, y_2) # ошибки накапливаются

С teacher forcing:

y_1 = model.predict(x, y_true_0)        # используем истинный 0-й token (START)
y_2 = model.predict(x, y_true_1)        # используем истинный 1-й token
y_3 = model.predict(x, y_true_2)        # используем истинный 2-й token
# Все шаги независимы! Легче обучаться.

Пример: Text Generation with LSTM

import torch
import torch.nn as nn
from torch.optim import Adam

class SimpleLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
    
    def forward(self, input_ids, target_ids=None, teacher_forcing_ratio=1.0):
        """
        input_ids: [batch_size, seq_len]
        target_ids: [batch_size, target_seq_len] (для обучения)
        teacher_forcing_ratio: 1.0 = всегда teacher forcing, 0.0 = никогда
        """
        batch_size = input_ids.shape[0]
        vocab_size = self.fc.out_features
        
        # Encode input
        embedded = self.embedding(input_ids)  # [batch, seq_len, embed_dim]
        _, (hidden, cell) = self.lstm(embedded)  # hidden: [1, batch, hidden_dim]
        
        # Decode with teacher forcing
        outputs = []
        input_token = torch.tensor([2] * batch_size)  # START token
        
        for t in range(len(target_ids[0])):
            # Декодируем один шаг
            embedded_token = self.embedding(input_token).unsqueeze(1)  # [batch, 1, embed_dim]
            output, (hidden, cell) = self.lstm(embedded_token, (hidden, cell))
            logits = self.fc(output.squeeze(1))  # [batch, vocab_size]
            outputs.append(logits)
            
            # Teacher forcing: используем истинный следующий token
            if target_ids is not None and torch.rand(1).item() < teacher_forcing_ratio:
                input_token = target_ids[:, t]  # Истинный token
            else:
                input_token = logits.argmax(dim=1)  # Предсказанный token
        
        return torch.stack(outputs, dim=1)  # [batch, target_seq_len, vocab_size]

# Обучение
model = SimpleLSTM(vocab_size=1000, embed_dim=64, hidden_dim=128)
optimizer = Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()

for epoch in range(10):
    # Обучение с teacher forcing (соотношение = 1.0)
    output = model(input_ids, target_ids, teacher_forcing_ratio=1.0)
    loss = loss_fn(output.reshape(-1, 1000), target_ids.reshape(-1))
    loss.backward()
    optimizer.step()

Проблема: Exposure Bias

Teacher forcing создаёт серьёзную проблему называется exposure bias или distribution mismatch:

Во время обучения:

  • Декодер видит истинные tokens
  • Модель никогда не видит свои собственные ошибки
  • Learns от чистых, правильных примеров

Во время инференса:

  • Декодер видит свои собственные предсказания
  • Если на шаге 1 ошибка, то шаг 2 видит ошибку!
  • Модель не была обучена на своих ошибках

Результат: модель работает хорошо на валидации, но плохо в production.

# Пример exposure bias
# Обучение:
Шаг 1: предсказать "Bonjour" по истинному START
       Модель не видит никогда, что может предсказать "Привет"

# Инференс:
Шаг 1: предсказать "Привет" (ошибка!)
Шаг 2: предсказать слово по "Привет" (ошибочный контекст)
       Модель не была обучена на таком контексте -> плохое предсказание

Решения: Scheduled Sampling

Scheduled Sampling — постепенно уменьшать teacher_forcing_ratio во время обучения:

epoch = 5
teacher_forcing_ratio = 1.0 - (epoch / total_epochs) * 0.9
# Epoch 0: 1.0 (полный teacher forcing)
# Epoch 5: 0.55
# Epoch 10: 0.1 (в основном autoregressive)

output = model(input_ids, target_ids, teacher_forcing_ratio)

Это позволяет модели:

  1. Быстро сходиться (teacher forcing в начале)
  2. Привыкнуть к своим ошибкам (меньше teacher forcing в конце)

Почему teacher forcing нужен?

Без teacher forcing обучение медленнее и нестабильнее:

  1. Скорость обучения: с teacher forcing все шаги независимы -> параллелизм
  2. Стабильность: учимся от правильных примеров, не от ошибок
  3. Сходимость: градиенты более информативны
  4. Практичность: иначе обучение может не сойтись вообще
# Без teacher forcing: slow, unstable, long training
# С teacher forcing: fast, stable, short training
# Компромисс: exposure bias

Когда teacher forcing НЕ подходит

  • Ненаблюдаемые состояния: если decoder states не соответствуют истинным (например, в генерации).
  • Контролируемое поколение: когда нужен specific стиль (пример: poem generation с формальностью)
  • Interactive systems: chatbots должны быть обучены на своих ошибках

Альтернативы / Варианты

  1. Curriculum Learning — начать с лёгких примеров, перейти к сложным
  2. Reinforcement Learning — обучить напрямую от reward (например, BLEU score)
  3. Minimum Risk Training — минимизировать ожидаемый risk, не точный loss
  4. Levenshtein Transformer — не autoregressive, предсказывает все сразу

Практический совет

  1. Используй teacher forcing для обучения — почти всегда
  2. Плавно уменьшай во время обучения — scheduled sampling
  3. На инференсе, никогда не используй teacher forcing — это cheating!
  4. Мониторь gap между валидацией и инференсом — признак exposure bias

Teacher forcing — это то, что позволило обучить трансформеры и большие языковые модели. Без этого, современный AI не существовал бы.

Что такое teacher forcing? | PrepBro