diff predict_tool_usage.py @ 5:4f7e6612906b draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 5eebc0cb44e71f581d548b7e842002705dd155eb"
author bgruening
date Fri, 06 May 2022 09:05:18 +0000
parents 5b3c08710e47
children e94dc7945639
line wrap: on
line diff
--- a/predict_tool_usage.py	Tue Jul 07 03:25:49 2020 -0400
+++ b/predict_tool_usage.py	Fri May 06 09:05:18 2022 +0000
@@ -2,17 +2,16 @@
 Predict tool usage to weigh the predicted tools
 """
 
-import os
-import numpy as np
-import warnings
+import collections
 import csv
-import collections
+import os
+import warnings
 
-from sklearn.svm import SVR
+import numpy as np
+import utils
 from sklearn.model_selection import GridSearchCV
 from sklearn.pipeline import Pipeline
-
-import utils
+from sklearn.svm import SVR
 
 warnings.filterwarnings("ignore")
 
@@ -20,7 +19,6 @@
 
 
 class ToolPopularity:
-
     def __init__(self):
         """ Init method. """
 
@@ -31,10 +29,11 @@
         tool_usage_dict = dict()
         all_dates = list()
         all_tool_list = list(dictionary.keys())
-        with open(tool_usage_file, 'rt') as usage_file:
-            tool_usage = csv.reader(usage_file, delimiter='\t')
+        with open(tool_usage_file, "rt") as usage_file:
+            tool_usage = csv.reader(usage_file, delimiter="\t")
             for index, row in enumerate(tool_usage):
-                if (str(row[1]) > cutoff_date) is True:
+                row = [item.strip() for item in row]
+                if (str(row[1]).strip() > cutoff_date) is True:
                     tool_id = utils.format_tool_id(row[0])
                     if tool_id in all_tool_list:
                         all_dates.append(row[1])
@@ -67,18 +66,25 @@
         """
         epsilon = 0.0
         cv = 5
-        s_typ = 'neg_mean_absolute_error'
+        s_typ = "neg_mean_absolute_error"
         n_jobs = 4
         s_error = 1
-        iid = True
         tr_score = False
         try:
-            pipe = Pipeline(steps=[('regressor', SVR(gamma='scale'))])
+            pipe = Pipeline(steps=[("regressor", SVR(gamma="scale"))])
             param_grid = {
-                'regressor__kernel': ['rbf', 'poly', 'linear'],
-                'regressor__degree': [2, 3]
+                "regressor__kernel": ["rbf", "poly", "linear"],
+                "regressor__degree": [2, 3],
             }
-            search = GridSearchCV(pipe, param_grid, iid=iid, cv=cv, scoring=s_typ, n_jobs=n_jobs, error_score=s_error, return_train_score=tr_score)
+            search = GridSearchCV(
+                pipe,
+                param_grid,
+                cv=cv,
+                scoring=s_typ,
+                n_jobs=n_jobs,
+                error_score=s_error,
+                return_train_score=tr_score,
+            )
             search.fit(x_reshaped, y_reshaped.ravel())
             model = search.best_estimator_
             # set the next time point to get prediction for
@@ -87,7 +93,8 @@
             if prediction < epsilon:
                 prediction = [epsilon]
             return prediction[0]
-        except Exception:
+        except Exception as e:
+            print(e)
             return epsilon
 
     def get_pupularity_prediction(self, tools_usage):