diff image_workflow.py @ 15:d17e3a1b8659 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
author goeckslab
date Fri, 28 Nov 2025 15:45:49 +0000
parents bcfa2e234a80
children
line wrap: on
line diff
--- a/image_workflow.py	Wed Nov 26 22:00:32 2025 +0000
+++ b/image_workflow.py	Fri Nov 28 15:45:49 2025 +0000
@@ -127,16 +127,31 @@
             logger.error("Error loading metadata file", exc_info=True)
             raise
 
-        required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME}
-        missing = required - set(df.columns)
-        if missing:
-            raise ValueError(f"Missing CSV columns: {', '.join(missing)}")
+        label_col = self.args.target_column or LABEL_COLUMN_NAME
+        image_col = self.args.image_column or IMAGE_PATH_COLUMN_NAME
+
+        # Remember the user-specified columns for reporting
+        self.args.report_target_column = label_col
+        self.args.report_image_column = image_col
+
+        missing_cols = []
+        if label_col not in df.columns:
+            missing_cols.append(label_col)
+        if image_col not in df.columns:
+            missing_cols.append(image_col)
+        if missing_cols:
+            raise ValueError(
+                f"Missing required column(s) in metadata: {', '.join(missing_cols)}. "
+                "Update the XML selections or rename your columns."
+            )
+
+        if label_col != LABEL_COLUMN_NAME:
+            df = df.rename(columns={label_col: LABEL_COLUMN_NAME})
+        if image_col != IMAGE_PATH_COLUMN_NAME:
+            df = df.rename(columns={image_col: IMAGE_PATH_COLUMN_NAME})
 
         try:
-            # Use relative paths that Ludwig can resolve from its internal working directory
-            df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply(
-                lambda p: str(Path("images") / p)
-            )
+            df = self._map_image_paths_with_search(df)
         except Exception:
             logger.error("Error updating image paths", exc_info=True)
             raise
@@ -205,6 +220,71 @@
         self.label_metadata = metadata
         self.output_type_hint = "binary" if metadata.get("is_binary") else None
 
+    def _map_image_paths_with_search(self, df: pd.DataFrame) -> pd.DataFrame:
+        """Map image identifiers to actual files by searching the extracted directory."""
+        if not self.image_extract_dir:
+            raise RuntimeError("Image directory is not initialized.")
+
+        # Build lookup maps for fast resolution by stem or full name
+        lookup_by_stem = {}
+        lookup_by_name = {}
+        for fpath in self.image_extract_dir.rglob("*"):
+            if fpath.is_file():
+                stem_key = fpath.stem.lower()
+                name_key = fpath.name.lower()
+                # Prefer first encounter; warn on collisions
+                if stem_key in lookup_by_stem and lookup_by_stem[stem_key] != fpath:
+                    logger.warning(
+                        "Multiple files share the same stem '%s'. Using '%s'.",
+                        stem_key,
+                        lookup_by_stem[stem_key],
+                    )
+                else:
+                    lookup_by_stem[stem_key] = fpath
+                if name_key in lookup_by_name and lookup_by_name[name_key] != fpath:
+                    logger.warning(
+                        "Multiple files share the same name '%s'. Using '%s'.",
+                        name_key,
+                        lookup_by_name[name_key],
+                    )
+                else:
+                    lookup_by_name[name_key] = fpath
+
+        resolved_paths = []
+        missing_count = 0
+        missing_samples = []
+
+        for raw in df[IMAGE_PATH_COLUMN_NAME]:
+            raw_str = str(raw)
+            name_key = Path(raw_str).name.lower()
+            stem_key = Path(raw_str).stem.lower()
+            resolved = lookup_by_name.get(name_key) or lookup_by_stem.get(stem_key)
+
+            if resolved is None:
+                missing_count += 1
+                missing_samples.append(raw_str)
+                resolved_paths.append(pd.NA)
+                continue
+
+            try:
+                rel_path = resolved.relative_to(self.image_extract_dir)
+            except ValueError:
+                rel_path = resolved
+            resolved_paths.append(str(Path("images") / rel_path))
+
+        if missing_count:
+            logger.warning(
+                "Unable to locate %d image(s) from the metadata in the extracted images directory.",
+                missing_count,
+            )
+            preview = ", ".join(missing_samples[:5])
+            logger.warning("Missing samples (showing up to 5): %s", preview)
+
+        df = df.copy()
+        df[IMAGE_PATH_COLUMN_NAME] = resolved_paths
+        df = df.dropna(subset=[IMAGE_PATH_COLUMN_NAME]).reset_index(drop=True)
+        return df
+
 # Removed duplicate method
 
     def _detect_image_dimensions(self) -> Tuple[int, int]:
@@ -275,6 +355,9 @@
                 "threshold": self.args.threshold,
                 "label_metadata": self.label_metadata,
                 "output_type_hint": self.output_type_hint,
+                "validation_metric": self.args.validation_metric,
+                "target_column": getattr(self.args, "report_target_column", LABEL_COLUMN_NAME),
+                "image_column": getattr(self.args, "report_image_column", IMAGE_PATH_COLUMN_NAME),
             }
             yaml_str = self.backend.prepare_config(backend_args, split_cfg)
 
@@ -297,6 +380,9 @@
 
             if ran_ok:
                 logger.info("Workflow completed successfully.")
+                # Convert predictions parquet → csv
+                self.backend.convert_parquet_to_csv(self.args.output_dir)
+                logger.info("Converted Parquet to CSV.")
                 # Generate a very small set of plots to conserve disk space
                 self.backend.generate_plots(self.args.output_dir)
                 # Build HTML report (robust to missing metrics)
@@ -307,9 +393,6 @@
                     split_info,
                 )
                 logger.info(f"HTML report generated at: {report_file}")
-                # Convert predictions parquet → csv
-                self.backend.convert_parquet_to_csv(self.args.output_dir)
-                logger.info("Converted Parquet to CSV.")
                 # Post-process cleanup to reduce disk footprint for subsequent tests
                 try:
                     self._postprocess_cleanup(self.args.output_dir)