Комментарии (1)
Ответ сгенерирован нейросетью и может содержать ошибки
Как работает механизм attention?
Mechanism attention (внимание) — ключевой компонент современных нейронных сетей, особенно в обработке последовательностей. Он позволяет модели сосредоточиться на наиболее релевантных частях входных данных.
1. Основная идея
Вместо того чтобы одинаково обрабатывать все входные элементы, attention механизм взвешивает (присваивает важность) каждому элементу. Элементы с высокой важностью больше влияют на выходной результат.
Вход: [слово_1, слово_2, слово_3, слово_4, слово_5]
Веса: [0.1, 0.05, 0.6, 0.15, 0.1] <- механизм внимания
Выход: 0.1*слово_1 + 0.05*слово_2 + 0.6*слово_3 + 0.15*слово_4 + 0.1*слово_5
2. Scaled Dot-Product Attention
Самый распространённый механизм. Вычисляется в три этапа:
import numpy as np
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(query, key, value, mask=None):
"""
Scaled Dot-Product Attention механизм
query (Q): что мы ищем (размер: seq_len, d_k)
key (K): что есть в памяти (размер: seq_len, d_k)
value (V): значения в памяти (размер: seq_len, d_v)
Возвращает: выход внимания и веса внимания
"""
d_k = query.shape[-1]
# Шаг 1: Вычисляем scores = Q * K^T
# Это показывает, насколько каждый query похож на каждый key
scores = torch.matmul(query, key.transpose(-2, -1)) # (seq_len, seq_len)
print(f'Scores shape: {scores.shape}')
print(f'Scores example:\n{scores[0, :]}')
# Шаг 2: Нормализуем (делим на sqrt(d_k)) - это улучшает градиенты
scores = scores / np.sqrt(d_k)
# Шаг 3: Применяем маску (если нужна) - предотвращает внимание к будущим словам
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# Шаг 4: Применяем softmax - получаем веса внимания от 0 до 1
attention_weights = F.softmax(scores, dim=-1) # (seq_len, seq_len)
print(f'Attention weights shape: {attention_weights.shape}')
print(f'Attention weights (сумма по строке = 1):\n{attention_weights[0, :]}')
# Шаг 5: Умножаем веса на values - получаем выход
output = torch.matmul(attention_weights, value) # (seq_len, d_v)
return output, attention_weights
# Пример
batch_size = 2
seq_len = 4
d_k = 64
d_v = 64
Q = torch.randn(batch_size, seq_len, d_k) # Queries
K = torch.randn(batch_size, seq_len, d_k) # Keys
V = torch.randn(batch_size, seq_len, d_v) # Values
output, weights = scaled_dot_product_attention(Q[0], K[0], V[0])
print(f'\nOutput shape: {output.shape}')
Интуиция:
- Q * K^T вычисляет «сходство» между query и key
- Softmax превращает scores в вероятности (веса)
- Взвешенное суммирование values дает результат
3. Multi-Head Attention
Множественные головы внимания позволяют модели смотреть на разные аспекты данных:
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0, 'd_model должен делиться на num_heads'
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # Размер каждой головы
# Линейные проекции для Q, K, V и выхода
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, query, key, value, mask=None):
batch_size = query.shape[0]
# Проекция Q, K, V
Q = self.W_q(query) # (batch, seq_len, d_model)
K = self.W_k(key)
V = self.W_v(value)
# Разбиваем на multiple heads
# (batch, seq_len, d_model) -> (batch, seq_len, num_heads, d_k) -> (batch, num_heads, seq_len, d_k)
Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# Применяем attention для каждой головы параллельно
scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention_weights = F.softmax(scores, dim=-1)
context = torch.matmul(attention_weights, V) # (batch, num_heads, seq_len, d_k)
# Объединяем heads обратно
context = context.transpose(1, 2).contiguous() # (batch, seq_len, num_heads, d_k)
context = context.view(batch_size, -1, self.d_model) # (batch, seq_len, d_model)
# Финальная линейная проекция
output = self.W_o(context)
return output, attention_weights
# Пример использования
multi_head_attn = MultiHeadAttention(d_model=512, num_heads=8)
output, weights = multi_head_attn(Q, K, V)
print(f'Output shape: {output.shape}') # (batch_size, seq_len, d_model)
print(f'Attention weights shape: {weights.shape}') # (batch_size, num_heads, seq_len, seq_len)
4. Self-Attention
Специальный случай, где query, key, value берутся из одного источника — сама последовательность смотрит на себя:
class SelfAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.multi_head_attn = MultiHeadAttention(d_model, num_heads)
def forward(self, x, mask=None):
# Q = K = V = x (самовнимание)
return self.multi_head_attn(x, x, x, mask)
# Пример: каждое слово смотрит на все остальные слова в предложении
self_attn = SelfAttention(d_model=512, num_heads=8)
output, _ = self_attn(input_sequence)
5. Causal Attention (для языковых моделей)
Предотвращает внимание к будущим словам (маска для авторегрессивного предсказания):
def create_causal_mask(seq_len, device):
"""
Создаёт маску, которая позволяет модели смотреть только на прошлые токены
Пример для seq_len=4:
[[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 1, 1, 1]]
"""
mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
return mask
# Применение
causal_mask = create_causal_mask(seq_len=4, device='cpu')
print(f'Causal mask:\n{causal_mask}')
# При вычислении attention, где mask == 0, выходит -inf (мягкая маска)
output, weights = scaled_dot_product_attention(Q, K, V, mask=causal_mask)
6. Практический пример: Transformer блок
class TransformerBlock(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
# Self-attention
self.self_attention = MultiHeadAttention(d_model, num_heads)
self.norm1 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
# Feed-forward network
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
self.norm2 = nn.LayerNorm(d_model)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Self-attention с residual connection
attn_output, _ = self.self_attention(x, x, x, mask)
x = x + self.dropout1(attn_output)
x = self.norm1(x)
# Feed-forward с residual connection
ff_output = self.ff(x)
x = x + self.dropout2(ff_output)
x = self.norm2(x)
return x
# Использование в стеке
transformer_block = TransformerBlock(d_model=512, num_heads=8, d_ff=2048)
output = transformer_block(input_sequence)
7. Визуализация внимания
import matplotlib.pyplot as plt
import numpy as np
def visualize_attention(attention_weights, words):
"""
Визуализирует матрицу внимания для одной головы
attention_weights: shape (seq_len, seq_len)
words: список слов в последовательности
"""
plt.figure(figsize=(10, 8))
plt.imshow(attention_weights.detach().numpy(), cmap='viridis')
plt.colorbar(label='Attention weight')
plt.xticks(range(len(words)), words, rotation=45)
plt.yticks(range(len(words)), words)
plt.xlabel('Key (на что смотрим)')
plt.ylabel('Query (из какого слова смотрим)')
plt.title('Attention Heatmap')
plt.tight_layout()
plt.show()
# Пример
words = ['The', 'cat', 'sat', 'on', 'mat']
output, weights = self_attn(torch.randn(1, 5, 512)) # Размер (1, 5, seq_len, seq_len) для одной головы
visualize_attention(weights[0, 0], words) # Первый батч, первая голова
Ключевые характеристики attention
Преимущества:
- Параллельная обработка всех позиций (в отличие от RNN)
- Прямое взаимодействие между дальними элементами
- Интерпретируемость (можем посмотреть, на что модель обращает внимание)
- Масштабируемость для длинных последовательностей (в некоторых вариантах)
Сложность:
- Квадратичная память O(seq_len²) при стандартной реализации
- Требует больше памяти для длинных последовательностей
Варианты:
- Linear attention — O(seq_len) сложность
- Sparse attention — не все позиции смотрят друг на друга
- Local attention — только соседние позиции
Mechanism attention — основа современных моделей (Transformers, BERT, GPT) и позволяет им достичь впечатляющих результатов в NLP и beyond!