← Назад к вопросам

Какие знаешь архитектуры декодера?

3.0 Senior🔥 112 комментариев
#NLP и обработка текста#Глубокое обучение

Комментарии (2)

🐱
claude-haiku-4.5PrepBro AI29 мар. 2026 г.(ред.)

Ответ сгенерирован нейросетью и может содержать ошибки

Архитектуры декодера в глубоком обучении

Декодер (Decoder) — это компонент нейросети, который генерирует последовательность на выходе на основе входного контекста. Существует несколько основных архитектур, каждая с уникальными характеристиками и областями применения.

1. Seq2Seq с LSTM/GRU (Sequence-to-Sequence)

Архитектура:

  • Энкодер (обычно LSTM/GRU) сжимает входную последовательность в контекстный вектор
  • Декодер (тоже LSTM/GRU) развертывает этот вектор в выходную последовательность
  • Используется Attention механизм для улучшения качества
import torch
import torch.nn as nn

class Seq2SeqDecoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(
            embedding_dim,
            hidden_dim,
            num_layers,
            batch_first=True,
            dropout=0.3
        )
        self.fc = nn.Linear(hidden_dim, vocab_size)
    
    def forward(self, x, encoder_hidden):
        embedded = self.embedding(x)  # (batch, seq_len, embedding_dim)
        lstm_out, hidden = self.lstm(embedded, encoder_hidden)
        logits = self.fc(lstm_out)  # (batch, seq_len, vocab_size)
        return logits, hidden

# Использование для машинного перевода, summarization

Преимущества:

  • Хорошо работает с переменной длиной последовательностей
  • Можно обучать end-to-end

Недостатки:

  • Потеря информации при сжатии в один вектор
  • Плохо масштабируется на длинные последовательности

2. Attention Mechanism (Bahdanau Attention)

Идея: Декодер не просто использует последний скрытый вектор энкодера, а сосредотачивается на разных частях входа при генерации каждого токена.

import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionDecoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, encoder_hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim + encoder_hidden_dim, hidden_dim)
        
        # Attention слои
        self.attention = nn.Linear(hidden_dim + encoder_hidden_dim, hidden_dim)
        self.v = nn.Linear(hidden_dim, 1)
        
        self.fc = nn.Linear(hidden_dim, vocab_size)
    
    def forward(self, x, encoder_outputs, hidden):
        embedded = self.embedding(x).unsqueeze(0)  # (1, batch, embedding_dim)
        
        # Attention: что смотреть в encoder_outputs
        attention_input = torch.cat([
            hidden[0],  # decoder hidden state
            encoder_outputs  # все выходы энкодера
        ], dim=2)
        
        attention_scores = torch.tanh(self.attention(attention_input))
        attention_weights = F.softmax(self.v(attention_scores), dim=1)
        
        context = torch.sum(attention_weights * encoder_outputs, dim=1)
        
        # Декодер использует контекст
        lstm_input = torch.cat([embedded, context.unsqueeze(0)], dim=2)
        lstm_out, new_hidden = self.lstm(lstm_input, hidden)
        
        logits = self.fc(lstm_out)
        return logits, new_hidden, attention_weights

# Применение: машинный перевод, question answering, summarization

Улучшение: Значительно повышает качество на длинных последовательностях.

3. Transformer Decoder (Self-Attention)

Архитектура:

  • Отказ от RNN в пользу pure attention
  • Позиционные кодирования для сохранения порядка
  • Мультиголовое внимание (Multi-Head Attention)
  • Feed-forward сети в каждом слое
import torch
import torch.nn as nn
import math

class TransformerDecoderLayer(nn.Module):
    def __init__(self, hidden_dim, num_heads, ffn_dim, dropout=0.1):
        super().__init__()
        
        # Self-attention: декодер смотрит на свои выходы
        self.self_attention = nn.MultiheadAttention(
            hidden_dim, num_heads, dropout=dropout
        )
        
        # Cross-attention: декодер смотрит на выходы энкодера
        self.cross_attention = nn.MultiheadAttention(
            hidden_dim, num_heads, dropout=dropout
        )
        
        # Feed-forward сеть
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, ffn_dim),
            nn.ReLU(),
            nn.Linear(ffn_dim, hidden_dim)
        )
        
        # Layer normalization
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.norm3 = nn.LayerNorm(hidden_dim)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, encoder_output, tgt_mask=None):
        # Self-attention (с маской, чтобы не смотреть на будущие токены)
        self_attn_out, _ = self.self_attention(
            x, x, x,
            attn_mask=tgt_mask  # Важно для авторегрессивной генерации!
        )
        x = self.norm1(x + self.dropout(self_attn_out))
        
        # Cross-attention (к выходам энкодера)
        cross_attn_out, _ = self.cross_attention(
            x, encoder_output, encoder_output
        )
        x = self.norm2(x + self.dropout(cross_attn_out))
        
        # Feed-forward
        ffn_out = self.ffn(x)
        x = self.norm3(x + self.dropout(ffn_out))
        
        return x

class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, hidden_dim, num_layers, num_heads, ffn_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.pos_encoding = self._create_positional_encoding(hidden_dim)
        
        self.decoder_layers = nn.ModuleList([
            TransformerDecoderLayer(hidden_dim, num_heads, ffn_dim)
            for _ in range(num_layers)
        ])
        
        self.fc_out = nn.Linear(hidden_dim, vocab_size)
    
    def _create_positional_encoding(self, hidden_dim, max_len=1000):
        pe = torch.zeros(max_len, hidden_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, hidden_dim, 2).float() * 
            (-math.log(10000.0) / hidden_dim)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe
    
    def forward(self, x, encoder_output, tgt_mask=None):
        x = self.embedding(x)
        x = x + self.pos_encoding[:x.size(0)].unsqueeze(0)
        
        for layer in self.decoder_layers:
            x = layer(x, encoder_output, tgt_mask)
        
        logits = self.fc_out(x)
        return logits

Преимущества:

  • Параллелизм — можно обрабатывать всю последовательность сразу
  • Лучшее масштабирование на длинные последовательности
  • Используется в BERT, GPT, T5

Недостатки:

  • Требует больше памяти (квадратическая сложность внимания)
  • Сложнее в реализации

4. Autoregressive Decoder (GPT-style)

Идея: Генерировать токены по одному, каждый раз предсказывая P(token_i | token_1...token_{i-1})

class GPTDecoder(nn.Module):
    def __init__(self, vocab_size, hidden_dim, num_layers, num_heads):
        super().__init__()
        self.transformer_decoder = TransformerDecoder(
            vocab_size, hidden_dim, num_layers, num_heads, ffn_dim=4*hidden_dim
        )
    
    def generate(self, prompt, max_length=100, temperature=1.0, top_k=50):
        """Авторегрессивная генерация"""
        tokens = prompt.clone()
        
        for _ in range(max_length):
            # Предсказываем следующий токен
            logits = self.transformer_decoder(tokens)
            next_token_logits = logits[-1, :] / temperature
            
            # Top-k sampling
            top_k_logits, top_k_indices = torch.topk(
                next_token_logits, min(top_k, len(logits[0]))
            )
            probabilities = torch.softmax(top_k_logits, dim=0)
            next_token = torch.multinomial(probabilities, 1)
            
            tokens = torch.cat([tokens, top_k_indices[next_token].unsqueeze(0)])
            
            if next_token == self.eos_token:
                break
        
        return tokens

Применение: GPT, LLaMA, Mistral (генерация текста)

5. Beam Search Decoding

Вместо жадной генерации (выбираем топ-1 токен):

class BeamSearchDecoder:
    def __init__(self, beam_width=5):
        self.beam_width = beam_width
    
    def decode(self, encoder_output, max_length=50):
        """Beam search для поиска лучшей последовательности"""
        batch_size = encoder_output.size(0)
        
        # Начинаем с <START> токена
        sequences = torch.full(
            (batch_size, 1), self.start_token, dtype=torch.long
        )
        sequence_scores = torch.zeros(batch_size)
        
        for step in range(max_length):
            decoder_output = self.decoder(sequences, encoder_output)
            
            # Получаем log вероятности
            log_probs = torch.log_softmax(decoder_output[:, -1], dim=-1)
            
            # Выбираем top-k кандидатов
            scores, next_tokens = torch.topk(
                log_probs, self.beam_width, dim=-1
            )
            
            # Обновляем последовательности
            new_sequences = []
            for beam in range(self.beam_width):
                new_seq = torch.cat([
                    sequences,
                    next_tokens[:, beam].unsqueeze(-1)
                ], dim=-1)
                new_sequences.append(new_seq)
            
            sequences = torch.cat(new_sequences, dim=0)
        
        return sequences

6. Pointer Networks (для задач с выбором из входа)

Применение: TSP, graph problems, copy mechanism

class PointerNetwork(nn.Module):
    """Декодер указывает на позиции в входе"""
    def __init__(self, hidden_dim):
        super().__init__()
        self.attention = nn.Linear(2 * hidden_dim, 1)
    
    def forward(self, decoder_hidden, encoder_outputs):
        # Вычисляем веса внимания для каждой позиции входа
        context = torch.cat([
            decoder_hidden.expand(encoder_outputs.size(0), -1),
            encoder_outputs
        ], dim=1)
        
        attention_scores = self.attention(context)
        probabilities = torch.softmax(attention_scores, dim=0)
        
        # Возвращаем индекс с максимальной вероятностью
        return torch.argmax(probabilities)

Сравнение архитектур

АрхитектураСкоростьКачествоПамятьИспользование
LSTMМедленноСреднееМалоНаследие, RNN задачи
LSTM + AttentionМедленноХорошееМалоSeq2seq, NMT
TransformerБыстроОтличноеМногоСовременный стандарт
AutoregressiveМедленноОтличноеМногоГенерация текста
Beam SearchОчень медленноМаксимальноеМногоProduction
Pointer NetworkМедленноХорошееМалоКомбинаторные задачи

Практический выбор

# Для Machine Translation: Transformer + Beam Search
# Для Summarization: BERT Encoder + Transformer Decoder
# Для Text Generation: GPT-style Autoregressive
# Для Q&A: Pointer Networks
# Для старого кода: LSTM + Attention
Какие знаешь архитектуры декодера? | PrepBro