# HG changeset patch
# User goeckslab
# Date 1764368906 0
# Node ID 4fee4504646ebea215697ce15060a08b30f20283
# Parent a2aeeb754d760c7cb58e57f7b208ac8f38ff4c83
planemo upload for repository https://github.com/goeckslab/gleam commit 2b826699ef9518d4610f5cfb6468ce719ec8039d
diff -r a2aeeb754d76 -r 4fee4504646e base_model_trainer.py
--- a/base_model_trainer.py Fri Nov 28 15:46:17 2025 +0000
+++ b/base_model_trainer.py Fri Nov 28 22:28:26 2025 +0000
@@ -46,6 +46,7 @@
self.results = None
self.tuning_results = None
self.features_name = None
+ self.plot_feature_names = None
self.plots = {}
self.explainer_plots = {}
self.plots_explainer_html = None
@@ -53,6 +54,24 @@
self.user_kwargs = kwargs.copy()
for key, value in self.user_kwargs.items():
setattr(self, key, value)
+ if not hasattr(self, "plot_feature_limit"):
+ self.plot_feature_limit = 30
+ self._shap_row_cap = None
+ if getattr(self, "polynomial_features", False):
+ # Keep feature importance responsive by trimming plots/SHAP rows
+ try:
+ limit_val = int(self.plot_feature_limit)
+ except (TypeError, ValueError):
+ limit_val = 30
+ self.plot_feature_limit = min(limit_val, 15)
+ self._shap_row_cap = 200
+ LOG.info(
+ "Polynomial features enabled; limiting feature plots to %s and SHAP rows to %s",
+ self.plot_feature_limit,
+ self._shap_row_cap,
+ )
+ self.imputed_training_data = None
+ self._best_model_metric_used = None
self.setup_params = {}
self.test_file = test_file
self.test_data = None
@@ -127,23 +146,7 @@
LOG.info(f"Dataset columns after processing: {names}")
self.features_name = [n for n in names if n != self.target]
-
- if getattr(self, "missing_value_strategy", None):
- strat = self.missing_value_strategy
- if strat == "mean":
- self.data = self.data.fillna(
- self.data.mean(numeric_only=True)
- )
- elif strat == "median":
- self.data = self.data.fillna(
- self.data.median(numeric_only=True)
- )
- elif strat == "drop":
- self.data = self.data.dropna()
- else:
- self.data = self.data.fillna(
- self.data.median(numeric_only=True)
- )
+ self.plot_feature_names = self._select_plot_features(self.features_name)
if self.test_file:
LOG.info(f"Loading test data from {self.test_file}")
@@ -153,6 +156,52 @@
df_test.columns = df_test.columns.str.replace(".", "_")
self.test_data = df_test
+ def _select_plot_features(self, all_features):
+ limit = getattr(self, "plot_feature_limit", 30)
+ if not isinstance(limit, int) or limit <= 0:
+ LOG.info(
+ "Feature plotting limit disabled (plot_feature_limit=%s).", limit
+ )
+ return all_features
+ if len(all_features) <= limit:
+ LOG.info(
+ "Feature plotting limit not needed (%s features <= limit %s).",
+ len(all_features),
+ limit,
+ )
+ return all_features
+ df = self.data[all_features].copy()
+ numeric_cols = df.select_dtypes(include=["number"]).columns
+ ranked = []
+ if len(numeric_cols) > 0:
+ variances = (
+ df[numeric_cols]
+ .var()
+ .fillna(0)
+ .abs()
+ .sort_values(ascending=False)
+ )
+ ranked = variances.index.tolist()
+ selected = []
+ for col in ranked:
+ if len(selected) >= limit:
+ break
+ selected.append(col)
+ if len(selected) < limit:
+ for col in all_features:
+ if col in selected:
+ continue
+ selected.append(col)
+ if len(selected) >= limit:
+ break
+ LOG.info(
+ "Limiting feature-level plots to %s of %s available features (limit=%s).",
+ len(selected),
+ len(all_features),
+ limit,
+ )
+ return selected
+
def setup_pycaret(self):
LOG.info("Initializing PyCaret")
self.setup_params = {
@@ -198,29 +247,41 @@
)
self.exp.setup(self.data, **self.setup_params)
+ self._capture_imputed_training_data()
self.setup_params.update(self.user_kwargs)
- def _normalize_metric(self, m: str) -> str:
- if not m:
- return "R2" if self.task_type == "regression" else "Accuracy"
- m_low = str(m).strip().lower()
- alias = {
- "auc": "AUC", "roc_auc": "AUC", "roc-auc": "AUC",
- "accuracy": "Accuracy",
- "precision": "Precision",
- "recall": "Recall",
- "f1": "F1",
- "kappa": "Kappa",
- "logloss": "Log Loss", "log_loss": "Log Loss",
- "pr_auc": "PR-AUC-Weighted", "prauc": "PR-AUC-Weighted",
- "r2": "R2",
- "mae": "MAE",
- "mse": "MSE",
- "rmse": "RMSE",
- "rmsle": "RMSLE",
- "mape": "MAPE",
- }
- return alias.get(m_low, m)
+ def _capture_imputed_training_data(self):
+ """
+ Cache the dataset as transformed/imputed by PyCaret so downstream
+ components (e.g., feature importance) can operate on the exact data
+ used for training.
+ """
+ if self.exp is None:
+ return
+ try:
+ X_processed = self.exp.get_config("X_transformed").copy()
+ y_processed = self.exp.get_config("y")
+ if isinstance(y_processed, pd.Series):
+ y_series = y_processed.reset_index(drop=True)
+ else:
+ y_series = pd.Series(y_processed)
+ y_series.name = self.target
+ X_processed = X_processed.reset_index(drop=True)
+ self.imputed_training_data = pd.concat(
+ [X_processed, y_series], axis=1
+ )
+ LOG.info(
+ "Captured imputed training dataset from PyCaret "
+ "(%s rows, %s features).",
+ self.imputed_training_data.shape[0],
+ self.imputed_training_data.shape[1] - 1,
+ )
+ except Exception as exc:
+ LOG.warning(
+ "Unable to capture processed training data from PyCaret: %s",
+ exc,
+ )
+ self.imputed_training_data = None
def train_model(self):
LOG.info("Training and selecting the best model")
@@ -245,17 +306,16 @@
if getattr(self, "cross_validation_folds", None) is not None:
compare_kwargs["fold"] = self.cross_validation_folds
- chosen_metric = self._normalize_metric(getattr(self, "best_model_metric", None))
- if chosen_metric:
- compare_kwargs["sort"] = chosen_metric
- self.chosen_metric_label = chosen_metric
- try:
- setattr(self.exp, "_fold_metric", chosen_metric)
- except Exception as e:
- LOG.warning(f"Failed to set '_fold_metric' to '{chosen_metric}': {e}", exc_info=True)
+ best_metric = getattr(self, "best_model_metric", None)
+ if best_metric:
+ compare_kwargs["sort"] = best_metric
+ self._best_model_metric_used = best_metric
+ LOG.info(f"Ranking models using metric: {best_metric}")
LOG.info(f"compare_models kwargs: {compare_kwargs}")
self.best_model = self.exp.compare_models(**compare_kwargs)
+ if self._best_model_metric_used is None:
+ self._best_model_metric_used = getattr(self.exp, "_fold_metric", None)
self.results = self.exp.pull()
if getattr(self, "tune_model", False):
LOG.info("Tuning hyperparameters of the best model")
@@ -327,6 +387,31 @@
with open(img_path, "rb") as img_file:
return base64.b64encode(img_file.read()).decode("utf-8")
+ def _resolve_plot_callable(self, key, fig_or_fn, section):
+ """
+ Safely execute stored plot callables so a single failure does not
+ abort the entire HTML report generation.
+ """
+ if fig_or_fn is None:
+ return None
+ try:
+ return fig_or_fn() if callable(fig_or_fn) else fig_or_fn
+ except Exception as exc:
+ extra = ""
+ if isinstance(exc, ValueError) and "Input contains NaN" in str(exc):
+ extra = (
+ " (model returned NaN probabilities; "
+ "consider checking data preprocessing)"
+ )
+ LOG.warning(
+ "Skipping %s plot '%s' due to error: %s%s",
+ section,
+ key,
+ exc,
+ extra,
+ )
+ return None
+
def save_html_report(self):
LOG.info("Saving HTML report")
@@ -401,8 +486,11 @@
else:
dv = v if v is not None else "None"
setup_rows.append([key, dv])
- if getattr(self, "chosen_metric_label", None):
- setup_rows.append(["Best Model Metric", self.chosen_metric_label])
+ metric_label = self._best_model_metric_used or getattr(
+ self.exp, "_fold_metric", None
+ )
+ if metric_label:
+ setup_rows.append(["Best Model Metric", metric_label])
df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"])
df_setup.to_csv(
@@ -564,13 +652,16 @@
"roc_auc",
"pr_auc",
"lift_curve",
- "threshold",
"cumulative_precision",
]
for key in test_order:
fig_or_fn = self.explainer_plots.pop(key, None)
if fig_or_fn is not None:
- fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn
+ fig = self._resolve_plot_callable(
+ key, fig_or_fn, section="test/explainer"
+ )
+ if fig is None:
+ continue
title = plot_title_map.get(
key, key.replace("_", " ").title()
)
@@ -584,7 +675,6 @@
# skipping anything
if self.task_type == "classification" and (
name in {
- "threshold",
"pr_auc",
"class_report",
}
@@ -630,20 +720,57 @@
feature_html = header
# 6a) PyCaret’s default feature importances
- feature_html += FeatureImportanceAnalyzer(
- data=self.data,
+ imputed_data = (
+ self.imputed_training_data
+ if self.imputed_training_data is not None
+ else self.data
+ )
+ fi_analyzer = FeatureImportanceAnalyzer(
+ data=imputed_data,
target_col=self.target_col,
task_type=self.task_type,
output_dir=self.output_dir,
exp=self.exp,
best_model=self.best_model,
- ).run()
+ max_plot_features=self.plot_feature_limit,
+ processed_data=self.imputed_training_data,
+ max_shap_rows=self._shap_row_cap,
+ )
+ fi_html = fi_analyzer.run()
+ # Add a small table to show SHAP feature caps near the Best Model header.
+ cap_rows = []
+ if fi_analyzer.shap_total_features is not None:
+ cap_rows.append(
+ ("Total transformed features", fi_analyzer.shap_total_features)
+ )
+ if fi_analyzer.shap_used_features is not None:
+ cap_rows.append(
+ ("Features used in SHAP", fi_analyzer.shap_used_features)
+ )
+ if cap_rows:
+ cap_table = (
+ "
"
+ "
"
+ "| Feature Importance Scope | Count |
"
+ ""
+ + "".join(
+ f"| {label} | {value} |
"
+ for label, value in cap_rows
+ )
+ + "
"
+ )
+ feature_html += cap_table
+ feature_html += fi_html
# 6b) Explainer SHAP importances
for key in ["shap_mean", "shap_perm"]:
fig_or_fn = self.explainer_plots.pop(key, None)
if fig_or_fn is not None:
- fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn
+ fig = self._resolve_plot_callable(
+ key, fig_or_fn, section="feature importance"
+ )
+ if fig is None:
+ continue
# give SHAP plots explicit titles
title = (
"Mean Absolute SHAP Value Impact"
@@ -661,7 +788,11 @@
)
for k in pdp_keys:
fig_or_fn = self.explainer_plots[k]
- fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn
+ fig = self._resolve_plot_callable(
+ k, fig_or_fn, section="pdp"
+ )
+ if fig is None:
+ continue
# extract feature name
feature = k.split("__", 1)[1]
title = f"Partial Dependence for {feature}"
diff -r a2aeeb754d76 -r 4fee4504646e feature_importance.py
--- a/feature_importance.py Fri Nov 28 15:46:17 2025 +0000
+++ b/feature_importance.py Fri Nov 28 22:28:26 2025 +0000
@@ -22,11 +22,23 @@
target_col=None,
exp=None,
best_model=None,
+ max_plot_features=None,
+ processed_data=None,
+ max_shap_rows=None,
):
self.task_type = task_type
self.output_dir = output_dir
self.exp = exp
self.best_model = best_model
+ self._skip_messages = []
+ self.shap_total_features = None
+ self.shap_used_features = None
+ if isinstance(max_plot_features, int) and max_plot_features > 0:
+ self.max_plot_features = max_plot_features
+ elif max_plot_features is None:
+ self.max_plot_features = 30
+ else:
+ self.max_plot_features = None
if exp is not None:
# Assume all configs (data, target) are in exp
@@ -48,8 +60,55 @@
if task_type == "classification"
else RegressionExperiment()
)
+ if processed_data is not None:
+ self.data = processed_data
self.plots = {}
+ self.max_shap_rows = max_shap_rows
+
+ def _get_feature_names_from_model(self, model):
+ """Best-effort extraction of feature names seen by the estimator."""
+ if model is None:
+ return None
+
+ candidates = [model]
+ if hasattr(model, "named_steps"):
+ candidates.extend(model.named_steps.values())
+ elif hasattr(model, "steps"):
+ candidates.extend(step for _, step in model.steps)
+
+ for candidate in candidates:
+ names = getattr(candidate, "feature_names_in_", None)
+ if names is not None:
+ return list(names)
+ return None
+
+ def _get_transformed_frame(self, model=None, prefer_test=True):
+ """Return a DataFrame that mirrors the matrix fed to the estimator."""
+ key_order = ["X_test_transformed", "X_train_transformed"]
+ if not prefer_test:
+ key_order.reverse()
+ key_order.append("X_transformed")
+
+ feature_names = self._get_feature_names_from_model(model)
+ for key in key_order:
+ try:
+ frame = self.exp.get_config(key)
+ except KeyError:
+ continue
+ if frame is None:
+ continue
+ if isinstance(frame, pd.DataFrame):
+ return frame.copy()
+ try:
+ n_features = frame.shape[1]
+ except Exception:
+ continue
+ if feature_names and len(feature_names) == n_features:
+ return pd.DataFrame(frame, columns=feature_names)
+ # Fallback to positional names so downstream logic still works
+ return pd.DataFrame(frame, columns=[f"f{i}" for i in range(n_features)])
+ return None
def setup_pycaret(self):
if self.exp is not None and hasattr(self.exp, "is_setup") and self.exp.is_setup:
@@ -67,7 +126,14 @@
def save_tree_importance(self):
model = self.best_model or self.exp.get_config("best_model")
- processed_features = self.exp.get_config("X_transformed").columns
+ processed_frame = self._get_transformed_frame(model, prefer_test=False)
+ if processed_frame is None:
+ LOG.warning(
+ "Unable to determine transformed feature names; skipping tree importance plot."
+ )
+ self.tree_model_name = None
+ return
+ processed_features = list(processed_frame.columns)
importances = None
model_type = model.__class__.__name__
@@ -85,20 +151,42 @@
return
if len(importances) != len(processed_features):
- LOG.warning(
- f"Importances ({len(importances)}) != features ({len(processed_features)}). Skipping tree importance."
- )
- self.tree_model_name = None
- return
+ model_feature_names = self._get_feature_names_from_model(model)
+ if model_feature_names and len(model_feature_names) == len(importances):
+ processed_features = model_feature_names
+ else:
+ LOG.warning(
+ "Importances (%s) != features (%s). Skipping tree importance.",
+ len(importances),
+ len(processed_features),
+ )
+ self.tree_model_name = None
+ return
feature_importances = pd.DataFrame(
{"Feature": processed_features, "Importance": importances}
).sort_values(by="Importance", ascending=False)
+ cap = (
+ min(self.max_plot_features, len(feature_importances))
+ if self.max_plot_features is not None
+ else len(feature_importances)
+ )
+ plot_importances = feature_importances.head(cap)
+ if cap < len(feature_importances):
+ LOG.info(
+ "Tree importance plot limited to top %s of %s features",
+ cap,
+ len(feature_importances),
+ )
plt.figure(figsize=(10, 6))
- plt.barh(feature_importances["Feature"], feature_importances["Importance"])
+ plt.barh(
+ plot_importances["Feature"],
+ plot_importances["Importance"],
+ )
plt.xlabel("Importance")
- plt.title(f"Feature Importance ({model_type})")
+ plt.title(f"Feature Importance ({model_type}) (top {cap})")
plot_path = os.path.join(self.output_dir, "tree_importance.png")
+ plt.tight_layout()
plt.savefig(plot_path, bbox_inches="tight")
plt.close()
self.plots["tree_importance"] = plot_path
@@ -106,23 +194,22 @@
def save_shap_values(self, max_samples=None, max_display=None, max_features=None):
model = self.best_model or self.exp.get_config("best_model")
- X_data = None
- for key in ("X_test_transformed", "X_train_transformed"):
- try:
- X_data = self.exp.get_config(key)
- break
- except KeyError:
- continue
+ X_data = self._get_transformed_frame(model)
if X_data is None:
raise RuntimeError("No transformed dataset found for SHAP.")
- # --- Adaptive feature limiting (proportional cap) ---
n_rows, n_features = X_data.shape
+ self.shap_total_features = n_features
+ feature_cap = (
+ min(self.max_plot_features, n_features)
+ if self.max_plot_features is not None
+ else n_features
+ )
if max_features is None:
- if n_features <= 200:
- max_features = n_features
- else:
- max_features = min(200, max(20, int(n_features * 0.1)))
+ max_features = feature_cap
+ else:
+ max_features = min(max_features, feature_cap)
+ display_features = list(X_data.columns)
try:
if hasattr(model, "feature_importances_"):
@@ -138,15 +225,35 @@
variances = X_data.var()
top_features = variances.nlargest(max_features).index
- if len(top_features) < n_features:
+ candidate_features = list(top_features)
+ missing = [f for f in candidate_features if f not in X_data.columns]
+ display_features = [f for f in candidate_features if f in X_data.columns]
+ if missing:
+ LOG.warning(
+ "Dropping %s transformed feature(s) not present in SHAP frame: %s",
+ len(missing),
+ missing[:5],
+ )
+ if display_features and len(display_features) < n_features:
LOG.info(
- f"Restricted SHAP computation to top {len(top_features)} / {n_features} features"
+ "Restricting SHAP display to top %s of %s features",
+ len(display_features),
+ n_features,
)
- X_data = X_data[top_features]
+ elif not display_features:
+ display_features = list(X_data.columns)
except Exception as e:
LOG.warning(
f"Feature limiting failed: {e}. Using all {n_features} features."
)
+ display_features = list(X_data.columns)
+
+ self.shap_used_features = len(display_features)
+
+ # Apply the column restriction so SHAP only runs on the selected features.
+ if display_features:
+ X_data = X_data[display_features]
+ n_rows, n_features = X_data.shape
# --- Adaptive row subsampling ---
if max_samples is None:
@@ -157,18 +264,26 @@
else:
max_samples = min(1000, int(n_rows * 0.1))
+ if self.max_shap_rows is not None:
+ max_samples = min(max_samples, self.max_shap_rows)
+
if n_rows > max_samples:
LOG.info(f"Subsampling SHAP rows: {max_samples} of {n_rows}")
X_data = X_data.sample(max_samples, random_state=42)
# --- Adaptive feature display ---
+ display_cap = (
+ min(self.max_plot_features, len(display_features))
+ if self.max_plot_features is not None
+ else len(display_features)
+ )
if max_display is None:
- if X_data.shape[1] <= 20:
- max_display = X_data.shape[1]
- elif X_data.shape[1] <= 100:
- max_display = 30
- else:
- max_display = 50
+ max_display = display_cap
+ else:
+ max_display = min(max_display, display_cap)
+ if not display_features:
+ display_features = list(X_data.columns)
+ max_display = len(display_features)
# Background set
bg = X_data.sample(min(len(X_data), 100), random_state=42)
@@ -177,37 +292,159 @@
)
# Optimized explainer
+ explainer = None
+ explainer_label = None
if hasattr(model, "feature_importances_"):
explainer = shap.TreeExplainer(
model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
)
+ explainer_label = "tree_path_dependent"
elif hasattr(model, "coef_"):
explainer = shap.LinearExplainer(model, bg)
+ explainer_label = "linear"
else:
explainer = shap.Explainer(predict_fn, bg)
+ explainer_label = explainer.__class__.__name__
try:
shap_values = explainer(X_data)
self.shap_model_name = explainer.__class__.__name__
except Exception as e:
- LOG.error(f"SHAP computation failed: {e}")
+ error_message = str(e)
+ needs_tree_fallback = (
+ hasattr(model, "feature_importances_")
+ and "does not cover all the leaves" in error_message.lower()
+ )
+ feature_name_mismatch = "feature names should match" in error_message.lower()
+ if needs_tree_fallback:
+ LOG.warning(
+ "SHAP computation failed using '%s' perturbation (%s). "
+ "Retrying with interventional perturbation.",
+ explainer_label,
+ error_message,
+ )
+ try:
+ explainer = shap.TreeExplainer(
+ model,
+ bg,
+ feature_perturbation="interventional",
+ n_jobs=-1,
+ )
+ shap_values = explainer(X_data)
+ self.shap_model_name = (
+ f"{explainer.__class__.__name__} (interventional)"
+ )
+ except Exception as retry_exc:
+ LOG.error(
+ "SHAP computation failed even after fallback: %s",
+ retry_exc,
+ )
+ self.shap_model_name = None
+ return
+ elif feature_name_mismatch:
+ LOG.warning(
+ "SHAP computation failed due to feature-name mismatch (%s). "
+ "Falling back to model-agnostic SHAP explainer.",
+ error_message,
+ )
+ try:
+ agnostic_explainer = shap.Explainer(predict_fn, bg)
+ shap_values = agnostic_explainer(X_data)
+ self.shap_model_name = (
+ f"{agnostic_explainer.__class__.__name__} (fallback)"
+ )
+ except Exception as fallback_exc:
+ LOG.error(
+ "Model-agnostic SHAP fallback also failed: %s",
+ fallback_exc,
+ )
+ self.shap_model_name = None
+ return
+ else:
+ LOG.error(f"SHAP computation failed: {e}")
+ self.shap_model_name = None
+ return
+
+ def _limit_explanation_features(explanation):
+ if len(display_features) >= n_features:
+ return explanation
+ try:
+ limited = explanation[:, display_features]
+ LOG.info(
+ "SHAP explanation trimmed to %s display features.",
+ len(display_features),
+ )
+ return limited
+ except Exception as exc:
+ LOG.warning(
+ "Failed to restrict SHAP explanation to top features "
+ "(sample=%s); plot will include all features. Error: %s",
+ display_features[:5],
+ exc,
+ )
+ # Keep using full feature list if trimming fails
+ return explanation
+
+ shap_shape = getattr(shap_values, "shape", None)
+ class_labels = list(getattr(model, "classes_", []))
+ shap_outputs = []
+ if shap_shape is not None and len(shap_shape) == 3:
+ output_count = shap_shape[2]
+ LOG.info("Detected multi-output SHAP explanation with %s classes.", output_count)
+ for class_idx in range(output_count):
+ try:
+ class_expl = shap_values[..., class_idx]
+ except Exception as exc:
+ LOG.warning(
+ "Failed to extract SHAP explanation for class index %s: %s",
+ class_idx,
+ exc,
+ )
+ continue
+ label = (
+ class_labels[class_idx]
+ if class_labels and class_idx < len(class_labels)
+ else class_idx
+ )
+ shap_outputs.append((class_idx, label, class_expl))
+ else:
+ shap_outputs.append((None, None, shap_values))
+
+ if not shap_outputs:
+ LOG.error("No SHAP outputs available for plotting.")
self.shap_model_name = None
return
- # --- Plot SHAP summary ---
- out_path = os.path.join(self.output_dir, "shap_summary.png")
- plt.figure()
- shap.plots.beeswarm(shap_values, max_display=max_display, show=False)
- plt.title(
- f"SHAP Summary for {model.__class__.__name__} (top {max_display} features)"
- )
- plt.savefig(out_path, bbox_inches="tight")
- plt.close()
- self.plots["shap_summary"] = out_path
+ # --- Plot SHAP summary (one per class if needed) ---
+ for class_idx, class_label, class_expl in shap_outputs:
+ expl_to_plot = _limit_explanation_features(class_expl)
+ suffix = ""
+ plot_key = "shap_summary"
+ if class_idx is not None:
+ safe_label = str(class_label).replace("/", "_").replace(" ", "_")
+ suffix = f"_class_{safe_label}"
+ plot_key = f"shap_summary_class_{safe_label}"
+ out_filename = f"shap_summary{suffix}.png"
+ out_path = os.path.join(self.output_dir, out_filename)
+ plt.figure()
+ shap.plots.beeswarm(expl_to_plot, max_display=max_display, show=False)
+ title = f"SHAP Summary for {model.__class__.__name__}"
+ if class_idx is not None:
+ title += f" (class {class_label})"
+ plt.title(f"{title} (top {max_display} features)")
+ plt.tight_layout()
+ plt.savefig(out_path, bbox_inches="tight")
+ plt.close()
+ self.plots[plot_key] = out_path
# --- Log summary ---
LOG.info(
- f"SHAP summary completed with {X_data.shape[0]} rows and {X_data.shape[1]} features (displaying top {max_display})."
+ "SHAP summary completed with %s rows and %s features "
+ "(displaying top %s) across %s output(s).",
+ X_data.shape[0],
+ X_data.shape[1],
+ max_display,
+ len(shap_outputs),
)
def generate_html_report(self):
@@ -227,12 +464,19 @@
section_title = (
f"SHAP Summary from {getattr(self, 'shap_model_name', 'model')}"
)
+ elif plot_name.startswith("shap_summary_class_"):
+ class_label = plot_name.replace("shap_summary_class_", "")
+ section_title = (
+ f"SHAP Summary for class {class_label} "
+ f"({getattr(self, 'shap_model_name', 'model')})"
+ )
else:
section_title = plot_name
plots_html += f"""
-
+
{section_title}
-

+
"""
return f"{plots_html}"
diff -r a2aeeb754d76 -r 4fee4504646e pycaret_macros.xml
--- a/pycaret_macros.xml Fri Nov 28 15:46:17 2025 +0000
+++ b/pycaret_macros.xml Fri Nov 28 22:28:26 2025 +0000
@@ -1,5 +1,5 @@
- 0.1.2
+ 0.1.3
3.3.2
2
@PYCARET_VERSION@+@SUFFIX@