← Назад к вопросам

Как решаешь проблему сильного дисбаланса данных для мультиклассовой классификации?

1.0 Junior🔥 152 комментариев
#Машинное обучение

Комментарии (2)

🐱
claude-haiku-4.5PrepBro AI29 мар. 2026 г.(ред.)

Ответ сгенерирован нейросетью и может содержать ошибки

Решение проблемы сильного дисбаланса в мультиклассовой классификации

Дисбаланс классов в мультиклассовой задаче — часто встречающаяся и сложная проблема. Нужен комплексный подход, учитывающий все классы.

1. Анализ дисбаланса

Оценка распределения классов

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

class_dist = pd.Series(y).value_counts().sort_index()
print("Class Distribution:")
print(class_dist)
print("\nClass Distribution (%):")
print((class_dist / len(y) * 100).round(2))

# Коэффициент дисбаланса
imbalance_ratio = class_dist.max() / class_dist.min()
print(f"\nImbalance Ratio: {imbalance_ratio:.2f}")

# Визуализация
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
class_dist.plot(kind="bar", ax=axes[0], color="steelblue")
axes[0].set_title("Class Distribution (Absolute)")
axes[0].set_ylabel("Count")
axes[0].set_xlabel("Class")

(class_dist / len(y) * 100).plot(kind="bar", ax=axes[1], color="coral")
axes[1].set_title("Class Distribution (%)")
axes[1].set_ylabel("Percentage")
axes[1].set_xlabel("Class")

plt.tight_layout()
plt.show()

2. Техники переsampling

SMOTE для мультиклассовой классификации

from imblearn.over_sampling import SMOTE, ADASYN
from imblearn.under_sampling import RandomUnderSampler
from imblearn.pipeline import Pipeline as ImbPipeline

# SMOTE для мультиклассовой задачи
smote = SMOTE(
    sampling_strategy="not majority",
    k_neighbors=5,
    random_state=42
)

X_train_smote, y_train_smote = smote.fit_resample(X_train, y_train)

print("After SMOTE:")
print(pd.Series(y_train_smote).value_counts().sort_index())

# ADASYN (адаптивный SMOTE)
adasyn = ADASYN(
    sampling_strategy="not majority",
    n_neighbors=5,
    random_state=42
)

X_train_adasyn, y_train_adasyn = adasyn.fit_resample(X_train, y_train)

# Комбинированная стратегия SMOTE + UnderSampling
pipeline = ImbPipeline([
    ("smote", SMOTE(sampling_strategy=0.7, random_state=42)),
    ("under", RandomUnderSampler(sampling_strategy=0.8, random_state=42))
])

X_train_combined, y_train_combined = pipeline.fit_resample(X_train, y_train)

print("\nAfter SMOTE + UnderSampling:")
print(pd.Series(y_train_combined).value_counts().sort_index())

3. Взвешивание классов

Class Weights в градиентном бустинге

from xgboost import XGBClassifier
from lightgbm import LGBMClassifier
from catboost import CatBoostClassifier
from sklearn.utils.class_weight import compute_class_weight

class_weights = compute_class_weight(
    "balanced",
    classes=np.unique(y_train),
    y=y_train
)

weights_dict = dict(zip(np.unique(y_train), class_weights))
print("Class Weights:")
print(weights_dict)

# XGBoost с взвешиванием
sample_weights = np.array([weights_dict[label] for label in y_train])

model_xgb = XGBClassifier(
    n_estimators=100,
    max_depth=5,
    learning_rate=0.1,
    random_state=42
)

model_xgb.fit(
    X_train, y_train,
    sample_weight=sample_weights
)

# LightGBM с class_weight
model_lgb = LGBMClassifier(
    n_estimators=100,
    num_leaves=31,
    class_weight="balanced",
    random_state=42
)

model_lgb.fit(X_train, y_train)

# CatBoost с auto_class_weights
model_cat = CatBoostClassifier(
    iterations=100,
    depth=5,
    auto_class_weights="balanced",
    random_state=42,
    verbose=False
)

model_cat.fit(X_train, y_train)

4. Стратифицированная кроссвалидация

from sklearn.model_selection import StratifiedKFold, cross_validate

skfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

# Проверка распределения в каждом фолде
for fold, (train_idx, val_idx) in enumerate(skfold.split(X, y)):
    print(f"Fold {fold+1}:")
    train_dist = pd.Series(y[train_idx]).value_counts(normalize=True)
    val_dist = pd.Series(y[val_idx]).value_counts(normalize=True)
    print("Train:", train_dist.to_dict())
    print("Val:", val_dist.to_dict())

# Метрики для дисбаланса
scoring = {
    "accuracy": "accuracy",
    "macro_f1": "f1_macro",
    "weighted_f1": "f1_weighted",
    "macro_precision": "precision_macro",
    "macro_recall": "recall_macro"
}

cv_results = cross_validate(
    model_xgb,
    X_train,
    y_train,
    cv=skfold,
    scoring=scoring,
    n_jobs=-1
)

results_df = pd.DataFrame(cv_results)
print("\nCross-Validation Results:")
print(results_df[["test_accuracy", "test_macro_f1", "test_weighted_f1"]].describe())

5. Выбор правильных метрик оценки

from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score
)

y_pred = model_xgb.predict(X_test)
y_pred_proba = model_xgb.predict_proba(X_test)

print("Classification Report:")
print(classification_report(y_test, y_pred, digits=3))

# Макро- и взвешенные метрики
print(f"\nMacro F1-Score: {f1_score(y_test, y_pred, average=\"macro\"):.3f}")
print(f"Weighted F1-Score: {f1_score(y_test, y_pred, average=\"weighted\"):.3f}")
print(f"Macro Precision: {precision_score(y_test, y_pred, average=\"macro\"):.3f}")
print(f"Macro Recall: {recall_score(y_test, y_pred, average=\"macro\"):.3f}")

# One-vs-Rest ROC-AUC
try:
    roc_auc_ovr = roc_auc_score(
        y_test,
        y_pred_proba,
        multi_class=\"ovr\",
        average=\"macro\"
    )
    print(f"ROC-AUC (OvR, Macro): {roc_auc_ovr:.3f}")
except:
    print("ROC-AUC not available")

# Матрица ошибок
cm = confusion_matrix(y_test, y_pred)
print("\nConfusion Matrix:")
print(cm)

# Нормализованная матрица
cm_normalized = cm.astype(\"float\") / cm.sum(axis=1)[:, np.newaxis]
print("\nNormalized Confusion Matrix:")
print(cm_normalized.round(3))

import seaborn as sns
plt.figure(figsize=(8, 6))
sns.heatmap(cm_normalized, annot=True, fmt=".2f", cmap="Blues")
plt.title("Normalized Confusion Matrix")
plt.ylabel("True Label")
plt.xlabel("Predicted Label")
plt.show()

6. Комплексный пайплайн

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

pipeline = ImbPipeline([
    ("scaler", StandardScaler()),
    ("smote", SMOTE(sampling_strategy=0.6, random_state=42)),
    ("model", XGBClassifier(
        n_estimators=100,
        max_depth=5,
        learning_rate=0.1,
        tree_method=\"hist\",
        random_state=42
    ))
])

pipeline.fit(X_train, y_train)

y_pred = pipeline.predict(X_test)
y_pred_proba = pipeline.predict_proba(X_test)

print("Pipeline Performance:")
print(classification_report(y_test, y_pred))

# Важность признаков
feature_importance = pd.DataFrame({
    "feature": X_train.columns,
    "importance": pipeline.named_steps[\"model\"].feature_importances_
}).sort_values(\"importance\", ascending=False)

print("\nTop Features:")
print(feature_importance.head(10))

7. Обработка экстремального дисбаланса

from imblearn.combine import SMOTETomek

# SMOTETomek = SMOTE + Tomek Links
smotetomek = SMOTETomek(random_state=42)
X_train_st, y_train_st = smotetomek.fit_resample(X_train, y_train)

# Focal Loss через sample_weight
focal_alpha = 0.25
focal_gamma = 2.0

class_sample_count = np.array(
    [(y_train == c).sum() for c in np.unique(y_train)]
)
weight = 1 / (class_sample_count ** focal_gamma)
samples_weight = np.array([weight[int(t)] for t in y_train])

model = XGBClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train, sample_weight=samples_weight)

Рекомендации по выбору методов

Стратегия в зависимости от уровня дисбаланса:

  • Слабый (1:5): взвешивание + стратифицированная CV
  • Средний (1:20): SMOTE + взвешивание + стратификация
  • Сильный (1:50+): SMOTE+Undersampling + взвешивание + focal loss идея

Обязательные шаги:

  1. Анализируй распределение в train/val/test
  2. Используй стратифицированную CV для надёжной оценки
  3. Выбирай метрики: macro F1, weighted F1, ROC-AUC
  4. Комбинируй методы: переsampling + взвешивание
  5. Проверяй каждый класс отдельно в classification_report
  6. Используй confusion matrix для анализа ошибок
  7. Документируй выводы для стейкхолдеров
Как решаешь проблему сильного дисбаланса данных для мультиклассовой классификации? | PrepBro