← Назад к вопросам

Какие знаешь виды attention в нейросетях, кроме тех, что используются в трансформерах?

2.0 Middle🔥 171 комментариев
#Глубокое обучение

Комментарии (1)

🐱
claude-haiku-4.5PrepBro AI29 мар. 2026 г.(ред.)

Ответ сгенерирован нейросетью и может содержать ошибки

Виды Attention в нейросетях

Attention механизм — это мощная техника, которая позволяет сети концентрироваться на релевантных частях входных данных. Я расскажу о различных типах attention, включая те, что используются в трансформерах и за их пределами.

1. История и эволюция Attention

Хронология:
2014 → Bahdanau Attention (для seq2seq, RNN)
2015 → Luong Attention (улучшение)
2016 → Self-Attention (Cheng et al)
2017 → Multi-Head Self-Attention (Transformer, Vaswani et al)
2020+ → Variations: Linear Attention, Sparse Attention, etc.

2. Bahdanau Attention (Additive/Concat Attention)

import torch
import torch.nn as nn
import torch.nn.functional as F

class BahdanauAttention(nn.Module):
    """
    Bahdanau Attention (2014):
    Используется в seq2seq моделях с RNN/LSTM
    
    Формула:
    score(h, s) = tanh(W[h; s])
    где h = encoder hidden state, s = decoder hidden state
    """
    
    def __init__(self, decoder_dim, encoder_dim):
        super().__init__()
        self.decoder_dim = decoder_dim
        self.encoder_dim = encoder_dim
        
        # Веса для вычисления scores
        self.linear_query = nn.Linear(decoder_dim, encoder_dim)
        self.linear_context = nn.Linear(encoder_dim, encoder_dim)
        self.v = nn.Linear(encoder_dim, 1)
    
    def forward(self, decoder_hidden, encoder_outputs):
        """
        decoder_hidden: [batch, decoder_dim]
        encoder_outputs: [batch, seq_len, encoder_dim]
        
        Returns: context, attention_weights
        """
        batch_size = encoder_outputs.shape[0]
        seq_len = encoder_outputs.shape[1]
        
        # Expand decoder_hidden для broadcasting
        # [batch, 1, encoder_dim]
        decoder_hidden_expanded = self.linear_query(decoder_hidden).unsqueeze(1)
        
        # Compute energy (scores)
        # energy = tanh(W_1 * encoder_hidden + W_2 * decoder_hidden)
        encoder_transformed = self.linear_context(encoder_outputs)  # [batch, seq, encoder_dim]
        combined = torch.tanh(decoder_hidden_expanded + encoder_transformed)  # [batch, seq, encoder_dim]
        energy = self.v(combined).squeeze(-1)  # [batch, seq]
        
        # Softmax для получения weights
        attention_weights = F.softmax(energy, dim=1)  # [batch, seq]
        
        # Context = weighted sum of encoder outputs
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)  # [batch, 1, encoder_dim]
        context = context.squeeze(1)  # [batch, encoder_dim]
        
        return context, attention_weights

# Пример использования
bahdanau = BahdanauAttention(decoder_dim=128, encoder_dim=256)

decoder_hidden = torch.randn(32, 128)  # batch=32, hidden_dim=128
encoder_outputs = torch.randn(32, 10, 256)  # batch=32, seq_len=10, hidden_dim=256

context, weights = bahdanau(decoder_hidden, encoder_outputs)
print(f"Context shape: {context.shape}")  # [32, 256]
print(f"Attention weights shape: {weights.shape}")  # [32, 10]

3. Luong Attention (Multiplicative/Dot-Product)

class LuongAttention(nn.Module):
    """
    Luong Attention (2015):
    Более простая и эффективная версия Bahdanau
    
    Формула:
    score(h, s) = h^T * W * s (dot product with weight matrix)
    или даже просто: score = h^T * s (без W)
    """
    
    def __init__(self, encoder_dim, attention_type='dot'):
        super().__init__()
        self.encoder_dim = encoder_dim
        self.attention_type = attention_type
        
        if attention_type == 'general':
            # Multiplicative avec weight matrix
            self.weight = nn.Linear(encoder_dim, encoder_dim)
    
    def forward(self, decoder_hidden, encoder_outputs):
        """
        decoder_hidden: [batch, encoder_dim]
        encoder_outputs: [batch, seq_len, encoder_dim]
        """
        
        if self.attention_type == 'dot':
            # Простой dot product
            # score = decoder_hidden @ encoder_outputs^T
            scores = torch.bmm(
                decoder_hidden.unsqueeze(1),  # [batch, 1, dim]
                encoder_outputs.transpose(1, 2)  # [batch, dim, seq]
            )  # [batch, 1, seq]
            scores = scores.squeeze(1)  # [batch, seq]
        
        elif self.attention_type == 'general':
            # score = decoder_hidden @ W @ encoder_outputs^T
            decoder_transformed = self.weight(decoder_hidden)  # [batch, dim]
            scores = torch.bmm(
                decoder_transformed.unsqueeze(1),
                encoder_outputs.transpose(1, 2)
            )  # [batch, 1, seq]
            scores = scores.squeeze(1)  # [batch, seq]
        
        # Normalize
        attention_weights = F.softmax(scores / torch.sqrt(torch.tensor(self.encoder_dim, dtype=torch.float32)), dim=1)
        
        # Context
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)
        context = context.squeeze(1)
        
        return context, attention_weights

# Пример
luong = LuongAttention(encoder_dim=256, attention_type='general')
context, weights = luong(decoder_hidden, encoder_outputs)
print(f"\nLuong - Context shape: {context.shape}")  # [32, 256]

4. Self-Attention (но не Multi-Head)

class SimpleScaledDotProductAttention(nn.Module):
    """
    Простой Scaled Dot-Product Attention (основа трансформера)
    Но это улучшенная версия для одного head
    (Multi-Head это набор таких слоёв)
    """
    
    def forward(self, query, key, value, mask=None):
        """
        query, key, value: [batch, seq_len, d_model]
        mask: [batch, seq_len] (для padding/future masking)
        """
        # 1. Compute scores
        scores = torch.matmul(query, key.transpose(-1, -2))  # [batch, seq, seq]
        scores = scores / torch.sqrt(torch.tensor(query.shape[-1], dtype=torch.float32))
        
        # 2. Apply mask (if provided)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # 3. Softmax
        attention_weights = F.softmax(scores, dim=-1)  # [batch, seq, seq]
        
        # 4. Apply to values
        output = torch.matmul(attention_weights, value)  # [batch, seq, d_model]
        
        return output, attention_weights

5. Cross-Attention (для seq2seq)

class CrossAttention(nn.Module):
    """
    Cross-Attention:
    Decoder attends to Encoder outputs
    
    Query: from decoder
    Key, Value: from encoder
    
    Используется в:
    - seq2seq models (translation, summarization)
    - Multimodal models (text attends to image)
    """
    
    def __init__(self, d_model):
        super().__init__()
        self.query_linear = nn.Linear(d_model, d_model)
        self.key_linear = nn.Linear(d_model, d_model)
        self.value_linear = nn.Linear(d_model, d_model)
        self.out_linear = nn.Linear(d_model, d_model)
    
    def forward(self, decoder_hidden, encoder_outputs):
        """
        decoder_hidden: [batch, seq_dec, d_model]
        encoder_outputs: [batch, seq_enc, d_model]
        
        Returns: attended representation
        """
        # Generate Q from decoder, K, V from encoder
        query = self.query_linear(decoder_hidden)  # [batch, seq_dec, d_model]
        key = self.key_linear(encoder_outputs)  # [batch, seq_enc, d_model]
        value = self.value_linear(encoder_outputs)  # [batch, seq_enc, d_model]
        
        # Compute attention
        scores = torch.matmul(query, key.transpose(-1, -2))  # [batch, seq_dec, seq_enc]
        scores = scores / torch.sqrt(torch.tensor(query.shape[-1], dtype=torch.float32))
        
        attention_weights = F.softmax(scores, dim=-1)
        
        # Combine with values
        attended = torch.matmul(attention_weights, value)  # [batch, seq_dec, d_model]
        output = self.out_linear(attended)
        
        return output, attention_weights

6. Spatial Attention (для компьютерного зрения)

class SpatialAttention(nn.Module):
    """
    Spatial Attention (для CNN/images):
    Сеть учится какие пространственные области важны
    
    Отличие от sequence attention:
    - Работает с 2D spatial maps (height x width)
    - Обычно используется после CNN слоёв
    """
    
    def __init__(self, kernel_size=7):
        super().__init__()
        padding = kernel_size // 2
        
        self.conv = nn.Conv2d(
            in_channels=2,  # Concatenate mean и max pooling
            out_channels=1,
            kernel_size=kernel_size,
            padding=padding,
            bias=False
        )
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        """
        x: [batch, channels, height, width]
        Returns: weighted x with spatial attention map
        """
        # Channel-wise statistics
        avg_pool = torch.mean(x, dim=1, keepdim=True)  # [batch, 1, h, w]
        max_pool, _ = torch.max(x, dim=1, keepdim=True)  # [batch, 1, h, w]
        
        # Concatenate
        concat = torch.cat([avg_pool, max_pool], dim=1)  # [batch, 2, h, w]
        
        # Generate attention map
        attention_map = self.sigmoid(self.conv(concat))  # [batch, 1, h, w]
        
        # Apply attention
        output = x * attention_map
        
        return output, attention_map

# Пример
spatial_att = SpatialAttention(kernel_size=7)
feature_maps = torch.randn(32, 64, 28, 28)  # batch=32, channels=64, H=28, W=28
attended, att_map = spatial_att(feature_maps)
print(f"\nSpatial Attention output shape: {attended.shape}")  # [32, 64, 28, 28]

7. Channel Attention (для компьютерного зрения)

class ChannelAttention(nn.Module):
    """
    Channel Attention (SE-Net, Squeeze-and-Excitation):
    Сеть учится какие каналы (features) важны
    
    Процесс:
    1. Squeeze: Global Average Pooling
    2. Excitation: FC layers для переобучения weights каналов
    """
    
    def __init__(self, num_channels, reduction_ratio=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        # FC bottleneck
        reduced_channels = max(num_channels // reduction_ratio, 1)
        self.fc = nn.Sequential(
            nn.Linear(num_channels, reduced_channels),
            nn.ReLU(inplace=True),
            nn.Linear(reduced_channels, num_channels)
        )
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        """
        x: [batch, channels, height, width]
        """
        # Squeeze
        avg = self.avg_pool(x)  # [batch, channels, 1, 1]
        max_ = self.max_pool(x)  # [batch, channels, 1, 1]
        
        avg = avg.view(avg.shape[0], -1)  # [batch, channels]
        max_ = max_.view(max_.shape[0], -1)  # [batch, channels]
        
        # Excitation
        avg_out = self.fc(avg)  # [batch, channels]
        max_out = self.fc(max_)  # [batch, channels]
        
        # Combine и sigmoid
        channel_weights = self.sigmoid(avg_out + max_out)  # [batch, channels]
        
        # Apply
        channel_weights = channel_weights.view(channel_weights.shape[0], -1, 1, 1)
        output = x * channel_weights
        
        return output, channel_weights

# Пример
channel_att = ChannelAttention(num_channels=64, reduction_ratio=16)
feature_maps = torch.randn(32, 64, 28, 28)
attended, weights = channel_att(feature_maps)
print(f"Channel Attention output shape: {attended.shape}")  # [32, 64, 28, 28]

8. Sparse Attention (Efficient Transformers)

class SparseAttention(nn.Module):
    """
    Sparse Attention (для длинных последовательностей):
    
    Вместо полной O(n^2) матрицы внимания,
    вычисляем внимание только к ограниченному окну токенов
    
    Типы:
    1. Local attention: только соседние токены
    2. Strided attention: каждый k-й токен
    3. Fixed attention: фиксированные позиции
    4. Reformer (LSH): использует locality-sensitive hashing
    """
    
    def __init__(self, window_size=64):
        self.window_size = window_size
    
    def forward(self, query, key, value):
        """
        query, key, value: [batch, seq_len, d_model]
        
        Пример: Local attention window_size=64
        - Каждый токен внимает только к 64 соседним токенам
        - Вместо всех seq_len
        """
        batch_size, seq_len, d_model = query.shape
        
        # Create mask для локального окна
        attention_mask = torch.zeros((seq_len, seq_len))
        for i in range(seq_len):
            # Окно around позиции i
            start = max(0, i - self.window_size // 2)
            end = min(seq_len, i + self.window_size // 2)
            attention_mask[i, start:end] = 1
        
        # Обычное scaled dot-product attention,
        # но с mask для sparse connectivity
        scores = torch.matmul(query, key.transpose(-1, -2))
        scores = scores / torch.sqrt(torch.tensor(d_model, dtype=torch.float32))
        
        # Apply sparse mask
        scores = scores.masked_fill(attention_mask == 0, -1e9)
        
        attention_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, value)
        
        return output, attention_weights

9. Сравнение видов Attention

┌──────────────────────┬────────────────────┬──────────────────────┐
│ Type                 │ Use Case           │ Complexity           │
├──────────────────────┼────────────────────┼──────────────────────┤
│ Bahdanau Additive    │ seq2seq RNN        │ O(n*m) where n,m=lengths │
│ Luong Multiplicative │ seq2seq RNN        │ O(n*m) (faster than BA) │
│ Self-Attention       │ intra-sequence     │ O(n^2) for full seq  │
│ Cross-Attention      │ seq2seq Transformer│ O(n*m)               │
│ Multi-Head Self-Att  │ Transformer        │ O(n^2) but parallelizable │
│ Spatial Attention    │ CNN features       │ O(h*w) for 2D spatial│
│ Channel Attention    │ CNN features       │ O(c) very fast       │
│ Sparse Attention     │ Long sequences     │ O(n*w) where w=window│
│ Linear Attention     │ Efficient          │ O(n) with approximation │
└──────────────────────┴────────────────────┴──────────────────────┘

10. Практический выбор

def choose_attention_mechanism(task):
    """
    Рекомендации по выбору attention
    """
    
    if task == 'seq2seq_translation':
        return "Cross-Attention (encoder-decoder)"
    
    elif task == 'language_modeling':
        return "Multi-Head Self-Attention (causal mask)"
    
    elif task == 'image_classification':
        return "Channel Attention (lightweight) or Spatial Attention"
    
    elif task == 'object_detection':
        return "Spatial + Channel Attention"
    
    elif task == 'long_document_processing':
        return "Sparse Attention or Linear Attention"
    
    elif task == 'rnn_sequence_model':
        return "Luong Attention (dot-product) или Bahdanau"
    
    elif task == 'efficient_transformer':
        return "Linear Attention или Sparse with local window"
    
    else:
        return "Multi-Head Self-Attention (universal choice)"

# Примеры
print("Выбор attention механизма:")
for task in ['seq2seq_translation', 'image_classification', 'long_document_processing']:
    print(f"- {task}: {choose_attention_mechanism(task)}")

Заключение

Основные виды Attention:

  1. Bahdanau (Additive): Для RNN seq2seq, медленнее
  2. Luong (Multiplicative): Для RNN seq2seq, быстрее
  3. Scaled Dot-Product: Основа трансформера
  4. Multi-Head Self-Attention: Стандарт для трансформеров
  5. Cross-Attention: Для decoder-encoder interaction
  6. Spatial Attention: Для 2D features (images)
  7. Channel Attention (SE-Net): Для важности каналов
  8. Sparse Attention: Для long sequences
  9. Linear Attention: Для efficient processing

Key Takeaway:

  • Attention как концепция универсален
  • Детали реализации зависят от task и data
  • Multi-Head Self-Attention стал стандартом, но не единственный вариант
  • Выбор attention вида критичен для performance