Mercurial > repos > bgruening > create_tool_recommendation_model
comparison predict_tool_usage.py @ 6:e94dc7945639 draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
| author | bgruening |
|---|---|
| date | Sun, 16 Oct 2022 11:52:10 +0000 |
| parents | 4f7e6612906b |
| children |
comparison
equal
deleted
inserted
replaced
| 5:4f7e6612906b | 6:e94dc7945639 |
|---|---|
| 1 """ | 1 """ |
| 2 Predict tool usage to weigh the predicted tools | 2 Predict tool usage to weigh the predicted tools |
| 3 """ | 3 """ |
| 4 | 4 |
| 5 import collections | 5 import collections |
| 6 import csv | |
| 7 import os | |
| 8 import warnings | |
| 9 | 6 |
| 10 import numpy as np | 7 import numpy as np |
| 11 import utils | 8 import utils |
| 12 from sklearn.model_selection import GridSearchCV | 9 from sklearn.model_selection import GridSearchCV |
| 13 from sklearn.pipeline import Pipeline | 10 from sklearn.pipeline import Pipeline |
| 14 from sklearn.svm import SVR | 11 from sklearn.svm import SVR |
| 15 | 12 |
| 16 warnings.filterwarnings("ignore") | |
| 17 | |
| 18 main_path = os.getcwd() | |
| 19 | |
| 20 | 13 |
| 21 class ToolPopularity: | 14 class ToolPopularity: |
| 15 | |
| 22 def __init__(self): | 16 def __init__(self): |
| 23 """ Init method. """ | 17 """ Init method. """ |
| 24 | 18 |
| 25 def extract_tool_usage(self, tool_usage_file, cutoff_date, dictionary): | 19 def extract_tool_usage(self, tool_usage_df, cutoff_date, dictionary): |
| 26 """ | 20 """ |
| 27 Extract the tool usage over time for each tool | 21 Extract the tool usage over time for each tool |
| 28 """ | 22 """ |
| 29 tool_usage_dict = dict() | 23 tool_usage_dict = dict() |
| 30 all_dates = list() | 24 all_dates = list() |
| 31 all_tool_list = list(dictionary.keys()) | 25 all_tool_list = list(dictionary.keys()) |
| 32 with open(tool_usage_file, "rt") as usage_file: | 26 for index, row in tool_usage_df.iterrows(): |
| 33 tool_usage = csv.reader(usage_file, delimiter="\t") | 27 row = row.tolist() |
| 34 for index, row in enumerate(tool_usage): | 28 row = [str(item).strip() for item in row] |
| 35 row = [item.strip() for item in row] | 29 if (row[1] > cutoff_date) is True: |
| 36 if (str(row[1]).strip() > cutoff_date) is True: | 30 tool_id = utils.format_tool_id(row[0]) |
| 37 tool_id = utils.format_tool_id(row[0]) | 31 if tool_id in all_tool_list: |
| 38 if tool_id in all_tool_list: | 32 all_dates.append(row[1]) |
| 39 all_dates.append(row[1]) | 33 if tool_id not in tool_usage_dict: |
| 40 if tool_id not in tool_usage_dict: | 34 tool_usage_dict[tool_id] = dict() |
| 41 tool_usage_dict[tool_id] = dict() | 35 tool_usage_dict[tool_id][row[1]] = int(float(row[2])) |
| 42 tool_usage_dict[tool_id][row[1]] = int(row[2]) | 36 else: |
| 37 curr_date = row[1] | |
| 38 # merge the usage of different version of tools into one | |
| 39 if curr_date in tool_usage_dict[tool_id]: | |
| 40 tool_usage_dict[tool_id][curr_date] += int(float(row[2])) | |
| 43 else: | 41 else: |
| 44 curr_date = row[1] | 42 tool_usage_dict[tool_id][curr_date] = int(float(row[2])) |
| 45 # merge the usage of different version of tools into one | |
| 46 if curr_date in tool_usage_dict[tool_id]: | |
| 47 tool_usage_dict[tool_id][curr_date] += int(row[2]) | |
| 48 else: | |
| 49 tool_usage_dict[tool_id][curr_date] = int(row[2]) | |
| 50 # get unique dates | 43 # get unique dates |
| 51 unique_dates = list(set(all_dates)) | 44 unique_dates = list(set(all_dates)) |
| 52 for tool in tool_usage_dict: | 45 for tool in tool_usage_dict: |
| 53 usage = tool_usage_dict[tool] | 46 usage = tool_usage_dict[tool] |
| 54 # extract those dates for which tool's usage is not present in raw data | 47 # extract those dates for which tool's usage is not present in raw data |
| 64 """ | 57 """ |
| 65 Fit a curve for the tool usage over time to predict future tool usage | 58 Fit a curve for the tool usage over time to predict future tool usage |
| 66 """ | 59 """ |
| 67 epsilon = 0.0 | 60 epsilon = 0.0 |
| 68 cv = 5 | 61 cv = 5 |
| 69 s_typ = "neg_mean_absolute_error" | 62 s_typ = 'neg_mean_absolute_error' |
| 70 n_jobs = 4 | 63 n_jobs = 4 |
| 71 s_error = 1 | 64 s_error = 1 |
| 72 tr_score = False | 65 tr_score = False |
| 73 try: | 66 try: |
| 74 pipe = Pipeline(steps=[("regressor", SVR(gamma="scale"))]) | 67 pipe = Pipeline(steps=[('regressor', SVR(gamma='scale'))]) |
| 75 param_grid = { | 68 param_grid = { |
| 76 "regressor__kernel": ["rbf", "poly", "linear"], | 69 'regressor__kernel': ['rbf', 'poly', 'linear'], |
| 77 "regressor__degree": [2, 3], | 70 'regressor__degree': [2, 3] |
| 78 } | 71 } |
| 79 search = GridSearchCV( | 72 search = GridSearchCV(pipe, param_grid, cv=cv, scoring=s_typ, n_jobs=n_jobs, error_score=s_error, return_train_score=tr_score) |
| 80 pipe, | |
| 81 param_grid, | |
| 82 cv=cv, | |
| 83 scoring=s_typ, | |
| 84 n_jobs=n_jobs, | |
| 85 error_score=s_error, | |
| 86 return_train_score=tr_score, | |
| 87 ) | |
| 88 search.fit(x_reshaped, y_reshaped.ravel()) | 73 search.fit(x_reshaped, y_reshaped.ravel()) |
| 89 model = search.best_estimator_ | 74 model = search.best_estimator_ |
| 90 # set the next time point to get prediction for | 75 # set the next time point to get prediction for |
| 91 prediction_point = np.reshape([x_reshaped[-1][0] + 1], (1, 1)) | 76 prediction_point = np.reshape([x_reshaped[-1][0] + 1], (1, 1)) |
| 92 prediction = model.predict(prediction_point) | 77 prediction = model.predict(prediction_point) |
