comparison keras_train_and_eval.py @ 9:e3b420d0b71a draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 208a8d348e7c7a182cfbe1b6f17868146428a7e2"
author bgruening
date Tue, 13 Apr 2021 22:42:14 +0000
parents 449a757be9c9
children 9b6faa256f15
comparison
equal deleted inserted replaced
8:449a757be9c9 9:e3b420d0b71a
1 import argparse 1 import argparse
2 import joblib
3 import json 2 import json
4 import numpy as np
5 import os 3 import os
6 import pandas as pd
7 import pickle 4 import pickle
8 import warnings 5 import warnings
9 from itertools import chain 6 from itertools import chain
10 from scipy.io import mmread 7
11 from sklearn.pipeline import Pipeline 8 import joblib
12 from sklearn.metrics.scorer import _check_multimetric_scoring 9 import numpy as np
13 from sklearn.model_selection._validation import _score 10 import pandas as pd
14 from sklearn.model_selection import _search, _validation
15 from sklearn.utils import indexable, safe_indexing
16
17 from galaxy_ml.externals.selene_sdk.utils import compute_score 11 from galaxy_ml.externals.selene_sdk.utils import compute_score
12 from galaxy_ml.keras_galaxy_models import _predict_generator
18 from galaxy_ml.model_validations import train_test_split 13 from galaxy_ml.model_validations import train_test_split
19 from galaxy_ml.keras_galaxy_models import _predict_generator
20 from galaxy_ml.utils import ( 14 from galaxy_ml.utils import (
21 SafeEval, 15 clean_params,
16 get_main_estimator,
17 get_module,
22 get_scoring, 18 get_scoring,
23 load_model, 19 load_model,
24 read_columns, 20 read_columns,
21 SafeEval,
25 try_get_attr, 22 try_get_attr,
26 get_module,
27 clean_params,
28 get_main_estimator,
29 ) 23 )
24 from scipy.io import mmread
25 from sklearn.metrics.scorer import _check_multimetric_scoring
26 from sklearn.model_selection import _search, _validation
27 from sklearn.model_selection._validation import _score
28 from sklearn.pipeline import Pipeline
29 from sklearn.utils import indexable, safe_indexing
30 30
31 31
32 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score") 32 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score")
33 setattr(_search, "_fit_and_score", _fit_and_score) 33 setattr(_search, "_fit_and_score", _fit_and_score)
34 setattr(_validation, "_fit_and_score", _fit_and_score) 34 setattr(_validation, "_fit_and_score", _fit_and_score)
102 rval = list(chain.from_iterable((safe_indexing(a, train), safe_indexing(a, test)) for a in new_arrays)) 102 rval = list(chain.from_iterable((safe_indexing(a, train), safe_indexing(a, test)) for a in new_arrays))
103 else: 103 else:
104 rval = train_test_split(*new_arrays, **kwargs) 104 rval = train_test_split(*new_arrays, **kwargs)
105 105
106 for pos in nones: 106 for pos in nones:
107 rval[pos * 2 : 2] = [None, None] 107 rval[pos * 2: 2] = [None, None]
108 108
109 return rval 109 return rval
110 110
111 111
112 def _evaluate(y_true, pred_probas, scorer, is_multimetric=True): 112 def _evaluate(y_true, pred_probas, scorer, is_multimetric=True):