Mercurial > repos > bgruening > plotly_ml_performance_plots
comparison plot_ml_performance.py @ 0:4fac53da862f draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
| author | bgruening |
|---|---|
| date | Thu, 11 Oct 2018 14:37:54 -0400 |
| parents | |
| children | 85da91bbdbfb |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:4fac53da862f |
|---|---|
| 1 import argparse | |
| 2 import pandas as pd | |
| 3 import plotly | |
| 4 import pickle | |
| 5 import plotly.graph_objs as go | |
| 6 from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, roc_curve, auc | |
| 7 from sklearn.preprocessing import label_binarize | |
| 8 | |
| 9 | |
| 10 def main(infile_input, infile_output, infile_trained_model): | |
| 11 """ | |
| 12 Produce an interactive confusion matrix (heatmap), precision, recall, fscore and auc plots | |
| 13 Args: | |
| 14 infile_input: str, input tabular file with true labels | |
| 15 infile_output: str, input tabular file with predicted labels | |
| 16 infile_trained_model: str, input trained model file (zip) | |
| 17 """ | |
| 18 | |
| 19 df_input = pd.read_csv(infile_input, sep='\t', parse_dates=True) | |
| 20 df_output = pd.read_csv(infile_output, sep='\t', parse_dates=True) | |
| 21 true_labels = df_input.iloc[:, -1].copy() | |
| 22 predicted_labels = df_output.iloc[:, -1].copy() | |
| 23 axis_labels = list(set(true_labels)) | |
| 24 c_matrix = confusion_matrix(true_labels, predicted_labels) | |
| 25 data = [ | |
| 26 go.Heatmap( | |
| 27 z=c_matrix, | |
| 28 x=axis_labels, | |
| 29 y=axis_labels, | |
| 30 colorscale='Portland', | |
| 31 ) | |
| 32 ] | |
| 33 | |
| 34 layout = go.Layout( | |
| 35 title='Confusion Matrix between true and predicted class labels', | |
| 36 xaxis=dict(title='True class labels'), | |
| 37 yaxis=dict(title='Predicted class labels') | |
| 38 ) | |
| 39 | |
| 40 fig = go.Figure(data=data, layout=layout) | |
| 41 plotly.offline.plot(fig, filename="output_confusion.html", auto_open=False) | |
| 42 | |
| 43 # plot precision, recall and f_score for each class label | |
| 44 precision, recall, f_score, _ = precision_recall_fscore_support(true_labels, predicted_labels) | |
| 45 | |
| 46 trace_precision = go.Scatter( | |
| 47 x=axis_labels, | |
| 48 y=precision, | |
| 49 mode='lines+markers', | |
| 50 name='Precision' | |
| 51 ) | |
| 52 | |
| 53 trace_recall = go.Scatter( | |
| 54 x=axis_labels, | |
| 55 y=recall, | |
| 56 mode='lines+markers', | |
| 57 name='Recall' | |
| 58 ) | |
| 59 | |
| 60 trace_fscore = go.Scatter( | |
| 61 x=axis_labels, | |
| 62 y=f_score, | |
| 63 mode='lines+markers', | |
| 64 name='F-score' | |
| 65 ) | |
| 66 | |
| 67 layout_prf = go.Layout( | |
| 68 title='Precision, recall and f-score of true and predicted class labels', | |
| 69 xaxis=dict(title='Class labels'), | |
| 70 yaxis=dict(title='Precision, recall and f-score') | |
| 71 ) | |
| 72 | |
| 73 data_prf = [trace_precision, trace_recall, trace_fscore] | |
| 74 fig_prf = go.Figure(data=data_prf, layout=layout_prf) | |
| 75 plotly.offline.plot(fig_prf, filename="output_prf.html", auto_open=False) | |
| 76 | |
| 77 # plot roc and auc curves for different classes | |
| 78 with open(infile_trained_model, 'rb') as model_file: | |
| 79 model = pickle.load(model_file) | |
| 80 | |
| 81 # remove the last column (label column) | |
| 82 test_data = df_input.iloc[:, :-1] | |
| 83 model_items = dir(model) | |
| 84 | |
| 85 try: | |
| 86 # find the probability estimating method | |
| 87 if 'predict_proba' in model_items: | |
| 88 y_score = model.predict_proba(test_data) | |
| 89 elif 'decision_function' in model_items: | |
| 90 y_score = model.decision_function(test_data) | |
| 91 | |
| 92 true_labels_list = true_labels.tolist() | |
| 93 one_hot_labels = label_binarize(true_labels_list, classes=axis_labels) | |
| 94 data_roc = list() | |
| 95 | |
| 96 if len(axis_labels) > 2: | |
| 97 fpr = dict() | |
| 98 tpr = dict() | |
| 99 roc_auc = dict() | |
| 100 for i in axis_labels: | |
| 101 fpr[i], tpr[i], _ = roc_curve(one_hot_labels[:, i], y_score[:, i]) | |
| 102 roc_auc[i] = auc(fpr[i], tpr[i]) | |
| 103 for i in range(len(axis_labels)): | |
| 104 trace = go.Scatter( | |
| 105 x=fpr[i], | |
| 106 y=tpr[i], | |
| 107 mode='lines+markers', | |
| 108 name='ROC curve of class {0} (AUC = {1:0.2f})'.format(i, roc_auc[i]) | |
| 109 ) | |
| 110 data_roc.append(trace) | |
| 111 else: | |
| 112 try: | |
| 113 y_score_binary = y_score[:, 1] | |
| 114 except: | |
| 115 y_score_binary = y_score | |
| 116 fpr, tpr, _ = roc_curve(one_hot_labels, y_score_binary, pos_label=1) | |
| 117 roc_auc = auc(fpr, tpr) | |
| 118 trace = go.Scatter( | |
| 119 x=fpr, | |
| 120 y=tpr, | |
| 121 mode='lines+markers', | |
| 122 name='ROC curve (AUC = {0:0.2f})'.format(roc_auc) | |
| 123 ) | |
| 124 data_roc.append(trace) | |
| 125 | |
| 126 trace_diag = go.Scatter( | |
| 127 x=[0, 1], | |
| 128 y=[0, 1], | |
| 129 mode='lines', | |
| 130 name='Chance' | |
| 131 ) | |
| 132 data_roc.append(trace_diag) | |
| 133 layout_roc = go.Layout( | |
| 134 title='Receiver operating characteristics (ROC) and area under curve (AUC)', | |
| 135 xaxis=dict(title='False positive rate'), | |
| 136 yaxis=dict(title='True positive rate') | |
| 137 ) | |
| 138 | |
| 139 fig_roc = go.Figure(data=data_roc, layout=layout_roc) | |
| 140 plotly.offline.plot(fig_roc, filename="output_roc.html", auto_open=False) | |
| 141 | |
| 142 except Exception as exp: | |
| 143 pass | |
| 144 | |
| 145 | |
| 146 if __name__ == "__main__": | |
| 147 aparser = argparse.ArgumentParser() | |
| 148 aparser.add_argument("-i", "--input", dest="infile_input", required=True) | |
| 149 aparser.add_argument("-j", "--output", dest="infile_output", required=True) | |
| 150 aparser.add_argument("-k", "--model", dest="infile_trained_model", required=True) | |
| 151 args = aparser.parse_args() | |
| 152 main(args.infile_input, args.infile_output, args.infile_trained_model) |
