Mercurial > repos > bgruening > create_tool_recommendation_model
diff prepare_data.py @ 5:4f7e6612906b draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 5eebc0cb44e71f581d548b7e842002705dd155eb"
author | bgruening |
---|---|
date | Fri, 06 May 2022 09:05:18 +0000 |
parents | afec8c595124 |
children | e94dc7945639 |
line wrap: on
line diff
--- a/prepare_data.py Tue Jul 07 03:25:49 2020 -0400 +++ b/prepare_data.py Fri May 06 09:05:18 2022 +0000 @@ -4,18 +4,17 @@ into the test and training sets """ +import collections import os -import collections -import numpy as np import random +import numpy as np import predict_tool_usage main_path = os.getcwd() class PrepareData: - def __init__(self, max_seq_length, test_data_share): """ Init method. """ self.max_tool_sequence_len = max_seq_length @@ -27,15 +26,20 @@ """ tokens = list() raw_paths = workflow_paths - raw_paths = [x.replace("\n", '') for x in raw_paths] + raw_paths = [x.replace("\n", "") for x in raw_paths] for item in raw_paths: split_items = item.split(",") for token in split_items: - if token is not "": + if token != "": tokens.append(token) tokens = list(set(tokens)) tokens = np.array(tokens) - tokens = np.reshape(tokens, [-1, ]) + tokens = np.reshape( + tokens, + [ + -1, + ], + ) return tokens, raw_paths def create_new_dict(self, new_data_dict): @@ -60,7 +64,10 @@ dictionary = dict() for word, _ in count: dictionary[word] = len(dictionary) + 1 - dictionary, reverse_dictionary = self.assemble_dictionary(dictionary, old_data_dictionary) + word = word.strip() + dictionary, reverse_dictionary = self.assemble_dictionary( + dictionary, old_data_dictionary + ) return dictionary, reverse_dictionary def decompose_paths(self, paths, dictionary): @@ -74,13 +81,17 @@ if len_tools <= self.max_tool_sequence_len: for window in range(1, len_tools): sequence = tools[0: window + 1] - tools_pos = [str(dictionary[str(tool_item)]) for tool_item in sequence] + tools_pos = [ + str(dictionary[str(tool_item)]) for tool_item in sequence + ] if len(tools_pos) > 1: sub_paths_pos.append(",".join(tools_pos)) sub_paths_pos = list(set(sub_paths_pos)) return sub_paths_pos - def prepare_paths_labels_dictionary(self, dictionary, reverse_dictionary, paths, compatible_next_tools): + def prepare_paths_labels_dictionary( + self, dictionary, reverse_dictionary, paths, compatible_next_tools + ): """ Create a dictionary of sequences with their labels for training and test paths """ @@ -90,14 +101,18 @@ if item and item not in "": tools = item.split(",") label = tools[-1] - train_tools = tools[:len(tools) - 1] + train_tools = tools[: len(tools) - 1] last_but_one_name = reverse_dictionary[int(train_tools[-1])] try: - compatible_tools = compatible_next_tools[last_but_one_name].split(",") + compatible_tools = compatible_next_tools[last_but_one_name].split( + "," + ) except Exception: continue if len(compatible_tools) > 0: - compatible_tools_ids = [str(dictionary[x]) for x in compatible_tools] + compatible_tools_ids = [ + str(dictionary[x]) for x in compatible_tools + ] compatible_tools_ids.append(label) composite_labels = ",".join(compatible_tools_ids) train_tools = ",".join(train_tools) @@ -127,7 +142,9 @@ train_counter += 1 return data_mat, label_mat - def pad_paths(self, paths_dictionary, num_classes, standard_connections, reverse_dictionary): + def pad_paths( + self, paths_dictionary, num_classes, standard_connections, reverse_dictionary + ): """ Add padding to the tools sequences and create multi-hot encoded labels """ @@ -231,12 +248,22 @@ l_tool_tr_samples[last_tool_id].append(index) return l_tool_tr_samples - def get_data_labels_matrices(self, workflow_paths, tool_usage_path, cutoff_date, compatible_next_tools, standard_connections, old_data_dictionary={}): + def get_data_labels_matrices( + self, + workflow_paths, + tool_usage_path, + cutoff_date, + compatible_next_tools, + standard_connections, + old_data_dictionary={}, + ): """ Convert the training and test paths into corresponding numpy matrices """ processed_data, raw_paths = self.process_workflow_paths(workflow_paths) - dictionary, rev_dict = self.create_data_dictionary(processed_data, old_data_dictionary) + dictionary, rev_dict = self.create_data_dictionary( + processed_data, old_data_dictionary + ) num_classes = len(dictionary) print("Raw paths: %d" % len(raw_paths)) @@ -247,18 +274,26 @@ random.shuffle(all_unique_paths) print("Creating dictionaries...") - multilabels_paths = self.prepare_paths_labels_dictionary(dictionary, rev_dict, all_unique_paths, compatible_next_tools) + multilabels_paths = self.prepare_paths_labels_dictionary( + dictionary, rev_dict, all_unique_paths, compatible_next_tools + ) print("Complete data: %d" % len(multilabels_paths)) - train_paths_dict, test_paths_dict = self.split_test_train_data(multilabels_paths) + train_paths_dict, test_paths_dict = self.split_test_train_data( + multilabels_paths + ) print("Train data: %d" % len(train_paths_dict)) print("Test data: %d" % len(test_paths_dict)) print("Padding train and test data...") # pad training and test data with leading zeros - test_data, test_labels = self.pad_paths(test_paths_dict, num_classes, standard_connections, rev_dict) - train_data, train_labels = self.pad_paths(train_paths_dict, num_classes, standard_connections, rev_dict) + test_data, test_labels = self.pad_paths( + test_paths_dict, num_classes, standard_connections, rev_dict + ) + train_data, train_labels = self.pad_paths( + train_paths_dict, num_classes, standard_connections, rev_dict + ) print("Estimating sample frequency...") l_tool_freq = self.get_train_last_tool_freq(train_paths_dict, rev_dict) @@ -274,4 +309,15 @@ # get class weights using the predicted usage for each tool class_weights = self.assign_class_weights(num_classes, t_pred_usage) - return train_data, train_labels, test_data, test_labels, dictionary, rev_dict, class_weights, t_pred_usage, l_tool_freq, l_tool_tr_samples + return ( + train_data, + train_labels, + test_data, + test_labels, + dictionary, + rev_dict, + class_weights, + t_pred_usage, + l_tool_freq, + l_tool_tr_samples, + )