Комментарии (1)
Ответ сгенерирован нейросетью и может содержать ошибки
LSTM (Long Short-Term Memory): полный разбор
LSTM — это специализированный вид рекуррентной нейронной сети, разработанный для решения проблемы исчезающих градиентов при работе с долгосрочными зависимостями.
Проблема RNN (базовой рекуррентной сети)
Проблема "исчезающих градиентов" (Vanishing Gradient):
Обычная RNN использует простое повторение:
h_t = tanh(W_h * h_{t-1} + W_x * x_t)
При backprop через много временных шагов:
graadient ∝ tanh'(...) * W_h * tanh'(...) * W_h * ... (много раз)
Если |W_h| < 1, градиенты экспоненциально убывают
Если |W_h| > 1, градиенты экспоненциально растут (exploding gradients)
Результат: RNN не может выучить долгосрочные зависимости.
# Пример: Предсказать последнее слово в предложении
sentence = "The cat, which had been sitting on the mat, suddenly..."
# Модель не может пройти 20+ слов без потери информации
Решение: LSTM архитектура
LSTM добавляет cell state (состояние ячейки) и gates (вентили) для контроля потока информации.
Четыре ключевых компонента:
1. Cell State (C_t) — "память ячейки" (долгосрочная память)
2. Hidden State (h_t) — "выходное скрытое состояние" (краткосрочная память)
3. Input Gate (i_t) — контролирует, что добавить в память
4. Forget Gate (f_t) — контролирует, что забыть
5. Output Gate (o_t) — контролирует, что выхлопить из памяти
Уравнения LSTM
Забывающий вентиль (Forget Gate):
f_t = σ(W_f * [h_{t-1}, x_t] + b_f)
Ответ: какую долю старого состояния забыть (0 = забыть, 1 = сохранить)
Входной вентиль (Input Gate):
i_t = σ(W_i * [h_{t-1}, x_t] + b_i)
C̃_t = tanh(W_c * [h_{t-1}, x_t] + b_c)
i_t: как много новой информации добавить
C̃_t: кандидат новой информации
Обновление состояния ячейки (Cell State Update):
C_t = f_t ⊙ C_{t-1} + i_t ⊙ C̃_t
⊙ означает поэлементное умножение (Hadamard product)
Читаем:
C_t = (забыть старое) + (добавить новое)
Выходной вентиль (Output Gate):
o_t = σ(W_o * [h_{t-1}, x_t] + b_o)
h_t = o_t ⊙ tanh(C_t)
o_t: какую долю состояния показать на выход
h_t: итоговое скрытое состояние
Пример работы LSTM step-by-step
import torch
import torch.nn as nn
# LSTM ячейка
class SimpleLSTMCell(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.hidden_size = hidden_size
# Все 4 вентиля используют одну матрицу (эффективность)
self.weight_ih = nn.Linear(input_size, 4 * hidden_size)
self.weight_hh = nn.Linear(hidden_size, 4 * hidden_size)
self.bias = nn.Parameter(torch.zeros(4 * hidden_size))
def forward(self, x, state):
h, c = state # h = скрытое состояние, c = состояние ячейки
# Объединяем входы
gates = self.weight_ih(x) + self.weight_hh(h) + self.bias
# Разделяем на 4 вентиля
i_gate, f_gate, g_gate, o_gate = gates.chunk(4, 1)
# Применяем активации
i_gate = torch.sigmoid(i_gate) # Input gate [0,1]
f_gate = torch.sigmoid(f_gate) # Forget gate [0,1]
g_gate = torch.tanh(g_gate) # Cell candidate [-1,1]
o_gate = torch.sigmoid(o_gate) # Output gate [0,1]
# Обновляем cell state
c_new = f_gate * c + i_gate * g_gate
# Вычисляем новое скрытое состояние
h_new = o_gate * torch.tanh(c_new)
return h_new, c_new
Визуализация LSTM
┌──────────────────────────────────────────────┐
│ LSTM CELL │
├──────────────────────────────────────────────┤
│ │
│ x_t (input) h_{t-1} (prev hidden) │
│ │ │ │
│ └────────┬───────┘ │
│ /\ │
│ / \ │
│ Forget Gate Σ Input Gate │
│ (f_t ∈ [0,1]) │ (i_t ∈ [0,1]) │
│ │ │ │ │
│ × | tanh | × (Hadamard) │
│ │ │ │ │
│ C_{t-1} ⊙ C̃_t Output Gate │
│ │ │ │ (o_t ∈ [0,1]) │
│ └─ + ──┘ │ │
│ │ │ │
│ C_t ─ tanh ─ × │
│ │ │ │
│ │ h_t (output) │
│ │ │ │
│ (cell state) (hidden state) │
│ │
└──────────────────────────────────────────────┘
Почему LSTM решает проблему исчезающих градиентов
При backprop:
градиент_C_t_to_C_{t-1} = f_t (значение между 0 и 1, но контролируется)
Вместо:
градиент ∝ tanh'(...) * W * tanh'(...) * W * ... (много умножений)
Теперь:
градиент ∝ f_t * f_{t-1} * ... * f_0
Люди могут обучить сеть:
f_t ≈ 1 ("не забывай информацию")
Tогда градиент ≈ 1 * 1 * ... * 1 = 1 (не исчезает!)
Ключевое отличие:
- RNN: градиент = произведение активаций (экспоненциальный рост/падение)
- LSTM: градиент контролируется вентилями (может оставаться стабильным)
Пример: использование LSTM в PyTorch
import torch
import torch.nn as nn
class LSTMModel(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size):
super().__init__()
self.lstm = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True
)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
# x shape: (batch_size, seq_length, input_size)
lstm_out, (h_n, c_n) = self.lstm(x)
# lstm_out shape: (batch_size, seq_length, hidden_size)
# h_n shape: (num_layers, batch_size, hidden_size)
# c_n shape: (num_layers, batch_size, hidden_size)
# Берём последний скрытый слой
last_hidden = h_n[-1] # (batch_size, hidden_size)
output = self.fc(last_hidden)
return output
# Использование
model = LSTMModel(
input_size=10,
hidden_size=32,
num_layers=2,
output_size=1
)
x = torch.randn(16, 20, 10) # batch_size=16, seq_len=20, input_size=10
y = model(x) # output shape: (16, 1)
GRU (упрощённая версия LSTM)
GRU (Gated Recurrent Unit) — упрощённая версия LSTM:
# Вместо 4 вентилей, GRU использует 3:
gru = nn.GRU(
input_size=10,
hidden_size=32,
num_layers=2,
batch_first=True
)
# Использование похоже на LSTM
gru_out, h_n = gru(x)
Отличие:
- LSTM: 4 вентиля, запоминает cell state и hidden state
- GRU: 3 вентиля, более простая, быстрее, часто работает так же хорошо
Практические примеры LSTM
1. Prediction временного ряда
# Предсказать следующее значение цены акций
model = LSTMModel(input_size=1, hidden_size=64, num_layers=2, output_size=1)
# Входные данные: цены за последние 30 дней
x = torch.randn(32, 30, 1) # batch=32, seq=30 дней, features=1
y = model(x) # Предсказание на день 31
2. Natural Language Processing (NLP)
# Классификация текста (sentiment analysis)
# [word_1, word_2, ..., word_n] → [positive/negative]
model = nn.Sequential(
nn.Embedding(vocab_size=10000, embedding_dim=128),
nn.LSTM(128, 64, num_layers=2, batch_first=True),
nn.AdaptiveAvgPool1d(1), #Pooling по временной оси
nn.Flatten(),
nn.Linear(64, 2) # 2 класса: positive/negative
)
3. Machine Translation (seq2seq)
# Encoder LSTM: читает исходный язык
encoder = nn.LSTM(input_size=input_vocab, hidden_size=256)
# Decoder LSTM: генерирует целевой язык
decoder = nn.LSTM(input_size=output_vocab, hidden_size=256)
# Передаём скрытое состояние от encoder'а в decoder
encoder_output, (h_n, c_n) = encoder(source_text)
decoder_output, _ = decoder(target_text, (h_n, c_n))
Плюсы и минусы LSTM
Плюсы:
- Решает проблему исчезающих градиентов
- Может запомнить долгосрочные зависимости (100+ шагов)
- Стандарт де-факто для задач с последовательностями
- Хорошо интерпретируемо (вентили показывают, что помнит)
Минусы:
- Медленнее, чем feedforward сети
- Последовательная обработка (не параллелизуется)
- Требует много памяти (состояния хранятся для каждого шага)
- Вытеснена Transformer'ами (параллельные, быстрее)
Когда использовать LSTM vs Transformer
| Сценарий | LSTM | Transformer |
|---|---|---|
| Маленький датасет | ✓ | ✗ (требует много данных) |
| Долгие последовательности | ✗ (медленно) | ✓ (параллельно) |
| Real-time inference | ✓ | ✗ (требует полный input) |
| Интерпретируемость | ✓ | ✗ (чёрный ящик) |
| Стандартная NLP | ✗ | ✓ (BERT, GPT) |
Итог
LSTM = RNN с памятью и вентилями
Ключевые идеи:
1. Cell state (C_t) = долгосрочная память
2. Hidden state (h_t) = краткосрочная память
3. Вентили (f, i, o) = контроль потока информации
4. Решает исчезающие градиенты через controlled gradients
Может запомнить зависимости на 100+ временных шагов,
в то время как обычная RNN забывает за 10-20 шагов.
Это было революционно в 2014-2019 годах, но теперь вытеснено Transformer'ами. Однако LSTM остаётся полезной архитектурой для маленьких датасетов и real-time приложений.