Mercurial > repos > bgruening > create_tool_recommendation_model
diff predict_tool_usage.py @ 0:9bf25dbe00ad draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
author | bgruening |
---|---|
date | Wed, 28 Aug 2019 07:19:38 -0400 |
parents | |
children | 5b3c08710e47 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/predict_tool_usage.py Wed Aug 28 07:19:38 2019 -0400 @@ -0,0 +1,113 @@ +""" +Predict tool usage to weigh the predicted tools +""" + +import os +import numpy as np +import warnings +import csv +import collections + +from sklearn.svm import SVR +from sklearn.model_selection import GridSearchCV +from sklearn.pipeline import Pipeline + +import utils + +warnings.filterwarnings("ignore") + +main_path = os.getcwd() + + +class ToolPopularity: + + @classmethod + def __init__(self): + """ Init method. """ + + @classmethod + def extract_tool_usage(self, tool_usage_file, cutoff_date, dictionary): + """ + Extract the tool usage over time for each tool + """ + tool_usage_dict = dict() + all_dates = list() + all_tool_list = list(dictionary.keys()) + with open(tool_usage_file, 'rt') as usage_file: + tool_usage = csv.reader(usage_file, delimiter='\t') + for index, row in enumerate(tool_usage): + if (str(row[1]) > cutoff_date) is True: + tool_id = utils.format_tool_id(row[0]) + if tool_id in all_tool_list: + all_dates.append(row[1]) + if tool_id not in tool_usage_dict: + tool_usage_dict[tool_id] = dict() + tool_usage_dict[tool_id][row[1]] = int(row[2]) + else: + curr_date = row[1] + # merge the usage of different version of tools into one + if curr_date in tool_usage_dict[tool_id]: + tool_usage_dict[tool_id][curr_date] += int(row[2]) + else: + tool_usage_dict[tool_id][curr_date] = int(row[2]) + # get unique dates + unique_dates = list(set(all_dates)) + for tool in tool_usage_dict: + usage = tool_usage_dict[tool] + # extract those dates for which tool's usage is not present in raw data + dates_not_present = list(set(unique_dates) ^ set(usage.keys())) + # impute the missing values by 0 + for dt in dates_not_present: + tool_usage_dict[tool][dt] = 0 + # sort the usage list by date + tool_usage_dict[tool] = collections.OrderedDict(sorted(usage.items())) + return tool_usage_dict + + @classmethod + def learn_tool_popularity(self, x_reshaped, y_reshaped): + """ + Fit a curve for the tool usage over time to predict future tool usage + """ + epsilon = 0.0 + cv = 5 + s_typ = 'neg_mean_absolute_error' + n_jobs = 4 + s_error = 1 + iid = True + tr_score = False + try: + pipe = Pipeline(steps=[('regressor', SVR(gamma='scale'))]) + param_grid = { + 'regressor__kernel': ['rbf', 'poly', 'linear'], + 'regressor__degree': [2, 3] + } + search = GridSearchCV(pipe, param_grid, iid=iid, cv=cv, scoring=s_typ, n_jobs=n_jobs, error_score=s_error, return_train_score=tr_score) + search.fit(x_reshaped, y_reshaped.ravel()) + model = search.best_estimator_ + # set the next time point to get prediction for + prediction_point = np.reshape([x_reshaped[-1][0] + 1], (1, 1)) + prediction = model.predict(prediction_point) + if prediction < epsilon: + prediction = [epsilon] + return prediction[0] + except Exception: + return epsilon + + @classmethod + def get_pupularity_prediction(self, tools_usage): + """ + Get the popularity prediction for each tool + """ + usage_prediction = dict() + for tool_name, usage in tools_usage.items(): + y_val = list() + x_val = list() + for x, y in usage.items(): + x_val.append(x) + y_val.append(y) + x_pos = np.arange(len(x_val)) + x_reshaped = x_pos.reshape(len(x_pos), 1) + y_reshaped = np.reshape(y_val, (len(x_pos), 1)) + prediction = np.round(self.learn_tool_popularity(x_reshaped, y_reshaped), 8) + usage_prediction[tool_name] = prediction + return usage_prediction