comparison train_test_split.py @ 11:0c933465d70e draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ea12f973df4b97a2691d9e4ce6bf6fae59d57717"
author bgruening
date Sat, 01 May 2021 00:59:48 +0000
parents 64bbfa592868
children 6eb4e7fb0f91
comparison
equal deleted inserted replaced
10:64bbfa592868 11:0c933465d70e
26 26
27 nth_split = params["mode_selection"]["nth_split"] 27 nth_split = params["mode_selection"]["nth_split"]
28 28
29 # read groups 29 # read groups
30 if infile_groups: 30 if infile_groups:
31 header = "infer" if (params["mode_selection"]["cv_selector"]["groups_selector"]["header_g"]) else None 31 header = (
32 column_option = params["mode_selection"]["cv_selector"]["groups_selector"]["column_selector_options_g"][ 32 "infer"
33 "selected_column_selector_option_g" 33 if (params["mode_selection"]["cv_selector"]["groups_selector"]["header_g"])
34 ] 34 else None
35 )
36 column_option = params["mode_selection"]["cv_selector"]["groups_selector"][
37 "column_selector_options_g"
38 ]["selected_column_selector_option_g"]
35 if column_option in [ 39 if column_option in [
36 "by_index_number", 40 "by_index_number",
37 "all_but_by_index_number", 41 "all_but_by_index_number",
38 "by_header_name", 42 "by_header_name",
39 "all_but_by_header_name", 43 "all_but_by_header_name",
40 ]: 44 ]:
41 c = params["mode_selection"]["cv_selector"]["groups_selector"]["column_selector_options_g"]["col_g"] 45 c = params["mode_selection"]["cv_selector"]["groups_selector"][
46 "column_selector_options_g"
47 ]["col_g"]
42 else: 48 else:
43 c = None 49 c = None
44 50
45 groups = read_columns( 51 groups = read_columns(
46 infile_groups, 52 infile_groups,
65 # construct the cv splitter object 71 # construct the cv splitter object
66 splitter, groups = get_cv(params["mode_selection"]["cv_selector"]) 72 splitter, groups = get_cv(params["mode_selection"]["cv_selector"])
67 73
68 total_n_splits = splitter.get_n_splits(array.values, y=y, groups=groups) 74 total_n_splits = splitter.get_n_splits(array.values, y=y, groups=groups)
69 if nth_split > total_n_splits: 75 if nth_split > total_n_splits:
70 raise ValueError("Total number of splits is {}, but got `nth_split` " "= {}".format(total_n_splits, nth_split)) 76 raise ValueError(
77 "Total number of splits is {}, but got `nth_split` "
78 "= {}".format(total_n_splits, nth_split)
79 )
71 80
72 i = 1 81 i = 1
73 for train_index, test_index in splitter.split(array.values, y=y, groups=groups): 82 for train_index, test_index in splitter.split(array.values, y=y, groups=groups):
74 # suppose nth_split >= 1 83 # suppose nth_split >= 1
75 if i == nth_split: 84 if i == nth_split:
135 144
136 train, test = train_test_split(array, **options) 145 train, test = train_test_split(array, **options)
137 146
138 # cv splitter 147 # cv splitter
139 else: 148 else:
140 train, test = _get_single_cv_split(params, array, infile_labels=infile_labels, infile_groups=infile_groups) 149 train, test = _get_single_cv_split(
150 params, array, infile_labels=infile_labels, infile_groups=infile_groups
151 )
141 152
142 print("Input shape: %s" % repr(array.shape)) 153 print("Input shape: %s" % repr(array.shape))
143 print("Train shape: %s" % repr(train.shape)) 154 print("Train shape: %s" % repr(train.shape))
144 print("Test shape: %s" % repr(test.shape)) 155 print("Test shape: %s" % repr(test.shape))
145 train.to_csv(outfile_train, sep="\t", header=input_header, index=False) 156 train.to_csv(outfile_train, sep="\t", header=input_header, index=False)