Mercurial > repos > goeckslab > image_learner
comparison ludwig_backend.py @ 12:bcfa2e234a80 draft
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
| author | goeckslab |
|---|---|
| date | Fri, 21 Nov 2025 15:58:13 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| 11:c5150cceab47 | 12:bcfa2e234a80 |
|---|---|
| 1 import json | |
| 2 import logging | |
| 3 from pathlib import Path | |
| 4 from typing import Any, Dict, Optional, Protocol, Tuple | |
| 5 | |
| 6 import pandas as pd | |
| 7 import pandas.api.types as ptypes | |
| 8 import yaml | |
| 9 from constants import ( | |
| 10 IMAGE_PATH_COLUMN_NAME, | |
| 11 LABEL_COLUMN_NAME, | |
| 12 MODEL_ENCODER_TEMPLATES, | |
| 13 SPLIT_COLUMN_NAME, | |
| 14 ) | |
| 15 from html_structure import ( | |
| 16 build_tabbed_html, | |
| 17 encode_image_to_base64, | |
| 18 format_config_table_html, | |
| 19 format_stats_table_html, | |
| 20 format_test_merged_stats_table_html, | |
| 21 format_train_val_stats_table_html, | |
| 22 get_html_closing, | |
| 23 get_html_template, | |
| 24 get_metrics_help_modal, | |
| 25 ) | |
| 26 from ludwig.globals import ( | |
| 27 DESCRIPTION_FILE_NAME, | |
| 28 PREDICTIONS_PARQUET_FILE_NAME, | |
| 29 TEST_STATISTICS_FILE_NAME, | |
| 30 TRAIN_SET_METADATA_FILE_NAME, | |
| 31 ) | |
| 32 from ludwig.utils.data_utils import get_split_path | |
| 33 from metaformer_setup import get_visualizations_registry, META_DEFAULT_CFGS | |
| 34 from plotly_plots import build_classification_plots | |
| 35 from utils import detect_output_type, extract_metrics_from_json | |
| 36 | |
| 37 logger = logging.getLogger("ImageLearner") | |
| 38 | |
| 39 | |
| 40 class Backend(Protocol): | |
| 41 """Interface for a machine learning backend.""" | |
| 42 | |
| 43 def prepare_config( | |
| 44 self, | |
| 45 config_params: Dict[str, Any], | |
| 46 split_config: Dict[str, Any], | |
| 47 ) -> str: | |
| 48 ... | |
| 49 | |
| 50 def run_experiment( | |
| 51 self, | |
| 52 dataset_path: Path, | |
| 53 config_path: Path, | |
| 54 output_dir: Path, | |
| 55 random_seed: int, | |
| 56 ) -> None: | |
| 57 ... | |
| 58 | |
| 59 def generate_plots(self, output_dir: Path) -> None: | |
| 60 ... | |
| 61 | |
| 62 def generate_html_report( | |
| 63 self, | |
| 64 title: str, | |
| 65 output_dir: str, | |
| 66 config: Dict[str, Any], | |
| 67 split_info: str, | |
| 68 ) -> Path: | |
| 69 ... | |
| 70 | |
| 71 | |
| 72 class LudwigDirectBackend: | |
| 73 """Backend for running Ludwig experiments directly via the internal experiment_cli function.""" | |
| 74 | |
| 75 def _detect_image_dimensions(self, image_zip_path: str) -> Tuple[int, int]: | |
| 76 """Detect image dimensions from the first image in the dataset.""" | |
| 77 try: | |
| 78 import zipfile | |
| 79 from PIL import Image | |
| 80 import io | |
| 81 | |
| 82 # Check if image_zip is provided | |
| 83 if not image_zip_path: | |
| 84 logger.warning("No image zip provided, using default 224x224") | |
| 85 return 224, 224 | |
| 86 | |
| 87 # Extract first image to detect dimensions | |
| 88 with zipfile.ZipFile(image_zip_path, 'r') as z: | |
| 89 image_files = [f for f in z.namelist() if f.lower().endswith(('.png', '.jpg', '.jpeg'))] | |
| 90 if not image_files: | |
| 91 logger.warning("No image files found in zip, using default 224x224") | |
| 92 return 224, 224 | |
| 93 | |
| 94 # Check first image | |
| 95 with z.open(image_files[0]) as f: | |
| 96 img = Image.open(io.BytesIO(f.read())) | |
| 97 width, height = img.size | |
| 98 logger.info(f"Detected image dimensions: {width}x{height}") | |
| 99 return height, width # Return as (height, width) to match encoder config | |
| 100 | |
| 101 except Exception as e: | |
| 102 logger.warning(f"Error detecting image dimensions: {e}, using default 224x224") | |
| 103 return 224, 224 | |
| 104 | |
| 105 def prepare_config( | |
| 106 self, | |
| 107 config_params: Dict[str, Any], | |
| 108 split_config: Dict[str, Any], | |
| 109 ) -> str: | |
| 110 logger.info("LudwigDirectBackend: Preparing YAML configuration.") | |
| 111 | |
| 112 model_name = config_params.get("model_name", "resnet18") | |
| 113 use_pretrained = config_params.get("use_pretrained", False) | |
| 114 fine_tune = config_params.get("fine_tune", False) | |
| 115 if use_pretrained: | |
| 116 trainable = bool(fine_tune) | |
| 117 else: | |
| 118 trainable = True | |
| 119 epochs = config_params.get("epochs", 10) | |
| 120 batch_size = config_params.get("batch_size") | |
| 121 num_processes = config_params.get("preprocessing_num_processes", 1) | |
| 122 early_stop = config_params.get("early_stop", None) | |
| 123 learning_rate = config_params.get("learning_rate") | |
| 124 learning_rate = "auto" if learning_rate is None else float(learning_rate) | |
| 125 raw_encoder = MODEL_ENCODER_TEMPLATES.get(model_name, model_name) | |
| 126 | |
| 127 # --- MetaFormer detection and config logic --- | |
| 128 def _is_metaformer(name: str) -> bool: | |
| 129 return isinstance(name, str) and name.startswith( | |
| 130 ( | |
| 131 "identityformer_", | |
| 132 "randformer_", | |
| 133 "poolformerv2_", | |
| 134 "convformer_", | |
| 135 "caformer_", | |
| 136 ) | |
| 137 ) | |
| 138 | |
| 139 # Check if this is a MetaFormer model (either direct name or in custom_model) | |
| 140 is_metaformer = ( | |
| 141 _is_metaformer(model_name) | |
| 142 or (isinstance(raw_encoder, dict) and "custom_model" in raw_encoder and _is_metaformer(raw_encoder["custom_model"])) | |
| 143 ) | |
| 144 | |
| 145 metaformer_resize: Optional[Tuple[int, int]] = None | |
| 146 metaformer_channels = 3 | |
| 147 | |
| 148 if is_metaformer: | |
| 149 # Handle MetaFormer models | |
| 150 custom_model = None | |
| 151 if isinstance(raw_encoder, dict) and "custom_model" in raw_encoder: | |
| 152 custom_model = raw_encoder["custom_model"] | |
| 153 else: | |
| 154 custom_model = model_name | |
| 155 | |
| 156 logger.info(f"DETECTED MetaFormer model: {custom_model}") | |
| 157 cfg_channels, cfg_height, cfg_width = 3, 224, 224 | |
| 158 if META_DEFAULT_CFGS: | |
| 159 model_cfg = META_DEFAULT_CFGS.get(custom_model, {}) | |
| 160 input_size = model_cfg.get("input_size") | |
| 161 if isinstance(input_size, (list, tuple)) and len(input_size) == 3: | |
| 162 cfg_channels, cfg_height, cfg_width = ( | |
| 163 int(input_size[0]), | |
| 164 int(input_size[1]), | |
| 165 int(input_size[2]), | |
| 166 ) | |
| 167 | |
| 168 target_height, target_width = cfg_height, cfg_width | |
| 169 resize_value = config_params.get("image_resize") | |
| 170 if resize_value and resize_value != "original": | |
| 171 try: | |
| 172 dimensions = resize_value.split("x") | |
| 173 if len(dimensions) == 2: | |
| 174 target_height, target_width = int(dimensions[0]), int(dimensions[1]) | |
| 175 if target_height <= 0 or target_width <= 0: | |
| 176 raise ValueError( | |
| 177 f"Image resize must be positive integers, received {resize_value}." | |
| 178 ) | |
| 179 logger.info(f"MetaFormer explicit resize: {target_height}x{target_width}") | |
| 180 else: | |
| 181 raise ValueError(resize_value) | |
| 182 except (ValueError, IndexError): | |
| 183 logger.warning( | |
| 184 "Invalid image resize format '%s'; falling back to model default %sx%s", | |
| 185 resize_value, | |
| 186 cfg_height, | |
| 187 cfg_width, | |
| 188 ) | |
| 189 target_height, target_width = cfg_height, cfg_width | |
| 190 else: | |
| 191 image_zip_path = config_params.get("image_zip", "") | |
| 192 detected_height, detected_width = self._detect_image_dimensions(image_zip_path) | |
| 193 if use_pretrained: | |
| 194 if (detected_height, detected_width) != (cfg_height, cfg_width): | |
| 195 logger.info( | |
| 196 "MetaFormer pretrained weights expect %sx%s; resizing from detected %sx%s", | |
| 197 cfg_height, | |
| 198 cfg_width, | |
| 199 detected_height, | |
| 200 detected_width, | |
| 201 ) | |
| 202 else: | |
| 203 target_height, target_width = detected_height, detected_width | |
| 204 if target_height <= 0 or target_width <= 0: | |
| 205 raise ValueError( | |
| 206 f"Invalid detected image dimensions for MetaFormer: {target_height}x{target_width}." | |
| 207 ) | |
| 208 | |
| 209 metaformer_channels = cfg_channels | |
| 210 metaformer_resize = (target_height, target_width) | |
| 211 | |
| 212 encoder_config = { | |
| 213 "type": "stacked_cnn", | |
| 214 "height": target_height, | |
| 215 "width": target_width, | |
| 216 "num_channels": metaformer_channels, | |
| 217 "output_size": 128, | |
| 218 "use_pretrained": use_pretrained, | |
| 219 "trainable": trainable, | |
| 220 "custom_model": custom_model, | |
| 221 } | |
| 222 | |
| 223 elif isinstance(raw_encoder, dict): | |
| 224 # Handle image resize for regular encoders | |
| 225 # Note: Standard encoders like ResNet don't support height/width parameters | |
| 226 # Resize will be handled at the preprocessing level by Ludwig | |
| 227 if config_params.get("image_resize") and config_params["image_resize"] != "original": | |
| 228 logger.info(f"Resize requested: {config_params['image_resize']} for standard encoder. Resize will be handled at preprocessing level.") | |
| 229 | |
| 230 encoder_config = { | |
| 231 **raw_encoder, | |
| 232 "use_pretrained": use_pretrained, | |
| 233 "trainable": trainable, | |
| 234 } | |
| 235 else: | |
| 236 encoder_config = {"type": raw_encoder} | |
| 237 | |
| 238 batch_size_cfg = batch_size or "auto" | |
| 239 | |
| 240 label_column_path = config_params.get("label_column_data_path") | |
| 241 label_series = None | |
| 242 label_metadata_hint = config_params.get("label_metadata") or {} | |
| 243 output_type_hint = config_params.get("output_type_hint") | |
| 244 num_unique_labels = int(label_metadata_hint.get("num_unique", 2)) | |
| 245 numeric_binary_labels = bool(label_metadata_hint.get("is_numeric_binary", False)) | |
| 246 likely_regression = bool(label_metadata_hint.get("likely_regression", False)) | |
| 247 if label_column_path is not None and Path(label_column_path).exists(): | |
| 248 try: | |
| 249 label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME] | |
| 250 non_na = label_series.dropna() | |
| 251 if not non_na.empty: | |
| 252 num_unique_labels = non_na.nunique() | |
| 253 is_numeric = ptypes.is_numeric_dtype(label_series.dtype) | |
| 254 numeric_binary_labels = is_numeric and num_unique_labels == 2 | |
| 255 likely_regression = ( | |
| 256 is_numeric and not numeric_binary_labels and num_unique_labels > 10 | |
| 257 ) | |
| 258 if numeric_binary_labels: | |
| 259 logger.info( | |
| 260 "Detected numeric binary labels in '%s'; configuring Ludwig for binary classification.", | |
| 261 LABEL_COLUMN_NAME, | |
| 262 ) | |
| 263 except Exception as e: | |
| 264 logger.warning(f"Could not read label column for task detection: {e}") | |
| 265 | |
| 266 if output_type_hint == "binary": | |
| 267 num_unique_labels = 2 | |
| 268 numeric_binary_labels = numeric_binary_labels or bool( | |
| 269 label_metadata_hint.get("is_numeric", False) | |
| 270 ) | |
| 271 | |
| 272 if numeric_binary_labels: | |
| 273 task_type = "classification" | |
| 274 elif likely_regression: | |
| 275 task_type = "regression" | |
| 276 else: | |
| 277 task_type = "classification" | |
| 278 | |
| 279 if task_type == "regression" and numeric_binary_labels: | |
| 280 logger.warning( | |
| 281 "Numeric binary labels detected but regression task chosen; forcing classification to avoid invalid Ludwig config." | |
| 282 ) | |
| 283 task_type = "classification" | |
| 284 | |
| 285 config_params["task_type"] = task_type | |
| 286 | |
| 287 image_feat: Dict[str, Any] = { | |
| 288 "name": IMAGE_PATH_COLUMN_NAME, | |
| 289 "type": "image", | |
| 290 } | |
| 291 # Set preprocessing dimensions FIRST for MetaFormer models | |
| 292 if is_metaformer: | |
| 293 if metaformer_resize is None: | |
| 294 metaformer_resize = (224, 224) | |
| 295 height, width = metaformer_resize | |
| 296 | |
| 297 # CRITICAL: Set preprocessing dimensions FIRST for MetaFormer models | |
| 298 # This is essential for MetaFormer models to work properly | |
| 299 if "preprocessing" not in image_feat: | |
| 300 image_feat["preprocessing"] = {} | |
| 301 image_feat["preprocessing"]["height"] = height | |
| 302 image_feat["preprocessing"]["width"] = width | |
| 303 # Use infer_image_dimensions=True to allow Ludwig to read images for validation | |
| 304 # but set explicit max dimensions to control the output size | |
| 305 image_feat["preprocessing"]["infer_image_dimensions"] = True | |
| 306 image_feat["preprocessing"]["infer_image_max_height"] = height | |
| 307 image_feat["preprocessing"]["infer_image_max_width"] = width | |
| 308 image_feat["preprocessing"]["num_channels"] = metaformer_channels | |
| 309 image_feat["preprocessing"]["resize_method"] = "interpolate" # Use interpolation for better quality | |
| 310 image_feat["preprocessing"]["standardize_image"] = "imagenet1k" # Use ImageNet standardization | |
| 311 # Force Ludwig to respect our dimensions by setting additional parameters | |
| 312 image_feat["preprocessing"]["requires_equal_dimensions"] = False | |
| 313 logger.info(f"Set preprocessing dimensions for MetaFormer: {height}x{width} (infer_dimensions=True with max dimensions to allow validation)") | |
| 314 # Now set the encoder configuration | |
| 315 image_feat["encoder"] = encoder_config | |
| 316 | |
| 317 if config_params.get("augmentation") is not None: | |
| 318 image_feat["augmentation"] = config_params["augmentation"] | |
| 319 | |
| 320 # Add resize configuration for standard encoders (ResNet, etc.) | |
| 321 # FIXED: MetaFormer models now respect user dimensions completely | |
| 322 # Previously there was a double resize issue where MetaFormer would force 224x224 | |
| 323 # Now both MetaFormer and standard encoders respect user's resize choice | |
| 324 if (not is_metaformer) and config_params.get("image_resize") and config_params["image_resize"] != "original": | |
| 325 try: | |
| 326 dimensions = config_params["image_resize"].split("x") | |
| 327 if len(dimensions) == 2: | |
| 328 height, width = int(dimensions[0]), int(dimensions[1]) | |
| 329 if height <= 0 or width <= 0: | |
| 330 raise ValueError( | |
| 331 f"Image resize must be positive integers, received {config_params['image_resize']}." | |
| 332 ) | |
| 333 | |
| 334 # Add resize to preprocessing for standard encoders | |
| 335 if "preprocessing" not in image_feat: | |
| 336 image_feat["preprocessing"] = {} | |
| 337 image_feat["preprocessing"]["height"] = height | |
| 338 image_feat["preprocessing"]["width"] = width | |
| 339 # Use infer_image_dimensions=True to allow Ludwig to read images for validation | |
| 340 # but set explicit max dimensions to control the output size | |
| 341 image_feat["preprocessing"]["infer_image_dimensions"] = True | |
| 342 image_feat["preprocessing"]["infer_image_max_height"] = height | |
| 343 image_feat["preprocessing"]["infer_image_max_width"] = width | |
| 344 logger.info(f"Added resize preprocessing: {height}x{width} for standard encoder with infer_image_dimensions=True and max dimensions") | |
| 345 except (ValueError, IndexError): | |
| 346 logger.warning(f"Invalid image resize format: {config_params['image_resize']}, skipping resize preprocessing") | |
| 347 if task_type == "regression": | |
| 348 output_feat = { | |
| 349 "name": LABEL_COLUMN_NAME, | |
| 350 "type": "number", | |
| 351 "decoder": {"type": "regressor"}, | |
| 352 "loss": {"type": "mean_squared_error"}, | |
| 353 } | |
| 354 val_metric = config_params.get("validation_metric", "mean_squared_error") | |
| 355 | |
| 356 else: | |
| 357 if num_unique_labels == 2: | |
| 358 output_feat = { | |
| 359 "name": LABEL_COLUMN_NAME, | |
| 360 "type": "binary", | |
| 361 "loss": {"type": "binary_weighted_cross_entropy"}, | |
| 362 } | |
| 363 if config_params.get("threshold") is not None: | |
| 364 output_feat["threshold"] = float(config_params["threshold"]) | |
| 365 else: | |
| 366 output_feat = { | |
| 367 "name": LABEL_COLUMN_NAME, | |
| 368 "type": "category", | |
| 369 "loss": {"type": "softmax_cross_entropy"}, | |
| 370 } | |
| 371 val_metric = None | |
| 372 | |
| 373 conf: Dict[str, Any] = { | |
| 374 "model_type": "ecd", | |
| 375 "input_features": [image_feat], | |
| 376 "output_features": [output_feat], | |
| 377 "combiner": {"type": "concat"}, | |
| 378 "trainer": { | |
| 379 "epochs": epochs, | |
| 380 "early_stop": early_stop, | |
| 381 "batch_size": batch_size_cfg, | |
| 382 "learning_rate": learning_rate, | |
| 383 # only set validation_metric for regression | |
| 384 **({"validation_metric": val_metric} if val_metric else {}), | |
| 385 }, | |
| 386 "preprocessing": { | |
| 387 "split": split_config, | |
| 388 "num_processes": num_processes, | |
| 389 "in_memory": False, | |
| 390 }, | |
| 391 } | |
| 392 | |
| 393 logger.debug("LudwigDirectBackend: Config dict built.") | |
| 394 try: | |
| 395 yaml_str = yaml.dump(conf, sort_keys=False, indent=2) | |
| 396 logger.info("LudwigDirectBackend: YAML config generated.") | |
| 397 return yaml_str | |
| 398 except Exception: | |
| 399 logger.error( | |
| 400 "LudwigDirectBackend: Failed to serialize YAML.", | |
| 401 exc_info=True, | |
| 402 ) | |
| 403 raise | |
| 404 | |
| 405 def run_experiment( | |
| 406 self, | |
| 407 dataset_path: Path, | |
| 408 config_path: Path, | |
| 409 output_dir: Path, | |
| 410 random_seed: int = 42, | |
| 411 ) -> None: | |
| 412 """Invoke Ludwig's internal experiment_cli function to run the experiment.""" | |
| 413 logger.info("LudwigDirectBackend: Starting experiment execution.") | |
| 414 | |
| 415 try: | |
| 416 from ludwig.experiment import experiment_cli | |
| 417 except ImportError as e: | |
| 418 logger.error( | |
| 419 "LudwigDirectBackend: Could not import experiment_cli.", | |
| 420 exc_info=True, | |
| 421 ) | |
| 422 raise RuntimeError("Ludwig import failed.") from e | |
| 423 | |
| 424 output_dir.mkdir(parents=True, exist_ok=True) | |
| 425 | |
| 426 try: | |
| 427 experiment_cli( | |
| 428 dataset=str(dataset_path), | |
| 429 config=str(config_path), | |
| 430 output_directory=str(output_dir), | |
| 431 random_seed=random_seed, | |
| 432 skip_preprocessing=True, | |
| 433 ) | |
| 434 logger.info( | |
| 435 f"LudwigDirectBackend: Experiment completed. Results in {output_dir}" | |
| 436 ) | |
| 437 except TypeError as e: | |
| 438 logger.error( | |
| 439 "LudwigDirectBackend: Argument mismatch in experiment_cli call.", | |
| 440 exc_info=True, | |
| 441 ) | |
| 442 raise RuntimeError("Ludwig argument error.") from e | |
| 443 except Exception: | |
| 444 logger.error( | |
| 445 "LudwigDirectBackend: Experiment execution error.", | |
| 446 exc_info=True, | |
| 447 ) | |
| 448 raise | |
| 449 | |
| 450 def get_training_process(self, output_dir) -> Optional[Dict[str, Any]]: | |
| 451 """Retrieve the learning rate used in the most recent Ludwig run.""" | |
| 452 output_dir = Path(output_dir) | |
| 453 exp_dirs = sorted( | |
| 454 output_dir.glob("experiment_run*"), | |
| 455 key=lambda p: p.stat().st_mtime, | |
| 456 ) | |
| 457 | |
| 458 if not exp_dirs: | |
| 459 logger.warning(f"No experiment run directories found in {output_dir}") | |
| 460 return None | |
| 461 | |
| 462 progress_file = exp_dirs[-1] / "model" / "training_progress.json" | |
| 463 if not progress_file.exists(): | |
| 464 logger.warning(f"No training_progress.json found in {progress_file}") | |
| 465 return None | |
| 466 | |
| 467 try: | |
| 468 with progress_file.open("r", encoding="utf-8") as f: | |
| 469 data = json.load(f) | |
| 470 return { | |
| 471 "learning_rate": data.get("learning_rate"), | |
| 472 "batch_size": data.get("batch_size"), | |
| 473 "epoch": data.get("epoch"), | |
| 474 } | |
| 475 except Exception as e: | |
| 476 logger.warning(f"Failed to read training progress info: {e}") | |
| 477 return {} | |
| 478 | |
| 479 def convert_parquet_to_csv(self, output_dir: Path): | |
| 480 """Convert the predictions Parquet file to CSV.""" | |
| 481 output_dir = Path(output_dir) | |
| 482 exp_dirs = sorted( | |
| 483 output_dir.glob("experiment_run*"), | |
| 484 key=lambda p: p.stat().st_mtime, | |
| 485 ) | |
| 486 if not exp_dirs: | |
| 487 logger.warning(f"No experiment run dirs found in {output_dir}") | |
| 488 return | |
| 489 exp_dir = exp_dirs[-1] | |
| 490 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME | |
| 491 csv_path = exp_dir / "predictions.csv" | |
| 492 | |
| 493 # Check if parquet file exists before trying to convert | |
| 494 if not parquet_path.exists(): | |
| 495 logger.info(f"Predictions parquet file not found at {parquet_path}, skipping conversion") | |
| 496 return | |
| 497 | |
| 498 try: | |
| 499 df = pd.read_parquet(parquet_path) | |
| 500 df.to_csv(csv_path, index=False) | |
| 501 logger.info(f"Converted Parquet to CSV: {csv_path}") | |
| 502 except Exception as e: | |
| 503 logger.error(f"Error converting Parquet to CSV: {e}") | |
| 504 | |
| 505 def generate_plots(self, output_dir: Path) -> None: | |
| 506 """Generate all registered Ludwig visualizations for the latest experiment run.""" | |
| 507 logger.info("Generating all Ludwig visualizations…") | |
| 508 | |
| 509 test_plots = { | |
| 510 "compare_performance", | |
| 511 "compare_classifiers_performance_from_prob", | |
| 512 "compare_classifiers_performance_from_pred", | |
| 513 "compare_classifiers_performance_changing_k", | |
| 514 "compare_classifiers_multiclass_multimetric", | |
| 515 "compare_classifiers_predictions", | |
| 516 "confidence_thresholding_2thresholds_2d", | |
| 517 "confidence_thresholding_2thresholds_3d", | |
| 518 "confidence_thresholding", | |
| 519 "confidence_thresholding_data_vs_acc", | |
| 520 "binary_threshold_vs_metric", | |
| 521 "roc_curves", | |
| 522 "roc_curves_from_test_statistics", | |
| 523 "calibration_1_vs_all", | |
| 524 "calibration_multiclass", | |
| 525 "confusion_matrix", | |
| 526 "frequency_vs_f1", | |
| 527 } | |
| 528 train_plots = { | |
| 529 "learning_curves", | |
| 530 "compare_classifiers_performance_subset", | |
| 531 } | |
| 532 | |
| 533 output_dir = Path(output_dir) | |
| 534 exp_dirs = sorted( | |
| 535 output_dir.glob("experiment_run*"), | |
| 536 key=lambda p: p.stat().st_mtime, | |
| 537 ) | |
| 538 if not exp_dirs: | |
| 539 logger.warning(f"No experiment run dirs found in {output_dir}") | |
| 540 return | |
| 541 exp_dir = exp_dirs[-1] | |
| 542 | |
| 543 viz_dir = exp_dir / "visualizations" | |
| 544 viz_dir.mkdir(exist_ok=True) | |
| 545 train_viz = viz_dir / "train" | |
| 546 test_viz = viz_dir / "test" | |
| 547 train_viz.mkdir(parents=True, exist_ok=True) | |
| 548 test_viz.mkdir(parents=True, exist_ok=True) | |
| 549 | |
| 550 def _check(p: Path) -> Optional[str]: | |
| 551 return str(p) if p.exists() else None | |
| 552 | |
| 553 training_stats = _check(exp_dir / "training_statistics.json") | |
| 554 test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME) | |
| 555 probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME) | |
| 556 gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME) | |
| 557 | |
| 558 dataset_path = None | |
| 559 split_file = None | |
| 560 desc = exp_dir / DESCRIPTION_FILE_NAME | |
| 561 if desc.exists(): | |
| 562 with open(desc, "r") as f: | |
| 563 cfg = json.load(f) | |
| 564 dataset_path = _check(Path(cfg.get("dataset", ""))) | |
| 565 split_file = _check(Path(get_split_path(cfg.get("dataset", "")))) | |
| 566 | |
| 567 output_feature = "" | |
| 568 if desc.exists(): | |
| 569 try: | |
| 570 output_feature = cfg["config"]["output_features"][0]["name"] | |
| 571 except Exception: | |
| 572 pass | |
| 573 if not output_feature and test_stats: | |
| 574 with open(test_stats, "r") as f: | |
| 575 stats = json.load(f) | |
| 576 output_feature = next(iter(stats.keys()), "") | |
| 577 | |
| 578 viz_registry = get_visualizations_registry() | |
| 579 for viz_name, viz_func in viz_registry.items(): | |
| 580 if viz_name in train_plots: | |
| 581 viz_dir_plot = train_viz | |
| 582 elif viz_name in test_plots: | |
| 583 viz_dir_plot = test_viz | |
| 584 else: | |
| 585 continue | |
| 586 | |
| 587 try: | |
| 588 viz_func( | |
| 589 training_statistics=[training_stats] if training_stats else [], | |
| 590 test_statistics=[test_stats] if test_stats else [], | |
| 591 probabilities=[probs_path] if probs_path else [], | |
| 592 output_feature_name=output_feature, | |
| 593 ground_truth_split=2, | |
| 594 top_n_classes=[0], | |
| 595 top_k=3, | |
| 596 ground_truth_metadata=gt_metadata, | |
| 597 ground_truth=dataset_path, | |
| 598 split_file=split_file, | |
| 599 output_directory=str(viz_dir_plot), | |
| 600 normalize=False, | |
| 601 file_format="png", | |
| 602 ) | |
| 603 logger.info(f"✔ Generated {viz_name}") | |
| 604 except Exception as e: | |
| 605 logger.warning(f"✘ Skipped {viz_name}: {e}") | |
| 606 | |
| 607 logger.info(f"All visualizations written to {viz_dir}") | |
| 608 | |
| 609 def generate_html_report( | |
| 610 self, | |
| 611 title: str, | |
| 612 output_dir: str, | |
| 613 config: dict, | |
| 614 split_info: str, | |
| 615 ) -> Path: | |
| 616 """Assemble an HTML report from visualizations under train_val/ and test/ folders.""" | |
| 617 cwd = Path.cwd() | |
| 618 report_name = title.lower().replace(" ", "_") + "_report.html" | |
| 619 report_path = cwd / report_name | |
| 620 output_dir = Path(output_dir) | |
| 621 output_type = None | |
| 622 | |
| 623 exp_dirs = sorted( | |
| 624 output_dir.glob("experiment_run*"), | |
| 625 key=lambda p: p.stat().st_mtime, | |
| 626 ) | |
| 627 if not exp_dirs: | |
| 628 raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}") | |
| 629 exp_dir = exp_dirs[-1] | |
| 630 | |
| 631 base_viz_dir = exp_dir / "visualizations" | |
| 632 train_viz_dir = base_viz_dir / "train" | |
| 633 test_viz_dir = base_viz_dir / "test" | |
| 634 | |
| 635 html = get_html_template() | |
| 636 | |
| 637 # Extra CSS & JS: center Plotly and enable CSV download for predictions table | |
| 638 html += """ | |
| 639 <style> | |
| 640 /* Center Plotly figures (both wrapper and native classes) */ | |
| 641 .plotly-center { display: flex; justify-content: center; } | |
| 642 .plotly-center .plotly-graph-div, .plotly-center .js-plotly-plot { margin: 0 auto !important; } | |
| 643 .js-plotly-plot, .plotly-graph-div { margin-left: auto !important; margin-right: auto !important; } | |
| 644 | |
| 645 /* Download button for predictions table */ | |
| 646 .download-btn { | |
| 647 padding: 8px 12px; | |
| 648 border: 1px solid #4CAF50; | |
| 649 background: #4CAF50; | |
| 650 color: white; | |
| 651 border-radius: 6px; | |
| 652 cursor: pointer; | |
| 653 } | |
| 654 .download-btn:hover { filter: brightness(0.95); } | |
| 655 .preds-controls { | |
| 656 display: flex; | |
| 657 justify-content: flex-end; | |
| 658 gap: 8px; | |
| 659 margin: 8px 0; | |
| 660 } | |
| 661 </style> | |
| 662 <script> | |
| 663 function tableToCSV(table){ | |
| 664 const rows = Array.from(table.querySelectorAll('tr')); | |
| 665 return rows.map(row => | |
| 666 Array.from(row.querySelectorAll('th,td')).map(cell => { | |
| 667 let text = cell.innerText.replace(/\\r?\\n|\\r/g,' ').trim(); | |
| 668 if (text.includes('"') || text.includes(',')) { | |
| 669 text = '"' + text.replace(/"/g,'""') + '"'; | |
| 670 } | |
| 671 return text; | |
| 672 }).join(',') | |
| 673 ).join('\\n'); | |
| 674 } | |
| 675 document.addEventListener('DOMContentLoaded', function(){ | |
| 676 const btn = document.getElementById('downloadPredsCsv'); | |
| 677 if(btn){ | |
| 678 btn.addEventListener('click', function(){ | |
| 679 const tbl = document.querySelector('.predictions-table'); | |
| 680 if(!tbl){ alert('Predictions table not found.'); return; } | |
| 681 const csv = tableToCSV(tbl); | |
| 682 const blob = new Blob([csv], {type: 'text/csv;charset=utf-8;'}); | |
| 683 const url = URL.createObjectURL(blob); | |
| 684 const a = document.createElement('a'); | |
| 685 a.href = url; | |
| 686 a.download = 'ground_truth_vs_predictions.csv'; | |
| 687 document.body.appendChild(a); | |
| 688 a.click(); | |
| 689 document.body.removeChild(a); | |
| 690 URL.revokeObjectURL(url); | |
| 691 }); | |
| 692 } | |
| 693 }); | |
| 694 </script> | |
| 695 """ | |
| 696 html += f"<h1>{title}</h1>" | |
| 697 | |
| 698 metrics_html = "" | |
| 699 train_val_metrics_html = "" | |
| 700 test_metrics_html = "" | |
| 701 try: | |
| 702 train_stats_path = exp_dir / "training_statistics.json" | |
| 703 test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME | |
| 704 if train_stats_path.exists() and test_stats_path.exists(): | |
| 705 with open(train_stats_path) as f: | |
| 706 train_stats = json.load(f) | |
| 707 with open(test_stats_path) as f: | |
| 708 test_stats = json.load(f) | |
| 709 output_type = detect_output_type(test_stats) | |
| 710 metrics_html = format_stats_table_html(train_stats, test_stats, output_type) | |
| 711 train_val_metrics_html = format_train_val_stats_table_html( | |
| 712 train_stats, test_stats | |
| 713 ) | |
| 714 test_metrics_html = format_test_merged_stats_table_html( | |
| 715 extract_metrics_from_json(train_stats, test_stats, output_type)[ | |
| 716 "test" | |
| 717 ], output_type | |
| 718 ) | |
| 719 except Exception as e: | |
| 720 logger.warning( | |
| 721 f"Could not load stats for HTML report: {type(e).__name__}: {e}" | |
| 722 ) | |
| 723 | |
| 724 config_html = "" | |
| 725 training_progress = self.get_training_process(output_dir) | |
| 726 try: | |
| 727 config_html = format_config_table_html( | |
| 728 config, split_info, training_progress, output_type | |
| 729 ) | |
| 730 except Exception as e: | |
| 731 logger.warning(f"Could not load config for HTML report: {e}") | |
| 732 | |
| 733 # ---------- image rendering with exclusions ---------- | |
| 734 def render_img_section( | |
| 735 title: str, | |
| 736 dir_path: Path, | |
| 737 output_type: str = None, | |
| 738 exclude_names: Optional[set] = None, | |
| 739 ) -> str: | |
| 740 if not dir_path.exists(): | |
| 741 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" | |
| 742 | |
| 743 exclude_names = exclude_names or set() | |
| 744 | |
| 745 imgs = list(dir_path.glob("*.png")) | |
| 746 | |
| 747 # Exclude ROC curves and standard confusion matrices (keep only entropy version) | |
| 748 default_exclude = { | |
| 749 # "roc_curves.png", # Remove ROC curves from test tab | |
| 750 "confusion_matrix__label_top5.png", # Remove standard confusion matrix | |
| 751 "confusion_matrix__label_top10.png", # Remove duplicate | |
| 752 "confusion_matrix__label_top6.png", # Remove duplicate | |
| 753 "confusion_matrix_entropy__label_top10.png", # Keep only top5 | |
| 754 "confusion_matrix_entropy__label_top6.png", # Keep only top5 | |
| 755 } | |
| 756 title_is_test = title.lower().startswith("test") | |
| 757 if title_is_test and output_type == "binary": | |
| 758 default_exclude.update( | |
| 759 { | |
| 760 "confusion_matrix__label_top2.png", | |
| 761 "confusion_matrix_entropy__label_top2.png", | |
| 762 "roc_curves_from_prediction_statistics.png", | |
| 763 } | |
| 764 ) | |
| 765 elif title_is_test and output_type == "category": | |
| 766 default_exclude.update( | |
| 767 { | |
| 768 "compare_classifiers_multiclass_multimetric__label_best10.png", | |
| 769 "compare_classifiers_multiclass_multimetric__label_sorted.png", | |
| 770 "compare_classifiers_multiclass_multimetric__label_worst10.png", | |
| 771 } | |
| 772 ) | |
| 773 | |
| 774 imgs = [ | |
| 775 img | |
| 776 for img in imgs | |
| 777 if img.name not in default_exclude | |
| 778 and img.name not in exclude_names | |
| 779 ] | |
| 780 | |
| 781 if not imgs: | |
| 782 return f"<h2>{title}</h2><p><em>No plots found.</em></p>" | |
| 783 | |
| 784 # Sort images by name for consistent ordering (works with string and numeric labels) | |
| 785 imgs = sorted(imgs, key=lambda x: x.name) | |
| 786 | |
| 787 html_section = "" | |
| 788 custom_titles = { | |
| 789 "compare_classifiers_multiclass_multimetric__label_top10": "Metric Comparison by Label", | |
| 790 "compare_classifiers_performance_from_prob": "Label Metric Comparison by Probability", | |
| 791 } | |
| 792 for img in imgs: | |
| 793 b64 = encode_image_to_base64(str(img)) | |
| 794 default_title = img.stem.replace("_", " ").title() | |
| 795 img_title = custom_titles.get(img.stem, default_title) | |
| 796 html_section += ( | |
| 797 f"<h2 style='text-align: center;'>{img_title}</h2>" | |
| 798 f'<div class="plot" style="margin-bottom:20px;text-align:center;">' | |
| 799 f'<img src="data:image/png;base64,{b64}" ' | |
| 800 f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' | |
| 801 f"</div>" | |
| 802 ) | |
| 803 return html_section | |
| 804 | |
| 805 tab1_content = config_html + metrics_html | |
| 806 | |
| 807 tab2_content = train_val_metrics_html + render_img_section( | |
| 808 "Training and Validation Visualizations", | |
| 809 train_viz_dir, | |
| 810 output_type, | |
| 811 exclude_names={ | |
| 812 "compare_classifiers_performance_from_prob.png", | |
| 813 "roc_curves_from_prediction_statistics.png", | |
| 814 "precision_recall_curves_from_prediction_statistics.png", | |
| 815 "precision_recall_curve.png", | |
| 816 }, | |
| 817 ) | |
| 818 | |
| 819 # --- Predictions vs Ground Truth table (REGRESSION ONLY) --- | |
| 820 preds_section = "" | |
| 821 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME | |
| 822 if output_type == "regression" and parquet_path.exists(): | |
| 823 try: | |
| 824 # 1) load predictions from Parquet | |
| 825 df_preds = pd.read_parquet(parquet_path).reset_index(drop=True) | |
| 826 # assume the column containing your model's prediction is named "prediction" | |
| 827 # or contains that substring: | |
| 828 pred_col = next( | |
| 829 (c for c in df_preds.columns if "prediction" in c.lower()), | |
| 830 None, | |
| 831 ) | |
| 832 if pred_col is None: | |
| 833 raise ValueError("No prediction column found in Parquet output") | |
| 834 df_pred = df_preds[[pred_col]].rename(columns={pred_col: "prediction"}) | |
| 835 | |
| 836 # 2) load ground truth for the test split from prepared CSV | |
| 837 df_all = pd.read_csv(config["label_column_data_path"]) | |
| 838 df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][ | |
| 839 LABEL_COLUMN_NAME | |
| 840 ].reset_index(drop=True) | |
| 841 # 3) concatenate side-by-side | |
| 842 df_table = pd.concat([df_gt, df_pred], axis=1) | |
| 843 df_table.columns = [LABEL_COLUMN_NAME, "prediction"] | |
| 844 | |
| 845 # 4) render as HTML | |
| 846 preds_html = df_table.to_html(index=False, classes="predictions-table") | |
| 847 preds_section = ( | |
| 848 "<h2 style='text-align: center;'>Ground Truth vs. Predictions</h2>" | |
| 849 "<div class='preds-controls'>" | |
| 850 "<button id='downloadPredsCsv' class='download-btn'>Download CSV</button>" | |
| 851 "</div>" | |
| 852 "<div class='scroll-rows-30' style='overflow-x:auto; overflow-y:auto; max-height:900px; margin-bottom:20px;'>" | |
| 853 + preds_html | |
| 854 + "</div>" | |
| 855 ) | |
| 856 except Exception as e: | |
| 857 logger.warning(f"Could not build Predictions vs GT table: {e}") | |
| 858 | |
| 859 tab3_content = test_metrics_html + preds_section | |
| 860 | |
| 861 if output_type in ("binary", "category") and test_stats_path.exists(): | |
| 862 try: | |
| 863 interactive_plots = build_classification_plots( | |
| 864 str(test_stats_path), | |
| 865 str(train_stats_path) if train_stats_path.exists() else None, | |
| 866 ) | |
| 867 for plot in interactive_plots: | |
| 868 tab3_content += ( | |
| 869 f"<h2 style='text-align: center;'>{plot['title']}</h2>" | |
| 870 f"<div class='plotly-center'>{plot['html']}</div>" | |
| 871 ) | |
| 872 logger.info(f"Generated {len(interactive_plots)} interactive Plotly plots") | |
| 873 except Exception as e: | |
| 874 logger.warning(f"Could not generate Plotly plots: {e}") | |
| 875 | |
| 876 # Add static TEST PNGs (with default dedupe/exclusions) | |
| 877 tab3_content += render_img_section( | |
| 878 "Test Visualizations", test_viz_dir, output_type | |
| 879 ) | |
| 880 | |
| 881 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) | |
| 882 modal_html = get_metrics_help_modal() | |
| 883 html += tabbed_html + modal_html + get_html_closing() | |
| 884 | |
| 885 try: | |
| 886 with open(report_path, "w") as f: | |
| 887 f.write(html) | |
| 888 logger.info(f"HTML report generated at: {report_path}") | |
| 889 except Exception as e: | |
| 890 logger.error(f"Failed to write HTML report: {e}") | |
| 891 raise | |
| 892 | |
| 893 return report_path |
