Mercurial > repos > bgruening > plotly_ml_performance_plots
diff plot_ml_performance.py @ 4:f234e2e59d76 draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
author | bgruening |
---|---|
date | Wed, 07 Aug 2024 10:20:17 +0000 |
parents | 1c5dcef5ce0f |
children |
line wrap: on
line diff
--- a/plot_ml_performance.py Tue May 07 14:11:16 2024 +0000 +++ b/plot_ml_performance.py Wed Aug 07 10:20:17 2024 +0000 @@ -1,12 +1,17 @@ import argparse +import matplotlib.pyplot as plt import pandas as pd import plotly import plotly.graph_objs as go from galaxy_ml.model_persist import load_model_from_h5 from galaxy_ml.utils import clean_params -from sklearn.metrics import (auc, confusion_matrix, - precision_recall_fscore_support, roc_curve) +from sklearn.metrics import ( + auc, + confusion_matrix, + precision_recall_fscore_support, + roc_curve, +) from sklearn.preprocessing import label_binarize @@ -16,7 +21,7 @@ Args: infile_input: str, input tabular file with true labels infile_output: str, input tabular file with predicted labels - infile_trained_model: str, input trained model file (zip) + infile_trained_model: str, input trained model file (h5mlm) """ df_input = pd.read_csv(infile_input, sep="\t", parse_dates=True) @@ -25,23 +30,20 @@ predicted_labels = df_output.iloc[:, -1].copy() axis_labels = list(set(true_labels)) c_matrix = confusion_matrix(true_labels, predicted_labels) - data = [ - go.Heatmap( - z=c_matrix, - x=axis_labels, - y=axis_labels, - colorscale="Portland", - ) - ] - - layout = go.Layout( - title="Confusion Matrix between true and predicted class labels", - xaxis=dict(title="Predicted class labels"), - yaxis=dict(title="True class labels"), - ) - - fig = go.Figure(data=data, layout=layout) - plotly.offline.plot(fig, filename="output_confusion.html", auto_open=False) + fig, ax = plt.subplots(figsize=(7, 7)) + im = plt.imshow(c_matrix, cmap="viridis") + # add number of samples to each cell of confusion matrix plot + for i in range(len(c_matrix)): + for j in range(len(c_matrix)): + ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k") + ax.set_ylabel("True class labels") + ax.set_xlabel("Predicted class labels") + ax.set_title("Confusion Matrix between true and predicted class labels") + ax.set_xticks(axis_labels) + ax.set_yticks(axis_labels) + fig.colorbar(im, ax=ax) + fig.tight_layout() + plt.savefig("output_confusion.png", dpi=120) # plot precision, recall and f_score for each class label precision, recall, f_score, _ = precision_recall_fscore_support(