← Назад к вопросам
Какие знаешь виды attention в нейросетях, кроме тех, что используются в трансформерах?
2.0 Middle🔥 171 комментариев
#Глубокое обучение
Комментарии (1)
🐱
claude-haiku-4.5PrepBro AI29 мар. 2026 г.(ред.)
Ответ сгенерирован нейросетью и может содержать ошибки
Виды Attention в нейросетях
Attention механизм — это мощная техника, которая позволяет сети концентрироваться на релевантных частях входных данных. Я расскажу о различных типах attention, включая те, что используются в трансформерах и за их пределами.
1. История и эволюция Attention
Хронология:
2014 → Bahdanau Attention (для seq2seq, RNN)
2015 → Luong Attention (улучшение)
2016 → Self-Attention (Cheng et al)
2017 → Multi-Head Self-Attention (Transformer, Vaswani et al)
2020+ → Variations: Linear Attention, Sparse Attention, etc.
2. Bahdanau Attention (Additive/Concat Attention)
import torch
import torch.nn as nn
import torch.nn.functional as F
class BahdanauAttention(nn.Module):
"""
Bahdanau Attention (2014):
Используется в seq2seq моделях с RNN/LSTM
Формула:
score(h, s) = tanh(W[h; s])
где h = encoder hidden state, s = decoder hidden state
"""
def __init__(self, decoder_dim, encoder_dim):
super().__init__()
self.decoder_dim = decoder_dim
self.encoder_dim = encoder_dim
# Веса для вычисления scores
self.linear_query = nn.Linear(decoder_dim, encoder_dim)
self.linear_context = nn.Linear(encoder_dim, encoder_dim)
self.v = nn.Linear(encoder_dim, 1)
def forward(self, decoder_hidden, encoder_outputs):
"""
decoder_hidden: [batch, decoder_dim]
encoder_outputs: [batch, seq_len, encoder_dim]
Returns: context, attention_weights
"""
batch_size = encoder_outputs.shape[0]
seq_len = encoder_outputs.shape[1]
# Expand decoder_hidden для broadcasting
# [batch, 1, encoder_dim]
decoder_hidden_expanded = self.linear_query(decoder_hidden).unsqueeze(1)
# Compute energy (scores)
# energy = tanh(W_1 * encoder_hidden + W_2 * decoder_hidden)
encoder_transformed = self.linear_context(encoder_outputs) # [batch, seq, encoder_dim]
combined = torch.tanh(decoder_hidden_expanded + encoder_transformed) # [batch, seq, encoder_dim]
energy = self.v(combined).squeeze(-1) # [batch, seq]
# Softmax для получения weights
attention_weights = F.softmax(energy, dim=1) # [batch, seq]
# Context = weighted sum of encoder outputs
context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs) # [batch, 1, encoder_dim]
context = context.squeeze(1) # [batch, encoder_dim]
return context, attention_weights
# Пример использования
bahdanau = BahdanauAttention(decoder_dim=128, encoder_dim=256)
decoder_hidden = torch.randn(32, 128) # batch=32, hidden_dim=128
encoder_outputs = torch.randn(32, 10, 256) # batch=32, seq_len=10, hidden_dim=256
context, weights = bahdanau(decoder_hidden, encoder_outputs)
print(f"Context shape: {context.shape}") # [32, 256]
print(f"Attention weights shape: {weights.shape}") # [32, 10]
3. Luong Attention (Multiplicative/Dot-Product)
class LuongAttention(nn.Module):
"""
Luong Attention (2015):
Более простая и эффективная версия Bahdanau
Формула:
score(h, s) = h^T * W * s (dot product with weight matrix)
или даже просто: score = h^T * s (без W)
"""
def __init__(self, encoder_dim, attention_type='dot'):
super().__init__()
self.encoder_dim = encoder_dim
self.attention_type = attention_type
if attention_type == 'general':
# Multiplicative avec weight matrix
self.weight = nn.Linear(encoder_dim, encoder_dim)
def forward(self, decoder_hidden, encoder_outputs):
"""
decoder_hidden: [batch, encoder_dim]
encoder_outputs: [batch, seq_len, encoder_dim]
"""
if self.attention_type == 'dot':
# Простой dot product
# score = decoder_hidden @ encoder_outputs^T
scores = torch.bmm(
decoder_hidden.unsqueeze(1), # [batch, 1, dim]
encoder_outputs.transpose(1, 2) # [batch, dim, seq]
) # [batch, 1, seq]
scores = scores.squeeze(1) # [batch, seq]
elif self.attention_type == 'general':
# score = decoder_hidden @ W @ encoder_outputs^T
decoder_transformed = self.weight(decoder_hidden) # [batch, dim]
scores = torch.bmm(
decoder_transformed.unsqueeze(1),
encoder_outputs.transpose(1, 2)
) # [batch, 1, seq]
scores = scores.squeeze(1) # [batch, seq]
# Normalize
attention_weights = F.softmax(scores / torch.sqrt(torch.tensor(self.encoder_dim, dtype=torch.float32)), dim=1)
# Context
context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)
context = context.squeeze(1)
return context, attention_weights
# Пример
luong = LuongAttention(encoder_dim=256, attention_type='general')
context, weights = luong(decoder_hidden, encoder_outputs)
print(f"\nLuong - Context shape: {context.shape}") # [32, 256]
4. Self-Attention (но не Multi-Head)
class SimpleScaledDotProductAttention(nn.Module):
"""
Простой Scaled Dot-Product Attention (основа трансформера)
Но это улучшенная версия для одного head
(Multi-Head это набор таких слоёв)
"""
def forward(self, query, key, value, mask=None):
"""
query, key, value: [batch, seq_len, d_model]
mask: [batch, seq_len] (для padding/future masking)
"""
# 1. Compute scores
scores = torch.matmul(query, key.transpose(-1, -2)) # [batch, seq, seq]
scores = scores / torch.sqrt(torch.tensor(query.shape[-1], dtype=torch.float32))
# 2. Apply mask (if provided)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# 3. Softmax
attention_weights = F.softmax(scores, dim=-1) # [batch, seq, seq]
# 4. Apply to values
output = torch.matmul(attention_weights, value) # [batch, seq, d_model]
return output, attention_weights
5. Cross-Attention (для seq2seq)
class CrossAttention(nn.Module):
"""
Cross-Attention:
Decoder attends to Encoder outputs
Query: from decoder
Key, Value: from encoder
Используется в:
- seq2seq models (translation, summarization)
- Multimodal models (text attends to image)
"""
def __init__(self, d_model):
super().__init__()
self.query_linear = nn.Linear(d_model, d_model)
self.key_linear = nn.Linear(d_model, d_model)
self.value_linear = nn.Linear(d_model, d_model)
self.out_linear = nn.Linear(d_model, d_model)
def forward(self, decoder_hidden, encoder_outputs):
"""
decoder_hidden: [batch, seq_dec, d_model]
encoder_outputs: [batch, seq_enc, d_model]
Returns: attended representation
"""
# Generate Q from decoder, K, V from encoder
query = self.query_linear(decoder_hidden) # [batch, seq_dec, d_model]
key = self.key_linear(encoder_outputs) # [batch, seq_enc, d_model]
value = self.value_linear(encoder_outputs) # [batch, seq_enc, d_model]
# Compute attention
scores = torch.matmul(query, key.transpose(-1, -2)) # [batch, seq_dec, seq_enc]
scores = scores / torch.sqrt(torch.tensor(query.shape[-1], dtype=torch.float32))
attention_weights = F.softmax(scores, dim=-1)
# Combine with values
attended = torch.matmul(attention_weights, value) # [batch, seq_dec, d_model]
output = self.out_linear(attended)
return output, attention_weights
6. Spatial Attention (для компьютерного зрения)
class SpatialAttention(nn.Module):
"""
Spatial Attention (для CNN/images):
Сеть учится какие пространственные области важны
Отличие от sequence attention:
- Работает с 2D spatial maps (height x width)
- Обычно используется после CNN слоёв
"""
def __init__(self, kernel_size=7):
super().__init__()
padding = kernel_size // 2
self.conv = nn.Conv2d(
in_channels=2, # Concatenate mean и max pooling
out_channels=1,
kernel_size=kernel_size,
padding=padding,
bias=False
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
"""
x: [batch, channels, height, width]
Returns: weighted x with spatial attention map
"""
# Channel-wise statistics
avg_pool = torch.mean(x, dim=1, keepdim=True) # [batch, 1, h, w]
max_pool, _ = torch.max(x, dim=1, keepdim=True) # [batch, 1, h, w]
# Concatenate
concat = torch.cat([avg_pool, max_pool], dim=1) # [batch, 2, h, w]
# Generate attention map
attention_map = self.sigmoid(self.conv(concat)) # [batch, 1, h, w]
# Apply attention
output = x * attention_map
return output, attention_map
# Пример
spatial_att = SpatialAttention(kernel_size=7)
feature_maps = torch.randn(32, 64, 28, 28) # batch=32, channels=64, H=28, W=28
attended, att_map = spatial_att(feature_maps)
print(f"\nSpatial Attention output shape: {attended.shape}") # [32, 64, 28, 28]
7. Channel Attention (для компьютерного зрения)
class ChannelAttention(nn.Module):
"""
Channel Attention (SE-Net, Squeeze-and-Excitation):
Сеть учится какие каналы (features) важны
Процесс:
1. Squeeze: Global Average Pooling
2. Excitation: FC layers для переобучения weights каналов
"""
def __init__(self, num_channels, reduction_ratio=16):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
# FC bottleneck
reduced_channels = max(num_channels // reduction_ratio, 1)
self.fc = nn.Sequential(
nn.Linear(num_channels, reduced_channels),
nn.ReLU(inplace=True),
nn.Linear(reduced_channels, num_channels)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
"""
x: [batch, channels, height, width]
"""
# Squeeze
avg = self.avg_pool(x) # [batch, channels, 1, 1]
max_ = self.max_pool(x) # [batch, channels, 1, 1]
avg = avg.view(avg.shape[0], -1) # [batch, channels]
max_ = max_.view(max_.shape[0], -1) # [batch, channels]
# Excitation
avg_out = self.fc(avg) # [batch, channels]
max_out = self.fc(max_) # [batch, channels]
# Combine и sigmoid
channel_weights = self.sigmoid(avg_out + max_out) # [batch, channels]
# Apply
channel_weights = channel_weights.view(channel_weights.shape[0], -1, 1, 1)
output = x * channel_weights
return output, channel_weights
# Пример
channel_att = ChannelAttention(num_channels=64, reduction_ratio=16)
feature_maps = torch.randn(32, 64, 28, 28)
attended, weights = channel_att(feature_maps)
print(f"Channel Attention output shape: {attended.shape}") # [32, 64, 28, 28]
8. Sparse Attention (Efficient Transformers)
class SparseAttention(nn.Module):
"""
Sparse Attention (для длинных последовательностей):
Вместо полной O(n^2) матрицы внимания,
вычисляем внимание только к ограниченному окну токенов
Типы:
1. Local attention: только соседние токены
2. Strided attention: каждый k-й токен
3. Fixed attention: фиксированные позиции
4. Reformer (LSH): использует locality-sensitive hashing
"""
def __init__(self, window_size=64):
self.window_size = window_size
def forward(self, query, key, value):
"""
query, key, value: [batch, seq_len, d_model]
Пример: Local attention window_size=64
- Каждый токен внимает только к 64 соседним токенам
- Вместо всех seq_len
"""
batch_size, seq_len, d_model = query.shape
# Create mask для локального окна
attention_mask = torch.zeros((seq_len, seq_len))
for i in range(seq_len):
# Окно around позиции i
start = max(0, i - self.window_size // 2)
end = min(seq_len, i + self.window_size // 2)
attention_mask[i, start:end] = 1
# Обычное scaled dot-product attention,
# но с mask для sparse connectivity
scores = torch.matmul(query, key.transpose(-1, -2))
scores = scores / torch.sqrt(torch.tensor(d_model, dtype=torch.float32))
# Apply sparse mask
scores = scores.masked_fill(attention_mask == 0, -1e9)
attention_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, value)
return output, attention_weights
9. Сравнение видов Attention
┌──────────────────────┬────────────────────┬──────────────────────┐
│ Type │ Use Case │ Complexity │
├──────────────────────┼────────────────────┼──────────────────────┤
│ Bahdanau Additive │ seq2seq RNN │ O(n*m) where n,m=lengths │
│ Luong Multiplicative │ seq2seq RNN │ O(n*m) (faster than BA) │
│ Self-Attention │ intra-sequence │ O(n^2) for full seq │
│ Cross-Attention │ seq2seq Transformer│ O(n*m) │
│ Multi-Head Self-Att │ Transformer │ O(n^2) but parallelizable │
│ Spatial Attention │ CNN features │ O(h*w) for 2D spatial│
│ Channel Attention │ CNN features │ O(c) very fast │
│ Sparse Attention │ Long sequences │ O(n*w) where w=window│
│ Linear Attention │ Efficient │ O(n) with approximation │
└──────────────────────┴────────────────────┴──────────────────────┘
10. Практический выбор
def choose_attention_mechanism(task):
"""
Рекомендации по выбору attention
"""
if task == 'seq2seq_translation':
return "Cross-Attention (encoder-decoder)"
elif task == 'language_modeling':
return "Multi-Head Self-Attention (causal mask)"
elif task == 'image_classification':
return "Channel Attention (lightweight) or Spatial Attention"
elif task == 'object_detection':
return "Spatial + Channel Attention"
elif task == 'long_document_processing':
return "Sparse Attention or Linear Attention"
elif task == 'rnn_sequence_model':
return "Luong Attention (dot-product) или Bahdanau"
elif task == 'efficient_transformer':
return "Linear Attention или Sparse with local window"
else:
return "Multi-Head Self-Attention (universal choice)"
# Примеры
print("Выбор attention механизма:")
for task in ['seq2seq_translation', 'image_classification', 'long_document_processing']:
print(f"- {task}: {choose_attention_mechanism(task)}")
Заключение
Основные виды Attention:
- Bahdanau (Additive): Для RNN seq2seq, медленнее
- Luong (Multiplicative): Для RNN seq2seq, быстрее
- Scaled Dot-Product: Основа трансформера
- Multi-Head Self-Attention: Стандарт для трансформеров
- Cross-Attention: Для decoder-encoder interaction
- Spatial Attention: Для 2D features (images)
- Channel Attention (SE-Net): Для важности каналов
- Sparse Attention: Для long sequences
- Linear Attention: Для efficient processing
Key Takeaway:
- Attention как концепция универсален
- Детали реализации зависят от task и data
- Multi-Head Self-Attention стал стандартом, но не единственный вариант
- Выбор attention вида критичен для performance