Mercurial > repos > bgruening > create_tool_recommendation_model
comparison predict_tool_usage.py @ 5:4f7e6612906b draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 5eebc0cb44e71f581d548b7e842002705dd155eb"
author | bgruening |
---|---|
date | Fri, 06 May 2022 09:05:18 +0000 |
parents | 5b3c08710e47 |
children | e94dc7945639 |
comparison
equal
deleted
inserted
replaced
4:afec8c595124 | 5:4f7e6612906b |
---|---|
1 """ | 1 """ |
2 Predict tool usage to weigh the predicted tools | 2 Predict tool usage to weigh the predicted tools |
3 """ | 3 """ |
4 | 4 |
5 import collections | |
6 import csv | |
5 import os | 7 import os |
8 import warnings | |
9 | |
6 import numpy as np | 10 import numpy as np |
7 import warnings | 11 import utils |
8 import csv | |
9 import collections | |
10 | |
11 from sklearn.svm import SVR | |
12 from sklearn.model_selection import GridSearchCV | 12 from sklearn.model_selection import GridSearchCV |
13 from sklearn.pipeline import Pipeline | 13 from sklearn.pipeline import Pipeline |
14 | 14 from sklearn.svm import SVR |
15 import utils | |
16 | 15 |
17 warnings.filterwarnings("ignore") | 16 warnings.filterwarnings("ignore") |
18 | 17 |
19 main_path = os.getcwd() | 18 main_path = os.getcwd() |
20 | 19 |
21 | 20 |
22 class ToolPopularity: | 21 class ToolPopularity: |
23 | |
24 def __init__(self): | 22 def __init__(self): |
25 """ Init method. """ | 23 """ Init method. """ |
26 | 24 |
27 def extract_tool_usage(self, tool_usage_file, cutoff_date, dictionary): | 25 def extract_tool_usage(self, tool_usage_file, cutoff_date, dictionary): |
28 """ | 26 """ |
29 Extract the tool usage over time for each tool | 27 Extract the tool usage over time for each tool |
30 """ | 28 """ |
31 tool_usage_dict = dict() | 29 tool_usage_dict = dict() |
32 all_dates = list() | 30 all_dates = list() |
33 all_tool_list = list(dictionary.keys()) | 31 all_tool_list = list(dictionary.keys()) |
34 with open(tool_usage_file, 'rt') as usage_file: | 32 with open(tool_usage_file, "rt") as usage_file: |
35 tool_usage = csv.reader(usage_file, delimiter='\t') | 33 tool_usage = csv.reader(usage_file, delimiter="\t") |
36 for index, row in enumerate(tool_usage): | 34 for index, row in enumerate(tool_usage): |
37 if (str(row[1]) > cutoff_date) is True: | 35 row = [item.strip() for item in row] |
36 if (str(row[1]).strip() > cutoff_date) is True: | |
38 tool_id = utils.format_tool_id(row[0]) | 37 tool_id = utils.format_tool_id(row[0]) |
39 if tool_id in all_tool_list: | 38 if tool_id in all_tool_list: |
40 all_dates.append(row[1]) | 39 all_dates.append(row[1]) |
41 if tool_id not in tool_usage_dict: | 40 if tool_id not in tool_usage_dict: |
42 tool_usage_dict[tool_id] = dict() | 41 tool_usage_dict[tool_id] = dict() |
65 """ | 64 """ |
66 Fit a curve for the tool usage over time to predict future tool usage | 65 Fit a curve for the tool usage over time to predict future tool usage |
67 """ | 66 """ |
68 epsilon = 0.0 | 67 epsilon = 0.0 |
69 cv = 5 | 68 cv = 5 |
70 s_typ = 'neg_mean_absolute_error' | 69 s_typ = "neg_mean_absolute_error" |
71 n_jobs = 4 | 70 n_jobs = 4 |
72 s_error = 1 | 71 s_error = 1 |
73 iid = True | |
74 tr_score = False | 72 tr_score = False |
75 try: | 73 try: |
76 pipe = Pipeline(steps=[('regressor', SVR(gamma='scale'))]) | 74 pipe = Pipeline(steps=[("regressor", SVR(gamma="scale"))]) |
77 param_grid = { | 75 param_grid = { |
78 'regressor__kernel': ['rbf', 'poly', 'linear'], | 76 "regressor__kernel": ["rbf", "poly", "linear"], |
79 'regressor__degree': [2, 3] | 77 "regressor__degree": [2, 3], |
80 } | 78 } |
81 search = GridSearchCV(pipe, param_grid, iid=iid, cv=cv, scoring=s_typ, n_jobs=n_jobs, error_score=s_error, return_train_score=tr_score) | 79 search = GridSearchCV( |
80 pipe, | |
81 param_grid, | |
82 cv=cv, | |
83 scoring=s_typ, | |
84 n_jobs=n_jobs, | |
85 error_score=s_error, | |
86 return_train_score=tr_score, | |
87 ) | |
82 search.fit(x_reshaped, y_reshaped.ravel()) | 88 search.fit(x_reshaped, y_reshaped.ravel()) |
83 model = search.best_estimator_ | 89 model = search.best_estimator_ |
84 # set the next time point to get prediction for | 90 # set the next time point to get prediction for |
85 prediction_point = np.reshape([x_reshaped[-1][0] + 1], (1, 1)) | 91 prediction_point = np.reshape([x_reshaped[-1][0] + 1], (1, 1)) |
86 prediction = model.predict(prediction_point) | 92 prediction = model.predict(prediction_point) |
87 if prediction < epsilon: | 93 if prediction < epsilon: |
88 prediction = [epsilon] | 94 prediction = [epsilon] |
89 return prediction[0] | 95 return prediction[0] |
90 except Exception: | 96 except Exception as e: |
97 print(e) | |
91 return epsilon | 98 return epsilon |
92 | 99 |
93 def get_pupularity_prediction(self, tools_usage): | 100 def get_pupularity_prediction(self, tools_usage): |
94 """ | 101 """ |
95 Get the popularity prediction for each tool | 102 Get the popularity prediction for each tool |