Что такое Mixture of Experts?
Комментарии (2)
Ответ сгенерирован нейросетью и может содержать ошибки
Что такое Mixture of Experts?
Mixture of Experts (MoE) — это архитектура нейронной сети, которая становится все популярнее в больших языковых моделях.
Суть идеи
В архитектуре MoE есть несколько специализированных "экспертов" (подсетей), и "маршрутизатор" определяет, какой эксперт(ы) обработает входные данные.
Как это работает?
import torch
import torch.nn as nn
class SimpleExpert(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.fc1(x))
return self.fc2(x)
class MixtureOfExperts(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_experts=4):
super().__init__()
# Несколько экспертов
self.experts = nn.ModuleList([
SimpleExpert(input_dim, hidden_dim, output_dim)
for _ in range(num_experts)
])
# Маршрутизатор
self.router = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, num_experts),
nn.Softmax(dim=-1)
)
self.num_experts = num_experts
def forward(self, x):
# Веса от маршрутизатора
router_weights = self.router(x)
# Выходы от всех экспертов
expert_outputs = []
for expert in self.experts:
expert_outputs.append(expert(x))
# Складываем с весами
expert_outputs = torch.stack(expert_outputs, dim=1)
output = (router_weights.unsqueeze(-1) * expert_outputs).sum(dim=1)
return output, router_weights
Sparse Mixture of Experts
В реальных системах (Google GLaM, OpenAI GPT-4) используется Sparse MoE, где маршрутизатор выбирает только несколько экспертов (обычно 2-8 из 64).
Примеры использования
Google GLaM:
- 1.2 триллионов параметров
- 64 эксперта
- Активно 2 эксперта на примере
- 7x ускорение при меньших вычислениях
OpenAI GPT-4 (предполагаемо):
- 8 экспертов по 220 миллиардов параметров
- Total 1.76 триллионов параметров
- Используется 2-3 эксперта одновременно
Преимущества
1. Масштабируемость без пропорционального увеличения вычислений:
- 10x параметров → 1.5x вычислений
- 6-7x ускорение на практике
2. Специализированные эксперты: Разные эксперты могут выучить разные аспекты задачи: синтаксис, семантику, факты, логику.
3. Условные вычисления: Не все примеры используют все параметры модели.
Проблемы
1. Load Imbalance (дисбаланс нагрузки)
Маршрутизатор может отправлять все примеры одному эксперту. Решение — добавить auxiliary loss:
def load_balancing_loss(router_logits, num_experts):
router_probs = torch.softmax(router_logits, dim=-1)
expert_load = router_probs.sum(dim=0)
ideal_load = 1.0 / num_experts
loss = torch.sum((expert_load - ideal_load) ** 2)
return loss
total_loss = main_loss + 0.01 * load_balancing_loss(router_logits, num_experts)
2. Communication Overhead
Во время распределённого обучения нужно отправлять данные к разным экспертам.
3. Training Instability
Маршрутизатор может "прыгать" между экспертами, нарушая обучение.
MoE в Трансформере
class MoETransformerLayer(nn.Module):
def __init__(self, d_model, num_experts=8, num_active=2):
super().__init__()
# Self-attention
self.attention = nn.MultiheadAttention(d_model, num_heads=8)
# MoE вместо обычного feedforward
self.moe = SparseExpert(d_model, 4*d_model, d_model, num_experts, num_active)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
# Self-attention
attn_out, _ = self.attention(x, x, x)
x = x + attn_out
x = self.norm1(x)
# MoE feedforward
moe_out, _ = self.moe(x)
x = x + moe_out
x = self.norm2(x)
return x
Ключевые выводы
-
MoE — архитектура условных вычислений, где маршрутизатор выбирает активные эксперты
-
Основное преимущество: масштабируемость без пропорционального увеличения вычислений
-
Sparse MoE чаще всего в больших моделях (используется 2-8 экспертов из 8-64)
-
Основные вызовы: load imbalance, communication overhead, training instability
-
Используется в: Google GLaM, T5 MoE, OpenAI GPT-4
-
На практике MoE работает лучше всего:
- С достаточно большими данными
- На распределённом обучении
- Когда нужна производительность при низких вычислениях