comparison plot_ml_performance.py @ 4:f234e2e59d76 draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
author bgruening
date Wed, 07 Aug 2024 10:20:17 +0000
parents 1c5dcef5ce0f
children
comparison
equal deleted inserted replaced
3:1c5dcef5ce0f 4:f234e2e59d76
1 import argparse 1 import argparse
2 2
3 import matplotlib.pyplot as plt
3 import pandas as pd 4 import pandas as pd
4 import plotly 5 import plotly
5 import plotly.graph_objs as go 6 import plotly.graph_objs as go
6 from galaxy_ml.model_persist import load_model_from_h5 7 from galaxy_ml.model_persist import load_model_from_h5
7 from galaxy_ml.utils import clean_params 8 from galaxy_ml.utils import clean_params
8 from sklearn.metrics import (auc, confusion_matrix, 9 from sklearn.metrics import (
9 precision_recall_fscore_support, roc_curve) 10 auc,
11 confusion_matrix,
12 precision_recall_fscore_support,
13 roc_curve,
14 )
10 from sklearn.preprocessing import label_binarize 15 from sklearn.preprocessing import label_binarize
11 16
12 17
13 def main(infile_input, infile_output, infile_trained_model): 18 def main(infile_input, infile_output, infile_trained_model):
14 """ 19 """
15 Produce an interactive confusion matrix (heatmap), precision, recall, fscore and auc plots 20 Produce an interactive confusion matrix (heatmap), precision, recall, fscore and auc plots
16 Args: 21 Args:
17 infile_input: str, input tabular file with true labels 22 infile_input: str, input tabular file with true labels
18 infile_output: str, input tabular file with predicted labels 23 infile_output: str, input tabular file with predicted labels
19 infile_trained_model: str, input trained model file (zip) 24 infile_trained_model: str, input trained model file (h5mlm)
20 """ 25 """
21 26
22 df_input = pd.read_csv(infile_input, sep="\t", parse_dates=True) 27 df_input = pd.read_csv(infile_input, sep="\t", parse_dates=True)
23 df_output = pd.read_csv(infile_output, sep="\t", parse_dates=True) 28 df_output = pd.read_csv(infile_output, sep="\t", parse_dates=True)
24 true_labels = df_input.iloc[:, -1].copy() 29 true_labels = df_input.iloc[:, -1].copy()
25 predicted_labels = df_output.iloc[:, -1].copy() 30 predicted_labels = df_output.iloc[:, -1].copy()
26 axis_labels = list(set(true_labels)) 31 axis_labels = list(set(true_labels))
27 c_matrix = confusion_matrix(true_labels, predicted_labels) 32 c_matrix = confusion_matrix(true_labels, predicted_labels)
28 data = [ 33 fig, ax = plt.subplots(figsize=(7, 7))
29 go.Heatmap( 34 im = plt.imshow(c_matrix, cmap="viridis")
30 z=c_matrix, 35 # add number of samples to each cell of confusion matrix plot
31 x=axis_labels, 36 for i in range(len(c_matrix)):
32 y=axis_labels, 37 for j in range(len(c_matrix)):
33 colorscale="Portland", 38 ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k")
34 ) 39 ax.set_ylabel("True class labels")
35 ] 40 ax.set_xlabel("Predicted class labels")
36 41 ax.set_title("Confusion Matrix between true and predicted class labels")
37 layout = go.Layout( 42 ax.set_xticks(axis_labels)
38 title="Confusion Matrix between true and predicted class labels", 43 ax.set_yticks(axis_labels)
39 xaxis=dict(title="Predicted class labels"), 44 fig.colorbar(im, ax=ax)
40 yaxis=dict(title="True class labels"), 45 fig.tight_layout()
41 ) 46 plt.savefig("output_confusion.png", dpi=120)
42
43 fig = go.Figure(data=data, layout=layout)
44 plotly.offline.plot(fig, filename="output_confusion.html", auto_open=False)
45 47
46 # plot precision, recall and f_score for each class label 48 # plot precision, recall and f_score for each class label
47 precision, recall, f_score, _ = precision_recall_fscore_support( 49 precision, recall, f_score, _ = precision_recall_fscore_support(
48 true_labels, predicted_labels 50 true_labels, predicted_labels
49 ) 51 )