← Назад к вопросам

Расскажите про multi-head attention в деталях

1.0 Junior🔥 192 комментариев
#NLP и обработка текста#Глубокое обучение

Комментарии (2)

🐱
claude-haiku-4.5PrepBro AI30 мар. 2026 г.(ред.)

Ответ сгенерирован нейросетью и может содержать ошибки

Multi-Head Attention в деталях

Multi-head attention — это ключевой компонент архитектуры Transformer, который позволяет модели одновременно учитывать различные представления информации из разных подпространств.

Что такое Attention

Внимание (attention) показывает, насколько сильно каждый элемент последовательности должен влиять на другие элементы. Механизм вычисляет взвешенную сумму значений на основе их релевантности к запросу.

Scaled Dot-Product Attention

Перед многоголовым вниманием нужно понять одноголовое:

import numpy as np
import torch
from torch import nn

def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.shape[-1]
    
    # Вычисляем scores: (batch, seq_len, seq_len)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d_k)
    
    # Применяем маску (опционально)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # Применяем softmax
    attention_weights = torch.softmax(scores, dim=-1)
    
    # Умножаем на значения
    output = torch.matmul(attention_weights, V)
    
    return output, attention_weights

Q (Query) — что ищем (текущий элемент) K (Key) — на что обращаем внимание (все элементы) V (Value) — какие значения отправляем (все элементы)

Делим на sqrt(d_k) для стабилизации градиентов (иначе softmax станет очень острым).

Multi-Head Attention

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        
        assert d_model % num_heads == 0, "d_model должен делиться на num_heads"
        
        self.d_k = d_model // num_heads
        
        # Линейные слои для Q, K, V и выхода
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
    
    def forward(self, Q, K, V, mask=None):
        batch_size = Q.shape[0]
        
        # 1. Преобразуем через линейные слои
        Q = self.W_q(Q)  # (batch, seq_len, d_model)
        K = self.W_k(K)
        V = self.W_v(V)
        
        # 2. Разбиваем на num_heads
        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        # (batch, num_heads, seq_len, d_k)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # 3. Применяем scaled dot-product attention к каждой голове
        attn_output, _ = scaled_dot_product_attention(Q, K, V, mask)
        # (batch, num_heads, seq_len, d_k)
        
        # 4. Объединяем головы
        attn_output = attn_output.transpose(1, 2).contiguous()
        # (batch, seq_len, num_heads, d_k)
        attn_output = attn_output.view(batch_size, -1, self.d_model)
        # (batch, seq_len, d_model)
        
        # 5. Финальное линейное преобразование
        output = self.W_o(attn_output)
        
        return output

Почему несколько голов?

1. Разные представления подпространств — каждая голова учится обращать внимание на разные аспекты. Первая может фокусироваться на синтаксисе, вторая на семантике, третья на дальних зависимостях.

2. Параллелизм вычислений — вместо одного большого attention'а размерности d_model, мы вычисляем h меньших внимаций размерности d_k = d_model / h. Это быстрее и стабильнее.

3. Капсулирование информации — каждая голова независимо изучает свой паттерн, что улучшает обобщение.

Визуализация процесса

Вход (d_model=512, num_heads=8):

  • Q, K, V размером (batch, seq_len, 512)

После разделения:

  • 8 голов размером (batch, seq_len, 64) каждая

После attention:

  • 8 выходов размером (batch, seq_len, 64)

После конкатенации:

  • Результат размером (batch, seq_len, 512)

На практике

В BERT и GPT обычно используются 8-12 голов. В больших моделях (GPT-3, GPT-4) могут быть 96 голов. Каждая голова "специализируется" на разных паттернах в тексте.

В компьютерном зрении (Vision Transformer) также используется multi-head attention для анализа пространственных взаимодействий между patch'ами изображения.