Mercurial > repos > goeckslab > image_learner
diff image_learner_cli.py @ 9:9e912fce264c draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit eace0d7c2b2939029c052991d238a54947d2e191
author | goeckslab |
---|---|
date | Wed, 27 Aug 2025 21:02:48 +0000 |
parents | 85e6f4b2ad18 |
children |
line wrap: on
line diff
--- a/image_learner_cli.py Thu Aug 14 14:53:10 2025 +0000 +++ b/image_learner_cli.py Wed Aug 27 21:02:48 2025 +0000 @@ -21,7 +21,7 @@ SPLIT_COLUMN_NAME, TEMP_CONFIG_FILENAME, TEMP_CSV_FILENAME, - TEMP_DIR_PREFIX + TEMP_DIR_PREFIX, ) from ludwig.globals import ( DESCRIPTION_FILE_NAME, @@ -38,13 +38,13 @@ encode_image_to_base64, get_html_closing, get_html_template, - get_metrics_help_modal + get_metrics_help_modal, ) # --- Logging Setup --- logging.basicConfig( level=logging.INFO, - format='%(asctime)s %(levelname)s %(name)s: %(message)s', + format="%(asctime)s %(levelname)s %(name)s: %(message)s", ) logger = logging.getLogger("ImageLearner") @@ -67,7 +67,9 @@ "early_stop", "threshold", ] + rows = [] + for key in display_keys: val = config.get(key, None) if key == "threshold": @@ -134,7 +136,9 @@ val_str = val else: val_str = val if val is not None else "N/A" - if val_str == "N/A" and key not in ["task_type"]: # Skip if N/A for non-essential + if val_str == "N/A" and key not in [ + "task_type" + ]: # Skip if N/A for non-essential continue rows.append( f"<tr>" @@ -166,7 +170,7 @@ <th style="padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Value</th> </tr></thead> <tbody> - {''.join(rows)} + {"".join(rows)} </tbody> </table> </div><br> @@ -251,6 +255,7 @@ "roc_auc": get_last_value(label_stats, "roc_auc"), "hits_at_k": get_last_value(label_stats, "hits_at_k"), } + # Test metrics: dynamic extraction according to exclusions test_label_stats = test_stats.get("label", {}) if not test_label_stats: @@ -258,11 +263,13 @@ else: combined_stats = test_stats.get("combined", {}) overall_stats = test_label_stats.get("overall_stats", {}) + # Define exclusions if output_type == "binary": exclude = {"per_class_stats", "precision_recall_curve", "roc_curve"} else: exclude = {"per_class_stats", "confusion_matrix"} + # 1. Get all scalar test_label_stats not excluded test_metrics = {} for k, v in test_label_stats.items(): @@ -272,9 +279,11 @@ continue if isinstance(v, (int, float, str, bool)): test_metrics[k] = v + # 2. Add overall_stats (flattened) for k, v in overall_stats.items(): test_metrics[k] = v + # 3. Optionally include combined/loss if present and not already if "loss" in combined_stats and "loss" not in test_metrics: test_metrics["loss"] = combined_stats["loss"] @@ -315,8 +324,10 @@ te = all_metrics["test"].get(metric_key) if all(x is not None for x in [t, v, te]): rows.append([display_name, f"{t:.4f}", f"{v:.4f}", f"{te:.4f}"]) + if not rows: return "<table><tr><td>No metric values found.</td></tr></table>" + html = ( "<h2 style='text-align: center;'>Model Performance Summary</h2>" "<div style='display: flex; justify-content: center;'>" @@ -331,7 +342,7 @@ for row in rows: html += generate_table_row( row, - "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;" + "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;", ) html += "</tbody></table></div><br>" return html @@ -357,8 +368,10 @@ v = all_metrics["validation"].get(metric_key) if t is not None and v is not None: rows.append([display_name, f"{t:.4f}", f"{v:.4f}"]) + if not rows: return "<table><tr><td>No metric values found for Train/Validation.</td></tr></table>" + html = ( "<h2 style='text-align: center;'>Train/Validation Performance Summary</h2>" "<div style='display: flex; justify-content: center;'>" @@ -372,7 +385,7 @@ for row in rows: html += generate_table_row( row, - "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;" + "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;", ) html += "</tbody></table></div><br>" return html @@ -393,8 +406,10 @@ value = test_metrics[key] if value is not None: rows.append([display_name, f"{value:.4f}"]) + if not rows: return "<table><tr><td>No test metric values found.</td></tr></table>" + html = ( "<h2 style='text-align: center;'>Test Performance Summary</h2>" "<div style='display: flex; justify-content: center;'>" @@ -407,7 +422,7 @@ for row in rows: html += generate_table_row( row, - "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;" + "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;", ) html += "</tbody></table></div><br>" return html @@ -436,10 +451,14 @@ min_samples_per_class = label_counts.min() if min_samples_per_class * validation_size < 1: # Adjust validation_size to ensure at least 1 sample per class, but do not exceed original validation_size - adjusted_validation_size = min(validation_size, 1.0 / min_samples_per_class) + adjusted_validation_size = min( + validation_size, 1.0 / min_samples_per_class + ) if adjusted_validation_size != validation_size: validation_size = adjusted_validation_size - logger.info(f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation") + logger.info( + f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation" + ) stratify_arr = out.loc[idx_train, label_column] logger.info("Using stratified split for validation set") else: @@ -486,7 +505,9 @@ # initialize split column out[split_column] = 0 if not label_column or label_column not in out.columns: - logger.warning("No label column found; using random split without stratification") + logger.warning( + "No label column found; using random split without stratification" + ) # fall back to simple random assignment indices = out.index.tolist() np.random.seed(random_state) @@ -529,7 +550,9 @@ stratify=out[label_column], ) # second split: separate training and validation from remaining data - val_size_adjusted = split_probabilities[1] / (split_probabilities[0] + split_probabilities[1]) + val_size_adjusted = split_probabilities[1] / ( + split_probabilities[0] + split_probabilities[1] + ) train_idx, val_idx = train_test_split( train_val_idx, test_size=val_size_adjusted, @@ -541,12 +564,15 @@ out.loc[val_idx, split_column] = 1 out.loc[test_idx, split_column] = 2 logger.info("Successfully applied stratified random split") - logger.info(f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}") + logger.info( + f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}" + ) return out.astype({split_column: int}) class Backend(Protocol): """Interface for a machine learning backend.""" + def prepare_config( self, config_params: Dict[str, Any], @@ -578,12 +604,14 @@ class LudwigDirectBackend: """Backend for running Ludwig experiments directly via the internal experiment_cli function.""" + def prepare_config( self, config_params: Dict[str, Any], split_config: Dict[str, Any], ) -> str: logger.info("LudwigDirectBackend: Preparing YAML configuration.") + model_name = config_params.get("model_name", "resnet18") use_pretrained = config_params.get("use_pretrained", False) fine_tune = config_params.get("fine_tune", False) @@ -606,7 +634,9 @@ } else: encoder_config = {"type": raw_encoder} + batch_size_cfg = batch_size or "auto" + label_column_path = config_params.get("label_column_data_path") label_series = None if label_column_path is not None and Path(label_column_path).exists(): @@ -614,6 +644,7 @@ label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME] except Exception as e: logger.warning(f"Could not read label column for task detection: {e}") + if ( label_series is not None and ptypes.is_numeric_dtype(label_series.dtype) @@ -622,7 +653,9 @@ task_type = "regression" else: task_type = "classification" + config_params["task_type"] = task_type + image_feat: Dict[str, Any] = { "name": IMAGE_PATH_COLUMN_NAME, "type": "image", @@ -630,6 +663,7 @@ } if config_params.get("augmentation") is not None: image_feat["augmentation"] = config_params["augmentation"] + if task_type == "regression": output_feat = { "name": LABEL_COLUMN_NAME, @@ -645,6 +679,7 @@ }, } val_metric = config_params.get("validation_metric", "mean_squared_error") + else: num_unique_labels = ( label_series.nunique() if label_series is not None else 2 @@ -654,6 +689,7 @@ if output_type == "binary" and config_params.get("threshold") is not None: output_feat["threshold"] = float(config_params["threshold"]) val_metric = None + conf: Dict[str, Any] = { "model_type": "ecd", "input_features": [image_feat], @@ -673,6 +709,7 @@ "in_memory": False, }, } + logger.debug("LudwigDirectBackend: Config dict built.") try: yaml_str = yaml.dump(conf, sort_keys=False, indent=2) @@ -694,6 +731,7 @@ ) -> None: """Invoke Ludwig's internal experiment_cli function to run the experiment.""" logger.info("LudwigDirectBackend: Starting experiment execution.") + try: from ludwig.experiment import experiment_cli except ImportError as e: @@ -702,7 +740,9 @@ exc_info=True, ) raise RuntimeError("Ludwig import failed.") from e + output_dir.mkdir(parents=True, exist_ok=True) + try: experiment_cli( dataset=str(dataset_path), @@ -733,13 +773,16 @@ output_dir.glob("experiment_run*"), key=lambda p: p.stat().st_mtime, ) + if not exp_dirs: logger.warning(f"No experiment run directories found in {output_dir}") return None + progress_file = exp_dirs[-1] / "model" / "training_progress.json" if not progress_file.exists(): logger.warning(f"No training_progress.json found in {progress_file}") return None + try: with progress_file.open("r", encoding="utf-8") as f: data = json.load(f) @@ -775,6 +818,7 @@ def generate_plots(self, output_dir: Path) -> None: """Generate all registered Ludwig visualizations for the latest experiment run.""" logger.info("Generating all Ludwig visualizations…") + test_plots = { "compare_performance", "compare_classifiers_performance_from_prob", @@ -798,6 +842,7 @@ "learning_curves", "compare_classifiers_performance_subset", } + output_dir = Path(output_dir) exp_dirs = sorted( output_dir.glob("experiment_run*"), @@ -807,6 +852,7 @@ logger.warning(f"No experiment run dirs found in {output_dir}") return exp_dir = exp_dirs[-1] + viz_dir = exp_dir / "visualizations" viz_dir.mkdir(exist_ok=True) train_viz = viz_dir / "train" @@ -821,6 +867,7 @@ test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME) probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME) gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME) + dataset_path = None split_file = None desc = exp_dir / DESCRIPTION_FILE_NAME @@ -829,6 +876,7 @@ cfg = json.load(f) dataset_path = _check(Path(cfg.get("dataset", ""))) split_file = _check(Path(get_split_path(cfg.get("dataset", "")))) + output_feature = "" if desc.exists(): try: @@ -839,6 +887,7 @@ with open(test_stats, "r") as f: stats = json.load(f) output_feature = next(iter(stats.keys()), "") + viz_registry = get_visualizations_registry() for viz_name, viz_func in viz_registry.items(): if viz_name in train_plots: @@ -847,6 +896,7 @@ viz_dir_plot = test_viz else: continue + try: viz_func( training_statistics=[training_stats] if training_stats else [], @@ -866,6 +916,7 @@ logger.info(f"✔ Generated {viz_name}") except Exception as e: logger.warning(f"✘ Skipped {viz_name}: {e}") + logger.info(f"All visualizations written to {viz_dir}") def generate_html_report( @@ -881,6 +932,7 @@ report_path = cwd / report_name output_dir = Path(output_dir) output_type = None + exp_dirs = sorted( output_dir.glob("experiment_run*"), key=lambda p: p.stat().st_mtime, @@ -888,11 +940,14 @@ if not exp_dirs: raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}") exp_dir = exp_dirs[-1] + base_viz_dir = exp_dir / "visualizations" train_viz_dir = base_viz_dir / "train" test_viz_dir = base_viz_dir / "test" + html = get_html_template() html += f"<h1>{title}</h1>" + metrics_html = "" train_val_metrics_html = "" test_metrics_html = "" @@ -918,11 +973,12 @@ logger.warning( f"Could not load stats for HTML report: {type(e).__name__}: {e}" ) + config_html = "" training_progress = self.get_training_process(output_dir) try: config_html = format_config_table_html( - config, split_info, training_progress + config, split_info, training_progress, output_type ) except Exception as e: logger.warning(f"Could not load config for HTML report: {e}") @@ -936,7 +992,8 @@ imgs = list(dir_path.glob("*.png")) # --- EXCLUDE Ludwig's base confusion matrix and any top-N confusion_matrix files --- imgs = [ - img for img in imgs + img + for img in imgs if not ( img.name == "confusion_matrix.png" or img.name.startswith("confusion_matrix__label_top") @@ -972,7 +1029,9 @@ valid_imgs = [img for img in imgs if img.name not in unwanted] img_map = {img.name: img for img in valid_imgs} ordered = [img_map[n] for n in display_order if n in img_map] - others = sorted(img for img in valid_imgs if img.name not in display_order) + others = sorted( + img for img in valid_imgs if img.name not in display_order + ) imgs = ordered + others else: # regression: just sort whatever's left @@ -1012,7 +1071,9 @@ df_pred = df_preds[[pred_col]].rename(columns={pred_col: "prediction"}) # 2) load ground truth for the test split from prepared CSV df_all = pd.read_csv(config["label_column_data_path"]) - df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][LABEL_COLUMN_NAME].reset_index(drop=True) + df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][ + LABEL_COLUMN_NAME + ].reset_index(drop=True) # 3) concatenate side-by-side df_table = pd.concat([df_gt, df_pred], axis=1) df_table.columns = [LABEL_COLUMN_NAME, "prediction"] @@ -1036,7 +1097,9 @@ for plot in interactive_plots: # 2) inject the static "roc_curves_from_prediction_statistics.png" if plot["title"] == "ROC-AUC": - static_img = test_viz_dir / "roc_curves_from_prediction_statistics.png" + static_img = ( + test_viz_dir / "roc_curves_from_prediction_statistics.png" + ) if static_img.exists(): b64 = encode_image_to_base64(str(static_img)) tab3_content += ( @@ -1054,14 +1117,13 @@ + plot["html"] ) tab3_content += render_img_section( - "Test Visualizations", - test_viz_dir, - output_type + "Test Visualizations", test_viz_dir, output_type ) # assemble the tabs and help modal tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) modal_html = get_metrics_help_modal() html += tabbed_html + modal_html + get_html_closing() + try: with open(report_path, "w") as f: f.write(html) @@ -1069,11 +1131,13 @@ except Exception as e: logger.error(f"Failed to write HTML report: {e}") raise + return report_path class WorkflowOrchestrator: """Manages the image-classification workflow.""" + def __init__(self, args: argparse.Namespace, backend: Backend): self.args = args self.backend = backend @@ -1113,16 +1177,19 @@ """Load CSV, update image paths, handle splits, and write prepared CSV.""" if not self.temp_dir or not self.image_extract_dir: raise RuntimeError("Temp dirs not initialized before data prep.") + try: df = pd.read_csv(self.args.csv_file) logger.info(f"Loaded CSV: {self.args.csv_file}") except Exception: logger.error("Error loading CSV file", exc_info=True) raise + required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME} missing = required - set(df.columns) if missing: raise ValueError(f"Missing CSV columns: {', '.join(missing)}") + try: df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply( lambda p: str((self.image_extract_dir / p).resolve()) @@ -1150,13 +1217,16 @@ f"{[int(p * 100) for p in self.args.split_probabilities]}% " f"for train/val/test with balanced label distribution." ) + final_csv = self.temp_dir / TEMP_CSV_FILENAME + try: df.to_csv(final_csv, index=False) logger.info(f"Saved prepared data to {final_csv}") except Exception: logger.error("Error saving prepared CSV", exc_info=True) raise + return final_csv, split_config, split_info def _process_fixed_split( @@ -1171,6 +1241,7 @@ ) if df[SPLIT_COLUMN_NAME].isna().any(): logger.warning("Split column contains non-numeric/missing values.") + unique = set(df[SPLIT_COLUMN_NAME].dropna().unique()) logger.info(f"Unique split values: {unique}") if unique == {0, 2}: @@ -1193,7 +1264,9 @@ logger.info("Using fixed split as-is.") else: raise ValueError(f"Unexpected split values: {unique}") + return df, {"type": "fixed", "column": SPLIT_COLUMN_NAME}, split_info + except Exception: logger.error("Error processing fixed split", exc_info=True) raise @@ -1209,11 +1282,14 @@ """Execute the full workflow end-to-end.""" logger.info("Starting workflow...") self.args.output_dir.mkdir(parents=True, exist_ok=True) + try: self._create_temp_dirs() self._extract_images() csv_path, split_cfg, split_info = self._prepare_data() + use_pretrained = self.args.use_pretrained or self.args.fine_tune + backend_args = { "model_name": self.args.model_name, "fine_tune": self.args.fine_tune, @@ -1230,9 +1306,11 @@ "threshold": self.args.threshold, } yaml_str = self.backend.prepare_config(backend_args, split_cfg) + config_file = self.temp_dir / TEMP_CONFIG_FILENAME config_file.write_text(yaml_str) logger.info(f"Wrote backend config: {config_file}") + self.backend.run_experiment( csv_path, config_file, @@ -1374,8 +1452,7 @@ action=SplitProbAction, default=[0.7, 0.1, 0.2], help=( - "Random split proportions (e.g., 0.7 0.1 0.2)." - "Only used if no split column." + "Random split proportions (e.g., 0.7 0.1 0.2).Only used if no split column." ), ) parser.add_argument( @@ -1408,9 +1485,10 @@ help=( "Decision threshold for binary classification (0.0–1.0)." "Overrides default 0.5." - ) + ), ) args = parser.parse_args() + if not 0.0 <= args.validation_size <= 1.0: parser.error("validation-size must be between 0.0 and 1.0") if not args.csv_file.is_file(): @@ -1423,8 +1501,10 @@ setattr(args, "augmentation", augmentation_setup) except ValueError as e: parser.error(str(e)) + backend_instance = LudwigDirectBackend() orchestrator = WorkflowOrchestrator(args, backend_instance) + exit_code = 0 try: orchestrator.run() @@ -1439,6 +1519,7 @@ if __name__ == "__main__": try: import ludwig + logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}") except ImportError: logger.error( @@ -1446,4 +1527,5 @@ "('pip install ludwig[image]')" ) sys.exit(1) + main()