Комментарии (1)
🐱
claude-haiku-4.5PrepBro AI30 мар. 2026 г.(ред.)
Ответ сгенерирован нейросетью и может содержать ошибки
Прунинг деревьев (Tree Pruning)
Прунинг (pruning) — метод сокращения размера дерева решений путём удаления ветвей, которые не улучшают качество предсказаний на независимых данных. Это один из самых эффективных способов борьбы с переобучением в деревьях решений.
Проблема переобучения в деревьях
Без ограничений дерево растёт до полного разделения всех примеров в листьях (чистота = 100%). Это приводит к переобучению на тренировочных данных и плохой обобщаемости.
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# Данные
X = np.random.rand(100, 4)
y = (X[:, 0] + X[:, 1] > 1).astype(int)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)
# Без ограничений
tree_full = DecisionTreeClassifier(random_state=42)
tree_full.fit(X_train, y_train)
# Ограничиваем глубину
tree_limited = DecisionTreeClassifier(max_depth=3, random_state=42)
tree_limited.fit(X_train, y_train)
print(f"Full Tree - Train: {tree_full.score(X_train, y_train):.3f}, Test: {tree_full.score(X_test, y_test):.3f}")
print(f"Limited Tree - Train: {tree_limited.score(X_train, y_train):.3f}, Test: {tree_limited.score(X_test, y_test):.3f}")
print(f"Переобучение: {tree_full.score(X_train, y_train) - tree_full.score(X_test, y_test):.3f}")
Виды прунинга
1. Pre-pruning (Препрунинг)
Останавливает рост дерева до того, как оно полностью разовьётся. Критерии остановки:
from sklearn.tree import DecisionTreeClassifier
# Настройки pre-pruning
tree_prepruned = DecisionTreeClassifier(
max_depth=5, # Максимальная глубина
min_samples_split=10, # Минимум примеров для разделения
min_samples_leaf=5, # Минимум примеров в листе
max_leaf_nodes=10, # Максимум листьев
min_impurity_decrease=0.01 # Минимальное снижение примеси
)
tree_prepruned.fit(X_train, y_train)
print(f"Pre-pruned - Глубина: {tree_prepruned.get_depth()}")
print(f"Pre-pruned - Листьев: {tree_prepruned.get_n_leaves()}")
Параметры pre-pruning:
- max_depth: максимальная глубина дерева
- min_samples_split: минимум примеров для разделения узла
- min_samples_leaf: минимум примеров в каждом листе
- max_leaf_nodes: максимальное количество листьев
- min_impurity_decrease: минимальное снижение примеси для разделения
2. Post-pruning (Постпрунинг)
Сначала строит полное дерево, затем удаляет ветви, не улучшающие качество на validation наборе.
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
# Разбиваем на train, validation, test
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.4)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5)
# Строим полное дерево
tree_full = DecisionTreeClassifier(random_state=42)
tree_full.fit(X_train, y_train)
# Используем alpha параметр для постпрунинга
ccp_alphas = tree_full.cost_complexity_pruning_path(X_val, y_val)[0]
# Строим последовательность деревьев
trees = []
for ccp_alpha in ccp_alphas:
tree = DecisionTreeClassifier(ccp_alpha=ccp_alpha, random_state=42)
tree.fit(X_train, y_train)
trees.append(tree)
# Выбираем лучшее дерево по validation accuracy
val_scores = [tree.score(X_val, y_val) for tree in trees]
best_tree = trees[np.argmax(val_scores)]
test_score = best_tree.score(X_test, y_test)
print(f"Best tree test accuracy: {test_score:.3f}")
print(f"Best tree depth: {best_tree.get_depth()}")
Как работает cost_complexity pruning
# Функция стоимости дерева
# Cost(T) = Error(T) + alpha * |leaves(T)|
# alpha = 0: используется полное дерево
# alpha = ∞: используется только корневой узел
# Пример расчёта
alphas = [0, 0.01, 0.05, 0.1, 0.5]
for alpha in alphas:
tree = DecisionTreeClassifier(ccp_alpha=alpha, random_state=42)
tree.fit(X_train, y_train)
leaves = tree.get_n_leaves()
train_acc = tree.score(X_train, y_train)
val_acc = tree.score(X_val, y_val)
print(f"alpha={alpha:4.2f} | leaves={leaves:3d} | train={train_acc:.3f} | val={val_acc:.3f}")
Полный пример: Выбор оптимального дерева
import matplotlib.pyplot as plt
# Строим полное дерево
tree = DecisionTreeClassifier(random_state=42)
tree.fit(X_train, y_train)
# Получаем alpha значения
path = tree.cost_complexity_pruning_path(X_val, y_val)
alphas = path.ccp_alphas[:-1] # Исключаем последний (только корень)
# Тренируем деревья для каждой alpha
trees = []
for alpha in alphas:
t = DecisionTreeClassifier(ccp_alpha=alpha, random_state=42)
t.fit(X_train, y_train)
trees.append(t)
# Вычисляем метрики
train_accs = [t.score(X_train, y_train) for t in trees]
val_accs = [t.score(X_val, y_val) for t in trees]
leaves = [t.get_n_leaves() for t in trees]
# Визуализируем
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
# График 1: Accuracy vs alpha
axes[0].plot(alphas, train_accs, "o-", label="Train", linewidth=2)
axes[0].plot(alphas, val_accs, "s-", label="Validation", linewidth=2)
axes[0].set_xlabel("Alpha (complexity parameter)")
axes[0].set_ylabel("Accuracy")
axes[0].set_xscale("log")
axes[0].legend()
axes[0].grid()
# График 2: Размер дерева vs alpha
axes[1].plot(alphas, leaves, "d-", color="green", linewidth=2)
axes[1].set_xlabel("Alpha")
axes[1].set_ylabel("Number of leaves")
axes[1].set_xscale("log")
axes[1].grid()
plt.tight_layout()
plt.show()
# Выбираем оптимальное дерево
best_idx = np.argmax(val_accs)
best_alpha = alphas[best_idx]
best_tree = trees[best_idx]
print(f"Optimal alpha: {best_alpha:.4f}")
print(f"Final test accuracy: {best_tree.score(X_test, y_test):.3f}")
print(f"Tree depth: {best_tree.get_depth()}")
print(f"Number of leaves: {best_tree.get_n_leaves()}")
Pre-pruning vs Post-pruning
| Характеристика | Pre-pruning | Post-pruning |
|---|---|---|
| Скорость обучения | Быстрее | Медленнее |
| Качество | Хорошее | Обычно лучше |
| Сложность | Просто настроить | Требует validation данных |
| Поиск оптимума | Может пропустить | Более систематичный |
| Использование | RandomForest | Decision Tree |
Применение в Random Forest
from sklearn.ensemble import RandomForestClassifier
# Random Forest использует pre-pruning параметры
rf = RandomForestClassifier(
n_estimators=100,
max_depth=5,
min_samples_split=10,
min_samples_leaf=5,
random_state=42
)
rf.fit(X_train, y_train)
print(f"Random Forest test accuracy: {rf.score(X_test, y_test):.3f}")
Лучшие практики
- Используй pre-pruning для быстрого обучения больших датасетов
- Используй post-pruning для получения лучшего качества
- Всегда проверяй performance на validation наборе
- Экспериментируй с разными параметрами
- Визуализируй дерево, чтобы понять, какие ветви удаляются
Прунинг — один из самых важных инструментов для борьбы с переобучением в древесных моделях.