Mercurial > repos > bgruening > sklearn_discriminant_classifier
diff simple_model_fit.py @ 41:d769d83ec796 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author | bgruening |
---|---|
date | Wed, 09 Aug 2023 13:14:12 +0000 |
parents | e76f6dfea5c9 |
children |
line wrap: on
line diff
--- a/simple_model_fit.py Thu Aug 11 08:53:29 2022 +0000 +++ b/simple_model_fit.py Wed Aug 09 13:14:12 2023 +0000 @@ -1,9 +1,9 @@ import argparse import json -import pickle import pandas as pd -from galaxy_ml.utils import load_model, read_columns +from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5 +from galaxy_ml.utils import read_columns from scipy.io import mmread from sklearn.pipeline import Pipeline @@ -148,9 +148,9 @@ params = json.load(param_handler) # load model - with open(infile_estimator, "rb") as est_handler: - estimator = load_model(est_handler) - estimator = clean_params(estimator, n_jobs=N_JOBS) + estimator = load_model_from_h5(infile_estimator) + + estimator = clean_params(estimator) X_train, y_train = _get_X_y(params, infile1, infile2) @@ -170,8 +170,7 @@ if getattr(main_est, "data_generator_", None): del main_est.data_generator_ - with open(out_object, "wb") as output_handler: - pickle.dump(estimator, output_handler, pickle.HIGHEST_PROTOCOL) + dump_model_to_h5(estimator, out_object) if __name__ == "__main__":