Mercurial > repos > bgruening > sklearn_svm_classifier
diff stacking_ensembles.py @ 10:153f237ddb36 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 60f0fbc0eafd7c11bc60fb6c77f2937782efd8a9-dirty
author | bgruening |
---|---|
date | Fri, 09 Aug 2019 07:08:07 -0400 |
parents | 1a9d5a8fff12 |
children | 73d2ef652879 |
line wrap: on
line diff
--- a/stacking_ensembles.py Tue Jul 09 19:25:14 2019 -0400 +++ b/stacking_ensembles.py Fri Aug 09 07:08:07 2019 -0400 @@ -1,26 +1,17 @@ import argparse +import ast import json +import mlxtend.regressor +import mlxtend.classifier import pandas as pd import pickle -import xgboost +import sklearn +import sys import warnings -from sklearn import (cluster, compose, decomposition, ensemble, - feature_extraction, feature_selection, - gaussian_process, kernel_approximation, metrics, - model_selection, naive_bayes, neighbors, - pipeline, preprocessing, svm, linear_model, - tree, discriminant_analysis) -from sklearn.model_selection._split import check_cv -from feature_selectors import (DyRFE, DyRFECV, - MyPipeline, MyimbPipeline) -from iraps_classifier import (IRAPSCore, IRAPSClassifier, - BinarizeTargetClassifier, - BinarizeTargetRegressor) -from preprocessors import Z_RandomOverSampler -from utils import load_model, get_cv, get_estimator, get_search_params +from sklearn import ensemble -from mlxtend.regressor import StackingCVRegressor, StackingRegressor -from mlxtend.classifier import StackingCVClassifier, StackingClassifier +from galaxy_ml.utils import (load_model, get_cv, get_estimator, + get_search_params) warnings.filterwarnings('ignore') @@ -51,6 +42,8 @@ with open(inputs_path, 'r') as param_handler: params = json.load(param_handler) + estimator_type = params['algo_selection']['estimator_type'] + # get base estimators base_estimators = [] for idx, base_file in enumerate(base_paths.split(',')): if base_file and base_file != 'None': @@ -60,14 +53,23 @@ estimator_json = (params['base_est_builder'][idx] ['estimator_selector']) model = get_estimator(estimator_json) - base_estimators.append(model) + + if estimator_type.startswith('sklearn'): + named = model.__class__.__name__.lower() + named = 'base_%d_%s' % (idx, named) + base_estimators.append((named, model)) + else: + base_estimators.append(model) - if meta_path: - with open(meta_path, 'rb') as f: - meta_estimator = load_model(f) - else: - estimator_json = params['meta_estimator']['estimator_selector'] - meta_estimator = get_estimator(estimator_json) + # get meta estimator, if applicable + if estimator_type.startswith('mlxtend'): + if meta_path: + with open(meta_path, 'rb') as f: + meta_estimator = load_model(f) + else: + estimator_json = (params['algo_selection'] + ['meta_estimator']['estimator_selector']) + meta_estimator = get_estimator(estimator_json) options = params['algo_selection']['options'] @@ -78,26 +80,26 @@ # set n_jobs options['n_jobs'] = N_JOBS - if params['algo_selection']['estimator_type'] == 'StackingCVClassifier': - ensemble_estimator = StackingCVClassifier( + weights = options.pop('weights', None) + if weights: + options['weights'] = ast.literal_eval(weights) + + mod_and_name = estimator_type.split('_') + mod = sys.modules[mod_and_name[0]] + klass = getattr(mod, mod_and_name[1]) + + if estimator_type.startswith('sklearn'): + options['n_jobs'] = N_JOBS + ensemble_estimator = klass(base_estimators, **options) + + elif mod == mlxtend.classifier: + ensemble_estimator = klass( classifiers=base_estimators, meta_classifier=meta_estimator, **options) - elif params['algo_selection']['estimator_type'] == 'StackingClassifier': - ensemble_estimator = StackingClassifier( - classifiers=base_estimators, - meta_classifier=meta_estimator, - **options) - - elif params['algo_selection']['estimator_type'] == 'StackingCVRegressor': - ensemble_estimator = StackingCVRegressor( - regressors=base_estimators, - meta_regressor=meta_estimator, - **options) - else: - ensemble_estimator = StackingRegressor( + ensemble_estimator = klass( regressors=base_estimators, meta_regressor=meta_estimator, **options)