diff ludwig_backend.py @ 15:d17e3a1b8659 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
author goeckslab
date Fri, 28 Nov 2025 15:45:49 +0000
parents bcfa2e234a80
children
line wrap: on
line diff
--- a/ludwig_backend.py	Wed Nov 26 22:00:32 2025 +0000
+++ b/ludwig_backend.py	Fri Nov 28 15:45:49 2025 +0000
@@ -31,7 +31,13 @@
 )
 from ludwig.utils.data_utils import get_split_path
 from metaformer_setup import get_visualizations_registry, META_DEFAULT_CFGS
-from plotly_plots import build_classification_plots
+from plotly_plots import (
+    build_classification_plots,
+    build_prediction_diagnostics,
+    build_regression_test_plots,
+    build_regression_train_val_plots,
+    build_train_validation_plots,
+)
 from utils import detect_output_type, extract_metrics_from_json
 
 logger = logging.getLogger("ImageLearner")
@@ -72,6 +78,8 @@
 class LudwigDirectBackend:
     """Backend for running Ludwig experiments directly via the internal experiment_cli function."""
 
+    _torchvision_patched = False
+
     def _detect_image_dimensions(self, image_zip_path: str) -> Tuple[int, int]:
         """Detect image dimensions from the first image in the dataset."""
         try:
@@ -344,6 +352,72 @@
                     logger.info(f"Added resize preprocessing: {height}x{width} for standard encoder with infer_image_dimensions=True and max dimensions")
             except (ValueError, IndexError):
                 logger.warning(f"Invalid image resize format: {config_params['image_resize']}, skipping resize preprocessing")
+
+        def _resolve_validation_metric(task: str, requested: Optional[str]) -> Optional[str]:
+            """Pick a validation metric that Ludwig will accept for the resolved task."""
+            default_map = {
+                "regression": "pearson_r",
+                "binary": "roc_auc",
+                "category": "accuracy",
+            }
+            allowed_map = {
+                "regression": {
+                    "pearson_r",
+                    "mean_absolute_error",
+                    "mean_squared_error",
+                    "root_mean_squared_error",
+                    "mean_absolute_percentage_error",
+                    "r2",
+                    "explained_variance",
+                    "loss",
+                },
+                # Ludwig rejects f1 and balanced_accuracy for binary outputs; keep to known-safe set.
+                "binary": {
+                    "roc_auc",
+                    "accuracy",
+                    "precision",
+                    "recall",
+                    "specificity",
+                    "log_loss",
+                    "loss",
+                },
+                "category": {
+                    "accuracy",
+                    "balanced_accuracy",
+                    "precision",
+                    "recall",
+                    "f1",
+                    "specificity",
+                    "log_loss",
+                    "loss",
+                },
+            }
+            alias_map = {
+                "regression": {
+                    "mae": "mean_absolute_error",
+                    "mse": "mean_squared_error",
+                    "rmse": "root_mean_squared_error",
+                    "mape": "mean_absolute_percentage_error",
+                },
+            }
+
+            default_metric = default_map.get(task)
+            allowed = allowed_map.get(task, set())
+            metric = requested or default_metric
+
+            if metric is None:
+                return None
+
+            metric = alias_map.get(task, {}).get(metric, metric)
+
+            if metric not in allowed:
+                if requested:
+                    logger.warning(
+                        f"Validation metric '{requested}' is not supported for {task} outputs; using '{default_metric}' instead."
+                    )
+                metric = default_metric
+            return metric
+
         if task_type == "regression":
             output_feat = {
                 "name": LABEL_COLUMN_NAME,
@@ -351,7 +425,7 @@
                 "decoder": {"type": "regressor"},
                 "loss": {"type": "mean_squared_error"},
             }
-            val_metric = config_params.get("validation_metric", "mean_squared_error")
+            val_metric = _resolve_validation_metric("regression", config_params.get("validation_metric"))
 
         else:
             if num_unique_labels == 2:
@@ -368,7 +442,10 @@
                     "type": "category",
                     "loss": {"type": "softmax_cross_entropy"},
                 }
-            val_metric = None
+            val_metric = _resolve_validation_metric(
+                "binary" if num_unique_labels == 2 else "category",
+                config_params.get("validation_metric"),
+            )
 
         conf: Dict[str, Any] = {
             "model_type": "ecd",
@@ -380,7 +457,7 @@
                 "early_stop": early_stop,
                 "batch_size": batch_size_cfg,
                 "learning_rate": learning_rate,
-                # only set validation_metric for regression
+                # set validation_metric when provided
                 **({"validation_metric": val_metric} if val_metric else {}),
             },
             "preprocessing": {
@@ -402,6 +479,41 @@
             )
             raise
 
+    def _patch_torchvision_download(self) -> None:
+        """
+        Torchvision weight downloads sometimes fail checksum validation behind
+        corporate proxies that rewrite binaries. Skip hash checking to allow
+        pre-trained weights to load in those environments.
+        """
+        if LudwigDirectBackend._torchvision_patched:
+            return
+        try:
+            import torch.hub as torch_hub
+
+            original = torch_hub.load_state_dict_from_url
+            original_download = torch_hub.download_url_to_file
+
+            def _no_hash(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None):
+                return original(
+                    url,
+                    model_dir=model_dir,
+                    map_location=map_location,
+                    progress=progress,
+                    check_hash=False,
+                    file_name=file_name,
+                )
+
+            def _download_no_hash(url, dst, hash_prefix=None, progress=True):
+                # Torchvision's download_url_to_file signature does not accept check_hash in older versions.
+                return original_download(url, dst, hash_prefix=None, progress=progress)
+
+            torch_hub.load_state_dict_from_url = _no_hash  # type: ignore[assignment]
+            torch_hub.download_url_to_file = _download_no_hash  # type: ignore[assignment]
+            LudwigDirectBackend._torchvision_patched = True
+            logger.info("Disabled torchvision weight hash verification to avoid proxy-corrupted downloads.")
+        except Exception as exc:
+            logger.warning(f"Could not patch torchvision download hash check: {exc}")
+
     def run_experiment(
         self,
         dataset_path: Path,
@@ -412,6 +524,9 @@
         """Invoke Ludwig's internal experiment_cli function to run the experiment."""
         logger.info("LudwigDirectBackend: Starting experiment execution.")
 
+        # Avoid strict hash validation for torchvision weights (common in proxied environments)
+        self._patch_torchvision_download()
+
         try:
             from ludwig.experiment import experiment_cli
         except ImportError as e:
@@ -506,24 +621,10 @@
         """Generate all registered Ludwig visualizations for the latest experiment run."""
         logger.info("Generating all Ludwig visualizations…")
 
+        # Keep only lightweight plots (drop compare_performance/roc_curves)
         test_plots = {
-            "compare_performance",
-            "compare_classifiers_performance_from_prob",
-            "compare_classifiers_performance_from_pred",
-            "compare_classifiers_performance_changing_k",
-            "compare_classifiers_multiclass_multimetric",
-            "compare_classifiers_predictions",
-            "confidence_thresholding_2thresholds_2d",
-            "confidence_thresholding_2thresholds_3d",
-            "confidence_thresholding",
-            "confidence_thresholding_data_vs_acc",
-            "binary_threshold_vs_metric",
-            "roc_curves",
             "roc_curves_from_test_statistics",
-            "calibration_1_vs_all",
-            "calibration_multiclass",
             "confusion_matrix",
-            "frequency_vs_f1",
         }
         train_plots = {
             "learning_curves",
@@ -627,6 +728,70 @@
         if not exp_dirs:
             raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}")
         exp_dir = exp_dirs[-1]
+        train_set_metadata_path = exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME
+        label_metadata_path = config.get("label_column_data_path")
+        if label_metadata_path:
+            label_metadata_path = Path(label_metadata_path)
+
+        # Pull additional config details from description.json if available
+        config_for_summary = dict(config)
+        if "target_column" not in config_for_summary or not config_for_summary.get("target_column"):
+            config_for_summary["target_column"] = LABEL_COLUMN_NAME
+        desc_path = exp_dir / DESCRIPTION_FILE_NAME
+        if desc_path.exists():
+            try:
+                with open(desc_path, "r") as f:
+                    desc_cfg = json.load(f).get("config", {})
+                encoder_cfg = (
+                    desc_cfg.get("input_features", [{}])[0].get("encoder", {})
+                    if isinstance(desc_cfg.get("input_features", [{}]), list)
+                    else {}
+                )
+                output_cfg = (
+                    desc_cfg.get("output_features", [{}])[0]
+                    if isinstance(desc_cfg.get("output_features", [{}]), list)
+                    else {}
+                )
+                trainer_cfg = desc_cfg.get("trainer", {}) if isinstance(desc_cfg, dict) else {}
+                loss_cfg = output_cfg.get("loss", {}) if isinstance(output_cfg, dict) else {}
+                opt_cfg = trainer_cfg.get("optimizer", {}) if isinstance(trainer_cfg, dict) else {}
+                clip_cfg = trainer_cfg.get("gradient_clipping", {}) if isinstance(trainer_cfg, dict) else {}
+
+                arch_type = encoder_cfg.get("type")
+                arch_variant = encoder_cfg.get("model_variant")
+                arch_name = None
+                if arch_type:
+                    arch_base = str(arch_type).replace("_", " ").title()
+                    arch_name = f"{arch_base} {arch_variant}" if arch_variant is not None else arch_base
+
+                summary_fields = {
+                    "architecture": arch_name,
+                    "model_variant": arch_variant,
+                    "pretrained": encoder_cfg.get("use_pretrained"),
+                    "trainable": encoder_cfg.get("trainable"),
+                    "target_column": output_cfg.get("column"),
+                    "task_type": output_cfg.get("type"),
+                    "validation_metric": trainer_cfg.get("validation_metric"),
+                    "loss_function": loss_cfg.get("type"),
+                    "threshold": output_cfg.get("threshold"),
+                    "total_epochs": trainer_cfg.get("epochs"),
+                    "early_stop": trainer_cfg.get("early_stop"),
+                    "batch_size": trainer_cfg.get("batch_size"),
+                    "optimizer": opt_cfg.get("type"),
+                    "learning_rate": trainer_cfg.get("learning_rate"),
+                    "random_seed": desc_cfg.get("random_seed") or config.get("random_seed"),
+                    "use_mixed_precision": trainer_cfg.get("use_mixed_precision"),
+                    "gradient_clipping": clip_cfg.get("clipglobalnorm"),
+                }
+                for k, v in summary_fields.items():
+                    if v is None:
+                        continue
+                    # Do not override user-passed target/image column names in config
+                    if k in {"target_column", "image_column"} and config_for_summary.get(k):
+                        continue
+                    config_for_summary.setdefault(k, v)
+            except Exception as e:  # pragma: no cover - defensive
+                logger.warning(f"Could not merge description.json into config summary: {e}")
 
         base_viz_dir = exp_dir / "visualizations"
         train_viz_dir = base_viz_dir / "train"
@@ -698,9 +863,10 @@
         metrics_html = ""
         train_val_metrics_html = ""
         test_metrics_html = ""
+        output_type = None
+        train_stats_path = exp_dir / "training_statistics.json"
+        test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME
         try:
-            train_stats_path = exp_dir / "training_statistics.json"
-            test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME
             if train_stats_path.exists() and test_stats_path.exists():
                 with open(train_stats_path) as f:
                     train_stats = json.load(f)
@@ -725,10 +891,19 @@
         training_progress = self.get_training_process(output_dir)
         try:
             config_html = format_config_table_html(
-                config, split_info, training_progress, output_type
+                config_for_summary, split_info, training_progress, output_type
             )
         except Exception as e:
             logger.warning(f"Could not load config for HTML report: {e}")
+            config_html = (
+                "<h2 style='text-align: center;'>Model and Training Summary</h2>"
+                "<p style='text-align:center; color:#666;'>Configuration details unavailable.</p>"
+            )
+        if not config_html:
+            config_html = (
+                "<h2 style='text-align: center;'>Model and Training Summary</h2>"
+                "<p style='text-align:center; color:#666;'>No configuration details found.</p>"
+            )
 
         # ---------- image rendering with exclusions ----------
         def render_img_section(
@@ -776,6 +951,11 @@
                 for img in imgs
                 if img.name not in default_exclude
                 and img.name not in exclude_names
+                and not (
+                    "learning_curves" in img.stem
+                    and "loss" in img.stem
+                    and "label" in img.stem
+                )
             ]
 
             if not imgs:
@@ -802,7 +982,8 @@
                 )
             return html_section
 
-        tab1_content = config_html + metrics_html
+        # Show performance first, then config
+        tab1_content = metrics_html + config_html
 
         tab2_content = train_val_metrics_html + render_img_section(
             "Training and Validation Visualizations",
@@ -815,6 +996,21 @@
                 "precision_recall_curve.png",
             },
         )
+        if train_stats_path.exists():
+            try:
+                if output_type == "regression":
+                    tv_plots = build_regression_train_val_plots(str(train_stats_path))
+                else:
+                    tv_plots = build_train_validation_plots(str(train_stats_path))
+                for plot in tv_plots:
+                    tab2_content += (
+                        f"<h2 style='text-align: center;'>{plot['title']}</h2>"
+                        f"<div class='plotly-center'>{plot['html']}</div>"
+                    )
+                if tv_plots:
+                    logger.info(f"Generated {len(tv_plots)} train/val diagnostic plots")
+            except Exception as e:
+                logger.warning(f"Could not generate train/val plots: {e}")
 
         # --- Predictions vs Ground Truth table (REGRESSION ONLY) ---
         preds_section = ""
@@ -849,7 +1045,7 @@
                     "<div class='preds-controls'>"
                     "<button id='downloadPredsCsv' class='download-btn'>Download CSV</button>"
                     "</div>"
-                    "<div class='scroll-rows-30' style='overflow-x:auto; overflow-y:auto; max-height:900px; margin-bottom:20px;'>"
+                    "<div class='scroll-rows-30' style='overflow-x:auto; overflow-y:auto; max-height:350px; margin-bottom:20px;'>"
                     + preds_html
                     + "</div>"
                 )
@@ -857,27 +1053,75 @@
                 logger.warning(f"Could not build Predictions vs GT table: {e}")
 
         tab3_content = test_metrics_html + preds_section
+        test_plotly_added = False
+
+        if output_type == "regression" and train_stats_path.exists():
+            try:
+                test_plots = build_regression_test_plots(str(train_stats_path))
+                for plot in test_plots:
+                    tab3_content += (
+                        f"<h2 style='text-align: center;'>{plot['title']}</h2>"
+                        f"<div class='plotly-center'>{plot['html']}</div>"
+                    )
+                if test_plots:
+                    test_plotly_added = True
+                    logger.info(f"Generated {len(test_plots)} regression test plots")
+            except Exception as e:
+                logger.warning(f"Could not generate regression test plots: {e}")
 
         if output_type in ("binary", "category") and test_stats_path.exists():
             try:
                 interactive_plots = build_classification_plots(
                     str(test_stats_path),
                     str(train_stats_path) if train_stats_path.exists() else None,
+                    metadata_csv_path=str(label_metadata_path)
+                    if label_metadata_path and label_metadata_path.exists()
+                    else None,
+                    train_set_metadata_path=str(train_set_metadata_path)
+                    if train_set_metadata_path.exists()
+                    else None,
                 )
                 for plot in interactive_plots:
                     tab3_content += (
                         f"<h2 style='text-align: center;'>{plot['title']}</h2>"
                         f"<div class='plotly-center'>{plot['html']}</div>"
                     )
+                if interactive_plots:
+                    test_plotly_added = True
                 logger.info(f"Generated {len(interactive_plots)} interactive Plotly plots")
             except Exception as e:
                 logger.warning(f"Could not generate Plotly plots: {e}")
 
+            # Add prediction diagnostics from predictions.csv
+            predictions_csv_path = exp_dir / "predictions.csv"
+            try:
+                diag_plots = build_prediction_diagnostics(
+                    str(predictions_csv_path),
+                    label_data_path=str(config.get("label_column_data_path"))
+                    if config.get("label_column_data_path")
+                    else None,
+                    threshold=config.get("threshold"),
+                )
+                for plot in diag_plots:
+                    tab3_content += (
+                        f"<h2 style='text-align: center;'>{plot['title']}</h2>"
+                        f"<div class='plotly-center'>{plot['html']}</div>"
+                    )
+                if diag_plots:
+                    test_plotly_added = True
+                    logger.info(f"Generated {len(diag_plots)} prediction diagnostic plots")
+            except Exception as e:
+                logger.warning(f"Could not generate prediction diagnostics: {e}")
+
+        # Fallback: include static PNGs if no interactive plots were added
+        if not test_plotly_added:
+            tab3_content += render_img_section(
+                "Test Visualizations (PNG fallback)",
+                test_viz_dir,
+                output_type,
+            )
+
         # Add static TEST PNGs (with default dedupe/exclusions)
-        tab3_content += render_img_section(
-            "Test Visualizations", test_viz_dir, output_type
-        )
-
         tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content)
         modal_html = get_metrics_help_modal()
         html += tabbed_html + modal_html + get_html_closing()