comparison main.py @ 5:4f7e6612906b draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 5eebc0cb44e71f581d548b7e842002705dd155eb"
author bgruening
date Fri, 06 May 2022 09:05:18 +0000
parents afec8c595124
children e94dc7945639
comparison
equal deleted inserted replaced
4:afec8c595124 5:4f7e6612906b
1 """ 1 """
2 Predict next tools in the Galaxy workflows 2 Predict next tools in the Galaxy workflows
3 using machine learning (recurrent neural network) 3 using machine learning (recurrent neural network)
4 """ 4 """
5 5
6 import numpy as np
7 import argparse 6 import argparse
8 import time 7 import time
9 8
10 # machine learning library 9 import extract_workflow_connections
11 import tensorflow as tf
12 from keras import backend as K
13 import keras.callbacks as callbacks 10 import keras.callbacks as callbacks
14 11 import numpy as np
15 import extract_workflow_connections 12 import optimise_hyperparameters
16 import prepare_data 13 import prepare_data
17 import optimise_hyperparameters
18 import utils 14 import utils
19 15
20 16
21 class PredictTool: 17 class PredictTool:
22
23 def __init__(self, num_cpus): 18 def __init__(self, num_cpus):
24 """ Init method. """ 19 """ Init method. """
25 # set the number of cpus 20
26 cpu_config = tf.ConfigProto( 21 def find_train_best_network(
27 device_count={"CPU": num_cpus}, 22 self,
28 intra_op_parallelism_threads=num_cpus, 23 network_config,
29 inter_op_parallelism_threads=num_cpus, 24 reverse_dictionary,
30 allow_soft_placement=True 25 train_data,
31 ) 26 train_labels,
32 K.set_session(tf.Session(config=cpu_config)) 27 test_data,
33 28 test_labels,
34 def find_train_best_network(self, network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, standard_connections, tool_freq, tool_tr_samples): 29 n_epochs,
30 class_weights,
31 usage_pred,
32 standard_connections,
33 tool_freq,
34 tool_tr_samples,
35 ):
35 """ 36 """
36 Define recurrent neural network and train sequential data 37 Define recurrent neural network and train sequential data
37 """ 38 """
38 # get tools with lowest representation 39 # get tools with lowest representation
39 lowest_tool_ids = utils.get_lowest_tools(tool_freq) 40 lowest_tool_ids = utils.get_lowest_tools(tool_freq)
40 41
41 print("Start hyperparameter optimisation...") 42 print("Start hyperparameter optimisation...")
42 hyper_opt = optimise_hyperparameters.HyperparameterOptimisation() 43 hyper_opt = optimise_hyperparameters.HyperparameterOptimisation()
43 best_params, best_model = hyper_opt.train_model(network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, tool_tr_samples, class_weights) 44 best_params, best_model = hyper_opt.train_model(
45 network_config,
46 reverse_dictionary,
47 train_data,
48 train_labels,
49 test_data,
50 test_labels,
51 tool_tr_samples,
52 class_weights,
53 )
44 54
45 # define callbacks 55 # define callbacks
46 early_stopping = callbacks.EarlyStopping(monitor='loss', mode='min', verbose=1, min_delta=1e-1, restore_best_weights=True) 56 early_stopping = callbacks.EarlyStopping(
47 predict_callback_test = PredictCallback(test_data, test_labels, reverse_dictionary, n_epochs, usage_pred, standard_connections, lowest_tool_ids) 57 monitor="loss",
58 mode="min",
59 verbose=1,
60 min_delta=1e-1,
61 restore_best_weights=True,
62 )
63 predict_callback_test = PredictCallback(
64 test_data,
65 test_labels,
66 reverse_dictionary,
67 n_epochs,
68 usage_pred,
69 standard_connections,
70 lowest_tool_ids,
71 )
48 72
49 callbacks_list = [predict_callback_test, early_stopping] 73 callbacks_list = [predict_callback_test, early_stopping]
50 batch_size = int(best_params["batch_size"]) 74 batch_size = int(best_params["batch_size"])
51 75
52 print("Start training on the best model...") 76 print("Start training on the best model...")
55 utils.balanced_sample_generator( 79 utils.balanced_sample_generator(
56 train_data, 80 train_data,
57 train_labels, 81 train_labels,
58 batch_size, 82 batch_size,
59 tool_tr_samples, 83 tool_tr_samples,
60 reverse_dictionary 84 reverse_dictionary,
61 ), 85 ),
62 steps_per_epoch=len(train_data) // batch_size, 86 steps_per_epoch=len(train_data) // batch_size,
63 epochs=n_epochs, 87 epochs=n_epochs,
64 callbacks=callbacks_list, 88 callbacks=callbacks_list,
65 validation_data=(test_data, test_labels), 89 validation_data=(test_data, test_labels),
66 verbose=2, 90 verbose=2,
67 shuffle=True 91 shuffle=True,
68 ) 92 )
69 train_performance["validation_loss"] = np.array(trained_model.history["val_loss"]) 93 train_performance["validation_loss"] = np.array(
94 trained_model.history["val_loss"]
95 )
70 train_performance["precision"] = predict_callback_test.precision 96 train_performance["precision"] = predict_callback_test.precision
71 train_performance["usage_weights"] = predict_callback_test.usage_weights 97 train_performance["usage_weights"] = predict_callback_test.usage_weights
72 train_performance["published_precision"] = predict_callback_test.published_precision 98 train_performance[
73 train_performance["lowest_pub_precision"] = predict_callback_test.lowest_pub_precision 99 "published_precision"
74 train_performance["lowest_norm_precision"] = predict_callback_test.lowest_norm_precision 100 ] = predict_callback_test.published_precision
101 train_performance[
102 "lowest_pub_precision"
103 ] = predict_callback_test.lowest_pub_precision
104 train_performance[
105 "lowest_norm_precision"
106 ] = predict_callback_test.lowest_norm_precision
75 train_performance["train_loss"] = np.array(trained_model.history["loss"]) 107 train_performance["train_loss"] = np.array(trained_model.history["loss"])
76 train_performance["model"] = best_model 108 train_performance["model"] = best_model
77 train_performance["best_parameters"] = best_params 109 train_performance["best_parameters"] = best_params
78 return train_performance 110 return train_performance
79 111
80 112
81 class PredictCallback(callbacks.Callback): 113 class PredictCallback(callbacks.Callback):
82 def __init__(self, test_data, test_labels, reverse_data_dictionary, n_epochs, usg_scores, standard_connections, lowest_tool_ids): 114 def __init__(
115 self,
116 test_data,
117 test_labels,
118 reverse_data_dictionary,
119 n_epochs,
120 usg_scores,
121 standard_connections,
122 lowest_tool_ids,
123 ):
83 self.test_data = test_data 124 self.test_data = test_data
84 self.test_labels = test_labels 125 self.test_labels = test_labels
85 self.reverse_data_dictionary = reverse_data_dictionary 126 self.reverse_data_dictionary = reverse_data_dictionary
86 self.precision = list() 127 self.precision = list()
87 self.usage_weights = list() 128 self.usage_weights = list()
96 def on_epoch_end(self, epoch, logs={}): 137 def on_epoch_end(self, epoch, logs={}):
97 """ 138 """
98 Compute absolute and compatible precision for test data 139 Compute absolute and compatible precision for test data
99 """ 140 """
100 if len(self.test_data) > 0: 141 if len(self.test_data) > 0:
101 usage_weights, precision, precision_pub, low_pub_prec, low_norm_prec, low_num = utils.verify_model(self.model, self.test_data, self.test_labels, self.reverse_data_dictionary, self.pred_usage_scores, self.standard_connections, self.lowest_tool_ids) 142 (
143 usage_weights,
144 precision,
145 precision_pub,
146 low_pub_prec,
147 low_norm_prec,
148 low_num,
149 ) = utils.verify_model(
150 self.model,
151 self.test_data,
152 self.test_labels,
153 self.reverse_data_dictionary,
154 self.pred_usage_scores,
155 self.standard_connections,
156 self.lowest_tool_ids,
157 )
102 self.precision.append(precision) 158 self.precision.append(precision)
103 self.usage_weights.append(usage_weights) 159 self.usage_weights.append(usage_weights)
104 self.published_precision.append(precision_pub) 160 self.published_precision.append(precision_pub)
105 self.lowest_pub_precision.append(low_pub_prec) 161 self.lowest_pub_precision.append(low_pub_prec)
106 self.lowest_norm_precision.append(low_norm_prec) 162 self.lowest_norm_precision.append(low_norm_prec)
107 print("Epoch %d usage weights: %s" % (epoch + 1, usage_weights)) 163 print("Epoch %d usage weights: %s" % (epoch + 1, usage_weights))
108 print("Epoch %d normal precision: %s" % (epoch + 1, precision)) 164 print("Epoch %d normal precision: %s" % (epoch + 1, precision))
109 print("Epoch %d published precision: %s" % (epoch + 1, precision_pub)) 165 print("Epoch %d published precision: %s" % (epoch + 1, precision_pub))
110 print("Epoch %d lowest published precision: %s" % (epoch + 1, low_pub_prec)) 166 print("Epoch %d lowest published precision: %s" % (epoch + 1, low_pub_prec))
111 print("Epoch %d lowest normal precision: %s" % (epoch + 1, low_norm_prec)) 167 print("Epoch %d lowest normal precision: %s" % (epoch + 1, low_norm_prec))
112 print("Epoch %d number of test samples with lowest tool ids: %s" % (epoch + 1, low_num)) 168 print(
169 "Epoch %d number of test samples with lowest tool ids: %s"
170 % (epoch + 1, low_num)
171 )
113 172
114 173
115 if __name__ == "__main__": 174 if __name__ == "__main__":
116 start_time = time.time() 175 start_time = time.time()
117 176
118 arg_parser = argparse.ArgumentParser() 177 arg_parser = argparse.ArgumentParser()
119 arg_parser.add_argument("-wf", "--workflow_file", required=True, help="workflows tabular file") 178 arg_parser.add_argument(
120 arg_parser.add_argument("-tu", "--tool_usage_file", required=True, help="tool usage file") 179 "-wf", "--workflow_file", required=True, help="workflows tabular file"
121 arg_parser.add_argument("-om", "--output_model", required=True, help="trained model file") 180 )
181 arg_parser.add_argument(
182 "-tu", "--tool_usage_file", required=True, help="tool usage file"
183 )
184 arg_parser.add_argument(
185 "-om", "--output_model", required=True, help="trained model file"
186 )
122 # data parameters 187 # data parameters
123 arg_parser.add_argument("-cd", "--cutoff_date", required=True, help="earliest date for taking tool usage") 188 arg_parser.add_argument(
124 arg_parser.add_argument("-pl", "--maximum_path_length", required=True, help="maximum length of tool path") 189 "-cd",
125 arg_parser.add_argument("-ep", "--n_epochs", required=True, help="number of iterations to run to create model") 190 "--cutoff_date",
126 arg_parser.add_argument("-oe", "--optimize_n_epochs", required=True, help="number of iterations to run to find best model parameters") 191 required=True,
127 arg_parser.add_argument("-me", "--max_evals", required=True, help="maximum number of configuration evaluations") 192 help="earliest date for taking tool usage",
128 arg_parser.add_argument("-ts", "--test_share", required=True, help="share of data to be used for testing") 193 )
194 arg_parser.add_argument(
195 "-pl",
196 "--maximum_path_length",
197 required=True,
198 help="maximum length of tool path",
199 )
200 arg_parser.add_argument(
201 "-ep",
202 "--n_epochs",
203 required=True,
204 help="number of iterations to run to create model",
205 )
206 arg_parser.add_argument(
207 "-oe",
208 "--optimize_n_epochs",
209 required=True,
210 help="number of iterations to run to find best model parameters",
211 )
212 arg_parser.add_argument(
213 "-me",
214 "--max_evals",
215 required=True,
216 help="maximum number of configuration evaluations",
217 )
218 arg_parser.add_argument(
219 "-ts",
220 "--test_share",
221 required=True,
222 help="share of data to be used for testing",
223 )
129 # neural network parameters 224 # neural network parameters
130 arg_parser.add_argument("-bs", "--batch_size", required=True, help="size of the tranining batch i.e. the number of samples per batch") 225 arg_parser.add_argument(
131 arg_parser.add_argument("-ut", "--units", required=True, help="number of hidden recurrent units") 226 "-bs",
132 arg_parser.add_argument("-es", "--embedding_size", required=True, help="size of the fixed vector learned for each tool") 227 "--batch_size",
133 arg_parser.add_argument("-dt", "--dropout", required=True, help="percentage of neurons to be dropped") 228 required=True,
134 arg_parser.add_argument("-sd", "--spatial_dropout", required=True, help="1d dropout used for embedding layer") 229 help="size of the tranining batch i.e. the number of samples per batch",
135 arg_parser.add_argument("-rd", "--recurrent_dropout", required=True, help="dropout for the recurrent layers") 230 )
136 arg_parser.add_argument("-lr", "--learning_rate", required=True, help="learning rate") 231 arg_parser.add_argument(
232 "-ut", "--units", required=True, help="number of hidden recurrent units"
233 )
234 arg_parser.add_argument(
235 "-es",
236 "--embedding_size",
237 required=True,
238 help="size of the fixed vector learned for each tool",
239 )
240 arg_parser.add_argument(
241 "-dt", "--dropout", required=True, help="percentage of neurons to be dropped"
242 )
243 arg_parser.add_argument(
244 "-sd",
245 "--spatial_dropout",
246 required=True,
247 help="1d dropout used for embedding layer",
248 )
249 arg_parser.add_argument(
250 "-rd",
251 "--recurrent_dropout",
252 required=True,
253 help="dropout for the recurrent layers",
254 )
255 arg_parser.add_argument(
256 "-lr", "--learning_rate", required=True, help="learning rate"
257 )
137 258
138 # get argument values 259 # get argument values
139 args = vars(arg_parser.parse_args()) 260 args = vars(arg_parser.parse_args())
140 tool_usage_path = args["tool_usage_file"] 261 tool_usage_path = args["tool_usage_file"]
141 workflows_path = args["workflow_file"] 262 workflows_path = args["workflow_file"]
154 recurrent_dropout = args["recurrent_dropout"] 275 recurrent_dropout = args["recurrent_dropout"]
155 learning_rate = args["learning_rate"] 276 learning_rate = args["learning_rate"]
156 num_cpus = 16 277 num_cpus = 16
157 278
158 config = { 279 config = {
159 'cutoff_date': cutoff_date, 280 "cutoff_date": cutoff_date,
160 'maximum_path_length': maximum_path_length, 281 "maximum_path_length": maximum_path_length,
161 'n_epochs': n_epochs, 282 "n_epochs": n_epochs,
162 'optimize_n_epochs': optimize_n_epochs, 283 "optimize_n_epochs": optimize_n_epochs,
163 'max_evals': max_evals, 284 "max_evals": max_evals,
164 'test_share': test_share, 285 "test_share": test_share,
165 'batch_size': batch_size, 286 "batch_size": batch_size,
166 'units': units, 287 "units": units,
167 'embedding_size': embedding_size, 288 "embedding_size": embedding_size,
168 'dropout': dropout, 289 "dropout": dropout,
169 'spatial_dropout': spatial_dropout, 290 "spatial_dropout": spatial_dropout,
170 'recurrent_dropout': recurrent_dropout, 291 "recurrent_dropout": recurrent_dropout,
171 'learning_rate': learning_rate 292 "learning_rate": learning_rate,
172 } 293 }
173 294
174 # Extract and process workflows 295 # Extract and process workflows
175 connections = extract_workflow_connections.ExtractWorkflowConnections() 296 connections = extract_workflow_connections.ExtractWorkflowConnections()
176 workflow_paths, compatible_next_tools, standard_connections = connections.read_tabular_file(workflows_path) 297 (
298 workflow_paths,
299 compatible_next_tools,
300 standard_connections,
301 ) = connections.read_tabular_file(workflows_path)
177 # Process the paths from workflows 302 # Process the paths from workflows
178 print("Dividing data...") 303 print("Dividing data...")
179 data = prepare_data.PrepareData(maximum_path_length, test_share) 304 data = prepare_data.PrepareData(maximum_path_length, test_share)
180 train_data, train_labels, test_data, test_labels, data_dictionary, reverse_dictionary, class_weights, usage_pred, train_tool_freq, tool_tr_samples = data.get_data_labels_matrices(workflow_paths, tool_usage_path, cutoff_date, compatible_next_tools, standard_connections) 305 (
306 train_data,
307 train_labels,
308 test_data,
309 test_labels,
310 data_dictionary,
311 reverse_dictionary,
312 class_weights,
313 usage_pred,
314 train_tool_freq,
315 tool_tr_samples,
316 ) = data.get_data_labels_matrices(
317 workflow_paths,
318 tool_usage_path,
319 cutoff_date,
320 compatible_next_tools,
321 standard_connections,
322 )
181 # find the best model and start training 323 # find the best model and start training
182 predict_tool = PredictTool(num_cpus) 324 predict_tool = PredictTool(num_cpus)
183 # start training with weighted classes 325 # start training with weighted classes
184 print("Training with weighted classes and samples ...") 326 print("Training with weighted classes and samples ...")
185 results_weighted = predict_tool.find_train_best_network(config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, standard_connections, train_tool_freq, tool_tr_samples) 327 results_weighted = predict_tool.find_train_best_network(
186 utils.save_model(results_weighted, data_dictionary, compatible_next_tools, trained_model_path, class_weights, standard_connections) 328 config,
329 reverse_dictionary,
330 train_data,
331 train_labels,
332 test_data,
333 test_labels,
334 n_epochs,
335 class_weights,
336 usage_pred,
337 standard_connections,
338 train_tool_freq,
339 tool_tr_samples,
340 )
341 utils.save_model(
342 results_weighted,
343 data_dictionary,
344 compatible_next_tools,
345 trained_model_path,
346 class_weights,
347 standard_connections,
348 )
187 end_time = time.time() 349 end_time = time.time()
188 print("Program finished in %s seconds" % str(end_time - start_time)) 350 print("Program finished in %s seconds" % str(end_time - start_time))