← Назад к вопросам

Как работает механизм attention?

1.0 Junior🔥 251 комментариев
#Глубокое обучение

Комментарии (1)

🐱
claude-haiku-4.5PrepBro AI29 мар. 2026 г.(ред.)

Ответ сгенерирован нейросетью и может содержать ошибки

Как работает механизм attention?

Mechanism attention (внимание) — ключевой компонент современных нейронных сетей, особенно в обработке последовательностей. Он позволяет модели сосредоточиться на наиболее релевантных частях входных данных.

1. Основная идея

Вместо того чтобы одинаково обрабатывать все входные элементы, attention механизм взвешивает (присваивает важность) каждому элементу. Элементы с высокой важностью больше влияют на выходной результат.

Вход: [слово_1, слово_2, слово_3, слово_4, слово_5]
Веса: [0.1,   0.05,   0.6,   0.15,  0.1]  <- механизм внимания
Выход: 0.1*слово_1 + 0.05*слово_2 + 0.6*слово_3 + 0.15*слово_4 + 0.1*слово_5

2. Scaled Dot-Product Attention

Самый распространённый механизм. Вычисляется в три этапа:

import numpy as np
import torch
import torch.nn.functional as F

def scaled_dot_product_attention(query, key, value, mask=None):
    """
    Scaled Dot-Product Attention механизм
    
    query (Q): что мы ищем (размер: seq_len, d_k)
    key (K): что есть в памяти (размер: seq_len, d_k)
    value (V): значения в памяти (размер: seq_len, d_v)
    
    Возвращает: выход внимания и веса внимания
    """
    
    d_k = query.shape[-1]
    
    # Шаг 1: Вычисляем scores = Q * K^T
    # Это показывает, насколько каждый query похож на каждый key
    scores = torch.matmul(query, key.transpose(-2, -1))  # (seq_len, seq_len)
    
    print(f'Scores shape: {scores.shape}')
    print(f'Scores example:\n{scores[0, :]}')
    
    # Шаг 2: Нормализуем (делим на sqrt(d_k)) - это улучшает градиенты
    scores = scores / np.sqrt(d_k)
    
    # Шаг 3: Применяем маску (если нужна) - предотвращает внимание к будущим словам
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # Шаг 4: Применяем softmax - получаем веса внимания от 0 до 1
    attention_weights = F.softmax(scores, dim=-1)  # (seq_len, seq_len)
    
    print(f'Attention weights shape: {attention_weights.shape}')
    print(f'Attention weights (сумма по строке = 1):\n{attention_weights[0, :]}')
    
    # Шаг 5: Умножаем веса на values - получаем выход
    output = torch.matmul(attention_weights, value)  # (seq_len, d_v)
    
    return output, attention_weights

# Пример
batch_size = 2
seq_len = 4
d_k = 64
d_v = 64

Q = torch.randn(batch_size, seq_len, d_k)  # Queries
K = torch.randn(batch_size, seq_len, d_k)  # Keys
V = torch.randn(batch_size, seq_len, d_v)  # Values

output, weights = scaled_dot_product_attention(Q[0], K[0], V[0])
print(f'\nOutput shape: {output.shape}')

Интуиция:

  • Q * K^T вычисляет «сходство» между query и key
  • Softmax превращает scores в вероятности (веса)
  • Взвешенное суммирование values дает результат

3. Multi-Head Attention

Множественные головы внимания позволяют модели смотреть на разные аспекты данных:

import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, 'd_model должен делиться на num_heads'
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # Размер каждой головы
        
        # Линейные проекции для Q, K, V и выхода
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
    
    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]
        
        # Проекция Q, K, V
        Q = self.W_q(query)  # (batch, seq_len, d_model)
        K = self.W_k(key)
        V = self.W_v(value)
        
        # Разбиваем на multiple heads
        # (batch, seq_len, d_model) -> (batch, seq_len, num_heads, d_k) -> (batch, num_heads, seq_len, d_k)
        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # Применяем attention для каждой головы параллельно
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attention_weights = F.softmax(scores, dim=-1)
        context = torch.matmul(attention_weights, V)  # (batch, num_heads, seq_len, d_k)
        
        # Объединяем heads обратно
        context = context.transpose(1, 2).contiguous()  # (batch, seq_len, num_heads, d_k)
        context = context.view(batch_size, -1, self.d_model)  # (batch, seq_len, d_model)
        
        # Финальная линейная проекция
        output = self.W_o(context)
        
        return output, attention_weights

# Пример использования
multi_head_attn = MultiHeadAttention(d_model=512, num_heads=8)
output, weights = multi_head_attn(Q, K, V)
print(f'Output shape: {output.shape}')  # (batch_size, seq_len, d_model)
print(f'Attention weights shape: {weights.shape}')  # (batch_size, num_heads, seq_len, seq_len)

4. Self-Attention

Специальный случай, где query, key, value берутся из одного источника — сама последовательность смотрит на себя:

class SelfAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.multi_head_attn = MultiHeadAttention(d_model, num_heads)
    
    def forward(self, x, mask=None):
        # Q = K = V = x (самовнимание)
        return self.multi_head_attn(x, x, x, mask)

# Пример: каждое слово смотрит на все остальные слова в предложении
self_attn = SelfAttention(d_model=512, num_heads=8)
output, _ = self_attn(input_sequence)

5. Causal Attention (для языковых моделей)

Предотвращает внимание к будущим словам (маска для авторегрессивного предсказания):

def create_causal_mask(seq_len, device):
    """
    Создаёт маску, которая позволяет модели смотреть только на прошлые токены
    
    Пример для seq_len=4:
    [[1, 0, 0, 0],
     [1, 1, 0, 0],
     [1, 1, 1, 0],
     [1, 1, 1, 1]]
    """
    mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
    return mask

# Применение
causal_mask = create_causal_mask(seq_len=4, device='cpu')
print(f'Causal mask:\n{causal_mask}')

# При вычислении attention, где mask == 0, выходит -inf (мягкая маска)
output, weights = scaled_dot_product_attention(Q, K, V, mask=causal_mask)

6. Практический пример: Transformer блок

class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        
        # Self-attention
        self.self_attention = MultiHeadAttention(d_model, num_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        
        # Feed-forward network
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout2 = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # Self-attention с residual connection
        attn_output, _ = self.self_attention(x, x, x, mask)
        x = x + self.dropout1(attn_output)
        x = self.norm1(x)
        
        # Feed-forward с residual connection
        ff_output = self.ff(x)
        x = x + self.dropout2(ff_output)
        x = self.norm2(x)
        
        return x

# Использование в стеке
transformer_block = TransformerBlock(d_model=512, num_heads=8, d_ff=2048)
output = transformer_block(input_sequence)

7. Визуализация внимания

import matplotlib.pyplot as plt
import numpy as np

def visualize_attention(attention_weights, words):
    """
    Визуализирует матрицу внимания для одной головы
    
    attention_weights: shape (seq_len, seq_len)
    words: список слов в последовательности
    """
    
    plt.figure(figsize=(10, 8))
    plt.imshow(attention_weights.detach().numpy(), cmap='viridis')
    plt.colorbar(label='Attention weight')
    
    plt.xticks(range(len(words)), words, rotation=45)
    plt.yticks(range(len(words)), words)
    plt.xlabel('Key (на что смотрим)')
    plt.ylabel('Query (из какого слова смотрим)')
    plt.title('Attention Heatmap')
    plt.tight_layout()
    plt.show()

# Пример
words = ['The', 'cat', 'sat', 'on', 'mat']
output, weights = self_attn(torch.randn(1, 5, 512))  # Размер (1, 5, seq_len, seq_len) для одной головы
visualize_attention(weights[0, 0], words)  # Первый батч, первая голова

Ключевые характеристики attention

Преимущества:

  • Параллельная обработка всех позиций (в отличие от RNN)
  • Прямое взаимодействие между дальними элементами
  • Интерпретируемость (можем посмотреть, на что модель обращает внимание)
  • Масштабируемость для длинных последовательностей (в некоторых вариантах)

Сложность:

  • Квадратичная память O(seq_len²) при стандартной реализации
  • Требует больше памяти для длинных последовательностей

Варианты:

  • Linear attention — O(seq_len) сложность
  • Sparse attention — не все позиции смотрят друг на друга
  • Local attention — только соседние позиции

Mechanism attention — основа современных моделей (Transformers, BERT, GPT) и позволяет им достичь впечатляющих результатов в NLP и beyond!

Как работает механизм attention? | PrepBro