Mercurial > repos > bgruening > create_tool_recommendation_model
diff optimise_hyperparameters.py @ 2:76251d1ccdcc draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 6fa2a0294d615c9f267b766337dca0b2d3637219"
author | bgruening |
---|---|
date | Fri, 11 Oct 2019 18:24:54 -0400 |
parents | 9bf25dbe00ad |
children | 5b3c08710e47 |
line wrap: on
line diff
--- a/optimise_hyperparameters.py Wed Sep 25 06:42:40 2019 -0400 +++ b/optimise_hyperparameters.py Fri Oct 11 18:24:54 2019 -0400 @@ -22,7 +22,7 @@ """ Init method. """ @classmethod - def train_model(self, config, reverse_dictionary, train_data, train_labels, test_data, test_labels, class_weights): + def train_model(self, config, reverse_dictionary, train_data, train_labels, class_weights): """ Train a model and report accuracy """ @@ -46,7 +46,7 @@ # get dimensions dimensions = len(reverse_dictionary) + 1 best_model_params = dict() - early_stopping = EarlyStopping(monitor='val_loss', mode='min', min_delta=1e-4, verbose=1, patience=1) + early_stopping = EarlyStopping(monitor='val_loss', mode='min', verbose=1, min_delta=1e-4) # specify the search space for finding the best combination of parameters using Bayesian optimisation params = { @@ -82,11 +82,12 @@ validation_split=validation_split, callbacks=[early_stopping] ) - return {'loss': model_fit.history["val_loss"][-1], 'status': STATUS_OK} - # minimize the objective function using the set of parameters above4 + return {'loss': model_fit.history["val_loss"][-1], 'status': STATUS_OK, 'model': model} + # minimize the objective function using the set of parameters above trials = Trials() learned_params = fmin(create_model, params, trials=trials, algo=tpe.suggest, max_evals=int(config["max_evals"])) - print(learned_params) + best_model = trials.results[np.argmin([r['loss'] for r in trials.results])]['model'] + # set the best params with respective values for item in learned_params: item_val = learned_params[item] @@ -96,4 +97,4 @@ best_model_params[item] = l_recurrent_activations[item_val] else: best_model_params[item] = item_val - return best_model_params + return best_model_params, best_model