Mercurial > repos > goeckslab > image_learner
comparison image_workflow.py @ 17:db9be962dc13 draft
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
| author | goeckslab |
|---|---|
| date | Wed, 10 Dec 2025 00:24:13 +0000 |
| parents | d17e3a1b8659 |
| children |
comparison
equal
deleted
inserted
replaced
| 16:8729f69e9207 | 17:db9be962dc13 |
|---|---|
| 3 import os | 3 import os |
| 4 import shutil | 4 import shutil |
| 5 import tempfile | 5 import tempfile |
| 6 import zipfile | 6 import zipfile |
| 7 from pathlib import Path | 7 from pathlib import Path |
| 8 from typing import Any, Dict, Optional, Tuple | 8 from typing import Any, Dict, List, Optional, Tuple |
| 9 | 9 |
| 10 import pandas as pd | 10 import pandas as pd |
| 11 import pandas.api.types as ptypes | 11 import pandas.api.types as ptypes |
| 12 from constants import ( | 12 from constants import ( |
| 13 IMAGE_PATH_COLUMN_NAME, | 13 IMAGE_PATH_COLUMN_NAME, |
| 33 self.backend = backend | 33 self.backend = backend |
| 34 self.temp_dir: Optional[Path] = None | 34 self.temp_dir: Optional[Path] = None |
| 35 self.image_extract_dir: Optional[Path] = None | 35 self.image_extract_dir: Optional[Path] = None |
| 36 self.label_metadata: Dict[str, Any] = {} | 36 self.label_metadata: Dict[str, Any] = {} |
| 37 self.output_type_hint: Optional[str] = None | 37 self.output_type_hint: Optional[str] = None |
| 38 self.label_split_counts: List[Dict[str, int]] = [] | |
| 39 self.split_counts: Dict[int, int] = {} | |
| 38 logger.info(f"Orchestrator initialized with backend: {type(backend).__name__}") | 40 logger.info(f"Orchestrator initialized with backend: {type(backend).__name__}") |
| 39 | 41 |
| 40 def _create_temp_dirs(self) -> None: | 42 def _create_temp_dirs(self) -> None: |
| 41 """Create temporary output and image extraction directories.""" | 43 """Create temporary output and image extraction directories.""" |
| 42 try: | 44 try: |
| 183 df.to_csv(final_csv, index=False) | 185 df.to_csv(final_csv, index=False) |
| 184 logger.info(f"Saved prepared data to {final_csv}") | 186 logger.info(f"Saved prepared data to {final_csv}") |
| 185 except Exception: | 187 except Exception: |
| 186 logger.error("Error saving prepared CSV", exc_info=True) | 188 logger.error("Error saving prepared CSV", exc_info=True) |
| 187 raise | 189 raise |
| 190 | |
| 191 # Capture actual split counts for downstream reporting (avoids heuristic 70/10/20 tables) | |
| 192 try: | |
| 193 split_series = pd.to_numeric(df[SPLIT_COLUMN_NAME], errors="coerce") | |
| 194 split_series = split_series.dropna().astype(int) | |
| 195 self.split_counts = {int(k): int(v) for k, v in split_series.value_counts().to_dict().items()} | |
| 196 if LABEL_COLUMN_NAME in df.columns: | |
| 197 counts = ( | |
| 198 df.dropna(subset=[LABEL_COLUMN_NAME]) | |
| 199 .assign(**{SPLIT_COLUMN_NAME: split_series}) | |
| 200 .groupby([LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME]) | |
| 201 .size() | |
| 202 .unstack(fill_value=0) | |
| 203 .sort_index() | |
| 204 ) | |
| 205 self.label_split_counts = [ | |
| 206 { | |
| 207 "label": str(lbl), | |
| 208 "train": int(row.get(0, 0)), | |
| 209 "validation": int(row.get(1, 0)), | |
| 210 "test": int(row.get(2, 0)), | |
| 211 } | |
| 212 for lbl, row in counts.iterrows() | |
| 213 ] | |
| 214 except Exception: | |
| 215 logger.warning("Unable to capture split counts for reporting", exc_info=True) | |
| 216 self.label_split_counts = [] | |
| 217 self.split_counts = {} | |
| 188 | 218 |
| 189 self._capture_label_metadata(df) | 219 self._capture_label_metadata(df) |
| 190 | 220 |
| 191 return final_csv, split_config, split_info | 221 return final_csv, split_config, split_info |
| 192 | 222 |
| 347 "split_probabilities": self.args.split_probabilities, | 377 "split_probabilities": self.args.split_probabilities, |
| 348 "learning_rate": self.args.learning_rate, | 378 "learning_rate": self.args.learning_rate, |
| 349 "random_seed": self.args.random_seed, | 379 "random_seed": self.args.random_seed, |
| 350 "early_stop": self.args.early_stop, | 380 "early_stop": self.args.early_stop, |
| 351 "label_column_data_path": csv_path, | 381 "label_column_data_path": csv_path, |
| 382 "label_split_counts": self.label_split_counts, | |
| 383 "split_counts": self.split_counts, | |
| 352 "augmentation": self.args.augmentation, | 384 "augmentation": self.args.augmentation, |
| 353 "image_resize": self.args.image_resize, | 385 "image_resize": self.args.image_resize, |
| 354 "image_zip": self.args.image_zip, | 386 "image_zip": self.args.image_zip, |
| 355 "threshold": self.args.threshold, | 387 "threshold": self.args.threshold, |
| 356 "label_metadata": self.label_metadata, | 388 "label_metadata": self.label_metadata, |
