changeset 4:f234e2e59d76 draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
author bgruening
date Wed, 07 Aug 2024 10:20:17 +0000
parents 1c5dcef5ce0f
children
files plot_ml_performance.py plotly_ml_performance_plots.xml
diffstat 2 files changed, 29 insertions(+), 32 deletions(-) [+]
line wrap: on
line diff
--- a/plot_ml_performance.py	Tue May 07 14:11:16 2024 +0000
+++ b/plot_ml_performance.py	Wed Aug 07 10:20:17 2024 +0000
@@ -1,12 +1,17 @@
 import argparse
 
+import matplotlib.pyplot as plt
 import pandas as pd
 import plotly
 import plotly.graph_objs as go
 from galaxy_ml.model_persist import load_model_from_h5
 from galaxy_ml.utils import clean_params
-from sklearn.metrics import (auc, confusion_matrix,
-                             precision_recall_fscore_support, roc_curve)
+from sklearn.metrics import (
+    auc,
+    confusion_matrix,
+    precision_recall_fscore_support,
+    roc_curve,
+)
 from sklearn.preprocessing import label_binarize
 
 
@@ -16,7 +21,7 @@
     Args:
         infile_input: str, input tabular file with true labels
         infile_output: str, input tabular file with predicted labels
-        infile_trained_model: str, input trained model file (zip)
+        infile_trained_model: str, input trained model file (h5mlm)
     """
 
     df_input = pd.read_csv(infile_input, sep="\t", parse_dates=True)
@@ -25,23 +30,20 @@
     predicted_labels = df_output.iloc[:, -1].copy()
     axis_labels = list(set(true_labels))
     c_matrix = confusion_matrix(true_labels, predicted_labels)
-    data = [
-        go.Heatmap(
-            z=c_matrix,
-            x=axis_labels,
-            y=axis_labels,
-            colorscale="Portland",
-        )
-    ]
-
-    layout = go.Layout(
-        title="Confusion Matrix between true and predicted class labels",
-        xaxis=dict(title="Predicted class labels"),
-        yaxis=dict(title="True class labels"),
-    )
-
-    fig = go.Figure(data=data, layout=layout)
-    plotly.offline.plot(fig, filename="output_confusion.html", auto_open=False)
+    fig, ax = plt.subplots(figsize=(7, 7))
+    im = plt.imshow(c_matrix, cmap="viridis")
+    # add number of samples to each cell of confusion matrix plot
+    for i in range(len(c_matrix)):
+        for j in range(len(c_matrix)):
+            ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k")
+    ax.set_ylabel("True class labels")
+    ax.set_xlabel("Predicted class labels")
+    ax.set_title("Confusion Matrix between true and predicted class labels")
+    ax.set_xticks(axis_labels)
+    ax.set_yticks(axis_labels)
+    fig.colorbar(im, ax=ax)
+    fig.tight_layout()
+    plt.savefig("output_confusion.png", dpi=120)
 
     # plot precision, recall and f_score for each class label
     precision, recall, f_score, _ = precision_recall_fscore_support(
--- a/plotly_ml_performance_plots.xml	Tue May 07 14:11:16 2024 +0000
+++ b/plotly_ml_performance_plots.xml	Wed Aug 07 10:20:17 2024 +0000
@@ -1,4 +1,4 @@
-<tool id="plotly_ml_performance_plots" name="Plot confusion matrix, precision, recall and ROC and AUC curves" version="0.3" profile="22.05">
+<tool id="plotly_ml_performance_plots" name="Plot confusion matrix, precision, recall and ROC and AUC curves" version="0.4" profile="22.05">
     <description>of tabular data</description>
     <requirements>
 	<requirement type="package" version="0.10.0">galaxy-ml</requirement>
@@ -18,7 +18,7 @@
     </inputs>
 
     <outputs>
-        <data name="output_confusion" format="html" from_work_dir="output_confusion.html" label="Confusion matrix of tabular data on ${on_string}"/>
+        <data name="output_confusion" format="png" from_work_dir="output_confusion.png" label="Confusion matrix of tabular data on ${on_string}"/>
         <data name="output_prf" format="html" from_work_dir="output_prf.html" label="Precision, recall and f-score of tabular data on ${on_string}"/>
         <data name="output_roc" format="html" from_work_dir="output_roc.html" label="ROC and AUC curves of tabular data on ${on_string}"/>
     </outputs>
@@ -30,8 +30,7 @@
             <param name="infile_trained_model" value="model_binary_sgd.h5mlm" ftype="h5mlm"/>
 	    <output name="output_confusion">
                 <assert_contents>
-		    <has_size value="3486809" delta="10000" />
-		    <has_text text="html" />
+		    <has_size value="31751" delta="1000" />
 		</assert_contents>
             </output>
 	    <output name="output_prf">
@@ -47,8 +46,7 @@
             <param name="infile_trained_model" value="model_binary_linearsvm.h5mlm" ftype="h5mlm"/>
 	    <output name="output_confusion">
 	        <assert_contents>
-		    <has_size value="3486810" delta="10000" />
-		    <has_text text="html" />
+		    <has_size value="31983" delta="1000" />
                 </assert_contents>
 	    </output>
             <output name="output_prf">
@@ -70,8 +68,7 @@
             <param name="infile_trained_model" value="model_binary_rfc.h5mlm" ftype="h5mlm"/>
             <output name="output_confusion">
 	        <assert_contents>
-		    <has_size value="3486806" delta="10000" />
-		    <has_text text="html" />
+		    <has_size value="34096" delta="1000" />
                 </assert_contents>
 	    </output>
             <output name="output_prf">
@@ -93,8 +90,7 @@
             <param name="infile_trained_model" value="model_binary_knn.h5mlm" ftype="h5mlm"/>
             <output name="output_confusion">
 	        <assert_contents>
-		    <has_size value="3486856" delta="10000" />
-		    <has_text text="html" />
+		    <has_size value="32398" delta="1000" />
                 </assert_contents>
 	    </output>
             <output name="output_prf">
@@ -116,8 +112,7 @@
             <param name="infile_trained_model" value="model_multi_lr.h5mlm" ftype="h5mlm"/>
             <output name="output_confusion">
 	        <assert_contents>
-		    <has_size value="3486832" delta="10000" />
-		    <has_text text="html" />
+		    <has_size value="34474" delta="1000" />
                 </assert_contents>
 	    </output>
             <output name="output_prf">