changeset 8:a48e750cfd25 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit c8a7fef0c54c269afd6c6bdf035af1a7574d11cb
author goeckslab
date Fri, 30 Jan 2026 14:20:49 +0000
parents ed2fefc8d892
children
files metrics_logic.py multimodal_learner.py multimodal_learner.xml report_utils.py test_pipeline.py training_pipeline.py
diffstat 6 files changed, 241 insertions(+), 60 deletions(-) [+]
line wrap: on
line diff
--- a/metrics_logic.py	Wed Jan 28 19:56:37 2026 +0000
+++ b/metrics_logic.py	Fri Jan 30 14:20:49 2026 +0000
@@ -18,6 +18,7 @@
     r2_score,
     recall_score,
     roc_auc_score,
+    roc_curve
 )
 
 
@@ -69,16 +70,40 @@
     return metrics
 
 
+def _get_binary_scores(
+    y_true: pd.Series,
+    y_proba: Optional[np.ndarray],
+    predictor,
+) -> Tuple[np.ndarray, object, Optional[np.ndarray]]:
+    classes_sorted = np.sort(pd.unique(y_true))
+    pos_label = classes_sorted[-1]
+    pos_scores = None
+    if y_proba is not None:
+        if y_proba.ndim == 1:
+            pos_scores = y_proba
+        else:
+            pos_col_idx = -1
+            try:
+                if hasattr(predictor, "class_labels") and predictor.class_labels:
+                    pos_col_idx = list(predictor.class_labels).index(pos_label)
+            except Exception:
+                pos_col_idx = -1
+            pos_scores = y_proba[:, pos_col_idx]
+    return classes_sorted, pos_label, pos_scores
+
+
 def _compute_binary_metrics(
     y_true: pd.Series,
     y_pred: pd.Series,
     y_proba: Optional[np.ndarray],
-    predictor
+    predictor,
+    classes_sorted: Optional[np.ndarray] = None,
+    pos_label: Optional[object] = None,
+    pos_scores: Optional[np.ndarray] = None,
 ) -> "OrderedDict[str, float]":
     metrics = OrderedDict()
-    classes_sorted = np.sort(pd.unique(y_true))
-    # Choose the lexicographically larger class as "positive"
-    pos_label = classes_sorted[-1]
+    if classes_sorted is None or pos_label is None or pos_scores is None:
+        classes_sorted, pos_label, pos_scores = _get_binary_scores(y_true, y_proba, predictor)
 
     metrics["Accuracy"] = accuracy_score(y_true, y_pred)
     metrics["Precision"] = precision_score(y_true, y_pred, pos_label=pos_label, zero_division=0)
@@ -92,18 +117,7 @@
         metrics["Specificity_(TNR)"] = np.nan
 
     # Probabilistic metrics
-    if y_proba is not None:
-        # pick column of positive class
-        if y_proba.ndim == 1:
-            pos_scores = y_proba
-        else:
-            pos_col_idx = -1
-            try:
-                if hasattr(predictor, "class_labels") and predictor.class_labels:
-                    pos_col_idx = list(predictor.class_labels).index(pos_label)
-            except Exception:
-                pos_col_idx = -1
-            pos_scores = y_proba[:, pos_col_idx]
+    if y_proba is not None and pos_scores is not None:
         try:
             metrics["ROC-AUC"] = roc_auc_score(y_true == pos_label, pos_scores)
         except Exception:
@@ -221,7 +235,8 @@
     target_col: str,
     problem_type: str,
     threshold: Optional[float] = None,    # <— NEW
-) -> "OrderedDict[str, float]":
+    return_curve: bool = False,
+) -> "OrderedDict[str, float] | Tuple[OrderedDict[str, float], Optional[dict]]":
     """Compute transparency metrics for one split (Train/Val/Test) based on task type."""
     # Prepare inputs
     features = df.drop(columns=[target_col], errors="ignore")
@@ -235,22 +250,14 @@
     except Exception:
         y_proba = None
 
+    classes_sorted = pos_label = pos_scores = None
+    if problem_type == "binary":
+        classes_sorted, pos_label, pos_scores = _get_binary_scores(y_true_series, y_proba, predictor)
+
     # Labels (optionally thresholded for binary)
     y_pred_series = None
-    if problem_type == "binary" and (threshold is not None) and (y_proba is not None):
-        classes_sorted = np.sort(pd.unique(y_true_series))
-        pos_label = classes_sorted[-1]
+    if problem_type == "binary" and (threshold is not None) and (pos_scores is not None):
         neg_label = classes_sorted[0]
-        if y_proba.ndim == 1:
-            pos_scores = y_proba
-        else:
-            pos_col_idx = -1
-            try:
-                if hasattr(predictor, "class_labels") and predictor.class_labels:
-                    pos_col_idx = list(predictor.class_labels).index(pos_label)
-            except Exception:
-                pos_col_idx = -1
-            pos_scores = y_proba[:, pos_col_idx]
         y_pred_series = pd.Series(np.where(pos_scores >= float(threshold), pos_label, neg_label)).reset_index(drop=True)
     else:
         # Fall back to model's default label prediction (argmax / 0.5 equivalent)
@@ -259,13 +266,35 @@
     if problem_type == "regression":
         y_true_arr = np.asarray(y_true_series, dtype=float)
         y_pred_arr = np.asarray(y_pred_series, dtype=float)
-        return _compute_regression_metrics(y_true_arr, y_pred_arr)
+        metrics = _compute_regression_metrics(y_true_arr, y_pred_arr)
+        return (metrics, None) if return_curve else metrics
 
     if problem_type == "binary":
-        return _compute_binary_metrics(y_true_series, y_pred_series, y_proba, predictor)
+        metrics = _compute_binary_metrics(
+            y_true_series,
+            y_pred_series,
+            y_proba,
+            predictor,
+            classes_sorted=classes_sorted,
+            pos_label=pos_label,
+            pos_scores=pos_scores,
+        )
+        roc_curve_data = None
+        if return_curve and pos_scores is not None and pos_label is not None:
+            try:
+                fpr, tpr, thresholds = roc_curve(y_true_series == pos_label, pos_scores)
+                roc_curve_data = {
+                    "fpr": fpr.tolist(),
+                    "tpr": tpr.tolist(),
+                    "thresholds": thresholds.tolist(),
+                }
+            except Exception:
+                roc_curve_data = None
+        return (metrics, roc_curve_data) if return_curve else metrics
 
     # multiclass
-    return _compute_multiclass_metrics(y_true_series, y_pred_series, y_proba)
+    metrics = _compute_multiclass_metrics(y_true_series, y_pred_series, y_proba)
+    return (metrics, None) if return_curve else metrics
 
 
 def evaluate_all_transparency(
@@ -276,25 +305,57 @@
     target_col: str,
     problem_type: str,
     threshold: Optional[float] = None,
-) -> Tuple[pd.DataFrame, Dict[str, Dict[str, float]]]:
+) -> Tuple[pd.DataFrame, Dict[str, Dict[str, float]], Dict[str, dict]]:
     """
     Evaluate Train/Val/Test with the transparent metrics suite.
     Returns:
       - metrics_table: DataFrame with index=Metric, columns subset of [Train, Validation, Test]
       - raw_dict: nested dict {split -> {metric -> value}}
+      - roc_curves: nested dict {split -> {fpr, tpr, thresholds}} (binary only)
     """
     split_results: Dict[str, Dict[str, float]] = {}
+    roc_curves: Dict[str, dict] = {}
     splits = []
 
     # IMPORTANT: do NOT apply threshold to Train/Val
     if train_df is not None and len(train_df):
-        split_results["Train"] = compute_metrics_for_split(predictor, train_df, target_col, problem_type, threshold=None)
+        train_metrics, train_curve = compute_metrics_for_split(
+            predictor,
+            train_df,
+            target_col,
+            problem_type,
+            threshold=None,
+            return_curve=True,
+        )
+        split_results["Train"] = train_metrics
+        if train_curve:
+            roc_curves["Train"] = train_curve
         splits.append("Train")
     if val_df is not None and len(val_df):
-        split_results["Validation"] = compute_metrics_for_split(predictor, val_df, target_col, problem_type, threshold=None)
+        val_metrics, val_curve = compute_metrics_for_split(
+            predictor,
+            val_df,
+            target_col,
+            problem_type,
+            threshold=None,
+            return_curve=True,
+        )
+        split_results["Validation"] = val_metrics
+        if val_curve:
+            roc_curves["Validation"] = val_curve
         splits.append("Validation")
     if test_df is not None and len(test_df):
-        split_results["Test"] = compute_metrics_for_split(predictor, test_df, target_col, problem_type, threshold=threshold)
+        test_metrics, test_curve = compute_metrics_for_split(
+            predictor,
+            test_df,
+            target_col,
+            problem_type,
+            threshold=threshold,
+            return_curve=True,
+        )
+        split_results["Test"] = test_metrics
+        if test_curve:
+            roc_curves["Test"] = test_curve
         splits.append("Test")
 
     # Preserve order from the first split; include any extras from others
@@ -310,4 +371,4 @@
         for m, v in split_results[s].items():
             metrics_table.loc[m, s] = v
 
-    return metrics_table, split_results
+    return metrics_table, split_results, roc_curves
--- a/multimodal_learner.py	Wed Jan 28 19:56:37 2026 +0000
+++ b/multimodal_learner.py	Fri Jan 30 14:20:49 2026 +0000
@@ -457,6 +457,7 @@
         ag_folds=ag_folds,
         raw_metrics_std=raw_metrics_std,
         ag_by_split_std=ag_by_split_std,
+        ag_config=ag_config,
     )
 
 
--- a/multimodal_learner.xml	Wed Jan 28 19:56:37 2026 +0000
+++ b/multimodal_learner.xml	Fri Jan 30 14:20:49 2026 +0000
@@ -1,4 +1,4 @@
-<tool id="multimodal_learner" name="Multimodal Learner" version="0.1.4" profile="22.01">
+<tool id="multimodal_learner" name="Multimodal Learner" version="0.1.5" profile="22.01">
   <description>Train and evaluate an AutoGluon Multimodal model (tabular + image + text)</description>
 
   <requirements>
--- a/report_utils.py	Wed Jan 28 19:56:37 2026 +0000
+++ b/report_utils.py	Fri Jan 30 14:20:49 2026 +0000
@@ -35,14 +35,49 @@
         return None
 
 
-def _copy_config_if_available(pred_path: Optional[str], output_config: Optional[str]):
+def _replace_local_backbone_strings(obj: Any, replacements: Optional[dict]) -> Any:
+    if not replacements:
+        return obj
+    if isinstance(obj, str):
+        out = obj
+        for needle, repl in replacements.items():
+            if repl:
+                out = out.replace(needle, str(repl))
+        return out
+    if isinstance(obj, dict):
+        return {k: _replace_local_backbone_strings(v, replacements) for k, v in obj.items()}
+    if isinstance(obj, list):
+        return [_replace_local_backbone_strings(v, replacements) for v in obj]
+    if isinstance(obj, tuple):
+        return tuple(_replace_local_backbone_strings(v, replacements) for v in obj)
+    return obj
+
+
+def _copy_config_if_available(
+    pred_path: Optional[str],
+    output_config: Optional[str],
+    replacements: Optional[dict] = None,
+    cfg_yaml: Optional[dict] = None,
+):
     if not output_config:
         return
     try:
         config_yaml_path = os.path.join(pred_path, "config.yaml") if pred_path else None
         if config_yaml_path and os.path.isfile(config_yaml_path):
-            shutil.copy2(config_yaml_path, output_config)
-            logger.info(f"Wrote AutoGluon config → {output_config}")
+            try:
+                if isinstance(cfg_yaml, dict) and cfg_yaml:
+                    cfg_raw = cfg_yaml
+                else:
+                    with open(config_yaml_path, "r") as cfg_in:
+                        cfg_raw = yaml.safe_load(cfg_in) or {}
+                cfg_sanitized = _replace_local_backbone_strings(cfg_raw, replacements)
+                with open(output_config, "w") as cfg_out:
+                    yaml.safe_dump(cfg_sanitized, cfg_out, sort_keys=False)
+                logger.info(f"Wrote AutoGluon config → {output_config}")
+            except Exception as e:
+                logger.warning(f"Failed to sanitize config.yaml; copying raw file instead: {e}")
+                shutil.copy2(config_yaml_path, output_config)
+                logger.info(f"Wrote AutoGluon config → {output_config}")
         else:
             with open(output_config, "w") as cfg_out:
                 cfg_out.write("# config.yaml not found for this run\n")
@@ -78,10 +113,74 @@
     return {}
 
 
-def _summarize_config(cfg: dict, args) -> List[tuple[str, str]]:
+def _is_local_checkpoint(val: Optional[str]) -> bool:
+    return isinstance(val, str) and val.startswith("local://")
+
+
+def _read_checkpoint_from_cfg(cfg: dict, model_key: str) -> Optional[str]:
+    if not isinstance(cfg, dict):
+        return None
+    model_cfg = cfg.get("model", {})
+    if not isinstance(model_cfg, dict):
+        return None
+    block = model_cfg.get(model_key)
+    if isinstance(block, str):
+        return block
+    if isinstance(block, dict):
+        for k in ("checkpoint_name", "model_name", "name"):
+            val = block.get(k)
+            if val:
+                return val
+    return None
+
+
+def _read_checkpoint_from_hparams(hparams: Optional[dict], model_key: str) -> Optional[str]:
+    if not isinstance(hparams, dict):
+        return None
+    model_block = hparams.get("model", {})
+    if isinstance(model_block, dict):
+        block = model_block.get(model_key)
+        if isinstance(block, str):
+            return block
+        if isinstance(block, dict):
+            for k in ("checkpoint_name", "model_name", "name"):
+                val = block.get(k)
+                if val:
+                    return val
+    for k in ("checkpoint_name", "model_name", "name"):
+        dotted = hparams.get(f"model.{model_key}.{k}")
+        if dotted:
+            return dotted
+    return None
+
+
+def _resolve_backbone_choice(cfg: dict, ag_config: Optional[dict], args, model_key: str, arg_attr: str) -> Optional[str]:
+    candidates = [
+        _read_checkpoint_from_cfg(cfg, model_key),
+        _read_checkpoint_from_hparams((ag_config or {}).get("hyperparameters"), model_key),
+        getattr(args, arg_attr, None),
+    ]
+    for val in candidates:
+        if val and not _is_local_checkpoint(val):
+            return str(val)
+    return None
+
+
+def _build_backbone_replacements(cfg: dict, ag_config: Optional[dict], args) -> dict:
+    replacements = {}
+    text_choice = _resolve_backbone_choice(cfg, ag_config, args, "hf_text", "backbone_text")
+    image_choice = _resolve_backbone_choice(cfg, ag_config, args, "timm_image", "backbone_image")
+    if text_choice:
+        replacements["local://hf_text"] = text_choice
+    if image_choice:
+        replacements["local://timm_image"] = image_choice
+    return replacements
+
+
+def _summarize_config(cfg: dict, args, ag_config: Optional[dict] = None) -> List[tuple[str, str]]:
     """
     Build rows describing model components and key hyperparameters from a loaded config.yaml.
-    Falls back to CLI args when config values are missing.
+    Falls back to AutoGluon hyperparameters and CLI args when config values are missing.
     """
     rows: List[tuple[str, str]] = []
     model_cfg = cfg.get("model", {}) if isinstance(cfg, dict) else {}
@@ -102,11 +201,11 @@
                 break
     rows.append(("Tabular backbone", tabular_val))
 
-    image_val = model_cfg.get("timm_image", {}).get("checkpoint_name") or "—"
-    rows.append(("Image backbone", image_val))
+    image_val = _resolve_backbone_choice(cfg, ag_config, args, "timm_image", "backbone_image")
+    rows.append(("Image backbone", image_val or "—"))
 
-    text_val = model_cfg.get("hf_text", {}).get("checkpoint_name") or "—"
-    rows.append(("Text backbone", text_val))
+    text_val = _resolve_backbone_choice(cfg, ag_config, args, "hf_text", "backbone_text")
+    rows.append(("Text backbone", text_val or "—"))
 
     fusion_val = "—"
     for k in model_cfg.keys():
@@ -150,6 +249,7 @@
     ag_folds=None,
     raw_metrics_std=None,
     ag_by_split_std=None,
+    ag_config: Optional[dict] = None,
 ):
     from plot_logic import (
         build_summary_html,
@@ -164,6 +264,11 @@
     raw_metrics = eval_results.get("raw_metrics", {})
     ag_by_split = eval_results.get("ag_eval", {})
     fit_summary_obj = eval_results.get("fit_summary")
+    roc_curves = eval_results.get("roc_curves")
+
+    cfg_yaml = _load_config_yaml(args, predictor)
+    backbone_replacements = _build_backbone_replacements(cfg_yaml, ag_config, args)
+    fit_summary_obj = _replace_local_backbone_strings(fit_summary_obj, backbone_replacements)
 
     df_train = data_ctx.get("train")
     df_val = data_ctx.get("val")
@@ -203,6 +308,7 @@
                 "ag_eval": ag_by_split,
                 "ag_eval_std": ag_by_split_std,
                 "fit_summary": fit_summary_obj,
+                "roc_curves": roc_curves,
                 "problem_type": problem_type,
                 "predictor_path": getattr(predictor, "path", None),
                 "threshold": args.threshold,
@@ -242,8 +348,7 @@
         show_title=False,
     )
 
-    cfg_yaml = _load_config_yaml(args, predictor)
-    config_rows = _summarize_config(cfg_yaml, args)
+    config_rows = _summarize_config(cfg_yaml, args, ag_config=ag_config)
     threshold_rows = []
     if problem_type == "binary" and args.threshold is not None:
         threshold_rows.append(("Decision threshold (Test)", f"{float(args.threshold):.3f}"))
@@ -315,7 +420,7 @@
     leaderboard_html = "" if is_multimodal else build_leaderboard_html(predictor)
     inputs_html = ""
     ignored_features_html = "" if is_multimodal else build_ignored_features_html(predictor, df_train_full)
-    presets_hparams_html = build_presets_hparams_html(predictor)
+    presets_hparams_html = build_presets_hparams_html(predictor, backbone_replacements)
     notices: List[str] = []
     if args.threshold is not None and problem_type == "binary":
         notices.append(f"Using decision threshold = {float(args.threshold):.3f} on Test.")
@@ -350,7 +455,12 @@
     logger.info(f"Wrote HTML report → {args.output_html}")
 
     pred_path = _write_predictor_path(predictor)
-    _copy_config_if_available(pred_path, args.output_config)
+    _copy_config_if_available(
+        pred_path,
+        args.output_config,
+        backbone_replacements,
+        cfg_yaml=cfg_yaml,
+    )
 
     outputs_to_check = [
         (args.output_json, "JSON results"),
@@ -886,7 +996,7 @@
     """
 
 
-def build_presets_hparams_html(predictor) -> str:
+def build_presets_hparams_html(predictor, replacements: Optional[dict] = None) -> str:
     # MultiModalPredictor path
     mm_hp = {}
     for attr in ("_config", "config", "_fit_args"):
@@ -894,7 +1004,7 @@
             try:
                 val = getattr(predictor, attr)
                 # make it JSON-ish
-                mm_hp[attr] = str(val)
+                mm_hp[attr] = _replace_local_backbone_strings(str(val), replacements)
             except Exception:
                 continue
     hp_html = f"<pre>{html.escape(json.dumps(mm_hp, indent=2))}</pre>" if mm_hp else "<i>Unavailable</i>"
--- a/test_pipeline.py	Wed Jan 28 19:56:37 2026 +0000
+++ b/test_pipeline.py	Fri Jan 30 14:20:49 2026 +0000
@@ -50,7 +50,7 @@
         problem_type = infer_problem_type(predictor, base_df, target_column)
 
     df_test_final = df_test_external if df_test_external is not None else df_test_internal
-    raw_metrics, ag_by_split = evaluate_predictor_all_splits(
+    raw_metrics, ag_by_split, roc_curves = evaluate_predictor_all_splits(
         predictor=predictor,
         df_train=df_train,
         df_val=df_val,
@@ -69,6 +69,7 @@
         "raw_metrics": raw_metrics,
         "ag_eval": ag_by_split,
         "fit_summary": summary,
+        "roc_curves": roc_curves,
     }
     logger.info("Evaluation complete; splits: %s", list(raw_metrics.keys()))
     return result
--- a/training_pipeline.py	Wed Jan 28 19:56:37 2026 +0000
+++ b/training_pipeline.py	Fri Jan 30 14:20:49 2026 +0000
@@ -316,7 +316,7 @@
     eval_metric: Optional[str],
     threshold_test: Optional[float],
     df_test_external: Optional[pd.DataFrame] = None,
-) -> Tuple[Dict[str, Dict[str, float]], Dict[str, Dict[str, float]]]:
+) -> Tuple[Dict[str, Dict[str, float]], Dict[str, Dict[str, float]], Dict[str, dict]]:
     """
     Returns (raw_metrics, ag_scores_by_split)
       - raw_metrics: our transparent suite (threshold applied to Test/External Test only inside metrics_logic)
@@ -335,7 +335,7 @@
         ag_by_split["Test"] = ag_evaluate_safely(predictor, df_test_effective, metrics=metrics_req)
 
     # Transparent suite (threshold on Test handled inside metrics_logic)
-    _, raw_metrics = evaluate_all_transparency(
+    _, raw_metrics, roc_curves = evaluate_all_transparency(
         predictor=predictor,
         train_df=df_train,
         val_df=df_val,
@@ -346,12 +346,20 @@
     )
 
     if df_test_external is not None and df_test_external is not df_test and len(df_test_external):
-        raw_metrics["Test (external)"] = compute_metrics_for_split(
-            predictor, df_test_external, label_col, problem_type, threshold=threshold_test
+        ext_metrics, ext_curve = compute_metrics_for_split(
+            predictor,
+            df_test_external,
+            label_col,
+            problem_type,
+            threshold=threshold_test,
+            return_curve=True,
         )
+        raw_metrics["Test (external)"] = ext_metrics
+        if ext_curve:
+            roc_curves["Test (external)"] = ext_curve
         ag_by_split["Test (external)"] = ag_evaluate_safely(predictor, df_test_external, metrics=metrics_req)
 
-    return raw_metrics, ag_by_split
+    return raw_metrics, ag_by_split, roc_curves
 
 
 def fit_summary_safely(predictor) -> Optional[dict]: