comparison ludwig_backend.py @ 12:bcfa2e234a80 draft

planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
author goeckslab
date Fri, 21 Nov 2025 15:58:13 +0000
parents
children
comparison
equal deleted inserted replaced
11:c5150cceab47 12:bcfa2e234a80
1 import json
2 import logging
3 from pathlib import Path
4 from typing import Any, Dict, Optional, Protocol, Tuple
5
6 import pandas as pd
7 import pandas.api.types as ptypes
8 import yaml
9 from constants import (
10 IMAGE_PATH_COLUMN_NAME,
11 LABEL_COLUMN_NAME,
12 MODEL_ENCODER_TEMPLATES,
13 SPLIT_COLUMN_NAME,
14 )
15 from html_structure import (
16 build_tabbed_html,
17 encode_image_to_base64,
18 format_config_table_html,
19 format_stats_table_html,
20 format_test_merged_stats_table_html,
21 format_train_val_stats_table_html,
22 get_html_closing,
23 get_html_template,
24 get_metrics_help_modal,
25 )
26 from ludwig.globals import (
27 DESCRIPTION_FILE_NAME,
28 PREDICTIONS_PARQUET_FILE_NAME,
29 TEST_STATISTICS_FILE_NAME,
30 TRAIN_SET_METADATA_FILE_NAME,
31 )
32 from ludwig.utils.data_utils import get_split_path
33 from metaformer_setup import get_visualizations_registry, META_DEFAULT_CFGS
34 from plotly_plots import build_classification_plots
35 from utils import detect_output_type, extract_metrics_from_json
36
37 logger = logging.getLogger("ImageLearner")
38
39
40 class Backend(Protocol):
41 """Interface for a machine learning backend."""
42
43 def prepare_config(
44 self,
45 config_params: Dict[str, Any],
46 split_config: Dict[str, Any],
47 ) -> str:
48 ...
49
50 def run_experiment(
51 self,
52 dataset_path: Path,
53 config_path: Path,
54 output_dir: Path,
55 random_seed: int,
56 ) -> None:
57 ...
58
59 def generate_plots(self, output_dir: Path) -> None:
60 ...
61
62 def generate_html_report(
63 self,
64 title: str,
65 output_dir: str,
66 config: Dict[str, Any],
67 split_info: str,
68 ) -> Path:
69 ...
70
71
72 class LudwigDirectBackend:
73 """Backend for running Ludwig experiments directly via the internal experiment_cli function."""
74
75 def _detect_image_dimensions(self, image_zip_path: str) -> Tuple[int, int]:
76 """Detect image dimensions from the first image in the dataset."""
77 try:
78 import zipfile
79 from PIL import Image
80 import io
81
82 # Check if image_zip is provided
83 if not image_zip_path:
84 logger.warning("No image zip provided, using default 224x224")
85 return 224, 224
86
87 # Extract first image to detect dimensions
88 with zipfile.ZipFile(image_zip_path, 'r') as z:
89 image_files = [f for f in z.namelist() if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
90 if not image_files:
91 logger.warning("No image files found in zip, using default 224x224")
92 return 224, 224
93
94 # Check first image
95 with z.open(image_files[0]) as f:
96 img = Image.open(io.BytesIO(f.read()))
97 width, height = img.size
98 logger.info(f"Detected image dimensions: {width}x{height}")
99 return height, width # Return as (height, width) to match encoder config
100
101 except Exception as e:
102 logger.warning(f"Error detecting image dimensions: {e}, using default 224x224")
103 return 224, 224
104
105 def prepare_config(
106 self,
107 config_params: Dict[str, Any],
108 split_config: Dict[str, Any],
109 ) -> str:
110 logger.info("LudwigDirectBackend: Preparing YAML configuration.")
111
112 model_name = config_params.get("model_name", "resnet18")
113 use_pretrained = config_params.get("use_pretrained", False)
114 fine_tune = config_params.get("fine_tune", False)
115 if use_pretrained:
116 trainable = bool(fine_tune)
117 else:
118 trainable = True
119 epochs = config_params.get("epochs", 10)
120 batch_size = config_params.get("batch_size")
121 num_processes = config_params.get("preprocessing_num_processes", 1)
122 early_stop = config_params.get("early_stop", None)
123 learning_rate = config_params.get("learning_rate")
124 learning_rate = "auto" if learning_rate is None else float(learning_rate)
125 raw_encoder = MODEL_ENCODER_TEMPLATES.get(model_name, model_name)
126
127 # --- MetaFormer detection and config logic ---
128 def _is_metaformer(name: str) -> bool:
129 return isinstance(name, str) and name.startswith(
130 (
131 "identityformer_",
132 "randformer_",
133 "poolformerv2_",
134 "convformer_",
135 "caformer_",
136 )
137 )
138
139 # Check if this is a MetaFormer model (either direct name or in custom_model)
140 is_metaformer = (
141 _is_metaformer(model_name)
142 or (isinstance(raw_encoder, dict) and "custom_model" in raw_encoder and _is_metaformer(raw_encoder["custom_model"]))
143 )
144
145 metaformer_resize: Optional[Tuple[int, int]] = None
146 metaformer_channels = 3
147
148 if is_metaformer:
149 # Handle MetaFormer models
150 custom_model = None
151 if isinstance(raw_encoder, dict) and "custom_model" in raw_encoder:
152 custom_model = raw_encoder["custom_model"]
153 else:
154 custom_model = model_name
155
156 logger.info(f"DETECTED MetaFormer model: {custom_model}")
157 cfg_channels, cfg_height, cfg_width = 3, 224, 224
158 if META_DEFAULT_CFGS:
159 model_cfg = META_DEFAULT_CFGS.get(custom_model, {})
160 input_size = model_cfg.get("input_size")
161 if isinstance(input_size, (list, tuple)) and len(input_size) == 3:
162 cfg_channels, cfg_height, cfg_width = (
163 int(input_size[0]),
164 int(input_size[1]),
165 int(input_size[2]),
166 )
167
168 target_height, target_width = cfg_height, cfg_width
169 resize_value = config_params.get("image_resize")
170 if resize_value and resize_value != "original":
171 try:
172 dimensions = resize_value.split("x")
173 if len(dimensions) == 2:
174 target_height, target_width = int(dimensions[0]), int(dimensions[1])
175 if target_height <= 0 or target_width <= 0:
176 raise ValueError(
177 f"Image resize must be positive integers, received {resize_value}."
178 )
179 logger.info(f"MetaFormer explicit resize: {target_height}x{target_width}")
180 else:
181 raise ValueError(resize_value)
182 except (ValueError, IndexError):
183 logger.warning(
184 "Invalid image resize format '%s'; falling back to model default %sx%s",
185 resize_value,
186 cfg_height,
187 cfg_width,
188 )
189 target_height, target_width = cfg_height, cfg_width
190 else:
191 image_zip_path = config_params.get("image_zip", "")
192 detected_height, detected_width = self._detect_image_dimensions(image_zip_path)
193 if use_pretrained:
194 if (detected_height, detected_width) != (cfg_height, cfg_width):
195 logger.info(
196 "MetaFormer pretrained weights expect %sx%s; resizing from detected %sx%s",
197 cfg_height,
198 cfg_width,
199 detected_height,
200 detected_width,
201 )
202 else:
203 target_height, target_width = detected_height, detected_width
204 if target_height <= 0 or target_width <= 0:
205 raise ValueError(
206 f"Invalid detected image dimensions for MetaFormer: {target_height}x{target_width}."
207 )
208
209 metaformer_channels = cfg_channels
210 metaformer_resize = (target_height, target_width)
211
212 encoder_config = {
213 "type": "stacked_cnn",
214 "height": target_height,
215 "width": target_width,
216 "num_channels": metaformer_channels,
217 "output_size": 128,
218 "use_pretrained": use_pretrained,
219 "trainable": trainable,
220 "custom_model": custom_model,
221 }
222
223 elif isinstance(raw_encoder, dict):
224 # Handle image resize for regular encoders
225 # Note: Standard encoders like ResNet don't support height/width parameters
226 # Resize will be handled at the preprocessing level by Ludwig
227 if config_params.get("image_resize") and config_params["image_resize"] != "original":
228 logger.info(f"Resize requested: {config_params['image_resize']} for standard encoder. Resize will be handled at preprocessing level.")
229
230 encoder_config = {
231 **raw_encoder,
232 "use_pretrained": use_pretrained,
233 "trainable": trainable,
234 }
235 else:
236 encoder_config = {"type": raw_encoder}
237
238 batch_size_cfg = batch_size or "auto"
239
240 label_column_path = config_params.get("label_column_data_path")
241 label_series = None
242 label_metadata_hint = config_params.get("label_metadata") or {}
243 output_type_hint = config_params.get("output_type_hint")
244 num_unique_labels = int(label_metadata_hint.get("num_unique", 2))
245 numeric_binary_labels = bool(label_metadata_hint.get("is_numeric_binary", False))
246 likely_regression = bool(label_metadata_hint.get("likely_regression", False))
247 if label_column_path is not None and Path(label_column_path).exists():
248 try:
249 label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME]
250 non_na = label_series.dropna()
251 if not non_na.empty:
252 num_unique_labels = non_na.nunique()
253 is_numeric = ptypes.is_numeric_dtype(label_series.dtype)
254 numeric_binary_labels = is_numeric and num_unique_labels == 2
255 likely_regression = (
256 is_numeric and not numeric_binary_labels and num_unique_labels > 10
257 )
258 if numeric_binary_labels:
259 logger.info(
260 "Detected numeric binary labels in '%s'; configuring Ludwig for binary classification.",
261 LABEL_COLUMN_NAME,
262 )
263 except Exception as e:
264 logger.warning(f"Could not read label column for task detection: {e}")
265
266 if output_type_hint == "binary":
267 num_unique_labels = 2
268 numeric_binary_labels = numeric_binary_labels or bool(
269 label_metadata_hint.get("is_numeric", False)
270 )
271
272 if numeric_binary_labels:
273 task_type = "classification"
274 elif likely_regression:
275 task_type = "regression"
276 else:
277 task_type = "classification"
278
279 if task_type == "regression" and numeric_binary_labels:
280 logger.warning(
281 "Numeric binary labels detected but regression task chosen; forcing classification to avoid invalid Ludwig config."
282 )
283 task_type = "classification"
284
285 config_params["task_type"] = task_type
286
287 image_feat: Dict[str, Any] = {
288 "name": IMAGE_PATH_COLUMN_NAME,
289 "type": "image",
290 }
291 # Set preprocessing dimensions FIRST for MetaFormer models
292 if is_metaformer:
293 if metaformer_resize is None:
294 metaformer_resize = (224, 224)
295 height, width = metaformer_resize
296
297 # CRITICAL: Set preprocessing dimensions FIRST for MetaFormer models
298 # This is essential for MetaFormer models to work properly
299 if "preprocessing" not in image_feat:
300 image_feat["preprocessing"] = {}
301 image_feat["preprocessing"]["height"] = height
302 image_feat["preprocessing"]["width"] = width
303 # Use infer_image_dimensions=True to allow Ludwig to read images for validation
304 # but set explicit max dimensions to control the output size
305 image_feat["preprocessing"]["infer_image_dimensions"] = True
306 image_feat["preprocessing"]["infer_image_max_height"] = height
307 image_feat["preprocessing"]["infer_image_max_width"] = width
308 image_feat["preprocessing"]["num_channels"] = metaformer_channels
309 image_feat["preprocessing"]["resize_method"] = "interpolate" # Use interpolation for better quality
310 image_feat["preprocessing"]["standardize_image"] = "imagenet1k" # Use ImageNet standardization
311 # Force Ludwig to respect our dimensions by setting additional parameters
312 image_feat["preprocessing"]["requires_equal_dimensions"] = False
313 logger.info(f"Set preprocessing dimensions for MetaFormer: {height}x{width} (infer_dimensions=True with max dimensions to allow validation)")
314 # Now set the encoder configuration
315 image_feat["encoder"] = encoder_config
316
317 if config_params.get("augmentation") is not None:
318 image_feat["augmentation"] = config_params["augmentation"]
319
320 # Add resize configuration for standard encoders (ResNet, etc.)
321 # FIXED: MetaFormer models now respect user dimensions completely
322 # Previously there was a double resize issue where MetaFormer would force 224x224
323 # Now both MetaFormer and standard encoders respect user's resize choice
324 if (not is_metaformer) and config_params.get("image_resize") and config_params["image_resize"] != "original":
325 try:
326 dimensions = config_params["image_resize"].split("x")
327 if len(dimensions) == 2:
328 height, width = int(dimensions[0]), int(dimensions[1])
329 if height <= 0 or width <= 0:
330 raise ValueError(
331 f"Image resize must be positive integers, received {config_params['image_resize']}."
332 )
333
334 # Add resize to preprocessing for standard encoders
335 if "preprocessing" not in image_feat:
336 image_feat["preprocessing"] = {}
337 image_feat["preprocessing"]["height"] = height
338 image_feat["preprocessing"]["width"] = width
339 # Use infer_image_dimensions=True to allow Ludwig to read images for validation
340 # but set explicit max dimensions to control the output size
341 image_feat["preprocessing"]["infer_image_dimensions"] = True
342 image_feat["preprocessing"]["infer_image_max_height"] = height
343 image_feat["preprocessing"]["infer_image_max_width"] = width
344 logger.info(f"Added resize preprocessing: {height}x{width} for standard encoder with infer_image_dimensions=True and max dimensions")
345 except (ValueError, IndexError):
346 logger.warning(f"Invalid image resize format: {config_params['image_resize']}, skipping resize preprocessing")
347 if task_type == "regression":
348 output_feat = {
349 "name": LABEL_COLUMN_NAME,
350 "type": "number",
351 "decoder": {"type": "regressor"},
352 "loss": {"type": "mean_squared_error"},
353 }
354 val_metric = config_params.get("validation_metric", "mean_squared_error")
355
356 else:
357 if num_unique_labels == 2:
358 output_feat = {
359 "name": LABEL_COLUMN_NAME,
360 "type": "binary",
361 "loss": {"type": "binary_weighted_cross_entropy"},
362 }
363 if config_params.get("threshold") is not None:
364 output_feat["threshold"] = float(config_params["threshold"])
365 else:
366 output_feat = {
367 "name": LABEL_COLUMN_NAME,
368 "type": "category",
369 "loss": {"type": "softmax_cross_entropy"},
370 }
371 val_metric = None
372
373 conf: Dict[str, Any] = {
374 "model_type": "ecd",
375 "input_features": [image_feat],
376 "output_features": [output_feat],
377 "combiner": {"type": "concat"},
378 "trainer": {
379 "epochs": epochs,
380 "early_stop": early_stop,
381 "batch_size": batch_size_cfg,
382 "learning_rate": learning_rate,
383 # only set validation_metric for regression
384 **({"validation_metric": val_metric} if val_metric else {}),
385 },
386 "preprocessing": {
387 "split": split_config,
388 "num_processes": num_processes,
389 "in_memory": False,
390 },
391 }
392
393 logger.debug("LudwigDirectBackend: Config dict built.")
394 try:
395 yaml_str = yaml.dump(conf, sort_keys=False, indent=2)
396 logger.info("LudwigDirectBackend: YAML config generated.")
397 return yaml_str
398 except Exception:
399 logger.error(
400 "LudwigDirectBackend: Failed to serialize YAML.",
401 exc_info=True,
402 )
403 raise
404
405 def run_experiment(
406 self,
407 dataset_path: Path,
408 config_path: Path,
409 output_dir: Path,
410 random_seed: int = 42,
411 ) -> None:
412 """Invoke Ludwig's internal experiment_cli function to run the experiment."""
413 logger.info("LudwigDirectBackend: Starting experiment execution.")
414
415 try:
416 from ludwig.experiment import experiment_cli
417 except ImportError as e:
418 logger.error(
419 "LudwigDirectBackend: Could not import experiment_cli.",
420 exc_info=True,
421 )
422 raise RuntimeError("Ludwig import failed.") from e
423
424 output_dir.mkdir(parents=True, exist_ok=True)
425
426 try:
427 experiment_cli(
428 dataset=str(dataset_path),
429 config=str(config_path),
430 output_directory=str(output_dir),
431 random_seed=random_seed,
432 skip_preprocessing=True,
433 )
434 logger.info(
435 f"LudwigDirectBackend: Experiment completed. Results in {output_dir}"
436 )
437 except TypeError as e:
438 logger.error(
439 "LudwigDirectBackend: Argument mismatch in experiment_cli call.",
440 exc_info=True,
441 )
442 raise RuntimeError("Ludwig argument error.") from e
443 except Exception:
444 logger.error(
445 "LudwigDirectBackend: Experiment execution error.",
446 exc_info=True,
447 )
448 raise
449
450 def get_training_process(self, output_dir) -> Optional[Dict[str, Any]]:
451 """Retrieve the learning rate used in the most recent Ludwig run."""
452 output_dir = Path(output_dir)
453 exp_dirs = sorted(
454 output_dir.glob("experiment_run*"),
455 key=lambda p: p.stat().st_mtime,
456 )
457
458 if not exp_dirs:
459 logger.warning(f"No experiment run directories found in {output_dir}")
460 return None
461
462 progress_file = exp_dirs[-1] / "model" / "training_progress.json"
463 if not progress_file.exists():
464 logger.warning(f"No training_progress.json found in {progress_file}")
465 return None
466
467 try:
468 with progress_file.open("r", encoding="utf-8") as f:
469 data = json.load(f)
470 return {
471 "learning_rate": data.get("learning_rate"),
472 "batch_size": data.get("batch_size"),
473 "epoch": data.get("epoch"),
474 }
475 except Exception as e:
476 logger.warning(f"Failed to read training progress info: {e}")
477 return {}
478
479 def convert_parquet_to_csv(self, output_dir: Path):
480 """Convert the predictions Parquet file to CSV."""
481 output_dir = Path(output_dir)
482 exp_dirs = sorted(
483 output_dir.glob("experiment_run*"),
484 key=lambda p: p.stat().st_mtime,
485 )
486 if not exp_dirs:
487 logger.warning(f"No experiment run dirs found in {output_dir}")
488 return
489 exp_dir = exp_dirs[-1]
490 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME
491 csv_path = exp_dir / "predictions.csv"
492
493 # Check if parquet file exists before trying to convert
494 if not parquet_path.exists():
495 logger.info(f"Predictions parquet file not found at {parquet_path}, skipping conversion")
496 return
497
498 try:
499 df = pd.read_parquet(parquet_path)
500 df.to_csv(csv_path, index=False)
501 logger.info(f"Converted Parquet to CSV: {csv_path}")
502 except Exception as e:
503 logger.error(f"Error converting Parquet to CSV: {e}")
504
505 def generate_plots(self, output_dir: Path) -> None:
506 """Generate all registered Ludwig visualizations for the latest experiment run."""
507 logger.info("Generating all Ludwig visualizations…")
508
509 test_plots = {
510 "compare_performance",
511 "compare_classifiers_performance_from_prob",
512 "compare_classifiers_performance_from_pred",
513 "compare_classifiers_performance_changing_k",
514 "compare_classifiers_multiclass_multimetric",
515 "compare_classifiers_predictions",
516 "confidence_thresholding_2thresholds_2d",
517 "confidence_thresholding_2thresholds_3d",
518 "confidence_thresholding",
519 "confidence_thresholding_data_vs_acc",
520 "binary_threshold_vs_metric",
521 "roc_curves",
522 "roc_curves_from_test_statistics",
523 "calibration_1_vs_all",
524 "calibration_multiclass",
525 "confusion_matrix",
526 "frequency_vs_f1",
527 }
528 train_plots = {
529 "learning_curves",
530 "compare_classifiers_performance_subset",
531 }
532
533 output_dir = Path(output_dir)
534 exp_dirs = sorted(
535 output_dir.glob("experiment_run*"),
536 key=lambda p: p.stat().st_mtime,
537 )
538 if not exp_dirs:
539 logger.warning(f"No experiment run dirs found in {output_dir}")
540 return
541 exp_dir = exp_dirs[-1]
542
543 viz_dir = exp_dir / "visualizations"
544 viz_dir.mkdir(exist_ok=True)
545 train_viz = viz_dir / "train"
546 test_viz = viz_dir / "test"
547 train_viz.mkdir(parents=True, exist_ok=True)
548 test_viz.mkdir(parents=True, exist_ok=True)
549
550 def _check(p: Path) -> Optional[str]:
551 return str(p) if p.exists() else None
552
553 training_stats = _check(exp_dir / "training_statistics.json")
554 test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME)
555 probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME)
556 gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME)
557
558 dataset_path = None
559 split_file = None
560 desc = exp_dir / DESCRIPTION_FILE_NAME
561 if desc.exists():
562 with open(desc, "r") as f:
563 cfg = json.load(f)
564 dataset_path = _check(Path(cfg.get("dataset", "")))
565 split_file = _check(Path(get_split_path(cfg.get("dataset", ""))))
566
567 output_feature = ""
568 if desc.exists():
569 try:
570 output_feature = cfg["config"]["output_features"][0]["name"]
571 except Exception:
572 pass
573 if not output_feature and test_stats:
574 with open(test_stats, "r") as f:
575 stats = json.load(f)
576 output_feature = next(iter(stats.keys()), "")
577
578 viz_registry = get_visualizations_registry()
579 for viz_name, viz_func in viz_registry.items():
580 if viz_name in train_plots:
581 viz_dir_plot = train_viz
582 elif viz_name in test_plots:
583 viz_dir_plot = test_viz
584 else:
585 continue
586
587 try:
588 viz_func(
589 training_statistics=[training_stats] if training_stats else [],
590 test_statistics=[test_stats] if test_stats else [],
591 probabilities=[probs_path] if probs_path else [],
592 output_feature_name=output_feature,
593 ground_truth_split=2,
594 top_n_classes=[0],
595 top_k=3,
596 ground_truth_metadata=gt_metadata,
597 ground_truth=dataset_path,
598 split_file=split_file,
599 output_directory=str(viz_dir_plot),
600 normalize=False,
601 file_format="png",
602 )
603 logger.info(f"✔ Generated {viz_name}")
604 except Exception as e:
605 logger.warning(f"✘ Skipped {viz_name}: {e}")
606
607 logger.info(f"All visualizations written to {viz_dir}")
608
609 def generate_html_report(
610 self,
611 title: str,
612 output_dir: str,
613 config: dict,
614 split_info: str,
615 ) -> Path:
616 """Assemble an HTML report from visualizations under train_val/ and test/ folders."""
617 cwd = Path.cwd()
618 report_name = title.lower().replace(" ", "_") + "_report.html"
619 report_path = cwd / report_name
620 output_dir = Path(output_dir)
621 output_type = None
622
623 exp_dirs = sorted(
624 output_dir.glob("experiment_run*"),
625 key=lambda p: p.stat().st_mtime,
626 )
627 if not exp_dirs:
628 raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}")
629 exp_dir = exp_dirs[-1]
630
631 base_viz_dir = exp_dir / "visualizations"
632 train_viz_dir = base_viz_dir / "train"
633 test_viz_dir = base_viz_dir / "test"
634
635 html = get_html_template()
636
637 # Extra CSS & JS: center Plotly and enable CSV download for predictions table
638 html += """
639 <style>
640 /* Center Plotly figures (both wrapper and native classes) */
641 .plotly-center { display: flex; justify-content: center; }
642 .plotly-center .plotly-graph-div, .plotly-center .js-plotly-plot { margin: 0 auto !important; }
643 .js-plotly-plot, .plotly-graph-div { margin-left: auto !important; margin-right: auto !important; }
644
645 /* Download button for predictions table */
646 .download-btn {
647 padding: 8px 12px;
648 border: 1px solid #4CAF50;
649 background: #4CAF50;
650 color: white;
651 border-radius: 6px;
652 cursor: pointer;
653 }
654 .download-btn:hover { filter: brightness(0.95); }
655 .preds-controls {
656 display: flex;
657 justify-content: flex-end;
658 gap: 8px;
659 margin: 8px 0;
660 }
661 </style>
662 <script>
663 function tableToCSV(table){
664 const rows = Array.from(table.querySelectorAll('tr'));
665 return rows.map(row =>
666 Array.from(row.querySelectorAll('th,td')).map(cell => {
667 let text = cell.innerText.replace(/\\r?\\n|\\r/g,' ').trim();
668 if (text.includes('"') || text.includes(',')) {
669 text = '"' + text.replace(/"/g,'""') + '"';
670 }
671 return text;
672 }).join(',')
673 ).join('\\n');
674 }
675 document.addEventListener('DOMContentLoaded', function(){
676 const btn = document.getElementById('downloadPredsCsv');
677 if(btn){
678 btn.addEventListener('click', function(){
679 const tbl = document.querySelector('.predictions-table');
680 if(!tbl){ alert('Predictions table not found.'); return; }
681 const csv = tableToCSV(tbl);
682 const blob = new Blob([csv], {type: 'text/csv;charset=utf-8;'});
683 const url = URL.createObjectURL(blob);
684 const a = document.createElement('a');
685 a.href = url;
686 a.download = 'ground_truth_vs_predictions.csv';
687 document.body.appendChild(a);
688 a.click();
689 document.body.removeChild(a);
690 URL.revokeObjectURL(url);
691 });
692 }
693 });
694 </script>
695 """
696 html += f"<h1>{title}</h1>"
697
698 metrics_html = ""
699 train_val_metrics_html = ""
700 test_metrics_html = ""
701 try:
702 train_stats_path = exp_dir / "training_statistics.json"
703 test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME
704 if train_stats_path.exists() and test_stats_path.exists():
705 with open(train_stats_path) as f:
706 train_stats = json.load(f)
707 with open(test_stats_path) as f:
708 test_stats = json.load(f)
709 output_type = detect_output_type(test_stats)
710 metrics_html = format_stats_table_html(train_stats, test_stats, output_type)
711 train_val_metrics_html = format_train_val_stats_table_html(
712 train_stats, test_stats
713 )
714 test_metrics_html = format_test_merged_stats_table_html(
715 extract_metrics_from_json(train_stats, test_stats, output_type)[
716 "test"
717 ], output_type
718 )
719 except Exception as e:
720 logger.warning(
721 f"Could not load stats for HTML report: {type(e).__name__}: {e}"
722 )
723
724 config_html = ""
725 training_progress = self.get_training_process(output_dir)
726 try:
727 config_html = format_config_table_html(
728 config, split_info, training_progress, output_type
729 )
730 except Exception as e:
731 logger.warning(f"Could not load config for HTML report: {e}")
732
733 # ---------- image rendering with exclusions ----------
734 def render_img_section(
735 title: str,
736 dir_path: Path,
737 output_type: str = None,
738 exclude_names: Optional[set] = None,
739 ) -> str:
740 if not dir_path.exists():
741 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>"
742
743 exclude_names = exclude_names or set()
744
745 imgs = list(dir_path.glob("*.png"))
746
747 # Exclude ROC curves and standard confusion matrices (keep only entropy version)
748 default_exclude = {
749 # "roc_curves.png", # Remove ROC curves from test tab
750 "confusion_matrix__label_top5.png", # Remove standard confusion matrix
751 "confusion_matrix__label_top10.png", # Remove duplicate
752 "confusion_matrix__label_top6.png", # Remove duplicate
753 "confusion_matrix_entropy__label_top10.png", # Keep only top5
754 "confusion_matrix_entropy__label_top6.png", # Keep only top5
755 }
756 title_is_test = title.lower().startswith("test")
757 if title_is_test and output_type == "binary":
758 default_exclude.update(
759 {
760 "confusion_matrix__label_top2.png",
761 "confusion_matrix_entropy__label_top2.png",
762 "roc_curves_from_prediction_statistics.png",
763 }
764 )
765 elif title_is_test and output_type == "category":
766 default_exclude.update(
767 {
768 "compare_classifiers_multiclass_multimetric__label_best10.png",
769 "compare_classifiers_multiclass_multimetric__label_sorted.png",
770 "compare_classifiers_multiclass_multimetric__label_worst10.png",
771 }
772 )
773
774 imgs = [
775 img
776 for img in imgs
777 if img.name not in default_exclude
778 and img.name not in exclude_names
779 ]
780
781 if not imgs:
782 return f"<h2>{title}</h2><p><em>No plots found.</em></p>"
783
784 # Sort images by name for consistent ordering (works with string and numeric labels)
785 imgs = sorted(imgs, key=lambda x: x.name)
786
787 html_section = ""
788 custom_titles = {
789 "compare_classifiers_multiclass_multimetric__label_top10": "Metric Comparison by Label",
790 "compare_classifiers_performance_from_prob": "Label Metric Comparison by Probability",
791 }
792 for img in imgs:
793 b64 = encode_image_to_base64(str(img))
794 default_title = img.stem.replace("_", " ").title()
795 img_title = custom_titles.get(img.stem, default_title)
796 html_section += (
797 f"<h2 style='text-align: center;'>{img_title}</h2>"
798 f'<div class="plot" style="margin-bottom:20px;text-align:center;">'
799 f'<img src="data:image/png;base64,{b64}" '
800 f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />'
801 f"</div>"
802 )
803 return html_section
804
805 tab1_content = config_html + metrics_html
806
807 tab2_content = train_val_metrics_html + render_img_section(
808 "Training and Validation Visualizations",
809 train_viz_dir,
810 output_type,
811 exclude_names={
812 "compare_classifiers_performance_from_prob.png",
813 "roc_curves_from_prediction_statistics.png",
814 "precision_recall_curves_from_prediction_statistics.png",
815 "precision_recall_curve.png",
816 },
817 )
818
819 # --- Predictions vs Ground Truth table (REGRESSION ONLY) ---
820 preds_section = ""
821 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME
822 if output_type == "regression" and parquet_path.exists():
823 try:
824 # 1) load predictions from Parquet
825 df_preds = pd.read_parquet(parquet_path).reset_index(drop=True)
826 # assume the column containing your model's prediction is named "prediction"
827 # or contains that substring:
828 pred_col = next(
829 (c for c in df_preds.columns if "prediction" in c.lower()),
830 None,
831 )
832 if pred_col is None:
833 raise ValueError("No prediction column found in Parquet output")
834 df_pred = df_preds[[pred_col]].rename(columns={pred_col: "prediction"})
835
836 # 2) load ground truth for the test split from prepared CSV
837 df_all = pd.read_csv(config["label_column_data_path"])
838 df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][
839 LABEL_COLUMN_NAME
840 ].reset_index(drop=True)
841 # 3) concatenate side-by-side
842 df_table = pd.concat([df_gt, df_pred], axis=1)
843 df_table.columns = [LABEL_COLUMN_NAME, "prediction"]
844
845 # 4) render as HTML
846 preds_html = df_table.to_html(index=False, classes="predictions-table")
847 preds_section = (
848 "<h2 style='text-align: center;'>Ground Truth vs. Predictions</h2>"
849 "<div class='preds-controls'>"
850 "<button id='downloadPredsCsv' class='download-btn'>Download CSV</button>"
851 "</div>"
852 "<div class='scroll-rows-30' style='overflow-x:auto; overflow-y:auto; max-height:900px; margin-bottom:20px;'>"
853 + preds_html
854 + "</div>"
855 )
856 except Exception as e:
857 logger.warning(f"Could not build Predictions vs GT table: {e}")
858
859 tab3_content = test_metrics_html + preds_section
860
861 if output_type in ("binary", "category") and test_stats_path.exists():
862 try:
863 interactive_plots = build_classification_plots(
864 str(test_stats_path),
865 str(train_stats_path) if train_stats_path.exists() else None,
866 )
867 for plot in interactive_plots:
868 tab3_content += (
869 f"<h2 style='text-align: center;'>{plot['title']}</h2>"
870 f"<div class='plotly-center'>{plot['html']}</div>"
871 )
872 logger.info(f"Generated {len(interactive_plots)} interactive Plotly plots")
873 except Exception as e:
874 logger.warning(f"Could not generate Plotly plots: {e}")
875
876 # Add static TEST PNGs (with default dedupe/exclusions)
877 tab3_content += render_img_section(
878 "Test Visualizations", test_viz_dir, output_type
879 )
880
881 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content)
882 modal_html = get_metrics_help_modal()
883 html += tabbed_html + modal_html + get_html_closing()
884
885 try:
886 with open(report_path, "w") as f:
887 f.write(html)
888 logger.info(f"HTML report generated at: {report_path}")
889 except Exception as e:
890 logger.error(f"Failed to write HTML report: {e}")
891 raise
892
893 return report_path