Multiclass Classification¶
import os
import warnings
import logging
# configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)
# get warning filter policy from the environment variables
# set to "ignore" for rendering the HTMLs, or to "once" otherwise
WARNING_FILTER_POLICY = os.getenv("WARNING_FILTER_POLICY", "once")
logger.info(f"{WARNING_FILTER_POLICY = }")
warnings.filterwarnings(WARNING_FILTER_POLICY)
21:11:50 [INFO] WARNING_FILTER_POLICY = 'ignore'
import itertools
import numpy as np
import pandas as pd
import seaborn as sns
import shap
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split, RepeatedStratifiedKFold, GridSearchCV
from statsmodels.stats.outliers_influence import variance_inflation_factor
from xgboost import XGBClassifier
pd.set_option("display.max_columns", None)
pd.options.display.float_format = "{:,.2f}".format
from utils.constants import RANDOM_SEED
from utils.common import (
get_data_folder_path,
set_plotting_config,
plot_boxplot_by_class,
plot_correlation_matrix,
)
from utils.evals import (
describe_input_features,
plot_confusion_matrix,
plot_target_rate,
compute_multiclass_classification_metrics,
build_coefficients_table,
plot_coefficients_values,
plot_coefficients_significance,
plot_eval_metrics_xgb,
plot_gain_metric_xgb,
plot_shap_importance,
plot_shap_beeswarm,
plot_roc_curve,
)
from utils.feature_selection import run_feature_selection_steps
# plots configuration
sns.set_style("darkgrid")
sns.set_palette("colorblind")
set_plotting_config()
%matplotlib inline
1. Load Data¶
In this notebook, we will use the Fetal Health Dataset. This dataset comprises 2126 records of features from Cardiotocogram exams, classified by experts into Normal, Suspect, and Pathological to assess fetal health and help reduce child and maternal mortality.
Sources:
data_path = get_data_folder_path()
df_input = pd.read_csv(os.path.join(data_path, "fetal_health.csv"))
df_input.columns = [col.replace(" ", "_") for col in df_input.columns]
2. Process Data¶
Target column¶
Fetal health (target column) can have the following values:
- 1: Normal
- 2: Suspect
- 3: Pathological
However, XGBoost expects 0-indexed positive integers for the classes. Therefore, we will use the following values in this notebook:
- 0: Normal
- 1: Suspect
- 2: Pathological
target_col = "fetal_health"
target_classes_dict = {
0: "Normal",
1: "Suspect",
2: "Pathological",
}
test_size = 0.20
# convert target column to integer
df_input[target_col] = df_input[target_col].astype(np.int8) - np.int8(1) # subtract 1 to make it 0-indexed
Train test split¶
df_input_train, df_input_test = train_test_split(
df_input,
test_size=test_size,
stratify=df_input[target_col],
random_state=RANDOM_SEED,
)
pd.concat([
pd.Series(target_classes_dict, name="label"),
df_input_train[target_col].value_counts(dropna=False, normalize=False).rename("train_target_count"),
df_input_train[target_col].value_counts(dropna=False, normalize=True).rename("train_target_pct"),
df_input_test[target_col].value_counts(dropna=False, normalize=False).rename("test_target_count"),
df_input_test[target_col].value_counts(dropna=False, normalize=True).rename("test_target_pct"),
], axis=1)
label | train_target_count | train_target_pct | test_target_count | test_target_pct | |
---|---|---|---|---|---|
0 | Normal | 1323 | 0.78 | 332 | 0.78 |
1 | Suspect | 236 | 0.14 | 59 | 0.14 |
2 | Pathological | 141 | 0.08 | 35 | 0.08 |
describe_input_features(df_input, df_input_train, df_input_test)
data_type | count | null_count | min | 25% | 50% | 75% | max | std | mean | mean_train | mean_test | train_test_pct_diff | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
baseline_value | numeric | 2126 | 0 | 106.00 | 126.00 | 133.00 | 140.00 | 160.00 | 9.84 | 133.30 | 133.09 | 134.17 | 0.01 |
accelerations | numeric | 2126 | 0 | 0.00 | 0.00 | 0.00 | 0.01 | 0.02 | 0.00 | 0.00 | 0.00 | 0.00 | -0.03 |
fetal_movement | numeric | 2126 | 0 | 0.00 | 0.00 | 0.00 | 0.00 | 0.48 | 0.05 | 0.01 | 0.01 | 0.01 | -0.23 |
uterine_contractions | numeric | 2126 | 0 | 0.00 | 0.00 | 0.00 | 0.01 | 0.01 | 0.00 | 0.00 | 0.00 | 0.00 | 0.01 |
light_decelerations | numeric | 2126 | 0 | 0.00 | 0.00 | 0.00 | 0.00 | 0.01 | 0.00 | 0.00 | 0.00 | 0.00 | 0.09 |
severe_decelerations | numeric | 2126 | 0 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.60 |
prolongued_decelerations | numeric | 2126 | 0 | 0.00 | 0.00 | 0.00 | 0.00 | 0.01 | 0.00 | 0.00 | 0.00 | 0.00 | -0.10 |
abnormal_short_term_variability | numeric | 2126 | 0 | 12.00 | 32.00 | 49.00 | 61.00 | 87.00 | 17.19 | 46.99 | 46.72 | 48.06 | 0.03 |
mean_value_of_short_term_variability | numeric | 2126 | 0 | 0.20 | 0.70 | 1.20 | 1.70 | 7.00 | 0.88 | 1.33 | 1.34 | 1.32 | -0.01 |
percentage_of_time_with_abnormal_long_term_variability | numeric | 2126 | 0 | 0.00 | 0.00 | 0.00 | 11.00 | 91.00 | 18.40 | 9.85 | 9.95 | 9.42 | -0.05 |
mean_value_of_long_term_variability | numeric | 2126 | 0 | 0.00 | 4.60 | 7.40 | 10.80 | 50.70 | 5.63 | 8.19 | 8.17 | 8.27 | 0.01 |
histogram_width | numeric | 2126 | 0 | 3.00 | 37.00 | 67.50 | 100.00 | 180.00 | 38.96 | 70.45 | 70.40 | 70.61 | 0.00 |
histogram_min | numeric | 2126 | 0 | 50.00 | 67.00 | 93.00 | 120.00 | 159.00 | 29.56 | 93.58 | 93.43 | 94.19 | 0.01 |
histogram_max | numeric | 2126 | 0 | 122.00 | 152.00 | 162.00 | 174.00 | 238.00 | 17.94 | 164.03 | 163.83 | 164.81 | 0.01 |
histogram_number_of_peaks | numeric | 2126 | 0 | 0.00 | 2.00 | 3.00 | 6.00 | 18.00 | 2.95 | 4.07 | 4.08 | 4.03 | -0.01 |
histogram_number_of_zeroes | numeric | 2126 | 0 | 0.00 | 0.00 | 0.00 | 0.00 | 10.00 | 0.71 | 0.32 | 0.31 | 0.36 | 0.15 |
histogram_mode | numeric | 2126 | 0 | 60.00 | 129.00 | 139.00 | 148.00 | 187.00 | 16.38 | 137.45 | 137.25 | 138.24 | 0.01 |
histogram_mean | numeric | 2126 | 0 | 73.00 | 125.00 | 136.00 | 145.00 | 182.00 | 15.59 | 134.61 | 134.44 | 135.30 | 0.01 |
histogram_median | numeric | 2126 | 0 | 77.00 | 129.00 | 139.00 | 148.00 | 186.00 | 14.47 | 138.09 | 137.92 | 138.77 | 0.01 |
histogram_variance | numeric | 2126 | 0 | 0.00 | 2.00 | 7.00 | 24.00 | 269.00 | 28.98 | 18.81 | 18.89 | 18.49 | -0.02 |
histogram_tendency | numeric | 2126 | 0 | -1.00 | 0.00 | 0.00 | 1.00 | 1.00 | 0.61 | 0.32 | 0.32 | 0.32 | 0.01 |
fetal_health | numeric | 2126 | 0 | 0.00 | 0.00 | 0.00 | 0.00 | 2.00 | 0.61 | 0.30 | 0.30 | 0.30 | -0.01 |
Scaling (Standardization)¶
# Standardize training and test data
stdscaler = StandardScaler()
# training data
X_train_all = (
pd.DataFrame(
# fit scaler on training data (and then transform training data)
data=stdscaler.fit_transform(df_input_train),
columns=df_input_train.columns,
index=df_input_train.index
)
# remove target from the model input features table
.drop(columns=[target_col])
)
y_train = df_input_train[target_col]
y_train_ohe = pd.get_dummies(y_train, dtype=np.int8) # one-hot encoding for plots
# test data
y_test = df_input_test[target_col]
X_test_all = (
pd.DataFrame(
# use scaler fitted on training data to transform test data
data=stdscaler.transform(df_input_test),
columns=df_input_test.columns,
index=df_input_test.index
)
# remove target from the model input features table
.drop(columns=[target_col])
)
y_test_ohe = pd.get_dummies(y_test, dtype=np.int8) # one-hot encoding for plots
3. Exploratory Data Analysis (EDA)¶
Boxplots by Target Class¶
display(
plot_boxplot_by_class(
df_input=df_input_train, # use only training data to avoid bias in test results
class_col=target_col,
class_mapping=target_classes_dict,
plots_per_line=5,
title="Features in input dataset",
)
)
Pearson's Correlation¶
display(
plot_correlation_matrix(
# use only training data to avoid bias in test results
df=df_input_train, method="pearson", fig_height=10
)
)
4. Feature Selection¶
fs_steps = {
"manual": {
"cols_to_exclude": [
"severe_decelerations",
]
},
"null_variance": None,
"correlation": {"threshold": 0.75},
"vif": {"threshold": 2},
"l1_regularization": {
"problem_type": "classification",
"train_test_split_params": {"test_size": test_size},
"logspace_search": {"start": -5, "stop": 1, "num": 20, "base": 10},
# tolerance over minimum error with which to search for the best model
"error_tolerance_pct": 0.05,
# minimum features to keep in final selection
"min_feats_to_keep": 3,
"random_seed": RANDOM_SEED,
},
}
# run Feature Selection separately for each class as binary classifications
selected_feats_ovr = {}
fs_tables_ovr = {}
for clss, label in target_classes_dict.items():
logger.info(f"Running Feature Selection for Class '{label}' (vs Rest)")
selected_feats_ovr[clss], fs_tables_ovr[clss] = run_feature_selection_steps(
# use only training data to avoid bias in test results
X=X_train_all,
y=y_train_ohe[clss],
fs_steps=fs_steps
)
logger.info("-" * 100)
21:11:54 [INFO] Running Feature Selection for Class 'Normal' (vs Rest)
21:11:54 [INFO] --> Starting Feature Selection with 21 features
21:11:54 [INFO] 1. MANUAL FILTER
21:11:54 [INFO] - Removing 1 (4.8%) feature(s) manually: ['severe_decelerations']
21:11:54 [INFO] 2. NULL_VARIANCE FILTER
21:11:54 [INFO] - Removing 0 (0.0%) feature(s) with null variance (var == 0): []
21:11:54 [INFO] 3. CORRELATION FILTER
21:11:54 [INFO] Running Correlation filter with threshold of 0.75
21:11:54 [INFO] - Removing feature 'histogram_median' with correlation +0.9478 to 'histogram_mean'
21:11:54 [INFO] - Removing feature 'histogram_mean' with correlation +0.8915 to 'histogram_mode'
21:11:54 [INFO] - Removing feature 'histogram_width' with correlation -0.8997 to 'histogram_min'
21:11:54 [INFO] - Removing 3 (15.0%) feature(s) with abs(correlation) > 0.75
21:11:54 [INFO] 4. VIF FILTER
21:11:54 [INFO] Computing the Variance Inflation Factor (VIF) for 17 features...
21:11:54 [INFO] 1. Removing feature: histogram_min .......................................... VIF: 4.84
21:11:55 [INFO] 2. Removing feature: histogram_mode ......................................... VIF: 4.27
21:11:55 [INFO] 3. Removing feature: histogram_max .......................................... VIF: 3.20
21:11:55 [INFO] 4. Removing feature: light_decelerations .................................... VIF: 2.53
21:11:55 [INFO] 5. Removing feature: mean_value_of_short_term_variability ................... VIF: 2.17
21:11:55 [INFO] >> Stopping at feat: histogram_variance ..................................... VIF: 1.80 (threshold: 2)
21:11:55 [INFO] - Removing 5 (29.4%) feature(s) with VIF >= 2
21:11:55 [INFO] 5. L1_REGULARIZATION FILTER
21:11:55 [INFO] - Removing 8 (66.7%) feature(s) with null coefficient after L1 regularization: ['baseline_value', 'fetal_movement', 'uterine_contractions', 'mean_value_of_long_term_variability', 'histogram_number_of_peaks', 'histogram_number_of_zeroes', 'histogram_variance', 'histogram_tendency']
21:11:55 [INFO] --> Completed Feature Selection with 4 selected features (19.0% of the original 21 features): ['accelerations', 'prolongued_decelerations', 'abnormal_short_term_variability', 'percentage_of_time_with_abnormal_long_term_variability']
21:11:55 [INFO] ----------------------------------------------------------------------------------------------------
21:11:55 [INFO] Running Feature Selection for Class 'Suspect' (vs Rest)
21:11:55 [INFO] --> Starting Feature Selection with 21 features
21:11:55 [INFO] 1. MANUAL FILTER
21:11:55 [INFO] - Removing 1 (4.8%) feature(s) manually: ['severe_decelerations']
21:11:55 [INFO] 2. NULL_VARIANCE FILTER
21:11:55 [INFO] - Removing 0 (0.0%) feature(s) with null variance (var == 0): []
21:11:55 [INFO] 3. CORRELATION FILTER
21:11:55 [INFO] Running Correlation filter with threshold of 0.75
21:11:55 [INFO] - Removing feature 'histogram_mode' with correlation +0.9309 to 'histogram_median'
21:11:55 [INFO] - Removing feature 'histogram_width' with correlation -0.8997 to 'histogram_min'
21:11:55 [INFO] - Removing feature 'histogram_median' with correlation +0.9478 to 'histogram_mean'
21:11:55 [INFO] - Removing 3 (15.0%) feature(s) with abs(correlation) > 0.75
21:11:55 [INFO] 4. VIF FILTER
21:11:55 [INFO] Computing the Variance Inflation Factor (VIF) for 17 features...
21:11:55 [INFO] 1. Removing feature: histogram_mean ......................................... VIF: 10.51
21:11:55 [INFO] 2. Removing feature: histogram_min .......................................... VIF: 4.62
21:11:55 [INFO] 3. Removing feature: histogram_max .......................................... VIF: 3.20
21:11:55 [INFO] 4. Removing feature: light_decelerations .................................... VIF: 2.53
21:11:55 [INFO] 5. Removing feature: mean_value_of_short_term_variability ................... VIF: 2.17
21:11:55 [INFO] >> Stopping at feat: histogram_variance ..................................... VIF: 1.80 (threshold: 2)
21:11:55 [INFO] - Removing 5 (29.4%) feature(s) with VIF >= 2
21:11:55 [INFO] 5. L1_REGULARIZATION FILTER
21:11:55 [INFO] - Removing 0 (0.0%) feature(s) with null coefficient after L1 regularization: []
21:11:55 [INFO] --> Completed Feature Selection with 12 selected features (57.1% of the original 21 features): ['baseline_value', 'accelerations', 'fetal_movement', 'uterine_contractions', 'prolongued_decelerations', 'abnormal_short_term_variability', 'percentage_of_time_with_abnormal_long_term_variability', 'mean_value_of_long_term_variability', 'histogram_number_of_peaks', 'histogram_number_of_zeroes', 'histogram_variance', 'histogram_tendency']
21:11:55 [INFO] ----------------------------------------------------------------------------------------------------
21:11:55 [INFO] Running Feature Selection for Class 'Pathological' (vs Rest)
21:11:55 [INFO] --> Starting Feature Selection with 21 features
21:11:55 [INFO] 1. MANUAL FILTER
21:11:55 [INFO] - Removing 1 (4.8%) feature(s) manually: ['severe_decelerations']
21:11:55 [INFO] 2. NULL_VARIANCE FILTER
21:11:55 [INFO] - Removing 0 (0.0%) feature(s) with null variance (var == 0): []
21:11:55 [INFO] 3. CORRELATION FILTER
21:11:55 [INFO] Running Correlation filter with threshold of 0.75
21:11:55 [INFO] - Removing feature 'baseline_value' with correlation +0.7864 to 'histogram_median'
21:11:55 [INFO] - Removing feature 'histogram_width' with correlation -0.8997 to 'histogram_min'
21:11:55 [INFO] - Removing feature 'histogram_median' with correlation +0.9478 to 'histogram_mean'
21:11:55 [INFO] - Removing feature 'histogram_mode' with correlation +0.8915 to 'histogram_mean'
21:11:55 [INFO] - Removing 4 (20.0%) feature(s) with abs(correlation) > 0.75
21:11:55 [INFO] 4. VIF FILTER
21:11:55 [INFO] Computing the Variance Inflation Factor (VIF) for 16 features...
21:11:55 [INFO] 1. Removing feature: histogram_mean ......................................... VIF: 5.57
21:11:55 [INFO] 2. Removing feature: histogram_min .......................................... VIF: 3.46
21:11:55 [INFO] 3. Removing feature: light_decelerations .................................... VIF: 2.53
21:11:55 [INFO] 4. Removing feature: histogram_variance ..................................... VIF: 2.18
21:11:55 [INFO] >> Stopping at feat: mean_value_of_short_term_variability ................... VIF: 1.96 (threshold: 2)
21:11:55 [INFO] - Removing 4 (25.0%) feature(s) with VIF >= 2
21:11:55 [INFO] 5. L1_REGULARIZATION FILTER
21:11:55 [INFO] - Removing 2 (16.7%) feature(s) with null coefficient after L1 regularization: ['accelerations', 'histogram_number_of_peaks']
21:11:55 [INFO] --> Completed Feature Selection with 10 selected features (47.6% of the original 21 features): ['fetal_movement', 'uterine_contractions', 'prolongued_decelerations', 'abnormal_short_term_variability', 'mean_value_of_short_term_variability', 'percentage_of_time_with_abnormal_long_term_variability', 'mean_value_of_long_term_variability', 'histogram_max', 'histogram_number_of_zeroes', 'histogram_tendency']
21:11:55 [INFO] ----------------------------------------------------------------------------------------------------
# keep only the features that were selected for at least 2 classes
MIN_NUM_SELECTIONS = 2
classes_intersections = []
for classes_group in itertools.combinations(selected_feats_ovr.keys(), MIN_NUM_SELECTIONS):
classes_intersections.append(
set.intersection(*[set(selected_feats_ovr[clss]) for clss in classes_group])
)
selected_feats = list(set.union(*classes_intersections))
print(f"Final selection ({len(selected_feats)} features selected):")
for feat in sorted(selected_feats):
print(f" - {feat}")
Final selection (9 features selected): - abnormal_short_term_variability - accelerations - fetal_movement - histogram_number_of_zeroes - histogram_tendency - mean_value_of_long_term_variability - percentage_of_time_with_abnormal_long_term_variability - prolongued_decelerations - uterine_contractions
# build model input datasets
X_train = X_train_all[selected_feats]
X_test = X_test_all[selected_feats]
Correlation check¶
display(
plot_correlation_matrix(
# use only training data to avoid bias in test results
df=df_input_train[selected_feats + [target_col]], method="pearson", fig_height=5
)
)
Multicollinearity check¶
# compute the Variance Inflation Factor (VIF) for each feature
df_vif = pd.DataFrame(
data=[variance_inflation_factor(X_train.values, i) for i in range(len(selected_feats))],
index=selected_feats,
columns=["VIF"]
).sort_values("VIF", ascending=False)
df_vif
VIF | |
---|---|
percentage_of_time_with_abnormal_long_term_variability | 1.62 |
abnormal_short_term_variability | 1.53 |
mean_value_of_long_term_variability | 1.41 |
accelerations | 1.39 |
prolongued_decelerations | 1.36 |
uterine_contractions | 1.18 |
fetal_movement | 1.14 |
histogram_tendency | 1.08 |
histogram_number_of_zeroes | 1.06 |
5. Classifier Model¶
Select classifier: Logistic Regression or XGBoost¶
# MODEL_SELECTION = "logistic_regression"
MODEL_SELECTION = "xgboost"
model_selection_error = ValueError(
"'MODEL_SELECTION' must be either 'logistic_regression' or 'xgboost'. "
f"Got {MODEL_SELECTION} instead."
)
Hyperparameter tuning with K-Fold Cross Validation¶
For a detailed explanation of XGBoost's parameters, refer to: https://www.kaggle.com/code/prashant111/a-guide-on-xgboost-hyperparameters-tuning/notebook
if MODEL_SELECTION == "logistic_regression":
Estimator = LogisticRegression
cv_search_space = {
"penalty": ["l1", "l2", "elasticnet"],
"solver": ["saga"],
"C": np.logspace(-3, 1, num=9, base=10.0),
"class_weight": [None],
}
elif MODEL_SELECTION == "xgboost":
Estimator = XGBClassifier
cv_search_space = {
"objective": ["multi:softmax"],
'num_class': [len(target_classes_dict)],
"n_estimators": [30, 40, 50],
"learning_rate": [0.1],
"max_depth": [3, 4, 6],
"min_child_weight": [2, 4],
"gamma": [0, 0.5],
"alpha": [0, 0.3],
"scale_pos_weight": [1],
"lambda": [1],
## "subsample": [0.8, 1.0],
## "colsample_bytree": [0.8, 1.0],
"verbosity": [0],
}
else:
raise model_selection_error
For the full list of scikit-learn's scoring string names, refer to: https://scikit-learn.org/stable/modules/model_evaluation.html#string-name-scorers
cv_scoring_metrics = {
"accuracy": "Accuracy",
"precision_macro": "Precision (macro)",
"recall_macro": "Recall (macro)",
"f1_macro": "F1 Score (macro)",
"precision_weighted": "Precision (weighted)",
"recall_weighted": "Recall (weighted)",
"f1_weighted": "F1 Score (weighted)",
"roc_auc_ovr": "ROC AUC One-vs-Rest (macro)",
"roc_auc_ovo": "ROC AUC One-vs-One (macro)",
"roc_auc_ovr_weighted": "ROC AUC One-vs-Rest (weighted)",
"roc_auc_ovo_weighted": "ROC AUC One-vs-One (weighted)",
}
refit_metric = "f1_weighted" # metric to optimize for the final model
%%time
# define evaluation
kfold_cv = RepeatedStratifiedKFold(n_splits=3, n_repeats=1, random_state=RANDOM_SEED)
# define search
grid_search = GridSearchCV(
estimator=Estimator(),
param_grid=cv_search_space,
scoring=list(cv_scoring_metrics.keys()),
cv=kfold_cv,
refit=refit_metric,
verbose=1,
)
# execute search
result_cv = grid_search.fit(X_train, y_train)
Fitting 3 folds for each of 72 candidates, totalling 216 fits
CPU times: user 19.4 s, sys: 43.4 s, total: 1min 2s Wall time: 34 s
print("Grid Search CV Best Model - Scoring Metrics (averaging method):")
for i, (metric_key, metric_name) in enumerate(cv_scoring_metrics.items(), start=1):
print(
f" {str(i) + ".":>3} {metric_name:.<31} "
f"{result_cv.cv_results_[f"mean_test_{metric_key}"][result_cv.best_index_]:.3f}"
)
print(f"\nBest Hyperparameters: {result_cv.best_params_}")
Grid Search CV Best Model - Scoring Metrics (averaging method): 1. Accuracy....................... 0.931 2. Precision (macro).............. 0.903 3. Recall (macro)................. 0.856 4. F1 Score (macro)............... 0.875 5. Precision (weighted)........... 0.929 6. Recall (weighted).............. 0.931 7. F1 Score (weighted)............ 0.928 8. ROC AUC One-vs-Rest (macro).... 0.969 9. ROC AUC One-vs-One (macro)..... 0.964 10. ROC AUC One-vs-Rest (weighted). 0.968 11. ROC AUC One-vs-One (weighted).. 0.969 Best Hyperparameters: {'alpha': 0, 'gamma': 0, 'lambda': 1, 'learning_rate': 0.1, 'max_depth': 6, 'min_child_weight': 4, 'n_estimators': 50, 'num_class': 3, 'objective': 'multi:softmax', 'scale_pos_weight': 1, 'verbosity': 0}
Final Model¶
# instantiate model with best hyperparameters and additional kwargs
if MODEL_SELECTION == "logistic_regression":
model_kwargs = dict()
model_fit_kwargs = dict()
elif MODEL_SELECTION == "xgboost":
eval_metrics = dict(
mlogloss="Binary Cross-entropy Loss (Log-loss)",
merror="Binary Classification Error Rate",
auc="ROC AUC",
)
model_kwargs = dict(eval_metric=list(eval_metrics.keys()))
model_fit_kwargs = dict(
eval_set=[(X_train, y_train), (X_test, y_test)],
verbose=False
)
else:
raise model_selection_error
model = Estimator(**result_cv.best_params_, **model_kwargs, random_state=RANDOM_SEED)
# Fit model and make predictions
model.fit(X_train, y_train, **model_fit_kwargs)
# Make probabilities predictions
y_pred_proba_train = pd.DataFrame(
data=model.predict_proba(X_train), columns=model.classes_, index=X_train.index
)
y_pred_proba = pd.DataFrame(
data=model.predict_proba(X_test), columns=model.classes_, index=X_test.index
)
# Make class predictions
y_pred_train = pd.Series(
data=model.predict(X_train), index=X_train.index, name=target_col
)
y_pred = pd.Series(
data=model.predict(X_test), index=X_test.index, name=target_col
)
if MODEL_SELECTION == "xgboost":
display(plot_eval_metrics_xgb(model.evals_result(), eval_metrics))
Plot target rate per group of predicted probability
A good model should have increasing target rate for each group of predicted probability (e.g. quartiles, deciles)
for clss, label in target_classes_dict.items():
title = f"Class '{label}': Target rate per group of predicted probability"
display(
plot_target_rate(y_test=y_test_ohe[clss], y_pred_proba=y_pred_proba[clss], title=title)
)
Feature Importance¶
- For Logistic Regression: coefficients values and statistical significance
- For XGBoost: SHAP analysis and Gain Metric
if MODEL_SELECTION == "logistic_regression":
for clss, coefficients, intercept in zip(
model.classes_, model.coef_, model.intercept_
):
label = f"Class '{target_classes_dict[clss]}'"
print(label)
df_coefficients = build_coefficients_table(
coefficients=coefficients,
intercept=intercept,
X_train=X_train,
y_pred_train=y_pred_proba_train[clss],
y_train=y_train_ohe[clss],
problem_type="classification",
)
display(
plot_coefficients_values(
df_coefficients,
title=f"{label}: Coefficient Values with 95% CI (±1.96 Std Error)"
),
plot_coefficients_significance(
df_coefficients,
log_scale=bool(df_coefficients["p-values"].max() < 2e-4),
title=f"{label}: Coefficients' Significance (p-values)"
),
)
elif MODEL_SELECTION == "xgboost":
# compute SHAP values
explainer = shap.Explainer(model)
shap_values = explainer(X_test)
# shap plots
for i, clss in enumerate(model.classes_):
label = f"Class '{target_classes_dict[clss]}'"
print(label)
display(
plot_shap_importance(
shap_values[:, :, i], title=f"{label}: SHAP Feature Importance"
),
plot_shap_beeswarm(
shap_values[:, :, i], title=f"{label}: SHAP Summary Plot"
),
plot_gain_metric_xgb(
model, X_test, title=f"{label}: XGBoost Feature Importance (Gain metric)"
)
)
else:
raise model_selection_error
Class 'Normal'
Class 'Suspect'
Class 'Pathological'
Performance Metrics¶
df_train_metrics = pd.Series(
compute_multiclass_classification_metrics(y_train, y_pred_train, y_pred_proba_train)
).to_frame(name="Train Metrics")
df_test_metrics = pd.Series(
compute_multiclass_classification_metrics(y_test, y_pred, y_pred_proba)
).to_frame(name="Test Metrics")
print("Final Model - Scoring Metrics on Train & Test Datasets:")
df_metrics = df_train_metrics.join(df_test_metrics)
display(df_metrics)
Final Model - Scoring Metrics on Train & Test Datasets:
Train Metrics | Test Metrics | |
---|---|---|
Accuracy | 0.97 | 0.92 |
Precision (macro) | 0.97 | 0.89 |
Recall (macro) | 0.92 | 0.82 |
F1 Score (macro) | 0.94 | 0.85 |
ROC AUC One-vs-Rest (macro) | 1.00 | 0.96 |
ROC AUC One-vs-One (macro) | 0.99 | 0.95 |
Precision (weighted) | 0.97 | 0.91 |
Recall (weighted) | 0.97 | 0.92 |
F1 Score (weighted) | 0.97 | 0.91 |
ROC AUC One-vs-Rest (weighted) | 0.99 | 0.96 |
ROC AUC One-vs-One (weighted) | 1.00 | 0.96 |
Confusion Matrix¶
# Confusion Matrix
display(
plot_confusion_matrix(
y_test,
y_pred,
estimator=model,
target_classes_dict=target_classes_dict,
normalize="true",
)
)
ROC AUC¶
for clss, label in target_classes_dict.items():
title = f"Class '{label}': ROC Curve One-vs-Rest"
display(
plot_roc_curve(
y_true=y_test_ohe[clss], y_pred_proba=y_pred_proba[clss], title=title,
)
)