Mercurial > repos > goeckslab > image_learner
diff ludwig_backend.py @ 15:d17e3a1b8659 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
| author | goeckslab |
|---|---|
| date | Fri, 28 Nov 2025 15:45:49 +0000 |
| parents | bcfa2e234a80 |
| children |
line wrap: on
line diff
--- a/ludwig_backend.py Wed Nov 26 22:00:32 2025 +0000 +++ b/ludwig_backend.py Fri Nov 28 15:45:49 2025 +0000 @@ -31,7 +31,13 @@ ) from ludwig.utils.data_utils import get_split_path from metaformer_setup import get_visualizations_registry, META_DEFAULT_CFGS -from plotly_plots import build_classification_plots +from plotly_plots import ( + build_classification_plots, + build_prediction_diagnostics, + build_regression_test_plots, + build_regression_train_val_plots, + build_train_validation_plots, +) from utils import detect_output_type, extract_metrics_from_json logger = logging.getLogger("ImageLearner") @@ -72,6 +78,8 @@ class LudwigDirectBackend: """Backend for running Ludwig experiments directly via the internal experiment_cli function.""" + _torchvision_patched = False + def _detect_image_dimensions(self, image_zip_path: str) -> Tuple[int, int]: """Detect image dimensions from the first image in the dataset.""" try: @@ -344,6 +352,72 @@ logger.info(f"Added resize preprocessing: {height}x{width} for standard encoder with infer_image_dimensions=True and max dimensions") except (ValueError, IndexError): logger.warning(f"Invalid image resize format: {config_params['image_resize']}, skipping resize preprocessing") + + def _resolve_validation_metric(task: str, requested: Optional[str]) -> Optional[str]: + """Pick a validation metric that Ludwig will accept for the resolved task.""" + default_map = { + "regression": "pearson_r", + "binary": "roc_auc", + "category": "accuracy", + } + allowed_map = { + "regression": { + "pearson_r", + "mean_absolute_error", + "mean_squared_error", + "root_mean_squared_error", + "mean_absolute_percentage_error", + "r2", + "explained_variance", + "loss", + }, + # Ludwig rejects f1 and balanced_accuracy for binary outputs; keep to known-safe set. + "binary": { + "roc_auc", + "accuracy", + "precision", + "recall", + "specificity", + "log_loss", + "loss", + }, + "category": { + "accuracy", + "balanced_accuracy", + "precision", + "recall", + "f1", + "specificity", + "log_loss", + "loss", + }, + } + alias_map = { + "regression": { + "mae": "mean_absolute_error", + "mse": "mean_squared_error", + "rmse": "root_mean_squared_error", + "mape": "mean_absolute_percentage_error", + }, + } + + default_metric = default_map.get(task) + allowed = allowed_map.get(task, set()) + metric = requested or default_metric + + if metric is None: + return None + + metric = alias_map.get(task, {}).get(metric, metric) + + if metric not in allowed: + if requested: + logger.warning( + f"Validation metric '{requested}' is not supported for {task} outputs; using '{default_metric}' instead." + ) + metric = default_metric + return metric + if task_type == "regression": output_feat = { "name": LABEL_COLUMN_NAME, @@ -351,7 +425,7 @@ "decoder": {"type": "regressor"}, "loss": {"type": "mean_squared_error"}, } - val_metric = config_params.get("validation_metric", "mean_squared_error") + val_metric = _resolve_validation_metric("regression", config_params.get("validation_metric")) else: if num_unique_labels == 2: @@ -368,7 +442,10 @@ "type": "category", "loss": {"type": "softmax_cross_entropy"}, } - val_metric = None + val_metric = _resolve_validation_metric( + "binary" if num_unique_labels == 2 else "category", + config_params.get("validation_metric"), + ) conf: Dict[str, Any] = { "model_type": "ecd", @@ -380,7 +457,7 @@ "early_stop": early_stop, "batch_size": batch_size_cfg, "learning_rate": learning_rate, - # only set validation_metric for regression + # set validation_metric when provided **({"validation_metric": val_metric} if val_metric else {}), }, "preprocessing": { @@ -402,6 +479,41 @@ ) raise + def _patch_torchvision_download(self) -> None: + """ + Torchvision weight downloads sometimes fail checksum validation behind + corporate proxies that rewrite binaries. Skip hash checking to allow + pre-trained weights to load in those environments. + """ + if LudwigDirectBackend._torchvision_patched: + return + try: + import torch.hub as torch_hub + + original = torch_hub.load_state_dict_from_url + original_download = torch_hub.download_url_to_file + + def _no_hash(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None): + return original( + url, + model_dir=model_dir, + map_location=map_location, + progress=progress, + check_hash=False, + file_name=file_name, + ) + + def _download_no_hash(url, dst, hash_prefix=None, progress=True): + # Torchvision's download_url_to_file signature does not accept check_hash in older versions. + return original_download(url, dst, hash_prefix=None, progress=progress) + + torch_hub.load_state_dict_from_url = _no_hash # type: ignore[assignment] + torch_hub.download_url_to_file = _download_no_hash # type: ignore[assignment] + LudwigDirectBackend._torchvision_patched = True + logger.info("Disabled torchvision weight hash verification to avoid proxy-corrupted downloads.") + except Exception as exc: + logger.warning(f"Could not patch torchvision download hash check: {exc}") + def run_experiment( self, dataset_path: Path, @@ -412,6 +524,9 @@ """Invoke Ludwig's internal experiment_cli function to run the experiment.""" logger.info("LudwigDirectBackend: Starting experiment execution.") + # Avoid strict hash validation for torchvision weights (common in proxied environments) + self._patch_torchvision_download() + try: from ludwig.experiment import experiment_cli except ImportError as e: @@ -506,24 +621,10 @@ """Generate all registered Ludwig visualizations for the latest experiment run.""" logger.info("Generating all Ludwig visualizations…") + # Keep only lightweight plots (drop compare_performance/roc_curves) test_plots = { - "compare_performance", - "compare_classifiers_performance_from_prob", - "compare_classifiers_performance_from_pred", - "compare_classifiers_performance_changing_k", - "compare_classifiers_multiclass_multimetric", - "compare_classifiers_predictions", - "confidence_thresholding_2thresholds_2d", - "confidence_thresholding_2thresholds_3d", - "confidence_thresholding", - "confidence_thresholding_data_vs_acc", - "binary_threshold_vs_metric", - "roc_curves", "roc_curves_from_test_statistics", - "calibration_1_vs_all", - "calibration_multiclass", "confusion_matrix", - "frequency_vs_f1", } train_plots = { "learning_curves", @@ -627,6 +728,70 @@ if not exp_dirs: raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}") exp_dir = exp_dirs[-1] + train_set_metadata_path = exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME + label_metadata_path = config.get("label_column_data_path") + if label_metadata_path: + label_metadata_path = Path(label_metadata_path) + + # Pull additional config details from description.json if available + config_for_summary = dict(config) + if "target_column" not in config_for_summary or not config_for_summary.get("target_column"): + config_for_summary["target_column"] = LABEL_COLUMN_NAME + desc_path = exp_dir / DESCRIPTION_FILE_NAME + if desc_path.exists(): + try: + with open(desc_path, "r") as f: + desc_cfg = json.load(f).get("config", {}) + encoder_cfg = ( + desc_cfg.get("input_features", [{}])[0].get("encoder", {}) + if isinstance(desc_cfg.get("input_features", [{}]), list) + else {} + ) + output_cfg = ( + desc_cfg.get("output_features", [{}])[0] + if isinstance(desc_cfg.get("output_features", [{}]), list) + else {} + ) + trainer_cfg = desc_cfg.get("trainer", {}) if isinstance(desc_cfg, dict) else {} + loss_cfg = output_cfg.get("loss", {}) if isinstance(output_cfg, dict) else {} + opt_cfg = trainer_cfg.get("optimizer", {}) if isinstance(trainer_cfg, dict) else {} + clip_cfg = trainer_cfg.get("gradient_clipping", {}) if isinstance(trainer_cfg, dict) else {} + + arch_type = encoder_cfg.get("type") + arch_variant = encoder_cfg.get("model_variant") + arch_name = None + if arch_type: + arch_base = str(arch_type).replace("_", " ").title() + arch_name = f"{arch_base} {arch_variant}" if arch_variant is not None else arch_base + + summary_fields = { + "architecture": arch_name, + "model_variant": arch_variant, + "pretrained": encoder_cfg.get("use_pretrained"), + "trainable": encoder_cfg.get("trainable"), + "target_column": output_cfg.get("column"), + "task_type": output_cfg.get("type"), + "validation_metric": trainer_cfg.get("validation_metric"), + "loss_function": loss_cfg.get("type"), + "threshold": output_cfg.get("threshold"), + "total_epochs": trainer_cfg.get("epochs"), + "early_stop": trainer_cfg.get("early_stop"), + "batch_size": trainer_cfg.get("batch_size"), + "optimizer": opt_cfg.get("type"), + "learning_rate": trainer_cfg.get("learning_rate"), + "random_seed": desc_cfg.get("random_seed") or config.get("random_seed"), + "use_mixed_precision": trainer_cfg.get("use_mixed_precision"), + "gradient_clipping": clip_cfg.get("clipglobalnorm"), + } + for k, v in summary_fields.items(): + if v is None: + continue + # Do not override user-passed target/image column names in config + if k in {"target_column", "image_column"} and config_for_summary.get(k): + continue + config_for_summary.setdefault(k, v) + except Exception as e: # pragma: no cover - defensive + logger.warning(f"Could not merge description.json into config summary: {e}") base_viz_dir = exp_dir / "visualizations" train_viz_dir = base_viz_dir / "train" @@ -698,9 +863,10 @@ metrics_html = "" train_val_metrics_html = "" test_metrics_html = "" + output_type = None + train_stats_path = exp_dir / "training_statistics.json" + test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME try: - train_stats_path = exp_dir / "training_statistics.json" - test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME if train_stats_path.exists() and test_stats_path.exists(): with open(train_stats_path) as f: train_stats = json.load(f) @@ -725,10 +891,19 @@ training_progress = self.get_training_process(output_dir) try: config_html = format_config_table_html( - config, split_info, training_progress, output_type + config_for_summary, split_info, training_progress, output_type ) except Exception as e: logger.warning(f"Could not load config for HTML report: {e}") + config_html = ( + "<h2 style='text-align: center;'>Model and Training Summary</h2>" + "<p style='text-align:center; color:#666;'>Configuration details unavailable.</p>" + ) + if not config_html: + config_html = ( + "<h2 style='text-align: center;'>Model and Training Summary</h2>" + "<p style='text-align:center; color:#666;'>No configuration details found.</p>" + ) # ---------- image rendering with exclusions ---------- def render_img_section( @@ -776,6 +951,11 @@ for img in imgs if img.name not in default_exclude and img.name not in exclude_names + and not ( + "learning_curves" in img.stem + and "loss" in img.stem + and "label" in img.stem + ) ] if not imgs: @@ -802,7 +982,8 @@ ) return html_section - tab1_content = config_html + metrics_html + # Show performance first, then config + tab1_content = metrics_html + config_html tab2_content = train_val_metrics_html + render_img_section( "Training and Validation Visualizations", @@ -815,6 +996,21 @@ "precision_recall_curve.png", }, ) + if train_stats_path.exists(): + try: + if output_type == "regression": + tv_plots = build_regression_train_val_plots(str(train_stats_path)) + else: + tv_plots = build_train_validation_plots(str(train_stats_path)) + for plot in tv_plots: + tab2_content += ( + f"<h2 style='text-align: center;'>{plot['title']}</h2>" + f"<div class='plotly-center'>{plot['html']}</div>" + ) + if tv_plots: + logger.info(f"Generated {len(tv_plots)} train/val diagnostic plots") + except Exception as e: + logger.warning(f"Could not generate train/val plots: {e}") # --- Predictions vs Ground Truth table (REGRESSION ONLY) --- preds_section = "" @@ -849,7 +1045,7 @@ "<div class='preds-controls'>" "<button id='downloadPredsCsv' class='download-btn'>Download CSV</button>" "</div>" - "<div class='scroll-rows-30' style='overflow-x:auto; overflow-y:auto; max-height:900px; margin-bottom:20px;'>" + "<div class='scroll-rows-30' style='overflow-x:auto; overflow-y:auto; max-height:350px; margin-bottom:20px;'>" + preds_html + "</div>" ) @@ -857,27 +1053,75 @@ logger.warning(f"Could not build Predictions vs GT table: {e}") tab3_content = test_metrics_html + preds_section + test_plotly_added = False + + if output_type == "regression" and train_stats_path.exists(): + try: + test_plots = build_regression_test_plots(str(train_stats_path)) + for plot in test_plots: + tab3_content += ( + f"<h2 style='text-align: center;'>{plot['title']}</h2>" + f"<div class='plotly-center'>{plot['html']}</div>" + ) + if test_plots: + test_plotly_added = True + logger.info(f"Generated {len(test_plots)} regression test plots") + except Exception as e: + logger.warning(f"Could not generate regression test plots: {e}") if output_type in ("binary", "category") and test_stats_path.exists(): try: interactive_plots = build_classification_plots( str(test_stats_path), str(train_stats_path) if train_stats_path.exists() else None, + metadata_csv_path=str(label_metadata_path) + if label_metadata_path and label_metadata_path.exists() + else None, + train_set_metadata_path=str(train_set_metadata_path) + if train_set_metadata_path.exists() + else None, ) for plot in interactive_plots: tab3_content += ( f"<h2 style='text-align: center;'>{plot['title']}</h2>" f"<div class='plotly-center'>{plot['html']}</div>" ) + if interactive_plots: + test_plotly_added = True logger.info(f"Generated {len(interactive_plots)} interactive Plotly plots") except Exception as e: logger.warning(f"Could not generate Plotly plots: {e}") + # Add prediction diagnostics from predictions.csv + predictions_csv_path = exp_dir / "predictions.csv" + try: + diag_plots = build_prediction_diagnostics( + str(predictions_csv_path), + label_data_path=str(config.get("label_column_data_path")) + if config.get("label_column_data_path") + else None, + threshold=config.get("threshold"), + ) + for plot in diag_plots: + tab3_content += ( + f"<h2 style='text-align: center;'>{plot['title']}</h2>" + f"<div class='plotly-center'>{plot['html']}</div>" + ) + if diag_plots: + test_plotly_added = True + logger.info(f"Generated {len(diag_plots)} prediction diagnostic plots") + except Exception as e: + logger.warning(f"Could not generate prediction diagnostics: {e}") + + # Fallback: include static PNGs if no interactive plots were added + if not test_plotly_added: + tab3_content += render_img_section( + "Test Visualizations (PNG fallback)", + test_viz_dir, + output_type, + ) + # Add static TEST PNGs (with default dedupe/exclusions) - tab3_content += render_img_section( - "Test Visualizations", test_viz_dir, output_type - ) - tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) modal_html = get_metrics_help_modal() html += tabbed_html + modal_html + get_html_closing()
