diff report_utils.py @ 8:a48e750cfd25 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit c8a7fef0c54c269afd6c6bdf035af1a7574d11cb
author goeckslab
date Fri, 30 Jan 2026 14:20:49 +0000
parents 375c36923da1
children
line wrap: on
line diff
--- a/report_utils.py	Wed Jan 28 19:56:37 2026 +0000
+++ b/report_utils.py	Fri Jan 30 14:20:49 2026 +0000
@@ -35,14 +35,49 @@
         return None
 
 
-def _copy_config_if_available(pred_path: Optional[str], output_config: Optional[str]):
+def _replace_local_backbone_strings(obj: Any, replacements: Optional[dict]) -> Any:
+    if not replacements:
+        return obj
+    if isinstance(obj, str):
+        out = obj
+        for needle, repl in replacements.items():
+            if repl:
+                out = out.replace(needle, str(repl))
+        return out
+    if isinstance(obj, dict):
+        return {k: _replace_local_backbone_strings(v, replacements) for k, v in obj.items()}
+    if isinstance(obj, list):
+        return [_replace_local_backbone_strings(v, replacements) for v in obj]
+    if isinstance(obj, tuple):
+        return tuple(_replace_local_backbone_strings(v, replacements) for v in obj)
+    return obj
+
+
+def _copy_config_if_available(
+    pred_path: Optional[str],
+    output_config: Optional[str],
+    replacements: Optional[dict] = None,
+    cfg_yaml: Optional[dict] = None,
+):
     if not output_config:
         return
     try:
         config_yaml_path = os.path.join(pred_path, "config.yaml") if pred_path else None
         if config_yaml_path and os.path.isfile(config_yaml_path):
-            shutil.copy2(config_yaml_path, output_config)
-            logger.info(f"Wrote AutoGluon config → {output_config}")
+            try:
+                if isinstance(cfg_yaml, dict) and cfg_yaml:
+                    cfg_raw = cfg_yaml
+                else:
+                    with open(config_yaml_path, "r") as cfg_in:
+                        cfg_raw = yaml.safe_load(cfg_in) or {}
+                cfg_sanitized = _replace_local_backbone_strings(cfg_raw, replacements)
+                with open(output_config, "w") as cfg_out:
+                    yaml.safe_dump(cfg_sanitized, cfg_out, sort_keys=False)
+                logger.info(f"Wrote AutoGluon config → {output_config}")
+            except Exception as e:
+                logger.warning(f"Failed to sanitize config.yaml; copying raw file instead: {e}")
+                shutil.copy2(config_yaml_path, output_config)
+                logger.info(f"Wrote AutoGluon config → {output_config}")
         else:
             with open(output_config, "w") as cfg_out:
                 cfg_out.write("# config.yaml not found for this run\n")
@@ -78,10 +113,74 @@
     return {}
 
 
-def _summarize_config(cfg: dict, args) -> List[tuple[str, str]]:
+def _is_local_checkpoint(val: Optional[str]) -> bool:
+    return isinstance(val, str) and val.startswith("local://")
+
+
+def _read_checkpoint_from_cfg(cfg: dict, model_key: str) -> Optional[str]:
+    if not isinstance(cfg, dict):
+        return None
+    model_cfg = cfg.get("model", {})
+    if not isinstance(model_cfg, dict):
+        return None
+    block = model_cfg.get(model_key)
+    if isinstance(block, str):
+        return block
+    if isinstance(block, dict):
+        for k in ("checkpoint_name", "model_name", "name"):
+            val = block.get(k)
+            if val:
+                return val
+    return None
+
+
+def _read_checkpoint_from_hparams(hparams: Optional[dict], model_key: str) -> Optional[str]:
+    if not isinstance(hparams, dict):
+        return None
+    model_block = hparams.get("model", {})
+    if isinstance(model_block, dict):
+        block = model_block.get(model_key)
+        if isinstance(block, str):
+            return block
+        if isinstance(block, dict):
+            for k in ("checkpoint_name", "model_name", "name"):
+                val = block.get(k)
+                if val:
+                    return val
+    for k in ("checkpoint_name", "model_name", "name"):
+        dotted = hparams.get(f"model.{model_key}.{k}")
+        if dotted:
+            return dotted
+    return None
+
+
+def _resolve_backbone_choice(cfg: dict, ag_config: Optional[dict], args, model_key: str, arg_attr: str) -> Optional[str]:
+    candidates = [
+        _read_checkpoint_from_cfg(cfg, model_key),
+        _read_checkpoint_from_hparams((ag_config or {}).get("hyperparameters"), model_key),
+        getattr(args, arg_attr, None),
+    ]
+    for val in candidates:
+        if val and not _is_local_checkpoint(val):
+            return str(val)
+    return None
+
+
+def _build_backbone_replacements(cfg: dict, ag_config: Optional[dict], args) -> dict:
+    replacements = {}
+    text_choice = _resolve_backbone_choice(cfg, ag_config, args, "hf_text", "backbone_text")
+    image_choice = _resolve_backbone_choice(cfg, ag_config, args, "timm_image", "backbone_image")
+    if text_choice:
+        replacements["local://hf_text"] = text_choice
+    if image_choice:
+        replacements["local://timm_image"] = image_choice
+    return replacements
+
+
+def _summarize_config(cfg: dict, args, ag_config: Optional[dict] = None) -> List[tuple[str, str]]:
     """
     Build rows describing model components and key hyperparameters from a loaded config.yaml.
-    Falls back to CLI args when config values are missing.
+    Falls back to AutoGluon hyperparameters and CLI args when config values are missing.
     """
     rows: List[tuple[str, str]] = []
     model_cfg = cfg.get("model", {}) if isinstance(cfg, dict) else {}
@@ -102,11 +201,11 @@
                 break
     rows.append(("Tabular backbone", tabular_val))
 
-    image_val = model_cfg.get("timm_image", {}).get("checkpoint_name") or "—"
-    rows.append(("Image backbone", image_val))
+    image_val = _resolve_backbone_choice(cfg, ag_config, args, "timm_image", "backbone_image")
+    rows.append(("Image backbone", image_val or "—"))
 
-    text_val = model_cfg.get("hf_text", {}).get("checkpoint_name") or "—"
-    rows.append(("Text backbone", text_val))
+    text_val = _resolve_backbone_choice(cfg, ag_config, args, "hf_text", "backbone_text")
+    rows.append(("Text backbone", text_val or "—"))
 
     fusion_val = "—"
     for k in model_cfg.keys():
@@ -150,6 +249,7 @@
     ag_folds=None,
     raw_metrics_std=None,
     ag_by_split_std=None,
+    ag_config: Optional[dict] = None,
 ):
     from plot_logic import (
         build_summary_html,
@@ -164,6 +264,11 @@
     raw_metrics = eval_results.get("raw_metrics", {})
     ag_by_split = eval_results.get("ag_eval", {})
     fit_summary_obj = eval_results.get("fit_summary")
+    roc_curves = eval_results.get("roc_curves")
+
+    cfg_yaml = _load_config_yaml(args, predictor)
+    backbone_replacements = _build_backbone_replacements(cfg_yaml, ag_config, args)
+    fit_summary_obj = _replace_local_backbone_strings(fit_summary_obj, backbone_replacements)
 
     df_train = data_ctx.get("train")
     df_val = data_ctx.get("val")
@@ -203,6 +308,7 @@
                 "ag_eval": ag_by_split,
                 "ag_eval_std": ag_by_split_std,
                 "fit_summary": fit_summary_obj,
+                "roc_curves": roc_curves,
                 "problem_type": problem_type,
                 "predictor_path": getattr(predictor, "path", None),
                 "threshold": args.threshold,
@@ -242,8 +348,7 @@
         show_title=False,
     )
 
-    cfg_yaml = _load_config_yaml(args, predictor)
-    config_rows = _summarize_config(cfg_yaml, args)
+    config_rows = _summarize_config(cfg_yaml, args, ag_config=ag_config)
     threshold_rows = []
     if problem_type == "binary" and args.threshold is not None:
         threshold_rows.append(("Decision threshold (Test)", f"{float(args.threshold):.3f}"))
@@ -315,7 +420,7 @@
     leaderboard_html = "" if is_multimodal else build_leaderboard_html(predictor)
     inputs_html = ""
     ignored_features_html = "" if is_multimodal else build_ignored_features_html(predictor, df_train_full)
-    presets_hparams_html = build_presets_hparams_html(predictor)
+    presets_hparams_html = build_presets_hparams_html(predictor, backbone_replacements)
     notices: List[str] = []
     if args.threshold is not None and problem_type == "binary":
         notices.append(f"Using decision threshold = {float(args.threshold):.3f} on Test.")
@@ -350,7 +455,12 @@
     logger.info(f"Wrote HTML report → {args.output_html}")
 
     pred_path = _write_predictor_path(predictor)
-    _copy_config_if_available(pred_path, args.output_config)
+    _copy_config_if_available(
+        pred_path,
+        args.output_config,
+        backbone_replacements,
+        cfg_yaml=cfg_yaml,
+    )
 
     outputs_to_check = [
         (args.output_json, "JSON results"),
@@ -886,7 +996,7 @@
     """
 
 
-def build_presets_hparams_html(predictor) -> str:
+def build_presets_hparams_html(predictor, replacements: Optional[dict] = None) -> str:
     # MultiModalPredictor path
     mm_hp = {}
     for attr in ("_config", "config", "_fit_args"):
@@ -894,7 +1004,7 @@
             try:
                 val = getattr(predictor, attr)
                 # make it JSON-ish
-                mm_hp[attr] = str(val)
+                mm_hp[attr] = _replace_local_backbone_strings(str(val), replacements)
             except Exception:
                 continue
     hp_html = f"<pre>{html.escape(json.dumps(mm_hp, indent=2))}</pre>" if mm_hp else "<i>Unavailable</i>"