diff ml_visualization_ex.py @ 15:2eb5c017958d draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author bgruening
date Wed, 09 Aug 2023 13:15:27 +0000
parents caf7d2b71a48
children adf9f8d5ab9a
line wrap: on
line diff
--- a/ml_visualization_ex.py	Thu Aug 11 09:13:40 2022 +0000
+++ b/ml_visualization_ex.py	Wed Aug 09 13:15:27 2023 +0000
@@ -9,13 +9,18 @@
 import pandas as pd
 import plotly
 import plotly.graph_objs as go
-from galaxy_ml.utils import load_model, read_columns, SafeEval
-from keras.models import model_from_json
-from keras.utils import plot_model
-from sklearn.feature_selection.base import SelectorMixin
-from sklearn.metrics import (auc, average_precision_score, confusion_matrix,
-                             precision_recall_curve, roc_curve)
+from galaxy_ml.model_persist import load_model_from_h5
+from galaxy_ml.utils import read_columns, SafeEval
+from sklearn.feature_selection._base import SelectorMixin
+from sklearn.metrics import (
+    auc,
+    average_precision_score,
+    precision_recall_curve,
+    roc_curve,
+)
 from sklearn.pipeline import Pipeline
+from tensorflow.keras.models import model_from_json
+from tensorflow.keras.utils import plot_model
 
 safe_eval = SafeEval()
 
@@ -253,30 +258,6 @@
     os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output"))
 
 
-def get_dataframe(file_path, plot_selection, header_name, column_name):
-    header = "infer" if plot_selection[header_name] else None
-    column_option = plot_selection[column_name]["selected_column_selector_option"]
-    if column_option in [
-        "by_index_number",
-        "all_but_by_index_number",
-        "by_header_name",
-        "all_but_by_header_name",
-    ]:
-        col = plot_selection[column_name]["col1"]
-    else:
-        col = None
-    _, input_df = read_columns(
-        file_path,
-        c=col,
-        c_option=column_option,
-        return_df=True,
-        sep="\t",
-        header=header,
-        parse_dates=True,
-    )
-    return input_df
-
-
 def main(
     inputs,
     infile_estimator=None,
@@ -290,10 +271,6 @@
     targets=None,
     fasta_path=None,
     model_config=None,
-    true_labels=None,
-    predicted_labels=None,
-    plot_color=None,
-    title=None,
 ):
     """
     Parameter
@@ -334,18 +311,6 @@
 
     model_config : str, default is None
         File path to dataset containing JSON config for neural networks
-
-    true_labels : str, default is None
-        File path to dataset containing true labels
-
-    predicted_labels : str, default is None
-        File path to dataset containing true predicted labels
-
-    plot_color : str, default is None
-        Color of the confusion matrix heatmap
-
-    title : str, default is None
-        Title of the confusion matrix heatmap
     """
     warnings.simplefilter("ignore")
 
@@ -357,8 +322,7 @@
     plot_format = params["plotting_selection"]["plot_format"]
 
     if plot_type == "feature_importances":
-        with open(infile_estimator, "rb") as estimator_handler:
-            estimator = load_model(estimator_handler)
+        estimator = load_model_from_h5(infile_estimator)
 
         column_option = params["plotting_selection"]["column_selector_options"][
             "selected_column_selector_option"
@@ -570,36 +534,6 @@
 
         return 0
 
-    elif plot_type == "classification_confusion_matrix":
-        plot_selection = params["plotting_selection"]
-        input_true = get_dataframe(
-            true_labels, plot_selection, "header_true", "column_selector_options_true"
-        )
-        header_predicted = "infer" if plot_selection["header_predicted"] else None
-        input_predicted = pd.read_csv(
-            predicted_labels, sep="\t", parse_dates=True, header=header_predicted
-        )
-        true_classes = input_true.iloc[:, -1].copy()
-        predicted_classes = input_predicted.iloc[:, -1].copy()
-        axis_labels = list(set(true_classes))
-        c_matrix = confusion_matrix(true_classes, predicted_classes)
-        fig, ax = plt.subplots(figsize=(7, 7))
-        im = plt.imshow(c_matrix, cmap=plot_color)
-        for i in range(len(c_matrix)):
-            for j in range(len(c_matrix)):
-                ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k")
-        ax.set_ylabel("True class labels")
-        ax.set_xlabel("Predicted class labels")
-        ax.set_title(title)
-        ax.set_xticks(axis_labels)
-        ax.set_yticks(axis_labels)
-        fig.colorbar(im, ax=ax)
-        fig.tight_layout()
-        plt.savefig("output.png", dpi=125)
-        os.rename("output.png", "output")
-
-        return 0
-
     # save pdf file to disk
     # fig.write_image("image.pdf", format='pdf')
     # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2)
@@ -619,10 +553,6 @@
     aparser.add_argument("-t", "--targets", dest="targets")
     aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
     aparser.add_argument("-c", "--model_config", dest="model_config")
-    aparser.add_argument("-tl", "--true_labels", dest="true_labels")
-    aparser.add_argument("-pl", "--predicted_labels", dest="predicted_labels")
-    aparser.add_argument("-pc", "--plot_color", dest="plot_color")
-    aparser.add_argument("-pt", "--title", dest="title")
     args = aparser.parse_args()
 
     main(
@@ -638,8 +568,4 @@
         targets=args.targets,
         fasta_path=args.fasta_path,
         model_config=args.model_config,
-        true_labels=args.true_labels,
-        predicted_labels=args.predicted_labels,
-        plot_color=args.plot_color,
-        title=args.title,
     )