Mercurial > repos > bgruening > create_tool_recommendation_model
diff prepare_data.py @ 4:afec8c595124 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 65d36f271296a38deeceb0d0e8d471b2898ee8f4"
author | bgruening |
---|---|
date | Tue, 07 Jul 2020 03:25:49 -0400 |
parents | 5b3c08710e47 |
children | 4f7e6612906b |
line wrap: on
line diff
--- a/prepare_data.py Sat May 09 05:38:23 2020 -0400 +++ b/prepare_data.py Tue Jul 07 03:25:49 2020 -0400 @@ -10,7 +10,6 @@ import random import predict_tool_usage -import utils main_path = os.getcwd() @@ -211,16 +210,15 @@ to estimate the frequency of tool sequences """ last_tool_freq = dict() - inv_freq = dict() + freq_dict_names = dict() for path in train_paths: last_tool = path.split(",")[-1] if last_tool not in last_tool_freq: last_tool_freq[last_tool] = 0 + freq_dict_names[reverse_dictionary[int(last_tool)]] = 0 last_tool_freq[last_tool] += 1 - max_freq = max(last_tool_freq.values()) - for t in last_tool_freq: - inv_freq[t] = int(np.round(max_freq / float(last_tool_freq[t]), 0)) - return last_tool_freq, inv_freq + freq_dict_names[reverse_dictionary[int(last_tool)]] += 1 + return last_tool_freq def get_toolid_samples(self, train_data, l_tool_freq): l_tool_tr_samples = dict() @@ -254,9 +252,6 @@ print("Complete data: %d" % len(multilabels_paths)) train_paths_dict, test_paths_dict = self.split_test_train_data(multilabels_paths) - # get sample frequency - l_tool_freq, inv_last_tool_freq = self.get_train_last_tool_freq(train_paths_dict, rev_dict) - print("Train data: %d" % len(train_paths_dict)) print("Test data: %d" % len(test_paths_dict)) @@ -265,6 +260,8 @@ test_data, test_labels = self.pad_paths(test_paths_dict, num_classes, standard_connections, rev_dict) train_data, train_labels = self.pad_paths(train_paths_dict, num_classes, standard_connections, rev_dict) + print("Estimating sample frequency...") + l_tool_freq = self.get_train_last_tool_freq(train_paths_dict, rev_dict) l_tool_tr_samples = self.get_toolid_samples(train_data, l_tool_freq) # Predict tools usage