Mercurial > repos > bgruening > plotly_ml_performance_plots
annotate plot_ml_performance.py @ 4:f234e2e59d76 draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
author | bgruening |
---|---|
date | Wed, 07 Aug 2024 10:20:17 +0000 |
parents | 1c5dcef5ce0f |
children |
rev | line source |
---|---|
0
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
1 import argparse |
3
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
2 |
4
f234e2e59d76
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents:
3
diff
changeset
|
3 import matplotlib.pyplot as plt |
0
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
4 import pandas as pd |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
5 import plotly |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
6 import plotly.graph_objs as go |
3
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
7 from galaxy_ml.model_persist import load_model_from_h5 |
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
8 from galaxy_ml.utils import clean_params |
4
f234e2e59d76
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents:
3
diff
changeset
|
9 from sklearn.metrics import ( |
f234e2e59d76
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents:
3
diff
changeset
|
10 auc, |
f234e2e59d76
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents:
3
diff
changeset
|
11 confusion_matrix, |
f234e2e59d76
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents:
3
diff
changeset
|
12 precision_recall_fscore_support, |
f234e2e59d76
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents:
3
diff
changeset
|
13 roc_curve, |
f234e2e59d76
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents:
3
diff
changeset
|
14 ) |
0
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
15 from sklearn.preprocessing import label_binarize |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
16 |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
17 |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
18 def main(infile_input, infile_output, infile_trained_model): |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
19 """ |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
20 Produce an interactive confusion matrix (heatmap), precision, recall, fscore and auc plots |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
21 Args: |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
22 infile_input: str, input tabular file with true labels |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
23 infile_output: str, input tabular file with predicted labels |
4
f234e2e59d76
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents:
3
diff
changeset
|
24 infile_trained_model: str, input trained model file (h5mlm) |
0
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
25 """ |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
26 |
3
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
27 df_input = pd.read_csv(infile_input, sep="\t", parse_dates=True) |
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
28 df_output = pd.read_csv(infile_output, sep="\t", parse_dates=True) |
0
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
29 true_labels = df_input.iloc[:, -1].copy() |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
30 predicted_labels = df_output.iloc[:, -1].copy() |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
31 axis_labels = list(set(true_labels)) |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
32 c_matrix = confusion_matrix(true_labels, predicted_labels) |
4
f234e2e59d76
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents:
3
diff
changeset
|
33 fig, ax = plt.subplots(figsize=(7, 7)) |
f234e2e59d76
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents:
3
diff
changeset
|
34 im = plt.imshow(c_matrix, cmap="viridis") |
f234e2e59d76
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents:
3
diff
changeset
|
35 # add number of samples to each cell of confusion matrix plot |
f234e2e59d76
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents:
3
diff
changeset
|
36 for i in range(len(c_matrix)): |
f234e2e59d76
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents:
3
diff
changeset
|
37 for j in range(len(c_matrix)): |
f234e2e59d76
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents:
3
diff
changeset
|
38 ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k") |
f234e2e59d76
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents:
3
diff
changeset
|
39 ax.set_ylabel("True class labels") |
f234e2e59d76
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents:
3
diff
changeset
|
40 ax.set_xlabel("Predicted class labels") |
f234e2e59d76
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents:
3
diff
changeset
|
41 ax.set_title("Confusion Matrix between true and predicted class labels") |
f234e2e59d76
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents:
3
diff
changeset
|
42 ax.set_xticks(axis_labels) |
f234e2e59d76
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents:
3
diff
changeset
|
43 ax.set_yticks(axis_labels) |
f234e2e59d76
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents:
3
diff
changeset
|
44 fig.colorbar(im, ax=ax) |
f234e2e59d76
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents:
3
diff
changeset
|
45 fig.tight_layout() |
f234e2e59d76
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents:
3
diff
changeset
|
46 plt.savefig("output_confusion.png", dpi=120) |
0
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
47 |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
48 # plot precision, recall and f_score for each class label |
3
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
49 precision, recall, f_score, _ = precision_recall_fscore_support( |
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
50 true_labels, predicted_labels |
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
51 ) |
0
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
52 |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
53 trace_precision = go.Scatter( |
3
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
54 x=axis_labels, y=precision, mode="lines+markers", name="Precision" |
0
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
55 ) |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
56 |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
57 trace_recall = go.Scatter( |
3
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
58 x=axis_labels, y=recall, mode="lines+markers", name="Recall" |
0
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
59 ) |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
60 |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
61 trace_fscore = go.Scatter( |
3
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
62 x=axis_labels, y=f_score, mode="lines+markers", name="F-score" |
0
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
63 ) |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
64 |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
65 layout_prf = go.Layout( |
3
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
66 title="Precision, recall and f-score of true and predicted class labels", |
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
67 xaxis=dict(title="Class labels"), |
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
68 yaxis=dict(title="Precision, recall and f-score"), |
0
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
69 ) |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
70 |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
71 data_prf = [trace_precision, trace_recall, trace_fscore] |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
72 fig_prf = go.Figure(data=data_prf, layout=layout_prf) |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
73 plotly.offline.plot(fig_prf, filename="output_prf.html", auto_open=False) |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
74 |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
75 # plot roc and auc curves for different classes |
3
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
76 classifier_object = load_model_from_h5(infile_trained_model) |
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
77 model = clean_params(classifier_object) |
0
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
78 |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
79 # remove the last column (label column) |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
80 test_data = df_input.iloc[:, :-1] |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
81 model_items = dir(model) |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
82 |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
83 try: |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
84 # find the probability estimating method |
3
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
85 if "predict_proba" in model_items: |
0
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
86 y_score = model.predict_proba(test_data) |
3
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
87 elif "decision_function" in model_items: |
0
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
88 y_score = model.decision_function(test_data) |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
89 |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
90 true_labels_list = true_labels.tolist() |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
91 one_hot_labels = label_binarize(true_labels_list, classes=axis_labels) |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
92 data_roc = list() |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
93 |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
94 if len(axis_labels) > 2: |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
95 fpr = dict() |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
96 tpr = dict() |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
97 roc_auc = dict() |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
98 for i in axis_labels: |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
99 fpr[i], tpr[i], _ = roc_curve(one_hot_labels[:, i], y_score[:, i]) |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
100 roc_auc[i] = auc(fpr[i], tpr[i]) |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
101 for i in range(len(axis_labels)): |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
102 trace = go.Scatter( |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
103 x=fpr[i], |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
104 y=tpr[i], |
3
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
105 mode="lines+markers", |
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
106 name="ROC curve of class {0} (AUC = {1:0.2f})".format( |
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
107 i, roc_auc[i] |
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
108 ), |
0
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
109 ) |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
110 data_roc.append(trace) |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
111 else: |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
112 try: |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
113 y_score_binary = y_score[:, 1] |
3
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
114 except Exception: |
0
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
115 y_score_binary = y_score |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
116 fpr, tpr, _ = roc_curve(one_hot_labels, y_score_binary, pos_label=1) |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
117 roc_auc = auc(fpr, tpr) |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
118 trace = go.Scatter( |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
119 x=fpr, |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
120 y=tpr, |
3
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
121 mode="lines+markers", |
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
122 name="ROC curve (AUC = {0:0.2f})".format(roc_auc), |
0
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
123 ) |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
124 data_roc.append(trace) |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
125 |
3
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
126 trace_diag = go.Scatter(x=[0, 1], y=[0, 1], mode="lines", name="Chance") |
0
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
127 data_roc.append(trace_diag) |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
128 layout_roc = go.Layout( |
3
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
129 title="Receiver operating characteristics (ROC) and area under curve (AUC)", |
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
130 xaxis=dict(title="False positive rate"), |
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
131 yaxis=dict(title="True positive rate"), |
0
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
132 ) |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
133 |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
134 fig_roc = go.Figure(data=data_roc, layout=layout_roc) |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
135 plotly.offline.plot(fig_roc, filename="output_roc.html", auto_open=False) |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
136 |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
137 except Exception as exp: |
3
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
138 print( |
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
139 "Plotting the ROC-AUC graph failed. This exception was raised: {}".format( |
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
140 exp |
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
141 ) |
1c5dcef5ce0f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents:
2
diff
changeset
|
142 ) |
0
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
143 |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
144 |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
145 if __name__ == "__main__": |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
146 aparser = argparse.ArgumentParser() |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
147 aparser.add_argument("-i", "--input", dest="infile_input", required=True) |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
148 aparser.add_argument("-j", "--output", dest="infile_output", required=True) |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
149 aparser.add_argument("-k", "--model", dest="infile_trained_model", required=True) |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
150 args = aparser.parse_args() |
4fac53da862f
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff
changeset
|
151 main(args.infile_input, args.infile_output, args.infile_trained_model) |