Mercurial > repos > goeckslab > pycaret_compare
diff base_model_trainer.py @ 4:4aa511539199 draft default tip
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit cf47efb521b91a9cb44ae5c5ade860627f9b9030
author | goeckslab |
---|---|
date | Tue, 03 Jun 2025 19:31:16 +0000 |
parents | 02f7746e7772 |
children |
line wrap: on
line diff
--- a/base_model_trainer.py Wed Jan 01 03:19:40 2025 +0000 +++ b/base_model_trainer.py Tue Jun 03 19:31:16 2025 +0000 @@ -3,18 +3,12 @@ import os import tempfile -from feature_importance import FeatureImportanceAnalyzer - import h5py - import joblib - import numpy as np - import pandas as pd - +from feature_importance import FeatureImportanceAnalyzer from sklearn.metrics import average_precision_score - from utils import get_html_closing, get_html_template logging.basicConfig(level=logging.DEBUG) @@ -31,8 +25,7 @@ task_type, random_seed, test_file=None, - **kwargs - ): + **kwargs): self.exp = None # This will be set in the subclass self.input_file = input_file self.target_col = target_col @@ -71,7 +64,7 @@ LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}") names = self.data.columns.to_list() - target_index = int(self.target_col)-1 + target_index = int(self.target_col) - 1 self.target = names[target_index] self.features_name = [name for i, name in enumerate(names) @@ -97,7 +90,7 @@ pd.to_numeric, errors='coerce') self.test_data.columns = self.test_data.columns.str.replace( '.', '_' - ) + ) def setup_pycaret(self): LOG.info("Initializing PyCaret") @@ -206,19 +199,22 @@ for k, v in self.setup_params.items() if k not in excluded_params } setup_params_table = pd.DataFrame( - list(filtered_setup_params.items()), - columns=['Parameter', 'Value']) + list(filtered_setup_params.items()), columns=['Parameter', 'Value'] + ) best_model_params = pd.DataFrame( self.best_model.get_params().items(), - columns=['Parameter', 'Value']) + columns=['Parameter', 'Value'] + ) best_model_params.to_csv( - os.path.join(self.output_dir, 'best_model.csv'), - index=False) - self.results.to_csv(os.path.join( - self.output_dir, "comparison_results.csv")) - self.test_result_df.to_csv(os.path.join( - self.output_dir, "test_results.csv")) + os.path.join(self.output_dir, "best_model.csv"), index=False + ) + self.results.to_csv( + os.path.join(self.output_dir, "comparison_results.csv") + ) + self.test_result_df.to_csv( + os.path.join(self.output_dir, "test_results.csv") + ) plots_html = "" length = len(self.plots) @@ -250,7 +246,8 @@ data=self.data, target_col=self.target_col, task_type=self.task_type, - output_dir=self.output_dir) + output_dir=self.output_dir, + ) feature_importance_html = analyzer.run() html_content = f""" @@ -263,38 +260,37 @@ Best Model Plots</div> <div class="tab" onclick="openTab(event, 'feature')"> Feature Importance</div> - """ + """ if self.plots_explainer_html: html_content += """ - "<div class="tab" onclick="openTab(event, 'explainer')">" + <div class="tab" onclick="openTab(event, 'explainer')"> Explainer Plots</div> """ html_content += f""" </div> <div id="summary" class="tab-content"> <h2>Setup Parameters</h2> - <table> - <tr><th>Parameter</th><th>Value</th></tr> - {setup_params_table.to_html( - index=False, header=False, classes='table')} - </table> + {setup_params_table.to_html( + index=False, + header=True, + classes='table sortable' + )} <h5>If you want to know all the experiment setup parameters, please check the PyCaret documentation for the classification/regression <code>exp</code> function.</h5> <h2>Best Model: {model_name}</h2> - <table> - <tr><th>Parameter</th><th>Value</th></tr> - {best_model_params.to_html( - index=False, header=False, classes='table')} - </table> + {best_model_params.to_html( + index=False, + header=True, + classes='table sortable' + )} <h2>Comparison Results on the Cross-Validation Set</h2> - <table> - {self.results.to_html(index=False, classes='table')} - </table> + {self.results.to_html(index=False, classes='table sortable')} <h2>Results on the Test Set for the best model</h2> - <table> - {self.test_result_df.to_html(index=False, classes='table')} - </table> + {self.test_result_df.to_html( + index=False, + classes='table sortable' + )} </div> <div id="plots" class="tab-content"> <h2>Best Model Plots on the testing set</h2> @@ -310,14 +306,66 @@ {self.plots_explainer_html} {tree_plots} </div> - {get_html_closing()} """ - else: - html_content += f""" - {get_html_closing()} - """ - with open(os.path.join( - self.output_dir, "comparison_result.html"), "w") as file: + html_content += """ + <script> + document.addEventListener("DOMContentLoaded", function() { + var tables = document.querySelectorAll("table.sortable"); + tables.forEach(function(table) { + var headers = table.querySelectorAll("th"); + headers.forEach(function(header, index) { + header.style.cursor = "pointer"; + // Add initial arrow (up) to indicate sortability + header.innerHTML += '<span class="sort-arrow"> ↑</span>'; + header.addEventListener("click", function() { + var direction = this.getAttribute( + "data-sort-direction" + ) || "asc"; + // Reset arrows in all headers of this table + headers.forEach(function(h) { + var arrow = h.querySelector(".sort-arrow"); + if (arrow) arrow.textContent = " ↑"; + }); + // Set arrow for clicked header + var arrow = this.querySelector(".sort-arrow"); + arrow.textContent = direction === "asc" ? " ↓" : " ↑"; + sortTable(table, index, direction); + this.setAttribute("data-sort-direction", + direction === "asc" ? "desc" : "asc"); + }); + }); + }); + }); + + function sortTable(table, colNum, direction) { + var tb = table.tBodies[0]; + var tr = Array.prototype.slice.call(tb.rows, 0); + var multiplier = direction === "asc" ? 1 : -1; + tr = tr.sort(function(a, b) { + var aText = a.cells[colNum].textContent.trim(); + var bText = b.cells[colNum].textContent.trim(); + // Remove arrow from text comparison + aText = aText.replace(/[↑↓]/g, '').trim(); + bText = bText.replace(/[↑↓]/g, '').trim(); + if (!isNaN(aText) && !isNaN(bText)) { + return multiplier * ( + parseFloat(aText) - parseFloat(bText) + ); + } else { + return multiplier * aText.localeCompare(bText); + } + }); + for (var i = 0; i < tr.length; ++i) tb.appendChild(tr[i]); + } + </script> + """ + html_content += f""" + {get_html_closing()} + """ + with open( + os.path.join(self.output_dir, "comparison_result.html"), + "w" + ) as file: file.write(html_content) def save_dashboard(self):