diff image_learner_cli.py @ 10:b0d893d04d4c draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 1594d503179f28987720594eb49b48a15486f073
author goeckslab
date Mon, 08 Sep 2025 22:38:35 +0000
parents 9e912fce264c
children
line wrap: on
line diff
--- a/image_learner_cli.py	Wed Aug 27 21:02:48 2025 +0000
+++ b/image_learner_cli.py	Mon Sep 08 22:38:35 2025 +0000
@@ -69,7 +69,6 @@
     ]
 
     rows = []
-
     for key in display_keys:
         val = config.get(key, None)
         if key == "threshold":
@@ -136,15 +135,15 @@
                         val_str = val
             else:
                 val_str = val if val is not None else "N/A"
-            if val_str == "N/A" and key not in [
-                "task_type"
-            ]:  # Skip if N/A for non-essential
+            if val_str == "N/A" and key not in ["task_type"]:
                 continue
         rows.append(
             f"<tr>"
-            f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>"
+            f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left; "
+            f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>"
             f"{key.replace('_', ' ').title()}</td>"
-            f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>"
+            f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center; "
+            f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>"
             f"{val_str}</td>"
             f"</tr>"
         )
@@ -153,13 +152,17 @@
         types = [str(a.get("type", "")) for a in aug_cfg]
         aug_val = ", ".join(types)
         rows.append(
-            f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Augmentation</td>"
-            f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{aug_val}</td></tr>"
+            f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left; "
+            f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>Augmentation</td>"
+            f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center; "
+            f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>{aug_val}</td></tr>"
         )
     if split_info:
         rows.append(
-            f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Data Split</td>"
-            f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{split_info}</td></tr>"
+            f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left; "
+            f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>Data Split</td>"
+            f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center; "
+            f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>{split_info}</td></tr>"
         )
     html = f"""
         <h2 style="text-align: center;">Model and Training Summary</h2>
@@ -946,6 +949,66 @@
         test_viz_dir = base_viz_dir / "test"
 
         html = get_html_template()
+
+        # Extra CSS & JS: center Plotly and enable CSV download for predictions table
+        html += """
+<style>
+  /* Center Plotly figures (both wrapper and native classes) */
+  .plotly-center { display: flex; justify-content: center; }
+  .plotly-center .plotly-graph-div, .plotly-center .js-plotly-plot { margin: 0 auto !important; }
+  .js-plotly-plot, .plotly-graph-div { margin-left: auto !important; margin-right: auto !important; }
+
+  /* Download button for predictions table */
+  .download-btn {
+    padding: 8px 12px;
+    border: 1px solid #4CAF50;
+    background: #4CAF50;
+    color: white;
+    border-radius: 6px;
+    cursor: pointer;
+  }
+  .download-btn:hover { filter: brightness(0.95); }
+  .preds-controls {
+    display: flex;
+    justify-content: flex-end;
+    gap: 8px;
+    margin: 8px 0;
+  }
+</style>
+<script>
+  function tableToCSV(table){
+    const rows = Array.from(table.querySelectorAll('tr'));
+    return rows.map(row =>
+      Array.from(row.querySelectorAll('th,td')).map(cell => {
+        let text = cell.innerText.replace(/\\r?\\n|\\r/g,' ').trim();
+        if (text.includes('"') || text.includes(',')) {
+          text = '"' + text.replace(/"/g,'""') + '"';
+        }
+        return text;
+      }).join(',')
+    ).join('\\n');
+  }
+  document.addEventListener('DOMContentLoaded', function(){
+    const btn = document.getElementById('downloadPredsCsv');
+    if(btn){
+      btn.addEventListener('click', function(){
+        const tbl = document.querySelector('.predictions-table');
+        if(!tbl){ alert('Predictions table not found.'); return; }
+        const csv = tableToCSV(tbl);
+        const blob = new Blob([csv], {type: 'text/csv;charset=utf-8;'});
+        const url = URL.createObjectURL(blob);
+        const a = document.createElement('a');
+        a.href = url;
+        a.download = 'ground_truth_vs_predictions.csv';
+        document.body.appendChild(a);
+        a.click();
+        document.body.removeChild(a);
+        URL.revokeObjectURL(url);
+      });
+    }
+  });
+</script>
+"""
         html += f"<h1>{title}</h1>"
 
         metrics_html = ""
@@ -983,31 +1046,38 @@
         except Exception as e:
             logger.warning(f"Could not load config for HTML report: {e}")
 
+        # ---------- image rendering with exclusions ----------
         def render_img_section(
-            title: str, dir_path: Path, output_type: str = None
+            title: str,
+            dir_path: Path,
+            output_type: str = None,
+            exclude_names: Optional[set] = None,
         ) -> str:
             if not dir_path.exists():
                 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>"
-            # collect every PNG
+
+            exclude_names = exclude_names or set()
+
             imgs = list(dir_path.glob("*.png"))
-            # --- EXCLUDE Ludwig's base confusion matrix and any top-N confusion_matrix files ---
+
+            default_exclude = {"confusion_matrix.png", "roc_curves.png"}
+
             imgs = [
                 img
                 for img in imgs
-                if not (
-                    img.name == "confusion_matrix.png"
-                    or img.name.startswith("confusion_matrix__label_top")
-                    or img.name == "roc_curves.png"
-                )
+                if img.name not in default_exclude
+                and img.name not in exclude_names
+                and not img.name.startswith("confusion_matrix__label_top")
             ]
+
             if not imgs:
                 return f"<h2>{title}</h2><p><em>No plots found.</em></p>"
+
             if output_type == "binary":
                 order = [
                     "roc_curves_from_prediction_statistics.png",
                     "compare_performance_label.png",
                     "confusion_matrix_entropy__label_top2.png",
-                    # ...you can tweak ordering as needed
                 ]
                 img_names = {img.name: img for img in imgs}
                 ordered = [img_names[n] for n in order if n in img_names]
@@ -1019,14 +1089,13 @@
                     "compare_classifiers_multiclass_multimetric__label_top10.png",
                     "compare_classifiers_multiclass_multimetric__label_worst10.png",
                 }
+                valid_imgs = [img for img in imgs if img.name not in unwanted]
                 display_order = [
                     "roc_curves.png",
                     "compare_performance_label.png",
                     "compare_classifiers_performance_from_prob.png",
                     "confusion_matrix_entropy__label_top10.png",
                 ]
-                # filter and order
-                valid_imgs = [img for img in imgs if img.name not in unwanted]
                 img_map = {img.name: img for img in valid_imgs}
                 ordered = [img_map[n] for n in display_order if n in img_map]
                 others = sorted(
@@ -1034,27 +1103,36 @@
                 )
                 imgs = ordered + others
             else:
-                # regression: just sort whatever's left
                 imgs = sorted(imgs)
-            # render each remaining PNG
-            html = ""
+
+            html_section = ""
             for img in imgs:
                 b64 = encode_image_to_base64(str(img))
                 img_title = img.stem.replace("_", " ").title()
-                html += (
+                html_section += (
                     f"<h2 style='text-align: center;'>{img_title}</h2>"
                     f'<div class="plot" style="margin-bottom:20px;text-align:center;">'
                     f'<img src="data:image/png;base64,{b64}" '
                     f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />'
                     f"</div>"
                 )
-            return html
+            return html_section
 
         tab1_content = config_html + metrics_html
+
         tab2_content = train_val_metrics_html + render_img_section(
-            "Training and Validation Visualizations", train_viz_dir
+            "Training and Validation Visualizations",
+            train_viz_dir,
+            output_type,
+            exclude_names={
+                "compare_classifiers_performance_from_prob.png",
+                "roc_curves_from_prediction_statistics.png",
+                "precision_recall_curves_from_prediction_statistics.png",
+                "precision_recall_curve.png",
+            },
         )
-        # --- Predictions vs Ground Truth table ---
+
+        # --- Predictions vs Ground Truth table (REGRESSION ONLY) ---
         preds_section = ""
         parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME
         if output_type == "regression" and parquet_path.exists():
@@ -1081,13 +1159,19 @@
                 preds_html = df_table.to_html(index=False, classes="predictions-table")
                 preds_section = (
                     "<h2 style='text-align: center;'>Ground Truth vs. Predictions</h2>"
-                    "<div style='overflow-y:auto; max-height:400px; overflow-x:auto; margin-bottom:20px;'>"
+                    "<div class='preds-controls'>"
+                    "<button id='downloadPredsCsv' class='download-btn'>Download CSV</button>"
+                    "</div>"
+                    "<div class='scroll-rows-30' style='overflow-x:auto; overflow-y:auto; max-height:900px; margin-bottom:20px;'>"
                     + preds_html
                     + "</div>"
                 )
             except Exception as e:
                 logger.warning(f"Could not build Predictions vs GT table: {e}")
+
         tab3_content = test_metrics_html + preds_section
+
+        # Classification-only interactive Plotly panels (centered)
         if output_type in ("binary", "category"):
             training_stats_path = exp_dir / "training_statistics.json"
             interactive_plots = build_classification_plots(
@@ -1095,31 +1179,16 @@
                 str(training_stats_path),
             )
             for plot in interactive_plots:
-                # 2) inject the static "roc_curves_from_prediction_statistics.png"
-                if plot["title"] == "ROC-AUC":
-                    static_img = (
-                        test_viz_dir / "roc_curves_from_prediction_statistics.png"
-                    )
-                    if static_img.exists():
-                        b64 = encode_image_to_base64(str(static_img))
-                        tab3_content += (
-                            "<h2 style='text-align: center;'>"
-                            "Roc Curves From Prediction Statistics"
-                            "</h2>"
-                            f'<div class="plot" style="margin-bottom:20px;text-align:center;">'
-                            f'<img src="data:image/png;base64,{b64}" '
-                            f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />'
-                            "</div>"
-                        )
-                # always render the plotly panels exactly as before
                 tab3_content += (
                     f"<h2 style='text-align: center;'>{plot['title']}</h2>"
-                    + plot["html"]
+                    f"<div class='plotly-center'>{plot['html']}</div>"
                 )
-            tab3_content += render_img_section(
-                "Test Visualizations", test_viz_dir, output_type
-            )
-        # assemble the tabs and help modal
+
+        # Add static TEST PNGs (with default dedupe/exclusions)
+        tab3_content += render_img_section(
+            "Test Visualizations", test_viz_dir, output_type
+        )
+
         tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content)
         modal_html = get_metrics_help_modal()
         html += tabbed_html + modal_html + get_html_closing()