Mercurial > repos > bgruening > sklearn_train_test_eval
comparison keras_train_and_eval.py @ 10:a9e0b963b7bb draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 208a8d348e7c7a182cfbe1b6f17868146428a7e2"
author | bgruening |
---|---|
date | Tue, 13 Apr 2021 22:04:06 +0000 |
parents | ead7adad8d0e |
children | caf7d2b71a48 |
comparison
equal
deleted
inserted
replaced
9:ead7adad8d0e | 10:a9e0b963b7bb |
---|---|
1 import argparse | 1 import argparse |
2 import joblib | |
3 import json | 2 import json |
4 import numpy as np | |
5 import os | 3 import os |
6 import pandas as pd | |
7 import pickle | 4 import pickle |
8 import warnings | 5 import warnings |
9 from itertools import chain | 6 from itertools import chain |
10 from scipy.io import mmread | 7 |
11 from sklearn.pipeline import Pipeline | 8 import joblib |
12 from sklearn.metrics.scorer import _check_multimetric_scoring | 9 import numpy as np |
13 from sklearn.model_selection._validation import _score | 10 import pandas as pd |
14 from sklearn.model_selection import _search, _validation | |
15 from sklearn.utils import indexable, safe_indexing | |
16 | |
17 from galaxy_ml.externals.selene_sdk.utils import compute_score | 11 from galaxy_ml.externals.selene_sdk.utils import compute_score |
12 from galaxy_ml.keras_galaxy_models import _predict_generator | |
18 from galaxy_ml.model_validations import train_test_split | 13 from galaxy_ml.model_validations import train_test_split |
19 from galaxy_ml.keras_galaxy_models import _predict_generator | |
20 from galaxy_ml.utils import ( | 14 from galaxy_ml.utils import ( |
21 SafeEval, | 15 clean_params, |
16 get_main_estimator, | |
17 get_module, | |
22 get_scoring, | 18 get_scoring, |
23 load_model, | 19 load_model, |
24 read_columns, | 20 read_columns, |
21 SafeEval, | |
25 try_get_attr, | 22 try_get_attr, |
26 get_module, | |
27 clean_params, | |
28 get_main_estimator, | |
29 ) | 23 ) |
24 from scipy.io import mmread | |
25 from sklearn.metrics.scorer import _check_multimetric_scoring | |
26 from sklearn.model_selection import _search, _validation | |
27 from sklearn.model_selection._validation import _score | |
28 from sklearn.pipeline import Pipeline | |
29 from sklearn.utils import indexable, safe_indexing | |
30 | 30 |
31 | 31 |
32 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score") | 32 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score") |
33 setattr(_search, "_fit_and_score", _fit_and_score) | 33 setattr(_search, "_fit_and_score", _fit_and_score) |
34 setattr(_validation, "_fit_and_score", _fit_and_score) | 34 setattr(_validation, "_fit_and_score", _fit_and_score) |
102 rval = list(chain.from_iterable((safe_indexing(a, train), safe_indexing(a, test)) for a in new_arrays)) | 102 rval = list(chain.from_iterable((safe_indexing(a, train), safe_indexing(a, test)) for a in new_arrays)) |
103 else: | 103 else: |
104 rval = train_test_split(*new_arrays, **kwargs) | 104 rval = train_test_split(*new_arrays, **kwargs) |
105 | 105 |
106 for pos in nones: | 106 for pos in nones: |
107 rval[pos * 2 : 2] = [None, None] | 107 rval[pos * 2: 2] = [None, None] |
108 | 108 |
109 return rval | 109 return rval |
110 | 110 |
111 | 111 |
112 def _evaluate(y_true, pred_probas, scorer, is_multimetric=True): | 112 def _evaluate(y_true, pred_probas, scorer, is_multimetric=True): |