Mercurial > repos > bgruening > create_tool_recommendation_model
comparison main.py @ 4:afec8c595124 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 65d36f271296a38deeceb0d0e8d471b2898ee8f4"
author | bgruening |
---|---|
date | Tue, 07 Jul 2020 03:25:49 -0400 |
parents | 5b3c08710e47 |
children | 4f7e6612906b |
comparison
equal
deleted
inserted
replaced
3:5b3c08710e47 | 4:afec8c595124 |
---|---|
29 inter_op_parallelism_threads=num_cpus, | 29 inter_op_parallelism_threads=num_cpus, |
30 allow_soft_placement=True | 30 allow_soft_placement=True |
31 ) | 31 ) |
32 K.set_session(tf.Session(config=cpu_config)) | 32 K.set_session(tf.Session(config=cpu_config)) |
33 | 33 |
34 def find_train_best_network(self, network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, standard_connections, l_tool_freq, l_tool_tr_samples): | 34 def find_train_best_network(self, network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, standard_connections, tool_freq, tool_tr_samples): |
35 """ | 35 """ |
36 Define recurrent neural network and train sequential data | 36 Define recurrent neural network and train sequential data |
37 """ | 37 """ |
38 # get tools with lowest representation | 38 # get tools with lowest representation |
39 lowest_tool_ids = utils.get_lowest_tools(l_tool_freq) | 39 lowest_tool_ids = utils.get_lowest_tools(tool_freq) |
40 | 40 |
41 print("Start hyperparameter optimisation...") | 41 print("Start hyperparameter optimisation...") |
42 hyper_opt = optimise_hyperparameters.HyperparameterOptimisation() | 42 hyper_opt = optimise_hyperparameters.HyperparameterOptimisation() |
43 best_params, best_model = hyper_opt.train_model(network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, l_tool_tr_samples, class_weights) | 43 best_params, best_model = hyper_opt.train_model(network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, tool_tr_samples, class_weights) |
44 | 44 |
45 # define callbacks | 45 # define callbacks |
46 early_stopping = callbacks.EarlyStopping(monitor='loss', mode='min', verbose=1, min_delta=1e-1, restore_best_weights=True) | 46 early_stopping = callbacks.EarlyStopping(monitor='loss', mode='min', verbose=1, min_delta=1e-1, restore_best_weights=True) |
47 predict_callback_test = PredictCallback(test_data, test_labels, reverse_dictionary, n_epochs, usage_pred, standard_connections, lowest_tool_ids) | 47 predict_callback_test = PredictCallback(test_data, test_labels, reverse_dictionary, n_epochs, usage_pred, standard_connections, lowest_tool_ids) |
48 | 48 |
49 callbacks_list = [predict_callback_test, early_stopping] | 49 callbacks_list = [predict_callback_test, early_stopping] |
50 | |
51 batch_size = int(best_params["batch_size"]) | 50 batch_size = int(best_params["batch_size"]) |
52 | 51 |
53 print("Start training on the best model...") | 52 print("Start training on the best model...") |
54 train_performance = dict() | 53 train_performance = dict() |
55 trained_model = best_model.fit_generator( | 54 trained_model = best_model.fit_generator( |
56 utils.balanced_sample_generator( | 55 utils.balanced_sample_generator( |
57 train_data, | 56 train_data, |
58 train_labels, | 57 train_labels, |
59 batch_size, | 58 batch_size, |
60 l_tool_tr_samples | 59 tool_tr_samples, |
60 reverse_dictionary | |
61 ), | 61 ), |
62 steps_per_epoch=len(train_data) // batch_size, | 62 steps_per_epoch=len(train_data) // batch_size, |
63 epochs=n_epochs, | 63 epochs=n_epochs, |
64 callbacks=callbacks_list, | 64 callbacks=callbacks_list, |
65 validation_data=(test_data, test_labels), | 65 validation_data=(test_data, test_labels), |
175 connections = extract_workflow_connections.ExtractWorkflowConnections() | 175 connections = extract_workflow_connections.ExtractWorkflowConnections() |
176 workflow_paths, compatible_next_tools, standard_connections = connections.read_tabular_file(workflows_path) | 176 workflow_paths, compatible_next_tools, standard_connections = connections.read_tabular_file(workflows_path) |
177 # Process the paths from workflows | 177 # Process the paths from workflows |
178 print("Dividing data...") | 178 print("Dividing data...") |
179 data = prepare_data.PrepareData(maximum_path_length, test_share) | 179 data = prepare_data.PrepareData(maximum_path_length, test_share) |
180 train_data, train_labels, test_data, test_labels, data_dictionary, reverse_dictionary, class_weights, usage_pred, l_tool_freq, l_tool_tr_samples = data.get_data_labels_matrices(workflow_paths, tool_usage_path, cutoff_date, compatible_next_tools, standard_connections) | 180 train_data, train_labels, test_data, test_labels, data_dictionary, reverse_dictionary, class_weights, usage_pred, train_tool_freq, tool_tr_samples = data.get_data_labels_matrices(workflow_paths, tool_usage_path, cutoff_date, compatible_next_tools, standard_connections) |
181 # find the best model and start training | 181 # find the best model and start training |
182 predict_tool = PredictTool(num_cpus) | 182 predict_tool = PredictTool(num_cpus) |
183 # start training with weighted classes | 183 # start training with weighted classes |
184 print("Training with weighted classes and samples ...") | 184 print("Training with weighted classes and samples ...") |
185 results_weighted = predict_tool.find_train_best_network(config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, standard_connections, l_tool_freq, l_tool_tr_samples) | 185 results_weighted = predict_tool.find_train_best_network(config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, standard_connections, train_tool_freq, tool_tr_samples) |
186 utils.save_model(results_weighted, data_dictionary, compatible_next_tools, trained_model_path, class_weights, standard_connections) | 186 utils.save_model(results_weighted, data_dictionary, compatible_next_tools, trained_model_path, class_weights, standard_connections) |
187 end_time = time.time() | 187 end_time = time.time() |
188 print() | |
189 print("Program finished in %s seconds" % str(end_time - start_time)) | 188 print("Program finished in %s seconds" % str(end_time - start_time)) |