Mercurial > repos > goeckslab > pycaret_predict
annotate pycaret_classification.py @ 12:e674b9e946fb draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
| author | goeckslab | 
|---|---|
| date | Mon, 08 Sep 2025 22:39:12 +0000 | 
| parents | 1aed7d47c5ec | 
| children | 
| rev | line source | 
|---|---|
| 0 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 1 import logging | 
| 8 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 2 import types | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 3 from typing import Dict | 
| 0 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 4 | 
| 12 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 5 import numpy as np | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 6 import pandas as pd | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 7 import plotly.graph_objects as go | 
| 0 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 8 from base_model_trainer import BaseModelTrainer | 
| 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 9 from dashboard import generate_classifier_explainer_dashboard | 
| 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 10 from pycaret.classification import ClassificationExperiment | 
| 12 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 11 from sklearn.metrics import auc, confusion_matrix, precision_recall_curve, roc_curve | 
| 8 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 12 from utils import predict_proba | 
| 0 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 13 | 
| 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 14 LOG = logging.getLogger(__name__) | 
| 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 15 | 
| 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 16 | 
| 12 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 17 def _apply_report_layout(fig: go.Figure) -> go.Figure: | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 18 # Give the left side more space for y-axis title/ticks and let axes auto-reserve room | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 19 fig.update_xaxes(automargin=True, title_standoff=12) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 20 fig.update_yaxes(automargin=True, title_standoff=12) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 21 fig.update_layout( | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 22 autosize=True, | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 23 margin=dict(l=120, r=40, t=60, b=60), # bump 'l' if you still see clipping | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 24 ) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 25 return fig | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 26 | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 27 | 
| 0 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 28 class ClassificationModelTrainer(BaseModelTrainer): | 
| 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 29 def __init__( | 
| 8 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 30 self, | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 31 input_file, | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 32 target_col, | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 33 output_dir, | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 34 task_type, | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 35 random_seed, | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 36 test_file=None, | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 37 **kwargs, | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 38 ): | 
| 0 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 39 super().__init__( | 
| 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 40 input_file, | 
| 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 41 target_col, | 
| 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 42 output_dir, | 
| 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 43 task_type, | 
| 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 44 random_seed, | 
| 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 45 test_file, | 
| 8 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 46 **kwargs, | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 47 ) | 
| 0 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 48 self.exp = ClassificationExperiment() | 
| 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 49 | 
| 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 50 def save_dashboard(self): | 
| 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 51 LOG.info("Saving explainer dashboard") | 
| 8 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 52 dashboard = generate_classifier_explainer_dashboard(self.exp, self.best_model) | 
| 0 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 53 dashboard.save_html("dashboard.html") | 
| 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 54 | 
| 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 55 def generate_plots(self): | 
| 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 56 LOG.info("Generating and saving plots") | 
| 2 
0314dad38aaa
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit ff6d674ecc83db933153b797ef4dbde17f07b10e
 goeckslab parents: 
0diff
changeset | 57 | 
| 
0314dad38aaa
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit ff6d674ecc83db933153b797ef4dbde17f07b10e
 goeckslab parents: 
0diff
changeset | 58 if not hasattr(self.best_model, "predict_proba"): | 
| 
0314dad38aaa
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit ff6d674ecc83db933153b797ef4dbde17f07b10e
 goeckslab parents: 
0diff
changeset | 59 self.best_model.predict_proba = types.MethodType( | 
| 8 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 60 predict_proba, self.best_model | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 61 ) | 
| 2 
0314dad38aaa
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit ff6d674ecc83db933153b797ef4dbde17f07b10e
 goeckslab parents: 
0diff
changeset | 62 LOG.warning( | 
| 8 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 63 f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch." | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 64 ) | 
| 2 
0314dad38aaa
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit ff6d674ecc83db933153b797ef4dbde17f07b10e
 goeckslab parents: 
0diff
changeset | 65 | 
| 8 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 66 plots = [ | 
| 12 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 67 "auc", | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 68 "threshold", | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 69 "pr", | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 70 "error", | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 71 "class_report", | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 72 "learning", | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 73 "calibration", | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 74 "vc", | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 75 "dimension", | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 76 "manifold", | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 77 "rfe", | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 78 "feature", | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 79 "feature_all", | 
| 8 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 80 ] | 
| 0 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 81 for plot_name in plots: | 
| 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 82 try: | 
| 8 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 83 if plot_name == "threshold": | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 84 plot_path = self.exp.plot_model( | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 85 self.best_model, | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 86 plot=plot_name, | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 87 save=True, | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 88 plot_kwargs={"binary": True, "percentage": True}, | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 89 ) | 
| 0 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 90 self.plots[plot_name] = plot_path | 
| 8 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 91 elif plot_name == "auc" and not self.exp.is_multiclass: | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 92 plot_path = self.exp.plot_model( | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 93 self.best_model, | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 94 plot=plot_name, | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 95 save=True, | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 96 plot_kwargs={ | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 97 "micro": False, | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 98 "macro": False, | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 99 "per_class": False, | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 100 "binary": True, | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 101 }, | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 102 ) | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 103 self.plots[plot_name] = plot_path | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 104 else: | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 105 plot_path = self.exp.plot_model( | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 106 self.best_model, plot=plot_name, save=True | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 107 ) | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 108 self.plots[plot_name] = plot_path | 
| 0 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 109 except Exception as e: | 
| 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 110 LOG.error(f"Error generating plot {plot_name}: {e}") | 
| 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 111 continue | 
| 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 112 | 
| 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 113 def generate_plots_explainer(self): | 
| 8 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 114 from explainerdashboard import ClassifierExplainer | 
| 0 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 115 | 
| 8 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 116 LOG.info("Generating explainer plots") | 
| 0 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 117 | 
| 12 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 118 # Ensure predict_proba is available here too | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 119 if not hasattr(self.best_model, "predict_proba"): | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 120 self.best_model.predict_proba = types.MethodType( | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 121 predict_proba, self.best_model | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 122 ) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 123 LOG.warning( | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 124 f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch." | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 125 ) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 126 | 
| 0 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 127 X_test = self.exp.X_test_transformed.copy() | 
| 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 128 y_test = self.exp.y_test_transformed | 
| 8 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 129 explainer = ClassifierExplainer(self.best_model, X_test, y_test) | 
| 0 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 130 | 
| 8 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 131 # a dict to hold the raw Figure objects or callables | 
| 12 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 132 self.explainer_plots: Dict[str, go.Figure] = {} | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 133 | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 134 # --- Threshold-aware overrides for CM / ROC / PR --- | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 135 prob_thresh = getattr(self, "probability_threshold", None) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 136 | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 137 # Only for binary classification and when threshold is provided | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 138 if (prob_thresh is not None) and (not self.exp.is_multiclass): | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 139 X = self.exp.X_test_transformed | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 140 y = pd.Series(self.exp.y_test_transformed).reset_index(drop=True) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 141 | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 142 # Get positive-class scores (robust defaults) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 143 classes = list(getattr(self.best_model, "classes_", [0, 1])) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 144 try: | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 145 pos_idx = classes.index(1) if 1 in classes else 1 | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 146 except Exception: | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 147 pos_idx = 1 | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 148 | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 149 proba = self.best_model.predict_proba(X) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 150 y_scores = proba[:, pos_idx] | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 151 | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 152 # Derive label names consistently | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 153 pos_label = classes[pos_idx] if len(classes) > pos_idx else 1 | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 154 neg_label = classes[1 - pos_idx] if len(classes) > 1 else 0 | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 155 | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 156 # ---- Confusion Matrix @ threshold ---- | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 157 try: | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 158 y_pred = np.where(y_scores >= prob_thresh, pos_label, neg_label) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 159 cm = confusion_matrix(y, y_pred, labels=[neg_label, pos_label]) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 160 fig_cm = go.Figure( | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 161 data=go.Heatmap( | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 162 z=cm, | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 163 x=[f"Pred {neg_label}", f"Pred {pos_label}"], | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 164 y=[f"True {neg_label}", f"True {pos_label}"], | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 165 text=cm, | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 166 texttemplate="%{text}", | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 167 colorscale="Blues", | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 168 showscale=False, | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 169 ) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 170 ) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 171 fig_cm.update_layout( | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 172 title=f"Confusion Matrix @ threshold={prob_thresh:.2f}", | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 173 xaxis_title="Predicted label", | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 174 yaxis_title="True label", | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 175 ) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 176 _apply_report_layout(fig_cm) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 177 self.explainer_plots["confusion_matrix"] = fig_cm | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 178 except Exception as e: | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 179 LOG.warning( | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 180 f"Threshold-aware confusion matrix failed; falling back: {e}" | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 181 ) | 
| 0 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 182 | 
| 12 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 183 # ---- ROC with threshold marker ---- | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 184 try: | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 185 fpr, tpr, thr = roc_curve(y, y_scores) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 186 roc_auc = auc(fpr, tpr) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 187 fig_roc = go.Figure() | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 188 fig_roc.add_scatter( | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 189 x=fpr, y=tpr, mode="lines", name=f"ROC (AUC={roc_auc:.3f})" | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 190 ) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 191 if len(thr): | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 192 mask = np.isfinite(thr) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 193 if mask.any(): | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 194 idx_local = int(np.argmin(np.abs(thr[mask] - prob_thresh))) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 195 idx = np.where(mask)[0][idx_local] | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 196 if 0 <= idx < len(fpr): | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 197 fig_roc.add_scatter( | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 198 x=[fpr[idx]], | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 199 y=[tpr[idx]], | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 200 mode="markers", | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 201 name=f"@ {prob_thresh:.2f}", | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 202 marker=dict(size=10), | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 203 ) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 204 fig_roc.update_layout( | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 205 title=f"ROC Curve (marker at threshold={prob_thresh:.2f})", | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 206 xaxis_title="False Positive Rate", | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 207 yaxis_title="True Positive Rate", | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 208 ) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 209 _apply_report_layout(fig_roc) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 210 self.explainer_plots["roc_auc"] = fig_roc | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 211 except Exception as e: | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 212 LOG.warning(f"Threshold marker on ROC failed; falling back: {e}") | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 213 | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 214 # ---- PR with threshold marker ---- | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 215 try: | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 216 precision, recall, thr_pr = precision_recall_curve(y, y_scores) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 217 pr_auc = auc(recall, precision) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 218 fig_pr = go.Figure() | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 219 fig_pr.add_scatter( | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 220 x=recall, y=precision, mode="lines", name=f"PR (AUC={pr_auc:.3f})" | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 221 ) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 222 if len(thr_pr): | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 223 idx_pr = int(np.argmin(np.abs(thr_pr - prob_thresh))) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 224 # note: thr_pr has length = len(precision) - 1 | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 225 idx_pr = max(0, min(idx_pr, len(recall) - 1)) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 226 fig_pr.add_scatter( | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 227 x=[recall[idx_pr]], | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 228 y=[precision[idx_pr]], | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 229 mode="markers", | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 230 name=f"@ {prob_thresh:.2f}", | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 231 marker=dict(size=10), | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 232 ) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 233 fig_pr.update_layout( | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 234 title=f"Precision–Recall (marker at threshold={prob_thresh:.2f})", | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 235 xaxis_title="Recall", | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 236 yaxis_title="Precision", | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 237 ) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 238 _apply_report_layout(fig_pr) | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 239 self.explainer_plots["pr_auc"] = fig_pr | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 240 except Exception as e: | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 241 LOG.warning(f"Threshold marker on PR failed; falling back: {e}") | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 242 | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 243 # these go into the Test tab (don't overwrite overrides) | 
| 8 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 244 for key, fn in [ | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 245 ("roc_auc", explainer.plot_roc_auc), | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 246 ("pr_auc", explainer.plot_pr_auc), | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 247 ("lift_curve", explainer.plot_lift_curve), | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 248 ("confusion_matrix", explainer.plot_confusion_matrix), | 
| 12 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 249 ("threshold", explainer.plot_precision), # percentage vs probability | 
| 8 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 250 ("cumulative_precision", explainer.plot_cumulative_precision), | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 251 ]: | 
| 12 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 252 if key in self.explainer_plots: | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 253 continue | 
| 8 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 254 try: | 
| 12 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 255 fig = fn() | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 256 if fig is not None: | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 257 self.explainer_plots[key] = fig | 
| 8 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 258 except Exception as e: | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 259 LOG.error(f"Error generating explainer plot {key}: {e}") | 
| 0 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 260 | 
| 8 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 261 # mean SHAP importances | 
| 0 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 262 try: | 
| 8 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 263 self.explainer_plots["shap_mean"] = explainer.plot_importances() | 
| 0 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 264 except Exception as e: | 
| 8 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 265 LOG.warning(f"Could not generate shap_mean: {e}") | 
| 0 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 266 | 
| 8 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 267 # permutation importances | 
| 0 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 268 try: | 
| 8 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 269 self.explainer_plots["shap_perm"] = lambda: explainer.plot_importances( | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 270 kind="permutation" | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 271 ) | 
| 0 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 272 except Exception as e: | 
| 8 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 273 LOG.warning(f"Could not generate shap_perm: {e}") | 
| 0 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 274 | 
| 8 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 275 # PDPs for each feature (appended last) | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 276 valid_feats = [] | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 277 for feat in self.features_name: | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 278 if feat in explainer.X.columns or feat in explainer.onehot_cols: | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 279 valid_feats.append(feat) | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 280 else: | 
| 12 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 281 LOG.warning( | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 282 f"Skipping PDP for feature {feat!r}: not found in explainer data" | 
| 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 283 ) | 
| 0 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 284 | 
| 8 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 285 for feat in valid_feats: | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 286 # wrap each PDP call to catch any unexpected AssertionErrors | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 287 def make_pdp_plotter(f): | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 288 def _plot(): | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 289 try: | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 290 return explainer.plot_pdp(f) | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 291 except AssertionError as ae: | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 292 LOG.warning(f"PDP AssertionError for {f!r}: {ae}") | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 293 return None | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 294 except Exception as e: | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 295 LOG.error(f"Unexpected error plotting PDP for {f!r}: {e}") | 
| 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 296 return None | 
| 12 
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
 goeckslab parents: 
8diff
changeset | 297 | 
| 8 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 298 return _plot | 
| 0 
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
 goeckslab parents: diff
changeset | 299 | 
| 8 
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
 goeckslab parents: 
3diff
changeset | 300 self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat) | 
