Mercurial > repos > bgruening > scipy_sparse
diff train_test_split.py @ 36:92e09b827300 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ea12f973df4b97a2691d9e4ce6bf6fae59d57717"
author | bgruening |
---|---|
date | Sat, 01 May 2021 01:43:23 +0000 |
parents | 318484f56b6a |
children | 5af054432771 |
line wrap: on
line diff
--- a/train_test_split.py Tue Apr 13 22:34:09 2021 +0000 +++ b/train_test_split.py Sat May 01 01:43:23 2021 +0000 @@ -28,17 +28,23 @@ # read groups if infile_groups: - header = "infer" if (params["mode_selection"]["cv_selector"]["groups_selector"]["header_g"]) else None - column_option = params["mode_selection"]["cv_selector"]["groups_selector"]["column_selector_options_g"][ - "selected_column_selector_option_g" - ] + header = ( + "infer" + if (params["mode_selection"]["cv_selector"]["groups_selector"]["header_g"]) + else None + ) + column_option = params["mode_selection"]["cv_selector"]["groups_selector"][ + "column_selector_options_g" + ]["selected_column_selector_option_g"] if column_option in [ "by_index_number", "all_but_by_index_number", "by_header_name", "all_but_by_header_name", ]: - c = params["mode_selection"]["cv_selector"]["groups_selector"]["column_selector_options_g"]["col_g"] + c = params["mode_selection"]["cv_selector"]["groups_selector"][ + "column_selector_options_g" + ]["col_g"] else: c = None @@ -67,7 +73,10 @@ total_n_splits = splitter.get_n_splits(array.values, y=y, groups=groups) if nth_split > total_n_splits: - raise ValueError("Total number of splits is {}, but got `nth_split` " "= {}".format(total_n_splits, nth_split)) + raise ValueError( + "Total number of splits is {}, but got `nth_split` " + "= {}".format(total_n_splits, nth_split) + ) i = 1 for train_index, test_index in splitter.split(array.values, y=y, groups=groups): @@ -137,7 +146,9 @@ # cv splitter else: - train, test = _get_single_cv_split(params, array, infile_labels=infile_labels, infile_groups=infile_groups) + train, test = _get_single_cv_split( + params, array, infile_labels=infile_labels, infile_groups=infile_groups + ) print("Input shape: %s" % repr(array.shape)) print("Train shape: %s" % repr(train.shape))