Расскажите про multi-head attention в деталях
Комментарии (2)
Ответ сгенерирован нейросетью и может содержать ошибки
Multi-Head Attention в деталях
Multi-head attention — это ключевой компонент архитектуры Transformer, который позволяет модели одновременно учитывать различные представления информации из разных подпространств.
Что такое Attention
Внимание (attention) показывает, насколько сильно каждый элемент последовательности должен влиять на другие элементы. Механизм вычисляет взвешенную сумму значений на основе их релевантности к запросу.
Scaled Dot-Product Attention
Перед многоголовым вниманием нужно понять одноголовое:
import numpy as np
import torch
from torch import nn
def scaled_dot_product_attention(Q, K, V, mask=None):
d_k = Q.shape[-1]
# Вычисляем scores: (batch, seq_len, seq_len)
scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d_k)
# Применяем маску (опционально)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# Применяем softmax
attention_weights = torch.softmax(scores, dim=-1)
# Умножаем на значения
output = torch.matmul(attention_weights, V)
return output, attention_weights
Q (Query) — что ищем (текущий элемент) K (Key) — на что обращаем внимание (все элементы) V (Value) — какие значения отправляем (все элементы)
Делим на sqrt(d_k) для стабилизации градиентов (иначе softmax станет очень острым).
Multi-Head Attention
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % num_heads == 0, "d_model должен делиться на num_heads"
self.d_k = d_model // num_heads
# Линейные слои для Q, K, V и выхода
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, Q, K, V, mask=None):
batch_size = Q.shape[0]
# 1. Преобразуем через линейные слои
Q = self.W_q(Q) # (batch, seq_len, d_model)
K = self.W_k(K)
V = self.W_v(V)
# 2. Разбиваем на num_heads
Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# (batch, num_heads, seq_len, d_k)
K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# 3. Применяем scaled dot-product attention к каждой голове
attn_output, _ = scaled_dot_product_attention(Q, K, V, mask)
# (batch, num_heads, seq_len, d_k)
# 4. Объединяем головы
attn_output = attn_output.transpose(1, 2).contiguous()
# (batch, seq_len, num_heads, d_k)
attn_output = attn_output.view(batch_size, -1, self.d_model)
# (batch, seq_len, d_model)
# 5. Финальное линейное преобразование
output = self.W_o(attn_output)
return output
Почему несколько голов?
1. Разные представления подпространств — каждая голова учится обращать внимание на разные аспекты. Первая может фокусироваться на синтаксисе, вторая на семантике, третья на дальних зависимостях.
2. Параллелизм вычислений — вместо одного большого attention'а размерности d_model, мы вычисляем h меньших внимаций размерности d_k = d_model / h. Это быстрее и стабильнее.
3. Капсулирование информации — каждая голова независимо изучает свой паттерн, что улучшает обобщение.
Визуализация процесса
Вход (d_model=512, num_heads=8):
- Q, K, V размером (batch, seq_len, 512)
После разделения:
- 8 голов размером (batch, seq_len, 64) каждая
После attention:
- 8 выходов размером (batch, seq_len, 64)
После конкатенации:
- Результат размером (batch, seq_len, 512)
На практике
В BERT и GPT обычно используются 8-12 голов. В больших моделях (GPT-3, GPT-4) могут быть 96 голов. Каждая голова "специализируется" на разных паттернах в тексте.
В компьютерном зрении (Vision Transformer) также используется multi-head attention для анализа пространственных взаимодействий между patch'ами изображения.