comparison report_utils.py @ 8:a48e750cfd25 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit c8a7fef0c54c269afd6c6bdf035af1a7574d11cb
author goeckslab
date Fri, 30 Jan 2026 14:20:49 +0000
parents 375c36923da1
children
comparison
equal deleted inserted replaced
7:ed2fefc8d892 8:a48e750cfd25
33 except Exception: 33 except Exception:
34 logger.warning("Could not write predictor_path.txt") 34 logger.warning("Could not write predictor_path.txt")
35 return None 35 return None
36 36
37 37
38 def _copy_config_if_available(pred_path: Optional[str], output_config: Optional[str]): 38 def _replace_local_backbone_strings(obj: Any, replacements: Optional[dict]) -> Any:
39 if not replacements:
40 return obj
41 if isinstance(obj, str):
42 out = obj
43 for needle, repl in replacements.items():
44 if repl:
45 out = out.replace(needle, str(repl))
46 return out
47 if isinstance(obj, dict):
48 return {k: _replace_local_backbone_strings(v, replacements) for k, v in obj.items()}
49 if isinstance(obj, list):
50 return [_replace_local_backbone_strings(v, replacements) for v in obj]
51 if isinstance(obj, tuple):
52 return tuple(_replace_local_backbone_strings(v, replacements) for v in obj)
53 return obj
54
55
56 def _copy_config_if_available(
57 pred_path: Optional[str],
58 output_config: Optional[str],
59 replacements: Optional[dict] = None,
60 cfg_yaml: Optional[dict] = None,
61 ):
39 if not output_config: 62 if not output_config:
40 return 63 return
41 try: 64 try:
42 config_yaml_path = os.path.join(pred_path, "config.yaml") if pred_path else None 65 config_yaml_path = os.path.join(pred_path, "config.yaml") if pred_path else None
43 if config_yaml_path and os.path.isfile(config_yaml_path): 66 if config_yaml_path and os.path.isfile(config_yaml_path):
44 shutil.copy2(config_yaml_path, output_config) 67 try:
45 logger.info(f"Wrote AutoGluon config → {output_config}") 68 if isinstance(cfg_yaml, dict) and cfg_yaml:
69 cfg_raw = cfg_yaml
70 else:
71 with open(config_yaml_path, "r") as cfg_in:
72 cfg_raw = yaml.safe_load(cfg_in) or {}
73 cfg_sanitized = _replace_local_backbone_strings(cfg_raw, replacements)
74 with open(output_config, "w") as cfg_out:
75 yaml.safe_dump(cfg_sanitized, cfg_out, sort_keys=False)
76 logger.info(f"Wrote AutoGluon config → {output_config}")
77 except Exception as e:
78 logger.warning(f"Failed to sanitize config.yaml; copying raw file instead: {e}")
79 shutil.copy2(config_yaml_path, output_config)
80 logger.info(f"Wrote AutoGluon config → {output_config}")
46 else: 81 else:
47 with open(output_config, "w") as cfg_out: 82 with open(output_config, "w") as cfg_out:
48 cfg_out.write("# config.yaml not found for this run\n") 83 cfg_out.write("# config.yaml not found for this run\n")
49 logger.warning(f"AutoGluon config.yaml not found; created placeholder at {output_config}") 84 logger.warning(f"AutoGluon config.yaml not found; created placeholder at {output_config}")
50 except Exception as e: 85 except Exception as e:
76 except Exception: 111 except Exception:
77 continue 112 continue
78 return {} 113 return {}
79 114
80 115
81 def _summarize_config(cfg: dict, args) -> List[tuple[str, str]]: 116 def _is_local_checkpoint(val: Optional[str]) -> bool:
117 return isinstance(val, str) and val.startswith("local://")
118
119
120 def _read_checkpoint_from_cfg(cfg: dict, model_key: str) -> Optional[str]:
121 if not isinstance(cfg, dict):
122 return None
123 model_cfg = cfg.get("model", {})
124 if not isinstance(model_cfg, dict):
125 return None
126 block = model_cfg.get(model_key)
127 if isinstance(block, str):
128 return block
129 if isinstance(block, dict):
130 for k in ("checkpoint_name", "model_name", "name"):
131 val = block.get(k)
132 if val:
133 return val
134 return None
135
136
137 def _read_checkpoint_from_hparams(hparams: Optional[dict], model_key: str) -> Optional[str]:
138 if not isinstance(hparams, dict):
139 return None
140 model_block = hparams.get("model", {})
141 if isinstance(model_block, dict):
142 block = model_block.get(model_key)
143 if isinstance(block, str):
144 return block
145 if isinstance(block, dict):
146 for k in ("checkpoint_name", "model_name", "name"):
147 val = block.get(k)
148 if val:
149 return val
150 for k in ("checkpoint_name", "model_name", "name"):
151 dotted = hparams.get(f"model.{model_key}.{k}")
152 if dotted:
153 return dotted
154 return None
155
156
157 def _resolve_backbone_choice(cfg: dict, ag_config: Optional[dict], args, model_key: str, arg_attr: str) -> Optional[str]:
158 candidates = [
159 _read_checkpoint_from_cfg(cfg, model_key),
160 _read_checkpoint_from_hparams((ag_config or {}).get("hyperparameters"), model_key),
161 getattr(args, arg_attr, None),
162 ]
163 for val in candidates:
164 if val and not _is_local_checkpoint(val):
165 return str(val)
166 return None
167
168
169 def _build_backbone_replacements(cfg: dict, ag_config: Optional[dict], args) -> dict:
170 replacements = {}
171 text_choice = _resolve_backbone_choice(cfg, ag_config, args, "hf_text", "backbone_text")
172 image_choice = _resolve_backbone_choice(cfg, ag_config, args, "timm_image", "backbone_image")
173 if text_choice:
174 replacements["local://hf_text"] = text_choice
175 if image_choice:
176 replacements["local://timm_image"] = image_choice
177 return replacements
178
179
180 def _summarize_config(cfg: dict, args, ag_config: Optional[dict] = None) -> List[tuple[str, str]]:
82 """ 181 """
83 Build rows describing model components and key hyperparameters from a loaded config.yaml. 182 Build rows describing model components and key hyperparameters from a loaded config.yaml.
84 Falls back to CLI args when config values are missing. 183 Falls back to AutoGluon hyperparameters and CLI args when config values are missing.
85 """ 184 """
86 rows: List[tuple[str, str]] = [] 185 rows: List[tuple[str, str]] = []
87 model_cfg = cfg.get("model", {}) if isinstance(cfg, dict) else {} 186 model_cfg = cfg.get("model", {}) if isinstance(cfg, dict) else {}
88 names = model_cfg.get("names") or [] 187 names = model_cfg.get("names") or []
89 if names: 188 if names:
100 dt_str = ", ".join(dtypes) if dtypes else "" 199 dt_str = ", ".join(dtypes) if dtypes else ""
101 tabular_val = f"{k} ({dt_str})" if dt_str else k 200 tabular_val = f"{k} ({dt_str})" if dt_str else k
102 break 201 break
103 rows.append(("Tabular backbone", tabular_val)) 202 rows.append(("Tabular backbone", tabular_val))
104 203
105 image_val = model_cfg.get("timm_image", {}).get("checkpoint_name") or "—" 204 image_val = _resolve_backbone_choice(cfg, ag_config, args, "timm_image", "backbone_image")
106 rows.append(("Image backbone", image_val)) 205 rows.append(("Image backbone", image_val or "—"))
107 206
108 text_val = model_cfg.get("hf_text", {}).get("checkpoint_name") or "—" 207 text_val = _resolve_backbone_choice(cfg, ag_config, args, "hf_text", "backbone_text")
109 rows.append(("Text backbone", text_val)) 208 rows.append(("Text backbone", text_val or "—"))
110 209
111 fusion_val = "—" 210 fusion_val = "—"
112 for k in model_cfg.keys(): 211 for k in model_cfg.keys():
113 if str(k).startswith("fusion"): 212 if str(k).startswith("fusion"):
114 fusion_val = k 213 fusion_val = k
148 data_ctx: dict, 247 data_ctx: dict,
149 raw_folds=None, 248 raw_folds=None,
150 ag_folds=None, 249 ag_folds=None,
151 raw_metrics_std=None, 250 raw_metrics_std=None,
152 ag_by_split_std=None, 251 ag_by_split_std=None,
252 ag_config: Optional[dict] = None,
153 ): 253 ):
154 from plot_logic import ( 254 from plot_logic import (
155 build_summary_html, 255 build_summary_html,
156 build_test_html_and_plots, 256 build_test_html_and_plots,
157 build_feature_html, 257 build_feature_html,
162 from metrics_logic import aggregate_metrics 262 from metrics_logic import aggregate_metrics
163 263
164 raw_metrics = eval_results.get("raw_metrics", {}) 264 raw_metrics = eval_results.get("raw_metrics", {})
165 ag_by_split = eval_results.get("ag_eval", {}) 265 ag_by_split = eval_results.get("ag_eval", {})
166 fit_summary_obj = eval_results.get("fit_summary") 266 fit_summary_obj = eval_results.get("fit_summary")
267 roc_curves = eval_results.get("roc_curves")
268
269 cfg_yaml = _load_config_yaml(args, predictor)
270 backbone_replacements = _build_backbone_replacements(cfg_yaml, ag_config, args)
271 fit_summary_obj = _replace_local_backbone_strings(fit_summary_obj, backbone_replacements)
167 272
168 df_train = data_ctx.get("train") 273 df_train = data_ctx.get("train")
169 df_val = data_ctx.get("val") 274 df_val = data_ctx.get("val")
170 df_test_internal = data_ctx.get("test_internal") 275 df_test_internal = data_ctx.get("test_internal")
171 df_test_external = data_ctx.get("test_external") 276 df_test_external = data_ctx.get("test_external")
201 "test": raw_metrics.get("Test", {}), 306 "test": raw_metrics.get("Test", {}),
202 "test_external": raw_metrics.get("Test (external)", {}), 307 "test_external": raw_metrics.get("Test (external)", {}),
203 "ag_eval": ag_by_split, 308 "ag_eval": ag_by_split,
204 "ag_eval_std": ag_by_split_std, 309 "ag_eval_std": ag_by_split_std,
205 "fit_summary": fit_summary_obj, 310 "fit_summary": fit_summary_obj,
311 "roc_curves": roc_curves,
206 "problem_type": problem_type, 312 "problem_type": problem_type,
207 "predictor_path": getattr(predictor, "path", None), 313 "predictor_path": getattr(predictor, "path", None),
208 "threshold": args.threshold, 314 "threshold": args.threshold,
209 "threshold_test": args.threshold, 315 "threshold_test": args.threshold,
210 "preset": args.preset, 316 "preset": args.preset,
240 include_test=True, 346 include_test=True,
241 title=None, 347 title=None,
242 show_title=False, 348 show_title=False,
243 ) 349 )
244 350
245 cfg_yaml = _load_config_yaml(args, predictor) 351 config_rows = _summarize_config(cfg_yaml, args, ag_config=ag_config)
246 config_rows = _summarize_config(cfg_yaml, args)
247 threshold_rows = [] 352 threshold_rows = []
248 if problem_type == "binary" and args.threshold is not None: 353 if problem_type == "binary" and args.threshold is not None:
249 threshold_rows.append(("Decision threshold (Test)", f"{float(args.threshold):.3f}")) 354 threshold_rows.append(("Decision threshold (Test)", f"{float(args.threshold):.3f}"))
250 extra_run_rows = [ 355 extra_run_rows = [
251 ("Target column", label_col), 356 ("Target column", label_col),
313 418
314 is_multimodal = isinstance(predictor, MultiModalPredictor) 419 is_multimodal = isinstance(predictor, MultiModalPredictor)
315 leaderboard_html = "" if is_multimodal else build_leaderboard_html(predictor) 420 leaderboard_html = "" if is_multimodal else build_leaderboard_html(predictor)
316 inputs_html = "" 421 inputs_html = ""
317 ignored_features_html = "" if is_multimodal else build_ignored_features_html(predictor, df_train_full) 422 ignored_features_html = "" if is_multimodal else build_ignored_features_html(predictor, df_train_full)
318 presets_hparams_html = build_presets_hparams_html(predictor) 423 presets_hparams_html = build_presets_hparams_html(predictor, backbone_replacements)
319 notices: List[str] = [] 424 notices: List[str] = []
320 if args.threshold is not None and problem_type == "binary": 425 if args.threshold is not None and problem_type == "binary":
321 notices.append(f"Using decision threshold = {float(args.threshold):.3f} on Test.") 426 notices.append(f"Using decision threshold = {float(args.threshold):.3f} on Test.")
322 warnings_html = build_warnings_html([], notices) 427 warnings_html = build_warnings_html([], notices)
323 repro_html = build_reproducibility_html(args, {}, getattr(predictor, "path", None)) 428 repro_html = build_reproducibility_html(args, {}, getattr(predictor, "path", None))
348 with open(args.output_html, "w") as f: 453 with open(args.output_html, "w") as f:
349 f.write(full_html) 454 f.write(full_html)
350 logger.info(f"Wrote HTML report → {args.output_html}") 455 logger.info(f"Wrote HTML report → {args.output_html}")
351 456
352 pred_path = _write_predictor_path(predictor) 457 pred_path = _write_predictor_path(predictor)
353 _copy_config_if_available(pred_path, args.output_config) 458 _copy_config_if_available(
459 pred_path,
460 args.output_config,
461 backbone_replacements,
462 cfg_yaml=cfg_yaml,
463 )
354 464
355 outputs_to_check = [ 465 outputs_to_check = [
356 (args.output_json, "JSON results"), 466 (args.output_json, "JSON results"),
357 (args.output_html, "HTML report"), 467 (args.output_html, "HTML report"),
358 ] 468 ]
884 <p>The following columns were not used by the trained predictor at inference time:</p> 994 <p>The following columns were not used by the trained predictor at inference time:</p>
885 <ul>{items}</ul> 995 <ul>{items}</ul>
886 """ 996 """
887 997
888 998
889 def build_presets_hparams_html(predictor) -> str: 999 def build_presets_hparams_html(predictor, replacements: Optional[dict] = None) -> str:
890 # MultiModalPredictor path 1000 # MultiModalPredictor path
891 mm_hp = {} 1001 mm_hp = {}
892 for attr in ("_config", "config", "_fit_args"): 1002 for attr in ("_config", "config", "_fit_args"):
893 if hasattr(predictor, attr): 1003 if hasattr(predictor, attr):
894 try: 1004 try:
895 val = getattr(predictor, attr) 1005 val = getattr(predictor, attr)
896 # make it JSON-ish 1006 # make it JSON-ish
897 mm_hp[attr] = str(val) 1007 mm_hp[attr] = _replace_local_backbone_strings(str(val), replacements)
898 except Exception: 1008 except Exception:
899 continue 1009 continue
900 hp_html = f"<pre>{html.escape(json.dumps(mm_hp, indent=2))}</pre>" if mm_hp else "<i>Unavailable</i>" 1010 hp_html = f"<pre>{html.escape(json.dumps(mm_hp, indent=2))}</pre>" if mm_hp else "<i>Unavailable</i>"
901 return f"<h3>Training Presets & Hyperparameters</h3><details open><summary>Show hyperparameters</summary>{hp_html}</details>" 1011 return f"<h3>Training Presets & Hyperparameters</h3><details open><summary>Show hyperparameters</summary>{hp_html}</details>"
902 1012