Mercurial > repos > bgruening > sklearn_estimator_attributes
comparison simple_model_fit.py @ 17:a01fa4e8fe4f draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author | bgruening |
---|---|
date | Wed, 09 Aug 2023 12:54:40 +0000 |
parents | c9ddd20d25d0 |
children |
comparison
equal
deleted
inserted
replaced
16:d0352e8b4c10 | 17:a01fa4e8fe4f |
---|---|
1 import argparse | 1 import argparse |
2 import json | 2 import json |
3 import pickle | |
4 | 3 |
5 import pandas as pd | 4 import pandas as pd |
6 from galaxy_ml.utils import load_model, read_columns | 5 from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5 |
6 from galaxy_ml.utils import read_columns | |
7 from scipy.io import mmread | 7 from scipy.io import mmread |
8 from sklearn.pipeline import Pipeline | 8 from sklearn.pipeline import Pipeline |
9 | 9 |
10 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1)) | 10 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1)) |
11 | 11 |
146 """ | 146 """ |
147 with open(inputs, "r") as param_handler: | 147 with open(inputs, "r") as param_handler: |
148 params = json.load(param_handler) | 148 params = json.load(param_handler) |
149 | 149 |
150 # load model | 150 # load model |
151 with open(infile_estimator, "rb") as est_handler: | 151 estimator = load_model_from_h5(infile_estimator) |
152 estimator = load_model(est_handler) | 152 |
153 estimator = clean_params(estimator, n_jobs=N_JOBS) | 153 estimator = clean_params(estimator) |
154 | 154 |
155 X_train, y_train = _get_X_y(params, infile1, infile2) | 155 X_train, y_train = _get_X_y(params, infile1, infile2) |
156 | 156 |
157 estimator.fit(X_train, y_train) | 157 estimator.fit(X_train, y_train) |
158 | 158 |
168 if getattr(main_est, "validation_data", None): | 168 if getattr(main_est, "validation_data", None): |
169 del main_est.validation_data | 169 del main_est.validation_data |
170 if getattr(main_est, "data_generator_", None): | 170 if getattr(main_est, "data_generator_", None): |
171 del main_est.data_generator_ | 171 del main_est.data_generator_ |
172 | 172 |
173 with open(out_object, "wb") as output_handler: | 173 dump_model_to_h5(estimator, out_object) |
174 pickle.dump(estimator, output_handler, pickle.HIGHEST_PROTOCOL) | |
175 | 174 |
176 | 175 |
177 if __name__ == "__main__": | 176 if __name__ == "__main__": |
178 aparser = argparse.ArgumentParser() | 177 aparser = argparse.ArgumentParser() |
179 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | 178 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) |