Mercurial > repos > goeckslab > image_learner
diff image_workflow.py @ 15:d17e3a1b8659 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
| author | goeckslab |
|---|---|
| date | Fri, 28 Nov 2025 15:45:49 +0000 |
| parents | bcfa2e234a80 |
| children |
line wrap: on
line diff
--- a/image_workflow.py Wed Nov 26 22:00:32 2025 +0000 +++ b/image_workflow.py Fri Nov 28 15:45:49 2025 +0000 @@ -127,16 +127,31 @@ logger.error("Error loading metadata file", exc_info=True) raise - required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME} - missing = required - set(df.columns) - if missing: - raise ValueError(f"Missing CSV columns: {', '.join(missing)}") + label_col = self.args.target_column or LABEL_COLUMN_NAME + image_col = self.args.image_column or IMAGE_PATH_COLUMN_NAME + + # Remember the user-specified columns for reporting + self.args.report_target_column = label_col + self.args.report_image_column = image_col + + missing_cols = [] + if label_col not in df.columns: + missing_cols.append(label_col) + if image_col not in df.columns: + missing_cols.append(image_col) + if missing_cols: + raise ValueError( + f"Missing required column(s) in metadata: {', '.join(missing_cols)}. " + "Update the XML selections or rename your columns." + ) + + if label_col != LABEL_COLUMN_NAME: + df = df.rename(columns={label_col: LABEL_COLUMN_NAME}) + if image_col != IMAGE_PATH_COLUMN_NAME: + df = df.rename(columns={image_col: IMAGE_PATH_COLUMN_NAME}) try: - # Use relative paths that Ludwig can resolve from its internal working directory - df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply( - lambda p: str(Path("images") / p) - ) + df = self._map_image_paths_with_search(df) except Exception: logger.error("Error updating image paths", exc_info=True) raise @@ -205,6 +220,71 @@ self.label_metadata = metadata self.output_type_hint = "binary" if metadata.get("is_binary") else None + def _map_image_paths_with_search(self, df: pd.DataFrame) -> pd.DataFrame: + """Map image identifiers to actual files by searching the extracted directory.""" + if not self.image_extract_dir: + raise RuntimeError("Image directory is not initialized.") + + # Build lookup maps for fast resolution by stem or full name + lookup_by_stem = {} + lookup_by_name = {} + for fpath in self.image_extract_dir.rglob("*"): + if fpath.is_file(): + stem_key = fpath.stem.lower() + name_key = fpath.name.lower() + # Prefer first encounter; warn on collisions + if stem_key in lookup_by_stem and lookup_by_stem[stem_key] != fpath: + logger.warning( + "Multiple files share the same stem '%s'. Using '%s'.", + stem_key, + lookup_by_stem[stem_key], + ) + else: + lookup_by_stem[stem_key] = fpath + if name_key in lookup_by_name and lookup_by_name[name_key] != fpath: + logger.warning( + "Multiple files share the same name '%s'. Using '%s'.", + name_key, + lookup_by_name[name_key], + ) + else: + lookup_by_name[name_key] = fpath + + resolved_paths = [] + missing_count = 0 + missing_samples = [] + + for raw in df[IMAGE_PATH_COLUMN_NAME]: + raw_str = str(raw) + name_key = Path(raw_str).name.lower() + stem_key = Path(raw_str).stem.lower() + resolved = lookup_by_name.get(name_key) or lookup_by_stem.get(stem_key) + + if resolved is None: + missing_count += 1 + missing_samples.append(raw_str) + resolved_paths.append(pd.NA) + continue + + try: + rel_path = resolved.relative_to(self.image_extract_dir) + except ValueError: + rel_path = resolved + resolved_paths.append(str(Path("images") / rel_path)) + + if missing_count: + logger.warning( + "Unable to locate %d image(s) from the metadata in the extracted images directory.", + missing_count, + ) + preview = ", ".join(missing_samples[:5]) + logger.warning("Missing samples (showing up to 5): %s", preview) + + df = df.copy() + df[IMAGE_PATH_COLUMN_NAME] = resolved_paths + df = df.dropna(subset=[IMAGE_PATH_COLUMN_NAME]).reset_index(drop=True) + return df + # Removed duplicate method def _detect_image_dimensions(self) -> Tuple[int, int]: @@ -275,6 +355,9 @@ "threshold": self.args.threshold, "label_metadata": self.label_metadata, "output_type_hint": self.output_type_hint, + "validation_metric": self.args.validation_metric, + "target_column": getattr(self.args, "report_target_column", LABEL_COLUMN_NAME), + "image_column": getattr(self.args, "report_image_column", IMAGE_PATH_COLUMN_NAME), } yaml_str = self.backend.prepare_config(backend_args, split_cfg) @@ -297,6 +380,9 @@ if ran_ok: logger.info("Workflow completed successfully.") + # Convert predictions parquet → csv + self.backend.convert_parquet_to_csv(self.args.output_dir) + logger.info("Converted Parquet to CSV.") # Generate a very small set of plots to conserve disk space self.backend.generate_plots(self.args.output_dir) # Build HTML report (robust to missing metrics) @@ -307,9 +393,6 @@ split_info, ) logger.info(f"HTML report generated at: {report_file}") - # Convert predictions parquet → csv - self.backend.convert_parquet_to_csv(self.args.output_dir) - logger.info("Converted Parquet to CSV.") # Post-process cleanup to reduce disk footprint for subsequent tests try: self._postprocess_cleanup(self.args.output_dir)
