diff prepare_data.py @ 4:afec8c595124 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 65d36f271296a38deeceb0d0e8d471b2898ee8f4"
author bgruening
date Tue, 07 Jul 2020 03:25:49 -0400
parents 5b3c08710e47
children 4f7e6612906b
line wrap: on
line diff
--- a/prepare_data.py	Sat May 09 05:38:23 2020 -0400
+++ b/prepare_data.py	Tue Jul 07 03:25:49 2020 -0400
@@ -10,7 +10,6 @@
 import random
 
 import predict_tool_usage
-import utils
 
 main_path = os.getcwd()
 
@@ -211,16 +210,15 @@
         to estimate the frequency of tool sequences
         """
         last_tool_freq = dict()
-        inv_freq = dict()
+        freq_dict_names = dict()
         for path in train_paths:
             last_tool = path.split(",")[-1]
             if last_tool not in last_tool_freq:
                 last_tool_freq[last_tool] = 0
+                freq_dict_names[reverse_dictionary[int(last_tool)]] = 0
             last_tool_freq[last_tool] += 1
-        max_freq = max(last_tool_freq.values())
-        for t in last_tool_freq:
-            inv_freq[t] = int(np.round(max_freq / float(last_tool_freq[t]), 0))
-        return last_tool_freq, inv_freq
+            freq_dict_names[reverse_dictionary[int(last_tool)]] += 1
+        return last_tool_freq
 
     def get_toolid_samples(self, train_data, l_tool_freq):
         l_tool_tr_samples = dict()
@@ -254,9 +252,6 @@
         print("Complete data: %d" % len(multilabels_paths))
         train_paths_dict, test_paths_dict = self.split_test_train_data(multilabels_paths)
 
-        # get sample frequency
-        l_tool_freq, inv_last_tool_freq = self.get_train_last_tool_freq(train_paths_dict, rev_dict)
-
         print("Train data: %d" % len(train_paths_dict))
         print("Test data: %d" % len(test_paths_dict))
 
@@ -265,6 +260,8 @@
         test_data, test_labels = self.pad_paths(test_paths_dict, num_classes, standard_connections, rev_dict)
         train_data, train_labels = self.pad_paths(train_paths_dict, num_classes, standard_connections, rev_dict)
 
+        print("Estimating sample frequency...")
+        l_tool_freq = self.get_train_last_tool_freq(train_paths_dict, rev_dict)
         l_tool_tr_samples = self.get_toolid_samples(train_data, l_tool_freq)
 
         # Predict tools usage