diff image_learner_cli.py @ 9:9e912fce264c draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit eace0d7c2b2939029c052991d238a54947d2e191
author goeckslab
date Wed, 27 Aug 2025 21:02:48 +0000
parents 85e6f4b2ad18
children
line wrap: on
line diff
--- a/image_learner_cli.py	Thu Aug 14 14:53:10 2025 +0000
+++ b/image_learner_cli.py	Wed Aug 27 21:02:48 2025 +0000
@@ -21,7 +21,7 @@
     SPLIT_COLUMN_NAME,
     TEMP_CONFIG_FILENAME,
     TEMP_CSV_FILENAME,
-    TEMP_DIR_PREFIX
+    TEMP_DIR_PREFIX,
 )
 from ludwig.globals import (
     DESCRIPTION_FILE_NAME,
@@ -38,13 +38,13 @@
     encode_image_to_base64,
     get_html_closing,
     get_html_template,
-    get_metrics_help_modal
+    get_metrics_help_modal,
 )
 
 # --- Logging Setup ---
 logging.basicConfig(
     level=logging.INFO,
-    format='%(asctime)s %(levelname)s %(name)s: %(message)s',
+    format="%(asctime)s %(levelname)s %(name)s: %(message)s",
 )
 logger = logging.getLogger("ImageLearner")
 
@@ -67,7 +67,9 @@
         "early_stop",
         "threshold",
     ]
+
     rows = []
+
     for key in display_keys:
         val = config.get(key, None)
         if key == "threshold":
@@ -134,7 +136,9 @@
                         val_str = val
             else:
                 val_str = val if val is not None else "N/A"
-            if val_str == "N/A" and key not in ["task_type"]:  # Skip if N/A for non-essential
+            if val_str == "N/A" and key not in [
+                "task_type"
+            ]:  # Skip if N/A for non-essential
                 continue
         rows.append(
             f"<tr>"
@@ -166,7 +170,7 @@
               <th style="padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Value</th>
             </tr></thead>
             <tbody>
-              {''.join(rows)}
+              {"".join(rows)}
             </tbody>
           </table>
         </div><br>
@@ -251,6 +255,7 @@
                 "roc_auc": get_last_value(label_stats, "roc_auc"),
                 "hits_at_k": get_last_value(label_stats, "hits_at_k"),
             }
+
     # Test metrics: dynamic extraction according to exclusions
     test_label_stats = test_stats.get("label", {})
     if not test_label_stats:
@@ -258,11 +263,13 @@
     else:
         combined_stats = test_stats.get("combined", {})
         overall_stats = test_label_stats.get("overall_stats", {})
+
         # Define exclusions
         if output_type == "binary":
             exclude = {"per_class_stats", "precision_recall_curve", "roc_curve"}
         else:
             exclude = {"per_class_stats", "confusion_matrix"}
+
         # 1. Get all scalar test_label_stats not excluded
         test_metrics = {}
         for k, v in test_label_stats.items():
@@ -272,9 +279,11 @@
                 continue
             if isinstance(v, (int, float, str, bool)):
                 test_metrics[k] = v
+
         # 2. Add overall_stats (flattened)
         for k, v in overall_stats.items():
             test_metrics[k] = v
+
         # 3. Optionally include combined/loss if present and not already
         if "loss" in combined_stats and "loss" not in test_metrics:
             test_metrics["loss"] = combined_stats["loss"]
@@ -315,8 +324,10 @@
             te = all_metrics["test"].get(metric_key)
             if all(x is not None for x in [t, v, te]):
                 rows.append([display_name, f"{t:.4f}", f"{v:.4f}", f"{te:.4f}"])
+
     if not rows:
         return "<table><tr><td>No metric values found.</td></tr></table>"
+
     html = (
         "<h2 style='text-align: center;'>Model Performance Summary</h2>"
         "<div style='display: flex; justify-content: center;'>"
@@ -331,7 +342,7 @@
     for row in rows:
         html += generate_table_row(
             row,
-            "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;"
+            "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;",
         )
     html += "</tbody></table></div><br>"
     return html
@@ -357,8 +368,10 @@
             v = all_metrics["validation"].get(metric_key)
             if t is not None and v is not None:
                 rows.append([display_name, f"{t:.4f}", f"{v:.4f}"])
+
     if not rows:
         return "<table><tr><td>No metric values found for Train/Validation.</td></tr></table>"
+
     html = (
         "<h2 style='text-align: center;'>Train/Validation Performance Summary</h2>"
         "<div style='display: flex; justify-content: center;'>"
@@ -372,7 +385,7 @@
     for row in rows:
         html += generate_table_row(
             row,
-            "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;"
+            "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;",
         )
     html += "</tbody></table></div><br>"
     return html
@@ -393,8 +406,10 @@
         value = test_metrics[key]
         if value is not None:
             rows.append([display_name, f"{value:.4f}"])
+
     if not rows:
         return "<table><tr><td>No test metric values found.</td></tr></table>"
+
     html = (
         "<h2 style='text-align: center;'>Test Performance Summary</h2>"
         "<div style='display: flex; justify-content: center;'>"
@@ -407,7 +422,7 @@
     for row in rows:
         html += generate_table_row(
             row,
-            "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;"
+            "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;",
         )
     html += "</tbody></table></div><br>"
     return html
@@ -436,10 +451,14 @@
             min_samples_per_class = label_counts.min()
             if min_samples_per_class * validation_size < 1:
                 # Adjust validation_size to ensure at least 1 sample per class, but do not exceed original validation_size
-                adjusted_validation_size = min(validation_size, 1.0 / min_samples_per_class)
+                adjusted_validation_size = min(
+                    validation_size, 1.0 / min_samples_per_class
+                )
                 if adjusted_validation_size != validation_size:
                     validation_size = adjusted_validation_size
-                    logger.info(f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation")
+                    logger.info(
+                        f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation"
+                    )
             stratify_arr = out.loc[idx_train, label_column]
             logger.info("Using stratified split for validation set")
         else:
@@ -486,7 +505,9 @@
     # initialize split column
     out[split_column] = 0
     if not label_column or label_column not in out.columns:
-        logger.warning("No label column found; using random split without stratification")
+        logger.warning(
+            "No label column found; using random split without stratification"
+        )
         # fall back to simple random assignment
         indices = out.index.tolist()
         np.random.seed(random_state)
@@ -529,7 +550,9 @@
         stratify=out[label_column],
     )
     # second split: separate training and validation from remaining data
-    val_size_adjusted = split_probabilities[1] / (split_probabilities[0] + split_probabilities[1])
+    val_size_adjusted = split_probabilities[1] / (
+        split_probabilities[0] + split_probabilities[1]
+    )
     train_idx, val_idx = train_test_split(
         train_val_idx,
         test_size=val_size_adjusted,
@@ -541,12 +564,15 @@
     out.loc[val_idx, split_column] = 1
     out.loc[test_idx, split_column] = 2
     logger.info("Successfully applied stratified random split")
-    logger.info(f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}")
+    logger.info(
+        f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}"
+    )
     return out.astype({split_column: int})
 
 
 class Backend(Protocol):
     """Interface for a machine learning backend."""
+
     def prepare_config(
         self,
         config_params: Dict[str, Any],
@@ -578,12 +604,14 @@
 
 class LudwigDirectBackend:
     """Backend for running Ludwig experiments directly via the internal experiment_cli function."""
+
     def prepare_config(
         self,
         config_params: Dict[str, Any],
         split_config: Dict[str, Any],
     ) -> str:
         logger.info("LudwigDirectBackend: Preparing YAML configuration.")
+
         model_name = config_params.get("model_name", "resnet18")
         use_pretrained = config_params.get("use_pretrained", False)
         fine_tune = config_params.get("fine_tune", False)
@@ -606,7 +634,9 @@
             }
         else:
             encoder_config = {"type": raw_encoder}
+
         batch_size_cfg = batch_size or "auto"
+
         label_column_path = config_params.get("label_column_data_path")
         label_series = None
         if label_column_path is not None and Path(label_column_path).exists():
@@ -614,6 +644,7 @@
                 label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME]
             except Exception as e:
                 logger.warning(f"Could not read label column for task detection: {e}")
+
         if (
             label_series is not None
             and ptypes.is_numeric_dtype(label_series.dtype)
@@ -622,7 +653,9 @@
             task_type = "regression"
         else:
             task_type = "classification"
+
         config_params["task_type"] = task_type
+
         image_feat: Dict[str, Any] = {
             "name": IMAGE_PATH_COLUMN_NAME,
             "type": "image",
@@ -630,6 +663,7 @@
         }
         if config_params.get("augmentation") is not None:
             image_feat["augmentation"] = config_params["augmentation"]
+
         if task_type == "regression":
             output_feat = {
                 "name": LABEL_COLUMN_NAME,
@@ -645,6 +679,7 @@
                 },
             }
             val_metric = config_params.get("validation_metric", "mean_squared_error")
+
         else:
             num_unique_labels = (
                 label_series.nunique() if label_series is not None else 2
@@ -654,6 +689,7 @@
             if output_type == "binary" and config_params.get("threshold") is not None:
                 output_feat["threshold"] = float(config_params["threshold"])
             val_metric = None
+
         conf: Dict[str, Any] = {
             "model_type": "ecd",
             "input_features": [image_feat],
@@ -673,6 +709,7 @@
                 "in_memory": False,
             },
         }
+
         logger.debug("LudwigDirectBackend: Config dict built.")
         try:
             yaml_str = yaml.dump(conf, sort_keys=False, indent=2)
@@ -694,6 +731,7 @@
     ) -> None:
         """Invoke Ludwig's internal experiment_cli function to run the experiment."""
         logger.info("LudwigDirectBackend: Starting experiment execution.")
+
         try:
             from ludwig.experiment import experiment_cli
         except ImportError as e:
@@ -702,7 +740,9 @@
                 exc_info=True,
             )
             raise RuntimeError("Ludwig import failed.") from e
+
         output_dir.mkdir(parents=True, exist_ok=True)
+
         try:
             experiment_cli(
                 dataset=str(dataset_path),
@@ -733,13 +773,16 @@
             output_dir.glob("experiment_run*"),
             key=lambda p: p.stat().st_mtime,
         )
+
         if not exp_dirs:
             logger.warning(f"No experiment run directories found in {output_dir}")
             return None
+
         progress_file = exp_dirs[-1] / "model" / "training_progress.json"
         if not progress_file.exists():
             logger.warning(f"No training_progress.json found in {progress_file}")
             return None
+
         try:
             with progress_file.open("r", encoding="utf-8") as f:
                 data = json.load(f)
@@ -775,6 +818,7 @@
     def generate_plots(self, output_dir: Path) -> None:
         """Generate all registered Ludwig visualizations for the latest experiment run."""
         logger.info("Generating all Ludwig visualizations…")
+
         test_plots = {
             "compare_performance",
             "compare_classifiers_performance_from_prob",
@@ -798,6 +842,7 @@
             "learning_curves",
             "compare_classifiers_performance_subset",
         }
+
         output_dir = Path(output_dir)
         exp_dirs = sorted(
             output_dir.glob("experiment_run*"),
@@ -807,6 +852,7 @@
             logger.warning(f"No experiment run dirs found in {output_dir}")
             return
         exp_dir = exp_dirs[-1]
+
         viz_dir = exp_dir / "visualizations"
         viz_dir.mkdir(exist_ok=True)
         train_viz = viz_dir / "train"
@@ -821,6 +867,7 @@
         test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME)
         probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME)
         gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME)
+
         dataset_path = None
         split_file = None
         desc = exp_dir / DESCRIPTION_FILE_NAME
@@ -829,6 +876,7 @@
                 cfg = json.load(f)
             dataset_path = _check(Path(cfg.get("dataset", "")))
             split_file = _check(Path(get_split_path(cfg.get("dataset", ""))))
+
         output_feature = ""
         if desc.exists():
             try:
@@ -839,6 +887,7 @@
             with open(test_stats, "r") as f:
                 stats = json.load(f)
             output_feature = next(iter(stats.keys()), "")
+
         viz_registry = get_visualizations_registry()
         for viz_name, viz_func in viz_registry.items():
             if viz_name in train_plots:
@@ -847,6 +896,7 @@
                 viz_dir_plot = test_viz
             else:
                 continue
+
             try:
                 viz_func(
                     training_statistics=[training_stats] if training_stats else [],
@@ -866,6 +916,7 @@
                 logger.info(f"✔ Generated {viz_name}")
             except Exception as e:
                 logger.warning(f"✘ Skipped {viz_name}: {e}")
+
         logger.info(f"All visualizations written to {viz_dir}")
 
     def generate_html_report(
@@ -881,6 +932,7 @@
         report_path = cwd / report_name
         output_dir = Path(output_dir)
         output_type = None
+
         exp_dirs = sorted(
             output_dir.glob("experiment_run*"),
             key=lambda p: p.stat().st_mtime,
@@ -888,11 +940,14 @@
         if not exp_dirs:
             raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}")
         exp_dir = exp_dirs[-1]
+
         base_viz_dir = exp_dir / "visualizations"
         train_viz_dir = base_viz_dir / "train"
         test_viz_dir = base_viz_dir / "test"
+
         html = get_html_template()
         html += f"<h1>{title}</h1>"
+
         metrics_html = ""
         train_val_metrics_html = ""
         test_metrics_html = ""
@@ -918,11 +973,12 @@
             logger.warning(
                 f"Could not load stats for HTML report: {type(e).__name__}: {e}"
             )
+
         config_html = ""
         training_progress = self.get_training_process(output_dir)
         try:
             config_html = format_config_table_html(
-                config, split_info, training_progress
+                config, split_info, training_progress, output_type
             )
         except Exception as e:
             logger.warning(f"Could not load config for HTML report: {e}")
@@ -936,7 +992,8 @@
             imgs = list(dir_path.glob("*.png"))
             # --- EXCLUDE Ludwig's base confusion matrix and any top-N confusion_matrix files ---
             imgs = [
-                img for img in imgs
+                img
+                for img in imgs
                 if not (
                     img.name == "confusion_matrix.png"
                     or img.name.startswith("confusion_matrix__label_top")
@@ -972,7 +1029,9 @@
                 valid_imgs = [img for img in imgs if img.name not in unwanted]
                 img_map = {img.name: img for img in valid_imgs}
                 ordered = [img_map[n] for n in display_order if n in img_map]
-                others = sorted(img for img in valid_imgs if img.name not in display_order)
+                others = sorted(
+                    img for img in valid_imgs if img.name not in display_order
+                )
                 imgs = ordered + others
             else:
                 # regression: just sort whatever's left
@@ -1012,7 +1071,9 @@
                 df_pred = df_preds[[pred_col]].rename(columns={pred_col: "prediction"})
                 # 2) load ground truth for the test split from prepared CSV
                 df_all = pd.read_csv(config["label_column_data_path"])
-                df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][LABEL_COLUMN_NAME].reset_index(drop=True)
+                df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][
+                    LABEL_COLUMN_NAME
+                ].reset_index(drop=True)
                 # 3) concatenate side-by-side
                 df_table = pd.concat([df_gt, df_pred], axis=1)
                 df_table.columns = [LABEL_COLUMN_NAME, "prediction"]
@@ -1036,7 +1097,9 @@
             for plot in interactive_plots:
                 # 2) inject the static "roc_curves_from_prediction_statistics.png"
                 if plot["title"] == "ROC-AUC":
-                    static_img = test_viz_dir / "roc_curves_from_prediction_statistics.png"
+                    static_img = (
+                        test_viz_dir / "roc_curves_from_prediction_statistics.png"
+                    )
                     if static_img.exists():
                         b64 = encode_image_to_base64(str(static_img))
                         tab3_content += (
@@ -1054,14 +1117,13 @@
                     + plot["html"]
                 )
             tab3_content += render_img_section(
-                "Test Visualizations",
-                test_viz_dir,
-                output_type
+                "Test Visualizations", test_viz_dir, output_type
             )
         # assemble the tabs and help modal
         tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content)
         modal_html = get_metrics_help_modal()
         html += tabbed_html + modal_html + get_html_closing()
+
         try:
             with open(report_path, "w") as f:
                 f.write(html)
@@ -1069,11 +1131,13 @@
         except Exception as e:
             logger.error(f"Failed to write HTML report: {e}")
             raise
+
         return report_path
 
 
 class WorkflowOrchestrator:
     """Manages the image-classification workflow."""
+
     def __init__(self, args: argparse.Namespace, backend: Backend):
         self.args = args
         self.backend = backend
@@ -1113,16 +1177,19 @@
         """Load CSV, update image paths, handle splits, and write prepared CSV."""
         if not self.temp_dir or not self.image_extract_dir:
             raise RuntimeError("Temp dirs not initialized before data prep.")
+
         try:
             df = pd.read_csv(self.args.csv_file)
             logger.info(f"Loaded CSV: {self.args.csv_file}")
         except Exception:
             logger.error("Error loading CSV file", exc_info=True)
             raise
+
         required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME}
         missing = required - set(df.columns)
         if missing:
             raise ValueError(f"Missing CSV columns: {', '.join(missing)}")
+
         try:
             df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply(
                 lambda p: str((self.image_extract_dir / p).resolve())
@@ -1150,13 +1217,16 @@
                 f"{[int(p * 100) for p in self.args.split_probabilities]}% "
                 f"for train/val/test with balanced label distribution."
             )
+
         final_csv = self.temp_dir / TEMP_CSV_FILENAME
+
         try:
             df.to_csv(final_csv, index=False)
             logger.info(f"Saved prepared data to {final_csv}")
         except Exception:
             logger.error("Error saving prepared CSV", exc_info=True)
             raise
+
         return final_csv, split_config, split_info
 
     def _process_fixed_split(
@@ -1171,6 +1241,7 @@
             )
             if df[SPLIT_COLUMN_NAME].isna().any():
                 logger.warning("Split column contains non-numeric/missing values.")
+
             unique = set(df[SPLIT_COLUMN_NAME].dropna().unique())
             logger.info(f"Unique split values: {unique}")
             if unique == {0, 2}:
@@ -1193,7 +1264,9 @@
                 logger.info("Using fixed split as-is.")
             else:
                 raise ValueError(f"Unexpected split values: {unique}")
+
             return df, {"type": "fixed", "column": SPLIT_COLUMN_NAME}, split_info
+
         except Exception:
             logger.error("Error processing fixed split", exc_info=True)
             raise
@@ -1209,11 +1282,14 @@
         """Execute the full workflow end-to-end."""
         logger.info("Starting workflow...")
         self.args.output_dir.mkdir(parents=True, exist_ok=True)
+
         try:
             self._create_temp_dirs()
             self._extract_images()
             csv_path, split_cfg, split_info = self._prepare_data()
+
             use_pretrained = self.args.use_pretrained or self.args.fine_tune
+
             backend_args = {
                 "model_name": self.args.model_name,
                 "fine_tune": self.args.fine_tune,
@@ -1230,9 +1306,11 @@
                 "threshold": self.args.threshold,
             }
             yaml_str = self.backend.prepare_config(backend_args, split_cfg)
+
             config_file = self.temp_dir / TEMP_CONFIG_FILENAME
             config_file.write_text(yaml_str)
             logger.info(f"Wrote backend config: {config_file}")
+
             self.backend.run_experiment(
                 csv_path,
                 config_file,
@@ -1374,8 +1452,7 @@
         action=SplitProbAction,
         default=[0.7, 0.1, 0.2],
         help=(
-            "Random split proportions (e.g., 0.7 0.1 0.2)."
-            "Only used if no split column."
+            "Random split proportions (e.g., 0.7 0.1 0.2).Only used if no split column."
         ),
     )
     parser.add_argument(
@@ -1408,9 +1485,10 @@
         help=(
             "Decision threshold for binary classification (0.0–1.0)."
             "Overrides default 0.5."
-        )
+        ),
     )
     args = parser.parse_args()
+
     if not 0.0 <= args.validation_size <= 1.0:
         parser.error("validation-size must be between 0.0 and 1.0")
     if not args.csv_file.is_file():
@@ -1423,8 +1501,10 @@
             setattr(args, "augmentation", augmentation_setup)
         except ValueError as e:
             parser.error(str(e))
+
     backend_instance = LudwigDirectBackend()
     orchestrator = WorkflowOrchestrator(args, backend_instance)
+
     exit_code = 0
     try:
         orchestrator.run()
@@ -1439,6 +1519,7 @@
 if __name__ == "__main__":
     try:
         import ludwig
+
         logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}")
     except ImportError:
         logger.error(
@@ -1446,4 +1527,5 @@
             "('pip install ludwig[image]')"
         )
         sys.exit(1)
+
     main()