Mercurial > repos > bgruening > create_tool_recommendation_model
comparison utils.py @ 6:e94dc7945639 draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
author | bgruening |
---|---|
date | Sun, 16 Oct 2022 11:52:10 +0000 |
parents | 4f7e6612906b |
children |
comparison
equal
deleted
inserted
replaced
5:4f7e6612906b | 6:e94dc7945639 |
---|---|
1 import json | 1 import json |
2 import os | |
2 import random | 3 import random |
3 | 4 |
4 import h5py | 5 import h5py |
5 import numpy as np | 6 import numpy as np |
7 import pandas as pd | |
6 import tensorflow as tf | 8 import tensorflow as tf |
7 from numpy.random import choice | 9 |
8 from tensorflow.keras import backend | 10 binary_ce = tf.keras.losses.BinaryCrossentropy() |
11 binary_acc = tf.keras.metrics.BinaryAccuracy() | |
12 categorical_ce = tf.keras.metrics.CategoricalCrossentropy(from_logits=True) | |
9 | 13 |
10 | 14 |
11 def read_file(file_path): | 15 def read_file(file_path): |
12 """ | 16 """ |
13 Read a file | 17 Read a file |
15 with open(file_path, "r") as json_file: | 19 with open(file_path, "r") as json_file: |
16 file_content = json.loads(json_file.read()) | 20 file_content = json.loads(json_file.read()) |
17 return file_content | 21 return file_content |
18 | 22 |
19 | 23 |
24 def write_file(file_path, content): | |
25 """ | |
26 Write a file | |
27 """ | |
28 remove_file(file_path) | |
29 with open(file_path, "w") as json_file: | |
30 json_file.write(json.dumps(content)) | |
31 | |
32 | |
33 def save_h5_data(inp, tar, filename): | |
34 hf_file = h5py.File(filename, 'w') | |
35 hf_file.create_dataset("input", data=inp) | |
36 hf_file.create_dataset("target", data=tar) | |
37 hf_file.close() | |
38 | |
39 | |
40 def get_low_freq_te_samples(te_data, te_target, tr_freq_dict): | |
41 lowest_tool_te_ids = list() | |
42 lowest_t_ids = get_lowest_tools(tr_freq_dict) | |
43 for i, te_labels in enumerate(te_target): | |
44 tools_pos = np.where(te_labels > 0)[0] | |
45 tools_pos = [str(int(item)) for item in tools_pos] | |
46 intersection = list(set(tools_pos).intersection(set(lowest_t_ids))) | |
47 if len(intersection) > 0: | |
48 lowest_tool_te_ids.append(i) | |
49 lowest_t_ids = [item for item in lowest_t_ids if item not in intersection] | |
50 return lowest_tool_te_ids | |
51 | |
52 | |
53 def save_processed_workflows(file_path, unique_paths): | |
54 workflow_paths_unique = "" | |
55 for path in unique_paths: | |
56 workflow_paths_unique += path + "\n" | |
57 with open(file_path, "w") as workflows_file: | |
58 workflows_file.write(workflow_paths_unique) | |
59 | |
60 | |
20 def format_tool_id(tool_link): | 61 def format_tool_id(tool_link): |
21 """ | 62 """ |
22 Extract tool id from tool link | 63 Extract tool id from tool link |
23 """ | 64 """ |
24 tool_id_split = tool_link.split("/") | 65 tool_id_split = tool_link.split("/") |
25 tool_id = tool_id_split[-2] if len(tool_id_split) > 1 else tool_link | 66 tool_id = tool_id_split[-2] if len(tool_id_split) > 1 else tool_link |
26 return tool_id | 67 return tool_id |
27 | 68 |
28 | 69 |
29 def set_trained_model(dump_file, model_values): | 70 def save_model_file(model, r_dict, c_wts, c_tools, s_conn, model_file): |
30 """ | 71 model.save_weights(model_file, save_format="h5") |
31 Create an h5 file with the trained weights and associated dicts | 72 hf_file = h5py.File(model_file, 'r+') |
32 """ | 73 model_values = { |
33 hf_file = h5py.File(dump_file, "w") | 74 "reverse_dict": r_dict, |
34 for key in model_values: | 75 "class_weights": c_wts, |
35 value = model_values[key] | 76 "compatible_tools": c_tools, |
36 if key == "model_weights": | 77 "standard_connections": s_conn |
37 for idx, item in enumerate(value): | 78 } |
38 w_key = "weight_" + str(idx) | 79 for k in model_values: |
39 if w_key in hf_file: | 80 hf_file.create_dataset(k, data=json.dumps(model_values[k])) |
40 hf_file.modify(w_key, item) | 81 hf_file.close() |
41 else: | 82 |
42 hf_file.create_dataset(w_key, data=item) | 83 |
84 def remove_file(file_path): | |
85 if os.path.exists(file_path): | |
86 os.remove(file_path) | |
87 | |
88 | |
89 def verify_oversampling_freq(oversampled_tr_data, rev_dict): | |
90 """ | |
91 Compute the frequency of tool sequences after oversampling | |
92 """ | |
93 freq_dict = dict() | |
94 freq_dict_names = dict() | |
95 for tr_data in oversampled_tr_data: | |
96 t_pos = np.where(tr_data > 0)[0] | |
97 last_tool_id = str(int(tr_data[t_pos[-1]])) | |
98 if last_tool_id not in freq_dict: | |
99 freq_dict[last_tool_id] = 0 | |
100 freq_dict_names[rev_dict[int(last_tool_id)]] = 0 | |
101 freq_dict[last_tool_id] += 1 | |
102 freq_dict_names[rev_dict[int(last_tool_id)]] += 1 | |
103 s_freq = dict(sorted(freq_dict_names.items(), key=lambda kv: kv[1], reverse=True)) | |
104 return s_freq | |
105 | |
106 | |
107 def collect_sampled_tool_freq(collected_dict, c_freq): | |
108 for t in c_freq: | |
109 if t not in collected_dict: | |
110 collected_dict[t] = int(c_freq[t]) | |
43 else: | 111 else: |
44 if key in hf_file: | 112 collected_dict[t] += int(c_freq[t]) |
45 hf_file.modify(key, json.dumps(value)) | 113 return collected_dict |
46 else: | 114 |
47 hf_file.create_dataset(key, data=json.dumps(value)) | 115 |
48 hf_file.close() | 116 def save_data_as_dict(f_dict, r_dict, inp, tar, save_path): |
49 | 117 inp_tar = dict() |
50 | 118 for index, (i, t) in enumerate(zip(inp, tar)): |
51 def weighted_loss(class_weights): | 119 i_pos = np.where(i > 0)[0] |
52 """ | 120 i_seq = ",".join([str(int(item)) for item in i[1:i_pos[-1] + 1]]) |
53 Create a weighted loss function. Penalise the misclassification | 121 t_pos = np.where(t > 0)[0] |
54 of classes more with the higher usage | 122 t_seq = ",".join([str(int(item)) for item in t[1:t_pos[-1] + 1]]) |
55 """ | 123 if i_seq not in inp_tar: |
56 weight_values = list(class_weights.values()) | 124 inp_tar[i_seq] = list() |
57 weight_values.extend(weight_values) | 125 inp_tar[i_seq].append(t_seq) |
58 | 126 size = 0 |
59 def weighted_binary_crossentropy(y_true, y_pred): | 127 for item in inp_tar: |
60 # add another dimension to compute dot product | 128 size += len(inp_tar[item]) |
61 expanded_weights = tf.expand_dims(weight_values, axis=-1) | 129 print("Size saved file: ", size) |
62 bce = backend.binary_crossentropy(y_true, y_pred) | 130 write_file(save_path, inp_tar) |
63 return backend.dot(bce, expanded_weights) | 131 |
64 | 132 |
65 return weighted_binary_crossentropy | 133 def read_train_test(datapath): |
66 | 134 file_obj = h5py.File(datapath, 'r') |
67 | 135 data_input = np.array(file_obj["input"]) |
68 def balanced_sample_generator( | 136 data_target = np.array(file_obj["target"]) |
69 train_data, train_labels, batch_size, l_tool_tr_samples, reverse_dictionary | 137 return data_input, data_target |
70 ): | 138 |
71 while True: | 139 |
72 dimension = train_data.shape[1] | 140 def sample_balanced_tr_y(x_seqs, y_labels, ulabels_tr_y_dict, b_size, tr_t_freq, prev_sel_tools): |
73 n_classes = train_labels.shape[1] | 141 batch_y_tools = list(ulabels_tr_y_dict.keys()) |
74 tool_ids = list(l_tool_tr_samples.keys()) | 142 random.shuffle(batch_y_tools) |
75 random.shuffle(tool_ids) | 143 label_tools = list() |
76 generator_batch_data = np.zeros([batch_size, dimension]) | 144 rand_batch_indices = list() |
77 generator_batch_labels = np.zeros([batch_size, n_classes]) | 145 sel_tools = list() |
78 generated_tool_ids = choice(tool_ids, batch_size) | 146 |
79 for i in range(batch_size): | 147 unselected_tools = [t for t in batch_y_tools if t not in prev_sel_tools] |
80 random_toolid = generated_tool_ids[i] | 148 rand_selected_tools = unselected_tools[:b_size] |
81 sample_indices = l_tool_tr_samples[str(random_toolid)] | 149 |
82 random_index = random.sample(range(0, len(sample_indices)), 1)[0] | 150 for l_tool in rand_selected_tools: |
83 random_tr_index = sample_indices[random_index] | 151 seq_indices = ulabels_tr_y_dict[l_tool] |
84 generator_batch_data[i] = train_data[random_tr_index] | 152 random.shuffle(seq_indices) |
85 generator_batch_labels[i] = train_labels[random_tr_index] | 153 rand_s_index = np.random.randint(0, len(seq_indices), 1)[0] |
86 yield generator_batch_data, generator_batch_labels | 154 rand_sample = seq_indices[rand_s_index] |
87 | 155 sel_tools.append(l_tool) |
88 | 156 rand_batch_indices.append(rand_sample) |
89 def compute_precision( | 157 label_tools.append(l_tool) |
90 model, | 158 |
91 x, | 159 x_batch_train = x_seqs[rand_batch_indices] |
92 y, | 160 y_batch_train = y_labels[rand_batch_indices] |
93 reverse_data_dictionary, | 161 |
94 usage_scores, | 162 unrolled_x = tf.convert_to_tensor(x_batch_train, dtype=tf.int64) |
95 actual_classes_pos, | 163 unrolled_y = tf.convert_to_tensor(y_batch_train, dtype=tf.int64) |
96 topk, | 164 return unrolled_x, unrolled_y, sel_tools |
97 standard_conn, | 165 |
98 last_tool_id, | 166 |
99 lowest_tool_ids, | 167 def sample_balanced_te_y(x_seqs, y_labels, ulabels_tr_y_dict, b_size): |
100 ): | 168 batch_y_tools = list(ulabels_tr_y_dict.keys()) |
101 """ | 169 random.shuffle(batch_y_tools) |
102 Compute absolute and compatible precision | 170 label_tools = list() |
103 """ | 171 rand_batch_indices = list() |
104 pred_t_name = "" | 172 sel_tools = list() |
105 top_precision = 0.0 | 173 for l_tool in batch_y_tools: |
106 mean_usage = 0.0 | 174 seq_indices = ulabels_tr_y_dict[l_tool] |
107 usage_wt_score = list() | 175 random.shuffle(seq_indices) |
108 pub_precision = 0.0 | 176 rand_s_index = np.random.randint(0, len(seq_indices), 1)[0] |
109 lowest_pub_prec = 0.0 | 177 rand_sample = seq_indices[rand_s_index] |
110 lowest_norm_prec = 0.0 | 178 sel_tools.append(l_tool) |
111 pub_tools = list() | 179 if rand_sample not in rand_batch_indices: |
112 actual_next_tool_names = list() | 180 rand_batch_indices.append(rand_sample) |
113 test_sample = np.reshape(x, (1, len(x))) | 181 label_tools.append(l_tool) |
114 | 182 if len(rand_batch_indices) == b_size: |
115 # predict next tools for a test path | 183 break |
116 prediction = model.predict(test_sample, verbose=0) | 184 x_batch_train = x_seqs[rand_batch_indices] |
117 | 185 y_batch_train = y_labels[rand_batch_indices] |
118 # divide the predicted vector into two halves - one for published and | 186 |
119 # another for normal workflows | 187 unrolled_x = tf.convert_to_tensor(x_batch_train, dtype=tf.int64) |
120 nw_dimension = prediction.shape[1] | 188 unrolled_y = tf.convert_to_tensor(y_batch_train, dtype=tf.int64) |
121 half_len = int(nw_dimension / 2) | 189 return unrolled_x, unrolled_y, sel_tools |
122 | 190 |
123 # predict tools | 191 |
124 prediction = np.reshape(prediction, (nw_dimension,)) | 192 def get_u_tr_labels(y_tr): |
125 # get predictions of tools from published workflows | 193 labels = list() |
126 standard_pred = prediction[:half_len] | 194 labels_pos_dict = dict() |
127 # get predictions of tools from normal workflows | 195 for i, item in enumerate(y_tr): |
128 normal_pred = prediction[half_len:] | 196 label_pos = np.where(item > 0)[0] |
129 | 197 labels.extend(label_pos) |
130 standard_prediction_pos = np.argsort(standard_pred, axis=-1) | 198 for label in label_pos: |
131 standard_topk_prediction_pos = standard_prediction_pos[-topk] | 199 if label not in labels_pos_dict: |
132 | 200 labels_pos_dict[label] = list() |
133 normal_prediction_pos = np.argsort(normal_pred, axis=-1) | 201 labels_pos_dict[label].append(i) |
134 normal_topk_prediction_pos = normal_prediction_pos[-topk] | 202 u_labels = list(set(labels)) |
135 | 203 for item in labels_pos_dict: |
136 # get true tools names | 204 labels_pos_dict[item] = list(set(labels_pos_dict[item])) |
137 for a_t_pos in actual_classes_pos: | 205 return u_labels, labels_pos_dict |
138 if a_t_pos > half_len: | 206 |
139 t_name = reverse_data_dictionary[int(a_t_pos - half_len)] | 207 |
140 else: | 208 def compute_loss(y_true, y_pred, class_weights=None): |
141 t_name = reverse_data_dictionary[int(a_t_pos)] | 209 y_true = tf.cast(y_true, dtype=tf.float32) |
142 actual_next_tool_names.append(t_name) | 210 loss = binary_ce(y_true, y_pred) |
143 last_tool_name = reverse_data_dictionary[x[-1]] | 211 categorical_loss = categorical_ce(y_true, y_pred) |
144 # compute scores for published recommendations | 212 if class_weights is None: |
145 if standard_topk_prediction_pos in reverse_data_dictionary: | 213 return tf.reduce_mean(loss), categorical_loss |
146 pred_t_name = reverse_data_dictionary[int(standard_topk_prediction_pos)] | 214 return tf.tensordot(loss, class_weights, axes=1), categorical_loss |
147 if last_tool_name in standard_conn: | 215 |
148 pub_tools = standard_conn[last_tool_name] | 216 |
149 if pred_t_name in pub_tools: | 217 def compute_acc(y_true, y_pred): |
150 pub_precision = 1.0 | 218 return binary_acc(y_true, y_pred) |
151 # count precision only when there is actually true published tools | 219 |
152 if last_tool_id in lowest_tool_ids: | 220 |
153 lowest_pub_prec = 1.0 | 221 def validate_model(te_x, te_y, te_batch_size, model, f_dict, r_dict, ulabels_te_dict, tr_labels, lowest_t_ids): |
154 else: | 222 te_x_batch, y_train_batch, _ = sample_balanced_te_y(te_x, te_y, ulabels_te_dict, te_batch_size) |
155 lowest_pub_prec = np.nan | 223 print("Total test data size: ", te_x.shape, te_y.shape) |
156 if standard_topk_prediction_pos in usage_scores: | 224 print("Batch test data size: ", te_x_batch.shape, y_train_batch.shape) |
157 usage_wt_score.append( | 225 te_pred_batch, _ = model(te_x_batch, training=False) |
158 np.log(usage_scores[standard_topk_prediction_pos] + 1.0) | 226 test_err, _ = compute_loss(y_train_batch, te_pred_batch) |
159 ) | 227 print("Test loss:") |
160 else: | 228 print(test_err.numpy()) |
161 # count precision only when there is actually true published tools | 229 print("Test finished") |
162 # else set to np.nan. Set to 0 only when there is wrong prediction | |
163 pub_precision = np.nan | |
164 lowest_pub_prec = np.nan | |
165 # compute scores for normal recommendations | |
166 if normal_topk_prediction_pos in reverse_data_dictionary: | |
167 pred_t_name = reverse_data_dictionary[int(normal_topk_prediction_pos)] | |
168 if pred_t_name in actual_next_tool_names: | |
169 if normal_topk_prediction_pos in usage_scores: | |
170 usage_wt_score.append( | |
171 np.log(usage_scores[normal_topk_prediction_pos] + 1.0) | |
172 ) | |
173 top_precision = 1.0 | |
174 if last_tool_id in lowest_tool_ids: | |
175 lowest_norm_prec = 1.0 | |
176 else: | |
177 lowest_norm_prec = np.nan | |
178 if len(usage_wt_score) > 0: | |
179 mean_usage = np.mean(usage_wt_score) | |
180 return mean_usage, top_precision, pub_precision, lowest_pub_prec, lowest_norm_prec | |
181 | 230 |
182 | 231 |
183 def get_lowest_tools(l_tool_freq, fraction=0.25): | 232 def get_lowest_tools(l_tool_freq, fraction=0.25): |
184 l_tool_freq = dict(sorted(l_tool_freq.items(), key=lambda kv: kv[1], reverse=True)) | 233 l_tool_freq = dict(sorted(l_tool_freq.items(), key=lambda kv: kv[1], reverse=True)) |
185 tool_ids = list(l_tool_freq.keys()) | 234 tool_ids = list(l_tool_freq.keys()) |
186 lowest_ids = tool_ids[-int(len(tool_ids) * fraction):] | 235 lowest_ids = tool_ids[-int(len(tool_ids) * fraction):] |
187 return lowest_ids | 236 return lowest_ids |
188 | 237 |
189 | 238 |
190 def verify_model( | 239 def remove_pipe(file_path): |
191 model, | 240 dataframe = pd.read_csv(file_path, sep="|", header=None) |
192 x, | 241 dataframe = dataframe[1:len(dataframe.index) - 1] |
193 y, | 242 return dataframe[1:] |
194 reverse_data_dictionary, | |
195 usage_scores, | |
196 standard_conn, | |
197 lowest_tool_ids, | |
198 topk_list=[1, 2, 3], | |
199 ): | |
200 """ | |
201 Verify the model on test data | |
202 """ | |
203 print("Evaluating performance on test data...") | |
204 print("Test data size: %d" % len(y)) | |
205 size = y.shape[0] | |
206 precision = np.zeros([len(y), len(topk_list)]) | |
207 usage_weights = np.zeros([len(y), len(topk_list)]) | |
208 epo_pub_prec = np.zeros([len(y), len(topk_list)]) | |
209 epo_lowest_tools_pub_prec = list() | |
210 epo_lowest_tools_norm_prec = list() | |
211 lowest_counter = 0 | |
212 # loop over all the test samples and find prediction precision | |
213 for i in range(size): | |
214 lowest_pub_topk = list() | |
215 lowest_norm_topk = list() | |
216 actual_classes_pos = np.where(y[i] > 0)[0] | |
217 test_sample = x[i, :] | |
218 last_tool_id = str(int(test_sample[-1])) | |
219 for index, abs_topk in enumerate(topk_list): | |
220 ( | |
221 usg_wt_score, | |
222 absolute_precision, | |
223 pub_prec, | |
224 lowest_p_prec, | |
225 lowest_n_prec, | |
226 ) = compute_precision( | |
227 model, | |
228 test_sample, | |
229 y, | |
230 reverse_data_dictionary, | |
231 usage_scores, | |
232 actual_classes_pos, | |
233 abs_topk, | |
234 standard_conn, | |
235 last_tool_id, | |
236 lowest_tool_ids, | |
237 ) | |
238 precision[i][index] = absolute_precision | |
239 usage_weights[i][index] = usg_wt_score | |
240 epo_pub_prec[i][index] = pub_prec | |
241 lowest_pub_topk.append(lowest_p_prec) | |
242 lowest_norm_topk.append(lowest_n_prec) | |
243 epo_lowest_tools_pub_prec.append(lowest_pub_topk) | |
244 epo_lowest_tools_norm_prec.append(lowest_norm_topk) | |
245 if last_tool_id in lowest_tool_ids: | |
246 lowest_counter += 1 | |
247 mean_precision = np.mean(precision, axis=0) | |
248 mean_usage = np.mean(usage_weights, axis=0) | |
249 mean_pub_prec = np.nanmean(epo_pub_prec, axis=0) | |
250 mean_lowest_pub_prec = np.nanmean(epo_lowest_tools_pub_prec, axis=0) | |
251 mean_lowest_norm_prec = np.nanmean(epo_lowest_tools_norm_prec, axis=0) | |
252 return ( | |
253 mean_usage, | |
254 mean_precision, | |
255 mean_pub_prec, | |
256 mean_lowest_pub_prec, | |
257 mean_lowest_norm_prec, | |
258 lowest_counter, | |
259 ) | |
260 | |
261 | |
262 def save_model( | |
263 results, | |
264 data_dictionary, | |
265 compatible_next_tools, | |
266 trained_model_path, | |
267 class_weights, | |
268 standard_connections, | |
269 ): | |
270 # save files | |
271 trained_model = results["model"] | |
272 best_model_parameters = results["best_parameters"] | |
273 model_config = trained_model.to_json() | |
274 model_weights = trained_model.get_weights() | |
275 model_values = { | |
276 "data_dictionary": data_dictionary, | |
277 "model_config": model_config, | |
278 "best_parameters": best_model_parameters, | |
279 "model_weights": model_weights, | |
280 "compatible_tools": compatible_next_tools, | |
281 "class_weights": class_weights, | |
282 "standard_connections": standard_connections, | |
283 } | |
284 set_trained_model(trained_model_path, model_values) |