Mercurial > repos > goeckslab > pycaret_predict
changeset 16:4fee4504646e draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 2b826699ef9518d4610f5cfb6468ce719ec8039d
| author | goeckslab |
|---|---|
| date | Fri, 28 Nov 2025 22:28:26 +0000 |
| parents | a2aeeb754d76 |
| children | |
| files | base_model_trainer.py feature_importance.py pycaret_macros.xml |
| diffstat | 3 files changed, 475 insertions(+), 100 deletions(-) [+] |
line wrap: on
line diff
--- a/base_model_trainer.py Fri Nov 28 15:46:17 2025 +0000 +++ b/base_model_trainer.py Fri Nov 28 22:28:26 2025 +0000 @@ -46,6 +46,7 @@ self.results = None self.tuning_results = None self.features_name = None + self.plot_feature_names = None self.plots = {} self.explainer_plots = {} self.plots_explainer_html = None @@ -53,6 +54,24 @@ self.user_kwargs = kwargs.copy() for key, value in self.user_kwargs.items(): setattr(self, key, value) + if not hasattr(self, "plot_feature_limit"): + self.plot_feature_limit = 30 + self._shap_row_cap = None + if getattr(self, "polynomial_features", False): + # Keep feature importance responsive by trimming plots/SHAP rows + try: + limit_val = int(self.plot_feature_limit) + except (TypeError, ValueError): + limit_val = 30 + self.plot_feature_limit = min(limit_val, 15) + self._shap_row_cap = 200 + LOG.info( + "Polynomial features enabled; limiting feature plots to %s and SHAP rows to %s", + self.plot_feature_limit, + self._shap_row_cap, + ) + self.imputed_training_data = None + self._best_model_metric_used = None self.setup_params = {} self.test_file = test_file self.test_data = None @@ -127,23 +146,7 @@ LOG.info(f"Dataset columns after processing: {names}") self.features_name = [n for n in names if n != self.target] - - if getattr(self, "missing_value_strategy", None): - strat = self.missing_value_strategy - if strat == "mean": - self.data = self.data.fillna( - self.data.mean(numeric_only=True) - ) - elif strat == "median": - self.data = self.data.fillna( - self.data.median(numeric_only=True) - ) - elif strat == "drop": - self.data = self.data.dropna() - else: - self.data = self.data.fillna( - self.data.median(numeric_only=True) - ) + self.plot_feature_names = self._select_plot_features(self.features_name) if self.test_file: LOG.info(f"Loading test data from {self.test_file}") @@ -153,6 +156,52 @@ df_test.columns = df_test.columns.str.replace(".", "_") self.test_data = df_test + def _select_plot_features(self, all_features): + limit = getattr(self, "plot_feature_limit", 30) + if not isinstance(limit, int) or limit <= 0: + LOG.info( + "Feature plotting limit disabled (plot_feature_limit=%s).", limit + ) + return all_features + if len(all_features) <= limit: + LOG.info( + "Feature plotting limit not needed (%s features <= limit %s).", + len(all_features), + limit, + ) + return all_features + df = self.data[all_features].copy() + numeric_cols = df.select_dtypes(include=["number"]).columns + ranked = [] + if len(numeric_cols) > 0: + variances = ( + df[numeric_cols] + .var() + .fillna(0) + .abs() + .sort_values(ascending=False) + ) + ranked = variances.index.tolist() + selected = [] + for col in ranked: + if len(selected) >= limit: + break + selected.append(col) + if len(selected) < limit: + for col in all_features: + if col in selected: + continue + selected.append(col) + if len(selected) >= limit: + break + LOG.info( + "Limiting feature-level plots to %s of %s available features (limit=%s).", + len(selected), + len(all_features), + limit, + ) + return selected + def setup_pycaret(self): LOG.info("Initializing PyCaret") self.setup_params = { @@ -198,29 +247,41 @@ ) self.exp.setup(self.data, **self.setup_params) + self._capture_imputed_training_data() self.setup_params.update(self.user_kwargs) - def _normalize_metric(self, m: str) -> str: - if not m: - return "R2" if self.task_type == "regression" else "Accuracy" - m_low = str(m).strip().lower() - alias = { - "auc": "AUC", "roc_auc": "AUC", "roc-auc": "AUC", - "accuracy": "Accuracy", - "precision": "Precision", - "recall": "Recall", - "f1": "F1", - "kappa": "Kappa", - "logloss": "Log Loss", "log_loss": "Log Loss", - "pr_auc": "PR-AUC-Weighted", "prauc": "PR-AUC-Weighted", - "r2": "R2", - "mae": "MAE", - "mse": "MSE", - "rmse": "RMSE", - "rmsle": "RMSLE", - "mape": "MAPE", - } - return alias.get(m_low, m) + def _capture_imputed_training_data(self): + """ + Cache the dataset as transformed/imputed by PyCaret so downstream + components (e.g., feature importance) can operate on the exact data + used for training. + """ + if self.exp is None: + return + try: + X_processed = self.exp.get_config("X_transformed").copy() + y_processed = self.exp.get_config("y") + if isinstance(y_processed, pd.Series): + y_series = y_processed.reset_index(drop=True) + else: + y_series = pd.Series(y_processed) + y_series.name = self.target + X_processed = X_processed.reset_index(drop=True) + self.imputed_training_data = pd.concat( + [X_processed, y_series], axis=1 + ) + LOG.info( + "Captured imputed training dataset from PyCaret " + "(%s rows, %s features).", + self.imputed_training_data.shape[0], + self.imputed_training_data.shape[1] - 1, + ) + except Exception as exc: + LOG.warning( + "Unable to capture processed training data from PyCaret: %s", + exc, + ) + self.imputed_training_data = None def train_model(self): LOG.info("Training and selecting the best model") @@ -245,17 +306,16 @@ if getattr(self, "cross_validation_folds", None) is not None: compare_kwargs["fold"] = self.cross_validation_folds - chosen_metric = self._normalize_metric(getattr(self, "best_model_metric", None)) - if chosen_metric: - compare_kwargs["sort"] = chosen_metric - self.chosen_metric_label = chosen_metric - try: - setattr(self.exp, "_fold_metric", chosen_metric) - except Exception as e: - LOG.warning(f"Failed to set '_fold_metric' to '{chosen_metric}': {e}", exc_info=True) + best_metric = getattr(self, "best_model_metric", None) + if best_metric: + compare_kwargs["sort"] = best_metric + self._best_model_metric_used = best_metric + LOG.info(f"Ranking models using metric: {best_metric}") LOG.info(f"compare_models kwargs: {compare_kwargs}") self.best_model = self.exp.compare_models(**compare_kwargs) + if self._best_model_metric_used is None: + self._best_model_metric_used = getattr(self.exp, "_fold_metric", None) self.results = self.exp.pull() if getattr(self, "tune_model", False): LOG.info("Tuning hyperparameters of the best model") @@ -327,6 +387,31 @@ with open(img_path, "rb") as img_file: return base64.b64encode(img_file.read()).decode("utf-8") + def _resolve_plot_callable(self, key, fig_or_fn, section): + """ + Safely execute stored plot callables so a single failure does not + abort the entire HTML report generation. + """ + if fig_or_fn is None: + return None + try: + return fig_or_fn() if callable(fig_or_fn) else fig_or_fn + except Exception as exc: + extra = "" + if isinstance(exc, ValueError) and "Input contains NaN" in str(exc): + extra = ( + " (model returned NaN probabilities; " + "consider checking data preprocessing)" + ) + LOG.warning( + "Skipping %s plot '%s' due to error: %s%s", + section, + key, + exc, + extra, + ) + return None + def save_html_report(self): LOG.info("Saving HTML report") @@ -401,8 +486,11 @@ else: dv = v if v is not None else "None" setup_rows.append([key, dv]) - if getattr(self, "chosen_metric_label", None): - setup_rows.append(["Best Model Metric", self.chosen_metric_label]) + metric_label = self._best_model_metric_used or getattr( + self.exp, "_fold_metric", None + ) + if metric_label: + setup_rows.append(["Best Model Metric", metric_label]) df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"]) df_setup.to_csv( @@ -564,13 +652,16 @@ "roc_auc", "pr_auc", "lift_curve", - "threshold", "cumulative_precision", ] for key in test_order: fig_or_fn = self.explainer_plots.pop(key, None) if fig_or_fn is not None: - fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn + fig = self._resolve_plot_callable( + key, fig_or_fn, section="test/explainer" + ) + if fig is None: + continue title = plot_title_map.get( key, key.replace("_", " ").title() ) @@ -584,7 +675,6 @@ # skipping anything if self.task_type == "classification" and ( name in { - "threshold", "pr_auc", "class_report", } @@ -630,20 +720,57 @@ feature_html = header # 6a) PyCaret’s default feature importances - feature_html += FeatureImportanceAnalyzer( - data=self.data, + imputed_data = ( + self.imputed_training_data + if self.imputed_training_data is not None + else self.data + ) + fi_analyzer = FeatureImportanceAnalyzer( + data=imputed_data, target_col=self.target_col, task_type=self.task_type, output_dir=self.output_dir, exp=self.exp, best_model=self.best_model, - ).run() + max_plot_features=self.plot_feature_limit, + processed_data=self.imputed_training_data, + max_shap_rows=self._shap_row_cap, + ) + fi_html = fi_analyzer.run() + # Add a small table to show SHAP feature caps near the Best Model header. + cap_rows = [] + if fi_analyzer.shap_total_features is not None: + cap_rows.append( + ("Total transformed features", fi_analyzer.shap_total_features) + ) + if fi_analyzer.shap_used_features is not None: + cap_rows.append( + ("Features used in SHAP", fi_analyzer.shap_used_features) + ) + if cap_rows: + cap_table = ( + "<div class='table-wrapper'>" + "<table class='table sortable'>" + "<thead><tr><th>Feature Importance Scope</th><th>Count</th></tr></thead>" + "<tbody>" + + "".join( + f"<tr><td>{label}</td><td>{value}</td></tr>" + for label, value in cap_rows + ) + + "</tbody></table></div>" + ) + feature_html += cap_table + feature_html += fi_html # 6b) Explainer SHAP importances for key in ["shap_mean", "shap_perm"]: fig_or_fn = self.explainer_plots.pop(key, None) if fig_or_fn is not None: - fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn + fig = self._resolve_plot_callable( + key, fig_or_fn, section="feature importance" + ) + if fig is None: + continue # give SHAP plots explicit titles title = ( "Mean Absolute SHAP Value Impact" @@ -661,7 +788,11 @@ ) for k in pdp_keys: fig_or_fn = self.explainer_plots[k] - fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn + fig = self._resolve_plot_callable( + k, fig_or_fn, section="pdp" + ) + if fig is None: + continue # extract feature name feature = k.split("__", 1)[1] title = f"Partial Dependence for {feature}"
--- a/feature_importance.py Fri Nov 28 15:46:17 2025 +0000 +++ b/feature_importance.py Fri Nov 28 22:28:26 2025 +0000 @@ -22,11 +22,23 @@ target_col=None, exp=None, best_model=None, + max_plot_features=None, + processed_data=None, + max_shap_rows=None, ): self.task_type = task_type self.output_dir = output_dir self.exp = exp self.best_model = best_model + self._skip_messages = [] + self.shap_total_features = None + self.shap_used_features = None + if isinstance(max_plot_features, int) and max_plot_features > 0: + self.max_plot_features = max_plot_features + elif max_plot_features is None: + self.max_plot_features = 30 + else: + self.max_plot_features = None if exp is not None: # Assume all configs (data, target) are in exp @@ -48,8 +60,55 @@ if task_type == "classification" else RegressionExperiment() ) + if processed_data is not None: + self.data = processed_data self.plots = {} + self.max_shap_rows = max_shap_rows + + def _get_feature_names_from_model(self, model): + """Best-effort extraction of feature names seen by the estimator.""" + if model is None: + return None + + candidates = [model] + if hasattr(model, "named_steps"): + candidates.extend(model.named_steps.values()) + elif hasattr(model, "steps"): + candidates.extend(step for _, step in model.steps) + + for candidate in candidates: + names = getattr(candidate, "feature_names_in_", None) + if names is not None: + return list(names) + return None + + def _get_transformed_frame(self, model=None, prefer_test=True): + """Return a DataFrame that mirrors the matrix fed to the estimator.""" + key_order = ["X_test_transformed", "X_train_transformed"] + if not prefer_test: + key_order.reverse() + key_order.append("X_transformed") + + feature_names = self._get_feature_names_from_model(model) + for key in key_order: + try: + frame = self.exp.get_config(key) + except KeyError: + continue + if frame is None: + continue + if isinstance(frame, pd.DataFrame): + return frame.copy() + try: + n_features = frame.shape[1] + except Exception: + continue + if feature_names and len(feature_names) == n_features: + return pd.DataFrame(frame, columns=feature_names) + # Fallback to positional names so downstream logic still works + return pd.DataFrame(frame, columns=[f"f{i}" for i in range(n_features)]) + return None def setup_pycaret(self): if self.exp is not None and hasattr(self.exp, "is_setup") and self.exp.is_setup: @@ -67,7 +126,14 @@ def save_tree_importance(self): model = self.best_model or self.exp.get_config("best_model") - processed_features = self.exp.get_config("X_transformed").columns + processed_frame = self._get_transformed_frame(model, prefer_test=False) + if processed_frame is None: + LOG.warning( + "Unable to determine transformed feature names; skipping tree importance plot." + ) + self.tree_model_name = None + return + processed_features = list(processed_frame.columns) importances = None model_type = model.__class__.__name__ @@ -85,20 +151,42 @@ return if len(importances) != len(processed_features): - LOG.warning( - f"Importances ({len(importances)}) != features ({len(processed_features)}). Skipping tree importance." - ) - self.tree_model_name = None - return + model_feature_names = self._get_feature_names_from_model(model) + if model_feature_names and len(model_feature_names) == len(importances): + processed_features = model_feature_names + else: + LOG.warning( + "Importances (%s) != features (%s). Skipping tree importance.", + len(importances), + len(processed_features), + ) + self.tree_model_name = None + return feature_importances = pd.DataFrame( {"Feature": processed_features, "Importance": importances} ).sort_values(by="Importance", ascending=False) + cap = ( + min(self.max_plot_features, len(feature_importances)) + if self.max_plot_features is not None + else len(feature_importances) + ) + plot_importances = feature_importances.head(cap) + if cap < len(feature_importances): + LOG.info( + "Tree importance plot limited to top %s of %s features", + cap, + len(feature_importances), + ) plt.figure(figsize=(10, 6)) - plt.barh(feature_importances["Feature"], feature_importances["Importance"]) + plt.barh( + plot_importances["Feature"], + plot_importances["Importance"], + ) plt.xlabel("Importance") - plt.title(f"Feature Importance ({model_type})") + plt.title(f"Feature Importance ({model_type}) (top {cap})") plot_path = os.path.join(self.output_dir, "tree_importance.png") + plt.tight_layout() plt.savefig(plot_path, bbox_inches="tight") plt.close() self.plots["tree_importance"] = plot_path @@ -106,23 +194,22 @@ def save_shap_values(self, max_samples=None, max_display=None, max_features=None): model = self.best_model or self.exp.get_config("best_model") - X_data = None - for key in ("X_test_transformed", "X_train_transformed"): - try: - X_data = self.exp.get_config(key) - break - except KeyError: - continue + X_data = self._get_transformed_frame(model) if X_data is None: raise RuntimeError("No transformed dataset found for SHAP.") - # --- Adaptive feature limiting (proportional cap) --- n_rows, n_features = X_data.shape + self.shap_total_features = n_features + feature_cap = ( + min(self.max_plot_features, n_features) + if self.max_plot_features is not None + else n_features + ) if max_features is None: - if n_features <= 200: - max_features = n_features - else: - max_features = min(200, max(20, int(n_features * 0.1))) + max_features = feature_cap + else: + max_features = min(max_features, feature_cap) + display_features = list(X_data.columns) try: if hasattr(model, "feature_importances_"): @@ -138,15 +225,35 @@ variances = X_data.var() top_features = variances.nlargest(max_features).index - if len(top_features) < n_features: + candidate_features = list(top_features) + missing = [f for f in candidate_features if f not in X_data.columns] + display_features = [f for f in candidate_features if f in X_data.columns] + if missing: + LOG.warning( + "Dropping %s transformed feature(s) not present in SHAP frame: %s", + len(missing), + missing[:5], + ) + if display_features and len(display_features) < n_features: LOG.info( - f"Restricted SHAP computation to top {len(top_features)} / {n_features} features" + "Restricting SHAP display to top %s of %s features", + len(display_features), + n_features, ) - X_data = X_data[top_features] + elif not display_features: + display_features = list(X_data.columns) except Exception as e: LOG.warning( f"Feature limiting failed: {e}. Using all {n_features} features." ) + display_features = list(X_data.columns) + + self.shap_used_features = len(display_features) + + # Apply the column restriction so SHAP only runs on the selected features. + if display_features: + X_data = X_data[display_features] + n_rows, n_features = X_data.shape # --- Adaptive row subsampling --- if max_samples is None: @@ -157,18 +264,26 @@ else: max_samples = min(1000, int(n_rows * 0.1)) + if self.max_shap_rows is not None: + max_samples = min(max_samples, self.max_shap_rows) + if n_rows > max_samples: LOG.info(f"Subsampling SHAP rows: {max_samples} of {n_rows}") X_data = X_data.sample(max_samples, random_state=42) # --- Adaptive feature display --- + display_cap = ( + min(self.max_plot_features, len(display_features)) + if self.max_plot_features is not None + else len(display_features) + ) if max_display is None: - if X_data.shape[1] <= 20: - max_display = X_data.shape[1] - elif X_data.shape[1] <= 100: - max_display = 30 - else: - max_display = 50 + max_display = display_cap + else: + max_display = min(max_display, display_cap) + if not display_features: + display_features = list(X_data.columns) + max_display = len(display_features) # Background set bg = X_data.sample(min(len(X_data), 100), random_state=42) @@ -177,37 +292,159 @@ ) # Optimized explainer + explainer = None + explainer_label = None if hasattr(model, "feature_importances_"): explainer = shap.TreeExplainer( model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 ) + explainer_label = "tree_path_dependent" elif hasattr(model, "coef_"): explainer = shap.LinearExplainer(model, bg) + explainer_label = "linear" else: explainer = shap.Explainer(predict_fn, bg) + explainer_label = explainer.__class__.__name__ try: shap_values = explainer(X_data) self.shap_model_name = explainer.__class__.__name__ except Exception as e: - LOG.error(f"SHAP computation failed: {e}") + error_message = str(e) + needs_tree_fallback = ( + hasattr(model, "feature_importances_") + and "does not cover all the leaves" in error_message.lower() + ) + feature_name_mismatch = "feature names should match" in error_message.lower() + if needs_tree_fallback: + LOG.warning( + "SHAP computation failed using '%s' perturbation (%s). " + "Retrying with interventional perturbation.", + explainer_label, + error_message, + ) + try: + explainer = shap.TreeExplainer( + model, + bg, + feature_perturbation="interventional", + n_jobs=-1, + ) + shap_values = explainer(X_data) + self.shap_model_name = ( + f"{explainer.__class__.__name__} (interventional)" + ) + except Exception as retry_exc: + LOG.error( + "SHAP computation failed even after fallback: %s", + retry_exc, + ) + self.shap_model_name = None + return + elif feature_name_mismatch: + LOG.warning( + "SHAP computation failed due to feature-name mismatch (%s). " + "Falling back to model-agnostic SHAP explainer.", + error_message, + ) + try: + agnostic_explainer = shap.Explainer(predict_fn, bg) + shap_values = agnostic_explainer(X_data) + self.shap_model_name = ( + f"{agnostic_explainer.__class__.__name__} (fallback)" + ) + except Exception as fallback_exc: + LOG.error( + "Model-agnostic SHAP fallback also failed: %s", + fallback_exc, + ) + self.shap_model_name = None + return + else: + LOG.error(f"SHAP computation failed: {e}") + self.shap_model_name = None + return + + def _limit_explanation_features(explanation): + if len(display_features) >= n_features: + return explanation + try: + limited = explanation[:, display_features] + LOG.info( + "SHAP explanation trimmed to %s display features.", + len(display_features), + ) + return limited + except Exception as exc: + LOG.warning( + "Failed to restrict SHAP explanation to top features " + "(sample=%s); plot will include all features. Error: %s", + display_features[:5], + exc, + ) + # Keep using full feature list if trimming fails + return explanation + + shap_shape = getattr(shap_values, "shape", None) + class_labels = list(getattr(model, "classes_", [])) + shap_outputs = [] + if shap_shape is not None and len(shap_shape) == 3: + output_count = shap_shape[2] + LOG.info("Detected multi-output SHAP explanation with %s classes.", output_count) + for class_idx in range(output_count): + try: + class_expl = shap_values[..., class_idx] + except Exception as exc: + LOG.warning( + "Failed to extract SHAP explanation for class index %s: %s", + class_idx, + exc, + ) + continue + label = ( + class_labels[class_idx] + if class_labels and class_idx < len(class_labels) + else class_idx + ) + shap_outputs.append((class_idx, label, class_expl)) + else: + shap_outputs.append((None, None, shap_values)) + + if not shap_outputs: + LOG.error("No SHAP outputs available for plotting.") self.shap_model_name = None return - # --- Plot SHAP summary --- - out_path = os.path.join(self.output_dir, "shap_summary.png") - plt.figure() - shap.plots.beeswarm(shap_values, max_display=max_display, show=False) - plt.title( - f"SHAP Summary for {model.__class__.__name__} (top {max_display} features)" - ) - plt.savefig(out_path, bbox_inches="tight") - plt.close() - self.plots["shap_summary"] = out_path + # --- Plot SHAP summary (one per class if needed) --- + for class_idx, class_label, class_expl in shap_outputs: + expl_to_plot = _limit_explanation_features(class_expl) + suffix = "" + plot_key = "shap_summary" + if class_idx is not None: + safe_label = str(class_label).replace("/", "_").replace(" ", "_") + suffix = f"_class_{safe_label}" + plot_key = f"shap_summary_class_{safe_label}" + out_filename = f"shap_summary{suffix}.png" + out_path = os.path.join(self.output_dir, out_filename) + plt.figure() + shap.plots.beeswarm(expl_to_plot, max_display=max_display, show=False) + title = f"SHAP Summary for {model.__class__.__name__}" + if class_idx is not None: + title += f" (class {class_label})" + plt.title(f"{title} (top {max_display} features)") + plt.tight_layout() + plt.savefig(out_path, bbox_inches="tight") + plt.close() + self.plots[plot_key] = out_path # --- Log summary --- LOG.info( - f"SHAP summary completed with {X_data.shape[0]} rows and {X_data.shape[1]} features (displaying top {max_display})." + "SHAP summary completed with %s rows and %s features " + "(displaying top %s) across %s output(s).", + X_data.shape[0], + X_data.shape[1], + max_display, + len(shap_outputs), ) def generate_html_report(self): @@ -227,12 +464,19 @@ section_title = ( f"SHAP Summary from {getattr(self, 'shap_model_name', 'model')}" ) + elif plot_name.startswith("shap_summary_class_"): + class_label = plot_name.replace("shap_summary_class_", "") + section_title = ( + f"SHAP Summary for class {class_label} " + f"({getattr(self, 'shap_model_name', 'model')})" + ) else: section_title = plot_name plots_html += f""" - <div class="plot" id="{plot_name}"> + <div class="plot" id="{plot_name}" style="text-align:center;margin-bottom:24px;"> <h2>{section_title}</h2> - <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}"> + <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}" + style="max-width:95%;height:auto;display:block;margin:0 auto;border:1px solid #ddd;padding:8px;background:#fff;"> </div> """ return f"{plots_html}"
--- a/pycaret_macros.xml Fri Nov 28 15:46:17 2025 +0000 +++ b/pycaret_macros.xml Fri Nov 28 22:28:26 2025 +0000 @@ -1,5 +1,5 @@ <macros> - <token name="@TABULAR_LEARNER_VERSION@">0.1.2</token> + <token name="@TABULAR_LEARNER_VERSION@">0.1.3</token> <token name="@PYCARET_VERSION@">3.3.2</token> <token name="@SUFFIX@">2</token> <token name="@PYCARET_PREDICT_VERSION@">@PYCARET_VERSION@+@SUFFIX@</token>
