# HG changeset patch
# User goeckslab
# Date 1763774172 0
# Node ID b7ed0e483e4d4daa49a5bf39dcc757f2b247a2f8
# Parent ec4b9244f930c758b83d765048a5f157483739ce
planemo upload for repository https://github.com/goeckslab/Galaxy-Ludwig.git commit e2ab4c0f9ce8b7a0a48f749ef5dd9899d6c2b1f8
diff -r ec4b9244f930 -r b7ed0e483e4d ludwig_experiment.py
--- a/ludwig_experiment.py Sat Sep 06 01:54:14 2025 +0000
+++ b/ludwig_experiment.py Sat Nov 22 01:16:12 2025 +0000
@@ -1,10 +1,15 @@
+import base64
+import html
import json
import logging
import os
import pickle
+import re
import sys
+from io import BytesIO
import pandas as pd
+from ludwig.api import LudwigModel
from ludwig.experiment import cli
from ludwig.globals import (
DESCRIPTION_FILE_NAME,
@@ -21,6 +26,11 @@
get_html_template
)
+try: # pragma: no cover - optional dependency in runtime containers
+ import matplotlib.pyplot as plt
+except ImportError: # pragma: no cover
+ plt = None
+
logging.basicConfig(level=logging.DEBUG)
@@ -158,44 +168,435 @@
LOG.error(f"Error converting Parquet to CSV: {e}")
-def generate_html_report(title, ludwig_output_directory_name):
- # ludwig_output_directory = os.path.join(
- # output_directory, ludwig_output_directory_name)
+def _resolve_dataset_path(dataset_path):
+ if not dataset_path:
+ return None
+
+ candidates = [dataset_path]
+
+ if not os.path.isabs(dataset_path):
+ candidates.extend([
+ os.path.join(output_directory, dataset_path),
+ os.path.join(os.getcwd(), dataset_path),
+ ])
+
+ for candidate in candidates:
+ if candidate and os.path.exists(candidate):
+ return os.path.abspath(candidate)
+
+ return None
+
+
+def _load_dataset_dataframe(dataset_path):
+ if not dataset_path:
+ return None
+
+ _, ext = os.path.splitext(dataset_path.lower())
+
+ try:
+ if ext in {".csv", ".tsv"}:
+ sep = "\t" if ext == ".tsv" else ","
+ return pd.read_csv(dataset_path, sep=sep)
+ if ext == ".parquet":
+ return pd.read_parquet(dataset_path)
+ if ext == ".json":
+ return pd.read_json(dataset_path)
+ if ext == ".h5":
+ return pd.read_hdf(dataset_path)
+ except Exception as exc:
+ LOG.warning(f"Unable to load dataset '{dataset_path}': {exc}")
+
+ LOG.warning("Unsupported dataset format for feature importance computation")
+ return None
+
+
+def sanitize_feature_name(name):
+ """Mirror Ludwig's get_sanitized_feature_name implementation."""
+ return re.sub(r"[(){}.:\"\"\'\'\[\]]", "_", str(name))
+
+
+def _sanitize_dataframe_columns(dataframe):
+ """Rename dataframe columns to Ludwig-sanitized names for explainability."""
+ column_map = {col: sanitize_feature_name(col) for col in dataframe.columns}
+
+ sanitized_df = dataframe.rename(columns=column_map)
+ if len(set(column_map.values())) != len(column_map.values()):
+ LOG.warning(
+ "Column name collision after sanitization; feature importance may be unreliable"
+ )
+
+ return sanitized_df
+
+
+def _feature_importance_plot(label_df, label_name, top_n=10, max_abs_importance=None):
+ """
+ Return base64-encoded bar plot for a label's top-N feature importances.
+
+ max_abs_importance lets us pin the x-axis across labels so readers can
+ compare magnitudes.
+ """
+ if plt is None or label_df.empty:
+ return ""
+
+ top_features = label_df.nlargest(top_n, "abs_importance")
+ if top_features.empty:
+ return ""
+
+ fig, ax = plt.subplots(figsize=(6, 3 + 0.2 * len(top_features)))
+ ax.barh(top_features["feature"], top_features["abs_importance"], color="#3f8fd2")
+ ax.set_xlabel("|importance|")
+ if max_abs_importance and max_abs_importance > 0:
+ ax.set_xlim(0, max_abs_importance * 1.05)
+ ax.invert_yaxis()
+ fig.tight_layout()
+
+ buf = BytesIO()
+ fig.savefig(buf, format="png", dpi=150)
+ plt.close(fig)
+ encoded = base64.b64encode(buf.getvalue()).decode("utf-8")
+ return encoded
+
+
+def render_feature_importance_table(df: pd.DataFrame) -> str:
+ """Render a sortable HTML table for feature importance values."""
+ if df.empty:
+ return ""
+
+ columns = list(df.columns)
+ headers = "".join(
+ f"
{html.escape(str(col).replace('_', ' '))}
"
+ for col in columns
+ )
+
+ body_rows = []
+ for _, row in df.iterrows():
+ cells = []
+ for col in columns:
+ val = row[col]
+ if isinstance(val, float):
+ val_str = f"{val:.6f}"
+ else:
+ val_str = str(val)
+ cells.append(f"
"
- # test_statistics_html += json_to_html_table(
- # test_statistics)
- # except Exception as e:
- # LOG.info(f"Error reading test statistics: {e}")
+ try:
+ ludwig_model = LudwigModel.load(model_dir)
+ except Exception as exc:
+ LOG.warning(f"Unable to load Ludwig model for explanations: {exc}")
+ return
+
+ training_metadata = getattr(ludwig_model, "training_set_metadata", {})
+
+ output_feature_name, dataset_path = get_output_feature_name(
+ ludwig_output_directory)
+
+ if not output_feature_name or not dataset_path:
+ LOG.warning("Output feature or dataset path missing; skipping feature importance")
+ if hasattr(ludwig_model, "close"):
+ ludwig_model.close()
+ return
+
+ dataset_full_path = _resolve_dataset_path(dataset_path)
+ if not dataset_full_path:
+ LOG.warning(f"Unable to resolve dataset path '{dataset_path}' for explanations")
+ if hasattr(ludwig_model, "close"):
+ ludwig_model.close()
+ return
+
+ dataframe = _load_dataset_dataframe(dataset_full_path)
+ if dataframe is None or dataframe.empty:
+ LOG.warning("Dataset unavailable or empty; skipping feature importance")
+ if hasattr(ludwig_model, "close"):
+ ludwig_model.close()
+ return
+
+ dataframe = _sanitize_dataframe_columns(dataframe)
+
+ data_subset = dataframe if len(dataframe) <= sample_size else dataframe.head(sample_size)
+ sample_df = dataframe.sample(
+ n=min(sample_size, len(dataframe)),
+ random_state=random_seed,
+ replace=False,
+ ) if len(dataframe) > sample_size else dataframe
+
+ try:
+ from ludwig.explain.captum import IntegratedGradientsExplainer
+ except ImportError as exc:
+ LOG.warning(f"Integrated Gradients explainer unavailable: {exc}")
+ if hasattr(ludwig_model, "close"):
+ ludwig_model.close()
+ return
+
+ sanitized_output_feature = sanitize_feature_name(output_feature_name)
+
+ try:
+ explainer = IntegratedGradientsExplainer(
+ ludwig_model,
+ data_subset,
+ sample_df,
+ sanitized_output_feature,
+ )
+ explanations = explainer.explain()
+ except Exception as exc:
+ LOG.warning(f"Unable to compute feature importance: {exc}")
+ if hasattr(ludwig_model, "close"):
+ ludwig_model.close()
+ return
+
+ if hasattr(ludwig_model, "close"):
+ try:
+ ludwig_model.close()
+ except Exception:
+ pass
- # Convert visualizations to HTML
+ label_names = []
+ target_metadata = {}
+ if isinstance(training_metadata, dict):
+ target_metadata = training_metadata.get(sanitized_output_feature, {})
+
+ if isinstance(target_metadata, dict):
+ if "idx2str" in target_metadata:
+ idx2str = target_metadata["idx2str"]
+ if isinstance(idx2str, dict):
+ def _idx_key(item):
+ idx_key = item[0]
+ try:
+ return (0, int(idx_key))
+ except (TypeError, ValueError):
+ return (1, str(idx_key))
+
+ label_names = [value for key, value in sorted(
+ idx2str.items(), key=_idx_key)]
+ else:
+ label_names = idx2str
+ elif "str2idx" in target_metadata and isinstance(
+ target_metadata["str2idx"], dict):
+ # invert mapping
+ label_names = [label for label, _ in sorted(
+ target_metadata["str2idx"].items(),
+ key=lambda item: item[1])]
+
+ rows = []
+ global_explanation = explanations.global_explanation
+ for label_index, label_explanation in enumerate(
+ global_explanation.label_explanations):
+ if label_names and label_index < len(label_names):
+ label_value = str(label_names[label_index])
+ elif len(global_explanation.label_explanations) == 1:
+ label_value = output_feature_name
+ else:
+ label_value = str(label_index)
+
+ for feature in label_explanation.feature_attributions:
+ rows.append({
+ "label": label_value,
+ "feature": feature.feature_name,
+ "importance": feature.attribution,
+ "abs_importance": abs(feature.attribution),
+ })
+
+ if not rows:
+ LOG.warning("No feature importance rows produced")
+ return
+
+ importance_df = pd.DataFrame(rows)
+ importance_df.sort_values([
+ "label",
+ "abs_importance"
+ ], ascending=[True, False], inplace=True)
+
+ importance_df.to_csv(output_csv_path, index=False)
+
+ LOG.info(f"Feature importance saved to {output_csv_path}")
+
+
+def generate_html_report(title, ludwig_output_directory_name):
plots_html = ""
- if len(os.listdir(viz_output_directory)) > 0:
+ plot_files = []
+ if os.path.isdir(viz_output_directory):
+ plot_files = sorted(os.listdir(viz_output_directory))
+ if plot_files:
plots_html = "
Visualizations
"
- for plot_file in sorted(os.listdir(viz_output_directory)):
+ for plot_file in plot_files:
plot_path = os.path.join(viz_output_directory, plot_file)
if os.path.isfile(plot_path) and plot_file.endswith((".png", ".jpg")):
encoded_image = encode_image_to_base64(plot_path)
+ plot_title = os.path.splitext(plot_file)[0].replace("_", " ")
plots_html += (
f'
Feature importance scores come from Ludwig's Integrated Gradients explainer. "
+ "It interpolates between each example and a neutral baseline sample, summing "
+ "the change in the model output along that path. Higher |importance| values "
+ "indicate stronger influence. Plots share a common x-axis to make magnitudes "
+ "comparable across labels, and the table columns can be sorted for quick scans.
"
+ )
+ feature_importance_html = (
+ "
Feature Importance
"
+ + explanation_text
+ + render_feature_importance_table(top_rows)
+ + "".join(plot_sections)
+ )
+ except Exception as exc:
+ LOG.info(f"Unable to embed feature importance table: {exc}")
+
# Generate the full HTML content
+ feature_section = feature_importance_html or "
Feature importance scores come from Ludwig's Integrated Gradients explainer. It interpolates between each example and a neutral baseline sample, summing the change in the model output along that path. Higher |importance| values indicate stronger influence. Plots share a common x-axis to make magnitudes comparable across labels, and the table columns can be sorted for quick scans.
Feature importance scores come from Ludwig's Integrated Gradients explainer. It interpolates between each example and a neutral baseline sample, summing the change in the model output along that path. Higher |importance| values indicate stronger influence. Plots share a common x-axis to make magnitudes comparable across labels, and the table columns can be sorted for quick scans.