diff utils.py @ 4:afec8c595124 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 65d36f271296a38deeceb0d0e8d471b2898ee8f4"
author bgruening
date Tue, 07 Jul 2020 03:25:49 -0400
parents 5b3c08710e47
children 4f7e6612906b
line wrap: on
line diff
--- a/utils.py	Sat May 09 05:38:23 2020 -0400
+++ b/utils.py	Tue Jul 07 03:25:49 2020 -0400
@@ -1,8 +1,8 @@
-import os
 import numpy as np
 import json
 import h5py
 import random
+from numpy.random import choice
 
 from keras import backend as K
 
@@ -54,7 +54,6 @@
     """
     weight_values = list(class_weights.values())
     weight_values.extend(weight_values)
-
     def weighted_binary_crossentropy(y_true, y_pred):
         # add another dimension to compute dot product
         expanded_weights = K.expand_dims(weight_values, axis=-1)
@@ -62,16 +61,17 @@
     return weighted_binary_crossentropy
 
 
-def balanced_sample_generator(train_data, train_labels, batch_size, l_tool_tr_samples):
+def balanced_sample_generator(train_data, train_labels, batch_size, l_tool_tr_samples, reverse_dictionary):
     while True:
         dimension = train_data.shape[1]
         n_classes = train_labels.shape[1]
         tool_ids = list(l_tool_tr_samples.keys())
+        random.shuffle(tool_ids)
         generator_batch_data = np.zeros([batch_size, dimension])
         generator_batch_labels = np.zeros([batch_size, n_classes])
+        generated_tool_ids = choice(tool_ids, batch_size)
         for i in range(batch_size):
-            random_toolid_index = random.sample(range(0, len(tool_ids)), 1)[0]
-            random_toolid = tool_ids[random_toolid_index]
+            random_toolid = generated_tool_ids[i]
             sample_indices = l_tool_tr_samples[str(random_toolid)]
             random_index = random.sample(range(0, len(sample_indices)), 1)[0]
             random_tr_index = sample_indices[random_index]
@@ -129,12 +129,20 @@
         pred_t_name = reverse_data_dictionary[int(standard_topk_prediction_pos)]
         if last_tool_name in standard_conn:
             pub_tools = standard_conn[last_tool_name]
-        if pred_t_name in pub_tools:
-            pub_precision = 1.0
-            if last_tool_id in lowest_tool_ids:
-                lowest_pub_prec = 1.0
-            if standard_topk_prediction_pos in usage_scores:
-                usage_wt_score.append(np.log(usage_scores[standard_topk_prediction_pos] + 1.0))
+            if pred_t_name in pub_tools:
+                pub_precision = 1.0
+                # count precision only when there is actually true published tools
+                if last_tool_id in lowest_tool_ids:
+                    lowest_pub_prec = 1.0
+                else:
+                    lowest_pub_prec = np.nan
+                if standard_topk_prediction_pos in usage_scores:
+                    usage_wt_score.append(np.log(usage_scores[standard_topk_prediction_pos] + 1.0))
+        else:
+            # count precision only when there is actually true published tools
+            # else set to np.nan. Set to 0 only when there is wrong prediction
+            pub_precision = np.nan
+            lowest_pub_prec = np.nan
     # compute scores for normal recommendations
     if normal_topk_prediction_pos in reverse_data_dictionary:
         pred_t_name = reverse_data_dictionary[int(normal_topk_prediction_pos)]
@@ -144,6 +152,8 @@
             top_precision = 1.0
             if last_tool_id in lowest_tool_ids:
                 lowest_norm_prec = 1.0
+            else:
+                lowest_norm_prec = np.nan
     if len(usage_wt_score) > 0:
         mean_usage = np.mean(usage_wt_score)
     return mean_usage, top_precision, pub_precision, lowest_pub_prec, lowest_norm_prec
@@ -168,7 +178,7 @@
     epo_pub_prec = np.zeros([len(y), len(topk_list)])
     epo_lowest_tools_pub_prec = list()
     epo_lowest_tools_norm_prec = list()
-
+    lowest_counter = 0
     # loop over all the test samples and find prediction precision
     for i in range(size):
         lowest_pub_topk = list()
@@ -181,18 +191,18 @@
             precision[i][index] = absolute_precision
             usage_weights[i][index] = usg_wt_score
             epo_pub_prec[i][index] = pub_prec
-            if last_tool_id in lowest_tool_ids:
-                lowest_pub_topk.append(lowest_p_prec)
-                lowest_norm_topk.append(lowest_n_prec)
+            lowest_pub_topk.append(lowest_p_prec)
+            lowest_norm_topk.append(lowest_n_prec)
+        epo_lowest_tools_pub_prec.append(lowest_pub_topk)
+        epo_lowest_tools_norm_prec.append(lowest_norm_topk)
         if last_tool_id in lowest_tool_ids:
-            epo_lowest_tools_pub_prec.append(lowest_pub_topk)
-            epo_lowest_tools_norm_prec.append(lowest_norm_topk)
+            lowest_counter += 1
     mean_precision = np.mean(precision, axis=0)
     mean_usage = np.mean(usage_weights, axis=0)
-    mean_pub_prec = np.mean(epo_pub_prec, axis=0)
-    mean_lowest_pub_prec = np.mean(epo_lowest_tools_pub_prec, axis=0)
-    mean_lowest_norm_prec = np.mean(epo_lowest_tools_norm_prec, axis=0)
-    return mean_usage, mean_precision, mean_pub_prec, mean_lowest_pub_prec, mean_lowest_norm_prec, len(epo_lowest_tools_pub_prec)
+    mean_pub_prec = np.nanmean(epo_pub_prec, axis=0)
+    mean_lowest_pub_prec = np.nanmean(epo_lowest_tools_pub_prec, axis=0)
+    mean_lowest_norm_prec = np.nanmean(epo_lowest_tools_norm_prec, axis=0)
+    return mean_usage, mean_precision, mean_pub_prec, mean_lowest_pub_prec, mean_lowest_norm_prec, lowest_counter
 
 
 def save_model(results, data_dictionary, compatible_next_tools, trained_model_path, class_weights, standard_connections):