Mercurial > repos > bgruening > create_tool_recommendation_model
diff predict_tool_usage.py @ 5:4f7e6612906b draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 5eebc0cb44e71f581d548b7e842002705dd155eb"
author | bgruening |
---|---|
date | Fri, 06 May 2022 09:05:18 +0000 |
parents | 5b3c08710e47 |
children | e94dc7945639 |
line wrap: on
line diff
--- a/predict_tool_usage.py Tue Jul 07 03:25:49 2020 -0400 +++ b/predict_tool_usage.py Fri May 06 09:05:18 2022 +0000 @@ -2,17 +2,16 @@ Predict tool usage to weigh the predicted tools """ -import os -import numpy as np -import warnings +import collections import csv -import collections +import os +import warnings -from sklearn.svm import SVR +import numpy as np +import utils from sklearn.model_selection import GridSearchCV from sklearn.pipeline import Pipeline - -import utils +from sklearn.svm import SVR warnings.filterwarnings("ignore") @@ -20,7 +19,6 @@ class ToolPopularity: - def __init__(self): """ Init method. """ @@ -31,10 +29,11 @@ tool_usage_dict = dict() all_dates = list() all_tool_list = list(dictionary.keys()) - with open(tool_usage_file, 'rt') as usage_file: - tool_usage = csv.reader(usage_file, delimiter='\t') + with open(tool_usage_file, "rt") as usage_file: + tool_usage = csv.reader(usage_file, delimiter="\t") for index, row in enumerate(tool_usage): - if (str(row[1]) > cutoff_date) is True: + row = [item.strip() for item in row] + if (str(row[1]).strip() > cutoff_date) is True: tool_id = utils.format_tool_id(row[0]) if tool_id in all_tool_list: all_dates.append(row[1]) @@ -67,18 +66,25 @@ """ epsilon = 0.0 cv = 5 - s_typ = 'neg_mean_absolute_error' + s_typ = "neg_mean_absolute_error" n_jobs = 4 s_error = 1 - iid = True tr_score = False try: - pipe = Pipeline(steps=[('regressor', SVR(gamma='scale'))]) + pipe = Pipeline(steps=[("regressor", SVR(gamma="scale"))]) param_grid = { - 'regressor__kernel': ['rbf', 'poly', 'linear'], - 'regressor__degree': [2, 3] + "regressor__kernel": ["rbf", "poly", "linear"], + "regressor__degree": [2, 3], } - search = GridSearchCV(pipe, param_grid, iid=iid, cv=cv, scoring=s_typ, n_jobs=n_jobs, error_score=s_error, return_train_score=tr_score) + search = GridSearchCV( + pipe, + param_grid, + cv=cv, + scoring=s_typ, + n_jobs=n_jobs, + error_score=s_error, + return_train_score=tr_score, + ) search.fit(x_reshaped, y_reshaped.ravel()) model = search.best_estimator_ # set the next time point to get prediction for @@ -87,7 +93,8 @@ if prediction < epsilon: prediction = [epsilon] return prediction[0] - except Exception: + except Exception as e: + print(e) return epsilon def get_pupularity_prediction(self, tools_usage):