Какие слои нейросети ведут себя по-разному на обучении и инференсе?
Комментарии (1)
Ответ сгенерирован нейросетью и может содержать ошибки
Слои нейросети с разным поведением на обучении и инференсе
Важно помнить, что некоторые слои используют разные вычисления во время обучения и во время инференса (предсказания). Это может привести к ошибкам, если забыть переключить режимы.
1. Dropout
Как работает на обучении: Randomly зануляет активации с вероятностью p (обычно 0.5).
На обучении:
input: [1.0, 2.0, 3.0, 4.0]
mask: [1, 0, 1, 0] <- случайная маска
output: [1.0*scale, 0, 3.0*scale, 0]
Как работает на инференсе: Все активации ПРОПУСКАЮТСЯ (никакие не удаляются). Обычно используется масштабирование: активации делятся на (1-p).
import torch
import torch.nn as nn
class DropoutExample(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 5)
self.dropout = nn.Dropout(p=0.5) # Вероятность 50%
self.fc2 = nn.Linear(5, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.dropout(x) # На обучении: ~50% зануляется
x = self.fc2(x)
return x
model = DropoutExample()
# На обучении
model.train() # ВАЖНО!
x = torch.randn(32, 10)
y = model(x) # Dropout активен
# На инференсе
model.eval() # ВАЖНО! Иначе результаты будут неправильные
x = torch.randn(32, 10)
y = model(x) # Dropout выключен
Почему это нужно: Dropout — это техника регуляризации для борьбы с переобучением. На обучении мы случайно удаляем информацию, чтобы сеть была более robust. На инференсе мы хотим использовать всю информацию для лучших предсказаний.
2. Batch Normalization
На обучении: Нормализует входы по статистике текущего мини-батча:
y = (x - batch_mean) / sqrt(batch_var + epsilon) * gamma + beta
Где batch_mean и batch_var вычисляются ПО ТЕКУЩЕМУ БАТЧУ.
На инференсе: Использует экспоненциально взвешенные скользящие средние (running statistics), которые накопились во время обучения:
y = (x - running_mean) / sqrt(running_var + epsilon) * gamma + beta
class BatchNormExample(nn.Module):
def __init__(self):
super().__init__()
self.bn = nn.BatchNorm1d(10)
def forward(self, x):
return self.bn(x)
model = BatchNormExample()
# На обучении: использует статистику батча
model.train()
batch1 = torch.randn(32, 10)
batch2 = torch.randn(32, 10)
y1 = model(batch1) # Нормализует по batch1
y2 = model(batch2) # Нормализует по batch2
# На инференсе: использует накопленные скользящие средние
model.eval()
single_sample = torch.randn(1, 10)
y = model(single_sample) # Нормализует по running_mean и running_var
Проблема, если забыть переключиться:
model.train() # Ошибка: инференс в режиме train
single_sample = torch.randn(1, 10)
y = model(single_sample)
# Батч_mean и batch_var для одного примера близки к нулю!
# Результаты неправильные
3. Layer Normalization
В отличие от BatchNorm: LayerNorm нормализует ПО ПРИЗНАКАМ в одном примере, а не по батчу.
class LayerNormExample(nn.Module):
def __init__(self):
super().__init__()
self.ln = nn.LayerNorm(10)
def forward(self, x):
return self.ln(x)
model = LayerNormExample()
# LayerNorm одинаков на обучении и инференсе
model.train()
y1 = model(x) # Результат
model.eval()
y2 = model(x) # ТОЧНО ТОТ ЖЕ результат (если нет других слоев)
LayerNorm НЕ хранит статистику батча, поэтому ведёт себя одинаково везде.
4. Embedding (в режиме инференса для новых ID)
На обучении: может встретиться с ID токенов, которые есть в словаре.
На инференсе: может встретиться с новыми ID, которых не было в обучении.
Технически слой работает одинаково, но стратегия обработки новых ID отличается:
class EmbeddingExample(nn.Module):
def __init__(self, vocab_size=1000, embedding_dim=50):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
def forward(self, x):
return self.embedding(x)
model = EmbeddingExample()
# На обучении
x_train = torch.tensor([1, 5, 10, 20]) # Все ID есть в словаре
y_train = model(x_train) # Работает
# На инференсе
x_test = torch.tensor([1, 5, 10, 999]) # ID 999 может быть новым
try:
y_test = model(x_test)
except IndexError:
# IndexError если ID >= vocab_size
print("Новый ID не поддерживается!")
5. RNN / LSTM / GRU (stateful версии)
На обучении: обычно используем stateless RNN (скрытое состояние сбрасывается между батчами).
На инференсе: часто нужна stateful версия, которая запоминает скрытое состояние между инференсами.
class StatelessLSTM(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(10, 20, batch_first=True)
def forward(self, x):
out, (h, c) = self.lstm(x) # h и c = None на первый раз
return out
class StatefulLSTM(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(10, 20, batch_first=True)
self.h = None
self.c = None
def forward(self, x):
out, (self.h, self.c) = self.lstm(x, (self.h, self.c))
return out
def reset_state(self):
self.h = None
self.c = None
# На обучении: stateless
model_train = StatelessLSTM()
seq1 = torch.randn(32, 5, 10) # batch=32, seq_len=5, features=10
seq2 = torch.randn(32, 5, 10)
y1 = model_train(seq1)
y2 = model_train(seq2) # Независимо от seq1
# На инференсе: stateful (например, обработка потока данных)
model_infer = StatefulLSTM()
for token in stream_of_tokens:
output = model_infer(token.unsqueeze(0).unsqueeze(0))
# Скрытое состояние сохраняется между итерациями
6. Attention слои (с different scaling)
Некоторые attention механизмы используют разные scaling на обучении и инференсе для улучшения numerical stability.
Практический пример проблемы
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 20)
self.bn = nn.BatchNorm1d(20)
self.dropout = nn.Dropout(0.5)
self.fc2 = nn.Linear(20, 1)
def forward(self, x):
x = self.fc1(x)
x = self.bn(x)
x = torch.relu(x)
x = self.dropout(x) # Опасно!
x = self.fc2(x)
return x
model = MyModel()
# Тренируем
model.train()
for epoch in range(10):
x = torch.randn(32, 10)
y = torch.randn(32, 1)
out = model(x)
loss = nn.MSELoss()(out, y)
# ... backward и update
# Инференс ПРАВИЛЬНЫЙ
model.eval()
with torch.no_grad():
x_test = torch.randn(5, 10)
y_pred = model(x_test)
print(f"Предсказания: {y_pred.numpy()}")
# Инференс НЕПРАВИЛЬНЫЙ (если забыли model.eval())
model.train() # Опасно!
with torch.no_grad():
x_test = torch.randn(5, 10)
y_pred = model(x_test) # BatchNorm использует статистику батча из 5 примеров!
print(f"Неправильные предсказания: {y_pred.numpy()}")
Чеклист перед инференсом
# Правильный инференс:
model.eval() # ✓ Выключить train-режим
with torch.no_grad(): # ✓ Выключить gradient tracking
predictions = model(x_test)
# Если нужен train режим для чего-то:
model.train()
# ... что-то
model.eval() # ✓ Переключить обратно перед инференсом
Итоговая таблица
| Слой | На обучении | На инференсе | Переключение обязательно? |
|---|---|---|---|
| Dropout | Зануляет ~p% | Все активны | ДА |
| BatchNorm | Статистика батча | Скользящие средние | ДА |
| LayerNorm | Нормализует признаки | То же | НЕТ |
| Embedding | Обычный lookup | Может быть новый ID | ЗАВИСИТ |
| RNN Stateful | Сбрасывает состояние | Запоминает состояние | ДА (конфигурация) |
Итог: ВСЕГДА используйте model.eval() перед инференсом если ваша модель содержит Dropout или BatchNorm!