diff simple_model_fit.py @ 21:1d3447c2203c draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
author bgruening
date Tue, 13 Apr 2021 17:48:25 +0000
parents 55c7d3e58eae
children 34d31bd995e9
line wrap: on
line diff
--- a/simple_model_fit.py	Thu Oct 01 20:23:20 2020 +0000
+++ b/simple_model_fit.py	Tue Apr 13 17:48:25 2021 +0000
@@ -4,10 +4,11 @@
 import pickle
 
 from galaxy_ml.utils import load_model, read_columns
+from scipy.io import mmread
 from sklearn.pipeline import Pipeline
 
 
-N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1))
+N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1))
 
 
 # TODO import from galaxy_ml.utils in future versions
@@ -20,33 +21,35 @@
     ------
     Cleaned estimator object
     """
-    ALLOWED_CALLBACKS = ('EarlyStopping', 'TerminateOnNaN',
-                         'ReduceLROnPlateau', 'CSVLogger', 'None')
+    ALLOWED_CALLBACKS = (
+        "EarlyStopping",
+        "TerminateOnNaN",
+        "ReduceLROnPlateau",
+        "CSVLogger",
+        "None",
+    )
 
     estimator_params = estimator.get_params()
 
     for name, p in estimator_params.items():
         # all potential unauthorized file write
-        if name == 'memory' or name.endswith('__memory') \
-                or name.endswith('_path'):
+        if name == "memory" or name.endswith("__memory") or name.endswith("_path"):
             new_p = {name: None}
             estimator.set_params(**new_p)
-        elif n_jobs is not None and (name == 'n_jobs' or
-                                     name.endswith('__n_jobs')):
+        elif n_jobs is not None and (name == 'n_jobs' or name.endswith('__n_jobs')):
             new_p = {name: n_jobs}
             estimator.set_params(**new_p)
-        elif name.endswith('callbacks'):
+        elif name.endswith("callbacks"):
             for cb in p:
-                cb_type = cb['callback_selection']['callback_type']
+                cb_type = cb["callback_selection"]["callback_type"]
                 if cb_type not in ALLOWED_CALLBACKS:
-                    raise ValueError(
-                        "Prohibited callback type: %s!" % cb_type)
+                    raise ValueError("Prohibited callback type: %s!" % cb_type)
 
     return estimator
 
 
 def _get_X_y(params, infile1, infile2):
-    """ read from inputs and output X and y
+    """read from inputs and output X and y
 
     Parameters
     ----------
@@ -61,35 +64,40 @@
     # store read dataframe object
     loaded_df = {}
 
-    input_type = params['input_options']['selected_input']
+    input_type = params["input_options"]["selected_input"]
     # tabular input
-    if input_type == 'tabular':
-        header = 'infer' if params['input_options']['header1'] else None
-        column_option = (params['input_options']['column_selector_options_1']
-                         ['selected_column_selector_option'])
-        if column_option in ['by_index_number', 'all_but_by_index_number',
-                             'by_header_name', 'all_but_by_header_name']:
-            c = params['input_options']['column_selector_options_1']['col1']
+    if input_type == "tabular":
+        header = "infer" if params["input_options"]["header1"] else None
+        column_option = params["input_options"]["column_selector_options_1"]["selected_column_selector_option"]
+        if column_option in [
+            "by_index_number",
+            "all_but_by_index_number",
+            "by_header_name",
+            "all_but_by_header_name",
+        ]:
+            c = params["input_options"]["column_selector_options_1"]["col1"]
         else:
             c = None
 
         df_key = infile1 + repr(header)
-        df = pd.read_csv(infile1, sep='\t', header=header,
-                         parse_dates=True)
+        df = pd.read_csv(infile1, sep="\t", header=header, parse_dates=True)
         loaded_df[df_key] = df
 
         X = read_columns(df, c=c, c_option=column_option).astype(float)
     # sparse input
-    elif input_type == 'sparse':
-        X = mmread(open(infile1, 'r'))
+    elif input_type == "sparse":
+        X = mmread(open(infile1, "r"))
 
     # Get target y
-    header = 'infer' if params['input_options']['header2'] else None
-    column_option = (params['input_options']['column_selector_options_2']
-                     ['selected_column_selector_option2'])
-    if column_option in ['by_index_number', 'all_but_by_index_number',
-                         'by_header_name', 'all_but_by_header_name']:
-        c = params['input_options']['column_selector_options_2']['col2']
+    header = "infer" if params["input_options"]["header2"] else None
+    column_option = params["input_options"]["column_selector_options_2"]["selected_column_selector_option2"]
+    if column_option in [
+        "by_index_number",
+        "all_but_by_index_number",
+        "by_header_name",
+        "all_but_by_header_name",
+    ]:
+        c = params["input_options"]["column_selector_options_2"]["col2"]
     else:
         c = None
 
@@ -97,26 +105,23 @@
     if df_key in loaded_df:
         infile2 = loaded_df[df_key]
     else:
-        infile2 = pd.read_csv(infile2, sep='\t',
-                              header=header, parse_dates=True)
+        infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True)
         loaded_df[df_key] = infile2
 
-    y = read_columns(
-            infile2,
-            c=c,
-            c_option=column_option,
-            sep='\t',
-            header=header,
-            parse_dates=True)
+    y = read_columns(infile2,
+                     c=c,
+                     c_option=column_option,
+                     sep='\t',
+                     header=header,
+                     parse_dates=True)
     if len(y.shape) == 2 and y.shape[1] == 1:
         y = y.ravel()
 
     return X, y
 
 
-def main(inputs, infile_estimator, infile1, infile2, out_object,
-         out_weights=None):
-    """ main
+def main(inputs, infile_estimator, infile1, infile2, out_object, out_weights=None):
+    """main
 
     Parameters
     ----------
@@ -139,38 +144,37 @@
         File path for output of weights
 
     """
-    with open(inputs, 'r') as param_handler:
+    with open(inputs, "r") as param_handler:
         params = json.load(param_handler)
 
     # load model
-    with open(infile_estimator, 'rb') as est_handler:
+    with open(infile_estimator, "rb") as est_handler:
         estimator = load_model(est_handler)
     estimator = clean_params(estimator, n_jobs=N_JOBS)
 
     X_train, y_train = _get_X_y(params, infile1, infile2)
 
     estimator.fit(X_train, y_train)
-    
+
     main_est = estimator
     if isinstance(main_est, Pipeline):
         main_est = main_est.steps[-1][-1]
-    if hasattr(main_est, 'model_') \
-            and hasattr(main_est, 'save_weights'):
+    if hasattr(main_est, "model_") and hasattr(main_est, "save_weights"):
         if out_weights:
             main_est.save_weights(out_weights)
         del main_est.model_
         del main_est.fit_params
         del main_est.model_class_
-        del main_est.validation_data
-        if getattr(main_est, 'data_generator_', None):
+        if getattr(main_est, "validation_data", None):
+            del main_est.validation_data
+        if getattr(main_est, "data_generator_", None):
             del main_est.data_generator_
 
-    with open(out_object, 'wb') as output_handler:
-        pickle.dump(estimator, output_handler,
-                    pickle.HIGHEST_PROTOCOL)
+    with open(out_object, "wb") as output_handler:
+        pickle.dump(estimator, output_handler, pickle.HIGHEST_PROTOCOL)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     aparser = argparse.ArgumentParser()
     aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
     aparser.add_argument("-X", "--infile_estimator", dest="infile_estimator")
@@ -180,5 +184,11 @@
     aparser.add_argument("-t", "--out_weights", dest="out_weights")
     args = aparser.parse_args()
 
-    main(args.inputs, args.infile_estimator, args.infile1,
-         args.infile2, args.out_object, args.out_weights)
+    main(
+        args.inputs,
+        args.infile_estimator,
+        args.infile1,
+        args.infile2,
+        args.out_object,
+        args.out_weights,
+    )