Mercurial > repos > bgruening > model_prediction
diff search_model_validation.py @ 9:4aa701f5a393 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
author | bgruening |
---|---|
date | Tue, 13 Apr 2021 18:00:54 +0000 |
parents | 6efb9bc6bf32 |
children | 22f9cbcf1582 |
line wrap: on
line diff
--- a/search_model_validation.py Thu Oct 01 20:35:52 2020 +0000 +++ b/search_model_validation.py Tue Apr 13 18:00:54 2021 +0000 @@ -11,45 +11,57 @@ import sys import warnings from scipy.io import mmread -from sklearn import (cluster, decomposition, feature_selection, - kernel_approximation, model_selection, preprocessing) +from sklearn import ( + cluster, + decomposition, + feature_selection, + kernel_approximation, + model_selection, + preprocessing, +) from sklearn.exceptions import FitFailedWarning from sklearn.model_selection._validation import _score, cross_validate from sklearn.model_selection import _search, _validation from sklearn.pipeline import Pipeline -from galaxy_ml.utils import (SafeEval, get_cv, get_scoring, load_model, - read_columns, try_get_attr, get_module, - clean_params, get_main_estimator) +from galaxy_ml.utils import ( + SafeEval, + get_cv, + get_scoring, + load_model, + read_columns, + try_get_attr, + get_module, + clean_params, + get_main_estimator, +) -_fit_and_score = try_get_attr('galaxy_ml.model_validations', '_fit_and_score') -setattr(_search, '_fit_and_score', _fit_and_score) -setattr(_validation, '_fit_and_score', _fit_and_score) +_fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score") +setattr(_search, "_fit_and_score", _fit_and_score) +setattr(_validation, "_fit_and_score", _fit_and_score) -N_JOBS = int(os.environ.get('GALAXY_SLOTS', 1)) +N_JOBS = int(os.environ.get("GALAXY_SLOTS", 1)) # handle disk cache -CACHE_DIR = os.path.join(os.getcwd(), 'cached') +CACHE_DIR = os.path.join(os.getcwd(), "cached") del os -NON_SEARCHABLE = ('n_jobs', 'pre_dispatch', 'memory', '_path', - 'nthread', 'callbacks') +NON_SEARCHABLE = ("n_jobs", "pre_dispatch", "memory", "_path", "nthread", "callbacks") def _eval_search_params(params_builder): search_params = {} - for p in params_builder['param_set']: - search_list = p['sp_list'].strip() - if search_list == '': + for p in params_builder["param_set"]: + search_list = p["sp_list"].strip() + if search_list == "": continue - param_name = p['sp_name'] + param_name = p["sp_name"] if param_name.lower().endswith(NON_SEARCHABLE): - print("Warning: `%s` is not eligible for search and was " - "omitted!" % param_name) + print("Warning: `%s` is not eligible for search and was " "omitted!" % param_name) continue - if not search_list.startswith(':'): + if not search_list.startswith(":"): safe_eval = SafeEval(load_scipy=True, load_numpy=True) ev = safe_eval(search_list) search_params[param_name] = ev @@ -60,26 +72,27 @@ # TODO maybe add regular express check ev = safe_eval_es(search_list) preprocessings = ( - preprocessing.StandardScaler(), preprocessing.Binarizer(), + preprocessing.StandardScaler(), + preprocessing.Binarizer(), preprocessing.MaxAbsScaler(), - preprocessing.Normalizer(), preprocessing.MinMaxScaler(), + preprocessing.Normalizer(), + preprocessing.MinMaxScaler(), preprocessing.PolynomialFeatures(), - preprocessing.RobustScaler(), feature_selection.SelectKBest(), + preprocessing.RobustScaler(), + feature_selection.SelectKBest(), feature_selection.GenericUnivariateSelect(), feature_selection.SelectPercentile(), - feature_selection.SelectFpr(), feature_selection.SelectFdr(), + feature_selection.SelectFpr(), + feature_selection.SelectFdr(), feature_selection.SelectFwe(), feature_selection.VarianceThreshold(), decomposition.FactorAnalysis(random_state=0), decomposition.FastICA(random_state=0), decomposition.IncrementalPCA(), decomposition.KernelPCA(random_state=0, n_jobs=N_JOBS), - decomposition.LatentDirichletAllocation( - random_state=0, n_jobs=N_JOBS), - decomposition.MiniBatchDictionaryLearning( - random_state=0, n_jobs=N_JOBS), - decomposition.MiniBatchSparsePCA( - random_state=0, n_jobs=N_JOBS), + decomposition.LatentDirichletAllocation(random_state=0, n_jobs=N_JOBS), + decomposition.MiniBatchDictionaryLearning(random_state=0, n_jobs=N_JOBS), + decomposition.MiniBatchSparsePCA(random_state=0, n_jobs=N_JOBS), decomposition.NMF(random_state=0), decomposition.PCA(random_state=0), decomposition.SparsePCA(random_state=0, n_jobs=N_JOBS), @@ -94,59 +107,48 @@ skrebate.SURFstar(n_jobs=N_JOBS), skrebate.MultiSURF(n_jobs=N_JOBS), skrebate.MultiSURFstar(n_jobs=N_JOBS), - imblearn.under_sampling.ClusterCentroids( - random_state=0, n_jobs=N_JOBS), - imblearn.under_sampling.CondensedNearestNeighbour( - random_state=0, n_jobs=N_JOBS), - imblearn.under_sampling.EditedNearestNeighbours( - random_state=0, n_jobs=N_JOBS), - imblearn.under_sampling.RepeatedEditedNearestNeighbours( - random_state=0, n_jobs=N_JOBS), + imblearn.under_sampling.ClusterCentroids(random_state=0, n_jobs=N_JOBS), + imblearn.under_sampling.CondensedNearestNeighbour(random_state=0, n_jobs=N_JOBS), + imblearn.under_sampling.EditedNearestNeighbours(random_state=0, n_jobs=N_JOBS), + imblearn.under_sampling.RepeatedEditedNearestNeighbours(random_state=0, n_jobs=N_JOBS), imblearn.under_sampling.AllKNN(random_state=0, n_jobs=N_JOBS), - imblearn.under_sampling.InstanceHardnessThreshold( - random_state=0, n_jobs=N_JOBS), - imblearn.under_sampling.NearMiss( - random_state=0, n_jobs=N_JOBS), - imblearn.under_sampling.NeighbourhoodCleaningRule( - random_state=0, n_jobs=N_JOBS), - imblearn.under_sampling.OneSidedSelection( - random_state=0, n_jobs=N_JOBS), - imblearn.under_sampling.RandomUnderSampler( - random_state=0), - imblearn.under_sampling.TomekLinks( - random_state=0, n_jobs=N_JOBS), + imblearn.under_sampling.InstanceHardnessThreshold(random_state=0, n_jobs=N_JOBS), + imblearn.under_sampling.NearMiss(random_state=0, n_jobs=N_JOBS), + imblearn.under_sampling.NeighbourhoodCleaningRule(random_state=0, n_jobs=N_JOBS), + imblearn.under_sampling.OneSidedSelection(random_state=0, n_jobs=N_JOBS), + imblearn.under_sampling.RandomUnderSampler(random_state=0), + imblearn.under_sampling.TomekLinks(random_state=0, n_jobs=N_JOBS), imblearn.over_sampling.ADASYN(random_state=0, n_jobs=N_JOBS), imblearn.over_sampling.RandomOverSampler(random_state=0), imblearn.over_sampling.SMOTE(random_state=0, n_jobs=N_JOBS), imblearn.over_sampling.SVMSMOTE(random_state=0, n_jobs=N_JOBS), - imblearn.over_sampling.BorderlineSMOTE( - random_state=0, n_jobs=N_JOBS), - imblearn.over_sampling.SMOTENC( - categorical_features=[], random_state=0, n_jobs=N_JOBS), + imblearn.over_sampling.BorderlineSMOTE(random_state=0, n_jobs=N_JOBS), + imblearn.over_sampling.SMOTENC(categorical_features=[], random_state=0, n_jobs=N_JOBS), imblearn.combine.SMOTEENN(random_state=0), - imblearn.combine.SMOTETomek(random_state=0)) + imblearn.combine.SMOTETomek(random_state=0), + ) newlist = [] for obj in ev: if obj is None: newlist.append(None) - elif obj == 'all_0': + elif obj == "all_0": newlist.extend(preprocessings[0:35]) - elif obj == 'sk_prep_all': # no KernalCenter() + elif obj == "sk_prep_all": # no KernalCenter() newlist.extend(preprocessings[0:7]) - elif obj == 'fs_all': + elif obj == "fs_all": newlist.extend(preprocessings[7:14]) - elif obj == 'decomp_all': + elif obj == "decomp_all": newlist.extend(preprocessings[14:25]) - elif obj == 'k_appr_all': + elif obj == "k_appr_all": newlist.extend(preprocessings[25:29]) - elif obj == 'reb_all': + elif obj == "reb_all": newlist.extend(preprocessings[30:35]) - elif obj == 'imb_all': + elif obj == "imb_all": newlist.extend(preprocessings[35:54]) elif type(obj) is int and -1 < obj < len(preprocessings): newlist.append(preprocessings[obj]) - elif hasattr(obj, 'get_params'): # user uploaded object - if 'n_jobs' in obj.get_params(): + elif hasattr(obj, "get_params"): # user uploaded object + if "n_jobs" in obj.get_params(): newlist.append(obj.set_params(n_jobs=N_JOBS)) else: newlist.append(obj) @@ -158,9 +160,17 @@ return search_params -def _handle_X_y(estimator, params, infile1, infile2, loaded_df={}, - ref_seq=None, intervals=None, targets=None, - fasta_path=None): +def _handle_X_y( + estimator, + params, + infile1, + infile2, + loaded_df={}, + ref_seq=None, + intervals=None, + targets=None, + fasta_path=None, +): """read inputs Params @@ -192,15 +202,18 @@ """ estimator_params = estimator.get_params() - input_type = params['input_options']['selected_input'] + input_type = params["input_options"]["selected_input"] # tabular input - if input_type == 'tabular': - header = 'infer' if params['input_options']['header1'] else None - column_option = (params['input_options']['column_selector_options_1'] - ['selected_column_selector_option']) - if column_option in ['by_index_number', 'all_but_by_index_number', - 'by_header_name', 'all_but_by_header_name']: - c = params['input_options']['column_selector_options_1']['col1'] + if input_type == "tabular": + header = "infer" if params["input_options"]["header1"] else None + column_option = params["input_options"]["column_selector_options_1"]["selected_column_selector_option"] + if column_option in [ + "by_index_number", + "all_but_by_index_number", + "by_header_name", + "all_but_by_header_name", + ]: + c = params["input_options"]["column_selector_options_1"]["col1"] else: c = None @@ -209,25 +222,23 @@ if df_key in loaded_df: infile1 = loaded_df[df_key] - df = pd.read_csv(infile1, sep='\t', header=header, - parse_dates=True) + df = pd.read_csv(infile1, sep="\t", header=header, parse_dates=True) loaded_df[df_key] = df X = read_columns(df, c=c, c_option=column_option).astype(float) # sparse input - elif input_type == 'sparse': - X = mmread(open(infile1, 'r')) + elif input_type == "sparse": + X = mmread(open(infile1, "r")) # fasta_file input - elif input_type == 'seq_fasta': - pyfaidx = get_module('pyfaidx') + elif input_type == "seq_fasta": + pyfaidx = get_module("pyfaidx") sequences = pyfaidx.Fasta(fasta_path) n_seqs = len(sequences.keys()) X = np.arange(n_seqs)[:, np.newaxis] for param in estimator_params.keys(): - if param.endswith('fasta_path'): - estimator.set_params( - **{param: fasta_path}) + if param.endswith("fasta_path"): + estimator.set_params(**{param: fasta_path}) break else: raise ValueError( @@ -236,25 +247,29 @@ "KerasGBatchClassifier with " "FastaDNABatchGenerator/FastaProteinBatchGenerator " "or having GenomeOneHotEncoder/ProteinOneHotEncoder " - "in pipeline!") + "in pipeline!" + ) - elif input_type == 'refseq_and_interval': + elif input_type == "refseq_and_interval": path_params = { - 'data_batch_generator__ref_genome_path': ref_seq, - 'data_batch_generator__intervals_path': intervals, - 'data_batch_generator__target_path': targets + "data_batch_generator__ref_genome_path": ref_seq, + "data_batch_generator__intervals_path": intervals, + "data_batch_generator__target_path": targets, } estimator.set_params(**path_params) n_intervals = sum(1 for line in open(intervals)) X = np.arange(n_intervals)[:, np.newaxis] # Get target y - header = 'infer' if params['input_options']['header2'] else None - column_option = (params['input_options']['column_selector_options_2'] - ['selected_column_selector_option2']) - if column_option in ['by_index_number', 'all_but_by_index_number', - 'by_header_name', 'all_but_by_header_name']: - c = params['input_options']['column_selector_options_2']['col2'] + header = "infer" if params["input_options"]["header2"] else None + column_option = params["input_options"]["column_selector_options_2"]["selected_column_selector_option2"] + if column_option in [ + "by_index_number", + "all_but_by_index_number", + "by_header_name", + "all_but_by_header_name", + ]: + c = params["input_options"]["column_selector_options_2"]["col2"] else: c = None @@ -262,30 +277,21 @@ if df_key in loaded_df: infile2 = loaded_df[df_key] else: - infile2 = pd.read_csv(infile2, sep='\t', - header=header, parse_dates=True) + infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True) loaded_df[df_key] = infile2 - y = read_columns( - infile2, - c=c, - c_option=column_option, - sep='\t', - header=header, - parse_dates=True) + y = read_columns(infile2, c=c, c_option=column_option, sep="\t", header=header, parse_dates=True) if len(y.shape) == 2 and y.shape[1] == 1: y = y.ravel() - if input_type == 'refseq_and_interval': - estimator.set_params( - data_batch_generator__features=y.ravel().tolist()) + if input_type == "refseq_and_interval": + estimator.set_params(data_batch_generator__features=y.ravel().tolist()) y = None # end y return estimator, X, y -def _do_outer_cv(searcher, X, y, outer_cv, scoring, error_score='raise', - outfile=None): +def _do_outer_cv(searcher, X, y, outer_cv, scoring, error_score="raise", outfile=None): """Do outer cross-validation for nested CV Parameters @@ -305,21 +311,31 @@ outfile : str File path to store the restuls """ - if error_score == 'raise': + if error_score == "raise": rval = cross_validate( - searcher, X, y, scoring=scoring, - cv=outer_cv, n_jobs=N_JOBS, verbose=0, - error_score=error_score) + searcher, + X, + y, + scoring=scoring, + cv=outer_cv, + n_jobs=N_JOBS, + verbose=0, + error_score=error_score, + ) else: - warnings.simplefilter('always', FitFailedWarning) + warnings.simplefilter("always", FitFailedWarning) with warnings.catch_warnings(record=True) as w: try: rval = cross_validate( - searcher, X, y, + searcher, + X, + y, scoring=scoring, - cv=outer_cv, n_jobs=N_JOBS, + cv=outer_cv, + n_jobs=N_JOBS, verbose=0, - error_score=error_score) + error_score=error_score, + ) except ValueError: pass for warning in w: @@ -327,55 +343,57 @@ keys = list(rval.keys()) for k in keys: - if k.startswith('test'): - rval['mean_' + k] = np.mean(rval[k]) - rval['std_' + k] = np.std(rval[k]) - if k.endswith('time'): + if k.startswith("test"): + rval["mean_" + k] = np.mean(rval[k]) + rval["std_" + k] = np.std(rval[k]) + if k.endswith("time"): rval.pop(k) rval = pd.DataFrame(rval) rval = rval[sorted(rval.columns)] - rval.to_csv(path_or_buf=outfile, sep='\t', header=True, index=False) + rval.to_csv(path_or_buf=outfile, sep="\t", header=True, index=False) -def _do_train_test_split_val(searcher, X, y, params, error_score='raise', - primary_scoring=None, groups=None, - outfile=None): - """ do train test split, searchCV validates on the train and then use +def _do_train_test_split_val( + searcher, + X, + y, + params, + error_score="raise", + primary_scoring=None, + groups=None, + outfile=None, +): + """do train test split, searchCV validates on the train and then use the best_estimator_ to evaluate on the test Returns -------- Fitted SearchCV object """ - train_test_split = try_get_attr( - 'galaxy_ml.model_validations', 'train_test_split') - split_options = params['outer_split'] + train_test_split = try_get_attr("galaxy_ml.model_validations", "train_test_split") + split_options = params["outer_split"] # splits - if split_options['shuffle'] == 'stratified': - split_options['labels'] = y + if split_options["shuffle"] == "stratified": + split_options["labels"] = y X, X_test, y, y_test = train_test_split(X, y, **split_options) - elif split_options['shuffle'] == 'group': + elif split_options["shuffle"] == "group": if groups is None: - raise ValueError("No group based CV option was choosen for " - "group shuffle!") - split_options['labels'] = groups + raise ValueError("No group based CV option was choosen for " "group shuffle!") + split_options["labels"] = groups if y is None: - X, X_test, groups, _ =\ - train_test_split(X, groups, **split_options) + X, X_test, groups, _ = train_test_split(X, groups, **split_options) else: - X, X_test, y, y_test, groups, _ =\ - train_test_split(X, y, groups, **split_options) + X, X_test, y, y_test, groups, _ = train_test_split(X, y, groups, **split_options) else: - if split_options['shuffle'] == 'None': - split_options['shuffle'] = None - X, X_test, y, y_test =\ - train_test_split(X, y, **split_options) + if split_options["shuffle"] == "None": + split_options["shuffle"] = None + X, X_test, y, y_test = train_test_split(X, y, **split_options) - if error_score == 'raise': + if error_score == "raise": searcher.fit(X, y, groups=groups) else: - warnings.simplefilter('always', FitFailedWarning) + warnings.simplefilter("always", FitFailedWarning) with warnings.catch_warnings(record=True) as w: try: searcher.fit(X, y, groups=groups) @@ -390,33 +408,38 @@ else: is_multimetric = False - best_estimator_ = getattr(searcher, 'best_estimator_') + best_estimator_ = getattr(searcher, "best_estimator_") # TODO Solve deep learning models in pipeline - if best_estimator_.__class__.__name__ == 'KerasGBatchClassifier': - test_score = best_estimator_.evaluate( - X_test, scorer=scorer_, is_multimetric=is_multimetric) + if best_estimator_.__class__.__name__ == "KerasGBatchClassifier": + test_score = best_estimator_.evaluate(X_test, scorer=scorer_, is_multimetric=is_multimetric) else: - test_score = _score(best_estimator_, X_test, - y_test, scorer_, - is_multimetric=is_multimetric) + test_score = _score(best_estimator_, X_test, y_test, scorer_, is_multimetric=is_multimetric) if not is_multimetric: test_score = {primary_scoring: test_score} for key, value in test_score.items(): test_score[key] = [value] result_df = pd.DataFrame(test_score) - result_df.to_csv(path_or_buf=outfile, sep='\t', header=True, - index=False) + result_df.to_csv(path_or_buf=outfile, sep="\t", header=True, index=False) return searcher -def main(inputs, infile_estimator, infile1, infile2, - outfile_result, outfile_object=None, - outfile_weights=None, groups=None, - ref_seq=None, intervals=None, targets=None, - fasta_path=None): +def main( + inputs, + infile_estimator, + infile1, + infile2, + outfile_result, + outfile_object=None, + outfile_weights=None, + groups=None, + ref_seq=None, + intervals=None, + targets=None, + fasta_path=None, +): """ Parameter --------- @@ -456,154 +479,174 @@ fasta_path : str File path to dataset containing fasta file """ - warnings.simplefilter('ignore') + warnings.simplefilter("ignore") # store read dataframe object loaded_df = {} - with open(inputs, 'r') as param_handler: + with open(inputs, "r") as param_handler: params = json.load(param_handler) # Override the refit parameter - params['search_schemes']['options']['refit'] = True \ - if params['save'] != 'nope' else False + params["search_schemes"]["options"]["refit"] = True if params["save"] != "nope" else False - with open(infile_estimator, 'rb') as estimator_handler: + with open(infile_estimator, "rb") as estimator_handler: estimator = load_model(estimator_handler) - optimizer = params['search_schemes']['selected_search_scheme'] + optimizer = params["search_schemes"]["selected_search_scheme"] optimizer = getattr(model_selection, optimizer) # handle gridsearchcv options - options = params['search_schemes']['options'] + options = params["search_schemes"]["options"] if groups: - header = 'infer' if (options['cv_selector']['groups_selector'] - ['header_g']) else None - column_option = (options['cv_selector']['groups_selector'] - ['column_selector_options_g'] - ['selected_column_selector_option_g']) - if column_option in ['by_index_number', 'all_but_by_index_number', - 'by_header_name', 'all_but_by_header_name']: - c = (options['cv_selector']['groups_selector'] - ['column_selector_options_g']['col_g']) + header = "infer" if (options["cv_selector"]["groups_selector"]["header_g"]) else None + column_option = options["cv_selector"]["groups_selector"]["column_selector_options_g"][ + "selected_column_selector_option_g" + ] + if column_option in [ + "by_index_number", + "all_but_by_index_number", + "by_header_name", + "all_but_by_header_name", + ]: + c = options["cv_selector"]["groups_selector"]["column_selector_options_g"]["col_g"] else: c = None df_key = groups + repr(header) - groups = pd.read_csv(groups, sep='\t', header=header, - parse_dates=True) + groups = pd.read_csv(groups, sep="\t", header=header, parse_dates=True) loaded_df[df_key] = groups groups = read_columns( - groups, - c=c, - c_option=column_option, - sep='\t', - header=header, - parse_dates=True) + groups, + c=c, + c_option=column_option, + sep="\t", + header=header, + parse_dates=True, + ) groups = groups.ravel() - options['cv_selector']['groups_selector'] = groups + options["cv_selector"]["groups_selector"] = groups - splitter, groups = get_cv(options.pop('cv_selector')) - options['cv'] = splitter - primary_scoring = options['scoring']['primary_scoring'] - options['scoring'] = get_scoring(options['scoring']) - if options['error_score']: - options['error_score'] = 'raise' + splitter, groups = get_cv(options.pop("cv_selector")) + options["cv"] = splitter + primary_scoring = options["scoring"]["primary_scoring"] + # get_scoring() expects secondary_scoring to be a comma separated string (not a list) + # Check if secondary_scoring is specified + secondary_scoring = options["scoring"].get("secondary_scoring", None) + if secondary_scoring is not None: + # If secondary_scoring is specified, convert the list into comman separated string + options["scoring"]["secondary_scoring"] = ",".join(options["scoring"]["secondary_scoring"]) + options["scoring"] = get_scoring(options["scoring"]) + if options["error_score"]: + options["error_score"] = "raise" else: - options['error_score'] = np.NaN - if options['refit'] and isinstance(options['scoring'], dict): - options['refit'] = primary_scoring - if 'pre_dispatch' in options and options['pre_dispatch'] == '': - options['pre_dispatch'] = None + options["error_score"] = np.NaN + if options["refit"] and isinstance(options["scoring"], dict): + options["refit"] = primary_scoring + if "pre_dispatch" in options and options["pre_dispatch"] == "": + options["pre_dispatch"] = None - params_builder = params['search_schemes']['search_params_builder'] + params_builder = params["search_schemes"]["search_params_builder"] param_grid = _eval_search_params(params_builder) estimator = clean_params(estimator) # save the SearchCV object without fit - if params['save'] == 'save_no_fit': + if params["save"] == "save_no_fit": searcher = optimizer(estimator, param_grid, **options) print(searcher) - with open(outfile_object, 'wb') as output_handler: - pickle.dump(searcher, output_handler, - pickle.HIGHEST_PROTOCOL) + with open(outfile_object, "wb") as output_handler: + pickle.dump(searcher, output_handler, pickle.HIGHEST_PROTOCOL) return 0 # read inputs and loads new attributes, like paths - estimator, X, y = _handle_X_y(estimator, params, infile1, infile2, - loaded_df=loaded_df, ref_seq=ref_seq, - intervals=intervals, targets=targets, - fasta_path=fasta_path) + estimator, X, y = _handle_X_y( + estimator, + params, + infile1, + infile2, + loaded_df=loaded_df, + ref_seq=ref_seq, + intervals=intervals, + targets=targets, + fasta_path=fasta_path, + ) # cache iraps_core fits could increase search speed significantly memory = joblib.Memory(location=CACHE_DIR, verbose=0) main_est = get_main_estimator(estimator) - if main_est.__class__.__name__ == 'IRAPSClassifier': + if main_est.__class__.__name__ == "IRAPSClassifier": main_est.set_params(memory=memory) searcher = optimizer(estimator, param_grid, **options) - split_mode = params['outer_split'].pop('split_mode') + split_mode = params["outer_split"].pop("split_mode") - if split_mode == 'nested_cv': + if split_mode == "nested_cv": # make sure refit is choosen # this could be True for sklearn models, but not the case for # deep learning models - if not options['refit'] and \ - not all(hasattr(estimator, attr) - for attr in ('config', 'model_type')): + if not options["refit"] and not all(hasattr(estimator, attr) for attr in ("config", "model_type")): warnings.warn("Refit is change to `True` for nested validation!") - setattr(searcher, 'refit', True) + setattr(searcher, "refit", True) - outer_cv, _ = get_cv(params['outer_split']['cv_selector']) + outer_cv, _ = get_cv(params["outer_split"]["cv_selector"]) # nested CV, outer cv using cross_validate - if options['error_score'] == 'raise': + if options["error_score"] == "raise": rval = cross_validate( - searcher, X, y, scoring=options['scoring'], - cv=outer_cv, n_jobs=N_JOBS, - verbose=options['verbose'], - return_estimator=(params['save'] == 'save_estimator'), - error_score=options['error_score'], - return_train_score=True) + searcher, + X, + y, + scoring=options["scoring"], + cv=outer_cv, + n_jobs=N_JOBS, + verbose=options["verbose"], + return_estimator=(params["save"] == "save_estimator"), + error_score=options["error_score"], + return_train_score=True, + ) else: - warnings.simplefilter('always', FitFailedWarning) + warnings.simplefilter("always", FitFailedWarning) with warnings.catch_warnings(record=True) as w: try: rval = cross_validate( - searcher, X, y, - scoring=options['scoring'], - cv=outer_cv, n_jobs=N_JOBS, - verbose=options['verbose'], - return_estimator=(params['save'] == 'save_estimator'), - error_score=options['error_score'], - return_train_score=True) + searcher, + X, + y, + scoring=options["scoring"], + cv=outer_cv, + n_jobs=N_JOBS, + verbose=options["verbose"], + return_estimator=(params["save"] == "save_estimator"), + error_score=options["error_score"], + return_train_score=True, + ) except ValueError: pass for warning in w: print(repr(warning.message)) - fitted_searchers = rval.pop('estimator', []) + fitted_searchers = rval.pop("estimator", []) if fitted_searchers: import os + pwd = os.getcwd() - save_dir = os.path.join(pwd, 'cv_results_in_folds') + save_dir = os.path.join(pwd, "cv_results_in_folds") try: os.mkdir(save_dir) for idx, obj in enumerate(fitted_searchers): - target_name = 'cv_results_' + '_' + 'split%d' % idx + target_name = "cv_results_" + "_" + "split%d" % idx target_path = os.path.join(pwd, save_dir, target_name) - cv_results_ = getattr(obj, 'cv_results_', None) + cv_results_ = getattr(obj, "cv_results_", None) if not cv_results_: print("%s is not available" % target_name) continue cv_results_ = pd.DataFrame(cv_results_) cv_results_ = cv_results_[sorted(cv_results_.columns)] - cv_results_.to_csv(target_path, sep='\t', header=True, - index=False) + cv_results_.to_csv(target_path, sep="\t", header=True, index=False) except Exception as e: print(e) finally: @@ -611,18 +654,14 @@ keys = list(rval.keys()) for k in keys: - if k.startswith('test'): - rval['mean_' + k] = np.mean(rval[k]) - rval['std_' + k] = np.std(rval[k]) - if k.endswith('time'): + if k.startswith("test"): + rval["mean_" + k] = np.mean(rval[k]) + rval["std_" + k] = np.std(rval[k]) + if k.endswith("time"): rval.pop(k) rval = pd.DataFrame(rval) rval = rval[sorted(rval.columns)] - rval.to_csv(path_or_buf=outfile_result, sep='\t', header=True, - index=False) - - return 0 - + rval.to_csv(path_or_buf=outfile_result, sep="\t", header=True, index=False) # deprecate train test split mode """searcher = _do_train_test_split_val( searcher, X, y, params, @@ -630,14 +669,15 @@ error_score=options['error_score'], groups=groups, outfile=outfile_result)""" + return 0 # no outer split else: searcher.set_params(n_jobs=N_JOBS) - if options['error_score'] == 'raise': + if options["error_score"] == "raise": searcher.fit(X, y, groups=groups) else: - warnings.simplefilter('always', FitFailedWarning) + warnings.simplefilter("always", FitFailedWarning) with warnings.catch_warnings(record=True) as w: try: searcher.fit(X, y, groups=groups) @@ -648,18 +688,19 @@ cv_results = pd.DataFrame(searcher.cv_results_) cv_results = cv_results[sorted(cv_results.columns)] - cv_results.to_csv(path_or_buf=outfile_result, sep='\t', - header=True, index=False) + cv_results.to_csv(path_or_buf=outfile_result, sep="\t", header=True, index=False) memory.clear(warn=False) # output best estimator, and weights if applicable if outfile_object: - best_estimator_ = getattr(searcher, 'best_estimator_', None) + best_estimator_ = getattr(searcher, "best_estimator_", None) if not best_estimator_: - warnings.warn("GridSearchCV object has no attribute " - "'best_estimator_', because either it's " - "nested gridsearch or `refit` is False!") + warnings.warn( + "GridSearchCV object has no attribute " + "'best_estimator_', because either it's " + "nested gridsearch or `refit` is False!" + ) return # clean prams @@ -667,24 +708,22 @@ main_est = get_main_estimator(best_estimator_) - if hasattr(main_est, 'model_') \ - and hasattr(main_est, 'save_weights'): + if hasattr(main_est, "model_") and hasattr(main_est, "save_weights"): if outfile_weights: main_est.save_weights(outfile_weights) del main_est.model_ del main_est.fit_params del main_est.model_class_ del main_est.validation_data - if getattr(main_est, 'data_generator_', None): + if getattr(main_est, "data_generator_", None): del main_est.data_generator_ - with open(outfile_object, 'wb') as output_handler: + with open(outfile_object, "wb") as output_handler: print("Best estimator is saved: %s " % repr(best_estimator_)) - pickle.dump(best_estimator_, output_handler, - pickle.HIGHEST_PROTOCOL) + pickle.dump(best_estimator_, output_handler, pickle.HIGHEST_PROTOCOL) -if __name__ == '__main__': +if __name__ == "__main__": aparser = argparse.ArgumentParser() aparser.add_argument("-i", "--inputs", dest="inputs", required=True) aparser.add_argument("-e", "--estimator", dest="infile_estimator") @@ -700,8 +739,17 @@ aparser.add_argument("-f", "--fasta_path", dest="fasta_path") args = aparser.parse_args() - main(args.inputs, args.infile_estimator, args.infile1, args.infile2, - args.outfile_result, outfile_object=args.outfile_object, - outfile_weights=args.outfile_weights, groups=args.groups, - ref_seq=args.ref_seq, intervals=args.intervals, - targets=args.targets, fasta_path=args.fasta_path) + main( + args.inputs, + args.infile_estimator, + args.infile1, + args.infile2, + args.outfile_result, + outfile_object=args.outfile_object, + outfile_weights=args.outfile_weights, + groups=args.groups, + ref_seq=args.ref_seq, + intervals=args.intervals, + targets=args.targets, + fasta_path=args.fasta_path, + )