Mercurial > repos > bgruening > sklearn_nn_classifier
diff keras_deep_learning.py @ 27:22f0b9db4ea1 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author | bgruening |
---|---|
date | Wed, 09 Aug 2023 12:57:05 +0000 |
parents | 823ecc0bce45 |
children |
line wrap: on
line diff
--- a/keras_deep_learning.py Thu Aug 11 09:54:23 2022 +0000 +++ b/keras_deep_learning.py Wed Aug 09 12:57:05 2023 +0000 @@ -1,21 +1,19 @@ import argparse import json -import pickle import warnings from ast import literal_eval -import keras -import pandas as pd import six -from galaxy_ml.utils import get_search_params, SafeEval, try_get_attr -from keras.models import Model, Sequential +from galaxy_ml.model_persist import dump_model_to_h5 +from galaxy_ml.utils import SafeEval, try_get_attr +from tensorflow import keras +from tensorflow.keras.models import Model, Sequential safe_eval = SafeEval() def _handle_shape(literal): - """ - Eval integer or list/tuple of integers from string + """Eval integer or list/tuple of integers from string Parameters: ----------- @@ -32,8 +30,7 @@ def _handle_regularizer(literal): - """ - Construct regularizer from string literal + """Construct regularizer from string literal Parameters ---------- @@ -57,8 +54,7 @@ def _handle_constraint(config): - """ - Construct constraint from galaxy tool parameters. + """Construct constraint from galaxy tool parameters. Suppose correct dictionary format Parameters @@ -91,9 +87,7 @@ def _handle_layer_parameters(params): - """ - Access to handle all kinds of parameters - """ + """Access to handle all kinds of parameters""" for key, value in six.iteritems(params): if value in ("None", ""): params[key] = None @@ -104,28 +98,24 @@ ): continue - if ( - key - in [ - "input_shape", - "noise_shape", - "shape", - "batch_shape", - "target_shape", - "dims", - "kernel_size", - "strides", - "dilation_rate", - "output_padding", - "cropping", - "size", - "padding", - "pool_size", - "axis", - "shared_axes", - ] - and isinstance(value, str) - ): + if key in [ + "input_shape", + "noise_shape", + "shape", + "batch_shape", + "target_shape", + "dims", + "kernel_size", + "strides", + "dilation_rate", + "output_padding", + "cropping", + "size", + "padding", + "pool_size", + "axis", + "shared_axes", + ] and isinstance(value, str): params[key] = _handle_shape(value) elif key.endswith("_regularizer") and isinstance(value, dict): @@ -141,8 +131,7 @@ def get_sequential_model(config): - """ - Construct keras Sequential model from Galaxy tool parameters + """Construct keras Sequential model from Galaxy tool parameters Parameters: ----------- @@ -165,7 +154,7 @@ options.update(kwargs) # add input_shape to the first layer only - if not getattr(model, "_layers") and input_shape is not None: + if not model.get_config()["layers"] and input_shape is not None: options["input_shape"] = input_shape model.add(klass(**options)) @@ -174,8 +163,7 @@ def get_functional_model(config): - """ - Construct keras functional model from Galaxy tool parameters + """Construct keras functional model from Galaxy tool parameters Parameters ----------- @@ -221,8 +209,7 @@ def get_batch_generator(config): - """ - Construct keras online data generator from Galaxy tool parameters + """Construct keras online data generator from Galaxy tool parameters Parameters ----------- @@ -246,8 +233,7 @@ def config_keras_model(inputs, outfile): - """ - config keras model layers and output JSON + """config keras model layers and output JSON Parameters ---------- @@ -271,16 +257,8 @@ json.dump(json.loads(json_string), f, indent=2) -def build_keras_model( - inputs, - outfile, - model_json, - infile_weights=None, - batch_mode=False, - outfile_params=None, -): - """ - for `keras_model_builder` tool +def build_keras_model(inputs, outfile, model_json, batch_mode=False): + """for `keras_model_builder` tool Parameters ---------- @@ -290,12 +268,8 @@ Path to galaxy dataset containing the keras_galaxy model output. model_json : str Path to dataset containing keras model JSON. - infile_weights : str or None - If string, path to dataset containing model weights. batch_mode : bool, default=False Whether to build online batch classifier. - outfile_params : str, default=None - File path to search parameters output. """ with open(model_json, "r") as f: json_model = json.load(f) @@ -307,7 +281,7 @@ if json_model["class_name"] == "Sequential": options["model_type"] = "sequential" klass = Sequential - elif json_model["class_name"] == "Model": + elif json_model["class_name"] == "Functional": options["model_type"] = "functional" klass = Model else: @@ -315,8 +289,9 @@ # load prefitted model if inputs["mode_selection"]["mode_type"] == "prefitted": - estimator = klass.from_config(config) - estimator.load_weights(infile_weights) + # estimator = klass.from_config(config) + # estimator.load_weights(infile_weights) + raise Exception("Prefitted was deprecated!") # build train model else: cls_name = inputs["mode_selection"]["learning_type"] @@ -338,8 +313,10 @@ ) train_metrics = inputs["mode_selection"]["compile_params"]["metrics"] + if not isinstance(train_metrics, list): # for older galaxy + train_metrics = train_metrics.split(",") if train_metrics[-1] == "none": - train_metrics = train_metrics[:-1] + train_metrics.pop() options["metrics"] = train_metrics options.update(inputs["mode_selection"]["fit_params"]) @@ -355,19 +332,10 @@ "class_positive_factor" ] estimator = klass(config, **options) - if outfile_params: - hyper_params = get_search_params(estimator) - # TODO: remove this after making `verbose` tunable - for h_param in hyper_params: - if h_param[1].endswith("verbose"): - h_param[0] = "@" - df = pd.DataFrame(hyper_params, columns=["", "Parameter", "Value"]) - df.to_csv(outfile_params, sep="\t", index=False) print(repr(estimator)) - # save model by pickle - with open(outfile, "wb") as f: - pickle.dump(estimator, f, pickle.HIGHEST_PROTOCOL) + # save model + dump_model_to_h5(estimator, outfile, verbose=1) if __name__ == "__main__": @@ -377,9 +345,7 @@ aparser.add_argument("-i", "--inputs", dest="inputs", required=True) aparser.add_argument("-m", "--model_json", dest="model_json") aparser.add_argument("-t", "--tool_id", dest="tool_id") - aparser.add_argument("-w", "--infile_weights", dest="infile_weights") aparser.add_argument("-o", "--outfile", dest="outfile") - aparser.add_argument("-p", "--outfile_params", dest="outfile_params") args = aparser.parse_args() input_json_path = args.inputs @@ -388,9 +354,7 @@ tool_id = args.tool_id outfile = args.outfile - outfile_params = args.outfile_params model_json = args.model_json - infile_weights = args.infile_weights # for keras_model_config tool if tool_id == "keras_model_config": @@ -403,10 +367,5 @@ batch_mode = True build_keras_model( - inputs=inputs, - model_json=model_json, - infile_weights=infile_weights, - batch_mode=batch_mode, - outfile=outfile, - outfile_params=outfile_params, + inputs=inputs, model_json=model_json, batch_mode=batch_mode, outfile=outfile )