annotate pycaret_classification.py @ 16:4fee4504646e draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 2b826699ef9518d4610f5cfb6468ce719ec8039d
author goeckslab
date Fri, 28 Nov 2025 22:28:26 +0000
parents a2aeeb754d76
children
Ignore whitespace changes - Everywhere: Within whitespace: At end of lines:
rev   line source
0
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
1 import logging
8
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
2 import types
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
3 from typing import Dict
0
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
4
12
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
5 import numpy as np
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
6 import pandas as pd
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
7 import plotly.graph_objects as go
0
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
8 from base_model_trainer import BaseModelTrainer
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
9 from dashboard import generate_classifier_explainer_dashboard
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
10 from pycaret.classification import ClassificationExperiment
12
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
11 from sklearn.metrics import auc, confusion_matrix, precision_recall_curve, roc_curve
8
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
12 from utils import predict_proba
0
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
13
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
14 LOG = logging.getLogger(__name__)
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
15
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
16
12
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
17 def _apply_report_layout(fig: go.Figure) -> go.Figure:
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
18 # Give the left side more space for y-axis title/ticks and let axes auto-reserve room
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
19 fig.update_xaxes(automargin=True, title_standoff=12)
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
20 fig.update_yaxes(automargin=True, title_standoff=12)
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
21 fig.update_layout(
15
a2aeeb754d76 planemo upload for repository https://github.com/goeckslab/gleam commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents: 12
diff changeset
22 plot_bgcolor="#ffffff",
a2aeeb754d76 planemo upload for repository https://github.com/goeckslab/gleam commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents: 12
diff changeset
23 paper_bgcolor="#ffffff",
a2aeeb754d76 planemo upload for repository https://github.com/goeckslab/gleam commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents: 12
diff changeset
24 )
a2aeeb754d76 planemo upload for repository https://github.com/goeckslab/gleam commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents: 12
diff changeset
25 fig.update_xaxes(gridcolor="#e8e8e8")
a2aeeb754d76 planemo upload for repository https://github.com/goeckslab/gleam commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents: 12
diff changeset
26 fig.update_yaxes(gridcolor="#e8e8e8")
a2aeeb754d76 planemo upload for repository https://github.com/goeckslab/gleam commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents: 12
diff changeset
27 fig.update_layout(
12
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
28 autosize=True,
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
29 margin=dict(l=120, r=40, t=60, b=60), # bump 'l' if you still see clipping
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
30 )
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
31 return fig
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
32
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
33
0
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
34 class ClassificationModelTrainer(BaseModelTrainer):
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
35 def __init__(
8
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
36 self,
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
37 input_file,
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
38 target_col,
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
39 output_dir,
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
40 task_type,
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
41 random_seed,
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
42 test_file=None,
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
43 **kwargs,
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
44 ):
0
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
45 super().__init__(
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
46 input_file,
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
47 target_col,
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
48 output_dir,
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
49 task_type,
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
50 random_seed,
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
51 test_file,
8
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
52 **kwargs,
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
53 )
0
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
54 self.exp = ClassificationExperiment()
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
55
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
56 def save_dashboard(self):
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
57 LOG.info("Saving explainer dashboard")
8
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
58 dashboard = generate_classifier_explainer_dashboard(self.exp, self.best_model)
0
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
59 dashboard.save_html("dashboard.html")
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
60
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
61 def generate_plots(self):
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
62 LOG.info("Generating and saving plots")
2
0314dad38aaa planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit ff6d674ecc83db933153b797ef4dbde17f07b10e
goeckslab
parents: 0
diff changeset
63
0314dad38aaa planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit ff6d674ecc83db933153b797ef4dbde17f07b10e
goeckslab
parents: 0
diff changeset
64 if not hasattr(self.best_model, "predict_proba"):
0314dad38aaa planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit ff6d674ecc83db933153b797ef4dbde17f07b10e
goeckslab
parents: 0
diff changeset
65 self.best_model.predict_proba = types.MethodType(
8
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
66 predict_proba, self.best_model
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
67 )
2
0314dad38aaa planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit ff6d674ecc83db933153b797ef4dbde17f07b10e
goeckslab
parents: 0
diff changeset
68 LOG.warning(
8
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
69 f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch."
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
70 )
2
0314dad38aaa planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit ff6d674ecc83db933153b797ef4dbde17f07b10e
goeckslab
parents: 0
diff changeset
71
8
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
72 plots = [
12
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
73 "auc",
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
74 "threshold",
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
75 "pr",
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
76 "error",
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
77 "class_report",
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
78 "learning",
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
79 "calibration",
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
80 "vc",
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
81 "dimension",
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
82 "manifold",
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
83 "rfe",
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
84 "feature",
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
85 "feature_all",
8
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
86 ]
0
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
87 for plot_name in plots:
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
88 try:
8
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
89 if plot_name == "threshold":
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
90 plot_path = self.exp.plot_model(
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
91 self.best_model,
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
92 plot=plot_name,
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
93 save=True,
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
94 plot_kwargs={"binary": True, "percentage": True},
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
95 )
0
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
96 self.plots[plot_name] = plot_path
8
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
97 elif plot_name == "auc" and not self.exp.is_multiclass:
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
98 plot_path = self.exp.plot_model(
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
99 self.best_model,
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
100 plot=plot_name,
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
101 save=True,
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
102 plot_kwargs={
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
103 "micro": False,
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
104 "macro": False,
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
105 "per_class": False,
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
106 "binary": True,
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
107 },
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
108 )
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
109 self.plots[plot_name] = plot_path
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
110 else:
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
111 plot_path = self.exp.plot_model(
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
112 self.best_model, plot=plot_name, save=True
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
113 )
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
114 self.plots[plot_name] = plot_path
0
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
115 except Exception as e:
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
116 LOG.error(f"Error generating plot {plot_name}: {e}")
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
117 continue
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
118
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
119 def generate_plots_explainer(self):
8
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
120 from explainerdashboard import ClassifierExplainer
0
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
121
8
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
122 LOG.info("Generating explainer plots")
0
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
123
12
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
124 # Ensure predict_proba is available here too
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
125 if not hasattr(self.best_model, "predict_proba"):
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
126 self.best_model.predict_proba = types.MethodType(
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
127 predict_proba, self.best_model
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
128 )
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
129 LOG.warning(
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
130 f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch."
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
131 )
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
132
0
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
133 X_test = self.exp.X_test_transformed.copy()
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
134 y_test = self.exp.y_test_transformed
8
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
135 explainer = ClassifierExplainer(self.best_model, X_test, y_test)
0
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
136
8
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
137 # a dict to hold the raw Figure objects or callables
12
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
138 self.explainer_plots: Dict[str, go.Figure] = {}
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
139
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
140 # --- Threshold-aware overrides for CM / ROC / PR ---
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
141 prob_thresh = getattr(self, "probability_threshold", None)
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
142
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
143 # Only for binary classification and when threshold is provided
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
144 if (prob_thresh is not None) and (not self.exp.is_multiclass):
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
145 X = self.exp.X_test_transformed
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
146 y = pd.Series(self.exp.y_test_transformed).reset_index(drop=True)
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
147
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
148 # Get positive-class scores (robust defaults)
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
149 classes = list(getattr(self.best_model, "classes_", [0, 1]))
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
150 try:
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
151 pos_idx = classes.index(1) if 1 in classes else 1
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
152 except Exception:
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
153 pos_idx = 1
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
154
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
155 proba = self.best_model.predict_proba(X)
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
156 y_scores = proba[:, pos_idx]
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
157
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
158 # Derive label names consistently
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
159 pos_label = classes[pos_idx] if len(classes) > pos_idx else 1
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
160 neg_label = classes[1 - pos_idx] if len(classes) > 1 else 0
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
161
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
162 # ---- Confusion Matrix @ threshold ----
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
163 try:
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
164 y_pred = np.where(y_scores >= prob_thresh, pos_label, neg_label)
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
165 cm = confusion_matrix(y, y_pred, labels=[neg_label, pos_label])
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
166 fig_cm = go.Figure(
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
167 data=go.Heatmap(
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
168 z=cm,
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
169 x=[f"Pred {neg_label}", f"Pred {pos_label}"],
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
170 y=[f"True {neg_label}", f"True {pos_label}"],
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
171 text=cm,
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
172 texttemplate="%{text}",
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
173 colorscale="Blues",
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
174 showscale=False,
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
175 )
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
176 )
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
177 fig_cm.update_layout(
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
178 title=f"Confusion Matrix @ threshold={prob_thresh:.2f}",
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
179 xaxis_title="Predicted label",
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
180 yaxis_title="True label",
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
181 )
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
182 _apply_report_layout(fig_cm)
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
183 self.explainer_plots["confusion_matrix"] = fig_cm
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
184 except Exception as e:
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
185 LOG.warning(
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
186 f"Threshold-aware confusion matrix failed; falling back: {e}"
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
187 )
0
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
188
12
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
189 # ---- ROC with threshold marker ----
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
190 try:
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
191 fpr, tpr, thr = roc_curve(y, y_scores)
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
192 roc_auc = auc(fpr, tpr)
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
193 fig_roc = go.Figure()
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
194 fig_roc.add_scatter(
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
195 x=fpr, y=tpr, mode="lines", name=f"ROC (AUC={roc_auc:.3f})"
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
196 )
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
197 if len(thr):
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
198 mask = np.isfinite(thr)
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
199 if mask.any():
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
200 idx_local = int(np.argmin(np.abs(thr[mask] - prob_thresh)))
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
201 idx = np.where(mask)[0][idx_local]
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
202 if 0 <= idx < len(fpr):
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
203 fig_roc.add_scatter(
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
204 x=[fpr[idx]],
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
205 y=[tpr[idx]],
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
206 mode="markers",
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
207 name=f"@ {prob_thresh:.2f}",
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
208 marker=dict(size=10),
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
209 )
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
210 fig_roc.update_layout(
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
211 title=f"ROC Curve (marker at threshold={prob_thresh:.2f})",
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
212 xaxis_title="False Positive Rate",
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
213 yaxis_title="True Positive Rate",
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
214 )
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
215 _apply_report_layout(fig_roc)
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
216 self.explainer_plots["roc_auc"] = fig_roc
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
217 except Exception as e:
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
218 LOG.warning(f"Threshold marker on ROC failed; falling back: {e}")
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
219
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
220 # ---- PR with threshold marker ----
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
221 try:
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
222 precision, recall, thr_pr = precision_recall_curve(y, y_scores)
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
223 pr_auc = auc(recall, precision)
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
224 fig_pr = go.Figure()
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
225 fig_pr.add_scatter(
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
226 x=recall, y=precision, mode="lines", name=f"PR (AUC={pr_auc:.3f})"
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
227 )
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
228 if len(thr_pr):
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
229 idx_pr = int(np.argmin(np.abs(thr_pr - prob_thresh)))
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
230 # note: thr_pr has length = len(precision) - 1
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
231 idx_pr = max(0, min(idx_pr, len(recall) - 1))
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
232 fig_pr.add_scatter(
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
233 x=[recall[idx_pr]],
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
234 y=[precision[idx_pr]],
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
235 mode="markers",
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
236 name=f"@ {prob_thresh:.2f}",
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
237 marker=dict(size=10),
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
238 )
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
239 fig_pr.update_layout(
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
240 title=f"Precision–Recall (marker at threshold={prob_thresh:.2f})",
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
241 xaxis_title="Recall",
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
242 yaxis_title="Precision",
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
243 )
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
244 _apply_report_layout(fig_pr)
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
245 self.explainer_plots["pr_auc"] = fig_pr
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
246 except Exception as e:
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
247 LOG.warning(f"Threshold marker on PR failed; falling back: {e}")
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
248
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
249 # these go into the Test tab (don't overwrite overrides)
8
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
250 for key, fn in [
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
251 ("roc_auc", explainer.plot_roc_auc),
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
252 ("pr_auc", explainer.plot_pr_auc),
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
253 ("lift_curve", explainer.plot_lift_curve),
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
254 ("confusion_matrix", explainer.plot_confusion_matrix),
12
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
255 ("threshold", explainer.plot_precision), # percentage vs probability
8
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
256 ("cumulative_precision", explainer.plot_cumulative_precision),
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
257 ]:
12
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
258 if key in self.explainer_plots:
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
259 continue
8
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
260 try:
12
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
261 fig = fn()
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
262 if fig is not None:
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
263 self.explainer_plots[key] = fig
8
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
264 except Exception as e:
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
265 LOG.error(f"Error generating explainer plot {key}: {e}")
0
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
266
8
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
267 # mean SHAP importances
0
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
268 try:
8
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
269 self.explainer_plots["shap_mean"] = explainer.plot_importances()
0
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
270 except Exception as e:
8
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
271 LOG.warning(f"Could not generate shap_mean: {e}")
0
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
272
8
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
273 # permutation importances
0
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
274 try:
8
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
275 self.explainer_plots["shap_perm"] = lambda: explainer.plot_importances(
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
276 kind="permutation"
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
277 )
0
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
278 except Exception as e:
8
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
279 LOG.warning(f"Could not generate shap_perm: {e}")
0
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
280
8
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
281 # PDPs for each feature (appended last)
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
282 valid_feats = []
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
283 for feat in self.features_name:
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
284 if feat in explainer.X.columns or feat in explainer.onehot_cols:
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
285 valid_feats.append(feat)
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
286 else:
12
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
287 LOG.warning(
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
288 f"Skipping PDP for feature {feat!r}: not found in explainer data"
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
289 )
0
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
290
8
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
291 for feat in valid_feats:
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
292 # wrap each PDP call to catch any unexpected AssertionErrors
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
293 def make_pdp_plotter(f):
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
294 def _plot():
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
295 try:
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
296 return explainer.plot_pdp(f)
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
297 except AssertionError as ae:
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
298 LOG.warning(f"PDP AssertionError for {f!r}: {ae}")
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
299 return None
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
300 except Exception as e:
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
301 LOG.error(f"Unexpected error plotting PDP for {f!r}: {e}")
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
302 return None
12
e674b9e946fb planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents: 8
diff changeset
303
8
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
304 return _plot
0
1f20fe57fdee planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
305
8
1aed7d47c5ec planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents: 3
diff changeset
306 self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat)