← Назад к вопросам
PyTorch: Реализовать attention mechanism
2.4 Senior🔥 91 комментариев
#Python#Глубокое обучение
Условие
Реализуйте механизм внимания (attention) на PyTorch.
Требования:
- Scaled dot-product attention
- Multi-head attention
- Тестирование на простом примере
Комментарии (1)
🐱
claude-haiku-4.5PrepBro AI23 мар. 2026 г.(ред.)
Ответ сгенерирован нейросетью и может содержать ошибки
Решение
Реализация механизма внимания в PyTorch.
1. Scaled Dot-Product Attention
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class ScaledDotProductAttention(nn.Module):
"""Механизм внимания: Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V"""
def __init__(self, d_k):
super().__init__()
self.d_k = d_k
def forward(self, Q, K, V, mask=None):
"""
Args:
Q: Query (batch_size, seq_len, d_k)
K: Key (batch_size, seq_len, d_k)
V: Value (batch_size, seq_len, d_v)
mask: маска для скрытия некоторых позиций (опционально)
Returns:
output: (batch_size, seq_len, d_v)
attention_weights: (batch_size, seq_len, seq_len)
"""
# Вычисляем оценки внимания
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# scores shape: (batch_size, seq_len_q, seq_len_k)
# Применяем маску (если нужна)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# Применяем softmax для нормализации
attention_weights = F.softmax(scores, dim=-1)
# Получаем взвешенную сумму значений
output = torch.matmul(attention_weights, V)
# output shape: (batch_size, seq_len_q, d_v)
return output, attention_weights
# Пример использования
print("=== SCALED DOT-PRODUCT ATTENTION ===")
# Параметры
batch_size, seq_len, d_k = 2, 4, 64 # 2 примера, длина последовательности 4, размер ключа 64
# Инициализируем Q, K, V
Q = torch.randn(batch_size, seq_len, d_k) # Query
K = torch.randn(batch_size, seq_len, d_k) # Key
V = torch.randn(batch_size, seq_len, d_k) # Value
print(f"Q shape: {Q.shape}")
print(f"K shape: {K.shape}")
print(f"V shape: {V.shape}")
# Применяем attention
attention = ScaledDotProductAttention(d_k=d_k)
output, weights = attention(Q, K, V)
print(f"\nОутпут shape: {output.shape}")
print(f"Веса внимания shape: {weights.shape}")
print(f"\nВеса внимания для первого примера:")
print(weights[0]) # сумма по строкам должна быть 1
print(f"Проверка: сумма весов = {weights[0].sum(dim=-1)}")
print(f"""
=== ОБЪЯСНЕНИЕ ===
1. Scores = Q @ K^T / sqrt(d_k)
- Q @ K^T дает оценку релевантности каждого запроса ко всем ключам
- Делим на sqrt(d_k) для стабилизации градиентов
2. Softmax(scores)
- Нормализуем оценки в распределение вероятностей
- Каждый запрос получает веса релевантности
3. Output = softmax(scores) @ V
- Берем взвешенную сумму значений
- Получаем контекст каждого запроса
Мотивация: механизм фокусирует на релевантные части входа
""")
2. Multi-Head Attention
class MultiHeadAttention(nn.Module):
"""Multi-Head Attention: несколько параллельных механизмов внимания"""
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0, "d_model должен делиться на num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # размер каждой головы
# Линейные преобразования для Q, K, V и финального вывода
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.scaled_dot_product_attention = ScaledDotProductAttention(self.d_k)
def forward(self, Q, K, V, mask=None):
"""
Args:
Q, K, V: (batch_size, seq_len, d_model)
mask: маска для скрытия позиций
Returns:
output: (batch_size, seq_len, d_model)
"""
batch_size = Q.shape[0]
# 1. Линейное преобразование
Q = self.W_q(Q) # (batch_size, seq_len, d_model)
K = self.W_k(K)
V = self.W_v(V)
# 2. Разделяем на несколько голов
# (batch_size, seq_len, d_model) → (batch_size, seq_len, num_heads, d_k)
# → (batch_size, num_heads, seq_len, d_k)
Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# 3. Применяем attention к каждой голове
attn_output, attn_weights = self.scaled_dot_product_attention(Q, K, V, mask)
# attn_output: (batch_size, num_heads, seq_len, d_k)
# 4. Объединяем головы
# (batch_size, num_heads, seq_len, d_k) → (batch_size, seq_len, num_heads, d_k)
# → (batch_size, seq_len, d_model)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, -1, self.d_model)
# 5. Финальное линейное преобразование
output = self.W_o(attn_output)
return output, attn_weights
# Пример использования
print("\n=== MULTI-HEAD ATTENTION ===")
batch_size, seq_len, d_model, num_heads = 2, 4, 256, 8
# Входные данные
Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)
print(f"Входная размерность: {d_model}")
print(f"Количество голов: {num_heads}")
print(f"Размер каждой головы: {d_model // num_heads}")
# Применяем multi-head attention
mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
output, weights = mha(Q, K, V)
print(f"\nОутпут shape: {output.shape}")
print(f"Веса внимания shape: {weights.shape}")
print(f"\nВес первой головы для первого примера:")
print(weights[0, 0]) # первый batch, первая голова
print(f"""
=== ПРЕИМУЩЕСТВА MULTI-HEAD ATTENTION ===
1. Разные головы учат разные типы зависимостей
- Голова 1: грамматические связи
- Голова 2: семантические связи
- Голова 3: дальние зависимости
2. Параллелизм: все головы работают одновременно
3. Рихтер представление информации
- Каждая голова фокусируется на своем аспекте
4. Улучшенная способность к обобщению
""")
3. Полный Transformer Block
class PositionalEncoding(nn.Module):
"""Позиционное кодирование (информация о порядке слов)"""
def __init__(self, d_model, max_seq_len=512):
super().__init__()
pe = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float) *
-(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
return x + self.pe[:, :x.size(1)]
class FeedForward(nn.Module):
"""Полносвязная сеть в Transformer"""
def __init__(self, d_model, d_ff=2048):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.activation = nn.ReLU()
def forward(self, x):
return self.linear2(self.activation(self.linear1(x)))
class TransformerEncoderLayer(nn.Module):
"""Один слой Transformer энкодера"""
def __init__(self, d_model, num_heads, d_ff=2048, dropout=0.1):
super().__init__()
self.multi_head_attention = MultiHeadAttention(d_model, num_heads)
self.feed_forward = FeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x):
# Multi-head attention + residual + layer norm
attn_output, _ = self.multi_head_attention(x, x, x) # self-attention
x = self.norm1(x + self.dropout1(attn_output))
# Feed-forward + residual + layer norm
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout2(ff_output))
return x
# Пример
print("\n=== TRANSFORMER ENCODER LAYER ===")
transformer_layer = TransformerEncoderLayer(d_model=256, num_heads=8, d_ff=1024)
x = torch.randn(2, 4, 256) # batch=2, seq_len=4, d_model=256
output = transformer_layer(x)
print(f"Входная форма: {x.shape}")
print(f"Выходная форма: {output.shape}")
4. Тестирование
print("\n=== ТЕСТИРОВАНИЕ ===")
# Тест 1: Проверка формирования
print("\nТест 1: Правильность форм")
x = torch.randn(3, 5, 128) # batch=3, seq_len=5, d_model=128
mha = MultiHeadAttention(d_model=128, num_heads=4)
output, _ = mha(x, x, x)
assert output.shape == x.shape, f"Shape mismatch: {output.shape} vs {x.shape}"
print("✓ Формы совпадают")
# Тест 2: Проверка градиентов
print("\nТест 2: Обратное распространение")
x = torch.randn(2, 3, 64, requires_grad=True)
mha = MultiHeadAttention(d_model=64, num_heads=2)
output, _ = mha(x, x, x)
loss = output.sum()
loss.backward()
assert x.grad is not None, "Градиент не вычислен"
print(f"✓ Градиент вычислен, shape: {x.grad.shape}")
# Тест 3: Проверка внимания на одиночный элемент
print("\nТест 3: Внимание на одиночный элемент")
Q = torch.zeros(1, 1, 64)
K = torch.randn(1, 4, 64)
V = torch.ones(1, 4, 64) * 5 # все значения равны 5
attn = ScaledDotProductAttention(d_k=64)
output, weights = attn(Q, K, V)
# Если все значения V одинаковы, выход должен быть близок к 5
print(f"Ожидается: ~5, получено: {output[0, 0, 0].item():.4f}")
print("✓ Логика работает")
# Тест 4: Causal mask (для декодера)
print("\nТест 4: Causal mask (маска для декодера)")
seq_len = 4
causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
print(f"Causal mask:\n{causal_mask}")
print("✓ Маска предотвращает внимание на будущие позиции")
print(f"""
=== ИТОГИ ===
✓ Реализован Scaled Dot-Product Attention
✓ Реализован Multi-Head Attention
✓ Реализован Transformer Encoder Layer
✓ Все тесты пройдены
Использование:
- Для NLP: BERT, GPT, T5
- Для компьютерного зрения: Vision Transformer (ViT)
- Для序列 задач: Seq2Seq моделей с attention
""")