Mercurial > repos > bgruening > plotly_ml_performance_plots
comparison plot_ml_performance.py @ 4:f234e2e59d76 draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
| author | bgruening |
|---|---|
| date | Wed, 07 Aug 2024 10:20:17 +0000 |
| parents | 1c5dcef5ce0f |
| children |
comparison
equal
deleted
inserted
replaced
| 3:1c5dcef5ce0f | 4:f234e2e59d76 |
|---|---|
| 1 import argparse | 1 import argparse |
| 2 | 2 |
| 3 import matplotlib.pyplot as plt | |
| 3 import pandas as pd | 4 import pandas as pd |
| 4 import plotly | 5 import plotly |
| 5 import plotly.graph_objs as go | 6 import plotly.graph_objs as go |
| 6 from galaxy_ml.model_persist import load_model_from_h5 | 7 from galaxy_ml.model_persist import load_model_from_h5 |
| 7 from galaxy_ml.utils import clean_params | 8 from galaxy_ml.utils import clean_params |
| 8 from sklearn.metrics import (auc, confusion_matrix, | 9 from sklearn.metrics import ( |
| 9 precision_recall_fscore_support, roc_curve) | 10 auc, |
| 11 confusion_matrix, | |
| 12 precision_recall_fscore_support, | |
| 13 roc_curve, | |
| 14 ) | |
| 10 from sklearn.preprocessing import label_binarize | 15 from sklearn.preprocessing import label_binarize |
| 11 | 16 |
| 12 | 17 |
| 13 def main(infile_input, infile_output, infile_trained_model): | 18 def main(infile_input, infile_output, infile_trained_model): |
| 14 """ | 19 """ |
| 15 Produce an interactive confusion matrix (heatmap), precision, recall, fscore and auc plots | 20 Produce an interactive confusion matrix (heatmap), precision, recall, fscore and auc plots |
| 16 Args: | 21 Args: |
| 17 infile_input: str, input tabular file with true labels | 22 infile_input: str, input tabular file with true labels |
| 18 infile_output: str, input tabular file with predicted labels | 23 infile_output: str, input tabular file with predicted labels |
| 19 infile_trained_model: str, input trained model file (zip) | 24 infile_trained_model: str, input trained model file (h5mlm) |
| 20 """ | 25 """ |
| 21 | 26 |
| 22 df_input = pd.read_csv(infile_input, sep="\t", parse_dates=True) | 27 df_input = pd.read_csv(infile_input, sep="\t", parse_dates=True) |
| 23 df_output = pd.read_csv(infile_output, sep="\t", parse_dates=True) | 28 df_output = pd.read_csv(infile_output, sep="\t", parse_dates=True) |
| 24 true_labels = df_input.iloc[:, -1].copy() | 29 true_labels = df_input.iloc[:, -1].copy() |
| 25 predicted_labels = df_output.iloc[:, -1].copy() | 30 predicted_labels = df_output.iloc[:, -1].copy() |
| 26 axis_labels = list(set(true_labels)) | 31 axis_labels = list(set(true_labels)) |
| 27 c_matrix = confusion_matrix(true_labels, predicted_labels) | 32 c_matrix = confusion_matrix(true_labels, predicted_labels) |
| 28 data = [ | 33 fig, ax = plt.subplots(figsize=(7, 7)) |
| 29 go.Heatmap( | 34 im = plt.imshow(c_matrix, cmap="viridis") |
| 30 z=c_matrix, | 35 # add number of samples to each cell of confusion matrix plot |
| 31 x=axis_labels, | 36 for i in range(len(c_matrix)): |
| 32 y=axis_labels, | 37 for j in range(len(c_matrix)): |
| 33 colorscale="Portland", | 38 ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k") |
| 34 ) | 39 ax.set_ylabel("True class labels") |
| 35 ] | 40 ax.set_xlabel("Predicted class labels") |
| 36 | 41 ax.set_title("Confusion Matrix between true and predicted class labels") |
| 37 layout = go.Layout( | 42 ax.set_xticks(axis_labels) |
| 38 title="Confusion Matrix between true and predicted class labels", | 43 ax.set_yticks(axis_labels) |
| 39 xaxis=dict(title="Predicted class labels"), | 44 fig.colorbar(im, ax=ax) |
| 40 yaxis=dict(title="True class labels"), | 45 fig.tight_layout() |
| 41 ) | 46 plt.savefig("output_confusion.png", dpi=120) |
| 42 | |
| 43 fig = go.Figure(data=data, layout=layout) | |
| 44 plotly.offline.plot(fig, filename="output_confusion.html", auto_open=False) | |
| 45 | 47 |
| 46 # plot precision, recall and f_score for each class label | 48 # plot precision, recall and f_score for each class label |
| 47 precision, recall, f_score, _ = precision_recall_fscore_support( | 49 precision, recall, f_score, _ = precision_recall_fscore_support( |
| 48 true_labels, predicted_labels | 50 true_labels, predicted_labels |
| 49 ) | 51 ) |
