comparison utils.py @ 5:4f7e6612906b draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 5eebc0cb44e71f581d548b7e842002705dd155eb"
author bgruening
date Fri, 06 May 2022 09:05:18 +0000
parents afec8c595124
children e94dc7945639
comparison
equal deleted inserted replaced
4:afec8c595124 5:4f7e6612906b
1 import json
2 import random
3
4 import h5py
1 import numpy as np 5 import numpy as np
2 import json 6 import tensorflow as tf
3 import h5py
4 import random
5 from numpy.random import choice 7 from numpy.random import choice
6 8 from tensorflow.keras import backend
7 from keras import backend as K
8 9
9 10
10 def read_file(file_path): 11 def read_file(file_path):
11 """ 12 """
12 Read a file 13 Read a file
27 28
28 def set_trained_model(dump_file, model_values): 29 def set_trained_model(dump_file, model_values):
29 """ 30 """
30 Create an h5 file with the trained weights and associated dicts 31 Create an h5 file with the trained weights and associated dicts
31 """ 32 """
32 hf_file = h5py.File(dump_file, 'w') 33 hf_file = h5py.File(dump_file, "w")
33 for key in model_values: 34 for key in model_values:
34 value = model_values[key] 35 value = model_values[key]
35 if key == 'model_weights': 36 if key == "model_weights":
36 for idx, item in enumerate(value): 37 for idx, item in enumerate(value):
37 w_key = "weight_" + str(idx) 38 w_key = "weight_" + str(idx)
38 if w_key in hf_file: 39 if w_key in hf_file:
39 hf_file.modify(w_key, item) 40 hf_file.modify(w_key, item)
40 else: 41 else:
52 Create a weighted loss function. Penalise the misclassification 53 Create a weighted loss function. Penalise the misclassification
53 of classes more with the higher usage 54 of classes more with the higher usage
54 """ 55 """
55 weight_values = list(class_weights.values()) 56 weight_values = list(class_weights.values())
56 weight_values.extend(weight_values) 57 weight_values.extend(weight_values)
58
57 def weighted_binary_crossentropy(y_true, y_pred): 59 def weighted_binary_crossentropy(y_true, y_pred):
58 # add another dimension to compute dot product 60 # add another dimension to compute dot product
59 expanded_weights = K.expand_dims(weight_values, axis=-1) 61 expanded_weights = tf.expand_dims(weight_values, axis=-1)
60 return K.dot(K.binary_crossentropy(y_true, y_pred), expanded_weights) 62 bce = backend.binary_crossentropy(y_true, y_pred)
63 return backend.dot(bce, expanded_weights)
64
61 return weighted_binary_crossentropy 65 return weighted_binary_crossentropy
62 66
63 67
64 def balanced_sample_generator(train_data, train_labels, batch_size, l_tool_tr_samples, reverse_dictionary): 68 def balanced_sample_generator(
69 train_data, train_labels, batch_size, l_tool_tr_samples, reverse_dictionary
70 ):
65 while True: 71 while True:
66 dimension = train_data.shape[1] 72 dimension = train_data.shape[1]
67 n_classes = train_labels.shape[1] 73 n_classes = train_labels.shape[1]
68 tool_ids = list(l_tool_tr_samples.keys()) 74 tool_ids = list(l_tool_tr_samples.keys())
69 random.shuffle(tool_ids) 75 random.shuffle(tool_ids)
78 generator_batch_data[i] = train_data[random_tr_index] 84 generator_batch_data[i] = train_data[random_tr_index]
79 generator_batch_labels[i] = train_labels[random_tr_index] 85 generator_batch_labels[i] = train_labels[random_tr_index]
80 yield generator_batch_data, generator_batch_labels 86 yield generator_batch_data, generator_batch_labels
81 87
82 88
83 def compute_precision(model, x, y, reverse_data_dictionary, usage_scores, actual_classes_pos, topk, standard_conn, last_tool_id, lowest_tool_ids): 89 def compute_precision(
90 model,
91 x,
92 y,
93 reverse_data_dictionary,
94 usage_scores,
95 actual_classes_pos,
96 topk,
97 standard_conn,
98 last_tool_id,
99 lowest_tool_ids,
100 ):
84 """ 101 """
85 Compute absolute and compatible precision 102 Compute absolute and compatible precision
86 """ 103 """
87 pred_t_name = "" 104 pred_t_name = ""
88 top_precision = 0.0 105 top_precision = 0.0
135 if last_tool_id in lowest_tool_ids: 152 if last_tool_id in lowest_tool_ids:
136 lowest_pub_prec = 1.0 153 lowest_pub_prec = 1.0
137 else: 154 else:
138 lowest_pub_prec = np.nan 155 lowest_pub_prec = np.nan
139 if standard_topk_prediction_pos in usage_scores: 156 if standard_topk_prediction_pos in usage_scores:
140 usage_wt_score.append(np.log(usage_scores[standard_topk_prediction_pos] + 1.0)) 157 usage_wt_score.append(
158 np.log(usage_scores[standard_topk_prediction_pos] + 1.0)
159 )
141 else: 160 else:
142 # count precision only when there is actually true published tools 161 # count precision only when there is actually true published tools
143 # else set to np.nan. Set to 0 only when there is wrong prediction 162 # else set to np.nan. Set to 0 only when there is wrong prediction
144 pub_precision = np.nan 163 pub_precision = np.nan
145 lowest_pub_prec = np.nan 164 lowest_pub_prec = np.nan
146 # compute scores for normal recommendations 165 # compute scores for normal recommendations
147 if normal_topk_prediction_pos in reverse_data_dictionary: 166 if normal_topk_prediction_pos in reverse_data_dictionary:
148 pred_t_name = reverse_data_dictionary[int(normal_topk_prediction_pos)] 167 pred_t_name = reverse_data_dictionary[int(normal_topk_prediction_pos)]
149 if pred_t_name in actual_next_tool_names: 168 if pred_t_name in actual_next_tool_names:
150 if normal_topk_prediction_pos in usage_scores: 169 if normal_topk_prediction_pos in usage_scores:
151 usage_wt_score.append(np.log(usage_scores[normal_topk_prediction_pos] + 1.0)) 170 usage_wt_score.append(
171 np.log(usage_scores[normal_topk_prediction_pos] + 1.0)
172 )
152 top_precision = 1.0 173 top_precision = 1.0
153 if last_tool_id in lowest_tool_ids: 174 if last_tool_id in lowest_tool_ids:
154 lowest_norm_prec = 1.0 175 lowest_norm_prec = 1.0
155 else: 176 else:
156 lowest_norm_prec = np.nan 177 lowest_norm_prec = np.nan
164 tool_ids = list(l_tool_freq.keys()) 185 tool_ids = list(l_tool_freq.keys())
165 lowest_ids = tool_ids[-int(len(tool_ids) * fraction):] 186 lowest_ids = tool_ids[-int(len(tool_ids) * fraction):]
166 return lowest_ids 187 return lowest_ids
167 188
168 189
169 def verify_model(model, x, y, reverse_data_dictionary, usage_scores, standard_conn, lowest_tool_ids, topk_list=[1, 2, 3]): 190 def verify_model(
191 model,
192 x,
193 y,
194 reverse_data_dictionary,
195 usage_scores,
196 standard_conn,
197 lowest_tool_ids,
198 topk_list=[1, 2, 3],
199 ):
170 """ 200 """
171 Verify the model on test data 201 Verify the model on test data
172 """ 202 """
173 print("Evaluating performance on test data...") 203 print("Evaluating performance on test data...")
174 print("Test data size: %d" % len(y)) 204 print("Test data size: %d" % len(y))
185 lowest_norm_topk = list() 215 lowest_norm_topk = list()
186 actual_classes_pos = np.where(y[i] > 0)[0] 216 actual_classes_pos = np.where(y[i] > 0)[0]
187 test_sample = x[i, :] 217 test_sample = x[i, :]
188 last_tool_id = str(int(test_sample[-1])) 218 last_tool_id = str(int(test_sample[-1]))
189 for index, abs_topk in enumerate(topk_list): 219 for index, abs_topk in enumerate(topk_list):
190 usg_wt_score, absolute_precision, pub_prec, lowest_p_prec, lowest_n_prec = compute_precision(model, test_sample, y, reverse_data_dictionary, usage_scores, actual_classes_pos, abs_topk, standard_conn, last_tool_id, lowest_tool_ids) 220 (
221 usg_wt_score,
222 absolute_precision,
223 pub_prec,
224 lowest_p_prec,
225 lowest_n_prec,
226 ) = compute_precision(
227 model,
228 test_sample,
229 y,
230 reverse_data_dictionary,
231 usage_scores,
232 actual_classes_pos,
233 abs_topk,
234 standard_conn,
235 last_tool_id,
236 lowest_tool_ids,
237 )
191 precision[i][index] = absolute_precision 238 precision[i][index] = absolute_precision
192 usage_weights[i][index] = usg_wt_score 239 usage_weights[i][index] = usg_wt_score
193 epo_pub_prec[i][index] = pub_prec 240 epo_pub_prec[i][index] = pub_prec
194 lowest_pub_topk.append(lowest_p_prec) 241 lowest_pub_topk.append(lowest_p_prec)
195 lowest_norm_topk.append(lowest_n_prec) 242 lowest_norm_topk.append(lowest_n_prec)
200 mean_precision = np.mean(precision, axis=0) 247 mean_precision = np.mean(precision, axis=0)
201 mean_usage = np.mean(usage_weights, axis=0) 248 mean_usage = np.mean(usage_weights, axis=0)
202 mean_pub_prec = np.nanmean(epo_pub_prec, axis=0) 249 mean_pub_prec = np.nanmean(epo_pub_prec, axis=0)
203 mean_lowest_pub_prec = np.nanmean(epo_lowest_tools_pub_prec, axis=0) 250 mean_lowest_pub_prec = np.nanmean(epo_lowest_tools_pub_prec, axis=0)
204 mean_lowest_norm_prec = np.nanmean(epo_lowest_tools_norm_prec, axis=0) 251 mean_lowest_norm_prec = np.nanmean(epo_lowest_tools_norm_prec, axis=0)
205 return mean_usage, mean_precision, mean_pub_prec, mean_lowest_pub_prec, mean_lowest_norm_prec, lowest_counter 252 return (
206 253 mean_usage,
207 254 mean_precision,
208 def save_model(results, data_dictionary, compatible_next_tools, trained_model_path, class_weights, standard_connections): 255 mean_pub_prec,
256 mean_lowest_pub_prec,
257 mean_lowest_norm_prec,
258 lowest_counter,
259 )
260
261
262 def save_model(
263 results,
264 data_dictionary,
265 compatible_next_tools,
266 trained_model_path,
267 class_weights,
268 standard_connections,
269 ):
209 # save files 270 # save files
210 trained_model = results["model"] 271 trained_model = results["model"]
211 best_model_parameters = results["best_parameters"] 272 best_model_parameters = results["best_parameters"]
212 model_config = trained_model.to_json() 273 model_config = trained_model.to_json()
213 model_weights = trained_model.get_weights() 274 model_weights = trained_model.get_weights()
214 model_values = { 275 model_values = {
215 'data_dictionary': data_dictionary, 276 "data_dictionary": data_dictionary,
216 'model_config': model_config, 277 "model_config": model_config,
217 'best_parameters': best_model_parameters, 278 "best_parameters": best_model_parameters,
218 'model_weights': model_weights, 279 "model_weights": model_weights,
219 "compatible_tools": compatible_next_tools, 280 "compatible_tools": compatible_next_tools,
220 "class_weights": class_weights, 281 "class_weights": class_weights,
221 "standard_connections": standard_connections 282 "standard_connections": standard_connections,
222 } 283 }
223 set_trained_model(trained_model_path, model_values) 284 set_trained_model(trained_model_path, model_values)