Multiclass Classification¶

In [1]:
import os
import warnings
import logging

# configure logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)

# get warning filter policy from the environment variables
# set to "ignore" for rendering the HTMLs, or to "once" otherwise
WARNING_FILTER_POLICY = os.getenv("WARNING_FILTER_POLICY", "once")
logger.info(f"{WARNING_FILTER_POLICY = }")
warnings.filterwarnings(WARNING_FILTER_POLICY)
21:11:50 [INFO] WARNING_FILTER_POLICY = 'ignore'
In [2]:
import itertools
import numpy as np
import pandas as pd
import seaborn as sns
import shap

from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split, RepeatedStratifiedKFold, GridSearchCV
from statsmodels.stats.outliers_influence import variance_inflation_factor
from xgboost import XGBClassifier

pd.set_option("display.max_columns", None)
pd.options.display.float_format = "{:,.2f}".format
In [3]:
from utils.constants import RANDOM_SEED
from utils.common import (
    get_data_folder_path,
    set_plotting_config,
    plot_boxplot_by_class,
    plot_correlation_matrix,
)
from utils.evals import (
    describe_input_features,
    plot_confusion_matrix,
    plot_target_rate,
    compute_multiclass_classification_metrics,
    build_coefficients_table,
    plot_coefficients_values,
    plot_coefficients_significance,
    plot_eval_metrics_xgb,
    plot_gain_metric_xgb,
    plot_shap_importance,
    plot_shap_beeswarm,
    plot_roc_curve,
)
from utils.feature_selection import run_feature_selection_steps
In [4]:
# plots configuration
sns.set_style("darkgrid")
sns.set_palette("colorblind")
set_plotting_config()
%matplotlib inline

1. Load Data¶

In this notebook, we will use the Fetal Health Dataset. This dataset comprises 2126 records of features from Cardiotocogram exams, classified by experts into Normal, Suspect, and Pathological to assess fetal health and help reduce child and maternal mortality.

Sources:

  1. Kaggle: https://www.kaggle.com/datasets/andrewmvd/fetal-health-classification
  2. Original article: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC68223152
In [5]:
data_path = get_data_folder_path()

df_input = pd.read_csv(os.path.join(data_path, "fetal_health.csv"))
df_input.columns = [col.replace(" ", "_") for col in df_input.columns]

2. Process Data¶

Target column¶

Fetal health (target column) can have the following values:

  • 1: Normal
  • 2: Suspect
  • 3: Pathological

However, XGBoost expects 0-indexed positive integers for the classes. Therefore, we will use the following values in this notebook:

  • 0: Normal
  • 1: Suspect
  • 2: Pathological
In [6]:
target_col = "fetal_health"
target_classes_dict = {
    0: "Normal",
    1: "Suspect",
    2: "Pathological",
}
test_size = 0.20
In [7]:
# convert target column to integer
df_input[target_col] = df_input[target_col].astype(np.int8) - np.int8(1)  # subtract 1 to make it 0-indexed

Train test split¶

In [8]:
df_input_train, df_input_test = train_test_split(
    df_input,
    test_size=test_size,
    stratify=df_input[target_col],
    random_state=RANDOM_SEED,
)
In [9]:
pd.concat([
    pd.Series(target_classes_dict, name="label"),
    df_input_train[target_col].value_counts(dropna=False, normalize=False).rename("train_target_count"),
    df_input_train[target_col].value_counts(dropna=False, normalize=True).rename("train_target_pct"),
    df_input_test[target_col].value_counts(dropna=False, normalize=False).rename("test_target_count"),
    df_input_test[target_col].value_counts(dropna=False, normalize=True).rename("test_target_pct"),
], axis=1)
Out[9]:
label train_target_count train_target_pct test_target_count test_target_pct
0 Normal 1323 0.78 332 0.78
1 Suspect 236 0.14 59 0.14
2 Pathological 141 0.08 35 0.08
In [10]:
describe_input_features(df_input, df_input_train, df_input_test)
Out[10]:
data_type count null_count min 25% 50% 75% max std mean mean_train mean_test train_test_pct_diff
baseline_value numeric 2126 0 106.00 126.00 133.00 140.00 160.00 9.84 133.30 133.09 134.17 0.01
accelerations numeric 2126 0 0.00 0.00 0.00 0.01 0.02 0.00 0.00 0.00 0.00 -0.03
fetal_movement numeric 2126 0 0.00 0.00 0.00 0.00 0.48 0.05 0.01 0.01 0.01 -0.23
uterine_contractions numeric 2126 0 0.00 0.00 0.00 0.01 0.01 0.00 0.00 0.00 0.00 0.01
light_decelerations numeric 2126 0 0.00 0.00 0.00 0.00 0.01 0.00 0.00 0.00 0.00 0.09
severe_decelerations numeric 2126 0 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.60
prolongued_decelerations numeric 2126 0 0.00 0.00 0.00 0.00 0.01 0.00 0.00 0.00 0.00 -0.10
abnormal_short_term_variability numeric 2126 0 12.00 32.00 49.00 61.00 87.00 17.19 46.99 46.72 48.06 0.03
mean_value_of_short_term_variability numeric 2126 0 0.20 0.70 1.20 1.70 7.00 0.88 1.33 1.34 1.32 -0.01
percentage_of_time_with_abnormal_long_term_variability numeric 2126 0 0.00 0.00 0.00 11.00 91.00 18.40 9.85 9.95 9.42 -0.05
mean_value_of_long_term_variability numeric 2126 0 0.00 4.60 7.40 10.80 50.70 5.63 8.19 8.17 8.27 0.01
histogram_width numeric 2126 0 3.00 37.00 67.50 100.00 180.00 38.96 70.45 70.40 70.61 0.00
histogram_min numeric 2126 0 50.00 67.00 93.00 120.00 159.00 29.56 93.58 93.43 94.19 0.01
histogram_max numeric 2126 0 122.00 152.00 162.00 174.00 238.00 17.94 164.03 163.83 164.81 0.01
histogram_number_of_peaks numeric 2126 0 0.00 2.00 3.00 6.00 18.00 2.95 4.07 4.08 4.03 -0.01
histogram_number_of_zeroes numeric 2126 0 0.00 0.00 0.00 0.00 10.00 0.71 0.32 0.31 0.36 0.15
histogram_mode numeric 2126 0 60.00 129.00 139.00 148.00 187.00 16.38 137.45 137.25 138.24 0.01
histogram_mean numeric 2126 0 73.00 125.00 136.00 145.00 182.00 15.59 134.61 134.44 135.30 0.01
histogram_median numeric 2126 0 77.00 129.00 139.00 148.00 186.00 14.47 138.09 137.92 138.77 0.01
histogram_variance numeric 2126 0 0.00 2.00 7.00 24.00 269.00 28.98 18.81 18.89 18.49 -0.02
histogram_tendency numeric 2126 0 -1.00 0.00 0.00 1.00 1.00 0.61 0.32 0.32 0.32 0.01
fetal_health numeric 2126 0 0.00 0.00 0.00 0.00 2.00 0.61 0.30 0.30 0.30 -0.01

Scaling (Standardization)¶

In [11]:
# Standardize training and test data
stdscaler = StandardScaler()

# training data
X_train_all = (
    pd.DataFrame(
        # fit scaler on training data (and then transform training data)
        data=stdscaler.fit_transform(df_input_train),
        columns=df_input_train.columns,
        index=df_input_train.index
    )
    # remove target from the model input features table
    .drop(columns=[target_col])
)
y_train = df_input_train[target_col]
y_train_ohe = pd.get_dummies(y_train, dtype=np.int8)  # one-hot encoding for plots

# test data
y_test = df_input_test[target_col]
X_test_all = (
    pd.DataFrame(
        # use scaler fitted on training data to transform test data
        data=stdscaler.transform(df_input_test),
        columns=df_input_test.columns,
        index=df_input_test.index
    )
    # remove target from the model input features table
    .drop(columns=[target_col])
)
y_test_ohe = pd.get_dummies(y_test, dtype=np.int8)  # one-hot encoding for plots

3. Exploratory Data Analysis (EDA)¶

Boxplots by Target Class¶

In [12]:
display(
    plot_boxplot_by_class(
        df_input=df_input_train,  # use only training data to avoid bias in test results
        class_col=target_col,
        class_mapping=target_classes_dict,
        plots_per_line=5,
        title="Features in input dataset",
    )
)
No description has been provided for this image

Pearson's Correlation¶

In [13]:
display(
    plot_correlation_matrix(
        # use only training data to avoid bias in test results
        df=df_input_train, method="pearson", fig_height=10
    )
)
No description has been provided for this image

4. Feature Selection¶

In [14]:
fs_steps = {
    "manual": {
        "cols_to_exclude": [
            "severe_decelerations",
        ]
    },
    "null_variance": None,
    "correlation": {"threshold": 0.75},
    "vif": {"threshold": 2},
    "l1_regularization": {
        "problem_type": "classification",
        "train_test_split_params": {"test_size": test_size},
        "logspace_search": {"start": -5, "stop": 1, "num": 20, "base": 10},
        # tolerance over minimum error with which to search for the best model
        "error_tolerance_pct": 0.05,
        # minimum features to keep in final selection
        "min_feats_to_keep": 3,
        "random_seed": RANDOM_SEED,
    },
}
In [15]:
# run Feature Selection separately for each class as binary classifications
selected_feats_ovr = {}
fs_tables_ovr = {}

for clss, label in target_classes_dict.items():
    logger.info(f"Running Feature Selection for Class '{label}' (vs Rest)")
    selected_feats_ovr[clss], fs_tables_ovr[clss] = run_feature_selection_steps(
        # use only training data to avoid bias in test results
        X=X_train_all,
        y=y_train_ohe[clss],
        fs_steps=fs_steps
    )
    logger.info("-" * 100)
21:11:54 [INFO] Running Feature Selection for Class 'Normal' (vs Rest)
21:11:54 [INFO] --> Starting Feature Selection with 21 features
21:11:54 [INFO] 1. MANUAL FILTER
21:11:54 [INFO]  - Removing 1 (4.8%) feature(s) manually: ['severe_decelerations']
21:11:54 [INFO] 2. NULL_VARIANCE FILTER
21:11:54 [INFO]  - Removing 0 (0.0%) feature(s) with null variance (var == 0): []
21:11:54 [INFO] 3. CORRELATION FILTER
21:11:54 [INFO]   Running Correlation filter with threshold of 0.75
21:11:54 [INFO]  - Removing feature 'histogram_median' with correlation +0.9478 to 'histogram_mean'
21:11:54 [INFO]  - Removing feature 'histogram_mean' with correlation +0.8915 to 'histogram_mode'
21:11:54 [INFO]  - Removing feature 'histogram_width' with correlation -0.8997 to 'histogram_min'
21:11:54 [INFO]  - Removing 3 (15.0%) feature(s) with abs(correlation) > 0.75
21:11:54 [INFO] 4. VIF FILTER
21:11:54 [INFO] Computing the Variance Inflation Factor (VIF) for 17 features...
21:11:54 [INFO]   1. Removing feature: histogram_min .......................................... VIF: 4.84
21:11:55 [INFO]   2. Removing feature: histogram_mode ......................................... VIF: 4.27
21:11:55 [INFO]   3. Removing feature: histogram_max .......................................... VIF: 3.20
21:11:55 [INFO]   4. Removing feature: light_decelerations .................................... VIF: 2.53
21:11:55 [INFO]   5. Removing feature: mean_value_of_short_term_variability ................... VIF: 2.17
21:11:55 [INFO]   >> Stopping at feat: histogram_variance ..................................... VIF: 1.80  (threshold: 2)
21:11:55 [INFO]  - Removing 5 (29.4%) feature(s) with VIF >= 2
21:11:55 [INFO] 5. L1_REGULARIZATION FILTER
21:11:55 [INFO]  - Removing 8 (66.7%) feature(s) with null coefficient after L1 regularization: ['baseline_value', 'fetal_movement', 'uterine_contractions', 'mean_value_of_long_term_variability', 'histogram_number_of_peaks', 'histogram_number_of_zeroes', 'histogram_variance', 'histogram_tendency']
21:11:55 [INFO] --> Completed Feature Selection with 4 selected features (19.0% of the original 21 features): ['accelerations', 'prolongued_decelerations', 'abnormal_short_term_variability', 'percentage_of_time_with_abnormal_long_term_variability']
21:11:55 [INFO] ----------------------------------------------------------------------------------------------------
21:11:55 [INFO] Running Feature Selection for Class 'Suspect' (vs Rest)
21:11:55 [INFO] --> Starting Feature Selection with 21 features
21:11:55 [INFO] 1. MANUAL FILTER
21:11:55 [INFO]  - Removing 1 (4.8%) feature(s) manually: ['severe_decelerations']
21:11:55 [INFO] 2. NULL_VARIANCE FILTER
21:11:55 [INFO]  - Removing 0 (0.0%) feature(s) with null variance (var == 0): []
21:11:55 [INFO] 3. CORRELATION FILTER
21:11:55 [INFO]   Running Correlation filter with threshold of 0.75
21:11:55 [INFO]  - Removing feature 'histogram_mode' with correlation +0.9309 to 'histogram_median'
21:11:55 [INFO]  - Removing feature 'histogram_width' with correlation -0.8997 to 'histogram_min'
21:11:55 [INFO]  - Removing feature 'histogram_median' with correlation +0.9478 to 'histogram_mean'
21:11:55 [INFO]  - Removing 3 (15.0%) feature(s) with abs(correlation) > 0.75
21:11:55 [INFO] 4. VIF FILTER
21:11:55 [INFO] Computing the Variance Inflation Factor (VIF) for 17 features...
21:11:55 [INFO]   1. Removing feature: histogram_mean ......................................... VIF: 10.51
21:11:55 [INFO]   2. Removing feature: histogram_min .......................................... VIF: 4.62
21:11:55 [INFO]   3. Removing feature: histogram_max .......................................... VIF: 3.20
21:11:55 [INFO]   4. Removing feature: light_decelerations .................................... VIF: 2.53
21:11:55 [INFO]   5. Removing feature: mean_value_of_short_term_variability ................... VIF: 2.17
21:11:55 [INFO]   >> Stopping at feat: histogram_variance ..................................... VIF: 1.80  (threshold: 2)
21:11:55 [INFO]  - Removing 5 (29.4%) feature(s) with VIF >= 2
21:11:55 [INFO] 5. L1_REGULARIZATION FILTER
21:11:55 [INFO]  - Removing 0 (0.0%) feature(s) with null coefficient after L1 regularization: []
21:11:55 [INFO] --> Completed Feature Selection with 12 selected features (57.1% of the original 21 features): ['baseline_value', 'accelerations', 'fetal_movement', 'uterine_contractions', 'prolongued_decelerations', 'abnormal_short_term_variability', 'percentage_of_time_with_abnormal_long_term_variability', 'mean_value_of_long_term_variability', 'histogram_number_of_peaks', 'histogram_number_of_zeroes', 'histogram_variance', 'histogram_tendency']
21:11:55 [INFO] ----------------------------------------------------------------------------------------------------
21:11:55 [INFO] Running Feature Selection for Class 'Pathological' (vs Rest)
21:11:55 [INFO] --> Starting Feature Selection with 21 features
21:11:55 [INFO] 1. MANUAL FILTER
21:11:55 [INFO]  - Removing 1 (4.8%) feature(s) manually: ['severe_decelerations']
21:11:55 [INFO] 2. NULL_VARIANCE FILTER
21:11:55 [INFO]  - Removing 0 (0.0%) feature(s) with null variance (var == 0): []
21:11:55 [INFO] 3. CORRELATION FILTER
21:11:55 [INFO]   Running Correlation filter with threshold of 0.75
21:11:55 [INFO]  - Removing feature 'baseline_value' with correlation +0.7864 to 'histogram_median'
21:11:55 [INFO]  - Removing feature 'histogram_width' with correlation -0.8997 to 'histogram_min'
21:11:55 [INFO]  - Removing feature 'histogram_median' with correlation +0.9478 to 'histogram_mean'
21:11:55 [INFO]  - Removing feature 'histogram_mode' with correlation +0.8915 to 'histogram_mean'
21:11:55 [INFO]  - Removing 4 (20.0%) feature(s) with abs(correlation) > 0.75
21:11:55 [INFO] 4. VIF FILTER
21:11:55 [INFO] Computing the Variance Inflation Factor (VIF) for 16 features...
21:11:55 [INFO]   1. Removing feature: histogram_mean ......................................... VIF: 5.57
21:11:55 [INFO]   2. Removing feature: histogram_min .......................................... VIF: 3.46
21:11:55 [INFO]   3. Removing feature: light_decelerations .................................... VIF: 2.53
21:11:55 [INFO]   4. Removing feature: histogram_variance ..................................... VIF: 2.18
21:11:55 [INFO]   >> Stopping at feat: mean_value_of_short_term_variability ................... VIF: 1.96  (threshold: 2)
21:11:55 [INFO]  - Removing 4 (25.0%) feature(s) with VIF >= 2
21:11:55 [INFO] 5. L1_REGULARIZATION FILTER
21:11:55 [INFO]  - Removing 2 (16.7%) feature(s) with null coefficient after L1 regularization: ['accelerations', 'histogram_number_of_peaks']
21:11:55 [INFO] --> Completed Feature Selection with 10 selected features (47.6% of the original 21 features): ['fetal_movement', 'uterine_contractions', 'prolongued_decelerations', 'abnormal_short_term_variability', 'mean_value_of_short_term_variability', 'percentage_of_time_with_abnormal_long_term_variability', 'mean_value_of_long_term_variability', 'histogram_max', 'histogram_number_of_zeroes', 'histogram_tendency']
21:11:55 [INFO] ----------------------------------------------------------------------------------------------------
In [16]:
# keep only the features that were selected for at least 2 classes
MIN_NUM_SELECTIONS = 2

classes_intersections = []
for classes_group in itertools.combinations(selected_feats_ovr.keys(), MIN_NUM_SELECTIONS):
    classes_intersections.append(
        set.intersection(*[set(selected_feats_ovr[clss]) for clss in classes_group])
    )
    
selected_feats = list(set.union(*classes_intersections))
print(f"Final selection ({len(selected_feats)} features selected):")
for feat in sorted(selected_feats):
    print(f"  - {feat}")
Final selection (9 features selected):
  - abnormal_short_term_variability
  - accelerations
  - fetal_movement
  - histogram_number_of_zeroes
  - histogram_tendency
  - mean_value_of_long_term_variability
  - percentage_of_time_with_abnormal_long_term_variability
  - prolongued_decelerations
  - uterine_contractions
In [17]:
# build model input datasets
X_train = X_train_all[selected_feats]
X_test = X_test_all[selected_feats]

Correlation check¶

In [18]:
display(
    plot_correlation_matrix(
        # use only training data to avoid bias in test results
        df=df_input_train[selected_feats + [target_col]], method="pearson", fig_height=5
    )
)
No description has been provided for this image

Multicollinearity check¶

In [19]:
# compute the Variance Inflation Factor (VIF) for each feature
df_vif = pd.DataFrame(
    data=[variance_inflation_factor(X_train.values, i) for i in range(len(selected_feats))],
    index=selected_feats,
    columns=["VIF"]
).sort_values("VIF", ascending=False)

df_vif
Out[19]:
VIF
percentage_of_time_with_abnormal_long_term_variability 1.62
abnormal_short_term_variability 1.53
mean_value_of_long_term_variability 1.41
accelerations 1.39
prolongued_decelerations 1.36
uterine_contractions 1.18
fetal_movement 1.14
histogram_tendency 1.08
histogram_number_of_zeroes 1.06

5. Classifier Model¶

Select classifier: Logistic Regression or XGBoost¶

In [20]:
# MODEL_SELECTION = "logistic_regression"
MODEL_SELECTION = "xgboost"

model_selection_error = ValueError(
    "'MODEL_SELECTION' must be either 'logistic_regression' or 'xgboost'. "
    f"Got {MODEL_SELECTION} instead."
)

Hyperparameter tuning with K-Fold Cross Validation¶

For a detailed explanation of XGBoost's parameters, refer to: https://www.kaggle.com/code/prashant111/a-guide-on-xgboost-hyperparameters-tuning/notebook

In [21]:
if MODEL_SELECTION == "logistic_regression":
    Estimator = LogisticRegression
    cv_search_space = {
        "penalty": ["l1", "l2", "elasticnet"],
        "solver": ["saga"],
        "C": np.logspace(-3, 1, num=9, base=10.0),
        "class_weight": [None],
    }
elif MODEL_SELECTION == "xgboost":
    Estimator = XGBClassifier
    cv_search_space = {
        "objective": ["multi:softmax"],
        'num_class': [len(target_classes_dict)],
        "n_estimators": [30, 40, 50],
        "learning_rate": [0.1],
        "max_depth": [3, 4, 6],
        "min_child_weight": [2, 4],
        "gamma": [0, 0.5],
        "alpha": [0, 0.3],
        "scale_pos_weight": [1],
        "lambda": [1],
        ## "subsample": [0.8, 1.0],
        ## "colsample_bytree": [0.8, 1.0],
        "verbosity": [0],
    }
else:
    raise model_selection_error

For the full list of scikit-learn's scoring string names, refer to: https://scikit-learn.org/stable/modules/model_evaluation.html#string-name-scorers

In [22]:
cv_scoring_metrics = {
    "accuracy": "Accuracy",
    "precision_macro": "Precision (macro)",
    "recall_macro": "Recall (macro)",
    "f1_macro": "F1 Score (macro)",
    "precision_weighted": "Precision (weighted)",
    "recall_weighted": "Recall (weighted)",
    "f1_weighted": "F1 Score (weighted)",
    "roc_auc_ovr": "ROC AUC One-vs-Rest (macro)",
    "roc_auc_ovo": "ROC AUC One-vs-One (macro)",
    "roc_auc_ovr_weighted": "ROC AUC One-vs-Rest (weighted)",
    "roc_auc_ovo_weighted": "ROC AUC One-vs-One (weighted)",
}
refit_metric = "f1_weighted"  # metric to optimize for the final model
In [23]:
%%time
# define evaluation
kfold_cv = RepeatedStratifiedKFold(n_splits=3, n_repeats=1, random_state=RANDOM_SEED)
# define search
grid_search = GridSearchCV(
    estimator=Estimator(),
    param_grid=cv_search_space,
    scoring=list(cv_scoring_metrics.keys()),
    cv=kfold_cv,
    refit=refit_metric,
    verbose=1,
)
# execute search
result_cv = grid_search.fit(X_train, y_train)
Fitting 3 folds for each of 72 candidates, totalling 216 fits
CPU times: user 19.4 s, sys: 43.4 s, total: 1min 2s
Wall time: 34 s
In [24]:
print("Grid Search CV Best Model - Scoring Metrics (averaging method):")
for i, (metric_key, metric_name) in enumerate(cv_scoring_metrics.items(), start=1):
    print(
        f" {str(i) + ".":>3} {metric_name:.<31} "
        f"{result_cv.cv_results_[f"mean_test_{metric_key}"][result_cv.best_index_]:.3f}"
    )
print(f"\nBest Hyperparameters: {result_cv.best_params_}")
Grid Search CV Best Model - Scoring Metrics (averaging method):
  1. Accuracy....................... 0.931
  2. Precision (macro).............. 0.903
  3. Recall (macro)................. 0.856
  4. F1 Score (macro)............... 0.875
  5. Precision (weighted)........... 0.929
  6. Recall (weighted).............. 0.931
  7. F1 Score (weighted)............ 0.928
  8. ROC AUC One-vs-Rest (macro).... 0.969
  9. ROC AUC One-vs-One (macro)..... 0.964
 10. ROC AUC One-vs-Rest (weighted). 0.968
 11. ROC AUC One-vs-One (weighted).. 0.969

Best Hyperparameters: {'alpha': 0, 'gamma': 0, 'lambda': 1, 'learning_rate': 0.1, 'max_depth': 6, 'min_child_weight': 4, 'n_estimators': 50, 'num_class': 3, 'objective': 'multi:softmax', 'scale_pos_weight': 1, 'verbosity': 0}

Final Model¶

In [25]:
# instantiate model with best hyperparameters and additional kwargs
if MODEL_SELECTION == "logistic_regression":
    model_kwargs = dict()
    model_fit_kwargs = dict()
elif MODEL_SELECTION == "xgboost":
    eval_metrics = dict(
        mlogloss="Binary Cross-entropy Loss (Log-loss)",
        merror="Binary Classification Error Rate",
        auc="ROC AUC",
    )
    model_kwargs = dict(eval_metric=list(eval_metrics.keys()))
    model_fit_kwargs = dict(
        eval_set=[(X_train, y_train), (X_test, y_test)],
        verbose=False
    )
else:
    raise model_selection_error
    
model = Estimator(**result_cv.best_params_, **model_kwargs, random_state=RANDOM_SEED)
In [26]:
# Fit model and make predictions
model.fit(X_train, y_train, **model_fit_kwargs)
# Make probabilities predictions
y_pred_proba_train = pd.DataFrame(
    data=model.predict_proba(X_train), columns=model.classes_, index=X_train.index
)
y_pred_proba = pd.DataFrame(
    data=model.predict_proba(X_test), columns=model.classes_, index=X_test.index
)
# Make class predictions
y_pred_train = pd.Series(
    data=model.predict(X_train), index=X_train.index, name=target_col
)
y_pred = pd.Series(
    data=model.predict(X_test), index=X_test.index, name=target_col
)
In [27]:
if MODEL_SELECTION == "xgboost":
    display(plot_eval_metrics_xgb(model.evals_result(), eval_metrics))
No description has been provided for this image

Plot target rate per group of predicted probability

A good model should have increasing target rate for each group of predicted probability (e.g. quartiles, deciles)

In [28]:
for clss, label in target_classes_dict.items():
    title = f"Class '{label}': Target rate per group of predicted probability"
    display(
        plot_target_rate(y_test=y_test_ohe[clss], y_pred_proba=y_pred_proba[clss], title=title)
    )
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

Feature Importance¶

  • For Logistic Regression: coefficients values and statistical significance
  • For XGBoost: SHAP analysis and Gain Metric
In [29]:
if MODEL_SELECTION == "logistic_regression":
    for clss, coefficients, intercept in zip(
        model.classes_, model.coef_, model.intercept_
    ):
        label = f"Class '{target_classes_dict[clss]}'"
        print(label)
        df_coefficients = build_coefficients_table(
            coefficients=coefficients,
            intercept=intercept,
            X_train=X_train,
            y_pred_train=y_pred_proba_train[clss],
            y_train=y_train_ohe[clss],
            problem_type="classification",
        )
        display(
            plot_coefficients_values(
                df_coefficients,
                title=f"{label}: Coefficient Values with 95% CI (±1.96 Std Error)"
            ),
            plot_coefficients_significance(
                df_coefficients,
                log_scale=bool(df_coefficients["p-values"].max() < 2e-4),
                title=f"{label}: Coefficients' Significance (p-values)"
            ),
        )
    
elif MODEL_SELECTION == "xgboost":
    # compute SHAP values
    explainer = shap.Explainer(model)
    shap_values = explainer(X_test)
    # shap plots
    for i, clss in enumerate(model.classes_):
        label = f"Class '{target_classes_dict[clss]}'"
        print(label)
        display(
            plot_shap_importance(
                shap_values[:, :, i], title=f"{label}: SHAP Feature Importance"
            ),
            plot_shap_beeswarm(
                shap_values[:, :, i], title=f"{label}: SHAP Summary Plot"
            ),
            plot_gain_metric_xgb(
                model, X_test, title=f"{label}: XGBoost Feature Importance (Gain metric)"
            )
        )

else:
    raise model_selection_error
Class 'Normal'
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
Class 'Suspect'
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
Class 'Pathological'
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

Performance Metrics¶

In [30]:
df_train_metrics = pd.Series(
    compute_multiclass_classification_metrics(y_train, y_pred_train, y_pred_proba_train)
).to_frame(name="Train Metrics")
df_test_metrics = pd.Series(
    compute_multiclass_classification_metrics(y_test, y_pred, y_pred_proba)
).to_frame(name="Test Metrics")

print("Final Model - Scoring Metrics on Train & Test Datasets:")
df_metrics = df_train_metrics.join(df_test_metrics)
display(df_metrics)
Final Model - Scoring Metrics on Train & Test Datasets:
Train Metrics Test Metrics
Accuracy 0.97 0.92
Precision (macro) 0.97 0.89
Recall (macro) 0.92 0.82
F1 Score (macro) 0.94 0.85
ROC AUC One-vs-Rest (macro) 1.00 0.96
ROC AUC One-vs-One (macro) 0.99 0.95
Precision (weighted) 0.97 0.91
Recall (weighted) 0.97 0.92
F1 Score (weighted) 0.97 0.91
ROC AUC One-vs-Rest (weighted) 0.99 0.96
ROC AUC One-vs-One (weighted) 1.00 0.96

Confusion Matrix¶

In [31]:
# Confusion Matrix
display(
    plot_confusion_matrix(
        y_test,
        y_pred,
        estimator=model,
        target_classes_dict=target_classes_dict,
        normalize="true",
    )
)
No description has been provided for this image

ROC AUC¶

In [32]:
for clss, label in target_classes_dict.items():
    title = f"Class '{label}': ROC Curve One-vs-Rest"
    display(
        plot_roc_curve(
            y_true=y_test_ohe[clss], y_pred_proba=y_pred_proba[clss], title=title,
        )
    )
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image