view pycaret_train.py @ 8:1aed7d47c5ec draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
author goeckslab
date Fri, 25 Jul 2025 19:02:32 +0000
parents f4cb41f458fd
children
line wrap: on
line source

import argparse
import logging

from pycaret_classification import ClassificationModelTrainer
from pycaret_regression import RegressionModelTrainer

logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger(__name__)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_file", help="Path to the input file")
    parser.add_argument("--target_col", help="Column number of the target")
    parser.add_argument("--output_dir", help="Path to the output directory")
    parser.add_argument(
        "--model_type",
        choices=["classification", "regression"],
        help="Type of the model",
    )
    parser.add_argument(
        "--train_size",
        type=float,
        default=None,
        help="Train size for PyCaret setup",
    )
    parser.add_argument(
        "--normalize",
        action="store_true",
        default=None,
        help="Normalize data for PyCaret setup",
    )
    parser.add_argument(
        "--feature_selection",
        action="store_true",
        default=None,
        help="Perform feature selection for PyCaret setup",
    )
    parser.add_argument(
        "--cross_validation",
        action="store_true",
        default=None,
        help="Enable cross-validation for PyCaret setup",
    )
    parser.add_argument(
        "--no_cross_validation",
        action="store_true",
        default=None,
        help="Disable cross-validation for PyCaret setup",
    )
    parser.add_argument(
        "--cross_validation_folds",
        type=int,
        default=None,
        help="Number of cross-validation folds for PyCaret setup",
    )
    parser.add_argument(
        "--remove_outliers",
        action="store_true",
        default=None,
        help="Remove outliers for PyCaret setup",
    )
    parser.add_argument(
        "--remove_multicollinearity",
        action="store_true",
        default=None,
        help="Remove multicollinearity for PyCaret setup",
    )
    parser.add_argument(
        "--polynomial_features",
        action="store_true",
        default=None,
        help="Generate polynomial features for PyCaret setup",
    )
    parser.add_argument(
        "--feature_interaction",
        action="store_true",
        default=None,
        help="Generate feature interactions for PyCaret setup",
    )
    parser.add_argument(
        "--feature_ratio",
        action="store_true",
        default=None,
        help="Generate feature ratios for PyCaret setup",
    )
    parser.add_argument(
        "--fix_imbalance",
        action="store_true",
        default=None,
        help="Fix class imbalance for PyCaret setup",
    )
    parser.add_argument(
        "--models",
        nargs="+",
        default=None,
        help="Selected models for training",
    )
    parser.add_argument(
        "--tune_model",
        action="store_true",
        default=False,
        help="Tune the best model hyperparameters after training",
    )
    parser.add_argument(
        "--random_seed",
        type=int,
        default=42,
        help="Random seed for PyCaret setup",
    )
    parser.add_argument(
        "--test_file",
        type=str,
        default=None,
        help="Path to the test data file",
    )

    args = parser.parse_args()

    # Normalize cross-validation flags: --no_cross_validation overrides --cross_validation
    if args.no_cross_validation:
        args.cross_validation = False
    # If --cross_validation was passed, args.cross_validation is True
    # If neither was passed, args.cross_validation remains None

    # Build the model_kwargs dict from CLI args
    model_kwargs = {
        "train_size": args.train_size,
        "normalize": args.normalize,
        "feature_selection": args.feature_selection,
        "cross_validation": args.cross_validation,
        "cross_validation_folds": args.cross_validation_folds,
        "remove_outliers": args.remove_outliers,
        "remove_multicollinearity": args.remove_multicollinearity,
        "polynomial_features": args.polynomial_features,
        "feature_interaction": args.feature_interaction,
        "feature_ratio": args.feature_ratio,
        "fix_imbalance": args.fix_imbalance,
        "tune_model": args.tune_model,
    }
    LOG.info(f"Model kwargs: {model_kwargs}")

    # If the XML passed a comma-separated string in a single list element, split it out
    if args.models:
        model_kwargs["models"] = args.models[0].split(",")

    # Drop None entries so PyCaret uses its default values
    model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None}
    LOG.info(f"Model kwargs 2: {model_kwargs}")

    # Instantiate the appropriate trainer
    if args.model_type == "classification":
        trainer = ClassificationModelTrainer(
            args.input_file,
            args.target_col,
            args.output_dir,
            args.model_type,
            args.random_seed,
            args.test_file,
            **model_kwargs,
        )
    elif args.model_type == "regression":
        # regression doesn't support fix_imbalance
        model_kwargs.pop("fix_imbalance", None)
        trainer = RegressionModelTrainer(
            args.input_file,
            args.target_col,
            args.output_dir,
            args.model_type,
            args.random_seed,
            args.test_file,
            **model_kwargs,
        )
    else:
        LOG.error("Invalid model type. Please choose 'classification' or 'regression'.")
        return

    trainer.run()


if __name__ == "__main__":
    main()