Mercurial > repos > goeckslab > pycaret_predict
diff pycaret_train.py @ 8:1aed7d47c5ec draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
author | goeckslab |
---|---|
date | Fri, 25 Jul 2025 19:02:32 +0000 |
parents | f4cb41f458fd |
children |
line wrap: on
line diff
--- a/pycaret_train.py Wed Jul 09 01:13:01 2025 +0000 +++ b/pycaret_train.py Fri Jul 25 19:02:32 2025 +0000 @@ -12,68 +12,123 @@ parser = argparse.ArgumentParser() parser.add_argument("--input_file", help="Path to the input file") parser.add_argument("--target_col", help="Column number of the target") - parser.add_argument("--output_dir", - help="Path to the output directory") - parser.add_argument("--model_type", - choices=["classification", "regression"], - help="Type of the model") - parser.add_argument("--train_size", type=float, - default=None, - help="Train size for PyCaret setup") - parser.add_argument("--normalize", action="store_true", - default=None, - help="Normalize data for PyCaret setup") - parser.add_argument("--feature_selection", action="store_true", - default=None, - help="Perform feature selection for PyCaret setup") - parser.add_argument("--cross_validation", action="store_true", - default=None, - help="Perform cross-validation for PyCaret setup") - parser.add_argument("--no_cross_validation", action="store_true", - default=None, - help="Don't perform cross-validation for PyCaret setup") - parser.add_argument("--cross_validation_folds", type=int, - default=None, - help="Number of cross-validation folds \ - for PyCaret setup") - parser.add_argument("--remove_outliers", action="store_true", - default=None, - help="Remove outliers for PyCaret setup") - parser.add_argument("--remove_multicollinearity", action="store_true", - default=None, - help="Remove multicollinearity for PyCaret setup") - parser.add_argument("--polynomial_features", action="store_true", - default=None, - help="Generate polynomial features for PyCaret setup") - parser.add_argument("--feature_interaction", action="store_true", - default=None, - help="Generate feature interactions for PyCaret setup") - parser.add_argument("--feature_ratio", action="store_true", - default=None, - help="Generate feature ratios for PyCaret setup") - parser.add_argument("--fix_imbalance", action="store_true", - default=None, - help="Fix class imbalance for PyCaret setup") - parser.add_argument("--models", nargs='+', - default=None, - help="Selected models for training") - parser.add_argument("--random_seed", type=int, - default=42, - help="Random seed for PyCaret setup") - parser.add_argument("--test_file", type=str, default=None, - help="Path to the test data file") + parser.add_argument("--output_dir", help="Path to the output directory") + parser.add_argument( + "--model_type", + choices=["classification", "regression"], + help="Type of the model", + ) + parser.add_argument( + "--train_size", + type=float, + default=None, + help="Train size for PyCaret setup", + ) + parser.add_argument( + "--normalize", + action="store_true", + default=None, + help="Normalize data for PyCaret setup", + ) + parser.add_argument( + "--feature_selection", + action="store_true", + default=None, + help="Perform feature selection for PyCaret setup", + ) + parser.add_argument( + "--cross_validation", + action="store_true", + default=None, + help="Enable cross-validation for PyCaret setup", + ) + parser.add_argument( + "--no_cross_validation", + action="store_true", + default=None, + help="Disable cross-validation for PyCaret setup", + ) + parser.add_argument( + "--cross_validation_folds", + type=int, + default=None, + help="Number of cross-validation folds for PyCaret setup", + ) + parser.add_argument( + "--remove_outliers", + action="store_true", + default=None, + help="Remove outliers for PyCaret setup", + ) + parser.add_argument( + "--remove_multicollinearity", + action="store_true", + default=None, + help="Remove multicollinearity for PyCaret setup", + ) + parser.add_argument( + "--polynomial_features", + action="store_true", + default=None, + help="Generate polynomial features for PyCaret setup", + ) + parser.add_argument( + "--feature_interaction", + action="store_true", + default=None, + help="Generate feature interactions for PyCaret setup", + ) + parser.add_argument( + "--feature_ratio", + action="store_true", + default=None, + help="Generate feature ratios for PyCaret setup", + ) + parser.add_argument( + "--fix_imbalance", + action="store_true", + default=None, + help="Fix class imbalance for PyCaret setup", + ) + parser.add_argument( + "--models", + nargs="+", + default=None, + help="Selected models for training", + ) + parser.add_argument( + "--tune_model", + action="store_true", + default=False, + help="Tune the best model hyperparameters after training", + ) + parser.add_argument( + "--random_seed", + type=int, + default=42, + help="Random seed for PyCaret setup", + ) + parser.add_argument( + "--test_file", + type=str, + default=None, + help="Path to the test data file", + ) args = parser.parse_args() - cross_validation = True + # Normalize cross-validation flags: --no_cross_validation overrides --cross_validation if args.no_cross_validation: - cross_validation = False + args.cross_validation = False + # If --cross_validation was passed, args.cross_validation is True + # If neither was passed, args.cross_validation remains None + # Build the model_kwargs dict from CLI args model_kwargs = { "train_size": args.train_size, "normalize": args.normalize, "feature_selection": args.feature_selection, - "cross_validation": cross_validation, + "cross_validation": args.cross_validation, "cross_validation_folds": args.cross_validation_folds, "remove_outliers": args.remove_outliers, "remove_multicollinearity": args.remove_multicollinearity, @@ -81,17 +136,19 @@ "feature_interaction": args.feature_interaction, "feature_ratio": args.feature_ratio, "fix_imbalance": args.fix_imbalance, + "tune_model": args.tune_model, } LOG.info(f"Model kwargs: {model_kwargs}") - # Remove None values from model_kwargs - - LOG.info(f"Model kwargs 2: {model_kwargs}") + # If the XML passed a comma-separated string in a single list element, split it out if args.models: model_kwargs["models"] = args.models[0].split(",") + # Drop None entries so PyCaret uses its default values model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None} + LOG.info(f"Model kwargs 2: {model_kwargs}") + # Instantiate the appropriate trainer if args.model_type == "classification": trainer = ClassificationModelTrainer( args.input_file, @@ -100,10 +157,11 @@ args.model_type, args.random_seed, args.test_file, - **model_kwargs) + **model_kwargs, + ) elif args.model_type == "regression": - if "fix_imbalance" in model_kwargs: - del model_kwargs["fix_imbalance"] + # regression doesn't support fix_imbalance + model_kwargs.pop("fix_imbalance", None) trainer = RegressionModelTrainer( args.input_file, args.target_col, @@ -111,11 +169,12 @@ args.model_type, args.random_seed, args.test_file, - **model_kwargs) + **model_kwargs, + ) else: - LOG.error("Invalid model type. Please choose \ - 'classification' or 'regression'.") + LOG.error("Invalid model type. Please choose 'classification' or 'regression'.") return + trainer.run()