Mercurial > repos > goeckslab > multimodal_learner
diff report_utils.py @ 8:a48e750cfd25 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit c8a7fef0c54c269afd6c6bdf035af1a7574d11cb
| author | goeckslab |
|---|---|
| date | Fri, 30 Jan 2026 14:20:49 +0000 |
| parents | 375c36923da1 |
| children |
line wrap: on
line diff
--- a/report_utils.py Wed Jan 28 19:56:37 2026 +0000 +++ b/report_utils.py Fri Jan 30 14:20:49 2026 +0000 @@ -35,14 +35,49 @@ return None -def _copy_config_if_available(pred_path: Optional[str], output_config: Optional[str]): +def _replace_local_backbone_strings(obj: Any, replacements: Optional[dict]) -> Any: + if not replacements: + return obj + if isinstance(obj, str): + out = obj + for needle, repl in replacements.items(): + if repl: + out = out.replace(needle, str(repl)) + return out + if isinstance(obj, dict): + return {k: _replace_local_backbone_strings(v, replacements) for k, v in obj.items()} + if isinstance(obj, list): + return [_replace_local_backbone_strings(v, replacements) for v in obj] + if isinstance(obj, tuple): + return tuple(_replace_local_backbone_strings(v, replacements) for v in obj) + return obj + + +def _copy_config_if_available( + pred_path: Optional[str], + output_config: Optional[str], + replacements: Optional[dict] = None, + cfg_yaml: Optional[dict] = None, +): if not output_config: return try: config_yaml_path = os.path.join(pred_path, "config.yaml") if pred_path else None if config_yaml_path and os.path.isfile(config_yaml_path): - shutil.copy2(config_yaml_path, output_config) - logger.info(f"Wrote AutoGluon config → {output_config}") + try: + if isinstance(cfg_yaml, dict) and cfg_yaml: + cfg_raw = cfg_yaml + else: + with open(config_yaml_path, "r") as cfg_in: + cfg_raw = yaml.safe_load(cfg_in) or {} + cfg_sanitized = _replace_local_backbone_strings(cfg_raw, replacements) + with open(output_config, "w") as cfg_out: + yaml.safe_dump(cfg_sanitized, cfg_out, sort_keys=False) + logger.info(f"Wrote AutoGluon config → {output_config}") + except Exception as e: + logger.warning(f"Failed to sanitize config.yaml; copying raw file instead: {e}") + shutil.copy2(config_yaml_path, output_config) + logger.info(f"Wrote AutoGluon config → {output_config}") else: with open(output_config, "w") as cfg_out: cfg_out.write("# config.yaml not found for this run\n") @@ -78,10 +113,74 @@ return {} -def _summarize_config(cfg: dict, args) -> List[tuple[str, str]]: +def _is_local_checkpoint(val: Optional[str]) -> bool: + return isinstance(val, str) and val.startswith("local://") + + +def _read_checkpoint_from_cfg(cfg: dict, model_key: str) -> Optional[str]: + if not isinstance(cfg, dict): + return None + model_cfg = cfg.get("model", {}) + if not isinstance(model_cfg, dict): + return None + block = model_cfg.get(model_key) + if isinstance(block, str): + return block + if isinstance(block, dict): + for k in ("checkpoint_name", "model_name", "name"): + val = block.get(k) + if val: + return val + return None + + +def _read_checkpoint_from_hparams(hparams: Optional[dict], model_key: str) -> Optional[str]: + if not isinstance(hparams, dict): + return None + model_block = hparams.get("model", {}) + if isinstance(model_block, dict): + block = model_block.get(model_key) + if isinstance(block, str): + return block + if isinstance(block, dict): + for k in ("checkpoint_name", "model_name", "name"): + val = block.get(k) + if val: + return val + for k in ("checkpoint_name", "model_name", "name"): + dotted = hparams.get(f"model.{model_key}.{k}") + if dotted: + return dotted + return None + + +def _resolve_backbone_choice(cfg: dict, ag_config: Optional[dict], args, model_key: str, arg_attr: str) -> Optional[str]: + candidates = [ + _read_checkpoint_from_cfg(cfg, model_key), + _read_checkpoint_from_hparams((ag_config or {}).get("hyperparameters"), model_key), + getattr(args, arg_attr, None), + ] + for val in candidates: + if val and not _is_local_checkpoint(val): + return str(val) + return None + + +def _build_backbone_replacements(cfg: dict, ag_config: Optional[dict], args) -> dict: + replacements = {} + text_choice = _resolve_backbone_choice(cfg, ag_config, args, "hf_text", "backbone_text") + image_choice = _resolve_backbone_choice(cfg, ag_config, args, "timm_image", "backbone_image") + if text_choice: + replacements["local://hf_text"] = text_choice + if image_choice: + replacements["local://timm_image"] = image_choice + return replacements + + +def _summarize_config(cfg: dict, args, ag_config: Optional[dict] = None) -> List[tuple[str, str]]: """ Build rows describing model components and key hyperparameters from a loaded config.yaml. - Falls back to CLI args when config values are missing. + Falls back to AutoGluon hyperparameters and CLI args when config values are missing. """ rows: List[tuple[str, str]] = [] model_cfg = cfg.get("model", {}) if isinstance(cfg, dict) else {} @@ -102,11 +201,11 @@ break rows.append(("Tabular backbone", tabular_val)) - image_val = model_cfg.get("timm_image", {}).get("checkpoint_name") or "—" - rows.append(("Image backbone", image_val)) + image_val = _resolve_backbone_choice(cfg, ag_config, args, "timm_image", "backbone_image") + rows.append(("Image backbone", image_val or "—")) - text_val = model_cfg.get("hf_text", {}).get("checkpoint_name") or "—" - rows.append(("Text backbone", text_val)) + text_val = _resolve_backbone_choice(cfg, ag_config, args, "hf_text", "backbone_text") + rows.append(("Text backbone", text_val or "—")) fusion_val = "—" for k in model_cfg.keys(): @@ -150,6 +249,7 @@ ag_folds=None, raw_metrics_std=None, ag_by_split_std=None, + ag_config: Optional[dict] = None, ): from plot_logic import ( build_summary_html, @@ -164,6 +264,11 @@ raw_metrics = eval_results.get("raw_metrics", {}) ag_by_split = eval_results.get("ag_eval", {}) fit_summary_obj = eval_results.get("fit_summary") + roc_curves = eval_results.get("roc_curves") + + cfg_yaml = _load_config_yaml(args, predictor) + backbone_replacements = _build_backbone_replacements(cfg_yaml, ag_config, args) + fit_summary_obj = _replace_local_backbone_strings(fit_summary_obj, backbone_replacements) df_train = data_ctx.get("train") df_val = data_ctx.get("val") @@ -203,6 +308,7 @@ "ag_eval": ag_by_split, "ag_eval_std": ag_by_split_std, "fit_summary": fit_summary_obj, + "roc_curves": roc_curves, "problem_type": problem_type, "predictor_path": getattr(predictor, "path", None), "threshold": args.threshold, @@ -242,8 +348,7 @@ show_title=False, ) - cfg_yaml = _load_config_yaml(args, predictor) - config_rows = _summarize_config(cfg_yaml, args) + config_rows = _summarize_config(cfg_yaml, args, ag_config=ag_config) threshold_rows = [] if problem_type == "binary" and args.threshold is not None: threshold_rows.append(("Decision threshold (Test)", f"{float(args.threshold):.3f}")) @@ -315,7 +420,7 @@ leaderboard_html = "" if is_multimodal else build_leaderboard_html(predictor) inputs_html = "" ignored_features_html = "" if is_multimodal else build_ignored_features_html(predictor, df_train_full) - presets_hparams_html = build_presets_hparams_html(predictor) + presets_hparams_html = build_presets_hparams_html(predictor, backbone_replacements) notices: List[str] = [] if args.threshold is not None and problem_type == "binary": notices.append(f"Using decision threshold = {float(args.threshold):.3f} on Test.") @@ -350,7 +455,12 @@ logger.info(f"Wrote HTML report → {args.output_html}") pred_path = _write_predictor_path(predictor) - _copy_config_if_available(pred_path, args.output_config) + _copy_config_if_available( + pred_path, + args.output_config, + backbone_replacements, + cfg_yaml=cfg_yaml, + ) outputs_to_check = [ (args.output_json, "JSON results"), @@ -886,7 +996,7 @@ """ -def build_presets_hparams_html(predictor) -> str: +def build_presets_hparams_html(predictor, replacements: Optional[dict] = None) -> str: # MultiModalPredictor path mm_hp = {} for attr in ("_config", "config", "_fit_args"): @@ -894,7 +1004,7 @@ try: val = getattr(predictor, attr) # make it JSON-ish - mm_hp[attr] = str(val) + mm_hp[attr] = _replace_local_backbone_strings(str(val), replacements) except Exception: continue hp_html = f"<pre>{html.escape(json.dumps(mm_hp, indent=2))}</pre>" if mm_hp else "<i>Unavailable</i>"
