← Назад к вопросам

Что такое прунинг деревьев?

2.0 Middle🔥 171 комментариев
#Машинное обучение

Комментарии (1)

🐱
claude-haiku-4.5PrepBro AI30 мар. 2026 г.(ред.)

Ответ сгенерирован нейросетью и может содержать ошибки

Прунинг деревьев (Tree Pruning)

Прунинг (pruning) — метод сокращения размера дерева решений путём удаления ветвей, которые не улучшают качество предсказаний на независимых данных. Это один из самых эффективных способов борьбы с переобучением в деревьях решений.

Проблема переобучения в деревьях

Без ограничений дерево растёт до полного разделения всех примеров в листьях (чистота = 100%). Это приводит к переобучению на тренировочных данных и плохой обобщаемости.

import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# Данные
X = np.random.rand(100, 4)
y = (X[:, 0] + X[:, 1] > 1).astype(int)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)

# Без ограничений
tree_full = DecisionTreeClassifier(random_state=42)
tree_full.fit(X_train, y_train)

# Ограничиваем глубину
tree_limited = DecisionTreeClassifier(max_depth=3, random_state=42)
tree_limited.fit(X_train, y_train)

print(f"Full Tree - Train: {tree_full.score(X_train, y_train):.3f}, Test: {tree_full.score(X_test, y_test):.3f}")
print(f"Limited Tree - Train: {tree_limited.score(X_train, y_train):.3f}, Test: {tree_limited.score(X_test, y_test):.3f}")
print(f"Переобучение: {tree_full.score(X_train, y_train) - tree_full.score(X_test, y_test):.3f}")

Виды прунинга

1. Pre-pruning (Препрунинг)

Останавливает рост дерева до того, как оно полностью разовьётся. Критерии остановки:

from sklearn.tree import DecisionTreeClassifier

# Настройки pre-pruning
tree_prepruned = DecisionTreeClassifier(
    max_depth=5,                  # Максимальная глубина
    min_samples_split=10,         # Минимум примеров для разделения
    min_samples_leaf=5,           # Минимум примеров в листе
    max_leaf_nodes=10,            # Максимум листьев
    min_impurity_decrease=0.01    # Минимальное снижение примеси
)

tree_prepruned.fit(X_train, y_train)
print(f"Pre-pruned - Глубина: {tree_prepruned.get_depth()}")
print(f"Pre-pruned - Листьев: {tree_prepruned.get_n_leaves()}")

Параметры pre-pruning:

  • max_depth: максимальная глубина дерева
  • min_samples_split: минимум примеров для разделения узла
  • min_samples_leaf: минимум примеров в каждом листе
  • max_leaf_nodes: максимальное количество листьев
  • min_impurity_decrease: минимальное снижение примеси для разделения

2. Post-pruning (Постпрунинг)

Сначала строит полное дерево, затем удаляет ветви, не улучшающие качество на validation наборе.

from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split

# Разбиваем на train, validation, test
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.4)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5)

# Строим полное дерево
tree_full = DecisionTreeClassifier(random_state=42)
tree_full.fit(X_train, y_train)

# Используем alpha параметр для постпрунинга
ccp_alphas = tree_full.cost_complexity_pruning_path(X_val, y_val)[0]

# Строим последовательность деревьев
trees = []
for ccp_alpha in ccp_alphas:
    tree = DecisionTreeClassifier(ccp_alpha=ccp_alpha, random_state=42)
    tree.fit(X_train, y_train)
    trees.append(tree)

# Выбираем лучшее дерево по validation accuracy
val_scores = [tree.score(X_val, y_val) for tree in trees]
best_tree = trees[np.argmax(val_scores)]

test_score = best_tree.score(X_test, y_test)
print(f"Best tree test accuracy: {test_score:.3f}")
print(f"Best tree depth: {best_tree.get_depth()}")

Как работает cost_complexity pruning

# Функция стоимости дерева
# Cost(T) = Error(T) + alpha * |leaves(T)|

# alpha = 0: используется полное дерево
# alpha = ∞: используется только корневой узел

# Пример расчёта
alphas = [0, 0.01, 0.05, 0.1, 0.5]
for alpha in alphas:
    tree = DecisionTreeClassifier(ccp_alpha=alpha, random_state=42)
    tree.fit(X_train, y_train)
    leaves = tree.get_n_leaves()
    train_acc = tree.score(X_train, y_train)
    val_acc = tree.score(X_val, y_val)
    print(f"alpha={alpha:4.2f} | leaves={leaves:3d} | train={train_acc:.3f} | val={val_acc:.3f}")

Полный пример: Выбор оптимального дерева

import matplotlib.pyplot as plt

# Строим полное дерево
tree = DecisionTreeClassifier(random_state=42)
tree.fit(X_train, y_train)

# Получаем alpha значения
path = tree.cost_complexity_pruning_path(X_val, y_val)
alphas = path.ccp_alphas[:-1]  # Исключаем последний (только корень)

# Тренируем деревья для каждой alpha
trees = []
for alpha in alphas:
    t = DecisionTreeClassifier(ccp_alpha=alpha, random_state=42)
    t.fit(X_train, y_train)
    trees.append(t)

# Вычисляем метрики
train_accs = [t.score(X_train, y_train) for t in trees]
val_accs = [t.score(X_val, y_val) for t in trees]
leaves = [t.get_n_leaves() for t in trees]

# Визуализируем
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# График 1: Accuracy vs alpha
axes[0].plot(alphas, train_accs, "o-", label="Train", linewidth=2)
axes[0].plot(alphas, val_accs, "s-", label="Validation", linewidth=2)
axes[0].set_xlabel("Alpha (complexity parameter)")
axes[0].set_ylabel("Accuracy")
axes[0].set_xscale("log")
axes[0].legend()
axes[0].grid()

# График 2: Размер дерева vs alpha
axes[1].plot(alphas, leaves, "d-", color="green", linewidth=2)
axes[1].set_xlabel("Alpha")
axes[1].set_ylabel("Number of leaves")
axes[1].set_xscale("log")
axes[1].grid()

plt.tight_layout()
plt.show()

# Выбираем оптимальное дерево
best_idx = np.argmax(val_accs)
best_alpha = alphas[best_idx]
best_tree = trees[best_idx]

print(f"Optimal alpha: {best_alpha:.4f}")
print(f"Final test accuracy: {best_tree.score(X_test, y_test):.3f}")
print(f"Tree depth: {best_tree.get_depth()}")
print(f"Number of leaves: {best_tree.get_n_leaves()}")

Pre-pruning vs Post-pruning

ХарактеристикаPre-pruningPost-pruning
Скорость обученияБыстрееМедленнее
КачествоХорошееОбычно лучше
СложностьПросто настроитьТребует validation данных
Поиск оптимумаМожет пропуститьБолее систематичный
ИспользованиеRandomForestDecision Tree

Применение в Random Forest

from sklearn.ensemble import RandomForestClassifier

# Random Forest использует pre-pruning параметры
rf = RandomForestClassifier(
    n_estimators=100,
    max_depth=5,
    min_samples_split=10,
    min_samples_leaf=5,
    random_state=42
)

rf.fit(X_train, y_train)
print(f"Random Forest test accuracy: {rf.score(X_test, y_test):.3f}")

Лучшие практики

  • Используй pre-pruning для быстрого обучения больших датасетов
  • Используй post-pruning для получения лучшего качества
  • Всегда проверяй performance на validation наборе
  • Экспериментируй с разными параметрами
  • Визуализируй дерево, чтобы понять, какие ветви удаляются

Прунинг — один из самых важных инструментов для борьбы с переобучением в древесных моделях.

Что такое прунинг деревьев? | PrepBro