diff image_workflow.py @ 17:db9be962dc13 draft

planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
author goeckslab
date Wed, 10 Dec 2025 00:24:13 +0000
parents d17e3a1b8659
children
line wrap: on
line diff
--- a/image_workflow.py	Wed Dec 03 01:28:52 2025 +0000
+++ b/image_workflow.py	Wed Dec 10 00:24:13 2025 +0000
@@ -5,7 +5,7 @@
 import tempfile
 import zipfile
 from pathlib import Path
-from typing import Any, Dict, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
 
 import pandas as pd
 import pandas.api.types as ptypes
@@ -35,6 +35,8 @@
         self.image_extract_dir: Optional[Path] = None
         self.label_metadata: Dict[str, Any] = {}
         self.output_type_hint: Optional[str] = None
+        self.label_split_counts: List[Dict[str, int]] = []
+        self.split_counts: Dict[int, int] = {}
         logger.info(f"Orchestrator initialized with backend: {type(backend).__name__}")
 
     def _create_temp_dirs(self) -> None:
@@ -186,6 +188,34 @@
             logger.error("Error saving prepared CSV", exc_info=True)
             raise
 
+        # Capture actual split counts for downstream reporting (avoids heuristic 70/10/20 tables)
+        try:
+            split_series = pd.to_numeric(df[SPLIT_COLUMN_NAME], errors="coerce")
+            split_series = split_series.dropna().astype(int)
+            self.split_counts = {int(k): int(v) for k, v in split_series.value_counts().to_dict().items()}
+            if LABEL_COLUMN_NAME in df.columns:
+                counts = (
+                    df.dropna(subset=[LABEL_COLUMN_NAME])
+                    .assign(**{SPLIT_COLUMN_NAME: split_series})
+                    .groupby([LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME])
+                    .size()
+                    .unstack(fill_value=0)
+                    .sort_index()
+                )
+                self.label_split_counts = [
+                    {
+                        "label": str(lbl),
+                        "train": int(row.get(0, 0)),
+                        "validation": int(row.get(1, 0)),
+                        "test": int(row.get(2, 0)),
+                    }
+                    for lbl, row in counts.iterrows()
+                ]
+        except Exception:
+            logger.warning("Unable to capture split counts for reporting", exc_info=True)
+            self.label_split_counts = []
+            self.split_counts = {}
+
         self._capture_label_metadata(df)
 
         return final_csv, split_config, split_info
@@ -349,6 +379,8 @@
                 "random_seed": self.args.random_seed,
                 "early_stop": self.args.early_stop,
                 "label_column_data_path": csv_path,
+                "label_split_counts": self.label_split_counts,
+                "split_counts": self.split_counts,
                 "augmentation": self.args.augmentation,
                 "image_resize": self.args.image_resize,
                 "image_zip": self.args.image_zip,