Mercurial > repos > goeckslab > multimodal_learner
comparison report_utils.py @ 8:a48e750cfd25 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit c8a7fef0c54c269afd6c6bdf035af1a7574d11cb
| author | goeckslab |
|---|---|
| date | Fri, 30 Jan 2026 14:20:49 +0000 |
| parents | 375c36923da1 |
| children |
comparison
equal
deleted
inserted
replaced
| 7:ed2fefc8d892 | 8:a48e750cfd25 |
|---|---|
| 33 except Exception: | 33 except Exception: |
| 34 logger.warning("Could not write predictor_path.txt") | 34 logger.warning("Could not write predictor_path.txt") |
| 35 return None | 35 return None |
| 36 | 36 |
| 37 | 37 |
| 38 def _copy_config_if_available(pred_path: Optional[str], output_config: Optional[str]): | 38 def _replace_local_backbone_strings(obj: Any, replacements: Optional[dict]) -> Any: |
| 39 if not replacements: | |
| 40 return obj | |
| 41 if isinstance(obj, str): | |
| 42 out = obj | |
| 43 for needle, repl in replacements.items(): | |
| 44 if repl: | |
| 45 out = out.replace(needle, str(repl)) | |
| 46 return out | |
| 47 if isinstance(obj, dict): | |
| 48 return {k: _replace_local_backbone_strings(v, replacements) for k, v in obj.items()} | |
| 49 if isinstance(obj, list): | |
| 50 return [_replace_local_backbone_strings(v, replacements) for v in obj] | |
| 51 if isinstance(obj, tuple): | |
| 52 return tuple(_replace_local_backbone_strings(v, replacements) for v in obj) | |
| 53 return obj | |
| 54 | |
| 55 | |
| 56 def _copy_config_if_available( | |
| 57 pred_path: Optional[str], | |
| 58 output_config: Optional[str], | |
| 59 replacements: Optional[dict] = None, | |
| 60 cfg_yaml: Optional[dict] = None, | |
| 61 ): | |
| 39 if not output_config: | 62 if not output_config: |
| 40 return | 63 return |
| 41 try: | 64 try: |
| 42 config_yaml_path = os.path.join(pred_path, "config.yaml") if pred_path else None | 65 config_yaml_path = os.path.join(pred_path, "config.yaml") if pred_path else None |
| 43 if config_yaml_path and os.path.isfile(config_yaml_path): | 66 if config_yaml_path and os.path.isfile(config_yaml_path): |
| 44 shutil.copy2(config_yaml_path, output_config) | 67 try: |
| 45 logger.info(f"Wrote AutoGluon config → {output_config}") | 68 if isinstance(cfg_yaml, dict) and cfg_yaml: |
| 69 cfg_raw = cfg_yaml | |
| 70 else: | |
| 71 with open(config_yaml_path, "r") as cfg_in: | |
| 72 cfg_raw = yaml.safe_load(cfg_in) or {} | |
| 73 cfg_sanitized = _replace_local_backbone_strings(cfg_raw, replacements) | |
| 74 with open(output_config, "w") as cfg_out: | |
| 75 yaml.safe_dump(cfg_sanitized, cfg_out, sort_keys=False) | |
| 76 logger.info(f"Wrote AutoGluon config → {output_config}") | |
| 77 except Exception as e: | |
| 78 logger.warning(f"Failed to sanitize config.yaml; copying raw file instead: {e}") | |
| 79 shutil.copy2(config_yaml_path, output_config) | |
| 80 logger.info(f"Wrote AutoGluon config → {output_config}") | |
| 46 else: | 81 else: |
| 47 with open(output_config, "w") as cfg_out: | 82 with open(output_config, "w") as cfg_out: |
| 48 cfg_out.write("# config.yaml not found for this run\n") | 83 cfg_out.write("# config.yaml not found for this run\n") |
| 49 logger.warning(f"AutoGluon config.yaml not found; created placeholder at {output_config}") | 84 logger.warning(f"AutoGluon config.yaml not found; created placeholder at {output_config}") |
| 50 except Exception as e: | 85 except Exception as e: |
| 76 except Exception: | 111 except Exception: |
| 77 continue | 112 continue |
| 78 return {} | 113 return {} |
| 79 | 114 |
| 80 | 115 |
| 81 def _summarize_config(cfg: dict, args) -> List[tuple[str, str]]: | 116 def _is_local_checkpoint(val: Optional[str]) -> bool: |
| 117 return isinstance(val, str) and val.startswith("local://") | |
| 118 | |
| 119 | |
| 120 def _read_checkpoint_from_cfg(cfg: dict, model_key: str) -> Optional[str]: | |
| 121 if not isinstance(cfg, dict): | |
| 122 return None | |
| 123 model_cfg = cfg.get("model", {}) | |
| 124 if not isinstance(model_cfg, dict): | |
| 125 return None | |
| 126 block = model_cfg.get(model_key) | |
| 127 if isinstance(block, str): | |
| 128 return block | |
| 129 if isinstance(block, dict): | |
| 130 for k in ("checkpoint_name", "model_name", "name"): | |
| 131 val = block.get(k) | |
| 132 if val: | |
| 133 return val | |
| 134 return None | |
| 135 | |
| 136 | |
| 137 def _read_checkpoint_from_hparams(hparams: Optional[dict], model_key: str) -> Optional[str]: | |
| 138 if not isinstance(hparams, dict): | |
| 139 return None | |
| 140 model_block = hparams.get("model", {}) | |
| 141 if isinstance(model_block, dict): | |
| 142 block = model_block.get(model_key) | |
| 143 if isinstance(block, str): | |
| 144 return block | |
| 145 if isinstance(block, dict): | |
| 146 for k in ("checkpoint_name", "model_name", "name"): | |
| 147 val = block.get(k) | |
| 148 if val: | |
| 149 return val | |
| 150 for k in ("checkpoint_name", "model_name", "name"): | |
| 151 dotted = hparams.get(f"model.{model_key}.{k}") | |
| 152 if dotted: | |
| 153 return dotted | |
| 154 return None | |
| 155 | |
| 156 | |
| 157 def _resolve_backbone_choice(cfg: dict, ag_config: Optional[dict], args, model_key: str, arg_attr: str) -> Optional[str]: | |
| 158 candidates = [ | |
| 159 _read_checkpoint_from_cfg(cfg, model_key), | |
| 160 _read_checkpoint_from_hparams((ag_config or {}).get("hyperparameters"), model_key), | |
| 161 getattr(args, arg_attr, None), | |
| 162 ] | |
| 163 for val in candidates: | |
| 164 if val and not _is_local_checkpoint(val): | |
| 165 return str(val) | |
| 166 return None | |
| 167 | |
| 168 | |
| 169 def _build_backbone_replacements(cfg: dict, ag_config: Optional[dict], args) -> dict: | |
| 170 replacements = {} | |
| 171 text_choice = _resolve_backbone_choice(cfg, ag_config, args, "hf_text", "backbone_text") | |
| 172 image_choice = _resolve_backbone_choice(cfg, ag_config, args, "timm_image", "backbone_image") | |
| 173 if text_choice: | |
| 174 replacements["local://hf_text"] = text_choice | |
| 175 if image_choice: | |
| 176 replacements["local://timm_image"] = image_choice | |
| 177 return replacements | |
| 178 | |
| 179 | |
| 180 def _summarize_config(cfg: dict, args, ag_config: Optional[dict] = None) -> List[tuple[str, str]]: | |
| 82 """ | 181 """ |
| 83 Build rows describing model components and key hyperparameters from a loaded config.yaml. | 182 Build rows describing model components and key hyperparameters from a loaded config.yaml. |
| 84 Falls back to CLI args when config values are missing. | 183 Falls back to AutoGluon hyperparameters and CLI args when config values are missing. |
| 85 """ | 184 """ |
| 86 rows: List[tuple[str, str]] = [] | 185 rows: List[tuple[str, str]] = [] |
| 87 model_cfg = cfg.get("model", {}) if isinstance(cfg, dict) else {} | 186 model_cfg = cfg.get("model", {}) if isinstance(cfg, dict) else {} |
| 88 names = model_cfg.get("names") or [] | 187 names = model_cfg.get("names") or [] |
| 89 if names: | 188 if names: |
| 100 dt_str = ", ".join(dtypes) if dtypes else "" | 199 dt_str = ", ".join(dtypes) if dtypes else "" |
| 101 tabular_val = f"{k} ({dt_str})" if dt_str else k | 200 tabular_val = f"{k} ({dt_str})" if dt_str else k |
| 102 break | 201 break |
| 103 rows.append(("Tabular backbone", tabular_val)) | 202 rows.append(("Tabular backbone", tabular_val)) |
| 104 | 203 |
| 105 image_val = model_cfg.get("timm_image", {}).get("checkpoint_name") or "—" | 204 image_val = _resolve_backbone_choice(cfg, ag_config, args, "timm_image", "backbone_image") |
| 106 rows.append(("Image backbone", image_val)) | 205 rows.append(("Image backbone", image_val or "—")) |
| 107 | 206 |
| 108 text_val = model_cfg.get("hf_text", {}).get("checkpoint_name") or "—" | 207 text_val = _resolve_backbone_choice(cfg, ag_config, args, "hf_text", "backbone_text") |
| 109 rows.append(("Text backbone", text_val)) | 208 rows.append(("Text backbone", text_val or "—")) |
| 110 | 209 |
| 111 fusion_val = "—" | 210 fusion_val = "—" |
| 112 for k in model_cfg.keys(): | 211 for k in model_cfg.keys(): |
| 113 if str(k).startswith("fusion"): | 212 if str(k).startswith("fusion"): |
| 114 fusion_val = k | 213 fusion_val = k |
| 148 data_ctx: dict, | 247 data_ctx: dict, |
| 149 raw_folds=None, | 248 raw_folds=None, |
| 150 ag_folds=None, | 249 ag_folds=None, |
| 151 raw_metrics_std=None, | 250 raw_metrics_std=None, |
| 152 ag_by_split_std=None, | 251 ag_by_split_std=None, |
| 252 ag_config: Optional[dict] = None, | |
| 153 ): | 253 ): |
| 154 from plot_logic import ( | 254 from plot_logic import ( |
| 155 build_summary_html, | 255 build_summary_html, |
| 156 build_test_html_and_plots, | 256 build_test_html_and_plots, |
| 157 build_feature_html, | 257 build_feature_html, |
| 162 from metrics_logic import aggregate_metrics | 262 from metrics_logic import aggregate_metrics |
| 163 | 263 |
| 164 raw_metrics = eval_results.get("raw_metrics", {}) | 264 raw_metrics = eval_results.get("raw_metrics", {}) |
| 165 ag_by_split = eval_results.get("ag_eval", {}) | 265 ag_by_split = eval_results.get("ag_eval", {}) |
| 166 fit_summary_obj = eval_results.get("fit_summary") | 266 fit_summary_obj = eval_results.get("fit_summary") |
| 267 roc_curves = eval_results.get("roc_curves") | |
| 268 | |
| 269 cfg_yaml = _load_config_yaml(args, predictor) | |
| 270 backbone_replacements = _build_backbone_replacements(cfg_yaml, ag_config, args) | |
| 271 fit_summary_obj = _replace_local_backbone_strings(fit_summary_obj, backbone_replacements) | |
| 167 | 272 |
| 168 df_train = data_ctx.get("train") | 273 df_train = data_ctx.get("train") |
| 169 df_val = data_ctx.get("val") | 274 df_val = data_ctx.get("val") |
| 170 df_test_internal = data_ctx.get("test_internal") | 275 df_test_internal = data_ctx.get("test_internal") |
| 171 df_test_external = data_ctx.get("test_external") | 276 df_test_external = data_ctx.get("test_external") |
| 201 "test": raw_metrics.get("Test", {}), | 306 "test": raw_metrics.get("Test", {}), |
| 202 "test_external": raw_metrics.get("Test (external)", {}), | 307 "test_external": raw_metrics.get("Test (external)", {}), |
| 203 "ag_eval": ag_by_split, | 308 "ag_eval": ag_by_split, |
| 204 "ag_eval_std": ag_by_split_std, | 309 "ag_eval_std": ag_by_split_std, |
| 205 "fit_summary": fit_summary_obj, | 310 "fit_summary": fit_summary_obj, |
| 311 "roc_curves": roc_curves, | |
| 206 "problem_type": problem_type, | 312 "problem_type": problem_type, |
| 207 "predictor_path": getattr(predictor, "path", None), | 313 "predictor_path": getattr(predictor, "path", None), |
| 208 "threshold": args.threshold, | 314 "threshold": args.threshold, |
| 209 "threshold_test": args.threshold, | 315 "threshold_test": args.threshold, |
| 210 "preset": args.preset, | 316 "preset": args.preset, |
| 240 include_test=True, | 346 include_test=True, |
| 241 title=None, | 347 title=None, |
| 242 show_title=False, | 348 show_title=False, |
| 243 ) | 349 ) |
| 244 | 350 |
| 245 cfg_yaml = _load_config_yaml(args, predictor) | 351 config_rows = _summarize_config(cfg_yaml, args, ag_config=ag_config) |
| 246 config_rows = _summarize_config(cfg_yaml, args) | |
| 247 threshold_rows = [] | 352 threshold_rows = [] |
| 248 if problem_type == "binary" and args.threshold is not None: | 353 if problem_type == "binary" and args.threshold is not None: |
| 249 threshold_rows.append(("Decision threshold (Test)", f"{float(args.threshold):.3f}")) | 354 threshold_rows.append(("Decision threshold (Test)", f"{float(args.threshold):.3f}")) |
| 250 extra_run_rows = [ | 355 extra_run_rows = [ |
| 251 ("Target column", label_col), | 356 ("Target column", label_col), |
| 313 | 418 |
| 314 is_multimodal = isinstance(predictor, MultiModalPredictor) | 419 is_multimodal = isinstance(predictor, MultiModalPredictor) |
| 315 leaderboard_html = "" if is_multimodal else build_leaderboard_html(predictor) | 420 leaderboard_html = "" if is_multimodal else build_leaderboard_html(predictor) |
| 316 inputs_html = "" | 421 inputs_html = "" |
| 317 ignored_features_html = "" if is_multimodal else build_ignored_features_html(predictor, df_train_full) | 422 ignored_features_html = "" if is_multimodal else build_ignored_features_html(predictor, df_train_full) |
| 318 presets_hparams_html = build_presets_hparams_html(predictor) | 423 presets_hparams_html = build_presets_hparams_html(predictor, backbone_replacements) |
| 319 notices: List[str] = [] | 424 notices: List[str] = [] |
| 320 if args.threshold is not None and problem_type == "binary": | 425 if args.threshold is not None and problem_type == "binary": |
| 321 notices.append(f"Using decision threshold = {float(args.threshold):.3f} on Test.") | 426 notices.append(f"Using decision threshold = {float(args.threshold):.3f} on Test.") |
| 322 warnings_html = build_warnings_html([], notices) | 427 warnings_html = build_warnings_html([], notices) |
| 323 repro_html = build_reproducibility_html(args, {}, getattr(predictor, "path", None)) | 428 repro_html = build_reproducibility_html(args, {}, getattr(predictor, "path", None)) |
| 348 with open(args.output_html, "w") as f: | 453 with open(args.output_html, "w") as f: |
| 349 f.write(full_html) | 454 f.write(full_html) |
| 350 logger.info(f"Wrote HTML report → {args.output_html}") | 455 logger.info(f"Wrote HTML report → {args.output_html}") |
| 351 | 456 |
| 352 pred_path = _write_predictor_path(predictor) | 457 pred_path = _write_predictor_path(predictor) |
| 353 _copy_config_if_available(pred_path, args.output_config) | 458 _copy_config_if_available( |
| 459 pred_path, | |
| 460 args.output_config, | |
| 461 backbone_replacements, | |
| 462 cfg_yaml=cfg_yaml, | |
| 463 ) | |
| 354 | 464 |
| 355 outputs_to_check = [ | 465 outputs_to_check = [ |
| 356 (args.output_json, "JSON results"), | 466 (args.output_json, "JSON results"), |
| 357 (args.output_html, "HTML report"), | 467 (args.output_html, "HTML report"), |
| 358 ] | 468 ] |
| 884 <p>The following columns were not used by the trained predictor at inference time:</p> | 994 <p>The following columns were not used by the trained predictor at inference time:</p> |
| 885 <ul>{items}</ul> | 995 <ul>{items}</ul> |
| 886 """ | 996 """ |
| 887 | 997 |
| 888 | 998 |
| 889 def build_presets_hparams_html(predictor) -> str: | 999 def build_presets_hparams_html(predictor, replacements: Optional[dict] = None) -> str: |
| 890 # MultiModalPredictor path | 1000 # MultiModalPredictor path |
| 891 mm_hp = {} | 1001 mm_hp = {} |
| 892 for attr in ("_config", "config", "_fit_args"): | 1002 for attr in ("_config", "config", "_fit_args"): |
| 893 if hasattr(predictor, attr): | 1003 if hasattr(predictor, attr): |
| 894 try: | 1004 try: |
| 895 val = getattr(predictor, attr) | 1005 val = getattr(predictor, attr) |
| 896 # make it JSON-ish | 1006 # make it JSON-ish |
| 897 mm_hp[attr] = str(val) | 1007 mm_hp[attr] = _replace_local_backbone_strings(str(val), replacements) |
| 898 except Exception: | 1008 except Exception: |
| 899 continue | 1009 continue |
| 900 hp_html = f"<pre>{html.escape(json.dumps(mm_hp, indent=2))}</pre>" if mm_hp else "<i>Unavailable</i>" | 1010 hp_html = f"<pre>{html.escape(json.dumps(mm_hp, indent=2))}</pre>" if mm_hp else "<i>Unavailable</i>" |
| 901 return f"<h3>Training Presets & Hyperparameters</h3><details open><summary>Show hyperparameters</summary>{hp_html}</details>" | 1011 return f"<h3>Training Presets & Hyperparameters</h3><details open><summary>Show hyperparameters</summary>{hp_html}</details>" |
| 902 | 1012 |
