Mercurial > repos > bgruening > create_tool_recommendation_model
comparison utils.py @ 5:4f7e6612906b draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 5eebc0cb44e71f581d548b7e842002705dd155eb"
| author | bgruening |
|---|---|
| date | Fri, 06 May 2022 09:05:18 +0000 |
| parents | afec8c595124 |
| children | e94dc7945639 |
comparison
equal
deleted
inserted
replaced
| 4:afec8c595124 | 5:4f7e6612906b |
|---|---|
| 1 import json | |
| 2 import random | |
| 3 | |
| 4 import h5py | |
| 1 import numpy as np | 5 import numpy as np |
| 2 import json | 6 import tensorflow as tf |
| 3 import h5py | |
| 4 import random | |
| 5 from numpy.random import choice | 7 from numpy.random import choice |
| 6 | 8 from tensorflow.keras import backend |
| 7 from keras import backend as K | |
| 8 | 9 |
| 9 | 10 |
| 10 def read_file(file_path): | 11 def read_file(file_path): |
| 11 """ | 12 """ |
| 12 Read a file | 13 Read a file |
| 27 | 28 |
| 28 def set_trained_model(dump_file, model_values): | 29 def set_trained_model(dump_file, model_values): |
| 29 """ | 30 """ |
| 30 Create an h5 file with the trained weights and associated dicts | 31 Create an h5 file with the trained weights and associated dicts |
| 31 """ | 32 """ |
| 32 hf_file = h5py.File(dump_file, 'w') | 33 hf_file = h5py.File(dump_file, "w") |
| 33 for key in model_values: | 34 for key in model_values: |
| 34 value = model_values[key] | 35 value = model_values[key] |
| 35 if key == 'model_weights': | 36 if key == "model_weights": |
| 36 for idx, item in enumerate(value): | 37 for idx, item in enumerate(value): |
| 37 w_key = "weight_" + str(idx) | 38 w_key = "weight_" + str(idx) |
| 38 if w_key in hf_file: | 39 if w_key in hf_file: |
| 39 hf_file.modify(w_key, item) | 40 hf_file.modify(w_key, item) |
| 40 else: | 41 else: |
| 52 Create a weighted loss function. Penalise the misclassification | 53 Create a weighted loss function. Penalise the misclassification |
| 53 of classes more with the higher usage | 54 of classes more with the higher usage |
| 54 """ | 55 """ |
| 55 weight_values = list(class_weights.values()) | 56 weight_values = list(class_weights.values()) |
| 56 weight_values.extend(weight_values) | 57 weight_values.extend(weight_values) |
| 58 | |
| 57 def weighted_binary_crossentropy(y_true, y_pred): | 59 def weighted_binary_crossentropy(y_true, y_pred): |
| 58 # add another dimension to compute dot product | 60 # add another dimension to compute dot product |
| 59 expanded_weights = K.expand_dims(weight_values, axis=-1) | 61 expanded_weights = tf.expand_dims(weight_values, axis=-1) |
| 60 return K.dot(K.binary_crossentropy(y_true, y_pred), expanded_weights) | 62 bce = backend.binary_crossentropy(y_true, y_pred) |
| 63 return backend.dot(bce, expanded_weights) | |
| 64 | |
| 61 return weighted_binary_crossentropy | 65 return weighted_binary_crossentropy |
| 62 | 66 |
| 63 | 67 |
| 64 def balanced_sample_generator(train_data, train_labels, batch_size, l_tool_tr_samples, reverse_dictionary): | 68 def balanced_sample_generator( |
| 69 train_data, train_labels, batch_size, l_tool_tr_samples, reverse_dictionary | |
| 70 ): | |
| 65 while True: | 71 while True: |
| 66 dimension = train_data.shape[1] | 72 dimension = train_data.shape[1] |
| 67 n_classes = train_labels.shape[1] | 73 n_classes = train_labels.shape[1] |
| 68 tool_ids = list(l_tool_tr_samples.keys()) | 74 tool_ids = list(l_tool_tr_samples.keys()) |
| 69 random.shuffle(tool_ids) | 75 random.shuffle(tool_ids) |
| 78 generator_batch_data[i] = train_data[random_tr_index] | 84 generator_batch_data[i] = train_data[random_tr_index] |
| 79 generator_batch_labels[i] = train_labels[random_tr_index] | 85 generator_batch_labels[i] = train_labels[random_tr_index] |
| 80 yield generator_batch_data, generator_batch_labels | 86 yield generator_batch_data, generator_batch_labels |
| 81 | 87 |
| 82 | 88 |
| 83 def compute_precision(model, x, y, reverse_data_dictionary, usage_scores, actual_classes_pos, topk, standard_conn, last_tool_id, lowest_tool_ids): | 89 def compute_precision( |
| 90 model, | |
| 91 x, | |
| 92 y, | |
| 93 reverse_data_dictionary, | |
| 94 usage_scores, | |
| 95 actual_classes_pos, | |
| 96 topk, | |
| 97 standard_conn, | |
| 98 last_tool_id, | |
| 99 lowest_tool_ids, | |
| 100 ): | |
| 84 """ | 101 """ |
| 85 Compute absolute and compatible precision | 102 Compute absolute and compatible precision |
| 86 """ | 103 """ |
| 87 pred_t_name = "" | 104 pred_t_name = "" |
| 88 top_precision = 0.0 | 105 top_precision = 0.0 |
| 135 if last_tool_id in lowest_tool_ids: | 152 if last_tool_id in lowest_tool_ids: |
| 136 lowest_pub_prec = 1.0 | 153 lowest_pub_prec = 1.0 |
| 137 else: | 154 else: |
| 138 lowest_pub_prec = np.nan | 155 lowest_pub_prec = np.nan |
| 139 if standard_topk_prediction_pos in usage_scores: | 156 if standard_topk_prediction_pos in usage_scores: |
| 140 usage_wt_score.append(np.log(usage_scores[standard_topk_prediction_pos] + 1.0)) | 157 usage_wt_score.append( |
| 158 np.log(usage_scores[standard_topk_prediction_pos] + 1.0) | |
| 159 ) | |
| 141 else: | 160 else: |
| 142 # count precision only when there is actually true published tools | 161 # count precision only when there is actually true published tools |
| 143 # else set to np.nan. Set to 0 only when there is wrong prediction | 162 # else set to np.nan. Set to 0 only when there is wrong prediction |
| 144 pub_precision = np.nan | 163 pub_precision = np.nan |
| 145 lowest_pub_prec = np.nan | 164 lowest_pub_prec = np.nan |
| 146 # compute scores for normal recommendations | 165 # compute scores for normal recommendations |
| 147 if normal_topk_prediction_pos in reverse_data_dictionary: | 166 if normal_topk_prediction_pos in reverse_data_dictionary: |
| 148 pred_t_name = reverse_data_dictionary[int(normal_topk_prediction_pos)] | 167 pred_t_name = reverse_data_dictionary[int(normal_topk_prediction_pos)] |
| 149 if pred_t_name in actual_next_tool_names: | 168 if pred_t_name in actual_next_tool_names: |
| 150 if normal_topk_prediction_pos in usage_scores: | 169 if normal_topk_prediction_pos in usage_scores: |
| 151 usage_wt_score.append(np.log(usage_scores[normal_topk_prediction_pos] + 1.0)) | 170 usage_wt_score.append( |
| 171 np.log(usage_scores[normal_topk_prediction_pos] + 1.0) | |
| 172 ) | |
| 152 top_precision = 1.0 | 173 top_precision = 1.0 |
| 153 if last_tool_id in lowest_tool_ids: | 174 if last_tool_id in lowest_tool_ids: |
| 154 lowest_norm_prec = 1.0 | 175 lowest_norm_prec = 1.0 |
| 155 else: | 176 else: |
| 156 lowest_norm_prec = np.nan | 177 lowest_norm_prec = np.nan |
| 164 tool_ids = list(l_tool_freq.keys()) | 185 tool_ids = list(l_tool_freq.keys()) |
| 165 lowest_ids = tool_ids[-int(len(tool_ids) * fraction):] | 186 lowest_ids = tool_ids[-int(len(tool_ids) * fraction):] |
| 166 return lowest_ids | 187 return lowest_ids |
| 167 | 188 |
| 168 | 189 |
| 169 def verify_model(model, x, y, reverse_data_dictionary, usage_scores, standard_conn, lowest_tool_ids, topk_list=[1, 2, 3]): | 190 def verify_model( |
| 191 model, | |
| 192 x, | |
| 193 y, | |
| 194 reverse_data_dictionary, | |
| 195 usage_scores, | |
| 196 standard_conn, | |
| 197 lowest_tool_ids, | |
| 198 topk_list=[1, 2, 3], | |
| 199 ): | |
| 170 """ | 200 """ |
| 171 Verify the model on test data | 201 Verify the model on test data |
| 172 """ | 202 """ |
| 173 print("Evaluating performance on test data...") | 203 print("Evaluating performance on test data...") |
| 174 print("Test data size: %d" % len(y)) | 204 print("Test data size: %d" % len(y)) |
| 185 lowest_norm_topk = list() | 215 lowest_norm_topk = list() |
| 186 actual_classes_pos = np.where(y[i] > 0)[0] | 216 actual_classes_pos = np.where(y[i] > 0)[0] |
| 187 test_sample = x[i, :] | 217 test_sample = x[i, :] |
| 188 last_tool_id = str(int(test_sample[-1])) | 218 last_tool_id = str(int(test_sample[-1])) |
| 189 for index, abs_topk in enumerate(topk_list): | 219 for index, abs_topk in enumerate(topk_list): |
| 190 usg_wt_score, absolute_precision, pub_prec, lowest_p_prec, lowest_n_prec = compute_precision(model, test_sample, y, reverse_data_dictionary, usage_scores, actual_classes_pos, abs_topk, standard_conn, last_tool_id, lowest_tool_ids) | 220 ( |
| 221 usg_wt_score, | |
| 222 absolute_precision, | |
| 223 pub_prec, | |
| 224 lowest_p_prec, | |
| 225 lowest_n_prec, | |
| 226 ) = compute_precision( | |
| 227 model, | |
| 228 test_sample, | |
| 229 y, | |
| 230 reverse_data_dictionary, | |
| 231 usage_scores, | |
| 232 actual_classes_pos, | |
| 233 abs_topk, | |
| 234 standard_conn, | |
| 235 last_tool_id, | |
| 236 lowest_tool_ids, | |
| 237 ) | |
| 191 precision[i][index] = absolute_precision | 238 precision[i][index] = absolute_precision |
| 192 usage_weights[i][index] = usg_wt_score | 239 usage_weights[i][index] = usg_wt_score |
| 193 epo_pub_prec[i][index] = pub_prec | 240 epo_pub_prec[i][index] = pub_prec |
| 194 lowest_pub_topk.append(lowest_p_prec) | 241 lowest_pub_topk.append(lowest_p_prec) |
| 195 lowest_norm_topk.append(lowest_n_prec) | 242 lowest_norm_topk.append(lowest_n_prec) |
| 200 mean_precision = np.mean(precision, axis=0) | 247 mean_precision = np.mean(precision, axis=0) |
| 201 mean_usage = np.mean(usage_weights, axis=0) | 248 mean_usage = np.mean(usage_weights, axis=0) |
| 202 mean_pub_prec = np.nanmean(epo_pub_prec, axis=0) | 249 mean_pub_prec = np.nanmean(epo_pub_prec, axis=0) |
| 203 mean_lowest_pub_prec = np.nanmean(epo_lowest_tools_pub_prec, axis=0) | 250 mean_lowest_pub_prec = np.nanmean(epo_lowest_tools_pub_prec, axis=0) |
| 204 mean_lowest_norm_prec = np.nanmean(epo_lowest_tools_norm_prec, axis=0) | 251 mean_lowest_norm_prec = np.nanmean(epo_lowest_tools_norm_prec, axis=0) |
| 205 return mean_usage, mean_precision, mean_pub_prec, mean_lowest_pub_prec, mean_lowest_norm_prec, lowest_counter | 252 return ( |
| 206 | 253 mean_usage, |
| 207 | 254 mean_precision, |
| 208 def save_model(results, data_dictionary, compatible_next_tools, trained_model_path, class_weights, standard_connections): | 255 mean_pub_prec, |
| 256 mean_lowest_pub_prec, | |
| 257 mean_lowest_norm_prec, | |
| 258 lowest_counter, | |
| 259 ) | |
| 260 | |
| 261 | |
| 262 def save_model( | |
| 263 results, | |
| 264 data_dictionary, | |
| 265 compatible_next_tools, | |
| 266 trained_model_path, | |
| 267 class_weights, | |
| 268 standard_connections, | |
| 269 ): | |
| 209 # save files | 270 # save files |
| 210 trained_model = results["model"] | 271 trained_model = results["model"] |
| 211 best_model_parameters = results["best_parameters"] | 272 best_model_parameters = results["best_parameters"] |
| 212 model_config = trained_model.to_json() | 273 model_config = trained_model.to_json() |
| 213 model_weights = trained_model.get_weights() | 274 model_weights = trained_model.get_weights() |
| 214 model_values = { | 275 model_values = { |
| 215 'data_dictionary': data_dictionary, | 276 "data_dictionary": data_dictionary, |
| 216 'model_config': model_config, | 277 "model_config": model_config, |
| 217 'best_parameters': best_model_parameters, | 278 "best_parameters": best_model_parameters, |
| 218 'model_weights': model_weights, | 279 "model_weights": model_weights, |
| 219 "compatible_tools": compatible_next_tools, | 280 "compatible_tools": compatible_next_tools, |
| 220 "class_weights": class_weights, | 281 "class_weights": class_weights, |
| 221 "standard_connections": standard_connections | 282 "standard_connections": standard_connections, |
| 222 } | 283 } |
| 223 set_trained_model(trained_model_path, model_values) | 284 set_trained_model(trained_model_path, model_values) |
