diff ml_visualization_ex.py @ 33:b4ba9602a36e draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9e28f4466084464d38d3f8db2aff07974be4ba69"
author bgruening
date Wed, 11 Mar 2020 13:36:07 -0400
parents df579b31311d
children 1e99cfb71f40
line wrap: on
line diff
--- a/ml_visualization_ex.py	Wed Jan 22 07:43:35 2020 -0500
+++ b/ml_visualization_ex.py	Wed Mar 11 13:36:07 2020 -0400
@@ -13,7 +13,7 @@
 from keras.utils import plot_model
 from sklearn.feature_selection.base import SelectorMixin
 from sklearn.metrics import precision_recall_curve, average_precision_score
-from sklearn.metrics import roc_curve, auc
+from sklearn.metrics import roc_curve, auc, confusion_matrix
 from sklearn.pipeline import Pipeline
 from galaxy_ml.utils import load_model, read_columns, SafeEval
 
@@ -266,12 +266,29 @@
               os.path.join(folder, "output"))
 
 
+def get_dataframe(file_path, plot_selection, header_name, column_name):
+    header = 'infer' if plot_selection[header_name] else None
+    column_option = plot_selection[column_name]["selected_column_selector_option"]
+    if column_option in ["by_index_number", "all_but_by_index_number", "by_header_name", "all_but_by_header_name"]:
+        col = plot_selection[column_name]["col1"]
+    else:
+        col = None
+    _, input_df = read_columns(file_path, c=col,
+                                   c_option=column_option,
+                                   return_df=True,
+                                   sep='\t', header=header,
+                                   parse_dates=True)
+    return input_df
+
+
 def main(inputs, infile_estimator=None, infile1=None,
          infile2=None, outfile_result=None,
          outfile_object=None, groups=None,
          ref_seq=None, intervals=None,
          targets=None, fasta_path=None,
-         model_config=None):
+         model_config=None, true_labels=None,
+         predicted_labels=None, plot_color=None,
+         title=None):
     """
     Parameter
     ---------
@@ -311,6 +328,18 @@
 
     model_config : str, default is None
         File path to dataset containing JSON config for neural networks
+
+    true_labels : str, default is None
+        File path to dataset containing true labels
+
+    predicted_labels : str, default is None
+        File path to dataset containing true predicted labels
+
+    plot_color : str, default is None
+        Color of the confusion matrix heatmap
+
+    title : str, default is None
+        Title of the confusion matrix heatmap
     """
     warnings.simplefilter('ignore')
 
@@ -543,6 +572,32 @@
 
         return 0
 
+    elif plot_type == 'classification_confusion_matrix':
+        plot_selection = params["plotting_selection"]
+        input_true = get_dataframe(true_labels, plot_selection, "header_true", "column_selector_options_true")
+        header_predicted = 'infer' if plot_selection["header_predicted"] else None
+        input_predicted = pd.read_csv(predicted_labels, sep='\t', parse_dates=True, header=header_predicted)
+        true_classes = input_true.iloc[:, -1].copy()
+        predicted_classes = input_predicted.iloc[:, -1].copy()
+        axis_labels = list(set(true_classes))
+        c_matrix = confusion_matrix(true_classes, predicted_classes)
+        fig, ax = plt.subplots(figsize=(7, 7))
+        im = plt.imshow(c_matrix, cmap=plot_color)
+        for i in range(len(c_matrix)):
+            for j in range(len(c_matrix)):
+                ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k")
+        ax.set_ylabel('True class labels')
+        ax.set_xlabel('Predicted class labels')
+        ax.set_title(title)
+        ax.set_xticks(axis_labels)
+        ax.set_yticks(axis_labels)
+        fig.colorbar(im, ax=ax)
+        fig.tight_layout()
+        plt.savefig("output.png", dpi=125)
+        os.rename('output.png', 'output')
+
+        return 0
+
     # save pdf file to disk
     # fig.write_image("image.pdf", format='pdf')
     # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2)
@@ -562,10 +617,17 @@
     aparser.add_argument("-t", "--targets", dest="targets")
     aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
     aparser.add_argument("-c", "--model_config", dest="model_config")
+    aparser.add_argument("-tl", "--true_labels", dest="true_labels")
+    aparser.add_argument("-pl", "--predicted_labels", dest="predicted_labels")
+    aparser.add_argument("-pc", "--plot_color", dest="plot_color")
+    aparser.add_argument("-pt", "--title", dest="title")
     args = aparser.parse_args()
 
     main(args.inputs, args.infile_estimator, args.infile1, args.infile2,
          args.outfile_result, outfile_object=args.outfile_object,
          groups=args.groups, ref_seq=args.ref_seq, intervals=args.intervals,
          targets=args.targets, fasta_path=args.fasta_path,
-         model_config=args.model_config)
+         model_config=args.model_config, true_labels=args.true_labels,
+         predicted_labels=args.predicted_labels,
+         plot_color=args.plot_color,
+         title=args.title)