Как работает beam search?
Комментарии (1)
Ответ сгенерирован нейросетью и может содержать ошибки
Как работает beam search?
Beam search - это эвристический алгоритм поиска, используемый для генерации последовательностей (текст, перевод, речь) с целью найти оптимальную последовательность при минимизации вычислительных затрат. Вместо того, чтобы исследовать все возможные пути (что экспоненциально дорого), beam search хранит только несколько наиболее вероятных кандидатов на каждом шаге.
Основная идея
В отличие от жадного поиска (greedy search), который выбирает только самое вероятное слово на каждом шаге, beam search поддерживает k лучших гипотез (beam width). Например, с beam width = 3, на каждом шаге мы следим за 3 наиболее вероятными последовательностями.
Простой пример
Представьте генерацию простого предложения с beam width = 2:
Шаг 1 (стартовое слово):
- Вероятность "The" = 0.8
- Вероятность "A" = 0.15
- Вероятность "An" = 0.05
Оставляем top-2: "The" и "A"
Шаг 2 (второе слово): Для каждого из top-2 кандидатов рассчитываем вероятности следующего слова:
- "The" + "cat" = 0.8 * 0.6 = 0.48
- "The" + "dog" = 0.8 * 0.3 = 0.24
- "A" + "cat" = 0.15 * 0.7 = 0.105
- "A" + "dog" = 0.15 * 0.2 = 0.03
Оставляем top-2: "The cat" (0.48) и "The dog" (0.24)
Шаг 3: Повторяем процесс и так далее...
Реализация на практике
import numpy as np
from heapq import heappush, heappop
class BeamSearch:
def __init__(self, vocab_size, beam_width=3, max_length=20):
self.vocab_size = vocab_size
self.beam_width = beam_width
self.max_length = max_length
def search(self, model, start_token, vocab_size):
# Инициализируем: (логарифм вероятности, последовательность, состояние)
# Используем логарифмы для числовой стабильности
beams = [(-0.0, [start_token], None)] # (score, sequence, state)
completed = []
for step in range(self.max_length):
candidates = []
for score, seq, state in beams:
# Модель предсказывает распределение для следующего токена
logits, new_state = model(seq[-1], state)
# Берем логарифм вероятностей
log_probs = np.log(softmax(logits) + 1e-10)
# Рассматриваем top-k наиболее вероятных токенов
top_k = np.argsort(log_probs)[-self.beam_width:]
for token in top_k:
new_score = score + log_probs[token]
new_seq = seq + [token]
candidates.append((new_score, new_seq, new_state))
# Сортируем по очкам и оставляем top beam_width
candidates.sort(reverse=True)
beams = candidates[:self.beam_width]
# Проверяем, завершены ли последовательности (достигнут END токен)
new_beams = []
for score, seq, state in beams:
if seq[-1] == vocab_size - 1: # END token
completed.append((score, seq))
else:
new_beams.append((score, seq, state))
beams = new_beams
if not beams:
break
# Добавляем неполные последовательности
completed.extend(beams)
# Возвращаем лучшие результаты
completed.sort(reverse=True)
return completed[:self.beam_width]
Beam Search в TensorFlow/Keras
import tensorflow as tf
from tensorflow.text import BeamSearchDecoder
class SequenceToSequenceModel(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, units):
super().__init__()
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.lstm = tf.keras.layers.LSTM(units, return_state=True)
self.dense = tf.keras.layers.Dense(vocab_size)
def call(self, inputs, states=None):
x = self.embedding(inputs)
output, h, c = self.lstm(x, initial_state=states)
logits = self.dense(output)
return logits, [h, c]
# Использование
model = SequenceToSequenceModel(vocab_size=10000, embedding_dim=256, units=512)
# Процесс декодирования с beam search
beam_width = 5
max_length = 50
# В реальных фреймворках это делается встроенными функциями
# tf.keras.layers.BeamSearchDecoder или torch.nn.utils.beam_search
Сравнение стратегий генерации
import numpy as np
from scipy.special import softmax
def greedy_search(model, start_token, max_length, vocab_size):
"""Выбирает самый вероятный токен на каждом шаге"""
seq = [start_token]
for _ in range(max_length):
logits = model.predict([seq])
next_token = np.argmax(logits[0, -1])
seq.append(next_token)
if next_token == vocab_size - 1: # END token
break
return seq
def random_sampling(model, start_token, max_length, vocab_size, temperature=1.0):
"""Случайно выбирает токен пропорционально вероятности"""
seq = [start_token]
for _ in range(max_length):
logits = model.predict([seq])[0, -1]
probs = softmax(logits / temperature)
next_token = np.random.choice(vocab_size, p=probs)
seq.append(next_token)
if next_token == vocab_size - 1:
break
return seq
def top_k_sampling(model, start_token, max_length, vocab_size, k=10):
"""Выбирает случайно из top-k вероятных токенов"""
seq = [start_token]
for _ in range(max_length):
logits = model.predict([seq])[0, -1]
probs = softmax(logits)
top_k_indices = np.argsort(probs)[-k:]
top_k_probs = probs[top_k_indices]
top_k_probs /= top_k_probs.sum()
next_token = np.random.choice(top_k_indices, p=top_k_probs)
seq.append(next_token)
if next_token == vocab_size - 1:
break
return seq
| Метод | Скорость | Качество | Разнообразие |
|---|---|---|---|
| Greedy Search | Быстро | Среднее | Низкое |
| Beam Search | Медленнее | Высокое | Среднее |
| Random Sampling | Быстро | Низкое | Высокое |
| Top-k Sampling | Медленнее | Хорошее | Высокое |
Параметры beam search
Beam Width (k):
- k=1 - это greedy search
- k=3-5 - типичный выбор
- k>10 - редко дает значительное улучшение
# Пример: оптимизация beam width
for beam_width in [1, 3, 5, 10]:
results = beam_search(model, beam_width=beam_width)
print(f"Beam width {beam_width}: BLEU score = {evaluate(results):.3f}")
Length Penalty
Без штрафа за длину beam search предпочитает короткие последовательности (они имеют высокую совместную вероятность). Поэтому часто используется длинный штраф:
def length_normalized_score(sequence_score, length, alpha=0.6):
# alpha=0 означает без штрафа
# alpha=1 означает полная нормализация на длину
return sequence_score / (length ** alpha)
Практические применения
- Machine Translation (MT): Google Translate, DeepL
- Text Generation: GPT, T5
- Speech Recognition: Automatic Speech Recognition (ASR)
- Summarization: Extractive и abstractive summarization
- Question Answering: Генеративные модели для ответов
Выводы
Beam search - это мощный компромисс между качеством (как при переборе всех путей) и вычислительной эффективностью. Он используется в большинстве производственных систем генерации последовательностей, где нужны релевантные и вероятные результаты.