Mercurial > repos > bgruening > create_tool_recommendation_model
comparison main.py @ 6:e94dc7945639 draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
author | bgruening |
---|---|
date | Sun, 16 Oct 2022 11:52:10 +0000 |
parents | 4f7e6612906b |
children |
comparison
equal
deleted
inserted
replaced
5:4f7e6612906b | 6:e94dc7945639 |
---|---|
1 """ | 1 """ |
2 Predict next tools in the Galaxy workflows | 2 Predict next tools in the Galaxy workflows |
3 using machine learning (recurrent neural network) | 3 using deep learning learning (Transformers) |
4 """ | 4 """ |
5 | |
6 import argparse | 5 import argparse |
7 import time | 6 import time |
8 | 7 |
9 import extract_workflow_connections | 8 import extract_workflow_connections |
10 import keras.callbacks as callbacks | |
11 import numpy as np | |
12 import optimise_hyperparameters | |
13 import prepare_data | 9 import prepare_data |
14 import utils | 10 import train_transformer |
15 | |
16 | |
17 class PredictTool: | |
18 def __init__(self, num_cpus): | |
19 """ Init method. """ | |
20 | |
21 def find_train_best_network( | |
22 self, | |
23 network_config, | |
24 reverse_dictionary, | |
25 train_data, | |
26 train_labels, | |
27 test_data, | |
28 test_labels, | |
29 n_epochs, | |
30 class_weights, | |
31 usage_pred, | |
32 standard_connections, | |
33 tool_freq, | |
34 tool_tr_samples, | |
35 ): | |
36 """ | |
37 Define recurrent neural network and train sequential data | |
38 """ | |
39 # get tools with lowest representation | |
40 lowest_tool_ids = utils.get_lowest_tools(tool_freq) | |
41 | |
42 print("Start hyperparameter optimisation...") | |
43 hyper_opt = optimise_hyperparameters.HyperparameterOptimisation() | |
44 best_params, best_model = hyper_opt.train_model( | |
45 network_config, | |
46 reverse_dictionary, | |
47 train_data, | |
48 train_labels, | |
49 test_data, | |
50 test_labels, | |
51 tool_tr_samples, | |
52 class_weights, | |
53 ) | |
54 | |
55 # define callbacks | |
56 early_stopping = callbacks.EarlyStopping( | |
57 monitor="loss", | |
58 mode="min", | |
59 verbose=1, | |
60 min_delta=1e-1, | |
61 restore_best_weights=True, | |
62 ) | |
63 predict_callback_test = PredictCallback( | |
64 test_data, | |
65 test_labels, | |
66 reverse_dictionary, | |
67 n_epochs, | |
68 usage_pred, | |
69 standard_connections, | |
70 lowest_tool_ids, | |
71 ) | |
72 | |
73 callbacks_list = [predict_callback_test, early_stopping] | |
74 batch_size = int(best_params["batch_size"]) | |
75 | |
76 print("Start training on the best model...") | |
77 train_performance = dict() | |
78 trained_model = best_model.fit_generator( | |
79 utils.balanced_sample_generator( | |
80 train_data, | |
81 train_labels, | |
82 batch_size, | |
83 tool_tr_samples, | |
84 reverse_dictionary, | |
85 ), | |
86 steps_per_epoch=len(train_data) // batch_size, | |
87 epochs=n_epochs, | |
88 callbacks=callbacks_list, | |
89 validation_data=(test_data, test_labels), | |
90 verbose=2, | |
91 shuffle=True, | |
92 ) | |
93 train_performance["validation_loss"] = np.array( | |
94 trained_model.history["val_loss"] | |
95 ) | |
96 train_performance["precision"] = predict_callback_test.precision | |
97 train_performance["usage_weights"] = predict_callback_test.usage_weights | |
98 train_performance[ | |
99 "published_precision" | |
100 ] = predict_callback_test.published_precision | |
101 train_performance[ | |
102 "lowest_pub_precision" | |
103 ] = predict_callback_test.lowest_pub_precision | |
104 train_performance[ | |
105 "lowest_norm_precision" | |
106 ] = predict_callback_test.lowest_norm_precision | |
107 train_performance["train_loss"] = np.array(trained_model.history["loss"]) | |
108 train_performance["model"] = best_model | |
109 train_performance["best_parameters"] = best_params | |
110 return train_performance | |
111 | |
112 | |
113 class PredictCallback(callbacks.Callback): | |
114 def __init__( | |
115 self, | |
116 test_data, | |
117 test_labels, | |
118 reverse_data_dictionary, | |
119 n_epochs, | |
120 usg_scores, | |
121 standard_connections, | |
122 lowest_tool_ids, | |
123 ): | |
124 self.test_data = test_data | |
125 self.test_labels = test_labels | |
126 self.reverse_data_dictionary = reverse_data_dictionary | |
127 self.precision = list() | |
128 self.usage_weights = list() | |
129 self.published_precision = list() | |
130 self.n_epochs = n_epochs | |
131 self.pred_usage_scores = usg_scores | |
132 self.standard_connections = standard_connections | |
133 self.lowest_tool_ids = lowest_tool_ids | |
134 self.lowest_pub_precision = list() | |
135 self.lowest_norm_precision = list() | |
136 | |
137 def on_epoch_end(self, epoch, logs={}): | |
138 """ | |
139 Compute absolute and compatible precision for test data | |
140 """ | |
141 if len(self.test_data) > 0: | |
142 ( | |
143 usage_weights, | |
144 precision, | |
145 precision_pub, | |
146 low_pub_prec, | |
147 low_norm_prec, | |
148 low_num, | |
149 ) = utils.verify_model( | |
150 self.model, | |
151 self.test_data, | |
152 self.test_labels, | |
153 self.reverse_data_dictionary, | |
154 self.pred_usage_scores, | |
155 self.standard_connections, | |
156 self.lowest_tool_ids, | |
157 ) | |
158 self.precision.append(precision) | |
159 self.usage_weights.append(usage_weights) | |
160 self.published_precision.append(precision_pub) | |
161 self.lowest_pub_precision.append(low_pub_prec) | |
162 self.lowest_norm_precision.append(low_norm_prec) | |
163 print("Epoch %d usage weights: %s" % (epoch + 1, usage_weights)) | |
164 print("Epoch %d normal precision: %s" % (epoch + 1, precision)) | |
165 print("Epoch %d published precision: %s" % (epoch + 1, precision_pub)) | |
166 print("Epoch %d lowest published precision: %s" % (epoch + 1, low_pub_prec)) | |
167 print("Epoch %d lowest normal precision: %s" % (epoch + 1, low_norm_prec)) | |
168 print( | |
169 "Epoch %d number of test samples with lowest tool ids: %s" | |
170 % (epoch + 1, low_num) | |
171 ) | |
172 | |
173 | 11 |
174 if __name__ == "__main__": | 12 if __name__ == "__main__": |
175 start_time = time.time() | 13 start_time = time.time() |
176 | 14 |
177 arg_parser = argparse.ArgumentParser() | 15 arg_parser = argparse.ArgumentParser() |
178 arg_parser.add_argument( | 16 arg_parser.add_argument("-wf", "--workflow_file", required=True, help="workflows tabular file") |
179 "-wf", "--workflow_file", required=True, help="workflows tabular file" | 17 arg_parser.add_argument("-tu", "--tool_usage_file", required=True, help="tool usage file") |
180 ) | |
181 arg_parser.add_argument( | |
182 "-tu", "--tool_usage_file", required=True, help="tool usage file" | |
183 ) | |
184 arg_parser.add_argument( | |
185 "-om", "--output_model", required=True, help="trained model file" | |
186 ) | |
187 # data parameters | 18 # data parameters |
188 arg_parser.add_argument( | 19 arg_parser.add_argument("-cd", "--cutoff_date", required=True, help="earliest date for taking tool usage") |
189 "-cd", | 20 arg_parser.add_argument("-pl", "--maximum_path_length", required=True, help="maximum length of tool path") |
190 "--cutoff_date", | 21 arg_parser.add_argument("-om", "--output_model", required=True, help="trained model path") |
191 required=True, | |
192 help="earliest date for taking tool usage", | |
193 ) | |
194 arg_parser.add_argument( | |
195 "-pl", | |
196 "--maximum_path_length", | |
197 required=True, | |
198 help="maximum length of tool path", | |
199 ) | |
200 arg_parser.add_argument( | |
201 "-ep", | |
202 "--n_epochs", | |
203 required=True, | |
204 help="number of iterations to run to create model", | |
205 ) | |
206 arg_parser.add_argument( | |
207 "-oe", | |
208 "--optimize_n_epochs", | |
209 required=True, | |
210 help="number of iterations to run to find best model parameters", | |
211 ) | |
212 arg_parser.add_argument( | |
213 "-me", | |
214 "--max_evals", | |
215 required=True, | |
216 help="maximum number of configuration evaluations", | |
217 ) | |
218 arg_parser.add_argument( | |
219 "-ts", | |
220 "--test_share", | |
221 required=True, | |
222 help="share of data to be used for testing", | |
223 ) | |
224 # neural network parameters | 22 # neural network parameters |
225 arg_parser.add_argument( | 23 arg_parser.add_argument("-ti", "--n_train_iter", required=True, help="Number of training iterations run to create model") |
226 "-bs", | 24 arg_parser.add_argument("-nhd", "--n_heads", required=True, help="Number of head in transformer's multi-head attention") |
227 "--batch_size", | 25 arg_parser.add_argument("-ed", "--n_embed_dim", required=True, help="Embedding dimension") |
228 required=True, | 26 arg_parser.add_argument("-fd", "--n_feed_forward_dim", required=True, help="Feed forward network dimension") |
229 help="size of the tranining batch i.e. the number of samples per batch", | 27 arg_parser.add_argument("-dt", "--dropout", required=True, help="Percentage of neurons to be dropped") |
230 ) | 28 arg_parser.add_argument("-lr", "--learning_rate", required=True, help="Learning rate") |
231 arg_parser.add_argument( | 29 arg_parser.add_argument("-ts", "--te_share", required=True, help="Share of data to be used for testing") |
232 "-ut", "--units", required=True, help="number of hidden recurrent units" | 30 arg_parser.add_argument("-trbs", "--tr_batch_size", required=True, help="Train batch size") |
233 ) | 31 arg_parser.add_argument("-trlg", "--tr_logging_step", required=True, help="Train logging frequency") |
234 arg_parser.add_argument( | 32 arg_parser.add_argument("-telg", "--te_logging_step", required=True, help="Test logging frequency") |
235 "-es", | 33 arg_parser.add_argument("-tebs", "--te_batch_size", required=True, help="Test batch size") |
236 "--embedding_size", | |
237 required=True, | |
238 help="size of the fixed vector learned for each tool", | |
239 ) | |
240 arg_parser.add_argument( | |
241 "-dt", "--dropout", required=True, help="percentage of neurons to be dropped" | |
242 ) | |
243 arg_parser.add_argument( | |
244 "-sd", | |
245 "--spatial_dropout", | |
246 required=True, | |
247 help="1d dropout used for embedding layer", | |
248 ) | |
249 arg_parser.add_argument( | |
250 "-rd", | |
251 "--recurrent_dropout", | |
252 required=True, | |
253 help="dropout for the recurrent layers", | |
254 ) | |
255 arg_parser.add_argument( | |
256 "-lr", "--learning_rate", required=True, help="learning rate" | |
257 ) | |
258 | 34 |
259 # get argument values | 35 # get argument values |
260 args = vars(arg_parser.parse_args()) | 36 args = vars(arg_parser.parse_args()) |
261 tool_usage_path = args["tool_usage_file"] | 37 tool_usage_path = args["tool_usage_file"] |
262 workflows_path = args["workflow_file"] | 38 workflows_path = args["workflow_file"] |
263 cutoff_date = args["cutoff_date"] | 39 cutoff_date = args["cutoff_date"] |
264 maximum_path_length = int(args["maximum_path_length"]) | 40 maximum_path_length = int(args["maximum_path_length"]) |
41 | |
42 n_train_iter = int(args["n_train_iter"]) | |
43 te_share = float(args["te_share"]) | |
44 tr_batch_size = int(args["tr_batch_size"]) | |
45 te_batch_size = int(args["te_batch_size"]) | |
46 | |
47 n_heads = int(args["n_heads"]) | |
48 feed_forward_dim = int(args["n_feed_forward_dim"]) | |
49 embedding_dim = int(args["n_embed_dim"]) | |
50 dropout = float(args["dropout"]) | |
51 learning_rate = float(args["learning_rate"]) | |
52 te_logging_step = int(args["te_logging_step"]) | |
53 tr_logging_step = int(args["tr_logging_step"]) | |
265 trained_model_path = args["output_model"] | 54 trained_model_path = args["output_model"] |
266 n_epochs = int(args["n_epochs"]) | |
267 optimize_n_epochs = int(args["optimize_n_epochs"]) | |
268 max_evals = int(args["max_evals"]) | |
269 test_share = float(args["test_share"]) | |
270 batch_size = args["batch_size"] | |
271 units = args["units"] | |
272 embedding_size = args["embedding_size"] | |
273 dropout = args["dropout"] | |
274 spatial_dropout = args["spatial_dropout"] | |
275 recurrent_dropout = args["recurrent_dropout"] | |
276 learning_rate = args["learning_rate"] | |
277 num_cpus = 16 | |
278 | 55 |
279 config = { | 56 config = { |
280 "cutoff_date": cutoff_date, | 57 'cutoff_date': cutoff_date, |
281 "maximum_path_length": maximum_path_length, | 58 'maximum_path_length': maximum_path_length, |
282 "n_epochs": n_epochs, | 59 'n_train_iter': n_train_iter, |
283 "optimize_n_epochs": optimize_n_epochs, | 60 'n_heads': n_heads, |
284 "max_evals": max_evals, | 61 'feed_forward_dim': feed_forward_dim, |
285 "test_share": test_share, | 62 'embedding_dim': embedding_dim, |
286 "batch_size": batch_size, | 63 'dropout': dropout, |
287 "units": units, | 64 'learning_rate': learning_rate, |
288 "embedding_size": embedding_size, | 65 'te_share': te_share, |
289 "dropout": dropout, | 66 'te_logging_step': te_logging_step, |
290 "spatial_dropout": spatial_dropout, | 67 'tr_logging_step': tr_logging_step, |
291 "recurrent_dropout": recurrent_dropout, | 68 'tr_batch_size': tr_batch_size, |
292 "learning_rate": learning_rate, | 69 'te_batch_size': te_batch_size, |
70 'trained_model_path': trained_model_path | |
293 } | 71 } |
294 | 72 print("Preprocessing workflows...") |
295 # Extract and process workflows | 73 # Extract and process workflows |
296 connections = extract_workflow_connections.ExtractWorkflowConnections() | 74 connections = extract_workflow_connections.ExtractWorkflowConnections() |
297 ( | 75 # Process raw workflow file |
298 workflow_paths, | 76 wf_dataframe, usage_df = connections.process_raw_files(workflows_path, tool_usage_path, config) |
299 compatible_next_tools, | 77 workflow_paths, pub_conn = connections.read_tabular_file(wf_dataframe, config) |
300 standard_connections, | |
301 ) = connections.read_tabular_file(workflows_path) | |
302 # Process the paths from workflows | 78 # Process the paths from workflows |
303 print("Dividing data...") | 79 print("Dividing data...") |
304 data = prepare_data.PrepareData(maximum_path_length, test_share) | 80 data = prepare_data.PrepareData(maximum_path_length, te_share) |
305 ( | 81 train_data, train_labels, test_data, test_labels, f_dict, r_dict, c_wts, c_tools, tr_tool_freq = data.get_data_labels_matrices(workflow_paths, usage_df, cutoff_date, pub_conn) |
306 train_data, | 82 print(train_data.shape, train_labels.shape, test_data.shape, test_labels.shape) |
307 train_labels, | 83 train_transformer.create_enc_transformer(train_data, train_labels, test_data, test_labels, f_dict, r_dict, c_wts, c_tools, pub_conn, tr_tool_freq, config) |
308 test_data, | |
309 test_labels, | |
310 data_dictionary, | |
311 reverse_dictionary, | |
312 class_weights, | |
313 usage_pred, | |
314 train_tool_freq, | |
315 tool_tr_samples, | |
316 ) = data.get_data_labels_matrices( | |
317 workflow_paths, | |
318 tool_usage_path, | |
319 cutoff_date, | |
320 compatible_next_tools, | |
321 standard_connections, | |
322 ) | |
323 # find the best model and start training | |
324 predict_tool = PredictTool(num_cpus) | |
325 # start training with weighted classes | |
326 print("Training with weighted classes and samples ...") | |
327 results_weighted = predict_tool.find_train_best_network( | |
328 config, | |
329 reverse_dictionary, | |
330 train_data, | |
331 train_labels, | |
332 test_data, | |
333 test_labels, | |
334 n_epochs, | |
335 class_weights, | |
336 usage_pred, | |
337 standard_connections, | |
338 train_tool_freq, | |
339 tool_tr_samples, | |
340 ) | |
341 utils.save_model( | |
342 results_weighted, | |
343 data_dictionary, | |
344 compatible_next_tools, | |
345 trained_model_path, | |
346 class_weights, | |
347 standard_connections, | |
348 ) | |
349 end_time = time.time() | 84 end_time = time.time() |
350 print("Program finished in %s seconds" % str(end_time - start_time)) | 85 print("Program finished in %s seconds" % str(end_time - start_time)) |