diff search_model_validation.py @ 23:e3bc646e63b2 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 57f4407e278a615f47a377a3328782b1d8e0b54d
author bgruening
date Sun, 30 Dec 2018 02:01:29 -0500
parents
children b628de0d101f
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/search_model_validation.py	Sun Dec 30 02:01:29 2018 -0500
@@ -0,0 +1,234 @@
+import imblearn
+import json
+import numpy as np
+import os
+import pandas
+import pickle
+import skrebate
+import sklearn
+import sys
+import xgboost
+import warnings
+from imblearn import under_sampling, over_sampling, combine
+from imblearn.pipeline import Pipeline as imbPipeline
+from sklearn import (cluster, compose, decomposition, ensemble, feature_extraction,
+                    feature_selection, gaussian_process, kernel_approximation, metrics,
+                    model_selection, naive_bayes, neighbors, pipeline, preprocessing,
+                    svm, linear_model, tree, discriminant_analysis)
+from sklearn.exceptions import FitFailedWarning
+from sklearn.externals import joblib
+from utils import get_cv, get_scoring, get_X_y, load_model, read_columns, SafeEval
+
+
+N_JOBS = int(os.environ.get('GALAXY_SLOTS', 1))
+
+
+def get_search_params(params_builder):
+    search_params = {}
+    safe_eval = SafeEval(load_scipy=True, load_numpy=True)
+    safe_eval_es = SafeEval(load_estimators=True)
+
+    for p in params_builder['param_set']:
+        search_p = p['search_param_selector']['search_p']
+        if search_p.strip() == '':
+            continue
+        param_type = p['search_param_selector']['selected_param_type']
+
+        lst = search_p.split(':')
+        assert (len(lst) == 2), "Error, make sure there is one and only one colon in search parameter input."
+        literal = lst[1].strip()
+        param_name = lst[0].strip()
+        if param_name:
+            if param_name.lower() == 'n_jobs':
+                sys.exit("Parameter `%s` is invalid for search." %param_name)
+            elif not param_name.endswith('-'):
+                ev = safe_eval(literal)
+                if param_type == 'final_estimator_p':
+                    search_params['estimator__' + param_name] = ev
+                else:
+                    search_params['preprocessing_' + param_type[5:6] + '__' + param_name] = ev
+            else:
+                # only for estimator eval, add `-` to the end of param
+                #TODO maybe add regular express check
+                ev = safe_eval_es(literal)
+                for obj in ev:
+                    if 'n_jobs' in obj.get_params():
+                        obj.set_params( n_jobs=N_JOBS )
+                if param_type == 'final_estimator_p':
+                    search_params['estimator__' + param_name[:-1]] = ev
+                else:
+                    search_params['preprocessing_' + param_type[5:6] + '__' + param_name[:-1]] = ev
+        elif param_type != 'final_estimator_p':
+            #TODO regular express check ?
+            ev = safe_eval_es(literal)
+            preprocessors = [preprocessing.StandardScaler(), preprocessing.Binarizer(), preprocessing.Imputer(),
+                            preprocessing.MaxAbsScaler(), preprocessing.Normalizer(), preprocessing.MinMaxScaler(),
+                            preprocessing.PolynomialFeatures(),preprocessing.RobustScaler(),
+                            feature_selection.SelectKBest(), feature_selection.GenericUnivariateSelect(),
+                            feature_selection.SelectPercentile(), feature_selection.SelectFpr(), feature_selection.SelectFdr(),
+                            feature_selection.SelectFwe(), feature_selection.VarianceThreshold(),
+                            decomposition.FactorAnalysis(random_state=0), decomposition.FastICA(random_state=0), decomposition.IncrementalPCA(),
+                            decomposition.KernelPCA(random_state=0, n_jobs=N_JOBS), decomposition.LatentDirichletAllocation(random_state=0, n_jobs=N_JOBS),
+                            decomposition.MiniBatchDictionaryLearning(random_state=0, n_jobs=N_JOBS),
+                            decomposition.MiniBatchSparsePCA(random_state=0, n_jobs=N_JOBS), decomposition.NMF(random_state=0),
+                            decomposition.PCA(random_state=0), decomposition.SparsePCA(random_state=0, n_jobs=N_JOBS),
+                            decomposition.TruncatedSVD(random_state=0),
+                            kernel_approximation.Nystroem(random_state=0), kernel_approximation.RBFSampler(random_state=0),
+                            kernel_approximation.AdditiveChi2Sampler(), kernel_approximation.SkewedChi2Sampler(random_state=0),
+                            cluster.FeatureAgglomeration(),
+                            skrebate.ReliefF(n_jobs=N_JOBS), skrebate.SURF(n_jobs=N_JOBS), skrebate.SURFstar(n_jobs=N_JOBS),
+                            skrebate.MultiSURF(n_jobs=N_JOBS), skrebate.MultiSURFstar(n_jobs=N_JOBS),
+                            imblearn.under_sampling.ClusterCentroids(random_state=0, n_jobs=N_JOBS),
+                            imblearn.under_sampling.CondensedNearestNeighbour(random_state=0, n_jobs=N_JOBS),
+                            imblearn.under_sampling.EditedNearestNeighbours(random_state=0, n_jobs=N_JOBS),
+                            imblearn.under_sampling.RepeatedEditedNearestNeighbours(random_state=0, n_jobs=N_JOBS),
+                            imblearn.under_sampling.AllKNN(random_state=0, n_jobs=N_JOBS),
+                            imblearn.under_sampling.InstanceHardnessThreshold(random_state=0, n_jobs=N_JOBS),
+                            imblearn.under_sampling.NearMiss(random_state=0, n_jobs=N_JOBS),
+                            imblearn.under_sampling.NeighbourhoodCleaningRule(random_state=0, n_jobs=N_JOBS),
+                            imblearn.under_sampling.OneSidedSelection(random_state=0, n_jobs=N_JOBS),
+                            imblearn.under_sampling.RandomUnderSampler(random_state=0),
+                            imblearn.under_sampling.TomekLinks(random_state=0, n_jobs=N_JOBS),
+                            imblearn.over_sampling.ADASYN(random_state=0, n_jobs=N_JOBS),
+                            imblearn.over_sampling.RandomOverSampler(random_state=0),
+                            imblearn.over_sampling.SMOTE(random_state=0, n_jobs=N_JOBS),
+                            imblearn.over_sampling.SVMSMOTE(random_state=0, n_jobs=N_JOBS),
+                            imblearn.over_sampling.BorderlineSMOTE(random_state=0, n_jobs=N_JOBS),
+                            imblearn.over_sampling.SMOTENC(categorical_features=[], random_state=0, n_jobs=N_JOBS),
+                            imblearn.combine.SMOTEENN(random_state=0), imblearn.combine.SMOTETomek(random_state=0)]
+            newlist = []
+            for obj in ev:
+                if obj is None:
+                    newlist.append(None)
+                elif obj == 'all_0':
+                    newlist.extend(preprocessors[0:36])
+                elif obj == 'sk_prep_all':      # no KernalCenter()
+                    newlist.extend(preprocessors[0:8])
+                elif obj == 'fs_all':
+                    newlist.extend(preprocessors[8:15])
+                elif obj == 'decomp_all':
+                    newlist.extend(preprocessors[15:26])
+                elif obj == 'k_appr_all':
+                    newlist.extend(preprocessors[26:30])
+                elif obj == 'reb_all':
+                    newlist.extend(preprocessors[31:36])
+                elif obj == 'imb_all':
+                    newlist.extend(preprocessors[36:55])
+                elif  type(obj) is int and -1 < obj < len(preprocessors):
+                    newlist.append(preprocessors[obj])
+                elif hasattr(obj, 'get_params'):       # user object
+                    if 'n_jobs' in obj.get_params():
+                        newlist.append( obj.set_params(n_jobs=N_JOBS) )
+                    else:
+                        newlist.append(obj)
+                else:
+                    sys.exit("Unsupported preprocessor type: %r" %(obj))
+            search_params['preprocessing_' + param_type[5:6]] = newlist
+        else:
+            sys.exit("Parameter name of the final estimator can't be skipped!")
+
+    return search_params
+
+
+if __name__ == '__main__':
+
+    warnings.simplefilter('ignore')
+
+    input_json_path = sys.argv[1]
+    with open(input_json_path, 'r') as param_handler:
+        params = json.load(param_handler)
+
+    infile_pipeline = sys.argv[2]
+    infile1 = sys.argv[3]
+    infile2 = sys.argv[4]
+    outfile_result = sys.argv[5]
+    if len(sys.argv) > 6:
+        outfile_estimator = sys.argv[6]
+    else:
+        outfile_estimator = None
+
+    params_builder = params['search_schemes']['search_params_builder']
+
+    input_type = params['input_options']['selected_input']
+    if input_type == 'tabular':
+        header = 'infer' if params['input_options']['header1'] else None
+        column_option = params['input_options']['column_selector_options_1']['selected_column_selector_option']
+        if column_option in ['by_index_number', 'all_but_by_index_number', 'by_header_name', 'all_but_by_header_name']:
+            c = params['input_options']['column_selector_options_1']['col1']
+        else:
+            c = None
+        X = read_columns(
+                infile1,
+                c = c,
+                c_option = column_option,
+                sep='\t',
+                header=header,
+                parse_dates=True
+        )
+    else:
+        X = mmread(open(infile1, 'r'))
+
+    header = 'infer' if params['input_options']['header2'] else None
+    column_option = params['input_options']['column_selector_options_2']['selected_column_selector_option2']
+    if column_option in ['by_index_number', 'all_but_by_index_number', 'by_header_name', 'all_but_by_header_name']:
+        c = params['input_options']['column_selector_options_2']['col2']
+    else:
+        c = None
+    y = read_columns(
+            infile2,
+            c = c,
+            c_option = column_option,
+            sep='\t',
+            header=header,
+            parse_dates=True
+    )
+    y = y.ravel()
+
+    optimizer = params['search_schemes']['selected_search_scheme']
+    optimizer = getattr(model_selection, optimizer)
+
+    options = params['search_schemes']['options']
+    splitter, groups = get_cv(options.pop('cv_selector'))
+    if groups is None:
+        options['cv'] = splitter
+    elif groups == '':
+        options['cv'] = list( splitter.split(X, y, groups=None) )
+    else:
+        options['cv'] = list( splitter.split(X, y, groups=groups) )
+    options['n_jobs'] = N_JOBS
+    primary_scoring = options['scoring']['primary_scoring']
+    options['scoring'] = get_scoring(options['scoring'])
+    if options['error_score']:
+        options['error_score'] = 'raise'
+    else:
+        options['error_score'] = np.NaN
+    if options['refit'] and isinstance(options['scoring'], dict):
+        options['refit'] = 'primary'
+    if 'pre_dispatch' in options and options['pre_dispatch'] == '':
+        options['pre_dispatch'] = None
+
+    with open(infile_pipeline, 'rb') as pipeline_handler:
+        pipeline = load_model(pipeline_handler)
+
+    search_params = get_search_params(params_builder)
+    searcher = optimizer(pipeline, search_params, **options)
+
+    if options['error_score'] == 'raise':
+        searcher.fit(X, y)
+    else:
+        warnings.simplefilter('always', FitFailedWarning)
+        with warnings.catch_warnings(record=True) as w:
+            try:
+                searcher.fit(X, y)
+            except ValueError:
+                pass
+            for warning in w:
+                print(repr(warning.message))
+
+    cv_result = pandas.DataFrame(searcher.cv_results_)
+    cv_result.rename(inplace=True, columns={'mean_test_primary': 'mean_test_'+primary_scoring, 'rank_test_primary': 'rank_test_'+primary_scoring})
+    cv_result.to_csv(path_or_buf=outfile_result, sep='\t', header=True, index=False)
+
+    if outfile_estimator:
+        with open(outfile_estimator, 'wb') as output_handler:
+            pickle.dump(searcher.best_estimator_, output_handler, pickle.HIGHEST_PROTOCOL)