comparison image_workflow.py @ 17:db9be962dc13 draft

planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
author goeckslab
date Wed, 10 Dec 2025 00:24:13 +0000
parents d17e3a1b8659
children
comparison
equal deleted inserted replaced
16:8729f69e9207 17:db9be962dc13
3 import os 3 import os
4 import shutil 4 import shutil
5 import tempfile 5 import tempfile
6 import zipfile 6 import zipfile
7 from pathlib import Path 7 from pathlib import Path
8 from typing import Any, Dict, Optional, Tuple 8 from typing import Any, Dict, List, Optional, Tuple
9 9
10 import pandas as pd 10 import pandas as pd
11 import pandas.api.types as ptypes 11 import pandas.api.types as ptypes
12 from constants import ( 12 from constants import (
13 IMAGE_PATH_COLUMN_NAME, 13 IMAGE_PATH_COLUMN_NAME,
33 self.backend = backend 33 self.backend = backend
34 self.temp_dir: Optional[Path] = None 34 self.temp_dir: Optional[Path] = None
35 self.image_extract_dir: Optional[Path] = None 35 self.image_extract_dir: Optional[Path] = None
36 self.label_metadata: Dict[str, Any] = {} 36 self.label_metadata: Dict[str, Any] = {}
37 self.output_type_hint: Optional[str] = None 37 self.output_type_hint: Optional[str] = None
38 self.label_split_counts: List[Dict[str, int]] = []
39 self.split_counts: Dict[int, int] = {}
38 logger.info(f"Orchestrator initialized with backend: {type(backend).__name__}") 40 logger.info(f"Orchestrator initialized with backend: {type(backend).__name__}")
39 41
40 def _create_temp_dirs(self) -> None: 42 def _create_temp_dirs(self) -> None:
41 """Create temporary output and image extraction directories.""" 43 """Create temporary output and image extraction directories."""
42 try: 44 try:
183 df.to_csv(final_csv, index=False) 185 df.to_csv(final_csv, index=False)
184 logger.info(f"Saved prepared data to {final_csv}") 186 logger.info(f"Saved prepared data to {final_csv}")
185 except Exception: 187 except Exception:
186 logger.error("Error saving prepared CSV", exc_info=True) 188 logger.error("Error saving prepared CSV", exc_info=True)
187 raise 189 raise
190
191 # Capture actual split counts for downstream reporting (avoids heuristic 70/10/20 tables)
192 try:
193 split_series = pd.to_numeric(df[SPLIT_COLUMN_NAME], errors="coerce")
194 split_series = split_series.dropna().astype(int)
195 self.split_counts = {int(k): int(v) for k, v in split_series.value_counts().to_dict().items()}
196 if LABEL_COLUMN_NAME in df.columns:
197 counts = (
198 df.dropna(subset=[LABEL_COLUMN_NAME])
199 .assign(**{SPLIT_COLUMN_NAME: split_series})
200 .groupby([LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME])
201 .size()
202 .unstack(fill_value=0)
203 .sort_index()
204 )
205 self.label_split_counts = [
206 {
207 "label": str(lbl),
208 "train": int(row.get(0, 0)),
209 "validation": int(row.get(1, 0)),
210 "test": int(row.get(2, 0)),
211 }
212 for lbl, row in counts.iterrows()
213 ]
214 except Exception:
215 logger.warning("Unable to capture split counts for reporting", exc_info=True)
216 self.label_split_counts = []
217 self.split_counts = {}
188 218
189 self._capture_label_metadata(df) 219 self._capture_label_metadata(df)
190 220
191 return final_csv, split_config, split_info 221 return final_csv, split_config, split_info
192 222
347 "split_probabilities": self.args.split_probabilities, 377 "split_probabilities": self.args.split_probabilities,
348 "learning_rate": self.args.learning_rate, 378 "learning_rate": self.args.learning_rate,
349 "random_seed": self.args.random_seed, 379 "random_seed": self.args.random_seed,
350 "early_stop": self.args.early_stop, 380 "early_stop": self.args.early_stop,
351 "label_column_data_path": csv_path, 381 "label_column_data_path": csv_path,
382 "label_split_counts": self.label_split_counts,
383 "split_counts": self.split_counts,
352 "augmentation": self.args.augmentation, 384 "augmentation": self.args.augmentation,
353 "image_resize": self.args.image_resize, 385 "image_resize": self.args.image_resize,
354 "image_zip": self.args.image_zip, 386 "image_zip": self.args.image_zip,
355 "threshold": self.args.threshold, 387 "threshold": self.args.threshold,
356 "label_metadata": self.label_metadata, 388 "label_metadata": self.label_metadata,