← Назад к вопросам

Как работает LSTM?

2.7 Senior🔥 181 комментариев
#Временные ряды#Глубокое обучение

Комментарии (1)

🐱
claude-haiku-4.5PrepBro AI30 мар. 2026 г.(ред.)

Ответ сгенерирован нейросетью и может содержать ошибки

LSTM (Long Short-Term Memory): полный разбор

LSTM — это специализированный вид рекуррентной нейронной сети, разработанный для решения проблемы исчезающих градиентов при работе с долгосрочными зависимостями.

Проблема RNN (базовой рекуррентной сети)

Проблема "исчезающих градиентов" (Vanishing Gradient):

Обычная RNN использует простое повторение:

h_t = tanh(W_h * h_{t-1} + W_x * x_t)

При backprop через много временных шагов:
graadient ∝ tanh'(...) * W_h * tanh'(...) * W_h * ... (много раз)

Если |W_h| < 1, градиенты экспоненциально убывают
Если |W_h| > 1, градиенты экспоненциально растут (exploding gradients)

Результат: RNN не может выучить долгосрочные зависимости.

# Пример: Предсказать последнее слово в предложении
sentence = "The cat, which had been sitting on the mat, suddenly..."
# Модель не может пройти 20+ слов без потери информации

Решение: LSTM архитектура

LSTM добавляет cell state (состояние ячейки) и gates (вентили) для контроля потока информации.

Четыре ключевых компонента:

1. Cell State (C_t) — "память ячейки" (долгосрочная память)
2. Hidden State (h_t) — "выходное скрытое состояние" (краткосрочная память)
3. Input Gate (i_t) — контролирует, что добавить в память
4. Forget Gate (f_t) — контролирует, что забыть
5. Output Gate (o_t) — контролирует, что выхлопить из памяти

Уравнения LSTM

Забывающий вентиль (Forget Gate):

f_t = σ(W_f * [h_{t-1}, x_t] + b_f)

Ответ: какую долю старого состояния забыть (0 = забыть, 1 = сохранить)

Входной вентиль (Input Gate):

i_t = σ(W_i * [h_{t-1}, x_t] + b_i)
C̃_t = tanh(W_c * [h_{t-1}, x_t] + b_c)

i_t: как много новой информации добавить
C̃_t: кандидат новой информации

Обновление состояния ячейки (Cell State Update):

C_t = f_t ⊙ C_{t-1} + i_t ⊙ C̃_t

⊙ означает поэлементное умножение (Hadamard product)

Читаем:
C_t = (забыть старое) + (добавить новое)

Выходной вентиль (Output Gate):

o_t = σ(W_o * [h_{t-1}, x_t] + b_o)
h_t = o_t ⊙ tanh(C_t)

o_t: какую долю состояния показать на выход
h_t: итоговое скрытое состояние

Пример работы LSTM step-by-step

import torch
import torch.nn as nn

# LSTM ячейка
class SimpleLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        
        # Все 4 вентиля используют одну матрицу (эффективность)
        self.weight_ih = nn.Linear(input_size, 4 * hidden_size)
        self.weight_hh = nn.Linear(hidden_size, 4 * hidden_size)
        self.bias = nn.Parameter(torch.zeros(4 * hidden_size))
    
    def forward(self, x, state):
        h, c = state  # h = скрытое состояние, c = состояние ячейки
        
        # Объединяем входы
        gates = self.weight_ih(x) + self.weight_hh(h) + self.bias
        
        # Разделяем на 4 вентиля
        i_gate, f_gate, g_gate, o_gate = gates.chunk(4, 1)
        
        # Применяем активации
        i_gate = torch.sigmoid(i_gate)  # Input gate [0,1]
        f_gate = torch.sigmoid(f_gate)  # Forget gate [0,1]
        g_gate = torch.tanh(g_gate)      # Cell candidate [-1,1]
        o_gate = torch.sigmoid(o_gate)  # Output gate [0,1]
        
        # Обновляем cell state
        c_new = f_gate * c + i_gate * g_gate
        
        # Вычисляем новое скрытое состояние
        h_new = o_gate * torch.tanh(c_new)
        
        return h_new, c_new

Визуализация LSTM

┌──────────────────────────────────────────────┐
│                  LSTM CELL                   │
├──────────────────────────────────────────────┤
│                                              │
│     x_t (input)    h_{t-1} (prev hidden)    │
│      │                │                      │
│      └────────┬───────┘                      │
│              /\                              │
│             /  \                             │
│    Forget Gate Σ Input Gate                  │
│   (f_t ∈ [0,1]) │ (i_t ∈ [0,1])             │
│          │      │       │                    │
│        × | tanh |       × (Hadamard)         │
│          │      │       │                    │
│         C_{t-1} ⊙ C̃_t  Output Gate         │
│          │      │       │ (o_t ∈ [0,1])     │
│          └─ + ──┘       │                    │
│             │           │                    │
│           C_t ─ tanh ─ ×                    │
│             │           │                    │
│             │          h_t (output)         │
│             │           │                    │
│     (cell state)    (hidden state)           │
│                                              │
└──────────────────────────────────────────────┘

Почему LSTM решает проблему исчезающих градиентов

При backprop:

градиент_C_t_to_C_{t-1} = f_t (значение между 0 и 1, но контролируется)

Вместо:
градиент ∝ tanh'(...) * W * tanh'(...) * W * ... (много умножений)

Теперь:
градиент ∝ f_t * f_{t-1} * ... * f_0

Люди могут обучить сеть:
f_t ≈ 1 ("не забывай информацию")

Tогда градиент ≈ 1 * 1 * ... * 1 = 1 (не исчезает!)

Ключевое отличие:

  • RNN: градиент = произведение активаций (экспоненциальный рост/падение)
  • LSTM: градиент контролируется вентилями (может оставаться стабильным)

Пример: использование LSTM в PyTorch

import torch
import torch.nn as nn

class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True
        )
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        # x shape: (batch_size, seq_length, input_size)
        lstm_out, (h_n, c_n) = self.lstm(x)
        # lstm_out shape: (batch_size, seq_length, hidden_size)
        # h_n shape: (num_layers, batch_size, hidden_size)
        # c_n shape: (num_layers, batch_size, hidden_size)
        
        # Берём последний скрытый слой
        last_hidden = h_n[-1]  # (batch_size, hidden_size)
        output = self.fc(last_hidden)
        return output

# Использование
model = LSTMModel(
    input_size=10,
    hidden_size=32,
    num_layers=2,
    output_size=1
)

x = torch.randn(16, 20, 10)  # batch_size=16, seq_len=20, input_size=10
y = model(x)  # output shape: (16, 1)

GRU (упрощённая версия LSTM)

GRU (Gated Recurrent Unit) — упрощённая версия LSTM:

# Вместо 4 вентилей, GRU использует 3:
gru = nn.GRU(
    input_size=10,
    hidden_size=32,
    num_layers=2,
    batch_first=True
)

# Использование похоже на LSTM
gru_out, h_n = gru(x)

Отличие:

  • LSTM: 4 вентиля, запоминает cell state и hidden state
  • GRU: 3 вентиля, более простая, быстрее, часто работает так же хорошо

Практические примеры LSTM

1. Prediction временного ряда

# Предсказать следующее значение цены акций
model = LSTMModel(input_size=1, hidden_size=64, num_layers=2, output_size=1)

# Входные данные: цены за последние 30 дней
x = torch.randn(32, 30, 1)  # batch=32, seq=30 дней, features=1
y = model(x)  # Предсказание на день 31

2. Natural Language Processing (NLP)

# Классификация текста (sentiment analysis)
# [word_1, word_2, ..., word_n] → [positive/negative]

model = nn.Sequential(
    nn.Embedding(vocab_size=10000, embedding_dim=128),
    nn.LSTM(128, 64, num_layers=2, batch_first=True),
    nn.AdaptiveAvgPool1d(1),  #Pooling по временной оси
    nn.Flatten(),
    nn.Linear(64, 2)  # 2 класса: positive/negative
)

3. Machine Translation (seq2seq)

# Encoder LSTM: читает исходный язык
encoder = nn.LSTM(input_size=input_vocab, hidden_size=256)

# Decoder LSTM: генерирует целевой язык
decoder = nn.LSTM(input_size=output_vocab, hidden_size=256)

# Передаём скрытое состояние от encoder'а в decoder
encoder_output, (h_n, c_n) = encoder(source_text)
decoder_output, _ = decoder(target_text, (h_n, c_n))

Плюсы и минусы LSTM

Плюсы:

  • Решает проблему исчезающих градиентов
  • Может запомнить долгосрочные зависимости (100+ шагов)
  • Стандарт де-факто для задач с последовательностями
  • Хорошо интерпретируемо (вентили показывают, что помнит)

Минусы:

  • Медленнее, чем feedforward сети
  • Последовательная обработка (не параллелизуется)
  • Требует много памяти (состояния хранятся для каждого шага)
  • Вытеснена Transformer'ами (параллельные, быстрее)

Когда использовать LSTM vs Transformer

СценарийLSTMTransformer
Маленький датасет✗ (требует много данных)
Долгие последовательности✗ (медленно)✓ (параллельно)
Real-time inference✗ (требует полный input)
Интерпретируемость✗ (чёрный ящик)
Стандартная NLP✓ (BERT, GPT)

Итог

LSTM = RNN с памятью и вентилями

Ключевые идеи:
1. Cell state (C_t) = долгосрочная память
2. Hidden state (h_t) = краткосрочная память
3. Вентили (f, i, o) = контроль потока информации
4. Решает исчезающие градиенты через controlled gradients

Может запомнить зависимости на 100+ временных шагов,
в то время как обычная RNN забывает за 10-20 шагов.

Это было революционно в 2014-2019 годах, но теперь вытеснено Transformer'ами. Однако LSTM остаётся полезной архитектурой для маленьких датасетов и real-time приложений.