Mercurial > repos > bgruening > create_tool_recommendation_model
comparison predict_tool_usage.py @ 6:e94dc7945639 draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
author | bgruening |
---|---|
date | Sun, 16 Oct 2022 11:52:10 +0000 |
parents | 4f7e6612906b |
children |
comparison
equal
deleted
inserted
replaced
5:4f7e6612906b | 6:e94dc7945639 |
---|---|
1 """ | 1 """ |
2 Predict tool usage to weigh the predicted tools | 2 Predict tool usage to weigh the predicted tools |
3 """ | 3 """ |
4 | 4 |
5 import collections | 5 import collections |
6 import csv | |
7 import os | |
8 import warnings | |
9 | 6 |
10 import numpy as np | 7 import numpy as np |
11 import utils | 8 import utils |
12 from sklearn.model_selection import GridSearchCV | 9 from sklearn.model_selection import GridSearchCV |
13 from sklearn.pipeline import Pipeline | 10 from sklearn.pipeline import Pipeline |
14 from sklearn.svm import SVR | 11 from sklearn.svm import SVR |
15 | 12 |
16 warnings.filterwarnings("ignore") | |
17 | |
18 main_path = os.getcwd() | |
19 | |
20 | 13 |
21 class ToolPopularity: | 14 class ToolPopularity: |
15 | |
22 def __init__(self): | 16 def __init__(self): |
23 """ Init method. """ | 17 """ Init method. """ |
24 | 18 |
25 def extract_tool_usage(self, tool_usage_file, cutoff_date, dictionary): | 19 def extract_tool_usage(self, tool_usage_df, cutoff_date, dictionary): |
26 """ | 20 """ |
27 Extract the tool usage over time for each tool | 21 Extract the tool usage over time for each tool |
28 """ | 22 """ |
29 tool_usage_dict = dict() | 23 tool_usage_dict = dict() |
30 all_dates = list() | 24 all_dates = list() |
31 all_tool_list = list(dictionary.keys()) | 25 all_tool_list = list(dictionary.keys()) |
32 with open(tool_usage_file, "rt") as usage_file: | 26 for index, row in tool_usage_df.iterrows(): |
33 tool_usage = csv.reader(usage_file, delimiter="\t") | 27 row = row.tolist() |
34 for index, row in enumerate(tool_usage): | 28 row = [str(item).strip() for item in row] |
35 row = [item.strip() for item in row] | 29 if (row[1] > cutoff_date) is True: |
36 if (str(row[1]).strip() > cutoff_date) is True: | 30 tool_id = utils.format_tool_id(row[0]) |
37 tool_id = utils.format_tool_id(row[0]) | 31 if tool_id in all_tool_list: |
38 if tool_id in all_tool_list: | 32 all_dates.append(row[1]) |
39 all_dates.append(row[1]) | 33 if tool_id not in tool_usage_dict: |
40 if tool_id not in tool_usage_dict: | 34 tool_usage_dict[tool_id] = dict() |
41 tool_usage_dict[tool_id] = dict() | 35 tool_usage_dict[tool_id][row[1]] = int(float(row[2])) |
42 tool_usage_dict[tool_id][row[1]] = int(row[2]) | 36 else: |
37 curr_date = row[1] | |
38 # merge the usage of different version of tools into one | |
39 if curr_date in tool_usage_dict[tool_id]: | |
40 tool_usage_dict[tool_id][curr_date] += int(float(row[2])) | |
43 else: | 41 else: |
44 curr_date = row[1] | 42 tool_usage_dict[tool_id][curr_date] = int(float(row[2])) |
45 # merge the usage of different version of tools into one | |
46 if curr_date in tool_usage_dict[tool_id]: | |
47 tool_usage_dict[tool_id][curr_date] += int(row[2]) | |
48 else: | |
49 tool_usage_dict[tool_id][curr_date] = int(row[2]) | |
50 # get unique dates | 43 # get unique dates |
51 unique_dates = list(set(all_dates)) | 44 unique_dates = list(set(all_dates)) |
52 for tool in tool_usage_dict: | 45 for tool in tool_usage_dict: |
53 usage = tool_usage_dict[tool] | 46 usage = tool_usage_dict[tool] |
54 # extract those dates for which tool's usage is not present in raw data | 47 # extract those dates for which tool's usage is not present in raw data |
64 """ | 57 """ |
65 Fit a curve for the tool usage over time to predict future tool usage | 58 Fit a curve for the tool usage over time to predict future tool usage |
66 """ | 59 """ |
67 epsilon = 0.0 | 60 epsilon = 0.0 |
68 cv = 5 | 61 cv = 5 |
69 s_typ = "neg_mean_absolute_error" | 62 s_typ = 'neg_mean_absolute_error' |
70 n_jobs = 4 | 63 n_jobs = 4 |
71 s_error = 1 | 64 s_error = 1 |
72 tr_score = False | 65 tr_score = False |
73 try: | 66 try: |
74 pipe = Pipeline(steps=[("regressor", SVR(gamma="scale"))]) | 67 pipe = Pipeline(steps=[('regressor', SVR(gamma='scale'))]) |
75 param_grid = { | 68 param_grid = { |
76 "regressor__kernel": ["rbf", "poly", "linear"], | 69 'regressor__kernel': ['rbf', 'poly', 'linear'], |
77 "regressor__degree": [2, 3], | 70 'regressor__degree': [2, 3] |
78 } | 71 } |
79 search = GridSearchCV( | 72 search = GridSearchCV(pipe, param_grid, cv=cv, scoring=s_typ, n_jobs=n_jobs, error_score=s_error, return_train_score=tr_score) |
80 pipe, | |
81 param_grid, | |
82 cv=cv, | |
83 scoring=s_typ, | |
84 n_jobs=n_jobs, | |
85 error_score=s_error, | |
86 return_train_score=tr_score, | |
87 ) | |
88 search.fit(x_reshaped, y_reshaped.ravel()) | 73 search.fit(x_reshaped, y_reshaped.ravel()) |
89 model = search.best_estimator_ | 74 model = search.best_estimator_ |
90 # set the next time point to get prediction for | 75 # set the next time point to get prediction for |
91 prediction_point = np.reshape([x_reshaped[-1][0] + 1], (1, 1)) | 76 prediction_point = np.reshape([x_reshaped[-1][0] + 1], (1, 1)) |
92 prediction = model.predict(prediction_point) | 77 prediction = model.predict(prediction_point) |