Mercurial > repos > goeckslab > tabular_learner
comparison base_model_trainer.py @ 3:f6a65e05d6ec draft
planemo upload for repository https://github.com/goeckslab/gleam commit b430f8b466655878c3bf63b053655fdbf039ddb0
| author | goeckslab |
|---|---|
| date | Wed, 09 Jul 2025 01:12:48 +0000 |
| parents | 77c88226bfde |
| children | 11fdac5affb3 |
comparison
equal
deleted
inserted
replaced
| 2:77c88226bfde | 3:f6a65e05d6ec |
|---|---|
| 125 if ( | 125 if ( |
| 126 hasattr(self, "cross_validation") | 126 hasattr(self, "cross_validation") |
| 127 and self.cross_validation is not None | 127 and self.cross_validation is not None |
| 128 and self.cross_validation is False | 128 and self.cross_validation is False |
| 129 ): | 129 ): |
| 130 self.setup_params["cross_validation"] = self.cross_validation | 130 logging.info( |
| 131 | 131 "cross_validation is set to False. This will disable cross-validation." |
| 132 if hasattr(self, "cross_validation") and self.cross_validation is not None: | 132 ) |
| 133 | |
| 134 if hasattr(self, "cross_validation") and self.cross_validation: | |
| 133 if hasattr(self, "cross_validation_folds"): | 135 if hasattr(self, "cross_validation_folds"): |
| 134 self.setup_params["fold"] = self.cross_validation_folds | 136 self.setup_params["fold"] = self.cross_validation_folds |
| 135 | 137 |
| 136 if hasattr(self, "remove_outliers") and self.remove_outliers is not None: | 138 if hasattr(self, "remove_outliers") and self.remove_outliers is not None: |
| 137 self.setup_params["remove_outliers"] = self.remove_outliers | 139 self.setup_params["remove_outliers"] = self.remove_outliers |
| 180 score_func=average_precision_score, | 182 score_func=average_precision_score, |
| 181 average="weighted", | 183 average="weighted", |
| 182 ) | 184 ) |
| 183 | 185 |
| 184 if hasattr(self, "models") and self.models is not None: | 186 if hasattr(self, "models") and self.models is not None: |
| 185 self.best_model = self.exp.compare_models(include=self.models) | 187 self.best_model = self.exp.compare_models(include=self.models, cross_validation=self.cross_validation) |
| 186 else: | 188 else: |
| 187 self.best_model = self.exp.compare_models() | 189 self.best_model = self.exp.compare_models(cross_validation=self.cross_validation) |
| 188 self.results = self.exp.pull() | 190 self.results = self.exp.pull() |
| 191 | |
| 189 if self.task_type == "classification": | 192 if self.task_type == "classification": |
| 190 self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True) | 193 self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True) |
| 191 | 194 |
| 192 _ = self.exp.predict_model(self.best_model) | 195 _ = self.exp.predict_model(self.best_model) |
| 193 self.test_result_df = self.exp.pull() | 196 self.test_result_df = self.exp.pull() |
| 312 "Explainer Plots</div>" | 315 "Explainer Plots</div>" |
| 313 ) | 316 ) |
| 314 html_content += ( | 317 html_content += ( |
| 315 "</div>" | 318 "</div>" |
| 316 '<div id="summary" class="tab-content">' | 319 '<div id="summary" class="tab-content">' |
| 317 "<h2>Model Metrics from Cross-Validation Set</h2>" | 320 f"<h2>Model Metrics from {'Cross-Validation Set' if self.cross_validation else 'Validation set'}</h2>" |
| 318 f"<h2>Best Model: {model_name}</h2>" | 321 f"<h2>Best Model: {model_name}</h2>" |
| 319 "<h5>The best model is selected by: Accuracy (Classification)" | 322 "<h5>The best model is selected by: Accuracy (Classification)" |
| 320 " or R2 (Regression).</h5>" | 323 " or R2 (Regression).</h5>" |
| 321 f"{self.results.to_html(index=False, classes='table sortable')}" | 324 f"{self.results.to_html(index=False, classes='table sortable')}" |
| 322 "<h2>Best Model's Hyperparameters</h2>" | 325 "<h2>Best Model's Hyperparameters</h2>" |
