diff plot_ml_performance.py @ 4:f234e2e59d76 draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
author bgruening
date Wed, 07 Aug 2024 10:20:17 +0000
parents 1c5dcef5ce0f
children
line wrap: on
line diff
--- a/plot_ml_performance.py	Tue May 07 14:11:16 2024 +0000
+++ b/plot_ml_performance.py	Wed Aug 07 10:20:17 2024 +0000
@@ -1,12 +1,17 @@
 import argparse
 
+import matplotlib.pyplot as plt
 import pandas as pd
 import plotly
 import plotly.graph_objs as go
 from galaxy_ml.model_persist import load_model_from_h5
 from galaxy_ml.utils import clean_params
-from sklearn.metrics import (auc, confusion_matrix,
-                             precision_recall_fscore_support, roc_curve)
+from sklearn.metrics import (
+    auc,
+    confusion_matrix,
+    precision_recall_fscore_support,
+    roc_curve,
+)
 from sklearn.preprocessing import label_binarize
 
 
@@ -16,7 +21,7 @@
     Args:
         infile_input: str, input tabular file with true labels
         infile_output: str, input tabular file with predicted labels
-        infile_trained_model: str, input trained model file (zip)
+        infile_trained_model: str, input trained model file (h5mlm)
     """
 
     df_input = pd.read_csv(infile_input, sep="\t", parse_dates=True)
@@ -25,23 +30,20 @@
     predicted_labels = df_output.iloc[:, -1].copy()
     axis_labels = list(set(true_labels))
     c_matrix = confusion_matrix(true_labels, predicted_labels)
-    data = [
-        go.Heatmap(
-            z=c_matrix,
-            x=axis_labels,
-            y=axis_labels,
-            colorscale="Portland",
-        )
-    ]
-
-    layout = go.Layout(
-        title="Confusion Matrix between true and predicted class labels",
-        xaxis=dict(title="Predicted class labels"),
-        yaxis=dict(title="True class labels"),
-    )
-
-    fig = go.Figure(data=data, layout=layout)
-    plotly.offline.plot(fig, filename="output_confusion.html", auto_open=False)
+    fig, ax = plt.subplots(figsize=(7, 7))
+    im = plt.imshow(c_matrix, cmap="viridis")
+    # add number of samples to each cell of confusion matrix plot
+    for i in range(len(c_matrix)):
+        for j in range(len(c_matrix)):
+            ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k")
+    ax.set_ylabel("True class labels")
+    ax.set_xlabel("Predicted class labels")
+    ax.set_title("Confusion Matrix between true and predicted class labels")
+    ax.set_xticks(axis_labels)
+    ax.set_yticks(axis_labels)
+    fig.colorbar(im, ax=ax)
+    fig.tight_layout()
+    plt.savefig("output_confusion.png", dpi=120)
 
     # plot precision, recall and f_score for each class label
     precision, recall, f_score, _ = precision_recall_fscore_support(