Mercurial > repos > bgruening > create_tool_recommendation_model
diff predict_tool_usage.py @ 6:e94dc7945639 draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
author | bgruening |
---|---|
date | Sun, 16 Oct 2022 11:52:10 +0000 |
parents | 4f7e6612906b |
children |
line wrap: on
line diff
--- a/predict_tool_usage.py Fri May 06 09:05:18 2022 +0000 +++ b/predict_tool_usage.py Sun Oct 16 11:52:10 2022 +0000 @@ -3,9 +3,6 @@ """ import collections -import csv -import os -import warnings import numpy as np import utils @@ -13,40 +10,36 @@ from sklearn.pipeline import Pipeline from sklearn.svm import SVR -warnings.filterwarnings("ignore") - -main_path = os.getcwd() - class ToolPopularity: + def __init__(self): """ Init method. """ - def extract_tool_usage(self, tool_usage_file, cutoff_date, dictionary): + def extract_tool_usage(self, tool_usage_df, cutoff_date, dictionary): """ Extract the tool usage over time for each tool """ tool_usage_dict = dict() all_dates = list() all_tool_list = list(dictionary.keys()) - with open(tool_usage_file, "rt") as usage_file: - tool_usage = csv.reader(usage_file, delimiter="\t") - for index, row in enumerate(tool_usage): - row = [item.strip() for item in row] - if (str(row[1]).strip() > cutoff_date) is True: - tool_id = utils.format_tool_id(row[0]) - if tool_id in all_tool_list: - all_dates.append(row[1]) - if tool_id not in tool_usage_dict: - tool_usage_dict[tool_id] = dict() - tool_usage_dict[tool_id][row[1]] = int(row[2]) + for index, row in tool_usage_df.iterrows(): + row = row.tolist() + row = [str(item).strip() for item in row] + if (row[1] > cutoff_date) is True: + tool_id = utils.format_tool_id(row[0]) + if tool_id in all_tool_list: + all_dates.append(row[1]) + if tool_id not in tool_usage_dict: + tool_usage_dict[tool_id] = dict() + tool_usage_dict[tool_id][row[1]] = int(float(row[2])) + else: + curr_date = row[1] + # merge the usage of different version of tools into one + if curr_date in tool_usage_dict[tool_id]: + tool_usage_dict[tool_id][curr_date] += int(float(row[2])) else: - curr_date = row[1] - # merge the usage of different version of tools into one - if curr_date in tool_usage_dict[tool_id]: - tool_usage_dict[tool_id][curr_date] += int(row[2]) - else: - tool_usage_dict[tool_id][curr_date] = int(row[2]) + tool_usage_dict[tool_id][curr_date] = int(float(row[2])) # get unique dates unique_dates = list(set(all_dates)) for tool in tool_usage_dict: @@ -66,25 +59,17 @@ """ epsilon = 0.0 cv = 5 - s_typ = "neg_mean_absolute_error" + s_typ = 'neg_mean_absolute_error' n_jobs = 4 s_error = 1 tr_score = False try: - pipe = Pipeline(steps=[("regressor", SVR(gamma="scale"))]) + pipe = Pipeline(steps=[('regressor', SVR(gamma='scale'))]) param_grid = { - "regressor__kernel": ["rbf", "poly", "linear"], - "regressor__degree": [2, 3], + 'regressor__kernel': ['rbf', 'poly', 'linear'], + 'regressor__degree': [2, 3] } - search = GridSearchCV( - pipe, - param_grid, - cv=cv, - scoring=s_typ, - n_jobs=n_jobs, - error_score=s_error, - return_train_score=tr_score, - ) + search = GridSearchCV(pipe, param_grid, cv=cv, scoring=s_typ, n_jobs=n_jobs, error_score=s_error, return_train_score=tr_score) search.fit(x_reshaped, y_reshaped.ravel()) model = search.best_estimator_ # set the next time point to get prediction for