Какие знаешь архитектуры декодера?
Комментарии (2)
Ответ сгенерирован нейросетью и может содержать ошибки
Архитектуры декодера в глубоком обучении
Декодер (Decoder) — это компонент нейросети, который генерирует последовательность на выходе на основе входного контекста. Существует несколько основных архитектур, каждая с уникальными характеристиками и областями применения.
1. Seq2Seq с LSTM/GRU (Sequence-to-Sequence)
Архитектура:
- Энкодер (обычно LSTM/GRU) сжимает входную последовательность в контекстный вектор
- Декодер (тоже LSTM/GRU) развертывает этот вектор в выходную последовательность
- Используется Attention механизм для улучшения качества
import torch
import torch.nn as nn
class Seq2SeqDecoder(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers=2):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(
embedding_dim,
hidden_dim,
num_layers,
batch_first=True,
dropout=0.3
)
self.fc = nn.Linear(hidden_dim, vocab_size)
def forward(self, x, encoder_hidden):
embedded = self.embedding(x) # (batch, seq_len, embedding_dim)
lstm_out, hidden = self.lstm(embedded, encoder_hidden)
logits = self.fc(lstm_out) # (batch, seq_len, vocab_size)
return logits, hidden
# Использование для машинного перевода, summarization
Преимущества:
- Хорошо работает с переменной длиной последовательностей
- Можно обучать end-to-end
Недостатки:
- Потеря информации при сжатии в один вектор
- Плохо масштабируется на длинные последовательности
2. Attention Mechanism (Bahdanau Attention)
Идея: Декодер не просто использует последний скрытый вектор энкодера, а сосредотачивается на разных частях входа при генерации каждого токена.
import torch
import torch.nn as nn
import torch.nn.functional as F
class AttentionDecoder(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, encoder_hidden_dim):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim + encoder_hidden_dim, hidden_dim)
# Attention слои
self.attention = nn.Linear(hidden_dim + encoder_hidden_dim, hidden_dim)
self.v = nn.Linear(hidden_dim, 1)
self.fc = nn.Linear(hidden_dim, vocab_size)
def forward(self, x, encoder_outputs, hidden):
embedded = self.embedding(x).unsqueeze(0) # (1, batch, embedding_dim)
# Attention: что смотреть в encoder_outputs
attention_input = torch.cat([
hidden[0], # decoder hidden state
encoder_outputs # все выходы энкодера
], dim=2)
attention_scores = torch.tanh(self.attention(attention_input))
attention_weights = F.softmax(self.v(attention_scores), dim=1)
context = torch.sum(attention_weights * encoder_outputs, dim=1)
# Декодер использует контекст
lstm_input = torch.cat([embedded, context.unsqueeze(0)], dim=2)
lstm_out, new_hidden = self.lstm(lstm_input, hidden)
logits = self.fc(lstm_out)
return logits, new_hidden, attention_weights
# Применение: машинный перевод, question answering, summarization
Улучшение: Значительно повышает качество на длинных последовательностях.
3. Transformer Decoder (Self-Attention)
Архитектура:
- Отказ от RNN в пользу pure attention
- Позиционные кодирования для сохранения порядка
- Мультиголовое внимание (Multi-Head Attention)
- Feed-forward сети в каждом слое
import torch
import torch.nn as nn
import math
class TransformerDecoderLayer(nn.Module):
def __init__(self, hidden_dim, num_heads, ffn_dim, dropout=0.1):
super().__init__()
# Self-attention: декодер смотрит на свои выходы
self.self_attention = nn.MultiheadAttention(
hidden_dim, num_heads, dropout=dropout
)
# Cross-attention: декодер смотрит на выходы энкодера
self.cross_attention = nn.MultiheadAttention(
hidden_dim, num_heads, dropout=dropout
)
# Feed-forward сеть
self.ffn = nn.Sequential(
nn.Linear(hidden_dim, ffn_dim),
nn.ReLU(),
nn.Linear(ffn_dim, hidden_dim)
)
# Layer normalization
self.norm1 = nn.LayerNorm(hidden_dim)
self.norm2 = nn.LayerNorm(hidden_dim)
self.norm3 = nn.LayerNorm(hidden_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, encoder_output, tgt_mask=None):
# Self-attention (с маской, чтобы не смотреть на будущие токены)
self_attn_out, _ = self.self_attention(
x, x, x,
attn_mask=tgt_mask # Важно для авторегрессивной генерации!
)
x = self.norm1(x + self.dropout(self_attn_out))
# Cross-attention (к выходам энкодера)
cross_attn_out, _ = self.cross_attention(
x, encoder_output, encoder_output
)
x = self.norm2(x + self.dropout(cross_attn_out))
# Feed-forward
ffn_out = self.ffn(x)
x = self.norm3(x + self.dropout(ffn_out))
return x
class TransformerDecoder(nn.Module):
def __init__(self, vocab_size, hidden_dim, num_layers, num_heads, ffn_dim):
super().__init__()
self.embedding = nn.Embedding(vocab_size, hidden_dim)
self.pos_encoding = self._create_positional_encoding(hidden_dim)
self.decoder_layers = nn.ModuleList([
TransformerDecoderLayer(hidden_dim, num_heads, ffn_dim)
for _ in range(num_layers)
])
self.fc_out = nn.Linear(hidden_dim, vocab_size)
def _create_positional_encoding(self, hidden_dim, max_len=1000):
pe = torch.zeros(max_len, hidden_dim)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, hidden_dim, 2).float() *
(-math.log(10000.0) / hidden_dim)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
return pe
def forward(self, x, encoder_output, tgt_mask=None):
x = self.embedding(x)
x = x + self.pos_encoding[:x.size(0)].unsqueeze(0)
for layer in self.decoder_layers:
x = layer(x, encoder_output, tgt_mask)
logits = self.fc_out(x)
return logits
Преимущества:
- Параллелизм — можно обрабатывать всю последовательность сразу
- Лучшее масштабирование на длинные последовательности
- Используется в BERT, GPT, T5
Недостатки:
- Требует больше памяти (квадратическая сложность внимания)
- Сложнее в реализации
4. Autoregressive Decoder (GPT-style)
Идея: Генерировать токены по одному, каждый раз предсказывая P(token_i | token_1...token_{i-1})
class GPTDecoder(nn.Module):
def __init__(self, vocab_size, hidden_dim, num_layers, num_heads):
super().__init__()
self.transformer_decoder = TransformerDecoder(
vocab_size, hidden_dim, num_layers, num_heads, ffn_dim=4*hidden_dim
)
def generate(self, prompt, max_length=100, temperature=1.0, top_k=50):
"""Авторегрессивная генерация"""
tokens = prompt.clone()
for _ in range(max_length):
# Предсказываем следующий токен
logits = self.transformer_decoder(tokens)
next_token_logits = logits[-1, :] / temperature
# Top-k sampling
top_k_logits, top_k_indices = torch.topk(
next_token_logits, min(top_k, len(logits[0]))
)
probabilities = torch.softmax(top_k_logits, dim=0)
next_token = torch.multinomial(probabilities, 1)
tokens = torch.cat([tokens, top_k_indices[next_token].unsqueeze(0)])
if next_token == self.eos_token:
break
return tokens
Применение: GPT, LLaMA, Mistral (генерация текста)
5. Beam Search Decoding
Вместо жадной генерации (выбираем топ-1 токен):
class BeamSearchDecoder:
def __init__(self, beam_width=5):
self.beam_width = beam_width
def decode(self, encoder_output, max_length=50):
"""Beam search для поиска лучшей последовательности"""
batch_size = encoder_output.size(0)
# Начинаем с <START> токена
sequences = torch.full(
(batch_size, 1), self.start_token, dtype=torch.long
)
sequence_scores = torch.zeros(batch_size)
for step in range(max_length):
decoder_output = self.decoder(sequences, encoder_output)
# Получаем log вероятности
log_probs = torch.log_softmax(decoder_output[:, -1], dim=-1)
# Выбираем top-k кандидатов
scores, next_tokens = torch.topk(
log_probs, self.beam_width, dim=-1
)
# Обновляем последовательности
new_sequences = []
for beam in range(self.beam_width):
new_seq = torch.cat([
sequences,
next_tokens[:, beam].unsqueeze(-1)
], dim=-1)
new_sequences.append(new_seq)
sequences = torch.cat(new_sequences, dim=0)
return sequences
6. Pointer Networks (для задач с выбором из входа)
Применение: TSP, graph problems, copy mechanism
class PointerNetwork(nn.Module):
"""Декодер указывает на позиции в входе"""
def __init__(self, hidden_dim):
super().__init__()
self.attention = nn.Linear(2 * hidden_dim, 1)
def forward(self, decoder_hidden, encoder_outputs):
# Вычисляем веса внимания для каждой позиции входа
context = torch.cat([
decoder_hidden.expand(encoder_outputs.size(0), -1),
encoder_outputs
], dim=1)
attention_scores = self.attention(context)
probabilities = torch.softmax(attention_scores, dim=0)
# Возвращаем индекс с максимальной вероятностью
return torch.argmax(probabilities)
Сравнение архитектур
| Архитектура | Скорость | Качество | Память | Использование |
|---|---|---|---|---|
| LSTM | Медленно | Среднее | Мало | Наследие, RNN задачи |
| LSTM + Attention | Медленно | Хорошее | Мало | Seq2seq, NMT |
| Transformer | Быстро | Отличное | Много | Современный стандарт |
| Autoregressive | Медленно | Отличное | Много | Генерация текста |
| Beam Search | Очень медленно | Максимальное | Много | Production |
| Pointer Network | Медленно | Хорошее | Мало | Комбинаторные задачи |
Практический выбор
# Для Machine Translation: Transformer + Beam Search
# Для Summarization: BERT Encoder + Transformer Decoder
# Для Text Generation: GPT-style Autoregressive
# Для Q&A: Pointer Networks
# Для старого кода: LSTM + Attention