comparison image_learner_cli.py @ 7:801a8b6973fb draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 67df782ea551181e1d240d463764016ba528eba9
author goeckslab
date Fri, 08 Aug 2025 13:06:28 +0000
parents 09904b1f61f5
children
comparison
equal deleted inserted replaced
6:09904b1f61f5 7:801a8b6973fb
7 import tempfile 7 import tempfile
8 import zipfile 8 import zipfile
9 from pathlib import Path 9 from pathlib import Path
10 from typing import Any, Dict, Optional, Protocol, Tuple 10 from typing import Any, Dict, Optional, Protocol, Tuple
11 11
12 import numpy as np
12 import pandas as pd 13 import pandas as pd
13 import pandas.api.types as ptypes 14 import pandas.api.types as ptypes
14 import yaml 15 import yaml
15 from constants import ( 16 from constants import (
16 IMAGE_PATH_COLUMN_NAME, 17 IMAGE_PATH_COLUMN_NAME,
416 417
417 418
418 def split_data_0_2( 419 def split_data_0_2(
419 df: pd.DataFrame, 420 df: pd.DataFrame,
420 split_column: str, 421 split_column: str,
421 validation_size: float = 0.15, 422 validation_size: float = 0.1,
422 random_state: int = 42, 423 random_state: int = 42,
423 label_column: Optional[str] = None, 424 label_column: Optional[str] = None,
424 ) -> pd.DataFrame: 425 ) -> pd.DataFrame:
425 """Given a DataFrame whose split_column only contains {0,2}, re-assign a portion of the 0s to become 1s (validation).""" 426 """Given a DataFrame whose split_column only contains {0,2}, re-assign a portion of the 0s to become 1s (validation)."""
426 out = df.copy() 427 out = df.copy()
429 idx_train = out.index[out[split_column] == 0].tolist() 430 idx_train = out.index[out[split_column] == 0].tolist()
430 431
431 if not idx_train: 432 if not idx_train:
432 logger.info("No rows with split=0; nothing to do.") 433 logger.info("No rows with split=0; nothing to do.")
433 return out 434 return out
435
436 # Always use stratify if possible
434 stratify_arr = None 437 stratify_arr = None
435 if label_column and label_column in out.columns: 438 if label_column and label_column in out.columns:
436 label_counts = out.loc[idx_train, label_column].value_counts() 439 label_counts = out.loc[idx_train, label_column].value_counts()
437 if label_counts.size > 1 and (label_counts.min() * validation_size) >= 1: 440 if label_counts.size > 1:
441 # Force stratify even with fewer samples - adjust validation_size if needed
442 min_samples_per_class = label_counts.min()
443 if min_samples_per_class * validation_size < 1:
444 # Adjust validation_size to ensure at least 1 sample per class, but do not exceed original validation_size
445 adjusted_validation_size = min(validation_size, 1.0 / min_samples_per_class)
446 if adjusted_validation_size != validation_size:
447 validation_size = adjusted_validation_size
448 logger.info(f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation")
438 stratify_arr = out.loc[idx_train, label_column] 449 stratify_arr = out.loc[idx_train, label_column]
450 logger.info("Using stratified split for validation set")
439 else: 451 else:
440 logger.warning( 452 logger.warning("Only one label class found; cannot stratify")
441 "Cannot stratify (too few labels); splitting without stratify." 453
442 )
443 if validation_size <= 0: 454 if validation_size <= 0:
444 logger.info("validation_size <= 0; keeping all as train.") 455 logger.info("validation_size <= 0; keeping all as train.")
445 return out 456 return out
446 if validation_size >= 1: 457 if validation_size >= 1:
447 logger.info("validation_size >= 1; moving all train → validation.") 458 logger.info("validation_size >= 1; moving all train → validation.")
448 out.loc[idx_train, split_column] = 1 459 out.loc[idx_train, split_column] = 1
449 return out 460 return out
461
462 # Always try stratified split first
450 try: 463 try:
451 train_idx, val_idx = train_test_split( 464 train_idx, val_idx = train_test_split(
452 idx_train, 465 idx_train,
453 test_size=validation_size, 466 test_size=validation_size,
454 random_state=random_state, 467 random_state=random_state,
455 stratify=stratify_arr, 468 stratify=stratify_arr,
456 ) 469 )
470 logger.info("Successfully applied stratified split")
457 except ValueError as e: 471 except ValueError as e:
458 logger.warning(f"Stratified split failed ({e}); retrying without stratify.") 472 logger.warning(f"Stratified split failed ({e}); falling back to random split.")
459 train_idx, val_idx = train_test_split( 473 train_idx, val_idx = train_test_split(
460 idx_train, 474 idx_train,
461 test_size=validation_size, 475 test_size=validation_size,
462 random_state=random_state, 476 random_state=random_state,
463 stratify=None, 477 stratify=None,
464 ) 478 )
479
465 out.loc[train_idx, split_column] = 0 480 out.loc[train_idx, split_column] = 0
466 out.loc[val_idx, split_column] = 1 481 out.loc[val_idx, split_column] = 1
467 out[split_column] = out[split_column].astype(int) 482 out[split_column] = out[split_column].astype(int)
468 return out 483 return out
484
485
486 def create_stratified_random_split(
487 df: pd.DataFrame,
488 split_column: str,
489 split_probabilities: list = [0.7, 0.1, 0.2],
490 random_state: int = 42,
491 label_column: Optional[str] = None,
492 ) -> pd.DataFrame:
493 """Create a stratified random split when no split column exists."""
494 out = df.copy()
495
496 # initialize split column
497 out[split_column] = 0
498
499 if not label_column or label_column not in out.columns:
500 logger.warning("No label column found; using random split without stratification")
501 # fall back to simple random assignment
502 indices = out.index.tolist()
503 np.random.seed(random_state)
504 np.random.shuffle(indices)
505
506 n_total = len(indices)
507 n_train = int(n_total * split_probabilities[0])
508 n_val = int(n_total * split_probabilities[1])
509
510 out.loc[indices[:n_train], split_column] = 0
511 out.loc[indices[n_train:n_train + n_val], split_column] = 1
512 out.loc[indices[n_train + n_val:], split_column] = 2
513
514 return out.astype({split_column: int})
515
516 # check if stratification is possible
517 label_counts = out[label_column].value_counts()
518 min_samples_per_class = label_counts.min()
519
520 # ensure we have enough samples for stratification:
521 # Each class must have at least as many samples as the number of splits,
522 # so that each split can receive at least one sample per class.
523 min_samples_required = len(split_probabilities)
524 if min_samples_per_class < min_samples_required:
525 logger.warning(
526 f"Insufficient samples per class for stratification (min: {min_samples_per_class}, required: {min_samples_required}); using random split"
527 )
528 # fall back to simple random assignment
529 indices = out.index.tolist()
530 np.random.seed(random_state)
531 np.random.shuffle(indices)
532
533 n_total = len(indices)
534 n_train = int(n_total * split_probabilities[0])
535 n_val = int(n_total * split_probabilities[1])
536
537 out.loc[indices[:n_train], split_column] = 0
538 out.loc[indices[n_train:n_train + n_val], split_column] = 1
539 out.loc[indices[n_train + n_val:], split_column] = 2
540
541 return out.astype({split_column: int})
542
543 logger.info("Using stratified random split for train/validation/test sets")
544
545 # first split: separate test set
546 train_val_idx, test_idx = train_test_split(
547 out.index.tolist(),
548 test_size=split_probabilities[2],
549 random_state=random_state,
550 stratify=out[label_column],
551 )
552
553 # second split: separate training and validation from remaining data
554 val_size_adjusted = split_probabilities[1] / (split_probabilities[0] + split_probabilities[1])
555 train_idx, val_idx = train_test_split(
556 train_val_idx,
557 test_size=val_size_adjusted,
558 random_state=random_state,
559 stratify=out.loc[train_val_idx, label_column],
560 )
561
562 # assign split values
563 out.loc[train_idx, split_column] = 0
564 out.loc[val_idx, split_column] = 1
565 out.loc[test_idx, split_column] = 2
566
567 logger.info("Successfully applied stratified random split")
568 logger.info(f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}")
569
570 return out.astype({split_column: int})
469 571
470 572
471 class Backend(Protocol): 573 class Backend(Protocol):
472 """Interface for a machine learning backend.""" 574 """Interface for a machine learning backend."""
473 575
1087 raise 1189 raise
1088 1190
1089 if SPLIT_COLUMN_NAME in df.columns: 1191 if SPLIT_COLUMN_NAME in df.columns:
1090 df, split_config, split_info = self._process_fixed_split(df) 1192 df, split_config, split_info = self._process_fixed_split(df)
1091 else: 1193 else:
1092 logger.info("No split column; using random split") 1194 logger.info("No split column; creating stratified random split")
1195 df = create_stratified_random_split(
1196 df=df,
1197 split_column=SPLIT_COLUMN_NAME,
1198 split_probabilities=self.args.split_probabilities,
1199 random_state=self.args.random_seed,
1200 label_column=LABEL_COLUMN_NAME,
1201 )
1093 split_config = { 1202 split_config = {
1094 "type": "random", 1203 "type": "fixed",
1095 "probabilities": self.args.split_probabilities, 1204 "column": SPLIT_COLUMN_NAME,
1096 } 1205 }
1097 split_info = ( 1206 split_info = (
1098 f"No split column in CSV. Used random split: " 1207 f"No split column in CSV. Created stratified random split: "
1099 f"{[int(p * 100) for p in self.args.split_probabilities]}% " 1208 f"{[int(p * 100) for p in self.args.split_probabilities]}% "
1100 f"for train/val/test." 1209 f"for train/val/test with balanced label distribution."
1101 ) 1210 )
1102 1211
1103 final_csv = self.temp_dir / TEMP_CSV_FILENAME 1212 final_csv = self.temp_dir / TEMP_CSV_FILENAME
1104 try: 1213 try:
1105 1214
1137 ) 1246 )
1138 split_info = ( 1247 split_info = (
1139 "Detected a split column (with values 0 and 2) in the input CSV. " 1248 "Detected a split column (with values 0 and 2) in the input CSV. "
1140 f"Used this column as a base and reassigned " 1249 f"Used this column as a base and reassigned "
1141 f"{self.args.validation_size * 100:.1f}% " 1250 f"{self.args.validation_size * 100:.1f}% "
1142 "of the training set (originally labeled 0) to validation (labeled 1)." 1251 "of the training set (originally labeled 0) to validation (labeled 1) using stratified sampling."
1143 ) 1252 )
1144 logger.info("Applied custom 0/2 split.") 1253 logger.info("Applied custom 0/2 split.")
1145 elif unique.issubset({0, 1, 2}): 1254 elif unique.issubset({0, 1, 2}):
1146 split_info = "Used user-defined split column from CSV." 1255 split_info = "Used user-defined split column from CSV."
1147 logger.info("Using fixed split as-is.") 1256 logger.info("Using fixed split as-is.")
1317 help="Where to write outputs", 1426 help="Where to write outputs",
1318 ) 1427 )
1319 parser.add_argument( 1428 parser.add_argument(
1320 "--validation-size", 1429 "--validation-size",
1321 type=float, 1430 type=float,
1322 default=0.15, 1431 default=0.1,
1323 help="Fraction for validation (0.0–1.0)", 1432 help="Fraction for validation (0.0–1.0)",
1324 ) 1433 )
1325 parser.add_argument( 1434 parser.add_argument(
1326 "--preprocessing-num-processes", 1435 "--preprocessing-num-processes",
1327 type=int, 1436 type=int,