Mercurial > repos > bgruening > create_tool_recommendation_model
comparison main.py @ 0:9bf25dbe00ad draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
author | bgruening |
---|---|
date | Wed, 28 Aug 2019 07:19:38 -0400 |
parents | |
children | 12764915e1c5 |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:9bf25dbe00ad |
---|---|
1 """ | |
2 Predict next tools in the Galaxy workflows | |
3 using machine learning (recurrent neural network) | |
4 """ | |
5 | |
6 import numpy as np | |
7 import argparse | |
8 import time | |
9 | |
10 # machine learning library | |
11 import keras.callbacks as callbacks | |
12 | |
13 import extract_workflow_connections | |
14 import prepare_data | |
15 import optimise_hyperparameters | |
16 import utils | |
17 | |
18 | |
19 class PredictTool: | |
20 | |
21 @classmethod | |
22 def __init__(self): | |
23 """ Init method. """ | |
24 | |
25 @classmethod | |
26 def find_train_best_network(self, network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, compatible_next_tools): | |
27 """ | |
28 Define recurrent neural network and train sequential data | |
29 """ | |
30 print("Start hyperparameter optimisation...") | |
31 hyper_opt = optimise_hyperparameters.HyperparameterOptimisation() | |
32 best_params = hyper_opt.train_model(network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, class_weights) | |
33 | |
34 # retrieve the model and train on complete dataset without validation set | |
35 model, best_params = utils.set_recurrent_network(best_params, reverse_dictionary, class_weights) | |
36 | |
37 # define callbacks | |
38 predict_callback_test = PredictCallback(test_data, test_labels, reverse_dictionary, n_epochs, compatible_next_tools, usage_pred) | |
39 # tensor_board = callbacks.TensorBoard(log_dir=log_directory, histogram_freq=0, write_graph=True, write_images=True) | |
40 callbacks_list = [predict_callback_test] | |
41 | |
42 print("Start training on the best model...") | |
43 model_fit = model.fit( | |
44 train_data, | |
45 train_labels, | |
46 batch_size=int(best_params["batch_size"]), | |
47 epochs=n_epochs, | |
48 verbose=2, | |
49 callbacks=callbacks_list, | |
50 shuffle="batch", | |
51 validation_data=(test_data, test_labels) | |
52 ) | |
53 | |
54 train_performance = { | |
55 "train_loss": np.array(model_fit.history["loss"]), | |
56 "model": model, | |
57 "best_parameters": best_params | |
58 } | |
59 | |
60 # if there is test data, add more information | |
61 if len(test_data) > 0: | |
62 train_performance["validation_loss"] = np.array(model_fit.history["val_loss"]) | |
63 train_performance["precision"] = predict_callback_test.precision | |
64 train_performance["usage_weights"] = predict_callback_test.usage_weights | |
65 return train_performance | |
66 | |
67 | |
68 class PredictCallback(callbacks.Callback): | |
69 def __init__(self, test_data, test_labels, reverse_data_dictionary, n_epochs, next_compatible_tools, usg_scores): | |
70 self.test_data = test_data | |
71 self.test_labels = test_labels | |
72 self.reverse_data_dictionary = reverse_data_dictionary | |
73 self.precision = list() | |
74 self.usage_weights = list() | |
75 self.n_epochs = n_epochs | |
76 self.next_compatible_tools = next_compatible_tools | |
77 self.pred_usage_scores = usg_scores | |
78 | |
79 def on_epoch_end(self, epoch, logs={}): | |
80 """ | |
81 Compute absolute and compatible precision for test data | |
82 """ | |
83 if len(self.test_data) > 0: | |
84 precision, usage_weights = utils.verify_model(self.model, self.test_data, self.test_labels, self.reverse_data_dictionary, self.next_compatible_tools, self.pred_usage_scores) | |
85 self.precision.append(precision) | |
86 self.usage_weights.append(usage_weights) | |
87 print("Epoch %d precision: %s" % (epoch + 1, precision)) | |
88 print("Epoch %d usage weights: %s" % (epoch + 1, usage_weights)) | |
89 | |
90 | |
91 if __name__ == "__main__": | |
92 start_time = time.time() | |
93 arg_parser = argparse.ArgumentParser() | |
94 arg_parser.add_argument("-wf", "--workflow_file", required=True, help="workflows tabular file") | |
95 arg_parser.add_argument("-tu", "--tool_usage_file", required=True, help="tool usage file") | |
96 arg_parser.add_argument("-om", "--output_model", required=True, help="trained model file") | |
97 # data parameters | |
98 arg_parser.add_argument("-cd", "--cutoff_date", required=True, help="earliest date for taking tool usage") | |
99 arg_parser.add_argument("-pl", "--maximum_path_length", required=True, help="maximum length of tool path") | |
100 arg_parser.add_argument("-ep", "--n_epochs", required=True, help="number of iterations to run to create model") | |
101 arg_parser.add_argument("-oe", "--optimize_n_epochs", required=True, help="number of iterations to run to find best model parameters") | |
102 arg_parser.add_argument("-me", "--max_evals", required=True, help="maximum number of configuration evaluations") | |
103 arg_parser.add_argument("-ts", "--test_share", required=True, help="share of data to be used for testing") | |
104 arg_parser.add_argument("-vs", "--validation_share", required=True, help="share of data to be used for validation") | |
105 # neural network parameters | |
106 arg_parser.add_argument("-bs", "--batch_size", required=True, help="size of the tranining batch i.e. the number of samples per batch") | |
107 arg_parser.add_argument("-ut", "--units", required=True, help="number of hidden recurrent units") | |
108 arg_parser.add_argument("-es", "--embedding_size", required=True, help="size of the fixed vector learned for each tool") | |
109 arg_parser.add_argument("-dt", "--dropout", required=True, help="percentage of neurons to be dropped") | |
110 arg_parser.add_argument("-sd", "--spatial_dropout", required=True, help="1d dropout used for embedding layer") | |
111 arg_parser.add_argument("-rd", "--recurrent_dropout", required=True, help="dropout for the recurrent layers") | |
112 arg_parser.add_argument("-lr", "--learning_rate", required=True, help="learning rate") | |
113 arg_parser.add_argument("-ar", "--activation_recurrent", required=True, help="activation function for recurrent layers") | |
114 arg_parser.add_argument("-ao", "--activation_output", required=True, help="activation function for output layers") | |
115 arg_parser.add_argument("-lt", "--loss_type", required=True, help="type of the loss/error function") | |
116 # get argument values | |
117 args = vars(arg_parser.parse_args()) | |
118 tool_usage_path = args["tool_usage_file"] | |
119 workflows_path = args["workflow_file"] | |
120 cutoff_date = args["cutoff_date"] | |
121 maximum_path_length = int(args["maximum_path_length"]) | |
122 trained_model_path = args["output_model"] | |
123 n_epochs = int(args["n_epochs"]) | |
124 optimize_n_epochs = int(args["optimize_n_epochs"]) | |
125 max_evals = int(args["max_evals"]) | |
126 test_share = float(args["test_share"]) | |
127 validation_share = float(args["validation_share"]) | |
128 batch_size = args["batch_size"] | |
129 units = args["units"] | |
130 embedding_size = args["embedding_size"] | |
131 dropout = args["dropout"] | |
132 spatial_dropout = args["spatial_dropout"] | |
133 recurrent_dropout = args["recurrent_dropout"] | |
134 learning_rate = args["learning_rate"] | |
135 activation_recurrent = args["activation_recurrent"] | |
136 activation_output = args["activation_output"] | |
137 loss_type = args["loss_type"] | |
138 | |
139 config = { | |
140 'cutoff_date': cutoff_date, | |
141 'maximum_path_length': maximum_path_length, | |
142 'n_epochs': n_epochs, | |
143 'optimize_n_epochs': optimize_n_epochs, | |
144 'max_evals': max_evals, | |
145 'test_share': test_share, | |
146 'validation_share': validation_share, | |
147 'batch_size': batch_size, | |
148 'units': units, | |
149 'embedding_size': embedding_size, | |
150 'dropout': dropout, | |
151 'spatial_dropout': spatial_dropout, | |
152 'recurrent_dropout': recurrent_dropout, | |
153 'learning_rate': learning_rate, | |
154 'activation_recurrent': activation_recurrent, | |
155 'activation_output': activation_output, | |
156 'loss_type': loss_type | |
157 } | |
158 | |
159 # Extract and process workflows | |
160 connections = extract_workflow_connections.ExtractWorkflowConnections() | |
161 workflow_paths, compatible_next_tools = connections.read_tabular_file(workflows_path) | |
162 # Process the paths from workflows | |
163 print("Dividing data...") | |
164 data = prepare_data.PrepareData(maximum_path_length, test_share) | |
165 train_data, train_labels, test_data, test_labels, data_dictionary, reverse_dictionary, class_weights, usage_pred = data.get_data_labels_matrices(workflow_paths, tool_usage_path, cutoff_date, compatible_next_tools) | |
166 # find the best model and start training | |
167 predict_tool = PredictTool() | |
168 # start training with weighted classes | |
169 print("Training with weighted classes and samples ...") | |
170 results_weighted = predict_tool.find_train_best_network(config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, compatible_next_tools) | |
171 print() | |
172 print("Best parameters \n") | |
173 print(results_weighted["best_parameters"]) | |
174 print() | |
175 utils.save_model(results_weighted, data_dictionary, compatible_next_tools, trained_model_path, class_weights) | |
176 end_time = time.time() | |
177 print() | |
178 print("Program finished in %s seconds" % str(end_time - start_time)) |