comparison utils.py @ 3:5b3c08710e47 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
author bgruening
date Sat, 09 May 2020 05:38:23 -0400
parents 76251d1ccdcc
children afec8c595124
comparison
equal deleted inserted replaced
2:76251d1ccdcc 3:5b3c08710e47
1 import os 1 import os
2 import numpy as np 2 import numpy as np
3 import json 3 import json
4 import h5py 4 import h5py
5 import random
5 6
6 from keras import backend as K 7 from keras import backend as K
7 8
8 9
9 def read_file(file_path): 10 def read_file(file_path):
11 Read a file 12 Read a file
12 """ 13 """
13 with open(file_path, "r") as json_file: 14 with open(file_path, "r") as json_file:
14 file_content = json.loads(json_file.read()) 15 file_content = json.loads(json_file.read())
15 return file_content 16 return file_content
16
17
18 def write_file(file_path, content):
19 """
20 Write a file
21 """
22 remove_file(file_path)
23 with open(file_path, "w") as json_file:
24 json_file.write(json.dumps(content))
25
26
27 def save_processed_workflows(file_path, unique_paths):
28 workflow_paths_unique = ""
29 for path in unique_paths:
30 workflow_paths_unique += path + "\n"
31 with open(file_path, "w") as workflows_file:
32 workflows_file.write(workflow_paths_unique)
33 17
34 18
35 def format_tool_id(tool_link): 19 def format_tool_id(tool_link):
36 """ 20 """
37 Extract tool id from tool link 21 Extract tool id from tool link
61 else: 45 else:
62 hf_file.create_dataset(key, data=json.dumps(value)) 46 hf_file.create_dataset(key, data=json.dumps(value))
63 hf_file.close() 47 hf_file.close()
64 48
65 49
66 def remove_file(file_path):
67 if os.path.exists(file_path):
68 os.remove(file_path)
69
70
71 def weighted_loss(class_weights): 50 def weighted_loss(class_weights):
72 """ 51 """
73 Create a weighted loss function. Penalise the misclassification 52 Create a weighted loss function. Penalise the misclassification
74 of classes more with the higher usage 53 of classes more with the higher usage
75 """ 54 """
76 weight_values = list(class_weights.values()) 55 weight_values = list(class_weights.values())
56 weight_values.extend(weight_values)
77 57
78 def weighted_binary_crossentropy(y_true, y_pred): 58 def weighted_binary_crossentropy(y_true, y_pred):
79 # add another dimension to compute dot product 59 # add another dimension to compute dot product
80 expanded_weights = K.expand_dims(weight_values, axis=-1) 60 expanded_weights = K.expand_dims(weight_values, axis=-1)
81 return K.dot(K.binary_crossentropy(y_true, y_pred), expanded_weights) 61 return K.dot(K.binary_crossentropy(y_true, y_pred), expanded_weights)
82 return weighted_binary_crossentropy 62 return weighted_binary_crossentropy
83 63
84 64
85 def compute_precision(model, x, y, reverse_data_dictionary, next_compatible_tools, usage_scores, actual_classes_pos, topk): 65 def balanced_sample_generator(train_data, train_labels, batch_size, l_tool_tr_samples):
66 while True:
67 dimension = train_data.shape[1]
68 n_classes = train_labels.shape[1]
69 tool_ids = list(l_tool_tr_samples.keys())
70 generator_batch_data = np.zeros([batch_size, dimension])
71 generator_batch_labels = np.zeros([batch_size, n_classes])
72 for i in range(batch_size):
73 random_toolid_index = random.sample(range(0, len(tool_ids)), 1)[0]
74 random_toolid = tool_ids[random_toolid_index]
75 sample_indices = l_tool_tr_samples[str(random_toolid)]
76 random_index = random.sample(range(0, len(sample_indices)), 1)[0]
77 random_tr_index = sample_indices[random_index]
78 generator_batch_data[i] = train_data[random_tr_index]
79 generator_batch_labels[i] = train_labels[random_tr_index]
80 yield generator_batch_data, generator_batch_labels
81
82
83 def compute_precision(model, x, y, reverse_data_dictionary, usage_scores, actual_classes_pos, topk, standard_conn, last_tool_id, lowest_tool_ids):
86 """ 84 """
87 Compute absolute and compatible precision 85 Compute absolute and compatible precision
88 """ 86 """
89 absolute_precision = 0.0 87 pred_t_name = ""
88 top_precision = 0.0
89 mean_usage = 0.0
90 usage_wt_score = list()
91 pub_precision = 0.0
92 lowest_pub_prec = 0.0
93 lowest_norm_prec = 0.0
94 pub_tools = list()
95 actual_next_tool_names = list()
90 test_sample = np.reshape(x, (1, len(x))) 96 test_sample = np.reshape(x, (1, len(x)))
91 97
92 # predict next tools for a test path 98 # predict next tools for a test path
93 prediction = model.predict(test_sample, verbose=0) 99 prediction = model.predict(test_sample, verbose=0)
94 100
101 # divide the predicted vector into two halves - one for published and
102 # another for normal workflows
95 nw_dimension = prediction.shape[1] 103 nw_dimension = prediction.shape[1]
96 104 half_len = int(nw_dimension / 2)
97 # remove the 0th position as there is no tool at this index 105
106 # predict tools
98 prediction = np.reshape(prediction, (nw_dimension,)) 107 prediction = np.reshape(prediction, (nw_dimension,))
99 108 # get predictions of tools from published workflows
100 prediction_pos = np.argsort(prediction, axis=-1) 109 standard_pred = prediction[:half_len]
101 topk_prediction_pos = prediction_pos[-topk:] 110 # get predictions of tools from normal workflows
102 111 normal_pred = prediction[half_len:]
103 # remove the wrong tool position from the predicted list of tool positions 112
104 topk_prediction_pos = [x for x in topk_prediction_pos if x > 0] 113 standard_prediction_pos = np.argsort(standard_pred, axis=-1)
105 114 standard_topk_prediction_pos = standard_prediction_pos[-topk]
106 # read tool names using reverse dictionary 115
107 actual_next_tool_names = [reverse_data_dictionary[int(tool_pos)] for tool_pos in actual_classes_pos] 116 normal_prediction_pos = np.argsort(normal_pred, axis=-1)
108 top_predicted_next_tool_names = [reverse_data_dictionary[int(tool_pos)] for tool_pos in topk_prediction_pos] 117 normal_topk_prediction_pos = normal_prediction_pos[-topk]
109 118
110 # compute the class weights of predicted tools 119 # get true tools names
111 mean_usg_score = 0 120 for a_t_pos in actual_classes_pos:
112 usg_wt_scores = list() 121 if a_t_pos > half_len:
113 for t_id in topk_prediction_pos: 122 t_name = reverse_data_dictionary[int(a_t_pos - half_len)]
114 t_name = reverse_data_dictionary[int(t_id)] 123 else:
115 if t_id in usage_scores and t_name in actual_next_tool_names: 124 t_name = reverse_data_dictionary[int(a_t_pos)]
116 usg_wt_scores.append(np.log(usage_scores[t_id] + 1.0)) 125 actual_next_tool_names.append(t_name)
117 if len(usg_wt_scores) > 0: 126 last_tool_name = reverse_data_dictionary[x[-1]]
118 mean_usg_score = np.sum(usg_wt_scores) / float(topk) 127 # compute scores for published recommendations
119 false_positives = [tool_name for tool_name in top_predicted_next_tool_names if tool_name not in actual_next_tool_names] 128 if standard_topk_prediction_pos in reverse_data_dictionary:
120 absolute_precision = 1 - (len(false_positives) / float(topk)) 129 pred_t_name = reverse_data_dictionary[int(standard_topk_prediction_pos)]
121 return mean_usg_score, absolute_precision 130 if last_tool_name in standard_conn:
122 131 pub_tools = standard_conn[last_tool_name]
123 132 if pred_t_name in pub_tools:
124 def verify_model(model, x, y, reverse_data_dictionary, next_compatible_tools, usage_scores, topk_list=[1, 2, 3]): 133 pub_precision = 1.0
134 if last_tool_id in lowest_tool_ids:
135 lowest_pub_prec = 1.0
136 if standard_topk_prediction_pos in usage_scores:
137 usage_wt_score.append(np.log(usage_scores[standard_topk_prediction_pos] + 1.0))
138 # compute scores for normal recommendations
139 if normal_topk_prediction_pos in reverse_data_dictionary:
140 pred_t_name = reverse_data_dictionary[int(normal_topk_prediction_pos)]
141 if pred_t_name in actual_next_tool_names:
142 if normal_topk_prediction_pos in usage_scores:
143 usage_wt_score.append(np.log(usage_scores[normal_topk_prediction_pos] + 1.0))
144 top_precision = 1.0
145 if last_tool_id in lowest_tool_ids:
146 lowest_norm_prec = 1.0
147 if len(usage_wt_score) > 0:
148 mean_usage = np.mean(usage_wt_score)
149 return mean_usage, top_precision, pub_precision, lowest_pub_prec, lowest_norm_prec
150
151
152 def get_lowest_tools(l_tool_freq, fraction=0.25):
153 l_tool_freq = dict(sorted(l_tool_freq.items(), key=lambda kv: kv[1], reverse=True))
154 tool_ids = list(l_tool_freq.keys())
155 lowest_ids = tool_ids[-int(len(tool_ids) * fraction):]
156 return lowest_ids
157
158
159 def verify_model(model, x, y, reverse_data_dictionary, usage_scores, standard_conn, lowest_tool_ids, topk_list=[1, 2, 3]):
125 """ 160 """
126 Verify the model on test data 161 Verify the model on test data
127 """ 162 """
128 print("Evaluating performance on test data...") 163 print("Evaluating performance on test data...")
129 print("Test data size: %d" % len(y)) 164 print("Test data size: %d" % len(y))
130 size = y.shape[0] 165 size = y.shape[0]
131 precision = np.zeros([len(y), len(topk_list)]) 166 precision = np.zeros([len(y), len(topk_list)])
132 usage_weights = np.zeros([len(y), len(topk_list)]) 167 usage_weights = np.zeros([len(y), len(topk_list)])
168 epo_pub_prec = np.zeros([len(y), len(topk_list)])
169 epo_lowest_tools_pub_prec = list()
170 epo_lowest_tools_norm_prec = list()
171
133 # loop over all the test samples and find prediction precision 172 # loop over all the test samples and find prediction precision
134 for i in range(size): 173 for i in range(size):
174 lowest_pub_topk = list()
175 lowest_norm_topk = list()
135 actual_classes_pos = np.where(y[i] > 0)[0] 176 actual_classes_pos = np.where(y[i] > 0)[0]
177 test_sample = x[i, :]
178 last_tool_id = str(int(test_sample[-1]))
136 for index, abs_topk in enumerate(topk_list): 179 for index, abs_topk in enumerate(topk_list):
137 abs_mean_usg_score, absolute_precision = compute_precision(model, x[i, :], y, reverse_data_dictionary, next_compatible_tools, usage_scores, actual_classes_pos, abs_topk) 180 usg_wt_score, absolute_precision, pub_prec, lowest_p_prec, lowest_n_prec = compute_precision(model, test_sample, y, reverse_data_dictionary, usage_scores, actual_classes_pos, abs_topk, standard_conn, last_tool_id, lowest_tool_ids)
138 precision[i][index] = absolute_precision 181 precision[i][index] = absolute_precision
139 usage_weights[i][index] = abs_mean_usg_score 182 usage_weights[i][index] = usg_wt_score
183 epo_pub_prec[i][index] = pub_prec
184 if last_tool_id in lowest_tool_ids:
185 lowest_pub_topk.append(lowest_p_prec)
186 lowest_norm_topk.append(lowest_n_prec)
187 if last_tool_id in lowest_tool_ids:
188 epo_lowest_tools_pub_prec.append(lowest_pub_topk)
189 epo_lowest_tools_norm_prec.append(lowest_norm_topk)
140 mean_precision = np.mean(precision, axis=0) 190 mean_precision = np.mean(precision, axis=0)
141 mean_usage = np.mean(usage_weights, axis=0) 191 mean_usage = np.mean(usage_weights, axis=0)
142 return mean_precision, mean_usage 192 mean_pub_prec = np.mean(epo_pub_prec, axis=0)
143 193 mean_lowest_pub_prec = np.mean(epo_lowest_tools_pub_prec, axis=0)
144 194 mean_lowest_norm_prec = np.mean(epo_lowest_tools_norm_prec, axis=0)
145 def save_model(results, data_dictionary, compatible_next_tools, trained_model_path, class_weights): 195 return mean_usage, mean_precision, mean_pub_prec, mean_lowest_pub_prec, mean_lowest_norm_prec, len(epo_lowest_tools_pub_prec)
196
197
198 def save_model(results, data_dictionary, compatible_next_tools, trained_model_path, class_weights, standard_connections):
146 # save files 199 # save files
147 trained_model = results["model"] 200 trained_model = results["model"]
148 best_model_parameters = results["best_parameters"] 201 best_model_parameters = results["best_parameters"]
149 model_config = trained_model.to_json() 202 model_config = trained_model.to_json()
150 model_weights = trained_model.get_weights() 203 model_weights = trained_model.get_weights()
151
152 model_values = { 204 model_values = {
153 'data_dictionary': data_dictionary, 205 'data_dictionary': data_dictionary,
154 'model_config': model_config, 206 'model_config': model_config,
155 'best_parameters': best_model_parameters, 207 'best_parameters': best_model_parameters,
156 'model_weights': model_weights, 208 'model_weights': model_weights,
157 "compatible_tools": compatible_next_tools, 209 "compatible_tools": compatible_next_tools,
158 "class_weights": class_weights 210 "class_weights": class_weights,
211 "standard_connections": standard_connections
159 } 212 }
160 set_trained_model(trained_model_path, model_values) 213 set_trained_model(trained_model_path, model_values)