Mercurial > repos > goeckslab > galaxy_pycaret
comparison pycaret_train.py @ 0:1bc26b9636d2 draft default tip
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit 5089a5dffc154c8202624cfbd5f1be0f36a9f0cc
| author | goeckslab |
|---|---|
| date | Wed, 11 Dec 2024 03:29:00 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:1bc26b9636d2 |
|---|---|
| 1 import argparse | |
| 2 import logging | |
| 3 | |
| 4 from pycaret_classification import ClassificationModelTrainer | |
| 5 | |
| 6 from pycaret_regression import RegressionModelTrainer | |
| 7 | |
| 8 logging.basicConfig(level=logging.DEBUG) | |
| 9 LOG = logging.getLogger(__name__) | |
| 10 | |
| 11 | |
| 12 def main(): | |
| 13 parser = argparse.ArgumentParser() | |
| 14 parser.add_argument("--input_file", help="Path to the input file") | |
| 15 parser.add_argument("--target_col", help="Column number of the target") | |
| 16 parser.add_argument("--output_dir", | |
| 17 help="Path to the output directory") | |
| 18 parser.add_argument("--model_type", | |
| 19 choices=["classification", "regression"], | |
| 20 help="Type of the model") | |
| 21 parser.add_argument("--train_size", type=float, | |
| 22 default=None, | |
| 23 help="Train size for PyCaret setup") | |
| 24 parser.add_argument("--normalize", action="store_true", | |
| 25 default=None, | |
| 26 help="Normalize data for PyCaret setup") | |
| 27 parser.add_argument("--feature_selection", action="store_true", | |
| 28 default=None, | |
| 29 help="Perform feature selection for PyCaret setup") | |
| 30 parser.add_argument("--cross_validation", action="store_true", | |
| 31 default=None, | |
| 32 help="Perform cross-validation for PyCaret setup") | |
| 33 parser.add_argument("--cross_validation_folds", type=int, | |
| 34 default=None, | |
| 35 help="Number of cross-validation folds \ | |
| 36 for PyCaret setup") | |
| 37 parser.add_argument("--remove_outliers", action="store_true", | |
| 38 default=None, | |
| 39 help="Remove outliers for PyCaret setup") | |
| 40 parser.add_argument("--remove_multicollinearity", action="store_true", | |
| 41 default=None, | |
| 42 help="Remove multicollinearity for PyCaret setup") | |
| 43 parser.add_argument("--polynomial_features", action="store_true", | |
| 44 default=None, | |
| 45 help="Generate polynomial features for PyCaret setup") | |
| 46 parser.add_argument("--feature_interaction", action="store_true", | |
| 47 default=None, | |
| 48 help="Generate feature interactions for PyCaret setup") | |
| 49 parser.add_argument("--feature_ratio", action="store_true", | |
| 50 default=None, | |
| 51 help="Generate feature ratios for PyCaret setup") | |
| 52 parser.add_argument("--fix_imbalance", action="store_true", | |
| 53 default=None, | |
| 54 help="Fix class imbalance for PyCaret setup") | |
| 55 parser.add_argument("--models", nargs='+', | |
| 56 default=None, | |
| 57 help="Selected models for training") | |
| 58 parser.add_argument("--random_seed", type=int, | |
| 59 default=42, | |
| 60 help="Random seed for PyCaret setup") | |
| 61 parser.add_argument("--test_file", type=str, default=None, | |
| 62 help="Path to the test data file") | |
| 63 | |
| 64 args = parser.parse_args() | |
| 65 | |
| 66 model_kwargs = { | |
| 67 "train_size": args.train_size, | |
| 68 "normalize": args.normalize, | |
| 69 "feature_selection": args.feature_selection, | |
| 70 "cross_validation": args.cross_validation, | |
| 71 "cross_validation_folds": args.cross_validation_folds, | |
| 72 "remove_outliers": args.remove_outliers, | |
| 73 "remove_multicollinearity": args.remove_multicollinearity, | |
| 74 "polynomial_features": args.polynomial_features, | |
| 75 "feature_interaction": args.feature_interaction, | |
| 76 "feature_ratio": args.feature_ratio, | |
| 77 "fix_imbalance": args.fix_imbalance, | |
| 78 } | |
| 79 LOG.info(f"Model kwargs: {model_kwargs}") | |
| 80 | |
| 81 # Remove None values from model_kwargs | |
| 82 | |
| 83 LOG.info(f"Model kwargs 2: {model_kwargs}") | |
| 84 if args.models: | |
| 85 model_kwargs["models"] = args.models[0].split(",") | |
| 86 | |
| 87 model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None} | |
| 88 | |
| 89 if args.model_type == "classification": | |
| 90 trainer = ClassificationModelTrainer( | |
| 91 args.input_file, | |
| 92 args.target_col, | |
| 93 args.output_dir, | |
| 94 args.model_type, | |
| 95 args.random_seed, | |
| 96 args.test_file, | |
| 97 **model_kwargs) | |
| 98 elif args.model_type == "regression": | |
| 99 if "fix_imbalance" in model_kwargs: | |
| 100 del model_kwargs["fix_imbalance"] | |
| 101 trainer = RegressionModelTrainer( | |
| 102 args.input_file, | |
| 103 args.target_col, | |
| 104 args.output_dir, | |
| 105 args.model_type, | |
| 106 args.random_seed, | |
| 107 args.test_file, | |
| 108 **model_kwargs) | |
| 109 else: | |
| 110 LOG.error("Invalid model type. Please choose \ | |
| 111 'classification' or 'regression'.") | |
| 112 return | |
| 113 trainer.run() | |
| 114 | |
| 115 | |
| 116 if __name__ == "__main__": | |
| 117 main() |
