Mercurial > repos > bgruening > plotly_ml_performance_plots
comparison plot_ml_performance.py @ 3:1c5dcef5ce0f draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
| author | bgruening |
|---|---|
| date | Tue, 07 May 2024 14:11:16 +0000 |
| parents | 62e3a4e8c54c |
| children | f234e2e59d76 |
comparison
equal
deleted
inserted
replaced
| 2:62e3a4e8c54c | 3:1c5dcef5ce0f |
|---|---|
| 1 import argparse | 1 import argparse |
| 2 | |
| 2 import pandas as pd | 3 import pandas as pd |
| 3 import plotly | 4 import plotly |
| 4 import pickle | |
| 5 import plotly.graph_objs as go | 5 import plotly.graph_objs as go |
| 6 from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, roc_curve, auc | 6 from galaxy_ml.model_persist import load_model_from_h5 |
| 7 from galaxy_ml.utils import clean_params | |
| 8 from sklearn.metrics import (auc, confusion_matrix, | |
| 9 precision_recall_fscore_support, roc_curve) | |
| 7 from sklearn.preprocessing import label_binarize | 10 from sklearn.preprocessing import label_binarize |
| 8 | 11 |
| 9 | 12 |
| 10 def main(infile_input, infile_output, infile_trained_model): | 13 def main(infile_input, infile_output, infile_trained_model): |
| 11 """ | 14 """ |
| 14 infile_input: str, input tabular file with true labels | 17 infile_input: str, input tabular file with true labels |
| 15 infile_output: str, input tabular file with predicted labels | 18 infile_output: str, input tabular file with predicted labels |
| 16 infile_trained_model: str, input trained model file (zip) | 19 infile_trained_model: str, input trained model file (zip) |
| 17 """ | 20 """ |
| 18 | 21 |
| 19 df_input = pd.read_csv(infile_input, sep='\t', parse_dates=True) | 22 df_input = pd.read_csv(infile_input, sep="\t", parse_dates=True) |
| 20 df_output = pd.read_csv(infile_output, sep='\t', parse_dates=True) | 23 df_output = pd.read_csv(infile_output, sep="\t", parse_dates=True) |
| 21 true_labels = df_input.iloc[:, -1].copy() | 24 true_labels = df_input.iloc[:, -1].copy() |
| 22 predicted_labels = df_output.iloc[:, -1].copy() | 25 predicted_labels = df_output.iloc[:, -1].copy() |
| 23 axis_labels = list(set(true_labels)) | 26 axis_labels = list(set(true_labels)) |
| 24 c_matrix = confusion_matrix(true_labels, predicted_labels) | 27 c_matrix = confusion_matrix(true_labels, predicted_labels) |
| 25 data = [ | 28 data = [ |
| 26 go.Heatmap( | 29 go.Heatmap( |
| 27 z=c_matrix, | 30 z=c_matrix, |
| 28 x=axis_labels, | 31 x=axis_labels, |
| 29 y=axis_labels, | 32 y=axis_labels, |
| 30 colorscale='Portland', | 33 colorscale="Portland", |
| 31 ) | 34 ) |
| 32 ] | 35 ] |
| 33 | 36 |
| 34 layout = go.Layout( | 37 layout = go.Layout( |
| 35 title='Confusion Matrix between true and predicted class labels', | 38 title="Confusion Matrix between true and predicted class labels", |
| 36 xaxis=dict(title='Predicted class labels'), | 39 xaxis=dict(title="Predicted class labels"), |
| 37 yaxis=dict(title='True class labels') | 40 yaxis=dict(title="True class labels"), |
| 38 ) | 41 ) |
| 39 | 42 |
| 40 fig = go.Figure(data=data, layout=layout) | 43 fig = go.Figure(data=data, layout=layout) |
| 41 plotly.offline.plot(fig, filename="output_confusion.html", auto_open=False) | 44 plotly.offline.plot(fig, filename="output_confusion.html", auto_open=False) |
| 42 | 45 |
| 43 # plot precision, recall and f_score for each class label | 46 # plot precision, recall and f_score for each class label |
| 44 precision, recall, f_score, _ = precision_recall_fscore_support(true_labels, predicted_labels) | 47 precision, recall, f_score, _ = precision_recall_fscore_support( |
| 48 true_labels, predicted_labels | |
| 49 ) | |
| 45 | 50 |
| 46 trace_precision = go.Scatter( | 51 trace_precision = go.Scatter( |
| 47 x=axis_labels, | 52 x=axis_labels, y=precision, mode="lines+markers", name="Precision" |
| 48 y=precision, | |
| 49 mode='lines+markers', | |
| 50 name='Precision' | |
| 51 ) | 53 ) |
| 52 | 54 |
| 53 trace_recall = go.Scatter( | 55 trace_recall = go.Scatter( |
| 54 x=axis_labels, | 56 x=axis_labels, y=recall, mode="lines+markers", name="Recall" |
| 55 y=recall, | |
| 56 mode='lines+markers', | |
| 57 name='Recall' | |
| 58 ) | 57 ) |
| 59 | 58 |
| 60 trace_fscore = go.Scatter( | 59 trace_fscore = go.Scatter( |
| 61 x=axis_labels, | 60 x=axis_labels, y=f_score, mode="lines+markers", name="F-score" |
| 62 y=f_score, | |
| 63 mode='lines+markers', | |
| 64 name='F-score' | |
| 65 ) | 61 ) |
| 66 | 62 |
| 67 layout_prf = go.Layout( | 63 layout_prf = go.Layout( |
| 68 title='Precision, recall and f-score of true and predicted class labels', | 64 title="Precision, recall and f-score of true and predicted class labels", |
| 69 xaxis=dict(title='Class labels'), | 65 xaxis=dict(title="Class labels"), |
| 70 yaxis=dict(title='Precision, recall and f-score') | 66 yaxis=dict(title="Precision, recall and f-score"), |
| 71 ) | 67 ) |
| 72 | 68 |
| 73 data_prf = [trace_precision, trace_recall, trace_fscore] | 69 data_prf = [trace_precision, trace_recall, trace_fscore] |
| 74 fig_prf = go.Figure(data=data_prf, layout=layout_prf) | 70 fig_prf = go.Figure(data=data_prf, layout=layout_prf) |
| 75 plotly.offline.plot(fig_prf, filename="output_prf.html", auto_open=False) | 71 plotly.offline.plot(fig_prf, filename="output_prf.html", auto_open=False) |
| 76 | 72 |
| 77 # plot roc and auc curves for different classes | 73 # plot roc and auc curves for different classes |
| 78 with open(infile_trained_model, 'rb') as model_file: | 74 classifier_object = load_model_from_h5(infile_trained_model) |
| 79 model = pickle.load(model_file) | 75 model = clean_params(classifier_object) |
| 80 | 76 |
| 81 # remove the last column (label column) | 77 # remove the last column (label column) |
| 82 test_data = df_input.iloc[:, :-1] | 78 test_data = df_input.iloc[:, :-1] |
| 83 model_items = dir(model) | 79 model_items = dir(model) |
| 84 | 80 |
| 85 try: | 81 try: |
| 86 # find the probability estimating method | 82 # find the probability estimating method |
| 87 if 'predict_proba' in model_items: | 83 if "predict_proba" in model_items: |
| 88 y_score = model.predict_proba(test_data) | 84 y_score = model.predict_proba(test_data) |
| 89 elif 'decision_function' in model_items: | 85 elif "decision_function" in model_items: |
| 90 y_score = model.decision_function(test_data) | 86 y_score = model.decision_function(test_data) |
| 91 | 87 |
| 92 true_labels_list = true_labels.tolist() | 88 true_labels_list = true_labels.tolist() |
| 93 one_hot_labels = label_binarize(true_labels_list, classes=axis_labels) | 89 one_hot_labels = label_binarize(true_labels_list, classes=axis_labels) |
| 94 data_roc = list() | 90 data_roc = list() |
| 102 roc_auc[i] = auc(fpr[i], tpr[i]) | 98 roc_auc[i] = auc(fpr[i], tpr[i]) |
| 103 for i in range(len(axis_labels)): | 99 for i in range(len(axis_labels)): |
| 104 trace = go.Scatter( | 100 trace = go.Scatter( |
| 105 x=fpr[i], | 101 x=fpr[i], |
| 106 y=tpr[i], | 102 y=tpr[i], |
| 107 mode='lines+markers', | 103 mode="lines+markers", |
| 108 name='ROC curve of class {0} (AUC = {1:0.2f})'.format(i, roc_auc[i]) | 104 name="ROC curve of class {0} (AUC = {1:0.2f})".format( |
| 105 i, roc_auc[i] | |
| 106 ), | |
| 109 ) | 107 ) |
| 110 data_roc.append(trace) | 108 data_roc.append(trace) |
| 111 else: | 109 else: |
| 112 try: | 110 try: |
| 113 y_score_binary = y_score[:, 1] | 111 y_score_binary = y_score[:, 1] |
| 114 except: | 112 except Exception: |
| 115 y_score_binary = y_score | 113 y_score_binary = y_score |
| 116 fpr, tpr, _ = roc_curve(one_hot_labels, y_score_binary, pos_label=1) | 114 fpr, tpr, _ = roc_curve(one_hot_labels, y_score_binary, pos_label=1) |
| 117 roc_auc = auc(fpr, tpr) | 115 roc_auc = auc(fpr, tpr) |
| 118 trace = go.Scatter( | 116 trace = go.Scatter( |
| 119 x=fpr, | 117 x=fpr, |
| 120 y=tpr, | 118 y=tpr, |
| 121 mode='lines+markers', | 119 mode="lines+markers", |
| 122 name='ROC curve (AUC = {0:0.2f})'.format(roc_auc) | 120 name="ROC curve (AUC = {0:0.2f})".format(roc_auc), |
| 123 ) | 121 ) |
| 124 data_roc.append(trace) | 122 data_roc.append(trace) |
| 125 | 123 |
| 126 trace_diag = go.Scatter( | 124 trace_diag = go.Scatter(x=[0, 1], y=[0, 1], mode="lines", name="Chance") |
| 127 x=[0, 1], | |
| 128 y=[0, 1], | |
| 129 mode='lines', | |
| 130 name='Chance' | |
| 131 ) | |
| 132 data_roc.append(trace_diag) | 125 data_roc.append(trace_diag) |
| 133 layout_roc = go.Layout( | 126 layout_roc = go.Layout( |
| 134 title='Receiver operating characteristics (ROC) and area under curve (AUC)', | 127 title="Receiver operating characteristics (ROC) and area under curve (AUC)", |
| 135 xaxis=dict(title='False positive rate'), | 128 xaxis=dict(title="False positive rate"), |
| 136 yaxis=dict(title='True positive rate') | 129 yaxis=dict(title="True positive rate"), |
| 137 ) | 130 ) |
| 138 | 131 |
| 139 fig_roc = go.Figure(data=data_roc, layout=layout_roc) | 132 fig_roc = go.Figure(data=data_roc, layout=layout_roc) |
| 140 plotly.offline.plot(fig_roc, filename="output_roc.html", auto_open=False) | 133 plotly.offline.plot(fig_roc, filename="output_roc.html", auto_open=False) |
| 141 | 134 |
| 142 except Exception as exp: | 135 except Exception as exp: |
| 143 print("Plotting the ROC-AUC graph failed. This exception was raised: {}".format(exp)) | 136 print( |
| 137 "Plotting the ROC-AUC graph failed. This exception was raised: {}".format( | |
| 138 exp | |
| 139 ) | |
| 140 ) | |
| 144 | 141 |
| 145 | 142 |
| 146 if __name__ == "__main__": | 143 if __name__ == "__main__": |
| 147 aparser = argparse.ArgumentParser() | 144 aparser = argparse.ArgumentParser() |
| 148 aparser.add_argument("-i", "--input", dest="infile_input", required=True) | 145 aparser.add_argument("-i", "--input", dest="infile_input", required=True) |
