Mercurial > repos > goeckslab > image_learner
diff image_workflow.py @ 17:db9be962dc13 draft
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
| author | goeckslab |
|---|---|
| date | Wed, 10 Dec 2025 00:24:13 +0000 |
| parents | d17e3a1b8659 |
| children |
line wrap: on
line diff
--- a/image_workflow.py Wed Dec 03 01:28:52 2025 +0000 +++ b/image_workflow.py Wed Dec 10 00:24:13 2025 +0000 @@ -5,7 +5,7 @@ import tempfile import zipfile from pathlib import Path -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import pandas as pd import pandas.api.types as ptypes @@ -35,6 +35,8 @@ self.image_extract_dir: Optional[Path] = None self.label_metadata: Dict[str, Any] = {} self.output_type_hint: Optional[str] = None + self.label_split_counts: List[Dict[str, int]] = [] + self.split_counts: Dict[int, int] = {} logger.info(f"Orchestrator initialized with backend: {type(backend).__name__}") def _create_temp_dirs(self) -> None: @@ -186,6 +188,34 @@ logger.error("Error saving prepared CSV", exc_info=True) raise + # Capture actual split counts for downstream reporting (avoids heuristic 70/10/20 tables) + try: + split_series = pd.to_numeric(df[SPLIT_COLUMN_NAME], errors="coerce") + split_series = split_series.dropna().astype(int) + self.split_counts = {int(k): int(v) for k, v in split_series.value_counts().to_dict().items()} + if LABEL_COLUMN_NAME in df.columns: + counts = ( + df.dropna(subset=[LABEL_COLUMN_NAME]) + .assign(**{SPLIT_COLUMN_NAME: split_series}) + .groupby([LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME]) + .size() + .unstack(fill_value=0) + .sort_index() + ) + self.label_split_counts = [ + { + "label": str(lbl), + "train": int(row.get(0, 0)), + "validation": int(row.get(1, 0)), + "test": int(row.get(2, 0)), + } + for lbl, row in counts.iterrows() + ] + except Exception: + logger.warning("Unable to capture split counts for reporting", exc_info=True) + self.label_split_counts = [] + self.split_counts = {} + self._capture_label_metadata(df) return final_csv, split_config, split_info @@ -349,6 +379,8 @@ "random_seed": self.args.random_seed, "early_stop": self.args.early_stop, "label_column_data_path": csv_path, + "label_split_counts": self.label_split_counts, + "split_counts": self.split_counts, "augmentation": self.args.augmentation, "image_resize": self.args.image_resize, "image_zip": self.args.image_zip,
