← Назад к вопросам

Как определить оптимальное значение сплита?

2.0 Middle🔥 121 комментариев
#Машинное обучение#Метрики и оценка моделей

Комментарии (1)

🐱
claude-haiku-4.5PrepBro AI29 мар. 2026 г.(ред.)

Ответ сгенерирован нейросетью и может содержать ошибки

Как определить оптимальное значение сплита

Сплит (split) — это разделение данных на тренировочный и тестовый наборы, или поиск оптимального порога для бинарной классификации. Рассмотрим оба понимания этого вопроса.

1. ОПТИМАЛЬНЫЙ РАЗМЕР TRAIN/TEST СПЛИТА

Классические рекомендации:

from sklearn.model_selection import train_test_split
import numpy as np

# Рекомендации по размеру сплита в зависимости от размера датасета

def recommend_split_ratio(n_samples):
    """
    Рекомендует оптимальный train/test раздел
    """
    if n_samples < 1000:
        # Маленький датасет: 70/30 или 80/20
        return 0.80, 0.20, "Small dataset: use 80/20"
    elif n_samples < 10000:
        # Средний датасет: 75/25 или 80/20
        return 0.80, 0.20, "Medium dataset: use 80/20"
    elif n_samples < 100000:
        # Большой датасет: 80/20 или 85/15
        return 0.85, 0.15, "Large dataset: use 85/15"
    else:
        # Очень большой датасет: 90/10 или 95/5
        return 0.95, 0.05, "Very large dataset: use 95/5"

for n in [500, 5000, 50000, 500000]:
    train, test, msg = recommend_split_ratio(n)
    print(f"{n:,} samples: {msg} (train={train*100:.0f}%, test={test*100:.0f}%)")

# Стандартное разделение
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
    X, y,
    test_size=0.2,  # 80/20 split
    random_state=42,
    stratify=y  # для классификации
)

2. ОПТИМАЛЬНЫЙ ПОРОГ КЛАССИФИКАЦИИ (threshold)

Это более сложная задача - найти оптимальный порог для преобразования вероятностей в классы.

from sklearn.metrics import (
    precision_recall_curve, roc_curve, auc,
    f1_score, accuracy_score, precision_score, recall_score
)
import numpy as np

# Получить вероятности с модели
y_pred_proba = model.predict_proba(X_test)[:, 1]  # вероятности класса 1

# По умолчанию используется threshold = 0.5
# Но это не всегда оптимально

def find_optimal_threshold(y_true, y_pred_proba, metric='f1'):
    """
    Найти оптимальный порог для классификации
    """
    thresholds = np.arange(0.01, 1.00, 0.01)
    scores = []
    
    for threshold in thresholds:
        y_pred = (y_pred_proba >= threshold).astype(int)
        
        if metric == 'f1':
            score = f1_score(y_true, y_pred)
        elif metric == 'accuracy':
            score = accuracy_score(y_true, y_pred)
        elif metric == 'precision':
            score = precision_score(y_true, y_pred, zero_division=0)
        elif metric == 'recall':
            score = recall_score(y_true, y_pred, zero_division=0)
        
        scores.append(score)
    
    optimal_idx = np.argmax(scores)
    optimal_threshold = thresholds[optimal_idx]
    
    return optimal_threshold, thresholds, scores

# Найти оптимальный threshold
optimal_threshold, thresholds, scores = find_optimal_threshold(
    y_test, y_pred_proba, metric='f1'
)

print(f"Optimal threshold: {optimal_threshold:.2f}")
print(f"F1-score at optimal: {scores[np.argmax(scores)]:.3f}")

3. МЕТОД YOUDEN'S INDEX (J-статистика)

from sklearn.metrics import roc_curve

def find_optimal_threshold_youden(y_true, y_pred_proba):
    """
    Используем статистику Youden's Index
    J = Sensitivity + Specificity - 1
    """
    fpr, tpr, thresholds = roc_curve(y_true, y_pred_proba)
    
    # Youden's index
    j_scores = tpr - fpr
    optimal_idx = np.argmax(j_scores)
    optimal_threshold = thresholds[optimal_idx]
    
    print(f"Optimal threshold (Youden): {optimal_threshold:.3f}")
    print(f"Sensitivity (TPR): {tpr[optimal_idx]:.3f}")
    print(f"Specificity (1-FPR): {1-fpr[optimal_idx]:.3f}")
    
    return optimal_threshold

optimal_th = find_optimal_threshold_youden(y_test, y_pred_proba)

4. МЕТОД PRECISION-RECALL CURVE

from sklearn.metrics import precision_recall_curve, f1_score

def find_optimal_threshold_pr(y_true, y_pred_proba):
    """
    Используем Precision-Recall кривую
    """
    precision, recall, thresholds = precision_recall_curve(
        y_true, y_pred_proba
    )
    
    # F1-score для каждого threshold
    # F1 = 2 * (precision * recall) / (precision + recall)
    f1_scores = 2 * (precision[:-1] * recall[:-1]) / (
        precision[:-1] + recall[:-1] + 1e-10
    )
    
    optimal_idx = np.argmax(f1_scores)
    optimal_threshold = thresholds[optimal_idx]
    
    print(f"Optimal threshold (PR): {optimal_threshold:.3f}")
    print(f"Precision: {precision[optimal_idx]:.3f}")
    print(f"Recall: {recall[optimal_idx]:.3f}")
    print(f"F1-score: {f1_scores[optimal_idx]:.3f}")
    
    return optimal_threshold

optimal_th = find_optimal_threshold_pr(y_test, y_pred_proba)

5. НЕЛИНЕЙНАЯ ОПТИМИЗАЦИЯ

from scipy.optimize import minimize_scalar
from sklearn.metrics import f1_score

def objective_function(threshold, y_true, y_pred_proba):
    """
    Функция для минимизации (минус F1, т.к. minimize ищет минимум)
    """
    y_pred = (y_pred_proba >= threshold).astype(int)
    return -f1_score(y_true, y_pred)  # минус для минимизации

# Оптимизация
result = minimize_scalar(
    objective_function,
    bounds=(0, 1),
    args=(y_test, y_pred_proba),
    method='bounded'
)

optimal_threshold = result.x
max_f1 = -result.fun

print(f"Optimal threshold (optimization): {optimal_threshold:.3f}")
print(f"Max F1-score: {max_f1:.3f}")

6. ВИЗУАЛИЗАЦИЯ И ВЫБОР

import matplotlib.pyplot as plt

def visualize_threshold_analysis(y_true, y_pred_proba):
    """
    Визуализирует различные метрики в зависимости от threshold
    """
    thresholds = np.arange(0.01, 1.00, 0.01)
    accuracies = []
    precisions = []
    recalls = []
    f1_scores = []
    specificities = []
    
    for threshold in thresholds:
        y_pred = (y_pred_proba >= threshold).astype(int)
        
        accuracies.append(accuracy_score(y_true, y_pred))
        precisions.append(precision_score(y_true, y_pred, zero_division=0))
        recalls.append(recall_score(y_true, y_pred, zero_division=0))
        f1_scores.append(f1_score(y_true, y_pred, zero_division=0))
        
        # Specificity = TN / (TN + FP)
        tn = np.sum((y_pred == 0) & (y_true == 0))
        fp = np.sum((y_pred == 1) & (y_true == 0))
        spec = tn / (tn + fp) if (tn + fp) > 0 else 0
        specificities.append(spec)
    
    # Визуализация
    fig, ax = plt.subplots(figsize=(12, 6))
    
    ax.plot(thresholds, accuracies, label='Accuracy', linewidth=2)
    ax.plot(thresholds, precisions, label='Precision', linewidth=2)
    ax.plot(thresholds, recalls, label='Recall', linewidth=2)
    ax.plot(thresholds, f1_scores, label='F1-score', linewidth=2)
    ax.plot(thresholds, specificities, label='Specificity', linewidth=2)
    
    ax.axvline(x=0.5, color='r', linestyle='--', alpha=0.5, label='Default (0.5)')
    
    # Найти оптимальный
    optimal_idx = np.argmax(f1_scores)
    optimal_threshold = thresholds[optimal_idx]
    ax.axvline(x=optimal_threshold, color='g', linestyle='--', alpha=0.5, label=f'Optimal ({optimal_threshold:.2f})')
    
    ax.set_xlabel('Threshold')
    ax.set_ylabel('Score')
    ax.set_title('Metrics vs Threshold')
    ax.legend()
    ax.grid()
    
    plt.tight_layout()
    return fig, optimal_threshold

fig, optimal = visualize_threshold_analysis(y_test, y_pred_proba)
plt.show()

7. ВЫБОР В ЗАВИСИМОСТИ ОТ ЗАДАЧИ

def choose_threshold_by_use_case(use_case):
    """
    Рекомендует threshold в зависимости от задачи
    """
    recommendations = {
        'spam_detection': {
            'metric': 'precision',
            'priority': 'Минимизировать false positives',
            'typical_threshold': 0.7,
            'reason': 'Лучше пропустить спам, чем заблокировать письмо'
        },
        'disease_diagnosis': {
            'metric': 'recall',
            'priority': 'Минимизировать false negatives',
            'typical_threshold': 0.3,
            'reason': 'Лучше переконсультировать, чем пропустить болезнь'
        },
        'credit_approval': {
            'metric': 'f1',
            'priority': 'Баланс false positives и false negatives',
            'typical_threshold': 0.5,
            'reason': 'Баланс между отклонением и одобрением'
        },
        'fraud_detection': {
            'metric': 'recall',
            'priority': 'Минимизировать false negatives',
            'typical_threshold': 0.4,
            'reason': 'Лучше проверить, чем пропустить мошенничество'
        }
    }
    
    if use_case in recommendations:
        rec = recommendations[use_case]
        print(f"\nUse case: {use_case}")
        print(f"Metric to optimize: {rec['metric']}")
        print(f"Priority: {rec['priority']}")
        print(f"Typical threshold: {rec['typical_threshold']}")
        print(f"Reason: {rec['reason']}")
    
    return recommendations.get(use_case)

# Примеры
for use_case in ['spam_detection', 'disease_diagnosis', 'credit_approval', 'fraud_detection']:
    choose_threshold_by_use_case(use_case)

8. КРОСС-ВАЛИДАЦИЯ ДЛЯ ВЫБОРА СПЛИТА

from sklearn.model_selection import cross_validate

def evaluate_split_with_cv(model, X, y, test_size=0.2, cv=5):
    """
    Проверить стабильность выбранного сплита через CV
    """
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, random_state=42, stratify=y
    )
    
    # CV на тренировочном наборе
    cv_results = cross_validate(
        model, X_train, y_train,
        cv=cv,
        scoring=['accuracy', 'precision', 'recall', 'f1']
    )
    
    # Обучение на полном training set
    model.fit(X_train, y_train)
    test_score = model.score(X_test, y_test)
    
    print(f"\nTest size: {test_size*100:.0f}%")
    print(f"CV mean accuracy: {cv_results['test_accuracy'].mean():.3f}")
    print(f"Test accuracy: {test_score:.3f}")
    print(f"Difference: {abs(cv_results['test_accuracy'].mean() - test_score):.3f}")
    
    if abs(cv_results['test_accuracy'].mean() - test_score) < 0.03:
        print("✅ Split is stable and representative")
    else:
        print("⚠️  Split may not be representative")

Итоги

TRAIN/TEST СПЛИТ:

  • Маленький датасет (< 1K): 80/20 или 70/30
  • Средний датасет (1K-100K): 80/20
  • Большой датасет (> 100K): 85/15 или 90/10
  • Используй stratify для классификации

ОПТИМАЛЬНЫЙ THRESHOLD:

  • F1-score: для сбалансированных данных
  • Precision: если важны false positives
  • Recall: если важны false negatives
  • Youden's Index: для медицины
  • Всегда тестируй на validation set

ВЫБОР:

  1. Визуализируй метрики
  2. Выбери в зависимости от бизнес-требований
  3. Валидируй на независимом тестовом наборе
Как определить оптимальное значение сплита? | PrepBro