+
+
+"""
+
+
+def split_data_0_2(
+ df: pd.DataFrame,
+ split_column: str,
+ validation_size: float = 0.15,
+ random_state: int = 42,
+ label_column: Optional[str] = None,
+) -> pd.DataFrame:
+ """
+ Given a DataFrame whose split_column only contains {0,2}, re-assign
+ a portion of the 0s to become 1s (validation). Returns a fresh DataFrame.
+ """
+ # Work on a copy
+ out = df.copy()
+ # Ensure split col is integer dtype
+ out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int)
+
+ idx_train = out.index[out[split_column] == 0].tolist()
+
+ if not idx_train:
+ logger.info("No rows with split=0; nothing to do.")
+ return out
+
+ # Determine stratify array if possible
+ stratify_arr = None
+ if label_column and label_column in out.columns:
+ # Only stratify if at least two classes and enough samples
+ label_counts = out.loc[idx_train, label_column].value_counts()
+ if label_counts.size > 1 and (label_counts.min() * validation_size) >= 1:
+ stratify_arr = out.loc[idx_train, label_column]
+ else:
+ logger.warning("Cannot stratify (too few labels); splitting without stratify.")
+
+ # Edge cases
+ if validation_size <= 0:
+ logger.info("validation_size <= 0; keeping all as train.")
+ return out
+ if validation_size >= 1:
+ logger.info("validation_size >= 1; moving all train → validation.")
+ out.loc[idx_train, split_column] = 1
+ return out
+
+ # Do the split
+ try:
+ train_idx, val_idx = train_test_split(
+ idx_train,
+ test_size=validation_size,
+ random_state=random_state,
+ stratify=stratify_arr
+ )
+ except ValueError as e:
+ logger.warning(f"Stratified split failed ({e}); retrying without stratify.")
+ train_idx, val_idx = train_test_split(
+ idx_train,
+ test_size=validation_size,
+ random_state=random_state,
+ stratify=None
+ )
+
+ # Assign new splits
+ out.loc[train_idx, split_column] = 0
+ out.loc[val_idx, split_column] = 1
+ # idx_test stays at 2
+
+ # Cast back to a clean integer type
+ out[split_column] = out[split_column].astype(int)
+ # print(out)
+ return out
+
+
+class Backend(Protocol):
+ """Interface for a machine learning backend."""
+ def prepare_config(
+ self,
+ config_params: Dict[str, Any],
+ split_config: Dict[str, Any]
+ ) -> str:
+ ...
+
+ def run_experiment(
+ self,
+ dataset_path: Path,
+ config_path: Path,
+ output_dir: Path,
+ random_seed: int,
+ ) -> None:
+ ...
+
+ def generate_plots(
+ self,
+ output_dir: Path
+ ) -> None:
+ ...
+
+ def generate_html_report(
+ self,
+ title: str,
+ output_dir: str
+ ) -> Path:
+ ...
+
+
+class LudwigDirectBackend:
+ """
+ Backend for running Ludwig experiments directly via the internal experiment_cli function.
+ """
+
+ def prepare_config(
+ self,
+ config_params: Dict[str, Any],
+ split_config: Dict[str, Any],
+ ) -> str:
+ """
+ Build and serialize the Ludwig YAML configuration.
+ """
+ logger.info("LudwigDirectBackend: Preparing YAML configuration.")
+
+ model_name = config_params.get("model_name", "resnet18")
+ use_pretrained = config_params.get("use_pretrained", False)
+ fine_tune = config_params.get("fine_tune", False)
+ epochs = config_params.get("epochs", 10)
+ batch_size = config_params.get("batch_size")
+ num_processes = config_params.get("preprocessing_num_processes", 1)
+ early_stop = config_params.get("early_stop", None)
+ learning_rate = config_params.get("learning_rate")
+ learning_rate = "auto" if learning_rate is None else float(learning_rate)
+ trainable = fine_tune or (not use_pretrained)
+ if not use_pretrained and not trainable:
+ logger.warning("trainable=False; use_pretrained=False is ignored.")
+ logger.warning("Setting trainable=True to train the model from scratch.")
+ trainable = True
+
+ # Encoder setup
+ raw_encoder = MODEL_ENCODER_TEMPLATES.get(model_name, model_name)
+ if isinstance(raw_encoder, dict):
+ encoder_config = {
+ **raw_encoder,
+ "use_pretrained": use_pretrained,
+ "trainable": trainable,
+ }
+ else:
+ encoder_config = {"type": raw_encoder}
+
+ # Trainer & optimizer
+ # optimizer = {"type": "adam", "learning_rate": 5e-5} if fine_tune else {"type": "adam"}
+ batch_size_cfg = batch_size or "auto"
+
+ conf: Dict[str, Any] = {
+ "model_type": "ecd",
+ "input_features": [
+ {
+ "name": IMAGE_PATH_COLUMN_NAME,
+ "type": "image",
+ "encoder": encoder_config,
+ }
+ ],
+ "output_features": [
+ {"name": LABEL_COLUMN_NAME, "type": "category"}
+ ],
+ "combiner": {"type": "concat"},
+ "trainer": {
+ "epochs": epochs,
+ "early_stop": early_stop,
+ "batch_size": batch_size_cfg,
+ "learning_rate": learning_rate,
+ },
+ "preprocessing": {
+ "split": split_config,
+ "num_processes": num_processes,
+ "in_memory": False,
+ },
+ }
+
+ logger.debug("LudwigDirectBackend: Config dict built.")
+ try:
+ yaml_str = yaml.dump(conf, sort_keys=False, indent=2)
+ logger.info("LudwigDirectBackend: YAML config generated.")
+ return yaml_str
+ except Exception:
+ logger.error("LudwigDirectBackend: Failed to serialize YAML.", exc_info=True)
+ raise
+
+ def run_experiment(
+ self,
+ dataset_path: Path,
+ config_path: Path,
+ output_dir: Path,
+ random_seed: int = 42,
+ ) -> None:
+ """
+ Invoke Ludwig's internal experiment_cli function to run the experiment.
+ """
+ logger.info("LudwigDirectBackend: Starting experiment execution.")
+
+ try:
+ from ludwig.experiment import experiment_cli
+ except ImportError as e:
+ logger.error(
+ "LudwigDirectBackend: Could not import experiment_cli.",
+ exc_info=True
+ )
+ raise RuntimeError("Ludwig import failed.") from e
+
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ try:
+ experiment_cli(
+ dataset=str(dataset_path),
+ config=str(config_path),
+ output_directory=str(output_dir),
+ random_seed=random_seed,
+ )
+ logger.info(f"LudwigDirectBackend: Experiment completed. Results in {output_dir}")
+ except TypeError as e:
+ logger.error(
+ "LudwigDirectBackend: Argument mismatch in experiment_cli call.",
+ exc_info=True
+ )
+ raise RuntimeError("Ludwig argument error.") from e
+ except Exception:
+ logger.error(
+ "LudwigDirectBackend: Experiment execution error.",
+ exc_info=True
+ )
+ raise
+
+ def get_training_process(self, output_dir) -> float:
+ """
+ Retrieve the learning rate used in the most recent Ludwig run.
+ Returns:
+ float: learning rate (or None if not found)
+ """
+ output_dir = Path(output_dir)
+ exp_dirs = sorted(
+ output_dir.glob("experiment_run*"),
+ key=lambda p: p.stat().st_mtime
+ )
+
+ if not exp_dirs:
+ logger.warning(f"No experiment run directories found in {output_dir}")
+ return None
+
+ progress_file = exp_dirs[-1] / "model" / "training_progress.json"
+ if not progress_file.exists():
+ logger.warning(f"No training_progress.json found in {progress_file}")
+ return None
+
+ try:
+ with progress_file.open("r", encoding="utf-8") as f:
+ data = json.load(f)
+ return {
+ "learning_rate": data.get("learning_rate"),
+ "batch_size": data.get("batch_size"),
+ "epoch": data.get("epoch"),
+ }
+ except Exception as e:
+ self.logger.warning(f"Failed to read training progress info: {e}")
+ return {}
+
+ def convert_parquet_to_csv(self, output_dir: Path):
+ """Convert the predictions Parquet file to CSV."""
+ output_dir = Path(output_dir)
+ exp_dirs = sorted(
+ output_dir.glob("experiment_run*"),
+ key=lambda p: p.stat().st_mtime
+ )
+ if not exp_dirs:
+ logger.warning(f"No experiment run dirs found in {output_dir}")
+ return
+ exp_dir = exp_dirs[-1]
+ parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME
+ csv_path = exp_dir / "predictions.csv"
+ try:
+ df = pd.read_parquet(parquet_path)
+ df.to_csv(csv_path, index=False)
+ logger.info(f"Converted Parquet to CSV: {csv_path}")
+ except Exception as e:
+ logger.error(f"Error converting Parquet to CSV: {e}")
+
+ def generate_plots(self, output_dir: Path) -> None:
+ """
+ Generate _all_ registered Ludwig visualizations for the latest experiment run.
+ """
+ logger.info("Generating all Ludwig visualizations…")
+
+ test_plots = {
+ 'compare_performance',
+ 'compare_classifiers_performance_from_prob',
+ 'compare_classifiers_performance_from_pred',
+ 'compare_classifiers_performance_changing_k',
+ 'compare_classifiers_multiclass_multimetric',
+ 'compare_classifiers_predictions',
+ 'confidence_thresholding_2thresholds_2d',
+ 'confidence_thresholding_2thresholds_3d',
+ 'confidence_thresholding',
+ 'confidence_thresholding_data_vs_acc',
+ 'binary_threshold_vs_metric',
+ 'roc_curves',
+ 'roc_curves_from_test_statistics',
+ 'calibration_1_vs_all',
+ 'calibration_multiclass',
+ 'confusion_matrix',
+ 'frequency_vs_f1',
+ }
+ train_plots = {
+ 'learning_curves',
+ 'compare_classifiers_performance_subset',
+ }
+
+ # 1) find the most recent experiment directory
+ output_dir = Path(output_dir)
+ exp_dirs = sorted(
+ output_dir.glob("experiment_run*"),
+ key=lambda p: p.stat().st_mtime
+ )
+ if not exp_dirs:
+ logger.warning(f"No experiment run dirs found in {output_dir}")
+ return
+ exp_dir = exp_dirs[-1]
+
+ # 2) ensure viz output subfolder exists
+ viz_dir = exp_dir / "visualizations"
+ viz_dir.mkdir(exist_ok=True)
+ train_viz = viz_dir / "train"
+ test_viz = viz_dir / "test"
+ train_viz.mkdir(parents=True, exist_ok=True)
+ test_viz.mkdir(parents=True, exist_ok=True)
+
+ # 3) helper to check file existence
+ def _check(p: Path) -> Optional[str]:
+ return str(p) if p.exists() else None
+
+ # 4) gather standard Ludwig output files
+ training_stats = _check(exp_dir / "training_statistics.json")
+ test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME)
+ probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME)
+ gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME)
+
+ # 5) try to read original dataset & split file from description.json
+ dataset_path = None
+ split_file = None
+ desc = exp_dir / DESCRIPTION_FILE_NAME
+ if desc.exists():
+ with open(desc, "r") as f:
+ cfg = json.load(f)
+ dataset_path = _check(Path(cfg.get("dataset", "")))
+ split_file = _check(Path(get_split_path(cfg.get("dataset", ""))))
+
+ # 6) infer output feature name
+ output_feature = ""
+ if desc.exists():
+ try:
+ output_feature = cfg["config"]["output_features"][0]["name"]
+ except Exception:
+ pass
+ if not output_feature and test_stats:
+ with open(test_stats, "r") as f:
+ stats = json.load(f)
+ output_feature = next(iter(stats.keys()), "")
+
+ # 7) loop through every registered viz
+ viz_registry = get_visualizations_registry()
+ for viz_name, viz_func in viz_registry.items():
+ viz_dir_plot = None
+ if viz_name in train_plots:
+ viz_dir_plot = train_viz
+ elif viz_name in test_plots:
+ viz_dir_plot = test_viz
+
+ try:
+ viz_func(
+ training_statistics=[training_stats] if training_stats else [],
+ test_statistics=[test_stats] if test_stats else [],
+ probabilities=[probs_path] if probs_path else [],
+ output_feature_name=output_feature,
+ ground_truth_split=2,
+ top_n_classes=[0],
+ top_k=3,
+ ground_truth_metadata=gt_metadata,
+ ground_truth=dataset_path,
+ split_file=split_file,
+ output_directory=str(viz_dir_plot),
+ normalize=False,
+ file_format="png",
+ )
+ logger.info(f"✔ Generated {viz_name}")
+ except Exception as e:
+ logger.warning(f"✘ Skipped {viz_name}: {e}")
+
+ logger.info(f"All visualizations written to {viz_dir}")
+
+ def generate_html_report(
+ self,
+ title: str,
+ output_dir: str,
+ config: dict,
+ split_info: str) -> Path:
+ """
+ Assemble an HTML report from visualizations under train_val/ and test/ folders.
+ """
+ cwd = Path.cwd()
+ report_name = title.lower().replace(" ", "_") + "_report.html"
+ report_path = cwd / report_name
+ output_dir = Path(output_dir)
+
+ # Find latest experiment dir
+ exp_dirs = sorted(output_dir.glob("experiment_run*"), key=lambda p: p.stat().st_mtime)
+ if not exp_dirs:
+ raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}")
+ exp_dir = exp_dirs[-1]
+
+ base_viz_dir = exp_dir / "visualizations"
+ train_viz_dir = base_viz_dir / "train"
+ test_viz_dir = base_viz_dir / "test"
+
+ html = get_html_template()
+ html += f"
{title}
"
+
+ metrics_html = ""
+
+ # Load and embed metrics table (training/val/test stats)
+ try:
+ train_stats_path = exp_dir / "training_statistics.json"
+ test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME
+ if train_stats_path.exists() and test_stats_path.exists():
+ with open(train_stats_path) as f:
+ train_stats = json.load(f)
+ with open(test_stats_path) as f:
+ test_stats = json.load(f)
+ output_feature = next(iter(train_stats.keys()), "")
+ if output_feature:
+ metrics_html += format_stats_table_html(train_stats, test_stats)
+ except Exception as e:
+ logger.warning(f"Could not load stats for HTML report: {e}")
+
+ config_html = ""
+ training_progress = self.get_training_process(output_dir)
+ try:
+ config_html = format_config_table_html(config, split_info, training_progress)
+ except Exception as e:
+ logger.warning(f"Could not load config for HTML report: {e}")
+
+ def render_img_section(title: str, dir_path: Path) -> str:
+ if not dir_path.exists():
+ return f"
{title}
Directory not found.
"
+ imgs = sorted(dir_path.glob("*.png"))
+ if not imgs:
+ return f"
{title}
No plots found.
"
+
+ section_html = f"
{title}
"
+ for img in imgs:
+ b64 = encode_image_to_base64(str(img))
+ section_html += (
+ f'
'
+ f"
{img.stem.replace('_',' ').title()}
"
+ f''
+ "
"
+ )
+ section_html += "
"
+ return section_html
+
+ train_plots_html = render_img_section("Training & Validation Visualizations", train_viz_dir)
+ test_plots_html = render_img_section("Test Visualizations", test_viz_dir)
+ html += build_tabbed_html(config_html + metrics_html, train_plots_html, test_plots_html)
+ html += get_html_closing()
+
+ try:
+ with open(report_path, "w") as f:
+ f.write(html)
+ logger.info(f"HTML report generated at: {report_path}")
+ except Exception as e:
+ logger.error(f"Failed to write HTML report: {e}")
+ raise
+
+ return report_path
+
+
+class WorkflowOrchestrator:
+ """
+ Manages the image-classification workflow:
+ 1. Creates temp dirs
+ 2. Extracts images
+ 3. Prepares data (CSV + splits)
+ 4. Renders a backend config
+ 5. Runs the experiment
+ 6. Cleans up
+ """
+
+ def __init__(self, args: argparse.Namespace, backend: Backend):
+ self.args = args
+ self.backend = backend
+ self.temp_dir: Optional[Path] = None
+ self.image_extract_dir: Optional[Path] = None
+ logger.info(f"Orchestrator initialized with backend: {type(backend).__name__}")
+
+ def _create_temp_dirs(self) -> None:
+ """Create temporary output and image extraction directories."""
+ try:
+ self.temp_dir = Path(tempfile.mkdtemp(
+ dir=self.args.output_dir,
+ prefix=TEMP_DIR_PREFIX
+ ))
+ self.image_extract_dir = self.temp_dir / "images"
+ self.image_extract_dir.mkdir()
+ logger.info(f"Created temp directory: {self.temp_dir}")
+ except Exception:
+ logger.error("Failed to create temporary directories", exc_info=True)
+ raise
+
+ def _extract_images(self) -> None:
+ """Extract images from ZIP into the temp image directory."""
+ if self.image_extract_dir is None:
+ raise RuntimeError("Temp image directory not initialized.")
+ logger.info(f"Extracting images from {self.args.image_zip} → {self.image_extract_dir}")
+ try:
+ with zipfile.ZipFile(self.args.image_zip, "r") as z:
+ z.extractall(self.image_extract_dir)
+ logger.info("Image extraction complete.")
+ except Exception:
+ logger.error("Error extracting zip file", exc_info=True)
+ raise
+
+ def _prepare_data(self) -> Tuple[Path, Dict[str, Any]]:
+ """
+ Load CSV, update image paths, handle splits, and write prepared CSV.
+ Returns:
+ final_csv_path: Path to the prepared CSV
+ split_config: Dict for backend split settings
+ """
+ if not self.temp_dir or not self.image_extract_dir:
+ raise RuntimeError("Temp dirs not initialized before data prep.")
+
+ # 1) Load
+ try:
+ df = pd.read_csv(self.args.csv_file)
+ logger.info(f"Loaded CSV: {self.args.csv_file}")
+ except Exception:
+ logger.error("Error loading CSV file", exc_info=True)
+ raise
+
+ # 2) Validate columns
+ required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME}
+ missing = required - set(df.columns)
+ if missing:
+ raise ValueError(f"Missing CSV columns: {', '.join(missing)}")
+
+ # 3) Update image paths
+ try:
+ df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply(
+ lambda p: str((self.image_extract_dir / p).resolve())
+ )
+ except Exception:
+ logger.error("Error updating image paths", exc_info=True)
+ raise
+
+ # 4) Handle splits
+ if SPLIT_COLUMN_NAME in df.columns:
+ df, split_config, split_info = self._process_fixed_split(df)
+ else:
+ logger.info("No split column; using random split")
+ split_config = {
+ "type": "random",
+ "probabilities": self.args.split_probabilities
+ }
+ split_info = (
+ f"No split column in CSV. Used random split: "
+ f"{[int(p*100) for p in self.args.split_probabilities]}% for train/val/test."
+ )
+
+ # 5) Write out prepared CSV
+ final_csv = TEMP_CSV_FILENAME
+ try:
+ df.to_csv(final_csv, index=False)
+ logger.info(f"Saved prepared data to {final_csv}")
+ except Exception:
+ logger.error("Error saving prepared CSV", exc_info=True)
+ raise
+
+ return final_csv, split_config, split_info
+
+ def _process_fixed_split(self, df: pd.DataFrame) -> Dict[str, Any]:
+ """Process a fixed split column (0=train,1=val,2=test)."""
+ logger.info(f"Fixed split column '{SPLIT_COLUMN_NAME}' detected.")
+ try:
+ col = df[SPLIT_COLUMN_NAME]
+ df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype(pd.Int64Dtype())
+ if df[SPLIT_COLUMN_NAME].isna().any():
+ logger.warning("Split column contains non-numeric/missing values.")
+
+ unique = set(df[SPLIT_COLUMN_NAME].dropna().unique())
+ logger.info(f"Unique split values: {unique}")
+
+ if unique == {0, 2}:
+ df = split_data_0_2(
+ df, SPLIT_COLUMN_NAME,
+ validation_size=self.args.validation_size,
+ label_column=LABEL_COLUMN_NAME,
+ random_state=self.args.random_seed
+ )
+ split_info = (
+ "Detected a split column (with values 0 and 2) in the input CSV. "
+ f"Used this column as a base and"
+ f"reassigned {self.args.validation_size * 100:.1f}% "
+ "of the training set (originally labeled 0) to validation (labeled 1)."
+ )
+
+ logger.info("Applied custom 0/2 split.")
+ elif unique.issubset({0, 1, 2}):
+ split_info = "Used user-defined split column from CSV."
+ logger.info("Using fixed split as-is.")
+ else:
+ raise ValueError(f"Unexpected split values: {unique}")
+
+ return df, {"type": "fixed", "column": SPLIT_COLUMN_NAME}, split_info
+
+ except Exception:
+ logger.error("Error processing fixed split", exc_info=True)
+ raise
+
+ def _cleanup_temp_dirs(self) -> None:
+ """Remove any temporary directories."""
+ if self.temp_dir and self.temp_dir.exists():
+ logger.info(f"Cleaning up temp directory: {self.temp_dir}")
+ shutil.rmtree(self.temp_dir, ignore_errors=True)
+ self.temp_dir = None
+ self.image_extract_dir = None
+
+ def run(self) -> None:
+ """Execute the full workflow end-to-end."""
+ logger.info("Starting workflow...")
+ self.args.output_dir.mkdir(parents=True, exist_ok=True)
+
+ try:
+ self._create_temp_dirs()
+ self._extract_images()
+ csv_path, split_cfg, split_info = self._prepare_data()
+
+ use_pretrained = self.args.use_pretrained or self.args.fine_tune
+
+ backend_args = {
+ "model_name": self.args.model_name,
+ "fine_tune": self.args.fine_tune,
+ "use_pretrained": use_pretrained,
+ "epochs": self.args.epochs,
+ "batch_size": self.args.batch_size,
+ "preprocessing_num_processes": self.args.preprocessing_num_processes,
+ "split_probabilities": self.args.split_probabilities,
+ "learning_rate": self.args.learning_rate,
+ "random_seed": self.args.random_seed,
+ "early_stop": self.args.early_stop,
+ }
+ yaml_str = self.backend.prepare_config(backend_args, split_cfg)
+
+ config_file = self.temp_dir / TEMP_CONFIG_FILENAME
+ config_file.write_text(yaml_str)
+ logger.info(f"Wrote backend config: {config_file}")
+
+ self.backend.run_experiment(
+ csv_path,
+ config_file,
+ self.args.output_dir,
+ self.args.random_seed
+ )
+ logger.info("Workflow completed successfully.")
+ self.backend.generate_plots(self.args.output_dir)
+ report_file = self.backend.generate_html_report(
+ "Image Classification Results",
+ self.args.output_dir,
+ backend_args,
+ split_info
+ )
+ logger.info(f"HTML report generated at: {report_file}")
+ self.backend.convert_parquet_to_csv(self.args.output_dir)
+ logger.info("Converted Parquet to CSV.")
+ except Exception:
+ logger.error("Workflow execution failed", exc_info=True)
+ raise
+
+ finally:
+ self._cleanup_temp_dirs()
+
+
+def parse_learning_rate(s):
+ try:
+ return float(s)
+ except (TypeError, ValueError):
+ return None
+
+
+class SplitProbAction(argparse.Action):
+ def __call__(self, parser, namespace, values, option_string=None):
+ # values is a list of three floats
+ train, val, test = values
+ total = train + val + test
+ if abs(total - 1.0) > 1e-6:
+ parser.error(
+ f"--split-probabilities must sum to 1.0; "
+ f"got {train:.3f} + {val:.3f} + {test:.3f} = {total:.3f}"
+ )
+ setattr(namespace, self.dest, values)
+
+
+def main():
+
+ parser = argparse.ArgumentParser(
+ description="Image Classification Learner with Pluggable Backends"
+ )
+ parser.add_argument(
+ "--csv-file", required=True, type=Path,
+ help="Path to the input CSV"
+ )
+ parser.add_argument(
+ "--image-zip", required=True, type=Path,
+ help="Path to the images ZIP"
+ )
+ parser.add_argument(
+ "--model-name", required=True,
+ choices=MODEL_ENCODER_TEMPLATES.keys(),
+ help="Which model template to use"
+ )
+ parser.add_argument(
+ "--use-pretrained", action="store_true",
+ help="Use pretrained weights for the model"
+ )
+ parser.add_argument(
+ "--fine-tune", action="store_true",
+ help="Enable fine-tuning"
+ )
+ parser.add_argument(
+ "--epochs", type=int, default=10,
+ help="Number of training epochs"
+ )
+ parser.add_argument(
+ "--early-stop", type=int, default=5,
+ help="Early stopping patience"
+ )
+ parser.add_argument(
+ "--batch-size", type=int,
+ help="Batch size (None = auto)"
+ )
+ parser.add_argument(
+ "--output-dir", type=Path, default=Path("learner_output"),
+ help="Where to write outputs"
+ )
+ parser.add_argument(
+ "--validation-size", type=float, default=0.15,
+ help="Fraction for validation (0.0–1.0)"
+ )
+ parser.add_argument(
+ "--preprocessing-num-processes", type=int,
+ default=max(1, os.cpu_count() // 2),
+ help="CPU processes for data prep"
+ )
+ parser.add_argument(
+ "--split-probabilities", type=float, nargs=3,
+ metavar=("train", "val", "test"),
+ action=SplitProbAction,
+ default=[0.7, 0.1, 0.2],
+ help="Random split proportions (e.g., 0.7 0.1 0.2). Only used if no split column is present."
+ )
+ parser.add_argument(
+ "--random-seed", type=int, default=42,
+ help="Random seed used for dataset splitting (default: 42)"
+ )
+ parser.add_argument(
+ "--learning-rate", type=parse_learning_rate, default=None,
+ help="Learning rate. If not provided, Ludwig will auto-select it."
+ )
+
+ args = parser.parse_args()
+
+ # -- Validation --
+ if not 0.0 <= args.validation_size <= 1.0:
+ parser.error("validation-size must be between 0.0 and 1.0")
+ if not args.csv_file.is_file():
+ parser.error(f"CSV not found: {args.csv_file}")
+ if not args.image_zip.is_file():
+ parser.error(f"ZIP not found: {args.image_zip}")
+
+ # --- Instantiate Backend and Orchestrator ---
+ # Use the new LudwigDirectBackend
+ backend_instance = LudwigDirectBackend()
+ orchestrator = WorkflowOrchestrator(args, backend_instance)
+
+ # --- Run Workflow ---
+ exit_code = 0
+ try:
+ orchestrator.run()
+ logger.info("Main script finished successfully.")
+ except Exception as e:
+ logger.error(f"Main script failed.{e}")
+ exit_code = 1
+ finally:
+ sys.exit(exit_code)
+
+
+if __name__ == '__main__':
+ try:
+ import ludwig
+ logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}")
+ except ImportError:
+ logger.error("Ludwig library not found. Please ensure Ludwig is installed ('pip install ludwig[image]')")
+ sys.exit(1)
+
+ main()
diff -r 000000000000 -r 54b871dfc51e test-data/image_classification_results_report_mnist.html
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/test-data/image_classification_results_report_mnist.html Tue Jun 03 21:22:11 2025 +0000
@@ -0,0 +1,129 @@
+
+
+
+ Galaxy-Ludwig Report
+
+
+
+
+
Image Classification Results
+
+
+
+
Config & Metrics
+
Train/Validation Plots
+
Test Plots
+
+
+
+
Training Setup
Parameter
Value
Model Name
resnet18
Epochs
Because of early stopping: the trainingstopped at epoch 7
Batch Size
Auto-selected batch size by Ludwig: 16
Fine Tune
True
Use Pretrained
True
Learning Rate
Auto-selected learning rate by Ludwig: 1e-05 Based on model architecture and training setup (e.g., fine-tuning). See Ludwig Trainer Parameters for details.
Random Seedearly Stop
N/A
Data Split
Detected a split column (with values 0 and 2) in the input CSV. Used this column as a base andreassigned 15.0% of the training set (originally labeled 0) to validation (labeled 1).
Model trained using Ludwig. If want to learn more about Ludwig default settings,please check the their website(ludwig.ai).
+
+
+ """
+
+
+def encode_image_to_base64(image_path):
+ """Convert an image file to a base64 encoded string."""
+ with open(image_path, "rb") as img_file:
+ return base64.b64encode(img_file.read()).decode("utf-8")
+
+
+def json_to_nested_html_table(json_data, depth=0):
+ """
+ Convert JSON object to an HTML nested table.
+
+ Parameters:
+ json_data (dict or list): The JSON data to convert.
+ depth (int): Current depth level for indentation.
+
+ Returns:
+ str: HTML string for the nested table.
+ """
+ # Base case: if JSON is a simple key-value pair dictionary
+ if isinstance(json_data, dict) and all(
+ not isinstance(v, (dict, list)) for v in json_data.values()
+ ):
+ # Render a flat table
+ rows = [
+ f"
{key}
{value}
"
+ for key, value in json_data.items()
+ ]
+ return f"
{''.join(rows)}
"
+
+ # Base case: if JSON is a list of simple values
+ if isinstance(json_data, list) and all(
+ not isinstance(v, (dict, list)) for v in json_data
+ ):
+ rows = [
+ f"
Index {i}
{value}
"
+ for i, value in enumerate(json_data)
+ ]
+ return f"