diff base_model_trainer.py @ 4:4aa511539199 draft default tip

planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit cf47efb521b91a9cb44ae5c5ade860627f9b9030
author goeckslab
date Tue, 03 Jun 2025 19:31:16 +0000
parents 02f7746e7772
children
line wrap: on
line diff
--- a/base_model_trainer.py	Wed Jan 01 03:19:40 2025 +0000
+++ b/base_model_trainer.py	Tue Jun 03 19:31:16 2025 +0000
@@ -3,18 +3,12 @@
 import os
 import tempfile
 
-from feature_importance import FeatureImportanceAnalyzer
-
 import h5py
-
 import joblib
-
 import numpy as np
-
 import pandas as pd
-
+from feature_importance import FeatureImportanceAnalyzer
 from sklearn.metrics import average_precision_score
-
 from utils import get_html_closing, get_html_template
 
 logging.basicConfig(level=logging.DEBUG)
@@ -31,8 +25,7 @@
             task_type,
             random_seed,
             test_file=None,
-            **kwargs
-            ):
+            **kwargs):
         self.exp = None  # This will be set in the subclass
         self.input_file = input_file
         self.target_col = target_col
@@ -71,7 +64,7 @@
             LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}")
 
         names = self.data.columns.to_list()
-        target_index = int(self.target_col)-1
+        target_index = int(self.target_col) - 1
         self.target = names[target_index]
         self.features_name = [name
                               for i, name in enumerate(names)
@@ -97,7 +90,7 @@
                 pd.to_numeric, errors='coerce')
             self.test_data.columns = self.test_data.columns.str.replace(
                 '.', '_'
-                )
+            )
 
     def setup_pycaret(self):
         LOG.info("Initializing PyCaret")
@@ -206,19 +199,22 @@
             for k, v in self.setup_params.items() if k not in excluded_params
         }
         setup_params_table = pd.DataFrame(
-            list(filtered_setup_params.items()),
-            columns=['Parameter', 'Value'])
+            list(filtered_setup_params.items()), columns=['Parameter', 'Value']
+        )
 
         best_model_params = pd.DataFrame(
             self.best_model.get_params().items(),
-            columns=['Parameter', 'Value'])
+            columns=['Parameter', 'Value']
+        )
         best_model_params.to_csv(
-            os.path.join(self.output_dir, 'best_model.csv'),
-            index=False)
-        self.results.to_csv(os.path.join(
-            self.output_dir, "comparison_results.csv"))
-        self.test_result_df.to_csv(os.path.join(
-            self.output_dir, "test_results.csv"))
+            os.path.join(self.output_dir, "best_model.csv"), index=False
+        )
+        self.results.to_csv(
+            os.path.join(self.output_dir, "comparison_results.csv")
+        )
+        self.test_result_df.to_csv(
+            os.path.join(self.output_dir, "test_results.csv")
+        )
 
         plots_html = ""
         length = len(self.plots)
@@ -250,7 +246,8 @@
             data=self.data,
             target_col=self.target_col,
             task_type=self.task_type,
-            output_dir=self.output_dir)
+            output_dir=self.output_dir,
+        )
         feature_importance_html = analyzer.run()
 
         html_content = f"""
@@ -263,38 +260,37 @@
                 Best Model Plots</div>
                 <div class="tab" onclick="openTab(event, 'feature')">
                 Feature Importance</div>
-            """
+        """
         if self.plots_explainer_html:
             html_content += """
-                "<div class="tab" onclick="openTab(event, 'explainer')">"
+                <div class="tab" onclick="openTab(event, 'explainer')">
                 Explainer Plots</div>
             """
         html_content += f"""
             </div>
             <div id="summary" class="tab-content">
                 <h2>Setup Parameters</h2>
-                <table>
-                    <tr><th>Parameter</th><th>Value</th></tr>
-                    {setup_params_table.to_html(
-                        index=False, header=False, classes='table')}
-                </table>
+                {setup_params_table.to_html(
+                    index=False,
+                    header=True,
+                    classes='table sortable'
+                )}
                 <h5>If you want to know all the experiment setup parameters,
                   please check the PyCaret documentation for
                   the classification/regression <code>exp</code> function.</h5>
                 <h2>Best Model: {model_name}</h2>
-                <table>
-                    <tr><th>Parameter</th><th>Value</th></tr>
-                    {best_model_params.to_html(
-                        index=False, header=False, classes='table')}
-                </table>
+                {best_model_params.to_html(
+                    index=False,
+                    header=True,
+                    classes='table sortable'
+                )}
                 <h2>Comparison Results on the Cross-Validation Set</h2>
-                <table>
-                    {self.results.to_html(index=False, classes='table')}
-                </table>
+                {self.results.to_html(index=False, classes='table sortable')}
                 <h2>Results on the Test Set for the best model</h2>
-                <table>
-                    {self.test_result_df.to_html(index=False, classes='table')}
-                </table>
+                {self.test_result_df.to_html(
+                    index=False,
+                    classes='table sortable'
+                )}
             </div>
             <div id="plots" class="tab-content">
                 <h2>Best Model Plots on the testing set</h2>
@@ -310,14 +306,66 @@
                 {self.plots_explainer_html}
                 {tree_plots}
             </div>
-            {get_html_closing()}
             """
-        else:
-            html_content += f"""
-            {get_html_closing()}
-            """
-        with open(os.path.join(
-                self.output_dir, "comparison_result.html"), "w") as file:
+        html_content += """
+        <script>
+        document.addEventListener("DOMContentLoaded", function() {
+            var tables = document.querySelectorAll("table.sortable");
+            tables.forEach(function(table) {
+                var headers = table.querySelectorAll("th");
+                headers.forEach(function(header, index) {
+                    header.style.cursor = "pointer";
+                    // Add initial arrow (up) to indicate sortability
+                    header.innerHTML += '<span class="sort-arrow"> ↑</span>';
+                    header.addEventListener("click", function() {
+                        var direction = this.getAttribute(
+                            "data-sort-direction"
+                        ) || "asc";
+                        // Reset arrows in all headers of this table
+                        headers.forEach(function(h) {
+                            var arrow = h.querySelector(".sort-arrow");
+                            if (arrow) arrow.textContent = " ↑";
+                        });
+                        // Set arrow for clicked header
+                        var arrow = this.querySelector(".sort-arrow");
+                        arrow.textContent = direction === "asc" ? " ↓" : " ↑";
+                        sortTable(table, index, direction);
+                        this.setAttribute("data-sort-direction",
+                        direction === "asc" ? "desc" : "asc");
+                    });
+                });
+            });
+        });
+
+        function sortTable(table, colNum, direction) {
+            var tb = table.tBodies[0];
+            var tr = Array.prototype.slice.call(tb.rows, 0);
+            var multiplier = direction === "asc" ? 1 : -1;
+            tr = tr.sort(function(a, b) {
+                var aText = a.cells[colNum].textContent.trim();
+                var bText = b.cells[colNum].textContent.trim();
+                // Remove arrow from text comparison
+                aText = aText.replace(/[↑↓]/g, '').trim();
+                bText = bText.replace(/[↑↓]/g, '').trim();
+                if (!isNaN(aText) && !isNaN(bText)) {
+                    return multiplier * (
+                        parseFloat(aText) - parseFloat(bText)
+                    );
+                } else {
+                    return multiplier * aText.localeCompare(bText);
+                }
+            });
+            for (var i = 0; i < tr.length; ++i) tb.appendChild(tr[i]);
+        }
+        </script>
+        """
+        html_content += f"""
+        {get_html_closing()}
+        """
+        with open(
+            os.path.join(self.output_dir, "comparison_result.html"),
+            "w"
+        ) as file:
             file.write(html_content)
 
     def save_dashboard(self):