diff pycaret_train.py @ 8:1aed7d47c5ec draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
author goeckslab
date Fri, 25 Jul 2025 19:02:32 +0000
parents f4cb41f458fd
children
line wrap: on
line diff
--- a/pycaret_train.py	Wed Jul 09 01:13:01 2025 +0000
+++ b/pycaret_train.py	Fri Jul 25 19:02:32 2025 +0000
@@ -12,68 +12,123 @@
     parser = argparse.ArgumentParser()
     parser.add_argument("--input_file", help="Path to the input file")
     parser.add_argument("--target_col", help="Column number of the target")
-    parser.add_argument("--output_dir",
-                        help="Path to the output directory")
-    parser.add_argument("--model_type",
-                        choices=["classification", "regression"],
-                        help="Type of the model")
-    parser.add_argument("--train_size", type=float,
-                        default=None,
-                        help="Train size for PyCaret setup")
-    parser.add_argument("--normalize", action="store_true",
-                        default=None,
-                        help="Normalize data for PyCaret setup")
-    parser.add_argument("--feature_selection", action="store_true",
-                        default=None,
-                        help="Perform feature selection for PyCaret setup")
-    parser.add_argument("--cross_validation", action="store_true",
-                        default=None,
-                        help="Perform cross-validation for PyCaret setup")
-    parser.add_argument("--no_cross_validation", action="store_true",
-                        default=None,
-                        help="Don't perform cross-validation for PyCaret setup")
-    parser.add_argument("--cross_validation_folds", type=int,
-                        default=None,
-                        help="Number of cross-validation folds \
-                          for PyCaret setup")
-    parser.add_argument("--remove_outliers", action="store_true",
-                        default=None,
-                        help="Remove outliers for PyCaret setup")
-    parser.add_argument("--remove_multicollinearity", action="store_true",
-                        default=None,
-                        help="Remove multicollinearity for PyCaret setup")
-    parser.add_argument("--polynomial_features", action="store_true",
-                        default=None,
-                        help="Generate polynomial features for PyCaret setup")
-    parser.add_argument("--feature_interaction", action="store_true",
-                        default=None,
-                        help="Generate feature interactions for PyCaret setup")
-    parser.add_argument("--feature_ratio", action="store_true",
-                        default=None,
-                        help="Generate feature ratios for PyCaret setup")
-    parser.add_argument("--fix_imbalance", action="store_true",
-                        default=None,
-                        help="Fix class imbalance for PyCaret setup")
-    parser.add_argument("--models", nargs='+',
-                        default=None,
-                        help="Selected models for training")
-    parser.add_argument("--random_seed", type=int,
-                        default=42,
-                        help="Random seed for PyCaret setup")
-    parser.add_argument("--test_file", type=str, default=None,
-                        help="Path to the test data file")
+    parser.add_argument("--output_dir", help="Path to the output directory")
+    parser.add_argument(
+        "--model_type",
+        choices=["classification", "regression"],
+        help="Type of the model",
+    )
+    parser.add_argument(
+        "--train_size",
+        type=float,
+        default=None,
+        help="Train size for PyCaret setup",
+    )
+    parser.add_argument(
+        "--normalize",
+        action="store_true",
+        default=None,
+        help="Normalize data for PyCaret setup",
+    )
+    parser.add_argument(
+        "--feature_selection",
+        action="store_true",
+        default=None,
+        help="Perform feature selection for PyCaret setup",
+    )
+    parser.add_argument(
+        "--cross_validation",
+        action="store_true",
+        default=None,
+        help="Enable cross-validation for PyCaret setup",
+    )
+    parser.add_argument(
+        "--no_cross_validation",
+        action="store_true",
+        default=None,
+        help="Disable cross-validation for PyCaret setup",
+    )
+    parser.add_argument(
+        "--cross_validation_folds",
+        type=int,
+        default=None,
+        help="Number of cross-validation folds for PyCaret setup",
+    )
+    parser.add_argument(
+        "--remove_outliers",
+        action="store_true",
+        default=None,
+        help="Remove outliers for PyCaret setup",
+    )
+    parser.add_argument(
+        "--remove_multicollinearity",
+        action="store_true",
+        default=None,
+        help="Remove multicollinearity for PyCaret setup",
+    )
+    parser.add_argument(
+        "--polynomial_features",
+        action="store_true",
+        default=None,
+        help="Generate polynomial features for PyCaret setup",
+    )
+    parser.add_argument(
+        "--feature_interaction",
+        action="store_true",
+        default=None,
+        help="Generate feature interactions for PyCaret setup",
+    )
+    parser.add_argument(
+        "--feature_ratio",
+        action="store_true",
+        default=None,
+        help="Generate feature ratios for PyCaret setup",
+    )
+    parser.add_argument(
+        "--fix_imbalance",
+        action="store_true",
+        default=None,
+        help="Fix class imbalance for PyCaret setup",
+    )
+    parser.add_argument(
+        "--models",
+        nargs="+",
+        default=None,
+        help="Selected models for training",
+    )
+    parser.add_argument(
+        "--tune_model",
+        action="store_true",
+        default=False,
+        help="Tune the best model hyperparameters after training",
+    )
+    parser.add_argument(
+        "--random_seed",
+        type=int,
+        default=42,
+        help="Random seed for PyCaret setup",
+    )
+    parser.add_argument(
+        "--test_file",
+        type=str,
+        default=None,
+        help="Path to the test data file",
+    )
 
     args = parser.parse_args()
 
-    cross_validation = True
+    # Normalize cross-validation flags: --no_cross_validation overrides --cross_validation
     if args.no_cross_validation:
-        cross_validation = False
+        args.cross_validation = False
+    # If --cross_validation was passed, args.cross_validation is True
+    # If neither was passed, args.cross_validation remains None
 
+    # Build the model_kwargs dict from CLI args
     model_kwargs = {
         "train_size": args.train_size,
         "normalize": args.normalize,
         "feature_selection": args.feature_selection,
-        "cross_validation": cross_validation,
+        "cross_validation": args.cross_validation,
         "cross_validation_folds": args.cross_validation_folds,
         "remove_outliers": args.remove_outliers,
         "remove_multicollinearity": args.remove_multicollinearity,
@@ -81,17 +136,19 @@
         "feature_interaction": args.feature_interaction,
         "feature_ratio": args.feature_ratio,
         "fix_imbalance": args.fix_imbalance,
+        "tune_model": args.tune_model,
     }
     LOG.info(f"Model kwargs: {model_kwargs}")
 
-    # Remove None values from model_kwargs
-
-    LOG.info(f"Model kwargs 2: {model_kwargs}")
+    # If the XML passed a comma-separated string in a single list element, split it out
     if args.models:
         model_kwargs["models"] = args.models[0].split(",")
 
+    # Drop None entries so PyCaret uses its default values
     model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None}
+    LOG.info(f"Model kwargs 2: {model_kwargs}")
 
+    # Instantiate the appropriate trainer
     if args.model_type == "classification":
         trainer = ClassificationModelTrainer(
             args.input_file,
@@ -100,10 +157,11 @@
             args.model_type,
             args.random_seed,
             args.test_file,
-            **model_kwargs)
+            **model_kwargs,
+        )
     elif args.model_type == "regression":
-        if "fix_imbalance" in model_kwargs:
-            del model_kwargs["fix_imbalance"]
+        # regression doesn't support fix_imbalance
+        model_kwargs.pop("fix_imbalance", None)
         trainer = RegressionModelTrainer(
             args.input_file,
             args.target_col,
@@ -111,11 +169,12 @@
             args.model_type,
             args.random_seed,
             args.test_file,
-            **model_kwargs)
+            **model_kwargs,
+        )
     else:
-        LOG.error("Invalid model type. Please choose \
-                  'classification' or 'regression'.")
+        LOG.error("Invalid model type. Please choose 'classification' or 'regression'.")
         return
+
     trainer.run()