comparison main.py @ 0:f4619200cb0a draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
author bgruening
date Sat, 11 Dec 2021 17:56:38 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:f4619200cb0a
1 import argparse
2 import os
3 import subprocess
4 import warnings
5 from zipfile import ZipFile
6
7 import h5py
8 import yaml
9 from skl2onnx import convert_sklearn
10 from skl2onnx.common.data_types import FloatTensorType
11
12
13 warnings.filterwarnings("ignore")
14 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
15
16 SKLEARN_MODELS = [
17 "sklearn.ensemble",
18 "sklearn.tree",
19 "sklearn.linear_model",
20 "sklearn.svm",
21 "sklearn.neighbors",
22 "sklearn.preprocessing",
23 "sklearn.cluster"
24 ]
25
26 TF_MODELS = [
27 "tensorflow.python.keras.engine.training.Model",
28 "tensorflow.python.keras.engine.sequential.Sequential",
29 "tensorflow.python.keras.engine.functional.Functional",
30 "tensorflow.python.keras.layers",
31 "keras.engine.functional.Functional",
32 "keras.engine.sequential.Sequential",
33 "keras.engine.training.Model",
34 "keras.layers"
35 ]
36
37 ARRAYS = [
38 "numpy.ndarray",
39 "list"
40 ]
41
42 DATAFRAME = [
43 "pandas.core.frame.DataFrame"
44 ]
45
46 SCALAR_TYPES = [
47 "int",
48 "float",
49 "str"
50 ]
51
52
53 def find_replace_paths(script_file, updated_data_dict):
54 for item in updated_data_dict:
55 g_path = updated_data_dict[item]
56 script_file = script_file.replace(item, g_path)
57 return script_file
58
59
60 def update_ml_files_paths(old_file_paths, new_file_paths):
61 if old_file_paths == "" or old_file_paths is None or new_file_paths == "" or new_file_paths is None:
62 return dict()
63 o_files = old_file_paths.split(",")
64 n_files = new_file_paths.split(",")
65 new_paths_dict = dict()
66 for i, o_f in enumerate(o_files):
67 new_paths_dict[o_f] = n_files[i]
68 return new_paths_dict
69
70
71 def read_loaded_file(new_paths_dict, p_loaded_file, a_file, w_dir, z_file):
72 global_vars = dict()
73 input_file = yaml.safe_load(p_loaded_file)
74 code_string = open(input_file, "r").read()
75 re_code_string = find_replace_paths(code_string, new_paths_dict)
76 compiled_code = compile(re_code_string, input_file, 'exec')
77 exec(compiled_code, global_vars)
78 check_vars(w_dir, global_vars, a_file)
79 zip_files(w_dir, z_file)
80
81
82 def zip_files(w_dir, z_file):
83 with ZipFile(z_file, 'w') as zip_file:
84 for f_path in os.listdir(w_dir):
85 zip_file.write(f_path)
86
87
88 def create_model_path(curr_path, key):
89 onnx_path = curr_path + "/model_outputs"
90 if not os.path.exists(onnx_path):
91 os.makedirs(onnx_path)
92 onnx_model_path = curr_path + "/model_outputs/" + "onnx_model_{}.onnx".format(key)
93 return onnx_model_path
94
95
96 def save_sklearn_model(w_dir, key, obj):
97 initial_type = [('float_input', FloatTensorType([None, 4]))]
98 onx = convert_sklearn(obj, initial_types=initial_type)
99 sk_model_path = create_model_path(w_dir, key)
100 with open(sk_model_path, "wb") as f:
101 f.write(onx.SerializeToString())
102
103
104 def save_tf_model(w_dir, key, obj):
105 import tensorflow as tf
106 tf_file_key = "tf_model_{}".format(key)
107 tf_model_path = "{}/{}".format(w_dir, tf_file_key)
108 if not os.path.exists(tf_model_path):
109 os.makedirs(tf_model_path)
110 # save model as tf model
111 tf.saved_model.save(obj, tf_model_path)
112 # save model as ONNX
113 tf_onnx_model_p = create_model_path(w_dir, key)
114 # OPSET level defines a level of tensorflow operations supported by ONNX
115 python_shell_script = "python -m tf2onnx.convert --saved-model " + tf_model_path + " --output " + tf_onnx_model_p + " --opset 15 "
116 # convert tf/keras model to ONNX and save it to output file
117 subprocess.run(python_shell_script, shell=True, check=True)
118
119
120 def save_primitives(payload, a_file):
121 hf_file = h5py.File(a_file, "w")
122 for key in payload:
123 try:
124 hf_file.create_dataset(key, data=payload[key])
125 except Exception as e:
126 print(e)
127 continue
128 hf_file.close()
129
130
131 def save_dataframe(payload, a_file):
132 for key in payload:
133 payload[key].to_hdf(a_file, key=key)
134
135
136 def check_vars(w_dir, var_dict, a_file):
137 if var_dict is not None:
138 primitive_payload = dict()
139 dataframe_payload = dict()
140 for key in var_dict:
141 obj = var_dict[key]
142 obj_class = str(obj.__class__)
143 # save tf model
144 if len([item for item in TF_MODELS if item in obj_class]) > 0:
145 save_tf_model(w_dir, key, obj)
146 # save scikit-learn model
147 elif len([item for item in SKLEARN_MODELS if item in obj_class]) > 0:
148 save_sklearn_model(w_dir, key, obj)
149 # save arrays and lists
150 elif len([item for item in ARRAYS if item in obj_class]) > 0:
151 if key not in primitive_payload:
152 primitive_payload[key] = obj
153 elif len([item for item in DATAFRAME if item in obj_class]) > 0:
154 if key not in dataframe_payload:
155 dataframe_payload[key] = obj
156 elif len([item for item in SCALAR_TYPES if item in obj_class]) > 0:
157 if key not in primitive_payload:
158 primitive_payload[key] = obj
159 save_primitives(primitive_payload, a_file)
160 save_dataframe(dataframe_payload, a_file)
161
162
163 if __name__ == "__main__":
164
165 arg_parser = argparse.ArgumentParser()
166 arg_parser.add_argument("-mlp", "--ml_paths", required=True, help="")
167 arg_parser.add_argument("-ldf", "--loaded_file", required=True, help="")
168 arg_parser.add_argument("-wd", "--working_dir", required=True, help="")
169 arg_parser.add_argument("-oz", "--output_zip", required=True, help="")
170 arg_parser.add_argument("-oa", "--output_array", required=True, help="")
171 arg_parser.add_argument("-mlf", "--ml_h5_files", required=True, help="")
172 # get argument values
173 args = vars(arg_parser.parse_args())
174 ml_paths = args["ml_paths"]
175 loaded_file = args["loaded_file"]
176 array_output_file = args["output_array"]
177 zip_output_file = args["output_zip"]
178 working_dir = args["working_dir"]
179 ml_h5_files = args["ml_h5_files"]
180 new_paths_dict = update_ml_files_paths(ml_paths, ml_h5_files)
181 read_loaded_file(new_paths_dict, loaded_file, array_output_file, working_dir, zip_output_file)