Mercurial > repos > bgruening > sklearn_to_categorical
comparison stacking_ensembles.py @ 0:59e8b4328c82 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 208a8d348e7c7a182cfbe1b6f17868146428a7e2"
| author | bgruening |
|---|---|
| date | Tue, 13 Apr 2021 22:40:10 +0000 |
| parents | |
| children | f93f0cdbaf18 |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:59e8b4328c82 |
|---|---|
| 1 import argparse | |
| 2 import ast | |
| 3 import json | |
| 4 import pickle | |
| 5 import sys | |
| 6 import warnings | |
| 7 | |
| 8 import mlxtend.classifier | |
| 9 import mlxtend.regressor | |
| 10 import pandas as pd | |
| 11 from galaxy_ml.utils import get_cv, get_estimator, get_search_params, load_model | |
| 12 | |
| 13 | |
| 14 warnings.filterwarnings("ignore") | |
| 15 | |
| 16 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1)) | |
| 17 | |
| 18 | |
| 19 def main(inputs_path, output_obj, base_paths=None, meta_path=None, outfile_params=None): | |
| 20 """ | |
| 21 Parameter | |
| 22 --------- | |
| 23 inputs_path : str | |
| 24 File path for Galaxy parameters | |
| 25 | |
| 26 output_obj : str | |
| 27 File path for ensemble estimator ouput | |
| 28 | |
| 29 base_paths : str | |
| 30 File path or paths concatenated by comma. | |
| 31 | |
| 32 meta_path : str | |
| 33 File path | |
| 34 | |
| 35 outfile_params : str | |
| 36 File path for params output | |
| 37 """ | |
| 38 with open(inputs_path, "r") as param_handler: | |
| 39 params = json.load(param_handler) | |
| 40 | |
| 41 estimator_type = params["algo_selection"]["estimator_type"] | |
| 42 # get base estimators | |
| 43 base_estimators = [] | |
| 44 for idx, base_file in enumerate(base_paths.split(",")): | |
| 45 if base_file and base_file != "None": | |
| 46 with open(base_file, "rb") as handler: | |
| 47 model = load_model(handler) | |
| 48 else: | |
| 49 estimator_json = params["base_est_builder"][idx]["estimator_selector"] | |
| 50 model = get_estimator(estimator_json) | |
| 51 | |
| 52 if estimator_type.startswith("sklearn"): | |
| 53 named = model.__class__.__name__.lower() | |
| 54 named = "base_%d_%s" % (idx, named) | |
| 55 base_estimators.append((named, model)) | |
| 56 else: | |
| 57 base_estimators.append(model) | |
| 58 | |
| 59 # get meta estimator, if applicable | |
| 60 if estimator_type.startswith("mlxtend"): | |
| 61 if meta_path: | |
| 62 with open(meta_path, "rb") as f: | |
| 63 meta_estimator = load_model(f) | |
| 64 else: | |
| 65 estimator_json = params["algo_selection"]["meta_estimator"]["estimator_selector"] | |
| 66 meta_estimator = get_estimator(estimator_json) | |
| 67 | |
| 68 options = params["algo_selection"]["options"] | |
| 69 | |
| 70 cv_selector = options.pop("cv_selector", None) | |
| 71 if cv_selector: | |
| 72 splitter, _groups = get_cv(cv_selector) | |
| 73 options["cv"] = splitter | |
| 74 # set n_jobs | |
| 75 options["n_jobs"] = N_JOBS | |
| 76 | |
| 77 weights = options.pop("weights", None) | |
| 78 if weights: | |
| 79 weights = ast.literal_eval(weights) | |
| 80 if weights: | |
| 81 options["weights"] = weights | |
| 82 | |
| 83 mod_and_name = estimator_type.split("_") | |
| 84 mod = sys.modules[mod_and_name[0]] | |
| 85 klass = getattr(mod, mod_and_name[1]) | |
| 86 | |
| 87 if estimator_type.startswith("sklearn"): | |
| 88 options["n_jobs"] = N_JOBS | |
| 89 ensemble_estimator = klass(base_estimators, **options) | |
| 90 | |
| 91 elif mod == mlxtend.classifier: | |
| 92 ensemble_estimator = klass(classifiers=base_estimators, meta_classifier=meta_estimator, **options) | |
| 93 | |
| 94 else: | |
| 95 ensemble_estimator = klass(regressors=base_estimators, meta_regressor=meta_estimator, **options) | |
| 96 | |
| 97 print(ensemble_estimator) | |
| 98 for base_est in base_estimators: | |
| 99 print(base_est) | |
| 100 | |
| 101 with open(output_obj, "wb") as out_handler: | |
| 102 pickle.dump(ensemble_estimator, out_handler, pickle.HIGHEST_PROTOCOL) | |
| 103 | |
| 104 if params["get_params"] and outfile_params: | |
| 105 results = get_search_params(ensemble_estimator) | |
| 106 df = pd.DataFrame(results, columns=["", "Parameter", "Value"]) | |
| 107 df.to_csv(outfile_params, sep="\t", index=False) | |
| 108 | |
| 109 | |
| 110 if __name__ == "__main__": | |
| 111 aparser = argparse.ArgumentParser() | |
| 112 aparser.add_argument("-b", "--bases", dest="bases") | |
| 113 aparser.add_argument("-m", "--meta", dest="meta") | |
| 114 aparser.add_argument("-i", "--inputs", dest="inputs") | |
| 115 aparser.add_argument("-o", "--outfile", dest="outfile") | |
| 116 aparser.add_argument("-p", "--outfile_params", dest="outfile_params") | |
| 117 args = aparser.parse_args() | |
| 118 | |
| 119 main( | |
| 120 args.inputs, | |
| 121 args.outfile, | |
| 122 base_paths=args.bases, | |
| 123 meta_path=args.meta, | |
| 124 outfile_params=args.outfile_params, | |
| 125 ) |
