Mercurial > repos > bgruening > create_tool_recommendation_model
comparison main.py @ 5:4f7e6612906b draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 5eebc0cb44e71f581d548b7e842002705dd155eb"
author | bgruening |
---|---|
date | Fri, 06 May 2022 09:05:18 +0000 |
parents | afec8c595124 |
children | e94dc7945639 |
comparison
equal
deleted
inserted
replaced
4:afec8c595124 | 5:4f7e6612906b |
---|---|
1 """ | 1 """ |
2 Predict next tools in the Galaxy workflows | 2 Predict next tools in the Galaxy workflows |
3 using machine learning (recurrent neural network) | 3 using machine learning (recurrent neural network) |
4 """ | 4 """ |
5 | 5 |
6 import numpy as np | |
7 import argparse | 6 import argparse |
8 import time | 7 import time |
9 | 8 |
10 # machine learning library | 9 import extract_workflow_connections |
11 import tensorflow as tf | |
12 from keras import backend as K | |
13 import keras.callbacks as callbacks | 10 import keras.callbacks as callbacks |
14 | 11 import numpy as np |
15 import extract_workflow_connections | 12 import optimise_hyperparameters |
16 import prepare_data | 13 import prepare_data |
17 import optimise_hyperparameters | |
18 import utils | 14 import utils |
19 | 15 |
20 | 16 |
21 class PredictTool: | 17 class PredictTool: |
22 | |
23 def __init__(self, num_cpus): | 18 def __init__(self, num_cpus): |
24 """ Init method. """ | 19 """ Init method. """ |
25 # set the number of cpus | 20 |
26 cpu_config = tf.ConfigProto( | 21 def find_train_best_network( |
27 device_count={"CPU": num_cpus}, | 22 self, |
28 intra_op_parallelism_threads=num_cpus, | 23 network_config, |
29 inter_op_parallelism_threads=num_cpus, | 24 reverse_dictionary, |
30 allow_soft_placement=True | 25 train_data, |
31 ) | 26 train_labels, |
32 K.set_session(tf.Session(config=cpu_config)) | 27 test_data, |
33 | 28 test_labels, |
34 def find_train_best_network(self, network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, standard_connections, tool_freq, tool_tr_samples): | 29 n_epochs, |
30 class_weights, | |
31 usage_pred, | |
32 standard_connections, | |
33 tool_freq, | |
34 tool_tr_samples, | |
35 ): | |
35 """ | 36 """ |
36 Define recurrent neural network and train sequential data | 37 Define recurrent neural network and train sequential data |
37 """ | 38 """ |
38 # get tools with lowest representation | 39 # get tools with lowest representation |
39 lowest_tool_ids = utils.get_lowest_tools(tool_freq) | 40 lowest_tool_ids = utils.get_lowest_tools(tool_freq) |
40 | 41 |
41 print("Start hyperparameter optimisation...") | 42 print("Start hyperparameter optimisation...") |
42 hyper_opt = optimise_hyperparameters.HyperparameterOptimisation() | 43 hyper_opt = optimise_hyperparameters.HyperparameterOptimisation() |
43 best_params, best_model = hyper_opt.train_model(network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, tool_tr_samples, class_weights) | 44 best_params, best_model = hyper_opt.train_model( |
45 network_config, | |
46 reverse_dictionary, | |
47 train_data, | |
48 train_labels, | |
49 test_data, | |
50 test_labels, | |
51 tool_tr_samples, | |
52 class_weights, | |
53 ) | |
44 | 54 |
45 # define callbacks | 55 # define callbacks |
46 early_stopping = callbacks.EarlyStopping(monitor='loss', mode='min', verbose=1, min_delta=1e-1, restore_best_weights=True) | 56 early_stopping = callbacks.EarlyStopping( |
47 predict_callback_test = PredictCallback(test_data, test_labels, reverse_dictionary, n_epochs, usage_pred, standard_connections, lowest_tool_ids) | 57 monitor="loss", |
58 mode="min", | |
59 verbose=1, | |
60 min_delta=1e-1, | |
61 restore_best_weights=True, | |
62 ) | |
63 predict_callback_test = PredictCallback( | |
64 test_data, | |
65 test_labels, | |
66 reverse_dictionary, | |
67 n_epochs, | |
68 usage_pred, | |
69 standard_connections, | |
70 lowest_tool_ids, | |
71 ) | |
48 | 72 |
49 callbacks_list = [predict_callback_test, early_stopping] | 73 callbacks_list = [predict_callback_test, early_stopping] |
50 batch_size = int(best_params["batch_size"]) | 74 batch_size = int(best_params["batch_size"]) |
51 | 75 |
52 print("Start training on the best model...") | 76 print("Start training on the best model...") |
55 utils.balanced_sample_generator( | 79 utils.balanced_sample_generator( |
56 train_data, | 80 train_data, |
57 train_labels, | 81 train_labels, |
58 batch_size, | 82 batch_size, |
59 tool_tr_samples, | 83 tool_tr_samples, |
60 reverse_dictionary | 84 reverse_dictionary, |
61 ), | 85 ), |
62 steps_per_epoch=len(train_data) // batch_size, | 86 steps_per_epoch=len(train_data) // batch_size, |
63 epochs=n_epochs, | 87 epochs=n_epochs, |
64 callbacks=callbacks_list, | 88 callbacks=callbacks_list, |
65 validation_data=(test_data, test_labels), | 89 validation_data=(test_data, test_labels), |
66 verbose=2, | 90 verbose=2, |
67 shuffle=True | 91 shuffle=True, |
68 ) | 92 ) |
69 train_performance["validation_loss"] = np.array(trained_model.history["val_loss"]) | 93 train_performance["validation_loss"] = np.array( |
94 trained_model.history["val_loss"] | |
95 ) | |
70 train_performance["precision"] = predict_callback_test.precision | 96 train_performance["precision"] = predict_callback_test.precision |
71 train_performance["usage_weights"] = predict_callback_test.usage_weights | 97 train_performance["usage_weights"] = predict_callback_test.usage_weights |
72 train_performance["published_precision"] = predict_callback_test.published_precision | 98 train_performance[ |
73 train_performance["lowest_pub_precision"] = predict_callback_test.lowest_pub_precision | 99 "published_precision" |
74 train_performance["lowest_norm_precision"] = predict_callback_test.lowest_norm_precision | 100 ] = predict_callback_test.published_precision |
101 train_performance[ | |
102 "lowest_pub_precision" | |
103 ] = predict_callback_test.lowest_pub_precision | |
104 train_performance[ | |
105 "lowest_norm_precision" | |
106 ] = predict_callback_test.lowest_norm_precision | |
75 train_performance["train_loss"] = np.array(trained_model.history["loss"]) | 107 train_performance["train_loss"] = np.array(trained_model.history["loss"]) |
76 train_performance["model"] = best_model | 108 train_performance["model"] = best_model |
77 train_performance["best_parameters"] = best_params | 109 train_performance["best_parameters"] = best_params |
78 return train_performance | 110 return train_performance |
79 | 111 |
80 | 112 |
81 class PredictCallback(callbacks.Callback): | 113 class PredictCallback(callbacks.Callback): |
82 def __init__(self, test_data, test_labels, reverse_data_dictionary, n_epochs, usg_scores, standard_connections, lowest_tool_ids): | 114 def __init__( |
115 self, | |
116 test_data, | |
117 test_labels, | |
118 reverse_data_dictionary, | |
119 n_epochs, | |
120 usg_scores, | |
121 standard_connections, | |
122 lowest_tool_ids, | |
123 ): | |
83 self.test_data = test_data | 124 self.test_data = test_data |
84 self.test_labels = test_labels | 125 self.test_labels = test_labels |
85 self.reverse_data_dictionary = reverse_data_dictionary | 126 self.reverse_data_dictionary = reverse_data_dictionary |
86 self.precision = list() | 127 self.precision = list() |
87 self.usage_weights = list() | 128 self.usage_weights = list() |
96 def on_epoch_end(self, epoch, logs={}): | 137 def on_epoch_end(self, epoch, logs={}): |
97 """ | 138 """ |
98 Compute absolute and compatible precision for test data | 139 Compute absolute and compatible precision for test data |
99 """ | 140 """ |
100 if len(self.test_data) > 0: | 141 if len(self.test_data) > 0: |
101 usage_weights, precision, precision_pub, low_pub_prec, low_norm_prec, low_num = utils.verify_model(self.model, self.test_data, self.test_labels, self.reverse_data_dictionary, self.pred_usage_scores, self.standard_connections, self.lowest_tool_ids) | 142 ( |
143 usage_weights, | |
144 precision, | |
145 precision_pub, | |
146 low_pub_prec, | |
147 low_norm_prec, | |
148 low_num, | |
149 ) = utils.verify_model( | |
150 self.model, | |
151 self.test_data, | |
152 self.test_labels, | |
153 self.reverse_data_dictionary, | |
154 self.pred_usage_scores, | |
155 self.standard_connections, | |
156 self.lowest_tool_ids, | |
157 ) | |
102 self.precision.append(precision) | 158 self.precision.append(precision) |
103 self.usage_weights.append(usage_weights) | 159 self.usage_weights.append(usage_weights) |
104 self.published_precision.append(precision_pub) | 160 self.published_precision.append(precision_pub) |
105 self.lowest_pub_precision.append(low_pub_prec) | 161 self.lowest_pub_precision.append(low_pub_prec) |
106 self.lowest_norm_precision.append(low_norm_prec) | 162 self.lowest_norm_precision.append(low_norm_prec) |
107 print("Epoch %d usage weights: %s" % (epoch + 1, usage_weights)) | 163 print("Epoch %d usage weights: %s" % (epoch + 1, usage_weights)) |
108 print("Epoch %d normal precision: %s" % (epoch + 1, precision)) | 164 print("Epoch %d normal precision: %s" % (epoch + 1, precision)) |
109 print("Epoch %d published precision: %s" % (epoch + 1, precision_pub)) | 165 print("Epoch %d published precision: %s" % (epoch + 1, precision_pub)) |
110 print("Epoch %d lowest published precision: %s" % (epoch + 1, low_pub_prec)) | 166 print("Epoch %d lowest published precision: %s" % (epoch + 1, low_pub_prec)) |
111 print("Epoch %d lowest normal precision: %s" % (epoch + 1, low_norm_prec)) | 167 print("Epoch %d lowest normal precision: %s" % (epoch + 1, low_norm_prec)) |
112 print("Epoch %d number of test samples with lowest tool ids: %s" % (epoch + 1, low_num)) | 168 print( |
169 "Epoch %d number of test samples with lowest tool ids: %s" | |
170 % (epoch + 1, low_num) | |
171 ) | |
113 | 172 |
114 | 173 |
115 if __name__ == "__main__": | 174 if __name__ == "__main__": |
116 start_time = time.time() | 175 start_time = time.time() |
117 | 176 |
118 arg_parser = argparse.ArgumentParser() | 177 arg_parser = argparse.ArgumentParser() |
119 arg_parser.add_argument("-wf", "--workflow_file", required=True, help="workflows tabular file") | 178 arg_parser.add_argument( |
120 arg_parser.add_argument("-tu", "--tool_usage_file", required=True, help="tool usage file") | 179 "-wf", "--workflow_file", required=True, help="workflows tabular file" |
121 arg_parser.add_argument("-om", "--output_model", required=True, help="trained model file") | 180 ) |
181 arg_parser.add_argument( | |
182 "-tu", "--tool_usage_file", required=True, help="tool usage file" | |
183 ) | |
184 arg_parser.add_argument( | |
185 "-om", "--output_model", required=True, help="trained model file" | |
186 ) | |
122 # data parameters | 187 # data parameters |
123 arg_parser.add_argument("-cd", "--cutoff_date", required=True, help="earliest date for taking tool usage") | 188 arg_parser.add_argument( |
124 arg_parser.add_argument("-pl", "--maximum_path_length", required=True, help="maximum length of tool path") | 189 "-cd", |
125 arg_parser.add_argument("-ep", "--n_epochs", required=True, help="number of iterations to run to create model") | 190 "--cutoff_date", |
126 arg_parser.add_argument("-oe", "--optimize_n_epochs", required=True, help="number of iterations to run to find best model parameters") | 191 required=True, |
127 arg_parser.add_argument("-me", "--max_evals", required=True, help="maximum number of configuration evaluations") | 192 help="earliest date for taking tool usage", |
128 arg_parser.add_argument("-ts", "--test_share", required=True, help="share of data to be used for testing") | 193 ) |
194 arg_parser.add_argument( | |
195 "-pl", | |
196 "--maximum_path_length", | |
197 required=True, | |
198 help="maximum length of tool path", | |
199 ) | |
200 arg_parser.add_argument( | |
201 "-ep", | |
202 "--n_epochs", | |
203 required=True, | |
204 help="number of iterations to run to create model", | |
205 ) | |
206 arg_parser.add_argument( | |
207 "-oe", | |
208 "--optimize_n_epochs", | |
209 required=True, | |
210 help="number of iterations to run to find best model parameters", | |
211 ) | |
212 arg_parser.add_argument( | |
213 "-me", | |
214 "--max_evals", | |
215 required=True, | |
216 help="maximum number of configuration evaluations", | |
217 ) | |
218 arg_parser.add_argument( | |
219 "-ts", | |
220 "--test_share", | |
221 required=True, | |
222 help="share of data to be used for testing", | |
223 ) | |
129 # neural network parameters | 224 # neural network parameters |
130 arg_parser.add_argument("-bs", "--batch_size", required=True, help="size of the tranining batch i.e. the number of samples per batch") | 225 arg_parser.add_argument( |
131 arg_parser.add_argument("-ut", "--units", required=True, help="number of hidden recurrent units") | 226 "-bs", |
132 arg_parser.add_argument("-es", "--embedding_size", required=True, help="size of the fixed vector learned for each tool") | 227 "--batch_size", |
133 arg_parser.add_argument("-dt", "--dropout", required=True, help="percentage of neurons to be dropped") | 228 required=True, |
134 arg_parser.add_argument("-sd", "--spatial_dropout", required=True, help="1d dropout used for embedding layer") | 229 help="size of the tranining batch i.e. the number of samples per batch", |
135 arg_parser.add_argument("-rd", "--recurrent_dropout", required=True, help="dropout for the recurrent layers") | 230 ) |
136 arg_parser.add_argument("-lr", "--learning_rate", required=True, help="learning rate") | 231 arg_parser.add_argument( |
232 "-ut", "--units", required=True, help="number of hidden recurrent units" | |
233 ) | |
234 arg_parser.add_argument( | |
235 "-es", | |
236 "--embedding_size", | |
237 required=True, | |
238 help="size of the fixed vector learned for each tool", | |
239 ) | |
240 arg_parser.add_argument( | |
241 "-dt", "--dropout", required=True, help="percentage of neurons to be dropped" | |
242 ) | |
243 arg_parser.add_argument( | |
244 "-sd", | |
245 "--spatial_dropout", | |
246 required=True, | |
247 help="1d dropout used for embedding layer", | |
248 ) | |
249 arg_parser.add_argument( | |
250 "-rd", | |
251 "--recurrent_dropout", | |
252 required=True, | |
253 help="dropout for the recurrent layers", | |
254 ) | |
255 arg_parser.add_argument( | |
256 "-lr", "--learning_rate", required=True, help="learning rate" | |
257 ) | |
137 | 258 |
138 # get argument values | 259 # get argument values |
139 args = vars(arg_parser.parse_args()) | 260 args = vars(arg_parser.parse_args()) |
140 tool_usage_path = args["tool_usage_file"] | 261 tool_usage_path = args["tool_usage_file"] |
141 workflows_path = args["workflow_file"] | 262 workflows_path = args["workflow_file"] |
154 recurrent_dropout = args["recurrent_dropout"] | 275 recurrent_dropout = args["recurrent_dropout"] |
155 learning_rate = args["learning_rate"] | 276 learning_rate = args["learning_rate"] |
156 num_cpus = 16 | 277 num_cpus = 16 |
157 | 278 |
158 config = { | 279 config = { |
159 'cutoff_date': cutoff_date, | 280 "cutoff_date": cutoff_date, |
160 'maximum_path_length': maximum_path_length, | 281 "maximum_path_length": maximum_path_length, |
161 'n_epochs': n_epochs, | 282 "n_epochs": n_epochs, |
162 'optimize_n_epochs': optimize_n_epochs, | 283 "optimize_n_epochs": optimize_n_epochs, |
163 'max_evals': max_evals, | 284 "max_evals": max_evals, |
164 'test_share': test_share, | 285 "test_share": test_share, |
165 'batch_size': batch_size, | 286 "batch_size": batch_size, |
166 'units': units, | 287 "units": units, |
167 'embedding_size': embedding_size, | 288 "embedding_size": embedding_size, |
168 'dropout': dropout, | 289 "dropout": dropout, |
169 'spatial_dropout': spatial_dropout, | 290 "spatial_dropout": spatial_dropout, |
170 'recurrent_dropout': recurrent_dropout, | 291 "recurrent_dropout": recurrent_dropout, |
171 'learning_rate': learning_rate | 292 "learning_rate": learning_rate, |
172 } | 293 } |
173 | 294 |
174 # Extract and process workflows | 295 # Extract and process workflows |
175 connections = extract_workflow_connections.ExtractWorkflowConnections() | 296 connections = extract_workflow_connections.ExtractWorkflowConnections() |
176 workflow_paths, compatible_next_tools, standard_connections = connections.read_tabular_file(workflows_path) | 297 ( |
298 workflow_paths, | |
299 compatible_next_tools, | |
300 standard_connections, | |
301 ) = connections.read_tabular_file(workflows_path) | |
177 # Process the paths from workflows | 302 # Process the paths from workflows |
178 print("Dividing data...") | 303 print("Dividing data...") |
179 data = prepare_data.PrepareData(maximum_path_length, test_share) | 304 data = prepare_data.PrepareData(maximum_path_length, test_share) |
180 train_data, train_labels, test_data, test_labels, data_dictionary, reverse_dictionary, class_weights, usage_pred, train_tool_freq, tool_tr_samples = data.get_data_labels_matrices(workflow_paths, tool_usage_path, cutoff_date, compatible_next_tools, standard_connections) | 305 ( |
306 train_data, | |
307 train_labels, | |
308 test_data, | |
309 test_labels, | |
310 data_dictionary, | |
311 reverse_dictionary, | |
312 class_weights, | |
313 usage_pred, | |
314 train_tool_freq, | |
315 tool_tr_samples, | |
316 ) = data.get_data_labels_matrices( | |
317 workflow_paths, | |
318 tool_usage_path, | |
319 cutoff_date, | |
320 compatible_next_tools, | |
321 standard_connections, | |
322 ) | |
181 # find the best model and start training | 323 # find the best model and start training |
182 predict_tool = PredictTool(num_cpus) | 324 predict_tool = PredictTool(num_cpus) |
183 # start training with weighted classes | 325 # start training with weighted classes |
184 print("Training with weighted classes and samples ...") | 326 print("Training with weighted classes and samples ...") |
185 results_weighted = predict_tool.find_train_best_network(config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, standard_connections, train_tool_freq, tool_tr_samples) | 327 results_weighted = predict_tool.find_train_best_network( |
186 utils.save_model(results_weighted, data_dictionary, compatible_next_tools, trained_model_path, class_weights, standard_connections) | 328 config, |
329 reverse_dictionary, | |
330 train_data, | |
331 train_labels, | |
332 test_data, | |
333 test_labels, | |
334 n_epochs, | |
335 class_weights, | |
336 usage_pred, | |
337 standard_connections, | |
338 train_tool_freq, | |
339 tool_tr_samples, | |
340 ) | |
341 utils.save_model( | |
342 results_weighted, | |
343 data_dictionary, | |
344 compatible_next_tools, | |
345 trained_model_path, | |
346 class_weights, | |
347 standard_connections, | |
348 ) | |
187 end_time = time.time() | 349 end_time = time.time() |
188 print("Program finished in %s seconds" % str(end_time - start_time)) | 350 print("Program finished in %s seconds" % str(end_time - start_time)) |