Mercurial > repos > bgruening > create_tool_recommendation_model
comparison utils.py @ 3:5b3c08710e47 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
author | bgruening |
---|---|
date | Sat, 09 May 2020 05:38:23 -0400 |
parents | 76251d1ccdcc |
children | afec8c595124 |
comparison
equal
deleted
inserted
replaced
2:76251d1ccdcc | 3:5b3c08710e47 |
---|---|
1 import os | 1 import os |
2 import numpy as np | 2 import numpy as np |
3 import json | 3 import json |
4 import h5py | 4 import h5py |
5 import random | |
5 | 6 |
6 from keras import backend as K | 7 from keras import backend as K |
7 | 8 |
8 | 9 |
9 def read_file(file_path): | 10 def read_file(file_path): |
11 Read a file | 12 Read a file |
12 """ | 13 """ |
13 with open(file_path, "r") as json_file: | 14 with open(file_path, "r") as json_file: |
14 file_content = json.loads(json_file.read()) | 15 file_content = json.loads(json_file.read()) |
15 return file_content | 16 return file_content |
16 | |
17 | |
18 def write_file(file_path, content): | |
19 """ | |
20 Write a file | |
21 """ | |
22 remove_file(file_path) | |
23 with open(file_path, "w") as json_file: | |
24 json_file.write(json.dumps(content)) | |
25 | |
26 | |
27 def save_processed_workflows(file_path, unique_paths): | |
28 workflow_paths_unique = "" | |
29 for path in unique_paths: | |
30 workflow_paths_unique += path + "\n" | |
31 with open(file_path, "w") as workflows_file: | |
32 workflows_file.write(workflow_paths_unique) | |
33 | 17 |
34 | 18 |
35 def format_tool_id(tool_link): | 19 def format_tool_id(tool_link): |
36 """ | 20 """ |
37 Extract tool id from tool link | 21 Extract tool id from tool link |
61 else: | 45 else: |
62 hf_file.create_dataset(key, data=json.dumps(value)) | 46 hf_file.create_dataset(key, data=json.dumps(value)) |
63 hf_file.close() | 47 hf_file.close() |
64 | 48 |
65 | 49 |
66 def remove_file(file_path): | |
67 if os.path.exists(file_path): | |
68 os.remove(file_path) | |
69 | |
70 | |
71 def weighted_loss(class_weights): | 50 def weighted_loss(class_weights): |
72 """ | 51 """ |
73 Create a weighted loss function. Penalise the misclassification | 52 Create a weighted loss function. Penalise the misclassification |
74 of classes more with the higher usage | 53 of classes more with the higher usage |
75 """ | 54 """ |
76 weight_values = list(class_weights.values()) | 55 weight_values = list(class_weights.values()) |
56 weight_values.extend(weight_values) | |
77 | 57 |
78 def weighted_binary_crossentropy(y_true, y_pred): | 58 def weighted_binary_crossentropy(y_true, y_pred): |
79 # add another dimension to compute dot product | 59 # add another dimension to compute dot product |
80 expanded_weights = K.expand_dims(weight_values, axis=-1) | 60 expanded_weights = K.expand_dims(weight_values, axis=-1) |
81 return K.dot(K.binary_crossentropy(y_true, y_pred), expanded_weights) | 61 return K.dot(K.binary_crossentropy(y_true, y_pred), expanded_weights) |
82 return weighted_binary_crossentropy | 62 return weighted_binary_crossentropy |
83 | 63 |
84 | 64 |
85 def compute_precision(model, x, y, reverse_data_dictionary, next_compatible_tools, usage_scores, actual_classes_pos, topk): | 65 def balanced_sample_generator(train_data, train_labels, batch_size, l_tool_tr_samples): |
66 while True: | |
67 dimension = train_data.shape[1] | |
68 n_classes = train_labels.shape[1] | |
69 tool_ids = list(l_tool_tr_samples.keys()) | |
70 generator_batch_data = np.zeros([batch_size, dimension]) | |
71 generator_batch_labels = np.zeros([batch_size, n_classes]) | |
72 for i in range(batch_size): | |
73 random_toolid_index = random.sample(range(0, len(tool_ids)), 1)[0] | |
74 random_toolid = tool_ids[random_toolid_index] | |
75 sample_indices = l_tool_tr_samples[str(random_toolid)] | |
76 random_index = random.sample(range(0, len(sample_indices)), 1)[0] | |
77 random_tr_index = sample_indices[random_index] | |
78 generator_batch_data[i] = train_data[random_tr_index] | |
79 generator_batch_labels[i] = train_labels[random_tr_index] | |
80 yield generator_batch_data, generator_batch_labels | |
81 | |
82 | |
83 def compute_precision(model, x, y, reverse_data_dictionary, usage_scores, actual_classes_pos, topk, standard_conn, last_tool_id, lowest_tool_ids): | |
86 """ | 84 """ |
87 Compute absolute and compatible precision | 85 Compute absolute and compatible precision |
88 """ | 86 """ |
89 absolute_precision = 0.0 | 87 pred_t_name = "" |
88 top_precision = 0.0 | |
89 mean_usage = 0.0 | |
90 usage_wt_score = list() | |
91 pub_precision = 0.0 | |
92 lowest_pub_prec = 0.0 | |
93 lowest_norm_prec = 0.0 | |
94 pub_tools = list() | |
95 actual_next_tool_names = list() | |
90 test_sample = np.reshape(x, (1, len(x))) | 96 test_sample = np.reshape(x, (1, len(x))) |
91 | 97 |
92 # predict next tools for a test path | 98 # predict next tools for a test path |
93 prediction = model.predict(test_sample, verbose=0) | 99 prediction = model.predict(test_sample, verbose=0) |
94 | 100 |
101 # divide the predicted vector into two halves - one for published and | |
102 # another for normal workflows | |
95 nw_dimension = prediction.shape[1] | 103 nw_dimension = prediction.shape[1] |
96 | 104 half_len = int(nw_dimension / 2) |
97 # remove the 0th position as there is no tool at this index | 105 |
106 # predict tools | |
98 prediction = np.reshape(prediction, (nw_dimension,)) | 107 prediction = np.reshape(prediction, (nw_dimension,)) |
99 | 108 # get predictions of tools from published workflows |
100 prediction_pos = np.argsort(prediction, axis=-1) | 109 standard_pred = prediction[:half_len] |
101 topk_prediction_pos = prediction_pos[-topk:] | 110 # get predictions of tools from normal workflows |
102 | 111 normal_pred = prediction[half_len:] |
103 # remove the wrong tool position from the predicted list of tool positions | 112 |
104 topk_prediction_pos = [x for x in topk_prediction_pos if x > 0] | 113 standard_prediction_pos = np.argsort(standard_pred, axis=-1) |
105 | 114 standard_topk_prediction_pos = standard_prediction_pos[-topk] |
106 # read tool names using reverse dictionary | 115 |
107 actual_next_tool_names = [reverse_data_dictionary[int(tool_pos)] for tool_pos in actual_classes_pos] | 116 normal_prediction_pos = np.argsort(normal_pred, axis=-1) |
108 top_predicted_next_tool_names = [reverse_data_dictionary[int(tool_pos)] for tool_pos in topk_prediction_pos] | 117 normal_topk_prediction_pos = normal_prediction_pos[-topk] |
109 | 118 |
110 # compute the class weights of predicted tools | 119 # get true tools names |
111 mean_usg_score = 0 | 120 for a_t_pos in actual_classes_pos: |
112 usg_wt_scores = list() | 121 if a_t_pos > half_len: |
113 for t_id in topk_prediction_pos: | 122 t_name = reverse_data_dictionary[int(a_t_pos - half_len)] |
114 t_name = reverse_data_dictionary[int(t_id)] | 123 else: |
115 if t_id in usage_scores and t_name in actual_next_tool_names: | 124 t_name = reverse_data_dictionary[int(a_t_pos)] |
116 usg_wt_scores.append(np.log(usage_scores[t_id] + 1.0)) | 125 actual_next_tool_names.append(t_name) |
117 if len(usg_wt_scores) > 0: | 126 last_tool_name = reverse_data_dictionary[x[-1]] |
118 mean_usg_score = np.sum(usg_wt_scores) / float(topk) | 127 # compute scores for published recommendations |
119 false_positives = [tool_name for tool_name in top_predicted_next_tool_names if tool_name not in actual_next_tool_names] | 128 if standard_topk_prediction_pos in reverse_data_dictionary: |
120 absolute_precision = 1 - (len(false_positives) / float(topk)) | 129 pred_t_name = reverse_data_dictionary[int(standard_topk_prediction_pos)] |
121 return mean_usg_score, absolute_precision | 130 if last_tool_name in standard_conn: |
122 | 131 pub_tools = standard_conn[last_tool_name] |
123 | 132 if pred_t_name in pub_tools: |
124 def verify_model(model, x, y, reverse_data_dictionary, next_compatible_tools, usage_scores, topk_list=[1, 2, 3]): | 133 pub_precision = 1.0 |
134 if last_tool_id in lowest_tool_ids: | |
135 lowest_pub_prec = 1.0 | |
136 if standard_topk_prediction_pos in usage_scores: | |
137 usage_wt_score.append(np.log(usage_scores[standard_topk_prediction_pos] + 1.0)) | |
138 # compute scores for normal recommendations | |
139 if normal_topk_prediction_pos in reverse_data_dictionary: | |
140 pred_t_name = reverse_data_dictionary[int(normal_topk_prediction_pos)] | |
141 if pred_t_name in actual_next_tool_names: | |
142 if normal_topk_prediction_pos in usage_scores: | |
143 usage_wt_score.append(np.log(usage_scores[normal_topk_prediction_pos] + 1.0)) | |
144 top_precision = 1.0 | |
145 if last_tool_id in lowest_tool_ids: | |
146 lowest_norm_prec = 1.0 | |
147 if len(usage_wt_score) > 0: | |
148 mean_usage = np.mean(usage_wt_score) | |
149 return mean_usage, top_precision, pub_precision, lowest_pub_prec, lowest_norm_prec | |
150 | |
151 | |
152 def get_lowest_tools(l_tool_freq, fraction=0.25): | |
153 l_tool_freq = dict(sorted(l_tool_freq.items(), key=lambda kv: kv[1], reverse=True)) | |
154 tool_ids = list(l_tool_freq.keys()) | |
155 lowest_ids = tool_ids[-int(len(tool_ids) * fraction):] | |
156 return lowest_ids | |
157 | |
158 | |
159 def verify_model(model, x, y, reverse_data_dictionary, usage_scores, standard_conn, lowest_tool_ids, topk_list=[1, 2, 3]): | |
125 """ | 160 """ |
126 Verify the model on test data | 161 Verify the model on test data |
127 """ | 162 """ |
128 print("Evaluating performance on test data...") | 163 print("Evaluating performance on test data...") |
129 print("Test data size: %d" % len(y)) | 164 print("Test data size: %d" % len(y)) |
130 size = y.shape[0] | 165 size = y.shape[0] |
131 precision = np.zeros([len(y), len(topk_list)]) | 166 precision = np.zeros([len(y), len(topk_list)]) |
132 usage_weights = np.zeros([len(y), len(topk_list)]) | 167 usage_weights = np.zeros([len(y), len(topk_list)]) |
168 epo_pub_prec = np.zeros([len(y), len(topk_list)]) | |
169 epo_lowest_tools_pub_prec = list() | |
170 epo_lowest_tools_norm_prec = list() | |
171 | |
133 # loop over all the test samples and find prediction precision | 172 # loop over all the test samples and find prediction precision |
134 for i in range(size): | 173 for i in range(size): |
174 lowest_pub_topk = list() | |
175 lowest_norm_topk = list() | |
135 actual_classes_pos = np.where(y[i] > 0)[0] | 176 actual_classes_pos = np.where(y[i] > 0)[0] |
177 test_sample = x[i, :] | |
178 last_tool_id = str(int(test_sample[-1])) | |
136 for index, abs_topk in enumerate(topk_list): | 179 for index, abs_topk in enumerate(topk_list): |
137 abs_mean_usg_score, absolute_precision = compute_precision(model, x[i, :], y, reverse_data_dictionary, next_compatible_tools, usage_scores, actual_classes_pos, abs_topk) | 180 usg_wt_score, absolute_precision, pub_prec, lowest_p_prec, lowest_n_prec = compute_precision(model, test_sample, y, reverse_data_dictionary, usage_scores, actual_classes_pos, abs_topk, standard_conn, last_tool_id, lowest_tool_ids) |
138 precision[i][index] = absolute_precision | 181 precision[i][index] = absolute_precision |
139 usage_weights[i][index] = abs_mean_usg_score | 182 usage_weights[i][index] = usg_wt_score |
183 epo_pub_prec[i][index] = pub_prec | |
184 if last_tool_id in lowest_tool_ids: | |
185 lowest_pub_topk.append(lowest_p_prec) | |
186 lowest_norm_topk.append(lowest_n_prec) | |
187 if last_tool_id in lowest_tool_ids: | |
188 epo_lowest_tools_pub_prec.append(lowest_pub_topk) | |
189 epo_lowest_tools_norm_prec.append(lowest_norm_topk) | |
140 mean_precision = np.mean(precision, axis=0) | 190 mean_precision = np.mean(precision, axis=0) |
141 mean_usage = np.mean(usage_weights, axis=0) | 191 mean_usage = np.mean(usage_weights, axis=0) |
142 return mean_precision, mean_usage | 192 mean_pub_prec = np.mean(epo_pub_prec, axis=0) |
143 | 193 mean_lowest_pub_prec = np.mean(epo_lowest_tools_pub_prec, axis=0) |
144 | 194 mean_lowest_norm_prec = np.mean(epo_lowest_tools_norm_prec, axis=0) |
145 def save_model(results, data_dictionary, compatible_next_tools, trained_model_path, class_weights): | 195 return mean_usage, mean_precision, mean_pub_prec, mean_lowest_pub_prec, mean_lowest_norm_prec, len(epo_lowest_tools_pub_prec) |
196 | |
197 | |
198 def save_model(results, data_dictionary, compatible_next_tools, trained_model_path, class_weights, standard_connections): | |
146 # save files | 199 # save files |
147 trained_model = results["model"] | 200 trained_model = results["model"] |
148 best_model_parameters = results["best_parameters"] | 201 best_model_parameters = results["best_parameters"] |
149 model_config = trained_model.to_json() | 202 model_config = trained_model.to_json() |
150 model_weights = trained_model.get_weights() | 203 model_weights = trained_model.get_weights() |
151 | |
152 model_values = { | 204 model_values = { |
153 'data_dictionary': data_dictionary, | 205 'data_dictionary': data_dictionary, |
154 'model_config': model_config, | 206 'model_config': model_config, |
155 'best_parameters': best_model_parameters, | 207 'best_parameters': best_model_parameters, |
156 'model_weights': model_weights, | 208 'model_weights': model_weights, |
157 "compatible_tools": compatible_next_tools, | 209 "compatible_tools": compatible_next_tools, |
158 "class_weights": class_weights | 210 "class_weights": class_weights, |
211 "standard_connections": standard_connections | |
159 } | 212 } |
160 set_trained_model(trained_model_path, model_values) | 213 set_trained_model(trained_model_path, model_values) |