Mercurial > repos > goeckslab > pycaret_predict
changeset 9:c6c1f8777aae draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 4a11e8a4c4e9daa884bddedfa47090476c517667
author | goeckslab |
---|---|
date | Thu, 31 Jul 2025 15:41:24 +0000 (4 hours ago) |
parents | 1aed7d47c5ec |
children | |
files | base_model_trainer.py pycaret_train.py |
diffstat | 2 files changed, 26 insertions(+), 7 deletions(-) [+] |
line wrap: on
line diff
--- a/base_model_trainer.py Fri Jul 25 19:02:32 2025 +0000 +++ b/base_model_trainer.py Thu Jul 31 15:41:24 2025 +0000 @@ -175,7 +175,13 @@ if self.task_type == "classification": self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True) - _ = self.exp.predict_model(self.best_model) + + prob_thresh = getattr(self, "probability_threshold", None) + if self.task_type == "classification" and prob_thresh is not None: + _ = self.exp.predict_model(self.best_model, probability_threshold=prob_thresh) + else: + _ = self.exp.predict_model(self.best_model) + self.test_result_df = self.exp.pull() if self.task_type == "classification": self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True) @@ -233,7 +239,7 @@ best_model_name = type(self.best_model).__name__ LOG.info(f"Best model determined as: {best_model_name}") - # 2) Compute training sample count + # 2) Compute training sample count try: n_train = self.exp.X_train.shape[0] except Exception: @@ -241,7 +247,10 @@ total_rows = self.data.shape[0] # 3) Build setup parameters table - all_params = self.setup_params + all_params = self.setup_params.copy() + if self.task_type == "classification" and hasattr(self, "probability_threshold"): + all_params["probability_threshold"] = self.probability_threshold + display_keys = [ "Target", "Session ID", @@ -255,6 +264,7 @@ "Polynomial Features", "Fix Imbalance", "Models", + "Probability Threshold", ] setup_rows = [] for key in display_keys: @@ -281,6 +291,8 @@ dv = v if v is not None else "None" elif key == "Models": dv = ", ".join(map(str, v)) if isinstance(v, (list, tuple)) else "None" + elif key == "Probability Threshold": + dv = v if v is not None else "None" else: dv = v if v is not None else "None" setup_rows.append([key, dv])
--- a/pycaret_train.py Fri Jul 25 19:02:32 2025 +0000 +++ b/pycaret_train.py Thu Jul 31 15:41:24 2025 +0000 @@ -103,16 +103,22 @@ help="Tune the best model hyperparameters after training", ) parser.add_argument( + "--test_file", + type=str, + default=None, + help="Path to the test data file", + ) + parser.add_argument( "--random_seed", type=int, default=42, help="Random seed for PyCaret setup", ) parser.add_argument( - "--test_file", - type=str, + "--probability_threshold", + type=float, default=None, - help="Path to the test data file", + help="Probability threshold for classification decision,", ) args = parser.parse_args() @@ -120,7 +126,7 @@ # Normalize cross-validation flags: --no_cross_validation overrides --cross_validation if args.no_cross_validation: args.cross_validation = False - # If --cross_validation was passed, args.cross_validation is True + # If --cross_validation was passed, args.cross_validation is True # If neither was passed, args.cross_validation remains None # Build the model_kwargs dict from CLI args @@ -137,6 +143,7 @@ "feature_ratio": args.feature_ratio, "fix_imbalance": args.fix_imbalance, "tune_model": args.tune_model, + "probability_threshold": args.probability_threshold, } LOG.info(f"Model kwargs: {model_kwargs}")