Mercurial > repos > bgruening > plotly_ml_performance_plots
comparison plot_ml_performance.py @ 3:1c5dcef5ce0f draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
author | bgruening |
---|---|
date | Tue, 07 May 2024 14:11:16 +0000 |
parents | 62e3a4e8c54c |
children | f234e2e59d76 |
comparison
equal
deleted
inserted
replaced
2:62e3a4e8c54c | 3:1c5dcef5ce0f |
---|---|
1 import argparse | 1 import argparse |
2 | |
2 import pandas as pd | 3 import pandas as pd |
3 import plotly | 4 import plotly |
4 import pickle | |
5 import plotly.graph_objs as go | 5 import plotly.graph_objs as go |
6 from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, roc_curve, auc | 6 from galaxy_ml.model_persist import load_model_from_h5 |
7 from galaxy_ml.utils import clean_params | |
8 from sklearn.metrics import (auc, confusion_matrix, | |
9 precision_recall_fscore_support, roc_curve) | |
7 from sklearn.preprocessing import label_binarize | 10 from sklearn.preprocessing import label_binarize |
8 | 11 |
9 | 12 |
10 def main(infile_input, infile_output, infile_trained_model): | 13 def main(infile_input, infile_output, infile_trained_model): |
11 """ | 14 """ |
14 infile_input: str, input tabular file with true labels | 17 infile_input: str, input tabular file with true labels |
15 infile_output: str, input tabular file with predicted labels | 18 infile_output: str, input tabular file with predicted labels |
16 infile_trained_model: str, input trained model file (zip) | 19 infile_trained_model: str, input trained model file (zip) |
17 """ | 20 """ |
18 | 21 |
19 df_input = pd.read_csv(infile_input, sep='\t', parse_dates=True) | 22 df_input = pd.read_csv(infile_input, sep="\t", parse_dates=True) |
20 df_output = pd.read_csv(infile_output, sep='\t', parse_dates=True) | 23 df_output = pd.read_csv(infile_output, sep="\t", parse_dates=True) |
21 true_labels = df_input.iloc[:, -1].copy() | 24 true_labels = df_input.iloc[:, -1].copy() |
22 predicted_labels = df_output.iloc[:, -1].copy() | 25 predicted_labels = df_output.iloc[:, -1].copy() |
23 axis_labels = list(set(true_labels)) | 26 axis_labels = list(set(true_labels)) |
24 c_matrix = confusion_matrix(true_labels, predicted_labels) | 27 c_matrix = confusion_matrix(true_labels, predicted_labels) |
25 data = [ | 28 data = [ |
26 go.Heatmap( | 29 go.Heatmap( |
27 z=c_matrix, | 30 z=c_matrix, |
28 x=axis_labels, | 31 x=axis_labels, |
29 y=axis_labels, | 32 y=axis_labels, |
30 colorscale='Portland', | 33 colorscale="Portland", |
31 ) | 34 ) |
32 ] | 35 ] |
33 | 36 |
34 layout = go.Layout( | 37 layout = go.Layout( |
35 title='Confusion Matrix between true and predicted class labels', | 38 title="Confusion Matrix between true and predicted class labels", |
36 xaxis=dict(title='Predicted class labels'), | 39 xaxis=dict(title="Predicted class labels"), |
37 yaxis=dict(title='True class labels') | 40 yaxis=dict(title="True class labels"), |
38 ) | 41 ) |
39 | 42 |
40 fig = go.Figure(data=data, layout=layout) | 43 fig = go.Figure(data=data, layout=layout) |
41 plotly.offline.plot(fig, filename="output_confusion.html", auto_open=False) | 44 plotly.offline.plot(fig, filename="output_confusion.html", auto_open=False) |
42 | 45 |
43 # plot precision, recall and f_score for each class label | 46 # plot precision, recall and f_score for each class label |
44 precision, recall, f_score, _ = precision_recall_fscore_support(true_labels, predicted_labels) | 47 precision, recall, f_score, _ = precision_recall_fscore_support( |
48 true_labels, predicted_labels | |
49 ) | |
45 | 50 |
46 trace_precision = go.Scatter( | 51 trace_precision = go.Scatter( |
47 x=axis_labels, | 52 x=axis_labels, y=precision, mode="lines+markers", name="Precision" |
48 y=precision, | |
49 mode='lines+markers', | |
50 name='Precision' | |
51 ) | 53 ) |
52 | 54 |
53 trace_recall = go.Scatter( | 55 trace_recall = go.Scatter( |
54 x=axis_labels, | 56 x=axis_labels, y=recall, mode="lines+markers", name="Recall" |
55 y=recall, | |
56 mode='lines+markers', | |
57 name='Recall' | |
58 ) | 57 ) |
59 | 58 |
60 trace_fscore = go.Scatter( | 59 trace_fscore = go.Scatter( |
61 x=axis_labels, | 60 x=axis_labels, y=f_score, mode="lines+markers", name="F-score" |
62 y=f_score, | |
63 mode='lines+markers', | |
64 name='F-score' | |
65 ) | 61 ) |
66 | 62 |
67 layout_prf = go.Layout( | 63 layout_prf = go.Layout( |
68 title='Precision, recall and f-score of true and predicted class labels', | 64 title="Precision, recall and f-score of true and predicted class labels", |
69 xaxis=dict(title='Class labels'), | 65 xaxis=dict(title="Class labels"), |
70 yaxis=dict(title='Precision, recall and f-score') | 66 yaxis=dict(title="Precision, recall and f-score"), |
71 ) | 67 ) |
72 | 68 |
73 data_prf = [trace_precision, trace_recall, trace_fscore] | 69 data_prf = [trace_precision, trace_recall, trace_fscore] |
74 fig_prf = go.Figure(data=data_prf, layout=layout_prf) | 70 fig_prf = go.Figure(data=data_prf, layout=layout_prf) |
75 plotly.offline.plot(fig_prf, filename="output_prf.html", auto_open=False) | 71 plotly.offline.plot(fig_prf, filename="output_prf.html", auto_open=False) |
76 | 72 |
77 # plot roc and auc curves for different classes | 73 # plot roc and auc curves for different classes |
78 with open(infile_trained_model, 'rb') as model_file: | 74 classifier_object = load_model_from_h5(infile_trained_model) |
79 model = pickle.load(model_file) | 75 model = clean_params(classifier_object) |
80 | 76 |
81 # remove the last column (label column) | 77 # remove the last column (label column) |
82 test_data = df_input.iloc[:, :-1] | 78 test_data = df_input.iloc[:, :-1] |
83 model_items = dir(model) | 79 model_items = dir(model) |
84 | 80 |
85 try: | 81 try: |
86 # find the probability estimating method | 82 # find the probability estimating method |
87 if 'predict_proba' in model_items: | 83 if "predict_proba" in model_items: |
88 y_score = model.predict_proba(test_data) | 84 y_score = model.predict_proba(test_data) |
89 elif 'decision_function' in model_items: | 85 elif "decision_function" in model_items: |
90 y_score = model.decision_function(test_data) | 86 y_score = model.decision_function(test_data) |
91 | 87 |
92 true_labels_list = true_labels.tolist() | 88 true_labels_list = true_labels.tolist() |
93 one_hot_labels = label_binarize(true_labels_list, classes=axis_labels) | 89 one_hot_labels = label_binarize(true_labels_list, classes=axis_labels) |
94 data_roc = list() | 90 data_roc = list() |
102 roc_auc[i] = auc(fpr[i], tpr[i]) | 98 roc_auc[i] = auc(fpr[i], tpr[i]) |
103 for i in range(len(axis_labels)): | 99 for i in range(len(axis_labels)): |
104 trace = go.Scatter( | 100 trace = go.Scatter( |
105 x=fpr[i], | 101 x=fpr[i], |
106 y=tpr[i], | 102 y=tpr[i], |
107 mode='lines+markers', | 103 mode="lines+markers", |
108 name='ROC curve of class {0} (AUC = {1:0.2f})'.format(i, roc_auc[i]) | 104 name="ROC curve of class {0} (AUC = {1:0.2f})".format( |
105 i, roc_auc[i] | |
106 ), | |
109 ) | 107 ) |
110 data_roc.append(trace) | 108 data_roc.append(trace) |
111 else: | 109 else: |
112 try: | 110 try: |
113 y_score_binary = y_score[:, 1] | 111 y_score_binary = y_score[:, 1] |
114 except: | 112 except Exception: |
115 y_score_binary = y_score | 113 y_score_binary = y_score |
116 fpr, tpr, _ = roc_curve(one_hot_labels, y_score_binary, pos_label=1) | 114 fpr, tpr, _ = roc_curve(one_hot_labels, y_score_binary, pos_label=1) |
117 roc_auc = auc(fpr, tpr) | 115 roc_auc = auc(fpr, tpr) |
118 trace = go.Scatter( | 116 trace = go.Scatter( |
119 x=fpr, | 117 x=fpr, |
120 y=tpr, | 118 y=tpr, |
121 mode='lines+markers', | 119 mode="lines+markers", |
122 name='ROC curve (AUC = {0:0.2f})'.format(roc_auc) | 120 name="ROC curve (AUC = {0:0.2f})".format(roc_auc), |
123 ) | 121 ) |
124 data_roc.append(trace) | 122 data_roc.append(trace) |
125 | 123 |
126 trace_diag = go.Scatter( | 124 trace_diag = go.Scatter(x=[0, 1], y=[0, 1], mode="lines", name="Chance") |
127 x=[0, 1], | |
128 y=[0, 1], | |
129 mode='lines', | |
130 name='Chance' | |
131 ) | |
132 data_roc.append(trace_diag) | 125 data_roc.append(trace_diag) |
133 layout_roc = go.Layout( | 126 layout_roc = go.Layout( |
134 title='Receiver operating characteristics (ROC) and area under curve (AUC)', | 127 title="Receiver operating characteristics (ROC) and area under curve (AUC)", |
135 xaxis=dict(title='False positive rate'), | 128 xaxis=dict(title="False positive rate"), |
136 yaxis=dict(title='True positive rate') | 129 yaxis=dict(title="True positive rate"), |
137 ) | 130 ) |
138 | 131 |
139 fig_roc = go.Figure(data=data_roc, layout=layout_roc) | 132 fig_roc = go.Figure(data=data_roc, layout=layout_roc) |
140 plotly.offline.plot(fig_roc, filename="output_roc.html", auto_open=False) | 133 plotly.offline.plot(fig_roc, filename="output_roc.html", auto_open=False) |
141 | 134 |
142 except Exception as exp: | 135 except Exception as exp: |
143 print("Plotting the ROC-AUC graph failed. This exception was raised: {}".format(exp)) | 136 print( |
137 "Plotting the ROC-AUC graph failed. This exception was raised: {}".format( | |
138 exp | |
139 ) | |
140 ) | |
144 | 141 |
145 | 142 |
146 if __name__ == "__main__": | 143 if __name__ == "__main__": |
147 aparser = argparse.ArgumentParser() | 144 aparser = argparse.ArgumentParser() |
148 aparser.add_argument("-i", "--input", dest="infile_input", required=True) | 145 aparser.add_argument("-i", "--input", dest="infile_input", required=True) |