Mercurial > repos > bgruening > create_tool_recommendation_model
comparison predict_tool_usage.py @ 0:9bf25dbe00ad draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
author | bgruening |
---|---|
date | Wed, 28 Aug 2019 07:19:38 -0400 |
parents | |
children | 5b3c08710e47 |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:9bf25dbe00ad |
---|---|
1 """ | |
2 Predict tool usage to weigh the predicted tools | |
3 """ | |
4 | |
5 import os | |
6 import numpy as np | |
7 import warnings | |
8 import csv | |
9 import collections | |
10 | |
11 from sklearn.svm import SVR | |
12 from sklearn.model_selection import GridSearchCV | |
13 from sklearn.pipeline import Pipeline | |
14 | |
15 import utils | |
16 | |
17 warnings.filterwarnings("ignore") | |
18 | |
19 main_path = os.getcwd() | |
20 | |
21 | |
22 class ToolPopularity: | |
23 | |
24 @classmethod | |
25 def __init__(self): | |
26 """ Init method. """ | |
27 | |
28 @classmethod | |
29 def extract_tool_usage(self, tool_usage_file, cutoff_date, dictionary): | |
30 """ | |
31 Extract the tool usage over time for each tool | |
32 """ | |
33 tool_usage_dict = dict() | |
34 all_dates = list() | |
35 all_tool_list = list(dictionary.keys()) | |
36 with open(tool_usage_file, 'rt') as usage_file: | |
37 tool_usage = csv.reader(usage_file, delimiter='\t') | |
38 for index, row in enumerate(tool_usage): | |
39 if (str(row[1]) > cutoff_date) is True: | |
40 tool_id = utils.format_tool_id(row[0]) | |
41 if tool_id in all_tool_list: | |
42 all_dates.append(row[1]) | |
43 if tool_id not in tool_usage_dict: | |
44 tool_usage_dict[tool_id] = dict() | |
45 tool_usage_dict[tool_id][row[1]] = int(row[2]) | |
46 else: | |
47 curr_date = row[1] | |
48 # merge the usage of different version of tools into one | |
49 if curr_date in tool_usage_dict[tool_id]: | |
50 tool_usage_dict[tool_id][curr_date] += int(row[2]) | |
51 else: | |
52 tool_usage_dict[tool_id][curr_date] = int(row[2]) | |
53 # get unique dates | |
54 unique_dates = list(set(all_dates)) | |
55 for tool in tool_usage_dict: | |
56 usage = tool_usage_dict[tool] | |
57 # extract those dates for which tool's usage is not present in raw data | |
58 dates_not_present = list(set(unique_dates) ^ set(usage.keys())) | |
59 # impute the missing values by 0 | |
60 for dt in dates_not_present: | |
61 tool_usage_dict[tool][dt] = 0 | |
62 # sort the usage list by date | |
63 tool_usage_dict[tool] = collections.OrderedDict(sorted(usage.items())) | |
64 return tool_usage_dict | |
65 | |
66 @classmethod | |
67 def learn_tool_popularity(self, x_reshaped, y_reshaped): | |
68 """ | |
69 Fit a curve for the tool usage over time to predict future tool usage | |
70 """ | |
71 epsilon = 0.0 | |
72 cv = 5 | |
73 s_typ = 'neg_mean_absolute_error' | |
74 n_jobs = 4 | |
75 s_error = 1 | |
76 iid = True | |
77 tr_score = False | |
78 try: | |
79 pipe = Pipeline(steps=[('regressor', SVR(gamma='scale'))]) | |
80 param_grid = { | |
81 'regressor__kernel': ['rbf', 'poly', 'linear'], | |
82 'regressor__degree': [2, 3] | |
83 } | |
84 search = GridSearchCV(pipe, param_grid, iid=iid, cv=cv, scoring=s_typ, n_jobs=n_jobs, error_score=s_error, return_train_score=tr_score) | |
85 search.fit(x_reshaped, y_reshaped.ravel()) | |
86 model = search.best_estimator_ | |
87 # set the next time point to get prediction for | |
88 prediction_point = np.reshape([x_reshaped[-1][0] + 1], (1, 1)) | |
89 prediction = model.predict(prediction_point) | |
90 if prediction < epsilon: | |
91 prediction = [epsilon] | |
92 return prediction[0] | |
93 except Exception: | |
94 return epsilon | |
95 | |
96 @classmethod | |
97 def get_pupularity_prediction(self, tools_usage): | |
98 """ | |
99 Get the popularity prediction for each tool | |
100 """ | |
101 usage_prediction = dict() | |
102 for tool_name, usage in tools_usage.items(): | |
103 y_val = list() | |
104 x_val = list() | |
105 for x, y in usage.items(): | |
106 x_val.append(x) | |
107 y_val.append(y) | |
108 x_pos = np.arange(len(x_val)) | |
109 x_reshaped = x_pos.reshape(len(x_pos), 1) | |
110 y_reshaped = np.reshape(y_val, (len(x_pos), 1)) | |
111 prediction = np.round(self.learn_tool_popularity(x_reshaped, y_reshaped), 8) | |
112 usage_prediction[tool_name] = prediction | |
113 return usage_prediction |