Mercurial > repos > bgruening > plotly_ml_performance_plots
comparison plot_ml_performance.py @ 0:4fac53da862f draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
author | bgruening |
---|---|
date | Thu, 11 Oct 2018 14:37:54 -0400 |
parents | |
children | 85da91bbdbfb |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:4fac53da862f |
---|---|
1 import argparse | |
2 import pandas as pd | |
3 import plotly | |
4 import pickle | |
5 import plotly.graph_objs as go | |
6 from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, roc_curve, auc | |
7 from sklearn.preprocessing import label_binarize | |
8 | |
9 | |
10 def main(infile_input, infile_output, infile_trained_model): | |
11 """ | |
12 Produce an interactive confusion matrix (heatmap), precision, recall, fscore and auc plots | |
13 Args: | |
14 infile_input: str, input tabular file with true labels | |
15 infile_output: str, input tabular file with predicted labels | |
16 infile_trained_model: str, input trained model file (zip) | |
17 """ | |
18 | |
19 df_input = pd.read_csv(infile_input, sep='\t', parse_dates=True) | |
20 df_output = pd.read_csv(infile_output, sep='\t', parse_dates=True) | |
21 true_labels = df_input.iloc[:, -1].copy() | |
22 predicted_labels = df_output.iloc[:, -1].copy() | |
23 axis_labels = list(set(true_labels)) | |
24 c_matrix = confusion_matrix(true_labels, predicted_labels) | |
25 data = [ | |
26 go.Heatmap( | |
27 z=c_matrix, | |
28 x=axis_labels, | |
29 y=axis_labels, | |
30 colorscale='Portland', | |
31 ) | |
32 ] | |
33 | |
34 layout = go.Layout( | |
35 title='Confusion Matrix between true and predicted class labels', | |
36 xaxis=dict(title='True class labels'), | |
37 yaxis=dict(title='Predicted class labels') | |
38 ) | |
39 | |
40 fig = go.Figure(data=data, layout=layout) | |
41 plotly.offline.plot(fig, filename="output_confusion.html", auto_open=False) | |
42 | |
43 # plot precision, recall and f_score for each class label | |
44 precision, recall, f_score, _ = precision_recall_fscore_support(true_labels, predicted_labels) | |
45 | |
46 trace_precision = go.Scatter( | |
47 x=axis_labels, | |
48 y=precision, | |
49 mode='lines+markers', | |
50 name='Precision' | |
51 ) | |
52 | |
53 trace_recall = go.Scatter( | |
54 x=axis_labels, | |
55 y=recall, | |
56 mode='lines+markers', | |
57 name='Recall' | |
58 ) | |
59 | |
60 trace_fscore = go.Scatter( | |
61 x=axis_labels, | |
62 y=f_score, | |
63 mode='lines+markers', | |
64 name='F-score' | |
65 ) | |
66 | |
67 layout_prf = go.Layout( | |
68 title='Precision, recall and f-score of true and predicted class labels', | |
69 xaxis=dict(title='Class labels'), | |
70 yaxis=dict(title='Precision, recall and f-score') | |
71 ) | |
72 | |
73 data_prf = [trace_precision, trace_recall, trace_fscore] | |
74 fig_prf = go.Figure(data=data_prf, layout=layout_prf) | |
75 plotly.offline.plot(fig_prf, filename="output_prf.html", auto_open=False) | |
76 | |
77 # plot roc and auc curves for different classes | |
78 with open(infile_trained_model, 'rb') as model_file: | |
79 model = pickle.load(model_file) | |
80 | |
81 # remove the last column (label column) | |
82 test_data = df_input.iloc[:, :-1] | |
83 model_items = dir(model) | |
84 | |
85 try: | |
86 # find the probability estimating method | |
87 if 'predict_proba' in model_items: | |
88 y_score = model.predict_proba(test_data) | |
89 elif 'decision_function' in model_items: | |
90 y_score = model.decision_function(test_data) | |
91 | |
92 true_labels_list = true_labels.tolist() | |
93 one_hot_labels = label_binarize(true_labels_list, classes=axis_labels) | |
94 data_roc = list() | |
95 | |
96 if len(axis_labels) > 2: | |
97 fpr = dict() | |
98 tpr = dict() | |
99 roc_auc = dict() | |
100 for i in axis_labels: | |
101 fpr[i], tpr[i], _ = roc_curve(one_hot_labels[:, i], y_score[:, i]) | |
102 roc_auc[i] = auc(fpr[i], tpr[i]) | |
103 for i in range(len(axis_labels)): | |
104 trace = go.Scatter( | |
105 x=fpr[i], | |
106 y=tpr[i], | |
107 mode='lines+markers', | |
108 name='ROC curve of class {0} (AUC = {1:0.2f})'.format(i, roc_auc[i]) | |
109 ) | |
110 data_roc.append(trace) | |
111 else: | |
112 try: | |
113 y_score_binary = y_score[:, 1] | |
114 except: | |
115 y_score_binary = y_score | |
116 fpr, tpr, _ = roc_curve(one_hot_labels, y_score_binary, pos_label=1) | |
117 roc_auc = auc(fpr, tpr) | |
118 trace = go.Scatter( | |
119 x=fpr, | |
120 y=tpr, | |
121 mode='lines+markers', | |
122 name='ROC curve (AUC = {0:0.2f})'.format(roc_auc) | |
123 ) | |
124 data_roc.append(trace) | |
125 | |
126 trace_diag = go.Scatter( | |
127 x=[0, 1], | |
128 y=[0, 1], | |
129 mode='lines', | |
130 name='Chance' | |
131 ) | |
132 data_roc.append(trace_diag) | |
133 layout_roc = go.Layout( | |
134 title='Receiver operating characteristics (ROC) and area under curve (AUC)', | |
135 xaxis=dict(title='False positive rate'), | |
136 yaxis=dict(title='True positive rate') | |
137 ) | |
138 | |
139 fig_roc = go.Figure(data=data_roc, layout=layout_roc) | |
140 plotly.offline.plot(fig_roc, filename="output_roc.html", auto_open=False) | |
141 | |
142 except Exception as exp: | |
143 pass | |
144 | |
145 | |
146 if __name__ == "__main__": | |
147 aparser = argparse.ArgumentParser() | |
148 aparser.add_argument("-i", "--input", dest="infile_input", required=True) | |
149 aparser.add_argument("-j", "--output", dest="infile_output", required=True) | |
150 aparser.add_argument("-k", "--model", dest="infile_trained_model", required=True) | |
151 args = aparser.parse_args() | |
152 main(args.infile_input, args.infile_output, args.infile_trained_model) |