Mercurial > repos > goeckslab > pycaret_predict
changeset 12:e674b9e946fb draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
author | goeckslab |
---|---|
date | Mon, 08 Sep 2025 22:39:12 +0000 |
parents | 4eca9d109de1 |
children | |
files | feature_importance.py pycaret_classification.py |
diffstat | 2 files changed, 277 insertions(+), 110 deletions(-) [+] |
line wrap: on
line diff
--- a/feature_importance.py Fri Aug 22 21:13:30 2025 +0000 +++ b/feature_importance.py Mon Sep 08 22:39:12 2025 +0000 @@ -23,7 +23,6 @@ exp=None, best_model=None, ): - self.task_type = task_type self.output_dir = output_dir self.exp = exp @@ -40,8 +39,8 @@ LOG.info("Data loaded from memory") else: self.target_col = target_col - self.data = pd.read_csv(data_path, sep=None, engine='python') - self.data.columns = self.data.columns.str.replace('.', '_') + self.data = pd.read_csv(data_path, sep=None, engine="python") + self.data.columns = self.data.columns.str.replace(".", "_") self.data = self.data.fillna(self.data.median(numeric_only=True)) self.target = self.data.columns[int(target_col) - 1] self.exp = ( @@ -53,63 +52,58 @@ self.plots = {} def setup_pycaret(self): - if self.exp is not None and hasattr(self.exp, 'is_setup') and self.exp.is_setup: + if self.exp is not None and hasattr(self.exp, "is_setup") and self.exp.is_setup: LOG.info("Experiment already set up. Skipping PyCaret setup.") return LOG.info("Initializing PyCaret") setup_params = { - 'target': self.target, - 'session_id': 123, - 'html': True, - 'log_experiment': False, - 'system_log': False, + "target": self.target, + "session_id": 123, + "html": True, + "log_experiment": False, + "system_log": False, } self.exp.setup(self.data, **setup_params) def save_tree_importance(self): - model = self.best_model or self.exp.get_config('best_model') - processed_features = self.exp.get_config('X_transformed').columns + model = self.best_model or self.exp.get_config("best_model") + processed_features = self.exp.get_config("X_transformed").columns - # Try feature_importances_ or coef_ if available importances = None model_type = model.__class__.__name__ - self.tree_model_name = model_type # Store the model name for reporting + self.tree_model_name = model_type - if hasattr(model, 'feature_importances_'): + if hasattr(model, "feature_importances_"): importances = model.feature_importances_ - elif hasattr(model, 'coef_'): - # For linear models, flatten coef_ and take abs (importance as magnitude) + elif hasattr(model, "coef_"): importances = abs(model.coef_).flatten() else: - # Neither attribute exists; skip the plot LOG.warning( - f"Model {model_type} does not have feature_importances_ or coef_ attribute. Skipping feature importance plot." + f"Model {model_type} does not have feature_importances_ or coef_. Skipping tree importance." ) - self.tree_model_name = None # No plot generated + self.tree_model_name = None return - # Defensive: handle mismatch in number of features if len(importances) != len(processed_features): LOG.warning( - f"Number of importances ({len(importances)}) does not match number of features ({len(processed_features)}). Skipping plot." + f"Importances ({len(importances)}) != features ({len(processed_features)}). Skipping tree importance." ) self.tree_model_name = None return feature_importances = pd.DataFrame( - {'Feature': processed_features, 'Importance': importances} - ).sort_values(by='Importance', ascending=False) + {"Feature": processed_features, "Importance": importances} + ).sort_values(by="Importance", ascending=False) plt.figure(figsize=(10, 6)) - plt.barh(feature_importances['Feature'], feature_importances['Importance']) - plt.xlabel('Importance') - plt.title(f'Feature Importance ({model_type})') - plot_path = os.path.join(self.output_dir, 'tree_importance.png') - plt.savefig(plot_path) + plt.barh(feature_importances["Feature"], feature_importances["Importance"]) + plt.xlabel("Importance") + plt.title(f"Feature Importance ({model_type})") + plot_path = os.path.join(self.output_dir, "tree_importance.png") + plt.savefig(plot_path, bbox_inches="tight") plt.close() - self.plots['tree_importance'] = plot_path + self.plots["tree_importance"] = plot_path - def save_shap_values(self): - + def save_shap_values(self, max_samples=None, max_display=None, max_features=None): model = self.best_model or self.exp.get_config("best_model") X_data = None @@ -120,78 +114,119 @@ except KeyError: continue if X_data is None: - raise RuntimeError( - "Could not find 'X_test_transformed' or 'X_train_transformed' in the experiment. " - "Make sure PyCaret setup/compare_models was run with feature_selection=True." - ) + raise RuntimeError("No transformed dataset found for SHAP.") + + # --- Adaptive feature limiting (proportional cap) --- + n_rows, n_features = X_data.shape + if max_features is None: + if n_features <= 200: + max_features = n_features + else: + max_features = min(200, max(20, int(n_features * 0.1))) try: - used_features = model.booster_.feature_name() - except Exception: - used_features = getattr(model, "feature_names_in_", X_data.columns.tolist()) - X_data = X_data[used_features] + if hasattr(model, "feature_importances_"): + importances = pd.Series( + model.feature_importances_, index=X_data.columns + ) + top_features = importances.nlargest(max_features).index + elif hasattr(model, "coef_"): + coef = abs(model.coef_).flatten() + importances = pd.Series(coef, index=X_data.columns) + top_features = importances.nlargest(max_features).index + else: + variances = X_data.var() + top_features = variances.nlargest(max_features).index + + if len(top_features) < n_features: + LOG.info( + f"Restricted SHAP computation to top {len(top_features)} / {n_features} features" + ) + X_data = X_data[top_features] + except Exception as e: + LOG.warning( + f"Feature limiting failed: {e}. Using all {n_features} features." + ) - max_bg = min(len(X_data), 100) - bg = X_data.sample(max_bg, random_state=42) + # --- Adaptive row subsampling --- + if max_samples is None: + if n_rows <= 500: + max_samples = n_rows + elif n_rows <= 5000: + max_samples = 500 + else: + max_samples = min(1000, int(n_rows * 0.1)) + + if n_rows > max_samples: + LOG.info(f"Subsampling SHAP rows: {max_samples} of {n_rows}") + X_data = X_data.sample(max_samples, random_state=42) - predict_fn = model.predict_proba if hasattr(model, "predict_proba") else model.predict + # --- Adaptive feature display --- + if max_display is None: + if X_data.shape[1] <= 20: + max_display = X_data.shape[1] + elif X_data.shape[1] <= 100: + max_display = 30 + else: + max_display = 50 + + # Background set + bg = X_data.sample(min(len(X_data), 100), random_state=42) + predict_fn = ( + model.predict_proba if hasattr(model, "predict_proba") else model.predict + ) + + # Optimized explainer + if hasattr(model, "feature_importances_"): + explainer = shap.TreeExplainer( + model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 + ) + elif hasattr(model, "coef_"): + explainer = shap.LinearExplainer(model, bg) + else: + explainer = shap.Explainer(predict_fn, bg) try: - explainer = shap.Explainer(predict_fn, bg) + shap_values = explainer(X_data) self.shap_model_name = explainer.__class__.__name__ - - shap_values = explainer(X_data) except Exception as e: LOG.error(f"SHAP computation failed: {e}") self.shap_model_name = None return - output_names = getattr(shap_values, "output_names", None) - if output_names is None and hasattr(model, "classes_"): - output_names = list(model.classes_) - if output_names is None: - n_out = shap_values.values.shape[-1] - output_names = list(map(str, range(n_out))) + # --- Plot SHAP summary --- + out_path = os.path.join(self.output_dir, "shap_summary.png") + plt.figure() + shap.plots.beeswarm(shap_values, max_display=max_display, show=False) + plt.title( + f"SHAP Summary for {model.__class__.__name__} (top {max_display} features)" + ) + plt.savefig(out_path, bbox_inches="tight") + plt.close() + self.plots["shap_summary"] = out_path - values = shap_values.values - if values.ndim == 3: - for j, name in enumerate(output_names): - safe = name.replace(" ", "_").replace("/", "_") - out_path = os.path.join(self.output_dir, f"shap_summary_{safe}.png") - plt.figure() - shap.plots.beeswarm(shap_values[..., j], show=False) - plt.title(f"SHAP for {model.__class__.__name__} ⇒ {name}") - plt.savefig(out_path) - plt.close() - self.plots[f"shap_summary_{safe}"] = out_path - else: - plt.figure() - shap.plots.beeswarm(shap_values, show=False) - plt.title(f"SHAP Summary for {model.__class__.__name__}") - out_path = os.path.join(self.output_dir, "shap_summary.png") - plt.savefig(out_path) - plt.close() - self.plots["shap_summary"] = out_path + # --- Log summary --- + LOG.info( + f"SHAP summary completed with {X_data.shape[0]} rows and {X_data.shape[1]} features (displaying top {max_display})." + ) def generate_html_report(self): LOG.info("Generating HTML report") - plots_html = "" for plot_name, plot_path in self.plots.items(): - # Special handling for tree importance: skip if no model name (not generated) - if plot_name == 'tree_importance' and not getattr( - self, 'tree_model_name', None + if plot_name == "tree_importance" and not getattr( + self, "tree_model_name", None ): continue encoded_image = self.encode_image_to_base64(plot_path) - if plot_name == 'tree_importance' and getattr( - self, 'tree_model_name', None + if plot_name == "tree_importance" and getattr( + self, "tree_model_name", None ): + section_title = f"Feature importance from {self.tree_model_name}" + elif plot_name == "shap_summary": section_title = ( - f"Feature importance analysis from a trained {self.tree_model_name}" + f"SHAP Summary from {getattr(self, 'shap_model_name', 'model')}" ) - elif plot_name == 'shap_summary': - section_title = f"SHAP Summary from a trained {getattr(self, 'shap_model_name', 'model')}" else: section_title = plot_name plots_html += f""" @@ -200,25 +235,19 @@ <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}"> </div> """ - - html_content = f""" - {plots_html} - """ - - return html_content + return f"{plots_html}" def encode_image_to_base64(self, img_path): - with open(img_path, 'rb') as img_file: - return base64.b64encode(img_file.read()).decode('utf-8') + with open(img_path, "rb") as img_file: + return base64.b64encode(img_file.read()).decode("utf-8") def run(self): if ( self.exp is None - or not hasattr(self.exp, 'is_setup') + or not hasattr(self.exp, "is_setup") or not self.exp.is_setup ): self.setup_pycaret() self.save_tree_importance() self.save_shap_values() - html_content = self.generate_html_report() - return html_content + return self.generate_html_report()
--- a/pycaret_classification.py Fri Aug 22 21:13:30 2025 +0000 +++ b/pycaret_classification.py Mon Sep 08 22:39:12 2025 +0000 @@ -2,15 +2,29 @@ import types from typing import Dict +import numpy as np +import pandas as pd +import plotly.graph_objects as go from base_model_trainer import BaseModelTrainer from dashboard import generate_classifier_explainer_dashboard -from plotly.graph_objects import Figure from pycaret.classification import ClassificationExperiment +from sklearn.metrics import auc, confusion_matrix, precision_recall_curve, roc_curve from utils import predict_proba LOG = logging.getLogger(__name__) +def _apply_report_layout(fig: go.Figure) -> go.Figure: + # Give the left side more space for y-axis title/ticks and let axes auto-reserve room + fig.update_xaxes(automargin=True, title_standoff=12) + fig.update_yaxes(automargin=True, title_standoff=12) + fig.update_layout( + autosize=True, + margin=dict(l=120, r=40, t=60, b=60), # bump 'l' if you still see clipping + ) + return fig + + class ClassificationModelTrainer(BaseModelTrainer): def __init__( self, @@ -50,20 +64,19 @@ ) plots = [ - 'confusion_matrix', - 'auc', - 'threshold', - 'pr', - 'error', - 'class_report', - 'learning', - 'calibration', - 'vc', - 'dimension', - 'manifold', - 'rfe', - 'feature', - 'feature_all', + "auc", + "threshold", + "pr", + "error", + "class_report", + "learning", + "calibration", + "vc", + "dimension", + "manifold", + "rfe", + "feature", + "feature_all", ] for plot_name in plots: try: @@ -102,24 +115,146 @@ LOG.info("Generating explainer plots") + # Ensure predict_proba is available here too + if not hasattr(self.best_model, "predict_proba"): + self.best_model.predict_proba = types.MethodType( + predict_proba, self.best_model + ) + LOG.warning( + f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch." + ) + X_test = self.exp.X_test_transformed.copy() y_test = self.exp.y_test_transformed explainer = ClassifierExplainer(self.best_model, X_test, y_test) # a dict to hold the raw Figure objects or callables - self.explainer_plots: Dict[str, Figure] = {} + self.explainer_plots: Dict[str, go.Figure] = {} + + # --- Threshold-aware overrides for CM / ROC / PR --- + prob_thresh = getattr(self, "probability_threshold", None) + + # Only for binary classification and when threshold is provided + if (prob_thresh is not None) and (not self.exp.is_multiclass): + X = self.exp.X_test_transformed + y = pd.Series(self.exp.y_test_transformed).reset_index(drop=True) + + # Get positive-class scores (robust defaults) + classes = list(getattr(self.best_model, "classes_", [0, 1])) + try: + pos_idx = classes.index(1) if 1 in classes else 1 + except Exception: + pos_idx = 1 + + proba = self.best_model.predict_proba(X) + y_scores = proba[:, pos_idx] + + # Derive label names consistently + pos_label = classes[pos_idx] if len(classes) > pos_idx else 1 + neg_label = classes[1 - pos_idx] if len(classes) > 1 else 0 + + # ---- Confusion Matrix @ threshold ---- + try: + y_pred = np.where(y_scores >= prob_thresh, pos_label, neg_label) + cm = confusion_matrix(y, y_pred, labels=[neg_label, pos_label]) + fig_cm = go.Figure( + data=go.Heatmap( + z=cm, + x=[f"Pred {neg_label}", f"Pred {pos_label}"], + y=[f"True {neg_label}", f"True {pos_label}"], + text=cm, + texttemplate="%{text}", + colorscale="Blues", + showscale=False, + ) + ) + fig_cm.update_layout( + title=f"Confusion Matrix @ threshold={prob_thresh:.2f}", + xaxis_title="Predicted label", + yaxis_title="True label", + ) + _apply_report_layout(fig_cm) + self.explainer_plots["confusion_matrix"] = fig_cm + except Exception as e: + LOG.warning( + f"Threshold-aware confusion matrix failed; falling back: {e}" + ) - # these go into the Test tab + # ---- ROC with threshold marker ---- + try: + fpr, tpr, thr = roc_curve(y, y_scores) + roc_auc = auc(fpr, tpr) + fig_roc = go.Figure() + fig_roc.add_scatter( + x=fpr, y=tpr, mode="lines", name=f"ROC (AUC={roc_auc:.3f})" + ) + if len(thr): + mask = np.isfinite(thr) + if mask.any(): + idx_local = int(np.argmin(np.abs(thr[mask] - prob_thresh))) + idx = np.where(mask)[0][idx_local] + if 0 <= idx < len(fpr): + fig_roc.add_scatter( + x=[fpr[idx]], + y=[tpr[idx]], + mode="markers", + name=f"@ {prob_thresh:.2f}", + marker=dict(size=10), + ) + fig_roc.update_layout( + title=f"ROC Curve (marker at threshold={prob_thresh:.2f})", + xaxis_title="False Positive Rate", + yaxis_title="True Positive Rate", + ) + _apply_report_layout(fig_roc) + self.explainer_plots["roc_auc"] = fig_roc + except Exception as e: + LOG.warning(f"Threshold marker on ROC failed; falling back: {e}") + + # ---- PR with threshold marker ---- + try: + precision, recall, thr_pr = precision_recall_curve(y, y_scores) + pr_auc = auc(recall, precision) + fig_pr = go.Figure() + fig_pr.add_scatter( + x=recall, y=precision, mode="lines", name=f"PR (AUC={pr_auc:.3f})" + ) + if len(thr_pr): + idx_pr = int(np.argmin(np.abs(thr_pr - prob_thresh))) + # note: thr_pr has length = len(precision) - 1 + idx_pr = max(0, min(idx_pr, len(recall) - 1)) + fig_pr.add_scatter( + x=[recall[idx_pr]], + y=[precision[idx_pr]], + mode="markers", + name=f"@ {prob_thresh:.2f}", + marker=dict(size=10), + ) + fig_pr.update_layout( + title=f"Precision–Recall (marker at threshold={prob_thresh:.2f})", + xaxis_title="Recall", + yaxis_title="Precision", + ) + _apply_report_layout(fig_pr) + self.explainer_plots["pr_auc"] = fig_pr + except Exception as e: + LOG.warning(f"Threshold marker on PR failed; falling back: {e}") + + # these go into the Test tab (don't overwrite overrides) for key, fn in [ ("roc_auc", explainer.plot_roc_auc), ("pr_auc", explainer.plot_pr_auc), ("lift_curve", explainer.plot_lift_curve), ("confusion_matrix", explainer.plot_confusion_matrix), - ("threshold", explainer.plot_precision), # Percentage vs probability + ("threshold", explainer.plot_precision), # percentage vs probability ("cumulative_precision", explainer.plot_cumulative_precision), ]: + if key in self.explainer_plots: + continue try: - self.explainer_plots[key] = fn() + fig = fn() + if fig is not None: + self.explainer_plots[key] = fig except Exception as e: LOG.error(f"Error generating explainer plot {key}: {e}") @@ -143,7 +278,9 @@ if feat in explainer.X.columns or feat in explainer.onehot_cols: valid_feats.append(feat) else: - LOG.warning(f"Skipping PDP for feature {feat!r}: not found in explainer data") + LOG.warning( + f"Skipping PDP for feature {feat!r}: not found in explainer data" + ) for feat in valid_feats: # wrap each PDP call to catch any unexpected AssertionErrors @@ -157,6 +294,7 @@ except Exception as e: LOG.error(f"Unexpected error plotting PDP for {f!r}: {e}") return None + return _plot self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat)