Mercurial > repos > bgruening > create_tool_recommendation_model
diff utils.py @ 5:4f7e6612906b draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 5eebc0cb44e71f581d548b7e842002705dd155eb"
author | bgruening |
---|---|
date | Fri, 06 May 2022 09:05:18 +0000 |
parents | afec8c595124 |
children | e94dc7945639 |
line wrap: on
line diff
--- a/utils.py Tue Jul 07 03:25:49 2020 -0400 +++ b/utils.py Fri May 06 09:05:18 2022 +0000 @@ -1,10 +1,11 @@ -import numpy as np import json -import h5py import random + +import h5py +import numpy as np +import tensorflow as tf from numpy.random import choice - -from keras import backend as K +from tensorflow.keras import backend def read_file(file_path): @@ -29,10 +30,10 @@ """ Create an h5 file with the trained weights and associated dicts """ - hf_file = h5py.File(dump_file, 'w') + hf_file = h5py.File(dump_file, "w") for key in model_values: value = model_values[key] - if key == 'model_weights': + if key == "model_weights": for idx, item in enumerate(value): w_key = "weight_" + str(idx) if w_key in hf_file: @@ -54,14 +55,19 @@ """ weight_values = list(class_weights.values()) weight_values.extend(weight_values) + def weighted_binary_crossentropy(y_true, y_pred): # add another dimension to compute dot product - expanded_weights = K.expand_dims(weight_values, axis=-1) - return K.dot(K.binary_crossentropy(y_true, y_pred), expanded_weights) + expanded_weights = tf.expand_dims(weight_values, axis=-1) + bce = backend.binary_crossentropy(y_true, y_pred) + return backend.dot(bce, expanded_weights) + return weighted_binary_crossentropy -def balanced_sample_generator(train_data, train_labels, batch_size, l_tool_tr_samples, reverse_dictionary): +def balanced_sample_generator( + train_data, train_labels, batch_size, l_tool_tr_samples, reverse_dictionary +): while True: dimension = train_data.shape[1] n_classes = train_labels.shape[1] @@ -80,7 +86,18 @@ yield generator_batch_data, generator_batch_labels -def compute_precision(model, x, y, reverse_data_dictionary, usage_scores, actual_classes_pos, topk, standard_conn, last_tool_id, lowest_tool_ids): +def compute_precision( + model, + x, + y, + reverse_data_dictionary, + usage_scores, + actual_classes_pos, + topk, + standard_conn, + last_tool_id, + lowest_tool_ids, +): """ Compute absolute and compatible precision """ @@ -137,7 +154,9 @@ else: lowest_pub_prec = np.nan if standard_topk_prediction_pos in usage_scores: - usage_wt_score.append(np.log(usage_scores[standard_topk_prediction_pos] + 1.0)) + usage_wt_score.append( + np.log(usage_scores[standard_topk_prediction_pos] + 1.0) + ) else: # count precision only when there is actually true published tools # else set to np.nan. Set to 0 only when there is wrong prediction @@ -148,7 +167,9 @@ pred_t_name = reverse_data_dictionary[int(normal_topk_prediction_pos)] if pred_t_name in actual_next_tool_names: if normal_topk_prediction_pos in usage_scores: - usage_wt_score.append(np.log(usage_scores[normal_topk_prediction_pos] + 1.0)) + usage_wt_score.append( + np.log(usage_scores[normal_topk_prediction_pos] + 1.0) + ) top_precision = 1.0 if last_tool_id in lowest_tool_ids: lowest_norm_prec = 1.0 @@ -166,7 +187,16 @@ return lowest_ids -def verify_model(model, x, y, reverse_data_dictionary, usage_scores, standard_conn, lowest_tool_ids, topk_list=[1, 2, 3]): +def verify_model( + model, + x, + y, + reverse_data_dictionary, + usage_scores, + standard_conn, + lowest_tool_ids, + topk_list=[1, 2, 3], +): """ Verify the model on test data """ @@ -187,7 +217,24 @@ test_sample = x[i, :] last_tool_id = str(int(test_sample[-1])) for index, abs_topk in enumerate(topk_list): - usg_wt_score, absolute_precision, pub_prec, lowest_p_prec, lowest_n_prec = compute_precision(model, test_sample, y, reverse_data_dictionary, usage_scores, actual_classes_pos, abs_topk, standard_conn, last_tool_id, lowest_tool_ids) + ( + usg_wt_score, + absolute_precision, + pub_prec, + lowest_p_prec, + lowest_n_prec, + ) = compute_precision( + model, + test_sample, + y, + reverse_data_dictionary, + usage_scores, + actual_classes_pos, + abs_topk, + standard_conn, + last_tool_id, + lowest_tool_ids, + ) precision[i][index] = absolute_precision usage_weights[i][index] = usg_wt_score epo_pub_prec[i][index] = pub_prec @@ -202,22 +249,36 @@ mean_pub_prec = np.nanmean(epo_pub_prec, axis=0) mean_lowest_pub_prec = np.nanmean(epo_lowest_tools_pub_prec, axis=0) mean_lowest_norm_prec = np.nanmean(epo_lowest_tools_norm_prec, axis=0) - return mean_usage, mean_precision, mean_pub_prec, mean_lowest_pub_prec, mean_lowest_norm_prec, lowest_counter + return ( + mean_usage, + mean_precision, + mean_pub_prec, + mean_lowest_pub_prec, + mean_lowest_norm_prec, + lowest_counter, + ) -def save_model(results, data_dictionary, compatible_next_tools, trained_model_path, class_weights, standard_connections): +def save_model( + results, + data_dictionary, + compatible_next_tools, + trained_model_path, + class_weights, + standard_connections, +): # save files trained_model = results["model"] best_model_parameters = results["best_parameters"] model_config = trained_model.to_json() model_weights = trained_model.get_weights() model_values = { - 'data_dictionary': data_dictionary, - 'model_config': model_config, - 'best_parameters': best_model_parameters, - 'model_weights': model_weights, + "data_dictionary": data_dictionary, + "model_config": model_config, + "best_parameters": best_model_parameters, + "model_weights": model_weights, "compatible_tools": compatible_next_tools, "class_weights": class_weights, - "standard_connections": standard_connections + "standard_connections": standard_connections, } set_trained_model(trained_model_path, model_values)