comparison utils.py @ 2:76251d1ccdcc draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 6fa2a0294d615c9f267b766337dca0b2d3637219"
author bgruening
date Fri, 11 Oct 2019 18:24:54 -0400
parents 9bf25dbe00ad
children 5b3c08710e47
comparison
equal deleted inserted replaced
1:12764915e1c5 2:76251d1ccdcc
1 import os 1 import os
2 import numpy as np 2 import numpy as np
3 import json 3 import json
4 import h5py 4 import h5py
5 5
6 from keras.models import model_from_json, Sequential
7 from keras.layers import Dense, GRU, Dropout
8 from keras.layers.embeddings import Embedding
9 from keras.layers.core import SpatialDropout1D
10 from keras.optimizers import RMSprop
11 from keras import backend as K 6 from keras import backend as K
12 7
13 8
14 def read_file(file_path): 9 def read_file(file_path):
15 """ 10 """
35 workflow_paths_unique += path + "\n" 30 workflow_paths_unique += path + "\n"
36 with open(file_path, "w") as workflows_file: 31 with open(file_path, "w") as workflows_file:
37 workflows_file.write(workflow_paths_unique) 32 workflows_file.write(workflow_paths_unique)
38 33
39 34
40 def load_saved_model(model_config, model_weights):
41 """
42 Load the saved trained model using the saved network and its weights
43 """
44 # load the network
45 loaded_model = model_from_json(model_config)
46 # load the saved weights into the model
47 loaded_model.set_weights(model_weights)
48 return loaded_model
49
50
51 def format_tool_id(tool_link): 35 def format_tool_id(tool_link):
52 """ 36 """
53 Extract tool id from tool link 37 Extract tool id from tool link
54 """ 38 """
55 tool_id_split = tool_link.split("/") 39 tool_id_split = tool_link.split("/")
56 tool_id = tool_id_split[-2] if len(tool_id_split) > 1 else tool_link 40 tool_id = tool_id_split[-2] if len(tool_id_split) > 1 else tool_link
57 return tool_id 41 return tool_id
58
59
60 def get_HDF5(hf, d_key):
61 """
62 Read h5 file to get train and test data
63 """
64 return hf.get(d_key).value
65
66
67 def save_HDF5(hf_file, d_key, data, d_type=""):
68 """
69 Save datasets as h5 file
70 """
71 if (d_type == 'json'):
72 data = json.dumps(data)
73 hf_file.create_dataset(d_key, data=data)
74 42
75 43
76 def set_trained_model(dump_file, model_values): 44 def set_trained_model(dump_file, model_values):
77 """ 45 """
78 Create an h5 file with the trained weights and associated dicts 46 Create an h5 file with the trained weights and associated dicts
98 def remove_file(file_path): 66 def remove_file(file_path):
99 if os.path.exists(file_path): 67 if os.path.exists(file_path):
100 os.remove(file_path) 68 os.remove(file_path)
101 69
102 70
103 def extract_configuration(config_object):
104 config_loss = dict()
105 for index, item in enumerate(config_object):
106 config_loss[index] = list()
107 d_config = dict()
108 d_config['loss'] = item['result']['loss']
109 d_config['params_config'] = item['misc']['vals']
110 config_loss[index].append(d_config)
111 return config_loss
112
113
114 def get_best_parameters(mdl_dict):
115 """
116 Get param values (defaults as well)
117 """
118 lr = float(mdl_dict.get("learning_rate", "0.001"))
119 embedding_size = int(mdl_dict.get("embedding_size", "512"))
120 dropout = float(mdl_dict.get("dropout", "0.2"))
121 recurrent_dropout = float(mdl_dict.get("recurrent_dropout", "0.2"))
122 spatial_dropout = float(mdl_dict.get("spatial_dropout", "0.2"))
123 units = int(mdl_dict.get("units", "512"))
124 batch_size = int(mdl_dict.get("batch_size", "512"))
125 activation_recurrent = mdl_dict.get("activation_recurrent", "elu")
126 activation_output = mdl_dict.get("activation_output", "sigmoid")
127
128 return {
129 "lr": lr,
130 "embedding_size": embedding_size,
131 "dropout": dropout,
132 "recurrent_dropout": recurrent_dropout,
133 "spatial_dropout": spatial_dropout,
134 "units": units,
135 "batch_size": batch_size,
136 "activation_recurrent": activation_recurrent,
137 "activation_output": activation_output,
138 }
139
140
141 def weighted_loss(class_weights): 71 def weighted_loss(class_weights):
142 """ 72 """
143 Create a weighted loss function. Penalise the misclassification 73 Create a weighted loss function. Penalise the misclassification
144 of classes more with the higher usage 74 of classes more with the higher usage
145 """ 75 """
148 def weighted_binary_crossentropy(y_true, y_pred): 78 def weighted_binary_crossentropy(y_true, y_pred):
149 # add another dimension to compute dot product 79 # add another dimension to compute dot product
150 expanded_weights = K.expand_dims(weight_values, axis=-1) 80 expanded_weights = K.expand_dims(weight_values, axis=-1)
151 return K.dot(K.binary_crossentropy(y_true, y_pred), expanded_weights) 81 return K.dot(K.binary_crossentropy(y_true, y_pred), expanded_weights)
152 return weighted_binary_crossentropy 82 return weighted_binary_crossentropy
153
154
155 def set_recurrent_network(mdl_dict, reverse_dictionary, class_weights):
156 """
157 Create a RNN network and set its parameters
158 """
159 dimensions = len(reverse_dictionary) + 1
160 model_params = get_best_parameters(mdl_dict)
161
162 # define the architecture of the neural network
163 model = Sequential()
164 model.add(Embedding(dimensions, model_params["embedding_size"], mask_zero=True))
165 model.add(SpatialDropout1D(model_params["spatial_dropout"]))
166 model.add(GRU(model_params["units"], dropout=model_params["spatial_dropout"], recurrent_dropout=model_params["recurrent_dropout"], activation=model_params["activation_recurrent"], return_sequences=True))
167 model.add(Dropout(model_params["dropout"]))
168 model.add(GRU(model_params["units"], dropout=model_params["spatial_dropout"], recurrent_dropout=model_params["recurrent_dropout"], activation=model_params["activation_recurrent"], return_sequences=False))
169 model.add(Dropout(model_params["dropout"]))
170 model.add(Dense(dimensions, activation=model_params["activation_output"]))
171 optimizer = RMSprop(lr=model_params["lr"])
172 model.compile(loss=weighted_loss(class_weights), optimizer=optimizer)
173 return model, model_params
174 83
175 84
176 def compute_precision(model, x, y, reverse_data_dictionary, next_compatible_tools, usage_scores, actual_classes_pos, topk): 85 def compute_precision(model, x, y, reverse_data_dictionary, next_compatible_tools, usage_scores, actual_classes_pos, topk):
177 """ 86 """
178 Compute absolute and compatible precision 87 Compute absolute and compatible precision