Mercurial > repos > bgruening > create_tool_recommendation_model
comparison optimise_hyperparameters.py @ 2:76251d1ccdcc draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 6fa2a0294d615c9f267b766337dca0b2d3637219"
author | bgruening |
---|---|
date | Fri, 11 Oct 2019 18:24:54 -0400 |
parents | 9bf25dbe00ad |
children | 5b3c08710e47 |
comparison
equal
deleted
inserted
replaced
1:12764915e1c5 | 2:76251d1ccdcc |
---|---|
20 @classmethod | 20 @classmethod |
21 def __init__(self): | 21 def __init__(self): |
22 """ Init method. """ | 22 """ Init method. """ |
23 | 23 |
24 @classmethod | 24 @classmethod |
25 def train_model(self, config, reverse_dictionary, train_data, train_labels, test_data, test_labels, class_weights): | 25 def train_model(self, config, reverse_dictionary, train_data, train_labels, class_weights): |
26 """ | 26 """ |
27 Train a model and report accuracy | 27 Train a model and report accuracy |
28 """ | 28 """ |
29 l_recurrent_activations = config["activation_recurrent"].split(",") | 29 l_recurrent_activations = config["activation_recurrent"].split(",") |
30 l_output_activations = config["activation_output"].split(",") | 30 l_output_activations = config["activation_output"].split(",") |
44 validation_split = float(config["validation_share"]) | 44 validation_split = float(config["validation_share"]) |
45 | 45 |
46 # get dimensions | 46 # get dimensions |
47 dimensions = len(reverse_dictionary) + 1 | 47 dimensions = len(reverse_dictionary) + 1 |
48 best_model_params = dict() | 48 best_model_params = dict() |
49 early_stopping = EarlyStopping(monitor='val_loss', mode='min', min_delta=1e-4, verbose=1, patience=1) | 49 early_stopping = EarlyStopping(monitor='val_loss', mode='min', verbose=1, min_delta=1e-4) |
50 | 50 |
51 # specify the search space for finding the best combination of parameters using Bayesian optimisation | 51 # specify the search space for finding the best combination of parameters using Bayesian optimisation |
52 params = { | 52 params = { |
53 "embedding_size": hp.quniform("embedding_size", l_embedding_size[0], l_embedding_size[1], 1), | 53 "embedding_size": hp.quniform("embedding_size", l_embedding_size[0], l_embedding_size[1], 1), |
54 "units": hp.quniform("units", l_units[0], l_units[1], 1), | 54 "units": hp.quniform("units", l_units[0], l_units[1], 1), |
80 shuffle="batch", | 80 shuffle="batch", |
81 verbose=2, | 81 verbose=2, |
82 validation_split=validation_split, | 82 validation_split=validation_split, |
83 callbacks=[early_stopping] | 83 callbacks=[early_stopping] |
84 ) | 84 ) |
85 return {'loss': model_fit.history["val_loss"][-1], 'status': STATUS_OK} | 85 return {'loss': model_fit.history["val_loss"][-1], 'status': STATUS_OK, 'model': model} |
86 # minimize the objective function using the set of parameters above4 | 86 # minimize the objective function using the set of parameters above |
87 trials = Trials() | 87 trials = Trials() |
88 learned_params = fmin(create_model, params, trials=trials, algo=tpe.suggest, max_evals=int(config["max_evals"])) | 88 learned_params = fmin(create_model, params, trials=trials, algo=tpe.suggest, max_evals=int(config["max_evals"])) |
89 print(learned_params) | 89 best_model = trials.results[np.argmin([r['loss'] for r in trials.results])]['model'] |
90 | |
90 # set the best params with respective values | 91 # set the best params with respective values |
91 for item in learned_params: | 92 for item in learned_params: |
92 item_val = learned_params[item] | 93 item_val = learned_params[item] |
93 if item == 'activation_output': | 94 if item == 'activation_output': |
94 best_model_params[item] = l_output_activations[item_val] | 95 best_model_params[item] = l_output_activations[item_val] |
95 elif item == 'activation_recurrent': | 96 elif item == 'activation_recurrent': |
96 best_model_params[item] = l_recurrent_activations[item_val] | 97 best_model_params[item] = l_recurrent_activations[item_val] |
97 else: | 98 else: |
98 best_model_params[item] = item_val | 99 best_model_params[item] = item_val |
99 return best_model_params | 100 return best_model_params, best_model |