← Назад к вопросам

Как строится дерево в задачи классификации?

1.2 Junior🔥 211 комментариев
#Машинное обучение

Комментарии (1)

🐱
claude-haiku-4.5PrepBro AI29 мар. 2026 г.(ред.)

Ответ сгенерирован нейросетью и может содержать ошибки

Как строится дерево решений в задачах классификации

Дерево решений (Decision Tree) — это интерпретируемая модель машинного обучения, которая разделяет пространство признаков на прямоугольные регионы с рекурсивными сплитами. Рассмотрим полный процесс построения.

1. Основная идея

Дерево разбивает выборку на подгруппы, минимизируя примеси (impurity) в каждом узле. На каждом шаге выбирается признак и его значение, которые лучше всего разделяют классы.

from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt

# Простой пример
dt = DecisionTreeClassifier(max_depth=3, random_state=42)
dt.fit(X_train, y_train)

# Визуализация
plt.figure(figsize=(20,10))
plot_tree(dt, feature_names=feature_names, class_names=class_names, filled=True)
plt.show()

print(f"Accuracy: {dt.score(X_test, y_test):.4f}")
print(f"Tree depth: {dt.get_depth()}")
print(f"Leaf count: {dt.get_n_leaves()}")

2. Критерии расщепления (Splitting Criteria)

Для выбора лучшего разбиения используются метрики, которые измеряют "чистоту" подмножеств.

A. Gini Index (CART алгоритм)

Gini измеряет вероятность неправильной классификации случайного элемента.

def calculate_gini(y):
    """Расчёт Gini индекса"""
    _, counts = np.unique(y, return_counts=True)
    probabilities = counts / len(y)
    gini = 1.0 - np.sum(probabilities ** 2)
    return gini

def gini_split(y_parent, y_left, y_right):
    """Информационный выигрыш при разбиении (Gini gain)"""
    n = len(y_parent)
    n_left = len(y_left)
    n_right = len(y_right)
    
    # Взвешенный Gini детей
    gini_left = calculate_gini(y_left)
    gini_right = calculate_gini(y_right)
    gini_children = (n_left/n) * gini_left + (n_right/n) * gini_right
    
    # Информационный выигрыш
    gini_gain = calculate_gini(y_parent) - gini_children
    return gini_gain

B. Entropy и Information Gain (ID3/C4.5 алгоритмы)

Enтропия измеряет неопределённость.

from scipy.stats import entropy

def calculate_entropy(y):
    """Shannon энтропия"""
    _, counts = np.unique(y, return_counts=True)
    probabilities = counts / len(y)
    ent = entropy(probabilities, base=2)
    return ent

def information_gain(y_parent, y_left, y_right):
    """Информационный выигрыш"""
    n = len(y_parent)
    n_left = len(y_left)
    n_right = len(y_right)
    
    # Взвешенная энтропия детей
    entropy_left = calculate_entropy(y_left)
    entropy_right = calculate_entropy(y_right)
    entropy_children = (n_left/n) * entropy_left + (n_right/n) * entropy_right
    
    # Информационный выигрыш
    ig = calculate_entropy(y_parent) - entropy_children
    return ig

# Пример
y_parent = np.array([0, 0, 0, 1, 1])  # 60% класс 0, 40% класс 1
y_left = np.array([0, 0, 0])           # 100% класс 0 (чистый)
y_right = np.array([1, 1])              # 100% класс 1 (чистый)

ig = information_gain(y_parent, y_left, y_right)
print(f"Information Gain: {ig:.4f}")  # Высокое значение = хорошее разбиение

3. Рекурсивный алгоритм построения (CART)

class SimpleDecisionTree:
    def __init__(self, max_depth=None, min_samples_split=2):
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.tree = None
    
    def _build_tree(self, X, y, depth=0):
        """Рекурсивное построение дерева"""
        n_samples = len(y)
        n_classes = len(np.unique(y))
        
        # Условия остановки
        if (depth >= self.max_depth or
            n_samples < self.min_samples_split or
            n_classes == 1):  # Узел чистый
            return {"type": "leaf", "value": np.argmax(np.bincount(y))}
        
        best_gain = 0
        best_split = None
        
        # Перебираем все признаки
        for feature_idx in range(X.shape[1]):
            # Перебираем все уникальные значения
            thresholds = np.unique(X[:, feature_idx])
            
            for threshold in thresholds:
                # Разбиваем по признаку
                left_mask = X[:, feature_idx] <= threshold
                right_mask = ~left_mask
                
                if np.sum(left_mask) == 0 or np.sum(right_mask) == 0:
                    continue
                
                # Рассчитываем Gini gain
                gain = gini_split(y, y[left_mask], y[right_mask])
                
                if gain > best_gain:
                    best_gain = gain
                    best_split = {
                        "feature": feature_idx,
                        "threshold": threshold
                    }
        
        # Если не нашли хорошее разбиение
        if best_split is None:
            return {"type": "leaf", "value": np.argmax(np.bincount(y))}
        
        # Разбиваем и рекурсивно строим поддеревья
        feature = best_split["feature"]
        threshold = best_split["threshold"]
        left_mask = X[:, feature] <= threshold
        
        return {
            "type": "node",
            "feature": feature,
            "threshold": threshold,
            "left": self._build_tree(X[left_mask], y[left_mask], depth+1),
            "right": self._build_tree(X[~left_mask], y[~left_mask], depth+1)
        }
    
    def fit(self, X, y):
        self.tree = self._build_tree(X, y)
        return self
    
    def _predict_sample(self, x, node):
        if node["type"] == "leaf":
            return node["value"]
        
        if x[node["feature"]] <= node["threshold"]:
            return self._predict_sample(x, node["left"])
        else:
            return self._predict_sample(x, node["right"])
    
    def predict(self, X):
        return np.array([self._predict_sample(x, self.tree) for x in X])

4. Гиперпараметры и их влияние

from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import GridSearchCV

# max_depth — максимальная глубина дерева
# Маленькое значение → недообучение, большое → переобучение

# min_samples_split — минимум samples для разбиения узла
# Увеличение → проще модель, меньше переобучения

# min_samples_leaf — минимум samples в листовом узле
# Защита от очень маленьких узлов

# max_features — количество признаков для поиска best split
# "sqrt", "log2" уменьшают переобучение в ensemble'ях

params_grid = {
    "max_depth": [3, 5, 10, 15, None],
    "min_samples_split": [2, 5, 10],
    "min_samples_leaf": [1, 2, 4],
    "max_features": ["sqrt", "log2", None]
}

grid_search = GridSearchCV(
    DecisionTreeClassifier(random_state=42),
    params_grid,
    cv=5,
    scoring="accuracy"
)

grid_search.fit(X_train, y_train)
print(f"Best params: {grid_search.best_params_}")
print(f"Best CV score: {grid_search.best_score_:.4f}")

5. Pruning (Обрезка)

Обрезка удаляет узлы, которые не улучшают качество на тестовом наборе.

from sklearn.tree import DecisionTreeClassifier

# Cost complexity pruning (sklearn)
dt_full = DecisionTreeClassifier(random_state=42)
dt_full.fit(X_train, y_train)

# Получаем параметры сложности
ccp_alphas = dt_full.cost_complexity_pruning_path(X_train, y_train)[0]

# Обучаем дерево с разными alphas
trees = []
for ccp_alpha in ccp_alphas:
    tree = DecisionTreeClassifier(random_state=42, ccp_alpha=ccp_alpha)
    tree.fit(X_train, y_train)
    trees.append(tree)

# Выбираем лучший по валидации
val_scores = [tree.score(X_val, y_val) for tree in trees]
best_tree = trees[np.argmax(val_scores)]

6. Интерпретируемость и важность признаков

dt = DecisionTreeClassifier(max_depth=5)
dt.fit(X_train, y_train)

# Важность признаков (основана на Gini decrease)
feature_importance = pd.DataFrame({
    "feature": feature_names,
    "importance": dt.feature_importances_
}).sort_values("importance", ascending=False)

print(feature_importance)

# Визуализация правил
from sklearn.tree import export_text

tree_rules = export_text(dt, feature_names=feature_names)
print(tree_rules)

7. Сравнение Gini vs Entropy

ПараметрGiniEntropy
СкоростьБыстрееМедленнее
РезультатыОбычно схожиОбычно схожи
ИнтерпретацияПрощеСложнее
ИспользованиеCART (sklearn default)ID3, C4.5, C5.0

8. Плюсы и минусы деревьев

Плюсы:

  • Интерпретируемость
  • Обработка категориальных признаков
  • Не требует масштабирования
  • Захватывает нелинейные зависимости

Минусы:

  • Переобучение
  • Нестабильность (small changes → big changes в дереве)
  • Слабое качество на больших признаковых пространствах
  • Решение: использовать Random Forest или Gradient Boosting

Вывод: Дерево строится рекурсивно путём выбора лучших сплитов (максимизируя Gini gain или Information gain), разбивая пространство на чистые регионы до условий остановки. Обрезка и правильные гиперпараметры защищают от переобучения.