Mercurial > repos > goeckslab > pycaret_compare
comparison base_model_trainer.py @ 4:4aa511539199 draft default tip
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit cf47efb521b91a9cb44ae5c5ade860627f9b9030
author | goeckslab |
---|---|
date | Tue, 03 Jun 2025 19:31:16 +0000 |
parents | 02f7746e7772 |
children |
comparison
equal
deleted
inserted
replaced
3:02f7746e7772 | 4:4aa511539199 |
---|---|
1 import base64 | 1 import base64 |
2 import logging | 2 import logging |
3 import os | 3 import os |
4 import tempfile | 4 import tempfile |
5 | 5 |
6 import h5py | |
7 import joblib | |
8 import numpy as np | |
9 import pandas as pd | |
6 from feature_importance import FeatureImportanceAnalyzer | 10 from feature_importance import FeatureImportanceAnalyzer |
7 | |
8 import h5py | |
9 | |
10 import joblib | |
11 | |
12 import numpy as np | |
13 | |
14 import pandas as pd | |
15 | |
16 from sklearn.metrics import average_precision_score | 11 from sklearn.metrics import average_precision_score |
17 | |
18 from utils import get_html_closing, get_html_template | 12 from utils import get_html_closing, get_html_template |
19 | 13 |
20 logging.basicConfig(level=logging.DEBUG) | 14 logging.basicConfig(level=logging.DEBUG) |
21 LOG = logging.getLogger(__name__) | 15 LOG = logging.getLogger(__name__) |
22 | 16 |
29 target_col, | 23 target_col, |
30 output_dir, | 24 output_dir, |
31 task_type, | 25 task_type, |
32 random_seed, | 26 random_seed, |
33 test_file=None, | 27 test_file=None, |
34 **kwargs | 28 **kwargs): |
35 ): | |
36 self.exp = None # This will be set in the subclass | 29 self.exp = None # This will be set in the subclass |
37 self.input_file = input_file | 30 self.input_file = input_file |
38 self.target_col = target_col | 31 self.target_col = target_col |
39 self.output_dir = output_dir | 32 self.output_dir = output_dir |
40 self.task_type = task_type | 33 self.task_type = task_type |
69 | 62 |
70 if len(non_numeric_cols) > 0: | 63 if len(non_numeric_cols) > 0: |
71 LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}") | 64 LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}") |
72 | 65 |
73 names = self.data.columns.to_list() | 66 names = self.data.columns.to_list() |
74 target_index = int(self.target_col)-1 | 67 target_index = int(self.target_col) - 1 |
75 self.target = names[target_index] | 68 self.target = names[target_index] |
76 self.features_name = [name | 69 self.features_name = [name |
77 for i, name in enumerate(names) | 70 for i, name in enumerate(names) |
78 if i != target_index] | 71 if i != target_index] |
79 if hasattr(self, 'missing_value_strategy'): | 72 if hasattr(self, 'missing_value_strategy'): |
95 self.test_file, sep=None, engine='python') | 88 self.test_file, sep=None, engine='python') |
96 self.test_data = self.test_data[numeric_cols].apply( | 89 self.test_data = self.test_data[numeric_cols].apply( |
97 pd.to_numeric, errors='coerce') | 90 pd.to_numeric, errors='coerce') |
98 self.test_data.columns = self.test_data.columns.str.replace( | 91 self.test_data.columns = self.test_data.columns.str.replace( |
99 '.', '_' | 92 '.', '_' |
100 ) | 93 ) |
101 | 94 |
102 def setup_pycaret(self): | 95 def setup_pycaret(self): |
103 LOG.info("Initializing PyCaret") | 96 LOG.info("Initializing PyCaret") |
104 self.setup_params = { | 97 self.setup_params = { |
105 'target': self.target, | 98 'target': self.target, |
204 filtered_setup_params = { | 197 filtered_setup_params = { |
205 k: v | 198 k: v |
206 for k, v in self.setup_params.items() if k not in excluded_params | 199 for k, v in self.setup_params.items() if k not in excluded_params |
207 } | 200 } |
208 setup_params_table = pd.DataFrame( | 201 setup_params_table = pd.DataFrame( |
209 list(filtered_setup_params.items()), | 202 list(filtered_setup_params.items()), columns=['Parameter', 'Value'] |
210 columns=['Parameter', 'Value']) | 203 ) |
211 | 204 |
212 best_model_params = pd.DataFrame( | 205 best_model_params = pd.DataFrame( |
213 self.best_model.get_params().items(), | 206 self.best_model.get_params().items(), |
214 columns=['Parameter', 'Value']) | 207 columns=['Parameter', 'Value'] |
208 ) | |
215 best_model_params.to_csv( | 209 best_model_params.to_csv( |
216 os.path.join(self.output_dir, 'best_model.csv'), | 210 os.path.join(self.output_dir, "best_model.csv"), index=False |
217 index=False) | 211 ) |
218 self.results.to_csv(os.path.join( | 212 self.results.to_csv( |
219 self.output_dir, "comparison_results.csv")) | 213 os.path.join(self.output_dir, "comparison_results.csv") |
220 self.test_result_df.to_csv(os.path.join( | 214 ) |
221 self.output_dir, "test_results.csv")) | 215 self.test_result_df.to_csv( |
216 os.path.join(self.output_dir, "test_results.csv") | |
217 ) | |
222 | 218 |
223 plots_html = "" | 219 plots_html = "" |
224 length = len(self.plots) | 220 length = len(self.plots) |
225 for i, (plot_name, plot_path) in enumerate(self.plots.items()): | 221 for i, (plot_name, plot_path) in enumerate(self.plots.items()): |
226 encoded_image = self.encode_image_to_base64(plot_path) | 222 encoded_image = self.encode_image_to_base64(plot_path) |
248 | 244 |
249 analyzer = FeatureImportanceAnalyzer( | 245 analyzer = FeatureImportanceAnalyzer( |
250 data=self.data, | 246 data=self.data, |
251 target_col=self.target_col, | 247 target_col=self.target_col, |
252 task_type=self.task_type, | 248 task_type=self.task_type, |
253 output_dir=self.output_dir) | 249 output_dir=self.output_dir, |
250 ) | |
254 feature_importance_html = analyzer.run() | 251 feature_importance_html = analyzer.run() |
255 | 252 |
256 html_content = f""" | 253 html_content = f""" |
257 {get_html_template()} | 254 {get_html_template()} |
258 <h1>PyCaret Model Training Report</h1> | 255 <h1>PyCaret Model Training Report</h1> |
261 Setup & Best Model</div> | 258 Setup & Best Model</div> |
262 <div class="tab" onclick="openTab(event, 'plots')"> | 259 <div class="tab" onclick="openTab(event, 'plots')"> |
263 Best Model Plots</div> | 260 Best Model Plots</div> |
264 <div class="tab" onclick="openTab(event, 'feature')"> | 261 <div class="tab" onclick="openTab(event, 'feature')"> |
265 Feature Importance</div> | 262 Feature Importance</div> |
266 """ | 263 """ |
267 if self.plots_explainer_html: | 264 if self.plots_explainer_html: |
268 html_content += """ | 265 html_content += """ |
269 "<div class="tab" onclick="openTab(event, 'explainer')">" | 266 <div class="tab" onclick="openTab(event, 'explainer')"> |
270 Explainer Plots</div> | 267 Explainer Plots</div> |
271 """ | 268 """ |
272 html_content += f""" | 269 html_content += f""" |
273 </div> | 270 </div> |
274 <div id="summary" class="tab-content"> | 271 <div id="summary" class="tab-content"> |
275 <h2>Setup Parameters</h2> | 272 <h2>Setup Parameters</h2> |
276 <table> | 273 {setup_params_table.to_html( |
277 <tr><th>Parameter</th><th>Value</th></tr> | 274 index=False, |
278 {setup_params_table.to_html( | 275 header=True, |
279 index=False, header=False, classes='table')} | 276 classes='table sortable' |
280 </table> | 277 )} |
281 <h5>If you want to know all the experiment setup parameters, | 278 <h5>If you want to know all the experiment setup parameters, |
282 please check the PyCaret documentation for | 279 please check the PyCaret documentation for |
283 the classification/regression <code>exp</code> function.</h5> | 280 the classification/regression <code>exp</code> function.</h5> |
284 <h2>Best Model: {model_name}</h2> | 281 <h2>Best Model: {model_name}</h2> |
285 <table> | 282 {best_model_params.to_html( |
286 <tr><th>Parameter</th><th>Value</th></tr> | 283 index=False, |
287 {best_model_params.to_html( | 284 header=True, |
288 index=False, header=False, classes='table')} | 285 classes='table sortable' |
289 </table> | 286 )} |
290 <h2>Comparison Results on the Cross-Validation Set</h2> | 287 <h2>Comparison Results on the Cross-Validation Set</h2> |
291 <table> | 288 {self.results.to_html(index=False, classes='table sortable')} |
292 {self.results.to_html(index=False, classes='table')} | |
293 </table> | |
294 <h2>Results on the Test Set for the best model</h2> | 289 <h2>Results on the Test Set for the best model</h2> |
295 <table> | 290 {self.test_result_df.to_html( |
296 {self.test_result_df.to_html(index=False, classes='table')} | 291 index=False, |
297 </table> | 292 classes='table sortable' |
293 )} | |
298 </div> | 294 </div> |
299 <div id="plots" class="tab-content"> | 295 <div id="plots" class="tab-content"> |
300 <h2>Best Model Plots on the testing set</h2> | 296 <h2>Best Model Plots on the testing set</h2> |
301 {plots_html} | 297 {plots_html} |
302 </div> | 298 </div> |
308 html_content += f""" | 304 html_content += f""" |
309 <div id="explainer" class="tab-content"> | 305 <div id="explainer" class="tab-content"> |
310 {self.plots_explainer_html} | 306 {self.plots_explainer_html} |
311 {tree_plots} | 307 {tree_plots} |
312 </div> | 308 </div> |
313 {get_html_closing()} | |
314 """ | 309 """ |
315 else: | 310 html_content += """ |
316 html_content += f""" | 311 <script> |
317 {get_html_closing()} | 312 document.addEventListener("DOMContentLoaded", function() { |
318 """ | 313 var tables = document.querySelectorAll("table.sortable"); |
319 with open(os.path.join( | 314 tables.forEach(function(table) { |
320 self.output_dir, "comparison_result.html"), "w") as file: | 315 var headers = table.querySelectorAll("th"); |
316 headers.forEach(function(header, index) { | |
317 header.style.cursor = "pointer"; | |
318 // Add initial arrow (up) to indicate sortability | |
319 header.innerHTML += '<span class="sort-arrow"> ↑</span>'; | |
320 header.addEventListener("click", function() { | |
321 var direction = this.getAttribute( | |
322 "data-sort-direction" | |
323 ) || "asc"; | |
324 // Reset arrows in all headers of this table | |
325 headers.forEach(function(h) { | |
326 var arrow = h.querySelector(".sort-arrow"); | |
327 if (arrow) arrow.textContent = " ↑"; | |
328 }); | |
329 // Set arrow for clicked header | |
330 var arrow = this.querySelector(".sort-arrow"); | |
331 arrow.textContent = direction === "asc" ? " ↓" : " ↑"; | |
332 sortTable(table, index, direction); | |
333 this.setAttribute("data-sort-direction", | |
334 direction === "asc" ? "desc" : "asc"); | |
335 }); | |
336 }); | |
337 }); | |
338 }); | |
339 | |
340 function sortTable(table, colNum, direction) { | |
341 var tb = table.tBodies[0]; | |
342 var tr = Array.prototype.slice.call(tb.rows, 0); | |
343 var multiplier = direction === "asc" ? 1 : -1; | |
344 tr = tr.sort(function(a, b) { | |
345 var aText = a.cells[colNum].textContent.trim(); | |
346 var bText = b.cells[colNum].textContent.trim(); | |
347 // Remove arrow from text comparison | |
348 aText = aText.replace(/[↑↓]/g, '').trim(); | |
349 bText = bText.replace(/[↑↓]/g, '').trim(); | |
350 if (!isNaN(aText) && !isNaN(bText)) { | |
351 return multiplier * ( | |
352 parseFloat(aText) - parseFloat(bText) | |
353 ); | |
354 } else { | |
355 return multiplier * aText.localeCompare(bText); | |
356 } | |
357 }); | |
358 for (var i = 0; i < tr.length; ++i) tb.appendChild(tr[i]); | |
359 } | |
360 </script> | |
361 """ | |
362 html_content += f""" | |
363 {get_html_closing()} | |
364 """ | |
365 with open( | |
366 os.path.join(self.output_dir, "comparison_result.html"), | |
367 "w" | |
368 ) as file: | |
321 file.write(html_content) | 369 file.write(html_content) |
322 | 370 |
323 def save_dashboard(self): | 371 def save_dashboard(self): |
324 raise NotImplementedError("Subclasses should implement this method") | 372 raise NotImplementedError("Subclasses should implement this method") |
325 | 373 |