Какой learning rate использовать для большого vs маленького батча?
Комментарии (1)
Ответ сгенерирован нейросетью и может содержать ошибки
Какой learning rate использовать для большого vs маленького батча?
Размер батча и learning rate — две критически связанные гиперпараметры в глубоком обучении. Правильный их подбор определяет как скорость обучения, так и финальное качество модели.
Основная теория
Градиент, вычисленный на батче, используется для обновления весов:
W_new = W_old - learning_rate × gradient
Критическое наблюдение: размер батча влияет на дисперсию оценки градиента.
- Маленький батч → шумная оценка градиента → высокая дисперсия
- Большой батч → стабильная оценка градиента → низкая дисперсия
Правило масштабирования learning rate
Классический подход (Linear Scaling Rule):
learning_rate_new = learning_rate_base × (batch_size_new / batch_size_base)
Пример:
Исходная конфигурация:
- batch_size = 32
- learning_rate = 0.001
Если увеличиваем батч в 4 раза (до 128):
- learning_rate_new = 0.001 × (128 / 32) = 0.004
Интуиция: большой батч дает более точный градиент, поэтому можно смелее сдвигаться в этом направлении.
Маленький батч (small batch)
# Типичные параметры
batch_size = 32 # или даже 16, 8
learning_rate = 0.001 # или 0.0005
# Характеристики маленького батча:
# 1. Более шумная оценка градиента
# 2. Более частые обновления весов (больше итераций per epoch)
# 3. Лучше обобщение (может действовать как регуляризатор)
# 4. Медленнее GPU (плохое использование паралелизма)
import torch
import torch.nn as nn
model = nn.Linear(10, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
# Обучение с маленьким батчом
for epoch in range(100):
for batch_idx, (data, labels) in enumerate(small_batch_loader):
optimizer.zero_grad()
output = model(data)
loss = nn.functional.mse_loss(output, labels)
loss.backward()
optimizer.step() # частые обновления!
Когда использовать маленький батч:
- Ограниченная память GPU
- Хотите лучше обобщение
- Регуляризация данных важна
Большой батч (large batch)
# Типичные параметры
batch_size = 256 # или 512, 1024
learning_rate = 0.01 # увеличился в 8 раз
# Характеристики большого батча:
# 1. Стабильная, точная оценка градиента
# 2. Меньше обновлений весов per epoch
# 3. Может привести к переобучению
# 4. Быстрее GPU (хорошее использование паралелизма)
# 5. ТРЕБУЕТ БОЛЬШЕГО LEARNING RATE!
import torch
import torch.nn as nn
model = nn.Linear(10, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# Обучение с большим батчом
for epoch in range(100):
for batch_idx, (data, labels) in enumerate(large_batch_loader):
optimizer.zero_grad()
output = model(data)
loss = nn.functional.mse_loss(output, labels)
loss.backward()
optimizer.step() # меньше обновлений!
Когда использовать большой батч:
- Достаточно памяти
- Нужна скорость обучения
- Хорошая регуляризация (dropout, L2 и т.д.)
- Стабильность обучения важна
Практический пример масштабирования
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
# Синтетические данные
X = torch.randn(10000, 20)
y = torch.randn(10000, 1)
dataset = TensorDataset(X, y)
# Конфигурации
configs = [
{'batch_size': 32, 'lr': 0.001},
{'batch_size': 128, 'lr': 0.004}, # 4x batch → 4x lr
{'batch_size': 256, 'lr': 0.008}, # 8x batch → 8x lr
{'batch_size': 512, 'lr': 0.016}, # 16x batch → 16x lr
]
for config in configs:
batch_size = config['batch_size']
lr = config['lr']
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
model = nn.Linear(20, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
# Обучение и мониторинг
print(f"\nBatch size: {batch_size}, LR: {lr}")
for epoch in range(5):
total_loss = 0
for data, labels in loader:
optimizer.zero_grad()
output = model(data)
loss = nn.functional.mse_loss(output, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch}: Loss = {total_loss / len(loader):.6f}")
Проблема "генерализации" большого батча
Исследования показали, что большие батчи могут привести к худшей обобщаемости, даже при оптимальном learning rate.
# Проблема: острые минимумы vs пологие минимумы
# Маленький батч + noise → находит пологие минимумы (лучше обобщает)
# Большой батч + clean gradient → может найти острые минимумы (плохо обобщает)
# Решение: использовать learning rate warmup и schedule
import torch.optim.lr_scheduler as lr_scheduler
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
scheduler = lr_scheduler.PolynomialLR(optimizer, total_iters=100)
for epoch in range(100):
train(epoch)
scheduler.step() # уменьшает learning rate со временем
Адаптивные методы оптимизации
Adam, RMSprop и другие адаптивные методы менее чувствительны к размеру батча:
import torch.optim as optim
# SGD: требует тщательной настройки LR при изменении batch size
optimizer_sgd = optim.SGD(model.parameters(), lr=0.01)
# Adam: более прощающий к изменениям batch size
optimizer_adam = optim.Adam(model.parameters(), lr=0.001)
# На практике:
# SGD с большим батчом → нужна высокая LR
# Adam → можно использовать одну LR для разных batch sizes
Learning Rate Scheduling
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
optimizer = optim.SGD(model.parameters(), lr=0.1)
# Стратегия 1: Step decay
scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
# Стратегия 2: Cosine annealing (часто используется с большим батчом)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
# Стратегия 3: Warmup (увеличение LR в начале)
# Особенно полезно для больших батчей
warmup_iters = 1000
base_lr = 0.01
for iter in range(warmup_iters):
lr = base_lr * (iter + 1) / warmup_iters
for param_group in optimizer.param_groups:
param_group['lr'] = lr
Практические рекомендации
1. Начните с базовых параметров
# Стартовая конфигурация для экспериментов
base_config = {
'batch_size': 32,
'learning_rate': 0.001,
'optimizer': 'Adam' # менее критичен к LR
}
2. Масштабируйте при изменении batch size
# Если увеличили batch_size в N раз:
# - Для SGD: увеличьте LR в sqrt(N) раз (консервативно) или N раз (агрессивно)
# - Для Adam: поэкспериментируйте, но часто можно не менять
scaling_factor = new_batch_size / base_batch_size
if optimizer_type == 'SGD':
new_lr = base_lr * (scaling_factor ** 0.5) # sqrt масштабирование
else:
new_lr = base_lr # Adam обычно не требует масштабирования
3. Используйте learning rate finder
# PyTorch Lightning имеет встроенный lr_finder
from pytorch_lightning import Trainer
trainer = Trainer()
trainer.tune(model, train_dataloaders=train_loader)
# Автоматически найдёт оптимальный learning rate
4. Мониторьте loss curve
# Хорошее обучение: плавное падение loss
# Проблемы:
# - loss растет → LR слишком большой
# - loss не падает → LR слишком маленький
# - loss нестабилен → проблемы с градиентом или batch size
import matplotlib.pyplot as plt
plt.plot(train_losses)
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.show()
Сравнительная таблица
| Параметр | Маленький батч | Большой батч |
|---|---|---|
| Learning rate | Низкий (~0.001) | Высокий (~0.01-0.1) |
| Обновлений per epoch | Много | Мало |
| Шум градиента | Высокий | Низкий |
| Обобщение | Лучше | Может быть хуже |
| GPU эффективность | Низкая | Высокая |
| Масштабирование LR | 1x | Nx (где N = ratio) |
Заключение
Ключевые принципы:
- Маленький батч + низкий LR → медленно, но стабильно
- Большой батч + высокий LR → быстро, но требует внимательной настройки
- Linear Scaling Rule → хорошая стартовая точка для SGD
- Адаптивные методы (Adam) → менее чувствительны к этим параметрам
- Learning rate scheduling → критичен для больших батчей
- Экспериментируйте → оптимальные параметры зависят от задачи и данных