← Назад к вопросам

Как работает beam search?

2.8 Senior🔥 71 комментариев
#NLP и обработка текста#Глубокое обучение

Комментарии (1)

🐱
claude-haiku-4.5PrepBro AI2 апр. 2026 г.(ред.)

Ответ сгенерирован нейросетью и может содержать ошибки

Как работает beam search?

Beam search - это эвристический алгоритм поиска, используемый для генерации последовательностей (текст, перевод, речь) с целью найти оптимальную последовательность при минимизации вычислительных затрат. Вместо того, чтобы исследовать все возможные пути (что экспоненциально дорого), beam search хранит только несколько наиболее вероятных кандидатов на каждом шаге.

Основная идея

В отличие от жадного поиска (greedy search), который выбирает только самое вероятное слово на каждом шаге, beam search поддерживает k лучших гипотез (beam width). Например, с beam width = 3, на каждом шаге мы следим за 3 наиболее вероятными последовательностями.

Простой пример

Представьте генерацию простого предложения с beam width = 2:

Шаг 1 (стартовое слово):

  • Вероятность "The" = 0.8
  • Вероятность "A" = 0.15
  • Вероятность "An" = 0.05

Оставляем top-2: "The" и "A"

Шаг 2 (второе слово): Для каждого из top-2 кандидатов рассчитываем вероятности следующего слова:

  • "The" + "cat" = 0.8 * 0.6 = 0.48
  • "The" + "dog" = 0.8 * 0.3 = 0.24
  • "A" + "cat" = 0.15 * 0.7 = 0.105
  • "A" + "dog" = 0.15 * 0.2 = 0.03

Оставляем top-2: "The cat" (0.48) и "The dog" (0.24)

Шаг 3: Повторяем процесс и так далее...

Реализация на практике

import numpy as np
from heapq import heappush, heappop

class BeamSearch:
    def __init__(self, vocab_size, beam_width=3, max_length=20):
        self.vocab_size = vocab_size
        self.beam_width = beam_width
        self.max_length = max_length
    
    def search(self, model, start_token, vocab_size):
        # Инициализируем: (логарифм вероятности, последовательность, состояние)
        # Используем логарифмы для числовой стабильности
        beams = [(-0.0, [start_token], None)]  # (score, sequence, state)
        completed = []
        
        for step in range(self.max_length):
            candidates = []
            
            for score, seq, state in beams:
                # Модель предсказывает распределение для следующего токена
                logits, new_state = model(seq[-1], state)
                
                # Берем логарифм вероятностей
                log_probs = np.log(softmax(logits) + 1e-10)
                
                # Рассматриваем top-k наиболее вероятных токенов
                top_k = np.argsort(log_probs)[-self.beam_width:]
                
                for token in top_k:
                    new_score = score + log_probs[token]
                    new_seq = seq + [token]
                    candidates.append((new_score, new_seq, new_state))
            
            # Сортируем по очкам и оставляем top beam_width
            candidates.sort(reverse=True)
            beams = candidates[:self.beam_width]
            
            # Проверяем, завершены ли последовательности (достигнут END токен)
            new_beams = []
            for score, seq, state in beams:
                if seq[-1] == vocab_size - 1:  # END token
                    completed.append((score, seq))
                else:
                    new_beams.append((score, seq, state))
            beams = new_beams
            
            if not beams:
                break
        
        # Добавляем неполные последовательности
        completed.extend(beams)
        
        # Возвращаем лучшие результаты
        completed.sort(reverse=True)
        return completed[:self.beam_width]

Beam Search в TensorFlow/Keras

import tensorflow as tf
from tensorflow.text import BeamSearchDecoder

class SequenceToSequenceModel(tf.keras.Model):
    def __init__(self, vocab_size, embedding_dim, units):
        super().__init__()
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        self.lstm = tf.keras.layers.LSTM(units, return_state=True)
        self.dense = tf.keras.layers.Dense(vocab_size)
    
    def call(self, inputs, states=None):
        x = self.embedding(inputs)
        output, h, c = self.lstm(x, initial_state=states)
        logits = self.dense(output)
        return logits, [h, c]

# Использование
model = SequenceToSequenceModel(vocab_size=10000, embedding_dim=256, units=512)

# Процесс декодирования с beam search
beam_width = 5
max_length = 50

# В реальных фреймворках это делается встроенными функциями
# tf.keras.layers.BeamSearchDecoder или torch.nn.utils.beam_search

Сравнение стратегий генерации

import numpy as np
from scipy.special import softmax

def greedy_search(model, start_token, max_length, vocab_size):
    """Выбирает самый вероятный токен на каждом шаге"""
    seq = [start_token]
    for _ in range(max_length):
        logits = model.predict([seq])
        next_token = np.argmax(logits[0, -1])
        seq.append(next_token)
        if next_token == vocab_size - 1:  # END token
            break
    return seq

def random_sampling(model, start_token, max_length, vocab_size, temperature=1.0):
    """Случайно выбирает токен пропорционально вероятности"""
    seq = [start_token]
    for _ in range(max_length):
        logits = model.predict([seq])[0, -1]
        probs = softmax(logits / temperature)
        next_token = np.random.choice(vocab_size, p=probs)
        seq.append(next_token)
        if next_token == vocab_size - 1:
            break
    return seq

def top_k_sampling(model, start_token, max_length, vocab_size, k=10):
    """Выбирает случайно из top-k вероятных токенов"""
    seq = [start_token]
    for _ in range(max_length):
        logits = model.predict([seq])[0, -1]
        probs = softmax(logits)
        top_k_indices = np.argsort(probs)[-k:]
        top_k_probs = probs[top_k_indices]
        top_k_probs /= top_k_probs.sum()
        next_token = np.random.choice(top_k_indices, p=top_k_probs)
        seq.append(next_token)
        if next_token == vocab_size - 1:
            break
    return seq
МетодСкоростьКачествоРазнообразие
Greedy SearchБыстроСреднееНизкое
Beam SearchМедленнееВысокоеСреднее
Random SamplingБыстроНизкоеВысокое
Top-k SamplingМедленнееХорошееВысокое

Параметры beam search

Beam Width (k):

  • k=1 - это greedy search
  • k=3-5 - типичный выбор
  • k>10 - редко дает значительное улучшение
# Пример: оптимизация beam width
for beam_width in [1, 3, 5, 10]:
    results = beam_search(model, beam_width=beam_width)
    print(f"Beam width {beam_width}: BLEU score = {evaluate(results):.3f}")

Length Penalty

Без штрафа за длину beam search предпочитает короткие последовательности (они имеют высокую совместную вероятность). Поэтому часто используется длинный штраф:

def length_normalized_score(sequence_score, length, alpha=0.6):
    # alpha=0 означает без штрафа
    # alpha=1 означает полная нормализация на длину
    return sequence_score / (length ** alpha)

Практические применения

  • Machine Translation (MT): Google Translate, DeepL
  • Text Generation: GPT, T5
  • Speech Recognition: Automatic Speech Recognition (ASR)
  • Summarization: Extractive и abstractive summarization
  • Question Answering: Генеративные модели для ответов

Выводы

Beam search - это мощный компромисс между качеством (как при переборе всех путей) и вычислительной эффективностью. Он используется в большинстве производственных систем генерации последовательностей, где нужны релевантные и вероятные результаты.

Как работает beam search? | PrepBro