diff model_prediction.py @ 14:9c19cf3c4ea0 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author bgruening
date Wed, 09 Aug 2023 13:08:43 +0000
parents 4a229a7ad638
children
line wrap: on
line diff
--- a/model_prediction.py	Thu Aug 11 09:29:32 2022 +0000
+++ b/model_prediction.py	Wed Aug 09 13:08:43 2023 +0000
@@ -4,9 +4,10 @@
 
 import numpy as np
 import pandas as pd
-from galaxy_ml.utils import get_module, load_model, read_columns, try_get_attr
+from galaxy_ml.model_persist import load_model_from_h5
+from galaxy_ml.utils import (clean_params, get_module, read_columns,
+                             try_get_attr)
 from scipy.io import mmread
-from sklearn.pipeline import Pipeline
 
 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1))
 
@@ -15,7 +16,6 @@
     inputs,
     infile_estimator,
     outfile_predict,
-    infile_weights=None,
     infile1=None,
     fasta_path=None,
     ref_seq=None,
@@ -27,15 +27,12 @@
     inputs : str
         File path to galaxy tool parameter
 
-    infile_estimator : strgit
+    infile_estimator : str
         File path to trained estimator input
 
     outfile_predict : str
         File path to save the prediction results, tabular
 
-    infile_weights : str
-        File path to weights input
-
     infile1 : str
         File path to dataset containing features
 
@@ -54,19 +51,8 @@
         params = json.load(param_handler)
 
     # load model
-    with open(infile_estimator, "rb") as est_handler:
-        estimator = load_model(est_handler)
-
-    main_est = estimator
-    if isinstance(estimator, Pipeline):
-        main_est = estimator.steps[-1][-1]
-    if hasattr(main_est, "config") and hasattr(main_est, "load_weights"):
-        if not infile_weights or infile_weights == "None":
-            raise ValueError(
-                "The selected model skeleton asks for weights, "
-                "but dataset for weights wan not selected!"
-            )
-        main_est.load_weights(infile_weights)
+    estimator = load_model_from_h5(infile_estimator)
+    estimator = clean_params(estimator)
 
     # handle data input
     input_type = params["input_options"]["selected_input"]
@@ -221,7 +207,6 @@
     aparser = argparse.ArgumentParser()
     aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
     aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator")
-    aparser.add_argument("-w", "--infile_weights", dest="infile_weights")
     aparser.add_argument("-X", "--infile1", dest="infile1")
     aparser.add_argument("-O", "--outfile_predict", dest="outfile_predict")
     aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
@@ -233,7 +218,6 @@
         args.inputs,
         args.infile_estimator,
         args.outfile_predict,
-        infile_weights=args.infile_weights,
         infile1=args.infile1,
         fasta_path=args.fasta_path,
         ref_seq=args.ref_seq,