Mercurial > repos > bgruening > create_tool_recommendation_model
comparison utils.py @ 5:4f7e6612906b draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 5eebc0cb44e71f581d548b7e842002705dd155eb"
author | bgruening |
---|---|
date | Fri, 06 May 2022 09:05:18 +0000 |
parents | afec8c595124 |
children | e94dc7945639 |
comparison
equal
deleted
inserted
replaced
4:afec8c595124 | 5:4f7e6612906b |
---|---|
1 import json | |
2 import random | |
3 | |
4 import h5py | |
1 import numpy as np | 5 import numpy as np |
2 import json | 6 import tensorflow as tf |
3 import h5py | |
4 import random | |
5 from numpy.random import choice | 7 from numpy.random import choice |
6 | 8 from tensorflow.keras import backend |
7 from keras import backend as K | |
8 | 9 |
9 | 10 |
10 def read_file(file_path): | 11 def read_file(file_path): |
11 """ | 12 """ |
12 Read a file | 13 Read a file |
27 | 28 |
28 def set_trained_model(dump_file, model_values): | 29 def set_trained_model(dump_file, model_values): |
29 """ | 30 """ |
30 Create an h5 file with the trained weights and associated dicts | 31 Create an h5 file with the trained weights and associated dicts |
31 """ | 32 """ |
32 hf_file = h5py.File(dump_file, 'w') | 33 hf_file = h5py.File(dump_file, "w") |
33 for key in model_values: | 34 for key in model_values: |
34 value = model_values[key] | 35 value = model_values[key] |
35 if key == 'model_weights': | 36 if key == "model_weights": |
36 for idx, item in enumerate(value): | 37 for idx, item in enumerate(value): |
37 w_key = "weight_" + str(idx) | 38 w_key = "weight_" + str(idx) |
38 if w_key in hf_file: | 39 if w_key in hf_file: |
39 hf_file.modify(w_key, item) | 40 hf_file.modify(w_key, item) |
40 else: | 41 else: |
52 Create a weighted loss function. Penalise the misclassification | 53 Create a weighted loss function. Penalise the misclassification |
53 of classes more with the higher usage | 54 of classes more with the higher usage |
54 """ | 55 """ |
55 weight_values = list(class_weights.values()) | 56 weight_values = list(class_weights.values()) |
56 weight_values.extend(weight_values) | 57 weight_values.extend(weight_values) |
58 | |
57 def weighted_binary_crossentropy(y_true, y_pred): | 59 def weighted_binary_crossentropy(y_true, y_pred): |
58 # add another dimension to compute dot product | 60 # add another dimension to compute dot product |
59 expanded_weights = K.expand_dims(weight_values, axis=-1) | 61 expanded_weights = tf.expand_dims(weight_values, axis=-1) |
60 return K.dot(K.binary_crossentropy(y_true, y_pred), expanded_weights) | 62 bce = backend.binary_crossentropy(y_true, y_pred) |
63 return backend.dot(bce, expanded_weights) | |
64 | |
61 return weighted_binary_crossentropy | 65 return weighted_binary_crossentropy |
62 | 66 |
63 | 67 |
64 def balanced_sample_generator(train_data, train_labels, batch_size, l_tool_tr_samples, reverse_dictionary): | 68 def balanced_sample_generator( |
69 train_data, train_labels, batch_size, l_tool_tr_samples, reverse_dictionary | |
70 ): | |
65 while True: | 71 while True: |
66 dimension = train_data.shape[1] | 72 dimension = train_data.shape[1] |
67 n_classes = train_labels.shape[1] | 73 n_classes = train_labels.shape[1] |
68 tool_ids = list(l_tool_tr_samples.keys()) | 74 tool_ids = list(l_tool_tr_samples.keys()) |
69 random.shuffle(tool_ids) | 75 random.shuffle(tool_ids) |
78 generator_batch_data[i] = train_data[random_tr_index] | 84 generator_batch_data[i] = train_data[random_tr_index] |
79 generator_batch_labels[i] = train_labels[random_tr_index] | 85 generator_batch_labels[i] = train_labels[random_tr_index] |
80 yield generator_batch_data, generator_batch_labels | 86 yield generator_batch_data, generator_batch_labels |
81 | 87 |
82 | 88 |
83 def compute_precision(model, x, y, reverse_data_dictionary, usage_scores, actual_classes_pos, topk, standard_conn, last_tool_id, lowest_tool_ids): | 89 def compute_precision( |
90 model, | |
91 x, | |
92 y, | |
93 reverse_data_dictionary, | |
94 usage_scores, | |
95 actual_classes_pos, | |
96 topk, | |
97 standard_conn, | |
98 last_tool_id, | |
99 lowest_tool_ids, | |
100 ): | |
84 """ | 101 """ |
85 Compute absolute and compatible precision | 102 Compute absolute and compatible precision |
86 """ | 103 """ |
87 pred_t_name = "" | 104 pred_t_name = "" |
88 top_precision = 0.0 | 105 top_precision = 0.0 |
135 if last_tool_id in lowest_tool_ids: | 152 if last_tool_id in lowest_tool_ids: |
136 lowest_pub_prec = 1.0 | 153 lowest_pub_prec = 1.0 |
137 else: | 154 else: |
138 lowest_pub_prec = np.nan | 155 lowest_pub_prec = np.nan |
139 if standard_topk_prediction_pos in usage_scores: | 156 if standard_topk_prediction_pos in usage_scores: |
140 usage_wt_score.append(np.log(usage_scores[standard_topk_prediction_pos] + 1.0)) | 157 usage_wt_score.append( |
158 np.log(usage_scores[standard_topk_prediction_pos] + 1.0) | |
159 ) | |
141 else: | 160 else: |
142 # count precision only when there is actually true published tools | 161 # count precision only when there is actually true published tools |
143 # else set to np.nan. Set to 0 only when there is wrong prediction | 162 # else set to np.nan. Set to 0 only when there is wrong prediction |
144 pub_precision = np.nan | 163 pub_precision = np.nan |
145 lowest_pub_prec = np.nan | 164 lowest_pub_prec = np.nan |
146 # compute scores for normal recommendations | 165 # compute scores for normal recommendations |
147 if normal_topk_prediction_pos in reverse_data_dictionary: | 166 if normal_topk_prediction_pos in reverse_data_dictionary: |
148 pred_t_name = reverse_data_dictionary[int(normal_topk_prediction_pos)] | 167 pred_t_name = reverse_data_dictionary[int(normal_topk_prediction_pos)] |
149 if pred_t_name in actual_next_tool_names: | 168 if pred_t_name in actual_next_tool_names: |
150 if normal_topk_prediction_pos in usage_scores: | 169 if normal_topk_prediction_pos in usage_scores: |
151 usage_wt_score.append(np.log(usage_scores[normal_topk_prediction_pos] + 1.0)) | 170 usage_wt_score.append( |
171 np.log(usage_scores[normal_topk_prediction_pos] + 1.0) | |
172 ) | |
152 top_precision = 1.0 | 173 top_precision = 1.0 |
153 if last_tool_id in lowest_tool_ids: | 174 if last_tool_id in lowest_tool_ids: |
154 lowest_norm_prec = 1.0 | 175 lowest_norm_prec = 1.0 |
155 else: | 176 else: |
156 lowest_norm_prec = np.nan | 177 lowest_norm_prec = np.nan |
164 tool_ids = list(l_tool_freq.keys()) | 185 tool_ids = list(l_tool_freq.keys()) |
165 lowest_ids = tool_ids[-int(len(tool_ids) * fraction):] | 186 lowest_ids = tool_ids[-int(len(tool_ids) * fraction):] |
166 return lowest_ids | 187 return lowest_ids |
167 | 188 |
168 | 189 |
169 def verify_model(model, x, y, reverse_data_dictionary, usage_scores, standard_conn, lowest_tool_ids, topk_list=[1, 2, 3]): | 190 def verify_model( |
191 model, | |
192 x, | |
193 y, | |
194 reverse_data_dictionary, | |
195 usage_scores, | |
196 standard_conn, | |
197 lowest_tool_ids, | |
198 topk_list=[1, 2, 3], | |
199 ): | |
170 """ | 200 """ |
171 Verify the model on test data | 201 Verify the model on test data |
172 """ | 202 """ |
173 print("Evaluating performance on test data...") | 203 print("Evaluating performance on test data...") |
174 print("Test data size: %d" % len(y)) | 204 print("Test data size: %d" % len(y)) |
185 lowest_norm_topk = list() | 215 lowest_norm_topk = list() |
186 actual_classes_pos = np.where(y[i] > 0)[0] | 216 actual_classes_pos = np.where(y[i] > 0)[0] |
187 test_sample = x[i, :] | 217 test_sample = x[i, :] |
188 last_tool_id = str(int(test_sample[-1])) | 218 last_tool_id = str(int(test_sample[-1])) |
189 for index, abs_topk in enumerate(topk_list): | 219 for index, abs_topk in enumerate(topk_list): |
190 usg_wt_score, absolute_precision, pub_prec, lowest_p_prec, lowest_n_prec = compute_precision(model, test_sample, y, reverse_data_dictionary, usage_scores, actual_classes_pos, abs_topk, standard_conn, last_tool_id, lowest_tool_ids) | 220 ( |
221 usg_wt_score, | |
222 absolute_precision, | |
223 pub_prec, | |
224 lowest_p_prec, | |
225 lowest_n_prec, | |
226 ) = compute_precision( | |
227 model, | |
228 test_sample, | |
229 y, | |
230 reverse_data_dictionary, | |
231 usage_scores, | |
232 actual_classes_pos, | |
233 abs_topk, | |
234 standard_conn, | |
235 last_tool_id, | |
236 lowest_tool_ids, | |
237 ) | |
191 precision[i][index] = absolute_precision | 238 precision[i][index] = absolute_precision |
192 usage_weights[i][index] = usg_wt_score | 239 usage_weights[i][index] = usg_wt_score |
193 epo_pub_prec[i][index] = pub_prec | 240 epo_pub_prec[i][index] = pub_prec |
194 lowest_pub_topk.append(lowest_p_prec) | 241 lowest_pub_topk.append(lowest_p_prec) |
195 lowest_norm_topk.append(lowest_n_prec) | 242 lowest_norm_topk.append(lowest_n_prec) |
200 mean_precision = np.mean(precision, axis=0) | 247 mean_precision = np.mean(precision, axis=0) |
201 mean_usage = np.mean(usage_weights, axis=0) | 248 mean_usage = np.mean(usage_weights, axis=0) |
202 mean_pub_prec = np.nanmean(epo_pub_prec, axis=0) | 249 mean_pub_prec = np.nanmean(epo_pub_prec, axis=0) |
203 mean_lowest_pub_prec = np.nanmean(epo_lowest_tools_pub_prec, axis=0) | 250 mean_lowest_pub_prec = np.nanmean(epo_lowest_tools_pub_prec, axis=0) |
204 mean_lowest_norm_prec = np.nanmean(epo_lowest_tools_norm_prec, axis=0) | 251 mean_lowest_norm_prec = np.nanmean(epo_lowest_tools_norm_prec, axis=0) |
205 return mean_usage, mean_precision, mean_pub_prec, mean_lowest_pub_prec, mean_lowest_norm_prec, lowest_counter | 252 return ( |
206 | 253 mean_usage, |
207 | 254 mean_precision, |
208 def save_model(results, data_dictionary, compatible_next_tools, trained_model_path, class_weights, standard_connections): | 255 mean_pub_prec, |
256 mean_lowest_pub_prec, | |
257 mean_lowest_norm_prec, | |
258 lowest_counter, | |
259 ) | |
260 | |
261 | |
262 def save_model( | |
263 results, | |
264 data_dictionary, | |
265 compatible_next_tools, | |
266 trained_model_path, | |
267 class_weights, | |
268 standard_connections, | |
269 ): | |
209 # save files | 270 # save files |
210 trained_model = results["model"] | 271 trained_model = results["model"] |
211 best_model_parameters = results["best_parameters"] | 272 best_model_parameters = results["best_parameters"] |
212 model_config = trained_model.to_json() | 273 model_config = trained_model.to_json() |
213 model_weights = trained_model.get_weights() | 274 model_weights = trained_model.get_weights() |
214 model_values = { | 275 model_values = { |
215 'data_dictionary': data_dictionary, | 276 "data_dictionary": data_dictionary, |
216 'model_config': model_config, | 277 "model_config": model_config, |
217 'best_parameters': best_model_parameters, | 278 "best_parameters": best_model_parameters, |
218 'model_weights': model_weights, | 279 "model_weights": model_weights, |
219 "compatible_tools": compatible_next_tools, | 280 "compatible_tools": compatible_next_tools, |
220 "class_weights": class_weights, | 281 "class_weights": class_weights, |
221 "standard_connections": standard_connections | 282 "standard_connections": standard_connections, |
222 } | 283 } |
223 set_trained_model(trained_model_path, model_values) | 284 set_trained_model(trained_model_path, model_values) |