diff keras_deep_learning.py @ 4:9349ed2749c6 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author bgruening
date Wed, 09 Aug 2023 13:01:50 +0000
parents af2624d5ab32
children
line wrap: on
line diff
--- a/keras_deep_learning.py	Thu Aug 11 09:38:31 2022 +0000
+++ b/keras_deep_learning.py	Wed Aug 09 13:01:50 2023 +0000
@@ -1,21 +1,19 @@
 import argparse
 import json
-import pickle
 import warnings
 from ast import literal_eval
 
-import keras
-import pandas as pd
 import six
-from galaxy_ml.utils import get_search_params, SafeEval, try_get_attr
-from keras.models import Model, Sequential
+from galaxy_ml.model_persist import dump_model_to_h5
+from galaxy_ml.utils import SafeEval, try_get_attr
+from tensorflow import keras
+from tensorflow.keras.models import Model, Sequential
 
 safe_eval = SafeEval()
 
 
 def _handle_shape(literal):
-    """
-    Eval integer or list/tuple of integers from string
+    """Eval integer or list/tuple of integers from string
 
     Parameters:
     -----------
@@ -32,8 +30,7 @@
 
 
 def _handle_regularizer(literal):
-    """
-    Construct regularizer from string literal
+    """Construct regularizer from string literal
 
     Parameters
     ----------
@@ -57,8 +54,7 @@
 
 
 def _handle_constraint(config):
-    """
-    Construct constraint from galaxy tool parameters.
+    """Construct constraint from galaxy tool parameters.
     Suppose correct dictionary format
 
     Parameters
@@ -91,9 +87,7 @@
 
 
 def _handle_layer_parameters(params):
-    """
-    Access to handle all kinds of parameters
-    """
+    """Access to handle all kinds of parameters"""
     for key, value in six.iteritems(params):
         if value in ("None", ""):
             params[key] = None
@@ -104,28 +98,24 @@
         ):
             continue
 
-        if (
-            key
-            in [
-                "input_shape",
-                "noise_shape",
-                "shape",
-                "batch_shape",
-                "target_shape",
-                "dims",
-                "kernel_size",
-                "strides",
-                "dilation_rate",
-                "output_padding",
-                "cropping",
-                "size",
-                "padding",
-                "pool_size",
-                "axis",
-                "shared_axes",
-            ]
-            and isinstance(value, str)
-        ):
+        if key in [
+            "input_shape",
+            "noise_shape",
+            "shape",
+            "batch_shape",
+            "target_shape",
+            "dims",
+            "kernel_size",
+            "strides",
+            "dilation_rate",
+            "output_padding",
+            "cropping",
+            "size",
+            "padding",
+            "pool_size",
+            "axis",
+            "shared_axes",
+        ] and isinstance(value, str):
             params[key] = _handle_shape(value)
 
         elif key.endswith("_regularizer") and isinstance(value, dict):
@@ -141,8 +131,7 @@
 
 
 def get_sequential_model(config):
-    """
-    Construct keras Sequential model from Galaxy tool parameters
+    """Construct keras Sequential model from Galaxy tool parameters
 
     Parameters:
     -----------
@@ -165,7 +154,7 @@
             options.update(kwargs)
 
         # add input_shape to the first layer only
-        if not getattr(model, "_layers") and input_shape is not None:
+        if not model.get_config()["layers"] and input_shape is not None:
             options["input_shape"] = input_shape
 
         model.add(klass(**options))
@@ -174,8 +163,7 @@
 
 
 def get_functional_model(config):
-    """
-    Construct keras functional model from Galaxy tool parameters
+    """Construct keras functional model from Galaxy tool parameters
 
     Parameters
     -----------
@@ -221,8 +209,7 @@
 
 
 def get_batch_generator(config):
-    """
-    Construct keras online data generator from Galaxy tool parameters
+    """Construct keras online data generator from Galaxy tool parameters
 
     Parameters
     -----------
@@ -246,8 +233,7 @@
 
 
 def config_keras_model(inputs, outfile):
-    """
-    config keras model layers and output JSON
+    """config keras model layers and output JSON
 
     Parameters
     ----------
@@ -271,16 +257,8 @@
         json.dump(json.loads(json_string), f, indent=2)
 
 
-def build_keras_model(
-    inputs,
-    outfile,
-    model_json,
-    infile_weights=None,
-    batch_mode=False,
-    outfile_params=None,
-):
-    """
-    for `keras_model_builder` tool
+def build_keras_model(inputs, outfile, model_json, batch_mode=False):
+    """for `keras_model_builder` tool
 
     Parameters
     ----------
@@ -290,12 +268,8 @@
         Path to galaxy dataset containing the keras_galaxy model output.
     model_json : str
         Path to dataset containing keras model JSON.
-    infile_weights : str or None
-        If string, path to dataset containing model weights.
     batch_mode : bool, default=False
         Whether to build online batch classifier.
-    outfile_params : str, default=None
-        File path to search parameters output.
     """
     with open(model_json, "r") as f:
         json_model = json.load(f)
@@ -307,7 +281,7 @@
     if json_model["class_name"] == "Sequential":
         options["model_type"] = "sequential"
         klass = Sequential
-    elif json_model["class_name"] == "Model":
+    elif json_model["class_name"] == "Functional":
         options["model_type"] = "functional"
         klass = Model
     else:
@@ -315,8 +289,9 @@
 
     # load prefitted model
     if inputs["mode_selection"]["mode_type"] == "prefitted":
-        estimator = klass.from_config(config)
-        estimator.load_weights(infile_weights)
+        # estimator = klass.from_config(config)
+        # estimator.load_weights(infile_weights)
+        raise Exception("Prefitted was deprecated!")
     # build train model
     else:
         cls_name = inputs["mode_selection"]["learning_type"]
@@ -338,8 +313,10 @@
         )
 
         train_metrics = inputs["mode_selection"]["compile_params"]["metrics"]
+        if not isinstance(train_metrics, list):  # for older galaxy
+            train_metrics = train_metrics.split(",")
         if train_metrics[-1] == "none":
-            train_metrics = train_metrics[:-1]
+            train_metrics.pop()
         options["metrics"] = train_metrics
 
         options.update(inputs["mode_selection"]["fit_params"])
@@ -355,19 +332,10 @@
                 "class_positive_factor"
             ]
         estimator = klass(config, **options)
-        if outfile_params:
-            hyper_params = get_search_params(estimator)
-            # TODO: remove this after making `verbose` tunable
-            for h_param in hyper_params:
-                if h_param[1].endswith("verbose"):
-                    h_param[0] = "@"
-            df = pd.DataFrame(hyper_params, columns=["", "Parameter", "Value"])
-            df.to_csv(outfile_params, sep="\t", index=False)
 
     print(repr(estimator))
-    # save model by pickle
-    with open(outfile, "wb") as f:
-        pickle.dump(estimator, f, pickle.HIGHEST_PROTOCOL)
+    # save model
+    dump_model_to_h5(estimator, outfile, verbose=1)
 
 
 if __name__ == "__main__":
@@ -377,9 +345,7 @@
     aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
     aparser.add_argument("-m", "--model_json", dest="model_json")
     aparser.add_argument("-t", "--tool_id", dest="tool_id")
-    aparser.add_argument("-w", "--infile_weights", dest="infile_weights")
     aparser.add_argument("-o", "--outfile", dest="outfile")
-    aparser.add_argument("-p", "--outfile_params", dest="outfile_params")
     args = aparser.parse_args()
 
     input_json_path = args.inputs
@@ -388,9 +354,7 @@
 
     tool_id = args.tool_id
     outfile = args.outfile
-    outfile_params = args.outfile_params
     model_json = args.model_json
-    infile_weights = args.infile_weights
 
     # for keras_model_config tool
     if tool_id == "keras_model_config":
@@ -403,10 +367,5 @@
             batch_mode = True
 
         build_keras_model(
-            inputs=inputs,
-            model_json=model_json,
-            infile_weights=infile_weights,
-            batch_mode=batch_mode,
-            outfile=outfile,
-            outfile_params=outfile_params,
+            inputs=inputs, model_json=model_json, batch_mode=batch_mode, outfile=outfile
         )