Mercurial > repos > bgruening > create_tool_recommendation_model
comparison prepare_data.py @ 4:afec8c595124 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 65d36f271296a38deeceb0d0e8d471b2898ee8f4"
author | bgruening |
---|---|
date | Tue, 07 Jul 2020 03:25:49 -0400 |
parents | 5b3c08710e47 |
children | 4f7e6612906b |
comparison
equal
deleted
inserted
replaced
3:5b3c08710e47 | 4:afec8c595124 |
---|---|
8 import collections | 8 import collections |
9 import numpy as np | 9 import numpy as np |
10 import random | 10 import random |
11 | 11 |
12 import predict_tool_usage | 12 import predict_tool_usage |
13 import utils | |
14 | 13 |
15 main_path = os.getcwd() | 14 main_path = os.getcwd() |
16 | 15 |
17 | 16 |
18 class PrepareData: | 17 class PrepareData: |
209 """ | 208 """ |
210 Get the frequency of last tool of each tool sequence | 209 Get the frequency of last tool of each tool sequence |
211 to estimate the frequency of tool sequences | 210 to estimate the frequency of tool sequences |
212 """ | 211 """ |
213 last_tool_freq = dict() | 212 last_tool_freq = dict() |
214 inv_freq = dict() | 213 freq_dict_names = dict() |
215 for path in train_paths: | 214 for path in train_paths: |
216 last_tool = path.split(",")[-1] | 215 last_tool = path.split(",")[-1] |
217 if last_tool not in last_tool_freq: | 216 if last_tool not in last_tool_freq: |
218 last_tool_freq[last_tool] = 0 | 217 last_tool_freq[last_tool] = 0 |
218 freq_dict_names[reverse_dictionary[int(last_tool)]] = 0 | |
219 last_tool_freq[last_tool] += 1 | 219 last_tool_freq[last_tool] += 1 |
220 max_freq = max(last_tool_freq.values()) | 220 freq_dict_names[reverse_dictionary[int(last_tool)]] += 1 |
221 for t in last_tool_freq: | 221 return last_tool_freq |
222 inv_freq[t] = int(np.round(max_freq / float(last_tool_freq[t]), 0)) | |
223 return last_tool_freq, inv_freq | |
224 | 222 |
225 def get_toolid_samples(self, train_data, l_tool_freq): | 223 def get_toolid_samples(self, train_data, l_tool_freq): |
226 l_tool_tr_samples = dict() | 224 l_tool_tr_samples = dict() |
227 for tool_id in l_tool_freq: | 225 for tool_id in l_tool_freq: |
228 for index, tr_sample in enumerate(train_data): | 226 for index, tr_sample in enumerate(train_data): |
252 multilabels_paths = self.prepare_paths_labels_dictionary(dictionary, rev_dict, all_unique_paths, compatible_next_tools) | 250 multilabels_paths = self.prepare_paths_labels_dictionary(dictionary, rev_dict, all_unique_paths, compatible_next_tools) |
253 | 251 |
254 print("Complete data: %d" % len(multilabels_paths)) | 252 print("Complete data: %d" % len(multilabels_paths)) |
255 train_paths_dict, test_paths_dict = self.split_test_train_data(multilabels_paths) | 253 train_paths_dict, test_paths_dict = self.split_test_train_data(multilabels_paths) |
256 | 254 |
257 # get sample frequency | |
258 l_tool_freq, inv_last_tool_freq = self.get_train_last_tool_freq(train_paths_dict, rev_dict) | |
259 | |
260 print("Train data: %d" % len(train_paths_dict)) | 255 print("Train data: %d" % len(train_paths_dict)) |
261 print("Test data: %d" % len(test_paths_dict)) | 256 print("Test data: %d" % len(test_paths_dict)) |
262 | 257 |
263 print("Padding train and test data...") | 258 print("Padding train and test data...") |
264 # pad training and test data with leading zeros | 259 # pad training and test data with leading zeros |
265 test_data, test_labels = self.pad_paths(test_paths_dict, num_classes, standard_connections, rev_dict) | 260 test_data, test_labels = self.pad_paths(test_paths_dict, num_classes, standard_connections, rev_dict) |
266 train_data, train_labels = self.pad_paths(train_paths_dict, num_classes, standard_connections, rev_dict) | 261 train_data, train_labels = self.pad_paths(train_paths_dict, num_classes, standard_connections, rev_dict) |
267 | 262 |
263 print("Estimating sample frequency...") | |
264 l_tool_freq = self.get_train_last_tool_freq(train_paths_dict, rev_dict) | |
268 l_tool_tr_samples = self.get_toolid_samples(train_data, l_tool_freq) | 265 l_tool_tr_samples = self.get_toolid_samples(train_data, l_tool_freq) |
269 | 266 |
270 # Predict tools usage | 267 # Predict tools usage |
271 print("Predicting tools' usage...") | 268 print("Predicting tools' usage...") |
272 usage_pred = predict_tool_usage.ToolPopularity() | 269 usage_pred = predict_tool_usage.ToolPopularity() |