Mercurial > repos > goeckslab > pycaret_predict
changeset 13:f07850192bc2 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 84d5cd0b1fa5c1ff0ad892bc39c95dad1ceb4920
| author | goeckslab |
|---|---|
| date | Sat, 08 Nov 2025 14:20:33 +0000 |
| parents | e674b9e946fb |
| children | |
| files | base_model_trainer.py pycaret_macros.xml pycaret_train.py test-data/expected_best_model_classification_customized.csv test-data/expected_model_classification_customized.h5 |
| diffstat | 5 files changed, 43 insertions(+), 22 deletions(-) [+] |
line wrap: on
line diff
--- a/base_model_trainer.py Mon Sep 08 22:39:12 2025 +0000 +++ b/base_model_trainer.py Sat Nov 08 14:20:33 2025 +0000 @@ -199,6 +199,28 @@ self.exp.setup(self.data, **self.setup_params) self.setup_params.update(self.user_kwargs) + def _normalize_metric(self, m: str) -> str: + if not m: + return "R2" if self.task_type == "regression" else "Accuracy" + m_low = str(m).strip().lower() + alias = { + "auc": "AUC", "roc_auc": "AUC", "roc-auc": "AUC", + "accuracy": "Accuracy", + "precision": "Precision", + "recall": "Recall", + "f1": "F1", + "kappa": "Kappa", + "logloss": "Log Loss", "log_loss": "Log Loss", + "pr_auc": "PR-AUC-Weighted", "prauc": "PR-AUC-Weighted", + "r2": "R2", + "mae": "MAE", + "mse": "MSE", + "rmse": "RMSE", + "rmsle": "RMSLE", + "mape": "MAPE", + } + return alias.get(m_low, m) + def train_model(self): LOG.info("Training and selecting the best model") if self.task_type == "classification": @@ -222,6 +244,15 @@ if getattr(self, "cross_validation_folds", None) is not None: compare_kwargs["fold"] = self.cross_validation_folds + chosen_metric = self._normalize_metric(getattr(self, "best_model_metric", None)) + if chosen_metric: + compare_kwargs["sort"] = chosen_metric + self.chosen_metric_label = chosen_metric + try: + setattr(self.exp, "_fold_metric", chosen_metric) + except Exception as e: + LOG.warning(f"Failed to set '_fold_metric' to '{chosen_metric}': {e}", exc_info=True) + LOG.info(f"compare_models kwargs: {compare_kwargs}") self.best_model = self.exp.compare_models(**compare_kwargs) self.results = self.exp.pull() @@ -369,8 +400,8 @@ else: dv = v if v is not None else "None" setup_rows.append([key, dv]) - if hasattr(self.exp, "_fold_metric"): - setup_rows.append(["best_model_metric", self.exp._fold_metric]) + if getattr(self, "chosen_metric_label", None): + setup_rows.append(["Best Model Metric", self.chosen_metric_label]) df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"]) df_setup.to_csv(
--- a/pycaret_macros.xml Mon Sep 08 22:39:12 2025 +0000 +++ b/pycaret_macros.xml Sat Nov 08 14:20:33 2025 +0000 @@ -1,5 +1,5 @@ <macros> - <token name="@TABULAR_LEARNER_VERSION@">0.1.0.1</token> + <token name="@TABULAR_LEARNER_VERSION@">0.1.1</token> <token name="@PYCARET_VERSION@">3.3.2</token> <token name="@SUFFIX@">1</token> <token name="@PYCARET_PREDICT_VERSION@">@PYCARET_VERSION@+@SUFFIX@</token>
--- a/pycaret_train.py Mon Sep 08 22:39:12 2025 +0000 +++ b/pycaret_train.py Sat Nov 08 14:20:33 2025 +0000 @@ -120,6 +120,12 @@ default=None, help="Probability threshold for classification decision,", ) + parser.add_argument( + "--best_model_metric", + type=str, + default=None, + help="Metric used to select the best model (e.g. AUC, Accuracy, R2, RMSE).", + ) args = parser.parse_args() @@ -144,6 +150,7 @@ "fix_imbalance": args.fix_imbalance, "tune_model": args.tune_model, "probability_threshold": args.probability_threshold, + "best_model_metric": args.best_model_metric, } LOG.info(f"Model kwargs: {model_kwargs}")
--- a/test-data/expected_best_model_classification_customized.csv Mon Sep 08 22:39:12 2025 +0000 +++ b/test-data/expected_best_model_classification_customized.csv Sat Nov 08 14:20:33 2025 +0000 @@ -1,20 +1,3 @@ Parameter,Value -boosting_type,gbdt -class_weight, -colsample_bytree,1.0 -importance_type,split -learning_rate,0.1 -max_depth,-1 -min_child_samples,20 -min_child_weight,0.001 -min_split_gain,0.0 -n_estimators,100 -n_jobs,-1 -num_leaves,31 -objective, -random_state,42 -reg_alpha,0.0 -reg_lambda,0.0 -subsample,1.0 -subsample_for_bin,200000 -subsample_freq,0 +priors, +var_smoothing,1e-09
