comparison ludwig_experiment.py @ 7:12cd5a6fc2ba draft default tip

planemo upload for repository https://github.com/goeckslab/Galaxy-Ludwig.git commit e2ab4c0f9ce8b7a0a48f749ef5dd9899d6c2b1f8
author goeckslab
date Sat, 22 Nov 2025 01:17:46 +0000
parents a1b91224afb9
children
comparison
equal deleted inserted replaced
6:4a63b50a8634 7:12cd5a6fc2ba
1 import base64
2 import html
1 import json 3 import json
2 import logging 4 import logging
3 import os 5 import os
4 import pickle 6 import pickle
7 import re
5 import sys 8 import sys
9 from io import BytesIO
6 10
7 import pandas as pd 11 import pandas as pd
12 from ludwig.api import LudwigModel
8 from ludwig.experiment import cli 13 from ludwig.experiment import cli
9 from ludwig.globals import ( 14 from ludwig.globals import (
10 DESCRIPTION_FILE_NAME, 15 DESCRIPTION_FILE_NAME,
11 PREDICTIONS_PARQUET_FILE_NAME, 16 PREDICTIONS_PARQUET_FILE_NAME,
12 TEST_STATISTICS_FILE_NAME, 17 TEST_STATISTICS_FILE_NAME,
18 from utils import ( 23 from utils import (
19 encode_image_to_base64, 24 encode_image_to_base64,
20 get_html_closing, 25 get_html_closing,
21 get_html_template 26 get_html_template
22 ) 27 )
28
29 try: # pragma: no cover - optional dependency in runtime containers
30 import matplotlib.pyplot as plt
31 except ImportError: # pragma: no cover
32 plt = None
23 33
24 34
25 logging.basicConfig(level=logging.DEBUG) 35 logging.basicConfig(level=logging.DEBUG)
26 36
27 LOG = logging.getLogger(__name__) 37 LOG = logging.getLogger(__name__)
156 LOG.info(f"Converted Parquet to CSV: {csv_path}") 166 LOG.info(f"Converted Parquet to CSV: {csv_path}")
157 except Exception as e: 167 except Exception as e:
158 LOG.error(f"Error converting Parquet to CSV: {e}") 168 LOG.error(f"Error converting Parquet to CSV: {e}")
159 169
160 170
171 def _resolve_dataset_path(dataset_path):
172 if not dataset_path:
173 return None
174
175 candidates = [dataset_path]
176
177 if not os.path.isabs(dataset_path):
178 candidates.extend([
179 os.path.join(output_directory, dataset_path),
180 os.path.join(os.getcwd(), dataset_path),
181 ])
182
183 for candidate in candidates:
184 if candidate and os.path.exists(candidate):
185 return os.path.abspath(candidate)
186
187 return None
188
189
190 def _load_dataset_dataframe(dataset_path):
191 if not dataset_path:
192 return None
193
194 _, ext = os.path.splitext(dataset_path.lower())
195
196 try:
197 if ext in {".csv", ".tsv"}:
198 sep = "\t" if ext == ".tsv" else ","
199 return pd.read_csv(dataset_path, sep=sep)
200 if ext == ".parquet":
201 return pd.read_parquet(dataset_path)
202 if ext == ".json":
203 return pd.read_json(dataset_path)
204 if ext == ".h5":
205 return pd.read_hdf(dataset_path)
206 except Exception as exc:
207 LOG.warning(f"Unable to load dataset '{dataset_path}': {exc}")
208
209 LOG.warning("Unsupported dataset format for feature importance computation")
210 return None
211
212
213 def sanitize_feature_name(name):
214 """Mirror Ludwig's get_sanitized_feature_name implementation."""
215 return re.sub(r"[(){}.:\"\"\'\'\[\]]", "_", str(name))
216
217
218 def _sanitize_dataframe_columns(dataframe):
219 """Rename dataframe columns to Ludwig-sanitized names for explainability."""
220 column_map = {col: sanitize_feature_name(col) for col in dataframe.columns}
221
222 sanitized_df = dataframe.rename(columns=column_map)
223 if len(set(column_map.values())) != len(column_map.values()):
224 LOG.warning(
225 "Column name collision after sanitization; feature importance may be unreliable"
226 )
227
228 return sanitized_df
229
230
231 def _feature_importance_plot(label_df, label_name, top_n=10, max_abs_importance=None):
232 """
233 Return base64-encoded bar plot for a label's top-N feature importances.
234
235 max_abs_importance lets us pin the x-axis across labels so readers can
236 compare magnitudes.
237 """
238 if plt is None or label_df.empty:
239 return ""
240
241 top_features = label_df.nlargest(top_n, "abs_importance")
242 if top_features.empty:
243 return ""
244
245 fig, ax = plt.subplots(figsize=(6, 3 + 0.2 * len(top_features)))
246 ax.barh(top_features["feature"], top_features["abs_importance"], color="#3f8fd2")
247 ax.set_xlabel("|importance|")
248 if max_abs_importance and max_abs_importance > 0:
249 ax.set_xlim(0, max_abs_importance * 1.05)
250 ax.invert_yaxis()
251 fig.tight_layout()
252
253 buf = BytesIO()
254 fig.savefig(buf, format="png", dpi=150)
255 plt.close(fig)
256 encoded = base64.b64encode(buf.getvalue()).decode("utf-8")
257 return encoded
258
259
260 def render_feature_importance_table(df: pd.DataFrame) -> str:
261 """Render a sortable HTML table for feature importance values."""
262 if df.empty:
263 return ""
264
265 columns = list(df.columns)
266 headers = "".join(
267 f"<th class='sortable'>{html.escape(str(col).replace('_', ' '))}</th>"
268 for col in columns
269 )
270
271 body_rows = []
272 for _, row in df.iterrows():
273 cells = []
274 for col in columns:
275 val = row[col]
276 if isinstance(val, float):
277 val_str = f"{val:.6f}"
278 else:
279 val_str = str(val)
280 cells.append(f"<td>{html.escape(val_str)}</td>")
281 body_rows.append("<tr>" + "".join(cells) + "</tr>")
282
283 return (
284 "<div class='scroll-rows-30'>"
285 "<table class='feature-importance-table sortable-table'>"
286 f"<thead><tr>{headers}</tr></thead>"
287 f"<tbody>{''.join(body_rows)}</tbody>"
288 "</table>"
289 "</div>"
290 )
291
292
293 def compute_feature_importance(ludwig_output_directory_name,
294 sample_size=200,
295 random_seed=42):
296 ludwig_output_directory = os.path.join(
297 output_directory, ludwig_output_directory_name)
298 model_dir = os.path.join(ludwig_output_directory, "model")
299
300 output_csv_path = os.path.join(
301 ludwig_output_directory, "feature_importance.csv")
302
303 if not os.path.exists(model_dir):
304 LOG.info("Model directory not found; skipping feature importance computation")
305 return
306
307 try:
308 ludwig_model = LudwigModel.load(model_dir)
309 except Exception as exc:
310 LOG.warning(f"Unable to load Ludwig model for explanations: {exc}")
311 return
312
313 training_metadata = getattr(ludwig_model, "training_set_metadata", {})
314
315 output_feature_name, dataset_path = get_output_feature_name(
316 ludwig_output_directory)
317
318 if not output_feature_name or not dataset_path:
319 LOG.warning("Output feature or dataset path missing; skipping feature importance")
320 if hasattr(ludwig_model, "close"):
321 ludwig_model.close()
322 return
323
324 dataset_full_path = _resolve_dataset_path(dataset_path)
325 if not dataset_full_path:
326 LOG.warning(f"Unable to resolve dataset path '{dataset_path}' for explanations")
327 if hasattr(ludwig_model, "close"):
328 ludwig_model.close()
329 return
330
331 dataframe = _load_dataset_dataframe(dataset_full_path)
332 if dataframe is None or dataframe.empty:
333 LOG.warning("Dataset unavailable or empty; skipping feature importance")
334 if hasattr(ludwig_model, "close"):
335 ludwig_model.close()
336 return
337
338 dataframe = _sanitize_dataframe_columns(dataframe)
339
340 data_subset = dataframe if len(dataframe) <= sample_size else dataframe.head(sample_size)
341 sample_df = dataframe.sample(
342 n=min(sample_size, len(dataframe)),
343 random_state=random_seed,
344 replace=False,
345 ) if len(dataframe) > sample_size else dataframe
346
347 try:
348 from ludwig.explain.captum import IntegratedGradientsExplainer
349 except ImportError as exc:
350 LOG.warning(f"Integrated Gradients explainer unavailable: {exc}")
351 if hasattr(ludwig_model, "close"):
352 ludwig_model.close()
353 return
354
355 sanitized_output_feature = sanitize_feature_name(output_feature_name)
356
357 try:
358 explainer = IntegratedGradientsExplainer(
359 ludwig_model,
360 data_subset,
361 sample_df,
362 sanitized_output_feature,
363 )
364 explanations = explainer.explain()
365 except Exception as exc:
366 LOG.warning(f"Unable to compute feature importance: {exc}")
367 if hasattr(ludwig_model, "close"):
368 ludwig_model.close()
369 return
370
371 if hasattr(ludwig_model, "close"):
372 try:
373 ludwig_model.close()
374 except Exception:
375 pass
376
377 label_names = []
378 target_metadata = {}
379 if isinstance(training_metadata, dict):
380 target_metadata = training_metadata.get(sanitized_output_feature, {})
381
382 if isinstance(target_metadata, dict):
383 if "idx2str" in target_metadata:
384 idx2str = target_metadata["idx2str"]
385 if isinstance(idx2str, dict):
386 def _idx_key(item):
387 idx_key = item[0]
388 try:
389 return (0, int(idx_key))
390 except (TypeError, ValueError):
391 return (1, str(idx_key))
392
393 label_names = [value for key, value in sorted(
394 idx2str.items(), key=_idx_key)]
395 else:
396 label_names = idx2str
397 elif "str2idx" in target_metadata and isinstance(
398 target_metadata["str2idx"], dict):
399 # invert mapping
400 label_names = [label for label, _ in sorted(
401 target_metadata["str2idx"].items(),
402 key=lambda item: item[1])]
403
404 rows = []
405 global_explanation = explanations.global_explanation
406 for label_index, label_explanation in enumerate(
407 global_explanation.label_explanations):
408 if label_names and label_index < len(label_names):
409 label_value = str(label_names[label_index])
410 elif len(global_explanation.label_explanations) == 1:
411 label_value = output_feature_name
412 else:
413 label_value = str(label_index)
414
415 for feature in label_explanation.feature_attributions:
416 rows.append({
417 "label": label_value,
418 "feature": feature.feature_name,
419 "importance": feature.attribution,
420 "abs_importance": abs(feature.attribution),
421 })
422
423 if not rows:
424 LOG.warning("No feature importance rows produced")
425 return
426
427 importance_df = pd.DataFrame(rows)
428 importance_df.sort_values([
429 "label",
430 "abs_importance"
431 ], ascending=[True, False], inplace=True)
432
433 importance_df.to_csv(output_csv_path, index=False)
434
435 LOG.info(f"Feature importance saved to {output_csv_path}")
436
437
161 def generate_html_report(title, ludwig_output_directory_name): 438 def generate_html_report(title, ludwig_output_directory_name):
162 # ludwig_output_directory = os.path.join(
163 # output_directory, ludwig_output_directory_name)
164
165 # test_statistics_html = ""
166 # # Read test statistics JSON and convert to HTML table
167 # try:
168 # test_statistics_path = os.path.join(
169 # ludwig_output_directory, TEST_STATISTICS_FILE_NAME)
170 # with open(test_statistics_path, "r") as f:
171 # test_statistics = json.load(f)
172 # test_statistics_html = "<h2>Test Statistics</h2>"
173 # test_statistics_html += json_to_html_table(
174 # test_statistics)
175 # except Exception as e:
176 # LOG.info(f"Error reading test statistics: {e}")
177
178 # Convert visualizations to HTML
179 plots_html = "" 439 plots_html = ""
180 if len(os.listdir(viz_output_directory)) > 0: 440 plot_files = []
441 if os.path.isdir(viz_output_directory):
442 plot_files = sorted(os.listdir(viz_output_directory))
443 if plot_files:
181 plots_html = "<h2>Visualizations</h2>" 444 plots_html = "<h2>Visualizations</h2>"
182 for plot_file in sorted(os.listdir(viz_output_directory)): 445 for plot_file in plot_files:
183 plot_path = os.path.join(viz_output_directory, plot_file) 446 plot_path = os.path.join(viz_output_directory, plot_file)
184 if os.path.isfile(plot_path) and plot_file.endswith((".png", ".jpg")): 447 if os.path.isfile(plot_path) and plot_file.endswith((".png", ".jpg")):
185 encoded_image = encode_image_to_base64(plot_path) 448 encoded_image = encode_image_to_base64(plot_path)
449 plot_title = os.path.splitext(plot_file)[0].replace("_", " ")
186 plots_html += ( 450 plots_html += (
187 f'<div class="plot">' 451 f'<div class="plot">'
188 f'<h3>{os.path.splitext(plot_file)[0]}</h3>' 452 f'<h3>{plot_title}</h3>'
189 '<img src="data:image/png;base64,' 453 '<img src="data:image/png;base64,'
190 f'{encoded_image}" alt="{plot_file}">' 454 f'{encoded_image}" alt="{plot_file}">'
191 f'</div>' 455 f'</div>'
192 ) 456 )
193 457
458 feature_importance_html = ""
459 importance_path = os.path.join(
460 output_directory,
461 ludwig_output_directory_name,
462 "feature_importance.csv",
463 )
464 if os.path.exists(importance_path):
465 try:
466 importance_df = pd.read_csv(importance_path)
467 if not importance_df.empty:
468 sorted_df = (
469 importance_df
470 .sort_values(["label", "abs_importance"], ascending=[True, False])
471 )
472 top_rows = (
473 sorted_df
474 .groupby("label", as_index=False)
475 .head(5)
476 )
477 max_abs_importance = pd.to_numeric(
478 importance_df.get("abs_importance", pd.Series(dtype=float)),
479 errors="coerce",
480 ).max()
481 if pd.isna(max_abs_importance):
482 max_abs_importance = None
483
484 plot_sections = []
485 for label in sorted(importance_df["label"].unique()):
486 encoded_plot = _feature_importance_plot(
487 importance_df[importance_df["label"] == label],
488 label,
489 max_abs_importance=max_abs_importance,
490 )
491 if encoded_plot:
492 plot_sections.append(
493 f'<div class="plot feature-importance-plot">'
494 f'<h3>Top features for {label}</h3>'
495 f'<img src="data:image/png;base64,{encoded_plot}" '
496 f'alt="Feature importance plot for {label}">'
497 f'</div>'
498 )
499 explanation_text = (
500 "<p>Feature importance scores come from Ludwig's Integrated Gradients explainer. "
501 "It interpolates between each example and a neutral baseline sample, summing "
502 "the change in the model output along that path. Higher |importance| values "
503 "indicate stronger influence. Plots share a common x-axis to make magnitudes "
504 "comparable across labels, and the table columns can be sorted for quick scans.</p>"
505 )
506 feature_importance_html = (
507 "<h2>Feature Importance</h2>"
508 + explanation_text
509 + render_feature_importance_table(top_rows)
510 + "".join(plot_sections)
511 )
512 except Exception as exc:
513 LOG.info(f"Unable to embed feature importance table: {exc}")
514
194 # Generate the full HTML content 515 # Generate the full HTML content
516 feature_section = feature_importance_html or "<p>No feature importance artifacts were generated.</p>"
517 viz_section = plots_html or "<p>No visualizations were generated.</p>"
518 tabs_style = """
519 <style>
520 .tabs {
521 display: flex;
522 border-bottom: 2px solid #ccc;
523 margin-top: 20px;
524 margin-bottom: 1rem;
525 }
526 .tablink {
527 padding: 9px 18px;
528 cursor: pointer;
529 border: 1px solid #ccc;
530 border-bottom: none;
531 background: #f9f9f9;
532 margin-right: 5px;
533 border-top-left-radius: 8px;
534 border-top-right-radius: 8px;
535 font-size: 0.95rem;
536 font-weight: 500;
537 font-family: Arial, sans-serif;
538 color: #4A4A4A;
539 }
540 .tablink.active {
541 background: #ffffff;
542 font-weight: bold;
543 }
544 .tabcontent {
545 border: 1px solid #ccc;
546 border-top: none;
547 padding: 20px;
548 display: none;
549 }
550 .tabcontent.active {
551 display: block;
552 }
553 </style>
554 """
555 tabs_script = """
556 <script>
557 function openTab(evt, tabId) {
558 var i, tabcontent, tablinks;
559 tabcontent = document.getElementsByClassName("tabcontent");
560 for (i = 0; i < tabcontent.length; i++) {
561 tabcontent[i].style.display = "none";
562 tabcontent[i].classList.remove("active");
563 }
564 tablinks = document.getElementsByClassName("tablink");
565 for (i = 0; i < tablinks.length; i++) {
566 tablinks[i].classList.remove("active");
567 }
568 var current = document.getElementById(tabId);
569 if (current) {
570 current.style.display = "block";
571 current.classList.add("active");
572 }
573 if (evt && evt.currentTarget) {
574 evt.currentTarget.classList.add("active");
575 }
576 }
577 document.addEventListener("DOMContentLoaded", function() {
578 openTab({currentTarget: document.querySelector(".tablink")}, "viz-tab");
579 });
580 </script>
581 """
582 tabs_html = f"""
583 <div class="tabs">
584 <button class="tablink active" onclick="openTab(event, 'viz-tab')">Visualizations</button>
585 <button class="tablink" onclick="openTab(event, 'feature-tab')">Feature Importance</button>
586 </div>
587 <div id="viz-tab" class="tabcontent active">
588 {viz_section}
589 </div>
590 <div id="feature-tab" class="tabcontent">
591 {feature_section}
592 </div>
593 """
195 html_content = f""" 594 html_content = f"""
196 {get_html_template()} 595 {get_html_template()}
197 <h1>{title}</h1> 596 <h1>{title}</h1>
198 {plots_html} 597 {tabs_style}
598 {tabs_html}
599 {tabs_script}
199 {get_html_closing()} 600 {get_html_closing()}
200 """ 601 """
201 602
202 # Save the HTML report 603 # Save the HTML report
203 title: str 604 title: str
215 616
216 ludwig_output_directory_name = "experiment_run" 617 ludwig_output_directory_name = "experiment_run"
217 618
218 make_visualizations(ludwig_output_directory_name) 619 make_visualizations(ludwig_output_directory_name)
219 convert_parquet_to_csv(ludwig_output_directory_name) 620 convert_parquet_to_csv(ludwig_output_directory_name)
621 compute_feature_importance(ludwig_output_directory_name)
220 generate_html_report("Ludwig Experiment", ludwig_output_directory_name) 622 generate_html_report("Ludwig Experiment", ludwig_output_directory_name)