# HG changeset patch
# User goeckslab
# Date 1765030836 0
# Node ID c5c324ac29fcea2797dd5ca9f8acda0e70b041e5
# Parent 4fee4504646ebea215697ce15060a08b30f20283
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
diff -r 4fee4504646e -r c5c324ac29fc base_model_trainer.py
--- a/base_model_trainer.py Fri Nov 28 22:28:26 2025 +0000
+++ b/base_model_trainer.py Sat Dec 06 14:20:36 2025 +0000
@@ -9,7 +9,16 @@
import pandas as pd
from feature_help_modal import get_feature_metrics_help_modal
from feature_importance import FeatureImportanceAnalyzer
-from sklearn.metrics import average_precision_score
+from sklearn.metrics import (
+ accuracy_score,
+ average_precision_score,
+ confusion_matrix,
+ f1_score,
+ matthews_corrcoef,
+ precision_score,
+ recall_score,
+ roc_auc_score,
+)
from utils import (
add_hr_to_html,
add_plot_to_html,
@@ -387,6 +396,693 @@
with open(img_path, "rb") as img_file:
return base64.b64encode(img_file.read()).decode("utf-8")
+ def _build_dataset_overview(self):
+ """
+ Build an HTML table showing label counts with labels as rows and splits
+ (Train / Validation / Test) as columns. Each cell shows count and
+ percentage of that split. Returns empty string for regression or when
+ no label data is available.
+ """
+ if self.task_type != "classification":
+ return ""
+
+ def _safe_series(obj):
+ try:
+ return pd.Series(obj).reset_index(drop=True)
+ except Exception:
+ return None
+
+ def _get_from_config(keys):
+ if self.exp is None:
+ return None
+ for key in keys:
+ try:
+ val = self.exp.get_config(key)
+ except Exception:
+ val = getattr(self.exp, key, None)
+ if val is not None:
+ return val
+ return None
+
+ # Prefer PyCaret-configured splits; fall back to raw inputs.
+ X_train = _get_from_config(["X_train_transformed", "X_train"])
+ y_train = _get_from_config(["y_train_transformed", "y_train"])
+ y_test_cfg = _get_from_config(["y_test_transformed", "y_test"])
+
+ if y_train is None and self.data is not None and self.target in self.data.columns:
+ y_train = self.data[self.target]
+
+ y_train_series = _safe_series(y_train)
+
+ # Build a cross-validation generator to derive a validation subset size.
+ cv_gen = self._get_cv_generator(y_train_series)
+ y_train_fold = y_train_series
+ y_val_fold = None
+ if cv_gen is not None and y_train_series is not None:
+ try:
+ # Use the first fold to approximate Train/Validation split sizes.
+ splitter = cv_gen.split(
+ pd.DataFrame(X_train).reset_index(drop=True)
+ if X_train is not None
+ else y_train_series,
+ y_train_series,
+ )
+ train_idx, val_idx = next(iter(splitter))
+ y_train_fold = y_train_series.iloc[train_idx].reset_index(drop=True)
+ y_val_fold = y_train_series.iloc[val_idx].reset_index(drop=True)
+ except Exception as exc:
+ LOG.warning("Could not derive validation split for dataset overview: %s", exc)
+
+ # Test labels: prefer PyCaret transformed holdout (single file) or external test.
+ if self.test_data is not None:
+ if y_test_cfg is not None:
+ y_test = y_test_cfg
+ elif self.target in self.test_data.columns:
+ y_test = self.test_data[self.target]
+ else:
+ y_test = None
+ else:
+ y_test = y_test_cfg
+
+ split_map = {
+ "Train": _safe_series(y_train_fold),
+ "Validation": _safe_series(y_val_fold),
+ "Test": _safe_series(y_test),
+ }
+ available = {k: v for k, v in split_map.items() if v is not None and not v.empty}
+ if not available:
+ return ""
+
+ # Collect all labels across available splits (including NaN)
+ label_pool = pd.concat(
+ available.values(), ignore_index=True
+ )
+ labels = pd.unique(label_pool)
+
+ def _count_for_label(series, label):
+ if series is None or series.empty:
+ return None, None
+ total = len(series)
+ if pd.isna(label):
+ cnt = series.isna().sum()
+ else:
+ cnt = (series == label).sum()
+ return int(cnt), total
+
+ rows = []
+ for label in labels:
+ row = ["NaN" if pd.isna(label) else str(label)]
+ for split_name in ["Train", "Validation", "Test"]:
+ cnt, total = _count_for_label(split_map.get(split_name), label)
+ if cnt is None or total is None:
+ cell = "—"
+ else:
+ pct = (cnt / total * 100) if total else 0
+ cell = f"{cnt} ({pct:.1f}%)"
+ row.append(cell)
+ rows.append(row)
+
+ df = pd.DataFrame(rows, columns=["Label", "Train", "Validation", "Test"])
+ df.sort_values("Label", inplace=True)
+
+ return (
+ "
Dataset Overview
"
+ + ''
+ + df.to_html(
+ index=False,
+ classes=["table", "sortable", "table-dataset-overview"],
+ )
+ + "
"
+ )
+
+ def _predict_with_thresholds(self, X, y_true):
+ """
+ Generate predictions/probabilities for a split, respecting an optional
+ probability threshold for binary tasks. Returns a dict with y_true,
+ y_pred, y_scores (positive-class probs when available), pos_label,
+ and neg_label.
+ """
+ if X is None or y_true is None:
+ return None
+
+ y_true_series = pd.Series(y_true).reset_index(drop=True)
+ classes = list(getattr(self.best_model, "classes_", []))
+ if not classes:
+ try:
+ classes = pd.unique(y_true_series).tolist()
+ except Exception:
+ classes = []
+ if len(classes) > 1:
+ try:
+ pos_idx = classes.index(1)
+ except Exception:
+ pos_idx = 1
+ else:
+ pos_idx = 0
+ pos_idx = min(pos_idx, len(classes) - 1) if classes else 0
+ pos_label = (
+ classes[pos_idx]
+ if len(classes) > pos_idx and pos_idx >= 0
+ else (classes[-1] if classes else 1)
+ )
+ neg_label = None
+ if len(classes) >= 2:
+ neg_candidates = [c for c in classes if c != pos_label]
+ if neg_candidates:
+ neg_label = neg_candidates[0]
+
+ prob_thresh = getattr(self, "probability_threshold", None)
+ y_scores = None
+ try:
+ proba = self.best_model.predict_proba(X)
+ y_scores = np.asarray(proba) if proba is not None else None
+ except Exception:
+ y_scores = None
+
+ try:
+ if (
+ prob_thresh is not None
+ and not getattr(self.exp, "is_multiclass", False)
+ and y_scores is not None
+ and y_scores.ndim == 2
+ and y_scores.shape[1] > 1
+ ):
+ pos_idx = min(pos_idx, y_scores.shape[1] - 1)
+ neg_idx = 1 - pos_idx if y_scores.shape[1] > 1 else 0
+ if neg_label is None and len(classes) > neg_idx:
+ neg_label = classes[neg_idx]
+ y_pred = np.where(
+ y_scores[:, pos_idx] >= prob_thresh,
+ pos_label,
+ neg_label if neg_label is not None else 0,
+ )
+ y_scores = y_scores[:, pos_idx]
+ else:
+ y_pred = self.best_model.predict(X)
+ if (
+ not getattr(self.exp, "is_multiclass", False)
+ and y_scores is not None
+ and y_scores.ndim == 2
+ and y_scores.shape[1] > 1
+ ):
+ pos_idx = min(pos_idx, y_scores.shape[1] - 1)
+ y_scores = y_scores[:, pos_idx]
+ except Exception as exc:
+ LOG.warning(
+ "Falling back to raw predict while computing performance summary: %s",
+ exc,
+ )
+ try:
+ y_pred = self.best_model.predict(X)
+ except Exception as exc_inner:
+ LOG.warning(
+ "Unable to score split after fallback prediction: %s",
+ exc_inner,
+ )
+ return None
+ y_scores = None
+
+ y_pred_series = pd.Series(y_pred).reset_index(drop=True)
+ if y_scores is not None:
+ y_scores = np.asarray(y_scores)
+ if y_scores.ndim > 1 and y_scores.shape[1] == 1:
+ y_scores = y_scores.ravel()
+ if getattr(self.exp, "is_multiclass", False) and y_scores.ndim > 1:
+ # Avoid passing multiclass score matrices to ROC/PR utilities
+ y_scores = None
+
+ return {
+ "y_true": y_true_series,
+ "y_pred": y_pred_series,
+ "y_scores": y_scores,
+ "pos_label": pos_label,
+ "neg_label": neg_label,
+ }
+
+ def _get_cv_generator(self, y_series):
+ """
+ Build a cross-validation splitter that mirrors the experiment's
+ configuration. Returns None when CV is disabled or not applicable.
+ """
+ if self.task_type != "classification":
+ return None
+
+ if getattr(self, "cross_validation", None) is False:
+ return None
+
+ try:
+ cfg_gen = self.exp.get_config("fold_generator")
+ if cfg_gen is not None:
+ return cfg_gen
+ except Exception:
+ cfg_gen = None
+
+ folds = (
+ getattr(self, "cross_validation_folds", None)
+ or self.setup_params.get("fold")
+ or getattr(self.exp, "fold", None)
+ or 10
+ )
+ try:
+ folds = int(folds)
+ except Exception:
+ folds = 10
+
+ try:
+ y_series = pd.Series(y_series).reset_index(drop=True)
+ except Exception:
+ y_series = None
+ if y_series is None or y_series.empty:
+ return None
+
+ if folds < 2:
+ return None
+ if len(y_series) < folds:
+ folds = len(y_series)
+ if folds < 2:
+ return None
+
+ try:
+ from sklearn.model_selection import KFold, StratifiedKFold
+
+ if self.task_type == "classification":
+ return StratifiedKFold(
+ n_splits=folds,
+ shuffle=True,
+ random_state=self.random_seed,
+ )
+ return KFold(
+ n_splits=folds,
+ shuffle=True,
+ random_state=self.random_seed,
+ )
+ except Exception as exc:
+ LOG.warning("Could not build CV generator: %s", exc)
+ return None
+
+ def _get_cross_validated_predictions(self, X, y):
+ """
+ Generate cross-validated predictions for the validation split so we
+ can report validation metrics for the selected best model.
+ """
+ if self.task_type != "classification":
+ return None
+ if getattr(self, "cross_validation", None) is False:
+ return None
+ if X is None or y is None:
+ return None
+
+ try:
+ from sklearn.model_selection import cross_val_predict
+ except Exception as exc:
+ LOG.warning("cross_val_predict unavailable: %s", exc)
+ return None
+
+ y_series = pd.Series(y).reset_index(drop=True)
+ if y_series.empty:
+ return None
+
+ cv_gen = self._get_cv_generator(y_series)
+ if cv_gen is None:
+ return None
+
+ X_df = pd.DataFrame(X).reset_index(drop=True)
+ if len(X_df) != len(y_series):
+ X_df = X_df.iloc[: len(y_series)].reset_index(drop=True)
+
+ classes = list(getattr(self.best_model, "classes_", []))
+ if len(classes) > 1:
+ try:
+ pos_idx = classes.index(1)
+ except Exception:
+ pos_idx = 1
+ else:
+ pos_idx = 0
+ pos_idx = min(pos_idx, len(classes) - 1) if classes else 0
+ pos_label = (
+ classes[pos_idx] if len(classes) > pos_idx else 1
+ )
+ neg_label = None
+ if len(classes) >= 2:
+ neg_candidates = [c for c in classes if c != pos_label]
+ if neg_candidates:
+ neg_label = neg_candidates[0]
+
+ prob_thresh = getattr(self, "probability_threshold", None)
+ n_jobs = getattr(self, "n_jobs", None)
+
+ y_scores = None
+ if not getattr(self.exp, "is_multiclass", False):
+ try:
+ proba = cross_val_predict(
+ self.best_model,
+ X_df,
+ y_series,
+ cv=cv_gen,
+ method="predict_proba",
+ n_jobs=n_jobs,
+ )
+ y_scores = np.asarray(proba)
+ except Exception as exc:
+ LOG.debug("Could not compute CV probabilities: %s", exc)
+
+ y_pred = None
+ if (
+ prob_thresh is not None
+ and not getattr(self.exp, "is_multiclass", False)
+ and y_scores is not None
+ and y_scores.ndim == 2
+ and y_scores.shape[1] > 1
+ ):
+ pos_idx = min(pos_idx, y_scores.shape[1] - 1)
+ neg_idx = 1 - pos_idx if y_scores.shape[1] > 1 else 0
+ if neg_label is None and len(classes) > neg_idx:
+ neg_label = classes[neg_idx]
+ y_pred = np.where(
+ y_scores[:, pos_idx] >= prob_thresh,
+ pos_label,
+ neg_label if neg_label is not None else 0,
+ )
+ y_scores = y_scores[:, pos_idx]
+ else:
+ try:
+ y_pred = cross_val_predict(
+ self.best_model,
+ X_df,
+ y_series,
+ cv=cv_gen,
+ method="predict",
+ n_jobs=n_jobs,
+ )
+ except Exception as exc:
+ LOG.warning(
+ "Could not compute cross-validated predictions: %s",
+ exc,
+ )
+ return None
+ if (
+ not getattr(self.exp, "is_multiclass", False)
+ and y_scores is not None
+ and y_scores.ndim == 2
+ and y_scores.shape[1] > 1
+ ):
+ pos_idx = min(pos_idx, y_scores.shape[1] - 1)
+ y_scores = y_scores[:, pos_idx]
+
+ if y_scores is not None and getattr(self.exp, "is_multiclass", False):
+ y_scores = None
+
+ return {
+ "y_true": y_series,
+ "y_pred": pd.Series(y_pred).reset_index(drop=True),
+ "y_scores": y_scores,
+ "pos_label": pos_label,
+ "neg_label": neg_label,
+ }
+
+ def _get_split_predictions_for_report(self):
+ """
+ Collect predictions/probabilities for Train/Validation/Test splits so the
+ performance table can show consistent metrics across splits.
+ """
+ if self.task_type != "classification":
+ return {}
+
+ def _get_from_config(keys):
+ for key in keys:
+ try:
+ val = self.exp.get_config(key)
+ except Exception:
+ val = getattr(self.exp, key, None)
+ if val is not None:
+ return val
+ return None
+
+ X_train = _get_from_config(["X_train_transformed", "X_train"])
+ y_train = _get_from_config(["y_train_transformed", "y_train"])
+ X_holdout = _get_from_config(["X_test_transformed", "X_test"])
+ y_holdout = _get_from_config(["y_test_transformed", "y_test"])
+
+ predictions = {}
+
+ # Train metrics (best model on training data)
+ if X_train is not None and y_train is not None:
+ try:
+ train_preds = self._predict_with_thresholds(X_train, y_train)
+ if train_preds is not None:
+ predictions["Train"] = train_preds
+ except Exception as exc:
+ LOG.warning(
+ "Could not score Train split for performance summary: %s",
+ exc,
+ )
+
+ # Validation metrics via cross-validation on training data
+ try:
+ val_preds = self._get_cross_validated_predictions(X_train, y_train)
+ if val_preds is not None:
+ predictions["Validation"] = val_preds
+ except Exception as exc:
+ LOG.warning(
+ "Could not score Validation split for performance summary: %s",
+ exc,
+ )
+
+ # Test metrics (holdout from single file, or provided test file)
+ X_test = X_holdout
+ y_test = y_holdout
+ if (X_test is None or y_test is None) and self.test_data is not None:
+ try:
+ X_test = self.test_data.drop(columns=[self.target])
+ y_test = self.test_data[self.target]
+ except Exception as exc:
+ LOG.warning(
+ "Could not prepare external test data for performance summary: %s",
+ exc,
+ )
+
+ if X_test is not None and y_test is not None:
+ try:
+ test_preds = self._predict_with_thresholds(X_test, y_test)
+ if test_preds is not None:
+ predictions["Test"] = test_preds
+ except Exception as exc:
+ LOG.warning(
+ "Could not score Test split for performance summary: %s",
+ exc,
+ )
+ return predictions
+
+ def _compute_metric_value(self, metric_name, preds, split_name):
+ """
+ Compute a single metric for a given split prediction bundle.
+ """
+ if preds is None:
+ return None
+
+ y_true = preds["y_true"]
+ y_pred = preds["y_pred"]
+ y_scores = preds.get("y_scores")
+ pos_label = preds.get("pos_label")
+ neg_label = preds.get("neg_label")
+ is_multiclass = getattr(self.exp, "is_multiclass", False)
+
+ def _format_binary_labels(series):
+ if pos_label is None:
+ return series
+ try:
+ return (series == pos_label).astype(int)
+ except Exception:
+ return series
+
+ try:
+ if metric_name == "Accuracy":
+ return accuracy_score(y_true, y_pred)
+ if metric_name == "ROC-AUC":
+ if y_scores is None:
+ return None
+ y_true_bin = _format_binary_labels(y_true)
+ if len(pd.unique(y_true_bin)) < 2:
+ return None
+ return roc_auc_score(y_true_bin, y_scores)
+ if metric_name == "Precision":
+ if is_multiclass:
+ return precision_score(
+ y_true, y_pred, average="weighted", zero_division=0
+ )
+ try:
+ return precision_score(
+ y_true, y_pred, pos_label=pos_label, zero_division=0
+ )
+ except Exception:
+ return precision_score(
+ y_true, y_pred, average="weighted", zero_division=0
+ )
+ if metric_name == "Recall":
+ if is_multiclass:
+ return recall_score(
+ y_true, y_pred, average="weighted", zero_division=0
+ )
+ try:
+ return recall_score(
+ y_true, y_pred, pos_label=pos_label, zero_division=0
+ )
+ except Exception:
+ return recall_score(
+ y_true, y_pred, average="weighted", zero_division=0
+ )
+ if metric_name == "F1-Score":
+ if is_multiclass:
+ return f1_score(
+ y_true, y_pred, average="weighted", zero_division=0
+ )
+ try:
+ return f1_score(
+ y_true, y_pred, pos_label=pos_label, zero_division=0
+ )
+ except Exception:
+ return f1_score(
+ y_true, y_pred, average="weighted", zero_division=0
+ )
+ if metric_name == "PR-AUC":
+ if y_scores is None:
+ return None
+ y_true_bin = _format_binary_labels(y_true)
+ if len(pd.unique(y_true_bin)) < 2:
+ return None
+ return average_precision_score(y_true_bin, y_scores)
+ if metric_name == "Specificity":
+ labels = pd.unique(pd.concat([y_true, y_pred], ignore_index=True))
+ if len(labels) != 2:
+ return None
+ if pos_label is None or pos_label not in labels:
+ pos_label = labels[1]
+ neg_candidates = [lbl for lbl in labels if lbl != pos_label]
+ neg_label_final = (
+ neg_label if neg_label in labels else (neg_candidates[0] if neg_candidates else None)
+ )
+ if neg_label_final is None:
+ return None
+ cm = confusion_matrix(
+ y_true, y_pred, labels=[neg_label_final, pos_label]
+ )
+ if cm.shape != (2, 2):
+ return None
+ tn, fp, fn, tp = cm.ravel()
+ denom = tn + fp
+ return (tn / denom) if denom else None
+ if metric_name == "MCC":
+ return matthews_corrcoef(y_true, y_pred)
+ except Exception as exc:
+ LOG.warning(
+ "Could not compute %s for %s split: %s",
+ metric_name,
+ split_name,
+ exc,
+ )
+ return None
+ return None
+
+ def _build_performance_summary_table(self):
+ """
+ Build a Train/Validation/Test metrics table for classification tasks.
+ Returns empty string when metrics are unavailable or not applicable.
+ """
+ if self.task_type != "classification":
+ return ""
+
+ split_predictions = self._get_split_predictions_for_report()
+ validation_best_row = None
+ try:
+ if isinstance(self.results, pd.DataFrame) and not self.results.empty:
+ validation_best_row = self.results.iloc[0]
+ except Exception:
+ validation_best_row = None
+
+ if not split_predictions and validation_best_row is None:
+ return ""
+
+ metric_names = [
+ "Accuracy",
+ "ROC-AUC",
+ "Precision",
+ "Recall",
+ "F1-Score",
+ "PR-AUC",
+ "Specificity",
+ "MCC",
+ ]
+
+ validation_column_map = {
+ "Accuracy": ["Accuracy"],
+ "ROC-AUC": ["ROC-AUC", "AUC"],
+ "Precision": ["Precision", "Prec.", "Prec"],
+ "Recall": ["Recall"],
+ "F1-Score": ["F1-Score", "F1"],
+ "PR-AUC": ["PR-AUC", "PR-AUC-Weighted", "PRC"],
+ "Specificity": ["Specificity"],
+ "MCC": ["MCC"],
+ }
+
+ def _fmt(value):
+ if value is None:
+ return "—"
+ try:
+ if isinstance(value, (float, np.floating)) and (
+ np.isnan(value) or np.isinf(value)
+ ):
+ return "—"
+ return f"{value:.3f}"
+ except Exception:
+ return str(value)
+
+ def _validation_metric(metric_name):
+ if validation_best_row is None:
+ return None
+ cols = validation_column_map.get(metric_name, [])
+ for col in cols:
+ if col in validation_best_row:
+ try:
+ return validation_best_row[col]
+ except Exception:
+ return None
+ return None
+
+ rows = []
+ for metric in metric_names:
+ row = [metric]
+ # Train
+ train_val = self._compute_metric_value(
+ metric, split_predictions.get("Train"), "Train"
+ )
+ row.append(_fmt(train_val))
+
+ # Validation from Train & Validation Summary first row; fallback to computed CV.
+ val_val = _validation_metric(metric)
+ if val_val is None:
+ val_val = self._compute_metric_value(
+ metric, split_predictions.get("Validation"), "Validation"
+ )
+ row.append(_fmt(val_val))
+
+ # Test
+ test_val = self._compute_metric_value(
+ metric, split_predictions.get("Test"), "Test"
+ )
+ row.append(_fmt(test_val))
+ rows.append(row)
+
+ df = pd.DataFrame(rows, columns=["Metric", "Train", "Validation", "Test"])
+ return (
+ "Model Performance Summary
"
+ + ''
+ + df.to_html(
+ index=False,
+ classes=["table", "sortable", "table-perf-summary"],
+ )
+ + "
"
+ )
+
def _resolve_plot_callable(self, key, fig_or_fn, section):
"""
Safely execute stored plot callables so a single failure does not
@@ -521,17 +1217,19 @@
# — Validation Summary & Configuration —
val_df = self.results.copy()
+ dataset_overview_html = self._build_dataset_overview()
+ performance_summary_html = self._build_performance_summary_table()
# mapping raw plot keys to user-friendly titles
plot_title_map = {
"learning": "Learning Curve",
"vc": "Validation Curve",
"calibration": "Calibration Curve",
"dimension": "Dimensionality Reduction",
- "manifold": "Manifold Learning",
+ "manifold": "t-SNE",
"rfe": "Recursive Feature Elimination",
"threshold": "Threshold Plot",
"percentage_above_below": "Percentage Above vs. Below Cutoff",
- "class_report": "Classification Report",
+ "class_report": "Per-Class Metrics",
"pr_auc": "Precision-Recall AUC",
"roc_auc": "Receiver Operating Characteristic AUC",
"residuals": "Residuals Distribution",
@@ -560,10 +1258,16 @@
+ ""
)
- summary_html += (
- "Setup Parameters
"
+ config_html = (
+ header
+ + dataset_overview_html
+ + performance_summary_html
+ + "Setup Parameters
"
+ ''
- + df_setup.to_html(index=False, classes="table sortable")
+ + df_setup.to_html(
+ index=False,
+ classes=["table", "sortable", "table-setup-params"],
+ )
+ "
"
# — Hyperparameters
+ "Best Model Hyperparameters
"
@@ -571,20 +1275,23 @@
+ pd.DataFrame(
self.best_model.get_params().items(),
columns=["Parameter", "Value"]
- ).to_html(index=False, classes="table sortable")
+ ).to_html(
+ index=False,
+ classes=["table", "sortable", "table-hyperparams"],
+ )
+ ""
)
# choose summary plots based on task type
if self.task_type == "classification":
summary_plots = [
+ "threshold",
"learning",
+ "calibration",
+ "rfe",
"vc",
- "calibration",
"dimension",
"manifold",
- "rfe",
- "threshold",
"percentage_above_below",
]
else:
@@ -649,11 +1356,13 @@
else:
test_order = [
"confusion_matrix",
+ "class_report",
"roc_auc",
"pr_auc",
"lift_curve",
"cumulative_precision",
]
+ rendered_test_plots = set()
for key in test_order:
fig_or_fn = self.explainer_plots.pop(key, None)
if fig_or_fn is not None:
@@ -662,6 +1371,7 @@
)
if fig is None:
continue
+ rendered_test_plots.add(key)
title = plot_title_map.get(
key, key.replace("_", " ").title()
)
@@ -679,6 +1389,8 @@
"class_report",
}
):
+ if name in rendered_test_plots:
+ continue
title = plot_title_map.get(
name, name.replace("_", " ").title()
)
@@ -750,7 +1462,7 @@
if cap_rows:
cap_table = (
""
- "
"
+ ""
"| Feature Importance Scope | Count |
"
""
+ "".join(
@@ -803,7 +1515,13 @@
# 7) Assemble final HTML (three tabs)
html = get_html_template()
html += "Tabular Learner Model Report
"
- html += build_tabbed_html(summary_html, test_html, feature_html)
+ html += build_tabbed_html(
+ summary_html,
+ test_html,
+ feature_html,
+ explainer_html=None,
+ config_html=config_html,
+ )
html += get_feature_metrics_help_modal()
html += get_html_closing()
@@ -823,11 +1541,11 @@
raise NotImplementedError("Subclasses should implement this method")
def generate_tree_plots(self):
+ from explainerdashboard.explainers import RandomForestExplainer
from sklearn.ensemble import (
RandomForestClassifier, RandomForestRegressor
)
from xgboost import XGBClassifier, XGBRegressor
- from explainerdashboard.explainers import RandomForestExplainer
LOG.info("Generating tree plots")
X_test = self.exp.X_test_transformed.copy()
diff -r 4fee4504646e -r c5c324ac29fc feature_importance.py
--- a/feature_importance.py Fri Nov 28 22:28:26 2025 +0000
+++ b/feature_importance.py Sat Dec 06 14:20:36 2025 +0000
@@ -287,24 +287,16 @@
# Background set
bg = X_data.sample(min(len(X_data), 100), random_state=42)
- predict_fn = (
- model.predict_proba if hasattr(model, "predict_proba") else model.predict
- )
+ predict_fn = self._get_predict_fn(model)
- # Optimized explainer
- explainer = None
- explainer_label = None
- if hasattr(model, "feature_importances_"):
- explainer = shap.TreeExplainer(
- model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
- )
- explainer_label = "tree_path_dependent"
- elif hasattr(model, "coef_"):
- explainer = shap.LinearExplainer(model, bg)
- explainer_label = "linear"
- else:
- explainer = shap.Explainer(predict_fn, bg)
- explainer_label = explainer.__class__.__name__
+ # Optimized explainer based on model type
+ explainer, explainer_label, tree_based = self._choose_shap_explainer(
+ model, bg, predict_fn
+ )
+ if explainer is None:
+ LOG.warning("No suitable SHAP explainer for model %s; skipping SHAP.", model)
+ self.shap_model_name = None
+ return
try:
shap_values = explainer(X_data)
@@ -312,7 +304,7 @@
except Exception as e:
error_message = str(e)
needs_tree_fallback = (
- hasattr(model, "feature_importances_")
+ tree_based
and "does not cover all the leaves" in error_message.lower()
)
feature_name_mismatch = "feature names should match" in error_message.lower()
@@ -348,7 +340,9 @@
error_message,
)
try:
- agnostic_explainer = shap.Explainer(predict_fn, bg)
+ agnostic_explainer = shap.Explainer(
+ predict_fn, bg, algorithm="permutation"
+ )
shap_values = agnostic_explainer(X_data)
self.shap_model_name = (
f"{agnostic_explainer.__class__.__name__} (fallback)"
@@ -485,6 +479,241 @@
with open(img_path, "rb") as img_file:
return base64.b64encode(img_file.read()).decode("utf-8")
+ def _get_predict_fn(self, model):
+ if hasattr(model, "predict_proba"):
+ return model.predict_proba
+ if hasattr(model, "decision_function"):
+ return model.decision_function
+ return model.predict
+
+ def _choose_shap_explainer(self, model, bg, predict_fn):
+ """
+ Select a SHAP explainer following the prescribed priority order for
+ algorithms. Returns (explainer, label, is_tree_based).
+ """
+ if model is None:
+ return None, None, False
+
+ name = model.__class__.__name__
+ lname = name.lower()
+ task = getattr(self, "task_type", None)
+
+ def _permutation(fn):
+ return shap.Explainer(fn, bg, algorithm="permutation")
+
+ if task == "classification":
+ # 1) Logistic Regression
+ if "logisticregression" in lname:
+ return _permutation(model.predict_proba), "permutation-proba", False
+
+ # 2) Ridge Classifier
+ if "ridgeclassifier" in lname:
+ fn = (
+ model.decision_function
+ if hasattr(model, "decision_function")
+ else predict_fn
+ )
+ return _permutation(fn), "permutation-decision_function", False
+
+ # 3) LDA
+ if "lineardiscriminantanalysis" in lname:
+ return _permutation(model.predict_proba), "permutation-proba", False
+
+ # 4) Random Forest
+ if "randomforestclassifier" in lname:
+ return (
+ shap.TreeExplainer(
+ model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
+ ),
+ "tree_path_dependent",
+ True,
+ )
+
+ # 5) Gradient Boosting
+ if "gradientboostingclassifier" in lname:
+ return (
+ shap.TreeExplainer(
+ model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
+ ),
+ "tree_path_dependent",
+ True,
+ )
+
+ # 6) AdaBoost
+ if "adaboostclassifier" in lname:
+ return (
+ shap.TreeExplainer(
+ model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
+ ),
+ "tree_path_dependent",
+ True,
+ )
+
+ # 7) Extra Trees
+ if "extratreesclassifier" in lname:
+ return (
+ shap.TreeExplainer(
+ model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
+ ),
+ "tree_path_dependent",
+ True,
+ )
+
+ # 8) LightGBM
+ if "lgbmclassifier" in lname:
+ return (
+ shap.TreeExplainer(
+ model,
+ bg,
+ model_output="raw",
+ feature_perturbation="tree_path_dependent",
+ n_jobs=-1,
+ ),
+ "tree_path_dependent",
+ True,
+ )
+
+ # 9) XGBoost
+ if "xgbclassifier" in lname:
+ return (
+ shap.TreeExplainer(
+ model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
+ ),
+ "tree_path_dependent",
+ True,
+ )
+
+ # 10) CatBoost (classifier)
+ if "catboost" in lname:
+ return (
+ shap.TreeExplainer(
+ model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
+ ),
+ "tree_path_dependent",
+ True,
+ )
+
+ # 11) KNN
+ if "kneighborsclassifier" in lname:
+ return _permutation(model.predict_proba), "permutation-proba", False
+
+ # 12) SVM - linear kernel
+ if "svc" in lname or "svm" in lname:
+ kernel = getattr(model, "kernel", None)
+ if kernel == "linear":
+ return shap.LinearExplainer(model, bg), "linear", False
+ return _permutation(predict_fn), "permutation-svm", False
+
+ # 13) Decision Tree
+ if "decisiontreeclassifier" in lname:
+ return (
+ shap.TreeExplainer(
+ model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
+ ),
+ "tree_path_dependent",
+ True,
+ )
+
+ # 14) Naive Bayes
+ if "naive_bayes" in lname or lname.endswith("nb"):
+ fn = model.predict_proba if hasattr(model, "predict_proba") else predict_fn
+ return _permutation(fn), "permutation-proba", False
+
+ # 15) QDA
+ if "quadraticdiscriminantanalysis" in lname:
+ return _permutation(model.predict_proba), "permutation-proba", False
+
+ # 16) Dummy
+ if "dummyclassifier" in lname:
+ return None, None, False
+
+ # Default classification: permutation on predict_fn
+ return _permutation(predict_fn), "permutation-default", False
+
+ # Regression path
+ # Linear family
+ linear_keys = [
+ "linearregression",
+ "lasso",
+ "ridge",
+ "elasticnet",
+ "lars",
+ "lassolars",
+ "orthogonalmatchingpursuit",
+ "bayesianridge",
+ "ardregression",
+ "passiveaggressiveregressor",
+ "theilsenregressor",
+ "huberregressor",
+ ]
+ if any(k in lname for k in linear_keys):
+ return shap.LinearExplainer(model, bg), "linear", False
+
+ # Kernel ridge / SVR / KNN / MLP / RANSAC (model-agnostic)
+ if "kernelridge" in lname:
+ return _permutation(predict_fn), "permutation-kernelridge", False
+ if "svr" in lname or "svm" in lname:
+ kernel = getattr(model, "kernel", None)
+ if kernel == "linear":
+ return shap.LinearExplainer(model, bg), "linear", False
+ return _permutation(predict_fn), "permutation-svr", False
+ if "kneighborsregressor" in lname:
+ return _permutation(predict_fn), "permutation-knn", False
+ if "mlpregressor" in lname:
+ return _permutation(predict_fn), "permutation-mlp", False
+ if "ransacregressor" in lname:
+ return _permutation(predict_fn), "permutation-ransac", False
+
+ # Tree-based regressors
+ tree_class_names = [
+ "decisiontreeregressor",
+ "randomforestregressor",
+ "extratreesregressor",
+ "adaboostregressor",
+ "gradientboostingregressor",
+ ]
+ if any(k in lname for k in tree_class_names):
+ return (
+ shap.TreeExplainer(
+ model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
+ ),
+ "tree_path_dependent",
+ True,
+ )
+
+ # Boosting libraries
+ if "lgbmregressor" in lname or "lightgbm" in lname:
+ return (
+ shap.TreeExplainer(
+ model,
+ bg,
+ model_output="raw",
+ feature_perturbation="tree_path_dependent",
+ n_jobs=-1,
+ ),
+ "tree_path_dependent",
+ True,
+ )
+ if "xgbregressor" in lname or "xgboost" in lname:
+ return (
+ shap.TreeExplainer(
+ model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
+ ),
+ "tree_path_dependent",
+ True,
+ )
+ if "catboost" in lname:
+ return (
+ shap.TreeExplainer(
+ model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
+ ),
+ "tree_path_dependent",
+ True,
+ )
+
+ # Default regression: model-agnostic permutation explainer
+ return _permutation(predict_fn), "permutation-default", False
+
def run(self):
if (
self.exp is None
diff -r 4fee4504646e -r c5c324ac29fc pycaret_classification.py
--- a/pycaret_classification.py Fri Nov 28 22:28:26 2025 +0000
+++ b/pycaret_classification.py Sat Dec 06 14:20:36 2025 +0000
@@ -8,7 +8,14 @@
from base_model_trainer import BaseModelTrainer
from dashboard import generate_classifier_explainer_dashboard
from pycaret.classification import ClassificationExperiment
-from sklearn.metrics import auc, confusion_matrix, precision_recall_curve, roc_curve
+from sklearn.metrics import (
+ auc,
+ confusion_matrix,
+ matthews_corrcoef,
+ precision_recall_curve,
+ precision_recall_fscore_support,
+ roc_curve,
+)
from utils import predict_proba
LOG = logging.getLogger(__name__)
@@ -137,58 +144,36 @@
# a dict to hold the raw Figure objects or callables
self.explainer_plots: Dict[str, go.Figure] = {}
+ y_true, y_pred, label_values, y_scores = self._get_test_predictions()
+
+ # — Classification report (Plotly table) —
+ try:
+ fig_report = self._build_classification_report_fig(
+ y_true, y_pred, label_values
+ )
+ if fig_report is not None:
+ self.explainer_plots["class_report"] = fig_report
+ except Exception as e:
+ LOG.warning(f"Could not generate Plotly classification report: {e}")
+
+ # — Confusion matrix with actual labels —
+ try:
+ fig_cm = self._build_confusion_matrix_fig(y_true, y_pred, label_values)
+ if fig_cm is not None:
+ self.explainer_plots["confusion_matrix"] = fig_cm
+ except Exception as e:
+ LOG.warning(f"Could not generate Plotly confusion matrix: {e}")
+
# --- Threshold-aware overrides for CM / ROC / PR ---
prob_thresh = getattr(self, "probability_threshold", None)
# Only for binary classification and when threshold is provided
if (prob_thresh is not None) and (not self.exp.is_multiclass):
- X = self.exp.X_test_transformed
- y = pd.Series(self.exp.y_test_transformed).reset_index(drop=True)
-
- # Get positive-class scores (robust defaults)
- classes = list(getattr(self.best_model, "classes_", [0, 1]))
- try:
- pos_idx = classes.index(1) if 1 in classes else 1
- except Exception:
- pos_idx = 1
-
- proba = self.best_model.predict_proba(X)
- y_scores = proba[:, pos_idx]
-
- # Derive label names consistently
- pos_label = classes[pos_idx] if len(classes) > pos_idx else 1
- neg_label = classes[1 - pos_idx] if len(classes) > 1 else 0
-
- # ---- Confusion Matrix @ threshold ----
- try:
- y_pred = np.where(y_scores >= prob_thresh, pos_label, neg_label)
- cm = confusion_matrix(y, y_pred, labels=[neg_label, pos_label])
- fig_cm = go.Figure(
- data=go.Heatmap(
- z=cm,
- x=[f"Pred {neg_label}", f"Pred {pos_label}"],
- y=[f"True {neg_label}", f"True {pos_label}"],
- text=cm,
- texttemplate="%{text}",
- colorscale="Blues",
- showscale=False,
- )
- )
- fig_cm.update_layout(
- title=f"Confusion Matrix @ threshold={prob_thresh:.2f}",
- xaxis_title="Predicted label",
- yaxis_title="True label",
- )
- _apply_report_layout(fig_cm)
- self.explainer_plots["confusion_matrix"] = fig_cm
- except Exception as e:
- LOG.warning(
- f"Threshold-aware confusion matrix failed; falling back: {e}"
- )
-
# ---- ROC with threshold marker ----
try:
- fpr, tpr, thr = roc_curve(y, y_scores)
+ if y_scores is None:
+ raise ValueError("Predicted probabilities unavailable")
+ fpr, tpr, thr = roc_curve(y_true, y_scores)
roc_auc = auc(fpr, tpr)
fig_roc = go.Figure()
fig_roc.add_scatter(
@@ -219,7 +204,9 @@
# ---- PR with threshold marker ----
try:
- precision, recall, thr_pr = precision_recall_curve(y, y_scores)
+ if y_scores is None:
+ raise ValueError("Predicted probabilities unavailable")
+ precision, recall, thr_pr = precision_recall_curve(y_true, y_scores)
pr_auc = auc(recall, precision)
fig_pr = go.Figure()
fig_pr.add_scatter(
@@ -304,3 +291,182 @@
return _plot
self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat)
+
+ def _get_test_predictions(self):
+ """
+ Return y_true, y_pred, label list, and (optionally) positive-class
+ probabilities when available. Ensures predictions respect the optional
+ probability threshold for binary tasks.
+ """
+ y_true = pd.Series(self.exp.y_test_transformed).reset_index(drop=True)
+ X_test = self.exp.X_test_transformed
+ prob_thresh = getattr(self, "probability_threshold", None)
+
+ y_scores = None
+ try:
+ proba = self.best_model.predict_proba(X_test)
+ y_scores = proba
+ except Exception:
+ LOG.debug("predict_proba unavailable for test predictions.")
+
+ try:
+ if (
+ prob_thresh is not None
+ and not self.exp.is_multiclass
+ and y_scores is not None
+ and y_scores.ndim == 2
+ and y_scores.shape[1] > 1
+ ):
+ classes = list(getattr(self.best_model, "classes_", []))
+ try:
+ pos_idx = classes.index(1) if 1 in classes else 1
+ except Exception:
+ pos_idx = 1
+ neg_idx = 1 - pos_idx if y_scores.shape[1] > 1 else 0
+ pos_label = classes[pos_idx] if len(classes) > pos_idx else 1
+ neg_label = classes[neg_idx] if len(classes) > neg_idx else 0
+ y_pred = np.where(y_scores[:, pos_idx] >= prob_thresh, pos_label, neg_label)
+ y_scores = y_scores[:, pos_idx]
+ else:
+ y_pred = self.best_model.predict(X_test)
+ except Exception as exc:
+ LOG.warning("Falling back to raw predict for test predictions: %s", exc)
+ y_pred = self.best_model.predict(X_test)
+
+ y_pred = pd.Series(y_pred).reset_index(drop=True)
+ if y_scores is not None:
+ y_scores = np.asarray(y_scores)
+ if y_scores.ndim > 1 and y_scores.shape[1] == 1:
+ y_scores = y_scores.ravel()
+ if self.exp.is_multiclass and y_scores.ndim > 1:
+ # Avoid passing multiclass score matrices to ROC/PR utilities
+ y_scores = None
+ label_values = pd.unique(pd.concat([y_true, y_pred], ignore_index=True))
+ return y_true, y_pred, label_values.tolist(), y_scores
+
+ def _threshold_suffix(self) -> str:
+ """
+ Build a suffix like ' (threshold=0.50)' for binary tasks; omit for
+ multiclass where thresholds are not applied.
+ """
+ if getattr(self, "task_type", None) != "classification":
+ return ""
+ if getattr(self.exp, "is_multiclass", False):
+ return ""
+ prob_thresh = getattr(self, "probability_threshold", None)
+ if prob_thresh is None:
+ return " (threshold=0.50)"
+ try:
+ return f" (threshold={float(prob_thresh):.2f})"
+ except Exception:
+ return f" (threshold={prob_thresh})"
+
+ def _build_confusion_matrix_fig(self, y_true, y_pred, labels):
+ def _label_sort_key(lbl):
+ try:
+ return (0, float(lbl))
+ except Exception:
+ return (1, str(lbl))
+
+ ordered_labels = sorted(labels, key=_label_sort_key)
+ cm = confusion_matrix(y_true, y_pred, labels=ordered_labels)
+ label_names = [str(lbl) for lbl in ordered_labels]
+ fig_cm = go.Figure(
+ data=go.Heatmap(
+ z=cm,
+ x=[f"Pred {lbl}" for lbl in label_names],
+ y=[f"True {lbl}" for lbl in label_names],
+ text=cm,
+ texttemplate="%{text}",
+ colorscale="Blues",
+ showscale=False,
+ )
+ )
+ fig_cm.update_layout(
+ title=f"Confusion Matrix{self._threshold_suffix()}",
+ xaxis_title=f"Predicted label ({self.target})",
+ yaxis_title=f"True label ({self.target})",
+ )
+ fig_cm.update_xaxes(
+ type="category",
+ categoryorder="array",
+ categoryarray=[f"Pred {lbl}" for lbl in label_names],
+ )
+ fig_cm.update_yaxes(
+ type="category",
+ categoryorder="array",
+ categoryarray=[f"True {lbl}" for lbl in label_names],
+ autorange="reversed",
+ )
+ _apply_report_layout(fig_cm)
+ return fig_cm
+
+ def _build_classification_report_fig(self, y_true, y_pred, labels):
+ precision, recall, f1, support = precision_recall_fscore_support(
+ y_true, y_pred, labels=labels, zero_division=0
+ )
+ mcc_scores = []
+ for lbl in labels:
+ y_true_bin = (y_true == lbl).astype(int)
+ y_pred_bin = (y_pred == lbl).astype(int)
+ try:
+ mcc_val = matthews_corrcoef(y_true_bin, y_pred_bin)
+ except Exception:
+ mcc_val = 0.0
+ mcc_scores.append(mcc_val)
+
+ label_names = [str(lbl) for lbl in labels]
+ metrics = ["precision", "recall", "f1", "support"]
+
+ max_support = float(max(support) if len(support) else 0)
+ z_rows = []
+ text_rows = []
+ for i, lbl in enumerate(label_names):
+ norm_support = (support[i] / max_support) if max_support else 0.0
+ z_rows.append(
+ [
+ precision[i],
+ recall[i],
+ f1[i],
+ norm_support,
+ ]
+ )
+ text_rows.append(
+ [
+ f"{precision[i]:.3f}",
+ f"{recall[i]:.3f}",
+ f"{f1[i]:.3f}",
+ f"{int(support[i])}",
+ ]
+ )
+
+ fig = go.Figure(
+ data=go.Heatmap(
+ z=z_rows,
+ x=metrics,
+ y=label_names,
+ colorscale="YlOrRd",
+ zmin=0,
+ zmax=1,
+ colorbar=dict(title="Scale"),
+ text=text_rows,
+ texttemplate="%{text}",
+ hovertemplate="Label=%{y}
Metric=%{x}
Value=%{text}",
+ )
+ )
+ fig.update_yaxes(
+ title_text=f"Label ({self.target})",
+ autorange="reversed",
+ type="category",
+ tickmode="array",
+ tickvals=label_names,
+ ticktext=label_names,
+ showgrid=False,
+ )
+ fig.update_xaxes(title_text="", tickangle=45)
+ fig.update_layout(
+ title=f"Per-Class Metrics{self._threshold_suffix()}",
+ margin=dict(l=70, r=60, t=70, b=80),
+ )
+ _apply_report_layout(fig)
+ return fig
diff -r 4fee4504646e -r c5c324ac29fc utils.py
--- a/utils.py Fri Nov 28 22:28:26 2025 +0000
+++ b/utils.py Sat Dec 06 14:20:36 2025 +0000
@@ -65,6 +65,28 @@
color: white;
}
+ /* Center specific numeric columns */
+ .table-dataset-overview td:nth-child(n+2),
+ .table-dataset-overview th:nth-child(n+2) {
+ text-align: center;
+ }
+ .table-perf-summary td:nth-child(n+2),
+ .table-perf-summary th:nth-child(n+2) {
+ text-align: center;
+ }
+ .table-setup-params td:nth-child(2),
+ .table-setup-params th:nth-child(2) {
+ text-align: center;
+ }
+ .table-hyperparams td:nth-child(2),
+ .table-hyperparams th:nth-child(2) {
+ text-align: center;
+ }
+ .table-fi-scope td:nth-child(2),
+ .table-fi-scope th:nth-child(2) {
+ text-align: center;
+ }
+
.plot {
text-align: center;
margin: 20px 0;
@@ -194,6 +216,7 @@
test_html: str,
feature_html: str,
explainer_html: Optional[str] = None,
+ config_html: Optional[str] = None,
) -> str:
"""
Render the tabbed sections and an always-visible Help button.
@@ -202,12 +225,24 @@
css = get_html_template().split("")[1].rsplit("", 1)[0] + ""
# Tabs header
- tabs = [
- '',
- '
Validation Summary and Config
',
+ tabs = ['
']
+ default_active = "summary"
+ if config_html:
+ default_active = "config"
+ tabs.append(
+ '
Model Config Summary
'
+ )
+ tabs.append(
+ '
Validation Summary
'
+ )
+ else:
+ tabs.append(
+ '
Validation Summary
'
+ )
+ tabs.extend([
'
Test Summary
',
'
Feature Importance
',
- ]
+ ])
if explainer_html:
tabs.append(
'
Explainer Plots
'
@@ -217,11 +252,16 @@
tabs_section = "\n".join(tabs)
# Content
- contents = [
- f'
{summary_html}
',
- f'
{test_html}
',
- f'
{feature_html}
',
- ]
+ contents = []
+ if config_html:
+ contents.append(
+ f'
{config_html}
'
+ )
+ contents.append(
+ f'
{summary_html}
'
+ )
+ contents.append(f'
{test_html}
')
+ contents.append(f'
{feature_html}
')
if explainer_html:
contents.append(
f'
{explainer_html}
'