comparison base_model_trainer.py @ 7:f4cb41f458fd draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit b430f8b466655878c3bf63b053655fdbf039ddb0
author goeckslab
date Wed, 09 Jul 2025 01:13:01 +0000
parents a32ff7201629
children
comparison
equal deleted inserted replaced
6:a32ff7201629 7:f4cb41f458fd
125 if ( 125 if (
126 hasattr(self, "cross_validation") 126 hasattr(self, "cross_validation")
127 and self.cross_validation is not None 127 and self.cross_validation is not None
128 and self.cross_validation is False 128 and self.cross_validation is False
129 ): 129 ):
130 self.setup_params["cross_validation"] = self.cross_validation 130 logging.info(
131 131 "cross_validation is set to False. This will disable cross-validation."
132 if hasattr(self, "cross_validation") and self.cross_validation is not None: 132 )
133
134 if hasattr(self, "cross_validation") and self.cross_validation:
133 if hasattr(self, "cross_validation_folds"): 135 if hasattr(self, "cross_validation_folds"):
134 self.setup_params["fold"] = self.cross_validation_folds 136 self.setup_params["fold"] = self.cross_validation_folds
135 137
136 if hasattr(self, "remove_outliers") and self.remove_outliers is not None: 138 if hasattr(self, "remove_outliers") and self.remove_outliers is not None:
137 self.setup_params["remove_outliers"] = self.remove_outliers 139 self.setup_params["remove_outliers"] = self.remove_outliers
180 score_func=average_precision_score, 182 score_func=average_precision_score,
181 average="weighted", 183 average="weighted",
182 ) 184 )
183 185
184 if hasattr(self, "models") and self.models is not None: 186 if hasattr(self, "models") and self.models is not None:
185 self.best_model = self.exp.compare_models(include=self.models) 187 self.best_model = self.exp.compare_models(include=self.models, cross_validation=self.cross_validation)
186 else: 188 else:
187 self.best_model = self.exp.compare_models() 189 self.best_model = self.exp.compare_models(cross_validation=self.cross_validation)
188 self.results = self.exp.pull() 190 self.results = self.exp.pull()
191
189 if self.task_type == "classification": 192 if self.task_type == "classification":
190 self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True) 193 self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True)
191 194
192 _ = self.exp.predict_model(self.best_model) 195 _ = self.exp.predict_model(self.best_model)
193 self.test_result_df = self.exp.pull() 196 self.test_result_df = self.exp.pull()
312 "Explainer Plots</div>" 315 "Explainer Plots</div>"
313 ) 316 )
314 html_content += ( 317 html_content += (
315 "</div>" 318 "</div>"
316 '<div id="summary" class="tab-content">' 319 '<div id="summary" class="tab-content">'
317 "<h2>Model Metrics from Cross-Validation Set</h2>" 320 f"<h2>Model Metrics from {'Cross-Validation Set' if self.cross_validation else 'Validation set'}</h2>"
318 f"<h2>Best Model: {model_name}</h2>" 321 f"<h2>Best Model: {model_name}</h2>"
319 "<h5>The best model is selected by: Accuracy (Classification)" 322 "<h5>The best model is selected by: Accuracy (Classification)"
320 " or R2 (Regression).</h5>" 323 " or R2 (Regression).</h5>"
321 f"{self.results.to_html(index=False, classes='table sortable')}" 324 f"{self.results.to_html(index=False, classes='table sortable')}"
322 "<h2>Best Model's Hyperparameters</h2>" 325 "<h2>Best Model's Hyperparameters</h2>"