diff utils.py @ 8:1aed7d47c5ec draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
author goeckslab
date Fri, 25 Jul 2025 19:02:32 +0000
parents a32ff7201629
children
line wrap: on
line diff
--- a/utils.py	Wed Jul 09 01:13:01 2025 +0000
+++ b/utils.py	Fri Jul 25 19:02:32 2025 +0000
@@ -1,5 +1,6 @@
 import base64
 import logging
+from typing import Optional
 
 import numpy as np
 
@@ -7,7 +8,7 @@
 LOG = logging.getLogger(__name__)
 
 
-def get_html_template():
+def get_html_template() -> str:
     return """
     <html>
     <head>
@@ -20,13 +21,16 @@
               padding: 20px;
               background-color: #f4f4f4;
           }
+          /* allow horizontal scrolling if content overflows */
           .container {
               max-width: 800px;
               margin: auto;
               background: white;
               padding: 20px;
               box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
+              overflow-x: auto;
           }
+
           h1 {
               text-align: center;
               color: #333;
@@ -36,6 +40,14 @@
               color: #4CAF50;
               padding-bottom: 5px;
           }
+
+          /* wrapper for tables to allow individual horizontal scroll */
+          .table-wrapper {
+              overflow-x: auto;
+              margin: 1rem 0;
+          }
+
+          /* revert table styling to full borders */
           table {
               width: 100%;
               border-collapse: collapse;
@@ -52,6 +64,7 @@
               background-color: #4CAF50;
               color: white;
           }
+
           .plot {
               text-align: center;
               margin: 20px 0;
@@ -60,106 +73,202 @@
               max-width: 100%;
               height: auto;
           }
+
           .tabs {
               display: flex;
-              margin-bottom: 20px;
-              cursor: pointer;
-              justify-content: space-around;
+              align-items: center;
+              border-bottom: 2px solid #ccc;
+              margin-bottom: 1rem;
           }
           .tab {
-              padding: 10px;
-              background-color: #4CAF50;
-              color: white;
-              border-radius: 5px 5px 0 0;
-              flex-grow: 1;
-              text-align: center;
-              margin: 0 5px;
+              padding: 10px 20px;
+              cursor: pointer;
+              border: 1px solid #ccc;
+              border-bottom: none;
+              background: #f9f9f9;
+              margin-right: 5px;
+              border-top-left-radius: 8px;
+              border-top-right-radius: 8px;
           }
-          .tab.active-tab {
-              background-color: #333;
+          .tab.active {
+              background: white;
+              font-weight: bold;
           }
+
           .tab-content {
               display: none;
               padding: 20px;
-              border: 1px solid #ddd;
+              border: 1px solid #ccc;
               border-top: none;
-              background-color: white;
+              background: white;
           }
-          .tab-content.active-content {
+          .tab-content.active {
               display: block;
           }
-      </style>
+
+          .help-btn {
+              margin-left: auto;
+              padding: 6px 12px;
+              font-size: 0.9rem;
+              border: 1px solid #4CAF50;
+              border-radius: 4px;
+              background: #4CAF50;
+              color: white;
+              cursor: pointer;
+          }
+
+          /* sortable table header arrows */
+          table.sortable th {
+              position: relative;
+              padding-right: 20px; /* room for the arrow */
+              cursor: pointer;
+          }
+          table.sortable th::after {
+              content: '↕';
+              position: absolute;
+              right: 8px;
+              opacity: 0.4;
+              transition: opacity 0.2s;
+          }
+          table.sortable th:hover::after {
+              opacity: 0.7;
+          }
+          table.sortable th.sorted-asc::after {
+              content: '↑';
+              opacity: 1;
+          }
+          table.sortable th.sorted-desc::after {
+              content: '↓';
+              opacity: 1;
+          }
+        </style>
     </head>
     <body>
     <div class="container">
     """
 
 
-def get_html_closing():
+def get_html_closing() -> str:
     return """
-        </div>
-        <script>
-            function openTab(evt, tabName) {{
-                var i, tabcontent, tablinks;
-                tabcontent = document.getElementsByClassName("tab-content");
-                for (i = 0; i < tabcontent.length; i++) {{
-                    tabcontent[i].style.display = "none";
-                }}
-                tablinks = document.getElementsByClassName("tab");
-                for (i = 0; i < tablinks.length; i++) {{
-                    tablinks[i].className =
-                        tablinks[i].className.replace(" active-tab", "");
-                }}
-                document.getElementById(tabName).style.display = "block";
-                evt.currentTarget.className += " active-tab";
-            }}
-            document.addEventListener("DOMContentLoaded", function() {{
-                document.querySelector(".tab").click();
-            }});
-        </script>
+    </div>
+    <script>
+    document.addEventListener('DOMContentLoaded', () => {
+      document.querySelectorAll('table.sortable').forEach(table => {
+        const getCellValue = (row, idx) =>
+          row.children[idx].innerText.trim() || '';
+
+        const comparer = (idx, asc) => (a, b) => {
+          const v1 = getCellValue(asc ? a : b, idx);
+          const v2 = getCellValue(asc ? b : a, idx);
+          const n1 = parseFloat(v1), n2 = parseFloat(v2);
+          if (!isNaN(n1) && !isNaN(n2)) return n1 - n2;
+          return v1.localeCompare(v2);
+        };
+
+        table.querySelectorAll('th').forEach((th, idx) => {
+          let asc = true;
+          th.addEventListener('click', () => {
+            // sort rows
+            const tbody = table.tBodies[0];
+            Array.from(tbody.rows)
+              .sort(comparer(idx, asc))
+              .forEach(row => tbody.appendChild(row));
+            // update arrow classes
+            table.querySelectorAll('th').forEach(h => {
+              h.classList.remove('sorted-asc','sorted-desc');
+            });
+            th.classList.add(asc ? 'sorted-asc' : 'sorted-desc');
+            asc = !asc;
+          });
+        });
+      });
+    });
+    </script>
     </body>
     </html>
     """
 
 
-def customize_figure_layout(fig, margin_dict=None):
+def build_tabbed_html(
+    summary_html: str,
+    test_html: str,
+    feature_html: str,
+    explainer_html: Optional[str] = None,
+) -> str:
+    """
+    Render the tabbed sections and an always-visible Help button.
     """
-    Update the layout of a Plotly figure to reduce margins.
+    # CSS
+    css = get_html_template().split("<body>")[1].rsplit("</style>", 1)[0] + "</style>"
+
+    # Tabs header
+    tabs = [
+        '<div class="tabs">',
+        '<div class="tab active" onclick="showTab(\'summary\')">Validation Summary & Config</div>',
+        '<div class="tab" onclick="showTab(\'test\')">Test Summary</div>',
+        '<div class="tab" onclick="showTab(\'feature\')">Feature Importance</div>',
+    ]
+    if explainer_html:
+        tabs.append(
+            '<div class="tab" onclick="showTab(\'explainer\')">Explainer Plots</div>'
+        )
+    tabs.append('<button id="openMetricsHelp" class="help-btn">Help</button>')
+    tabs.append("</div>")
+    tabs_section = "\n".join(tabs)
 
-    Parameters:
-        fig (plotly.graph_objects.Figure): The Plotly figure to customize.
-        margin_dict (dict, optional): A dictionary specifying margin sizes.
-            Example: {'l': 10, 'r': 10, 't': 10, 'b': 10}
+    # Content
+    contents = [
+        f'<div id="summary" class="tab-content active">{summary_html}</div>',
+        f'<div id="test" class="tab-content">{test_html}</div>',
+        f'<div id="feature" class="tab-content">{feature_html}</div>',
+    ]
+    if explainer_html:
+        contents.append(
+            f'<div id="explainer" class="tab-content">{explainer_html}</div>'
+        )
+    content_section = "\n".join(contents)
 
-    Returns:
-        plotly.graph_objects.Figure: The updated Plotly figure.
-    """
+    # JS
+    js = """
+<script>
+function showTab(id) {
+  document.querySelectorAll('.tab-content').forEach(el=>el.classList.remove('active'));
+  document.querySelectorAll('.tab').forEach(el=>el.classList.remove('active'));
+  document.getElementById(id).classList.add('active');
+  document.querySelector(`.tab[onclick*="${id}"]`).classList.add('active');
+}
+</script>
+"""
+
+    return css + "\n" + tabs_section + "\n" + content_section + "\n" + js
+
+
+def customize_figure_layout(fig, margin_dict=None):
     if margin_dict is None:
-        # Set default smaller margins
-        margin_dict = {'l': 40, 'r': 40, 't': 40, 'b': 40}
-
+        margin_dict = {"l": 40, "r": 40, "t": 40, "b": 40}
     fig.update_layout(margin=margin_dict)
     return fig
 
 
-def add_plot_to_html(fig, include_plotlyjs=True):
-    custom_margin = {'l': 40, 'r': 40, 't': 60, 'b': 60}
+def add_plot_to_html(fig, include_plotlyjs=True) -> str:
+    custom_margin = {"l": 40, "r": 40, "t": 60, "b": 60}
     fig = customize_figure_layout(fig, margin_dict=custom_margin)
-    return fig.to_html(full_html=False,
-                       default_height=350,
-                       include_plotlyjs="cdn" if include_plotlyjs else False)
+    return fig.to_html(
+        full_html=False,
+        default_height=350,
+        include_plotlyjs="cdn" if include_plotlyjs else False,
+    )
 
 
-def add_hr_to_html():
+def add_hr_to_html() -> str:
     return "<hr>"
 
 
-def encode_image_to_base64(image_path):
-    """Convert an image file to a base64 encoded string."""
+def encode_image_to_base64(image_path: str) -> str:
     with open(image_path, "rb") as img_file:
         return base64.b64encode(img_file.read()).decode("utf-8")
 
 
 def predict_proba(self, X):
     pred = self.predict(X)
-    return np.array([1 - pred, pred]).T
+    return np.vstack((1 - pred, pred)).T