Mercurial > repos > goeckslab > pycaret_predict
diff base_model_trainer.py @ 17:c5c324ac29fc draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
| author | goeckslab |
|---|---|
| date | Sat, 06 Dec 2025 14:20:36 +0000 |
| parents | 4fee4504646e |
| children |
line wrap: on
line diff
--- a/base_model_trainer.py Fri Nov 28 22:28:26 2025 +0000 +++ b/base_model_trainer.py Sat Dec 06 14:20:36 2025 +0000 @@ -9,7 +9,16 @@ import pandas as pd from feature_help_modal import get_feature_metrics_help_modal from feature_importance import FeatureImportanceAnalyzer -from sklearn.metrics import average_precision_score +from sklearn.metrics import ( + accuracy_score, + average_precision_score, + confusion_matrix, + f1_score, + matthews_corrcoef, + precision_score, + recall_score, + roc_auc_score, +) from utils import ( add_hr_to_html, add_plot_to_html, @@ -387,6 +396,693 @@ with open(img_path, "rb") as img_file: return base64.b64encode(img_file.read()).decode("utf-8") + def _build_dataset_overview(self): + """ + Build an HTML table showing label counts with labels as rows and splits + (Train / Validation / Test) as columns. Each cell shows count and + percentage of that split. Returns empty string for regression or when + no label data is available. + """ + if self.task_type != "classification": + return "" + + def _safe_series(obj): + try: + return pd.Series(obj).reset_index(drop=True) + except Exception: + return None + + def _get_from_config(keys): + if self.exp is None: + return None + for key in keys: + try: + val = self.exp.get_config(key) + except Exception: + val = getattr(self.exp, key, None) + if val is not None: + return val + return None + + # Prefer PyCaret-configured splits; fall back to raw inputs. + X_train = _get_from_config(["X_train_transformed", "X_train"]) + y_train = _get_from_config(["y_train_transformed", "y_train"]) + y_test_cfg = _get_from_config(["y_test_transformed", "y_test"]) + + if y_train is None and self.data is not None and self.target in self.data.columns: + y_train = self.data[self.target] + + y_train_series = _safe_series(y_train) + + # Build a cross-validation generator to derive a validation subset size. + cv_gen = self._get_cv_generator(y_train_series) + y_train_fold = y_train_series + y_val_fold = None + if cv_gen is not None and y_train_series is not None: + try: + # Use the first fold to approximate Train/Validation split sizes. + splitter = cv_gen.split( + pd.DataFrame(X_train).reset_index(drop=True) + if X_train is not None + else y_train_series, + y_train_series, + ) + train_idx, val_idx = next(iter(splitter)) + y_train_fold = y_train_series.iloc[train_idx].reset_index(drop=True) + y_val_fold = y_train_series.iloc[val_idx].reset_index(drop=True) + except Exception as exc: + LOG.warning("Could not derive validation split for dataset overview: %s", exc) + + # Test labels: prefer PyCaret transformed holdout (single file) or external test. + if self.test_data is not None: + if y_test_cfg is not None: + y_test = y_test_cfg + elif self.target in self.test_data.columns: + y_test = self.test_data[self.target] + else: + y_test = None + else: + y_test = y_test_cfg + + split_map = { + "Train": _safe_series(y_train_fold), + "Validation": _safe_series(y_val_fold), + "Test": _safe_series(y_test), + } + available = {k: v for k, v in split_map.items() if v is not None and not v.empty} + if not available: + return "" + + # Collect all labels across available splits (including NaN) + label_pool = pd.concat( + available.values(), ignore_index=True + ) + labels = pd.unique(label_pool) + + def _count_for_label(series, label): + if series is None or series.empty: + return None, None + total = len(series) + if pd.isna(label): + cnt = series.isna().sum() + else: + cnt = (series == label).sum() + return int(cnt), total + + rows = [] + for label in labels: + row = ["NaN" if pd.isna(label) else str(label)] + for split_name in ["Train", "Validation", "Test"]: + cnt, total = _count_for_label(split_map.get(split_name), label) + if cnt is None or total is None: + cell = "—" + else: + pct = (cnt / total * 100) if total else 0 + cell = f"{cnt} ({pct:.1f}%)" + row.append(cell) + rows.append(row) + + df = pd.DataFrame(rows, columns=["Label", "Train", "Validation", "Test"]) + df.sort_values("Label", inplace=True) + + return ( + "<h2>Dataset Overview</h2>" + + '<div class="table-wrapper">' + + df.to_html( + index=False, + classes=["table", "sortable", "table-dataset-overview"], + ) + + "</div>" + ) + + def _predict_with_thresholds(self, X, y_true): + """ + Generate predictions/probabilities for a split, respecting an optional + probability threshold for binary tasks. Returns a dict with y_true, + y_pred, y_scores (positive-class probs when available), pos_label, + and neg_label. + """ + if X is None or y_true is None: + return None + + y_true_series = pd.Series(y_true).reset_index(drop=True) + classes = list(getattr(self.best_model, "classes_", [])) + if not classes: + try: + classes = pd.unique(y_true_series).tolist() + except Exception: + classes = [] + if len(classes) > 1: + try: + pos_idx = classes.index(1) + except Exception: + pos_idx = 1 + else: + pos_idx = 0 + pos_idx = min(pos_idx, len(classes) - 1) if classes else 0 + pos_label = ( + classes[pos_idx] + if len(classes) > pos_idx and pos_idx >= 0 + else (classes[-1] if classes else 1) + ) + neg_label = None + if len(classes) >= 2: + neg_candidates = [c for c in classes if c != pos_label] + if neg_candidates: + neg_label = neg_candidates[0] + + prob_thresh = getattr(self, "probability_threshold", None) + y_scores = None + try: + proba = self.best_model.predict_proba(X) + y_scores = np.asarray(proba) if proba is not None else None + except Exception: + y_scores = None + + try: + if ( + prob_thresh is not None + and not getattr(self.exp, "is_multiclass", False) + and y_scores is not None + and y_scores.ndim == 2 + and y_scores.shape[1] > 1 + ): + pos_idx = min(pos_idx, y_scores.shape[1] - 1) + neg_idx = 1 - pos_idx if y_scores.shape[1] > 1 else 0 + if neg_label is None and len(classes) > neg_idx: + neg_label = classes[neg_idx] + y_pred = np.where( + y_scores[:, pos_idx] >= prob_thresh, + pos_label, + neg_label if neg_label is not None else 0, + ) + y_scores = y_scores[:, pos_idx] + else: + y_pred = self.best_model.predict(X) + if ( + not getattr(self.exp, "is_multiclass", False) + and y_scores is not None + and y_scores.ndim == 2 + and y_scores.shape[1] > 1 + ): + pos_idx = min(pos_idx, y_scores.shape[1] - 1) + y_scores = y_scores[:, pos_idx] + except Exception as exc: + LOG.warning( + "Falling back to raw predict while computing performance summary: %s", + exc, + ) + try: + y_pred = self.best_model.predict(X) + except Exception as exc_inner: + LOG.warning( + "Unable to score split after fallback prediction: %s", + exc_inner, + ) + return None + y_scores = None + + y_pred_series = pd.Series(y_pred).reset_index(drop=True) + if y_scores is not None: + y_scores = np.asarray(y_scores) + if y_scores.ndim > 1 and y_scores.shape[1] == 1: + y_scores = y_scores.ravel() + if getattr(self.exp, "is_multiclass", False) and y_scores.ndim > 1: + # Avoid passing multiclass score matrices to ROC/PR utilities + y_scores = None + + return { + "y_true": y_true_series, + "y_pred": y_pred_series, + "y_scores": y_scores, + "pos_label": pos_label, + "neg_label": neg_label, + } + + def _get_cv_generator(self, y_series): + """ + Build a cross-validation splitter that mirrors the experiment's + configuration. Returns None when CV is disabled or not applicable. + """ + if self.task_type != "classification": + return None + + if getattr(self, "cross_validation", None) is False: + return None + + try: + cfg_gen = self.exp.get_config("fold_generator") + if cfg_gen is not None: + return cfg_gen + except Exception: + cfg_gen = None + + folds = ( + getattr(self, "cross_validation_folds", None) + or self.setup_params.get("fold") + or getattr(self.exp, "fold", None) + or 10 + ) + try: + folds = int(folds) + except Exception: + folds = 10 + + try: + y_series = pd.Series(y_series).reset_index(drop=True) + except Exception: + y_series = None + if y_series is None or y_series.empty: + return None + + if folds < 2: + return None + if len(y_series) < folds: + folds = len(y_series) + if folds < 2: + return None + + try: + from sklearn.model_selection import KFold, StratifiedKFold + + if self.task_type == "classification": + return StratifiedKFold( + n_splits=folds, + shuffle=True, + random_state=self.random_seed, + ) + return KFold( + n_splits=folds, + shuffle=True, + random_state=self.random_seed, + ) + except Exception as exc: + LOG.warning("Could not build CV generator: %s", exc) + return None + + def _get_cross_validated_predictions(self, X, y): + """ + Generate cross-validated predictions for the validation split so we + can report validation metrics for the selected best model. + """ + if self.task_type != "classification": + return None + if getattr(self, "cross_validation", None) is False: + return None + if X is None or y is None: + return None + + try: + from sklearn.model_selection import cross_val_predict + except Exception as exc: + LOG.warning("cross_val_predict unavailable: %s", exc) + return None + + y_series = pd.Series(y).reset_index(drop=True) + if y_series.empty: + return None + + cv_gen = self._get_cv_generator(y_series) + if cv_gen is None: + return None + + X_df = pd.DataFrame(X).reset_index(drop=True) + if len(X_df) != len(y_series): + X_df = X_df.iloc[: len(y_series)].reset_index(drop=True) + + classes = list(getattr(self.best_model, "classes_", [])) + if len(classes) > 1: + try: + pos_idx = classes.index(1) + except Exception: + pos_idx = 1 + else: + pos_idx = 0 + pos_idx = min(pos_idx, len(classes) - 1) if classes else 0 + pos_label = ( + classes[pos_idx] if len(classes) > pos_idx else 1 + ) + neg_label = None + if len(classes) >= 2: + neg_candidates = [c for c in classes if c != pos_label] + if neg_candidates: + neg_label = neg_candidates[0] + + prob_thresh = getattr(self, "probability_threshold", None) + n_jobs = getattr(self, "n_jobs", None) + + y_scores = None + if not getattr(self.exp, "is_multiclass", False): + try: + proba = cross_val_predict( + self.best_model, + X_df, + y_series, + cv=cv_gen, + method="predict_proba", + n_jobs=n_jobs, + ) + y_scores = np.asarray(proba) + except Exception as exc: + LOG.debug("Could not compute CV probabilities: %s", exc) + + y_pred = None + if ( + prob_thresh is not None + and not getattr(self.exp, "is_multiclass", False) + and y_scores is not None + and y_scores.ndim == 2 + and y_scores.shape[1] > 1 + ): + pos_idx = min(pos_idx, y_scores.shape[1] - 1) + neg_idx = 1 - pos_idx if y_scores.shape[1] > 1 else 0 + if neg_label is None and len(classes) > neg_idx: + neg_label = classes[neg_idx] + y_pred = np.where( + y_scores[:, pos_idx] >= prob_thresh, + pos_label, + neg_label if neg_label is not None else 0, + ) + y_scores = y_scores[:, pos_idx] + else: + try: + y_pred = cross_val_predict( + self.best_model, + X_df, + y_series, + cv=cv_gen, + method="predict", + n_jobs=n_jobs, + ) + except Exception as exc: + LOG.warning( + "Could not compute cross-validated predictions: %s", + exc, + ) + return None + if ( + not getattr(self.exp, "is_multiclass", False) + and y_scores is not None + and y_scores.ndim == 2 + and y_scores.shape[1] > 1 + ): + pos_idx = min(pos_idx, y_scores.shape[1] - 1) + y_scores = y_scores[:, pos_idx] + + if y_scores is not None and getattr(self.exp, "is_multiclass", False): + y_scores = None + + return { + "y_true": y_series, + "y_pred": pd.Series(y_pred).reset_index(drop=True), + "y_scores": y_scores, + "pos_label": pos_label, + "neg_label": neg_label, + } + + def _get_split_predictions_for_report(self): + """ + Collect predictions/probabilities for Train/Validation/Test splits so the + performance table can show consistent metrics across splits. + """ + if self.task_type != "classification": + return {} + + def _get_from_config(keys): + for key in keys: + try: + val = self.exp.get_config(key) + except Exception: + val = getattr(self.exp, key, None) + if val is not None: + return val + return None + + X_train = _get_from_config(["X_train_transformed", "X_train"]) + y_train = _get_from_config(["y_train_transformed", "y_train"]) + X_holdout = _get_from_config(["X_test_transformed", "X_test"]) + y_holdout = _get_from_config(["y_test_transformed", "y_test"]) + + predictions = {} + + # Train metrics (best model on training data) + if X_train is not None and y_train is not None: + try: + train_preds = self._predict_with_thresholds(X_train, y_train) + if train_preds is not None: + predictions["Train"] = train_preds + except Exception as exc: + LOG.warning( + "Could not score Train split for performance summary: %s", + exc, + ) + + # Validation metrics via cross-validation on training data + try: + val_preds = self._get_cross_validated_predictions(X_train, y_train) + if val_preds is not None: + predictions["Validation"] = val_preds + except Exception as exc: + LOG.warning( + "Could not score Validation split for performance summary: %s", + exc, + ) + + # Test metrics (holdout from single file, or provided test file) + X_test = X_holdout + y_test = y_holdout + if (X_test is None or y_test is None) and self.test_data is not None: + try: + X_test = self.test_data.drop(columns=[self.target]) + y_test = self.test_data[self.target] + except Exception as exc: + LOG.warning( + "Could not prepare external test data for performance summary: %s", + exc, + ) + + if X_test is not None and y_test is not None: + try: + test_preds = self._predict_with_thresholds(X_test, y_test) + if test_preds is not None: + predictions["Test"] = test_preds + except Exception as exc: + LOG.warning( + "Could not score Test split for performance summary: %s", + exc, + ) + return predictions + + def _compute_metric_value(self, metric_name, preds, split_name): + """ + Compute a single metric for a given split prediction bundle. + """ + if preds is None: + return None + + y_true = preds["y_true"] + y_pred = preds["y_pred"] + y_scores = preds.get("y_scores") + pos_label = preds.get("pos_label") + neg_label = preds.get("neg_label") + is_multiclass = getattr(self.exp, "is_multiclass", False) + + def _format_binary_labels(series): + if pos_label is None: + return series + try: + return (series == pos_label).astype(int) + except Exception: + return series + + try: + if metric_name == "Accuracy": + return accuracy_score(y_true, y_pred) + if metric_name == "ROC-AUC": + if y_scores is None: + return None + y_true_bin = _format_binary_labels(y_true) + if len(pd.unique(y_true_bin)) < 2: + return None + return roc_auc_score(y_true_bin, y_scores) + if metric_name == "Precision": + if is_multiclass: + return precision_score( + y_true, y_pred, average="weighted", zero_division=0 + ) + try: + return precision_score( + y_true, y_pred, pos_label=pos_label, zero_division=0 + ) + except Exception: + return precision_score( + y_true, y_pred, average="weighted", zero_division=0 + ) + if metric_name == "Recall": + if is_multiclass: + return recall_score( + y_true, y_pred, average="weighted", zero_division=0 + ) + try: + return recall_score( + y_true, y_pred, pos_label=pos_label, zero_division=0 + ) + except Exception: + return recall_score( + y_true, y_pred, average="weighted", zero_division=0 + ) + if metric_name == "F1-Score": + if is_multiclass: + return f1_score( + y_true, y_pred, average="weighted", zero_division=0 + ) + try: + return f1_score( + y_true, y_pred, pos_label=pos_label, zero_division=0 + ) + except Exception: + return f1_score( + y_true, y_pred, average="weighted", zero_division=0 + ) + if metric_name == "PR-AUC": + if y_scores is None: + return None + y_true_bin = _format_binary_labels(y_true) + if len(pd.unique(y_true_bin)) < 2: + return None + return average_precision_score(y_true_bin, y_scores) + if metric_name == "Specificity": + labels = pd.unique(pd.concat([y_true, y_pred], ignore_index=True)) + if len(labels) != 2: + return None + if pos_label is None or pos_label not in labels: + pos_label = labels[1] + neg_candidates = [lbl for lbl in labels if lbl != pos_label] + neg_label_final = ( + neg_label if neg_label in labels else (neg_candidates[0] if neg_candidates else None) + ) + if neg_label_final is None: + return None + cm = confusion_matrix( + y_true, y_pred, labels=[neg_label_final, pos_label] + ) + if cm.shape != (2, 2): + return None + tn, fp, fn, tp = cm.ravel() + denom = tn + fp + return (tn / denom) if denom else None + if metric_name == "MCC": + return matthews_corrcoef(y_true, y_pred) + except Exception as exc: + LOG.warning( + "Could not compute %s for %s split: %s", + metric_name, + split_name, + exc, + ) + return None + return None + + def _build_performance_summary_table(self): + """ + Build a Train/Validation/Test metrics table for classification tasks. + Returns empty string when metrics are unavailable or not applicable. + """ + if self.task_type != "classification": + return "" + + split_predictions = self._get_split_predictions_for_report() + validation_best_row = None + try: + if isinstance(self.results, pd.DataFrame) and not self.results.empty: + validation_best_row = self.results.iloc[0] + except Exception: + validation_best_row = None + + if not split_predictions and validation_best_row is None: + return "" + + metric_names = [ + "Accuracy", + "ROC-AUC", + "Precision", + "Recall", + "F1-Score", + "PR-AUC", + "Specificity", + "MCC", + ] + + validation_column_map = { + "Accuracy": ["Accuracy"], + "ROC-AUC": ["ROC-AUC", "AUC"], + "Precision": ["Precision", "Prec.", "Prec"], + "Recall": ["Recall"], + "F1-Score": ["F1-Score", "F1"], + "PR-AUC": ["PR-AUC", "PR-AUC-Weighted", "PRC"], + "Specificity": ["Specificity"], + "MCC": ["MCC"], + } + + def _fmt(value): + if value is None: + return "—" + try: + if isinstance(value, (float, np.floating)) and ( + np.isnan(value) or np.isinf(value) + ): + return "—" + return f"{value:.3f}" + except Exception: + return str(value) + + def _validation_metric(metric_name): + if validation_best_row is None: + return None + cols = validation_column_map.get(metric_name, []) + for col in cols: + if col in validation_best_row: + try: + return validation_best_row[col] + except Exception: + return None + return None + + rows = [] + for metric in metric_names: + row = [metric] + # Train + train_val = self._compute_metric_value( + metric, split_predictions.get("Train"), "Train" + ) + row.append(_fmt(train_val)) + + # Validation from Train & Validation Summary first row; fallback to computed CV. + val_val = _validation_metric(metric) + if val_val is None: + val_val = self._compute_metric_value( + metric, split_predictions.get("Validation"), "Validation" + ) + row.append(_fmt(val_val)) + + # Test + test_val = self._compute_metric_value( + metric, split_predictions.get("Test"), "Test" + ) + row.append(_fmt(test_val)) + rows.append(row) + + df = pd.DataFrame(rows, columns=["Metric", "Train", "Validation", "Test"]) + return ( + "<h2>Model Performance Summary</h2>" + + '<div class="table-wrapper">' + + df.to_html( + index=False, + classes=["table", "sortable", "table-perf-summary"], + ) + + "</div>" + ) + def _resolve_plot_callable(self, key, fig_or_fn, section): """ Safely execute stored plot callables so a single failure does not @@ -521,17 +1217,19 @@ # — Validation Summary & Configuration — val_df = self.results.copy() + dataset_overview_html = self._build_dataset_overview() + performance_summary_html = self._build_performance_summary_table() # mapping raw plot keys to user-friendly titles plot_title_map = { "learning": "Learning Curve", "vc": "Validation Curve", "calibration": "Calibration Curve", "dimension": "Dimensionality Reduction", - "manifold": "Manifold Learning", + "manifold": "t-SNE", "rfe": "Recursive Feature Elimination", "threshold": "Threshold Plot", "percentage_above_below": "Percentage Above vs. Below Cutoff", - "class_report": "Classification Report", + "class_report": "Per-Class Metrics", "pr_auc": "Precision-Recall AUC", "roc_auc": "Receiver Operating Characteristic AUC", "residuals": "Residuals Distribution", @@ -560,10 +1258,16 @@ + "</div>" ) - summary_html += ( - "<h2>Setup Parameters</h2>" + config_html = ( + header + + dataset_overview_html + + performance_summary_html + + "<h2>Setup Parameters</h2>" + '<div class="table-wrapper">' - + df_setup.to_html(index=False, classes="table sortable") + + df_setup.to_html( + index=False, + classes=["table", "sortable", "table-setup-params"], + ) + "</div>" # — Hyperparameters + "<h2>Best Model Hyperparameters</h2>" @@ -571,20 +1275,23 @@ + pd.DataFrame( self.best_model.get_params().items(), columns=["Parameter", "Value"] - ).to_html(index=False, classes="table sortable") + ).to_html( + index=False, + classes=["table", "sortable", "table-hyperparams"], + ) + "</div>" ) # choose summary plots based on task type if self.task_type == "classification": summary_plots = [ + "threshold", "learning", + "calibration", + "rfe", "vc", - "calibration", "dimension", "manifold", - "rfe", - "threshold", "percentage_above_below", ] else: @@ -649,11 +1356,13 @@ else: test_order = [ "confusion_matrix", + "class_report", "roc_auc", "pr_auc", "lift_curve", "cumulative_precision", ] + rendered_test_plots = set() for key in test_order: fig_or_fn = self.explainer_plots.pop(key, None) if fig_or_fn is not None: @@ -662,6 +1371,7 @@ ) if fig is None: continue + rendered_test_plots.add(key) title = plot_title_map.get( key, key.replace("_", " ").title() ) @@ -679,6 +1389,8 @@ "class_report", } ): + if name in rendered_test_plots: + continue title = plot_title_map.get( name, name.replace("_", " ").title() ) @@ -750,7 +1462,7 @@ if cap_rows: cap_table = ( "<div class='table-wrapper'>" - "<table class='table sortable'>" + "<table class='table sortable table-fi-scope'>" "<thead><tr><th>Feature Importance Scope</th><th>Count</th></tr></thead>" "<tbody>" + "".join( @@ -803,7 +1515,13 @@ # 7) Assemble final HTML (three tabs) html = get_html_template() html += "<h1>Tabular Learner Model Report</h1>" - html += build_tabbed_html(summary_html, test_html, feature_html) + html += build_tabbed_html( + summary_html, + test_html, + feature_html, + explainer_html=None, + config_html=config_html, + ) html += get_feature_metrics_help_modal() html += get_html_closing() @@ -823,11 +1541,11 @@ raise NotImplementedError("Subclasses should implement this method") def generate_tree_plots(self): + from explainerdashboard.explainers import RandomForestExplainer from sklearn.ensemble import ( RandomForestClassifier, RandomForestRegressor ) from xgboost import XGBClassifier, XGBRegressor - from explainerdashboard.explainers import RandomForestExplainer LOG.info("Generating tree plots") X_test = self.exp.X_test_transformed.copy()
