Mercurial > repos > bgruening > sklearn_sample_generator
comparison train_test_split.py @ 35:1e99cfb71f40 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
author | bgruening |
---|---|
date | Tue, 13 Apr 2021 17:52:15 +0000 |
parents | 6b14fe097541 |
children | 999e07f0a9fa |
comparison
equal
deleted
inserted
replaced
34:7068b5fcd623 | 35:1e99cfb71f40 |
---|---|
5 | 5 |
6 from galaxy_ml.model_validations import train_test_split | 6 from galaxy_ml.model_validations import train_test_split |
7 from galaxy_ml.utils import get_cv, read_columns | 7 from galaxy_ml.utils import get_cv, read_columns |
8 | 8 |
9 | 9 |
10 def _get_single_cv_split(params, array, infile_labels=None, | 10 def _get_single_cv_split(params, array, infile_labels=None, infile_groups=None): |
11 infile_groups=None): | 11 """output (train, test) subset from a cv splitter |
12 """ output (train, test) subset from a cv splitter | |
13 | 12 |
14 Parameters | 13 Parameters |
15 ---------- | 14 ---------- |
16 params : dict | 15 params : dict |
17 Galaxy tool inputs | 16 Galaxy tool inputs |
23 File path to dataset containing group values | 22 File path to dataset containing group values |
24 """ | 23 """ |
25 y = None | 24 y = None |
26 groups = None | 25 groups = None |
27 | 26 |
28 nth_split = params['mode_selection']['nth_split'] | 27 nth_split = params["mode_selection"]["nth_split"] |
29 | 28 |
30 # read groups | 29 # read groups |
31 if infile_groups: | 30 if infile_groups: |
32 header = 'infer' if (params['mode_selection']['cv_selector'] | 31 header = "infer" if (params["mode_selection"]["cv_selector"]["groups_selector"]["header_g"]) else None |
33 ['groups_selector']['header_g']) else None | 32 column_option = params["mode_selection"]["cv_selector"]["groups_selector"]["column_selector_options_g"][ |
34 column_option = (params['mode_selection']['cv_selector'] | 33 "selected_column_selector_option_g" |
35 ['groups_selector']['column_selector_options_g'] | 34 ] |
36 ['selected_column_selector_option_g']) | 35 if column_option in [ |
37 if column_option in ['by_index_number', 'all_but_by_index_number', | 36 "by_index_number", |
38 'by_header_name', 'all_but_by_header_name']: | 37 "all_but_by_index_number", |
39 c = (params['mode_selection']['cv_selector']['groups_selector'] | 38 "by_header_name", |
40 ['column_selector_options_g']['col_g']) | 39 "all_but_by_header_name", |
40 ]: | |
41 c = params["mode_selection"]["cv_selector"]["groups_selector"]["column_selector_options_g"]["col_g"] | |
41 else: | 42 else: |
42 c = None | 43 c = None |
43 | 44 |
44 groups = read_columns(infile_groups, c=c, c_option=column_option, | 45 groups = read_columns( |
45 sep='\t', header=header, parse_dates=True) | 46 infile_groups, |
47 c=c, | |
48 c_option=column_option, | |
49 sep="\t", | |
50 header=header, | |
51 parse_dates=True, | |
52 ) | |
46 groups = groups.ravel() | 53 groups = groups.ravel() |
47 | 54 |
48 params['mode_selection']['cv_selector']['groups_selector'] = groups | 55 params["mode_selection"]["cv_selector"]["groups_selector"] = groups |
49 | 56 |
50 # read labels | 57 # read labels |
51 if infile_labels: | 58 if infile_labels: |
52 target_input = (params['mode_selection'] | 59 target_input = params["mode_selection"]["cv_selector"].pop("target_input") |
53 ['cv_selector'].pop('target_input')) | 60 header = "infer" if target_input["header1"] else None |
54 header = 'infer' if target_input['header1'] else None | 61 col_index = target_input["col"][0] - 1 |
55 col_index = target_input['col'][0] - 1 | 62 df = pd.read_csv(infile_labels, sep="\t", header=header, parse_dates=True) |
56 df = pd.read_csv(infile_labels, sep='\t', header=header, | |
57 parse_dates=True) | |
58 y = df.iloc[:, col_index].values | 63 y = df.iloc[:, col_index].values |
59 | 64 |
60 # construct the cv splitter object | 65 # construct the cv splitter object |
61 splitter, groups = get_cv(params['mode_selection']['cv_selector']) | 66 splitter, groups = get_cv(params["mode_selection"]["cv_selector"]) |
62 | 67 |
63 total_n_splits = splitter.get_n_splits(array.values, y=y, groups=groups) | 68 total_n_splits = splitter.get_n_splits(array.values, y=y, groups=groups) |
64 if nth_split > total_n_splits: | 69 if nth_split > total_n_splits: |
65 raise ValueError("Total number of splits is {}, but got `nth_split` " | 70 raise ValueError("Total number of splits is {}, but got `nth_split` " "= {}".format(total_n_splits, nth_split)) |
66 "= {}".format(total_n_splits, nth_split)) | |
67 | 71 |
68 i = 1 | 72 i = 1 |
69 for train_index, test_index in splitter.split(array.values, y=y, groups=groups): | 73 for train_index, test_index in splitter.split(array.values, y=y, groups=groups): |
70 # suppose nth_split >= 1 | 74 # suppose nth_split >= 1 |
71 if i == nth_split: | 75 if i == nth_split: |
77 test = array.iloc[test_index, :] | 81 test = array.iloc[test_index, :] |
78 | 82 |
79 return train, test | 83 return train, test |
80 | 84 |
81 | 85 |
82 def main(inputs, infile_array, outfile_train, outfile_test, | 86 def main( |
83 infile_labels=None, infile_groups=None): | 87 inputs, |
88 infile_array, | |
89 outfile_train, | |
90 outfile_test, | |
91 infile_labels=None, | |
92 infile_groups=None, | |
93 ): | |
84 """ | 94 """ |
85 Parameter | 95 Parameter |
86 --------- | 96 --------- |
87 inputs : str | 97 inputs : str |
88 File path to galaxy tool parameter | 98 File path to galaxy tool parameter |
100 File path to dataset containing train split | 110 File path to dataset containing train split |
101 | 111 |
102 outfile_test : str | 112 outfile_test : str |
103 File path to dataset containing test split | 113 File path to dataset containing test split |
104 """ | 114 """ |
105 warnings.simplefilter('ignore') | 115 warnings.simplefilter("ignore") |
106 | 116 |
107 with open(inputs, 'r') as param_handler: | 117 with open(inputs, "r") as param_handler: |
108 params = json.load(param_handler) | 118 params = json.load(param_handler) |
109 | 119 |
110 input_header = params['header0'] | 120 input_header = params["header0"] |
111 header = 'infer' if input_header else None | 121 header = "infer" if input_header else None |
112 array = pd.read_csv(infile_array, sep='\t', header=header, | 122 array = pd.read_csv(infile_array, sep="\t", header=header, parse_dates=True) |
113 parse_dates=True) | |
114 | 123 |
115 # train test split | 124 # train test split |
116 if params['mode_selection']['selected_mode'] == 'train_test_split': | 125 if params["mode_selection"]["selected_mode"] == "train_test_split": |
117 options = params['mode_selection']['options'] | 126 options = params["mode_selection"]["options"] |
118 shuffle_selection = options.pop('shuffle_selection') | 127 shuffle_selection = options.pop("shuffle_selection") |
119 options['shuffle'] = shuffle_selection['shuffle'] | 128 options["shuffle"] = shuffle_selection["shuffle"] |
120 if infile_labels: | 129 if infile_labels: |
121 header = 'infer' if shuffle_selection['header1'] else None | 130 header = "infer" if shuffle_selection["header1"] else None |
122 col_index = shuffle_selection['col'][0] - 1 | 131 col_index = shuffle_selection["col"][0] - 1 |
123 df = pd.read_csv(infile_labels, sep='\t', header=header, | 132 df = pd.read_csv(infile_labels, sep="\t", header=header, parse_dates=True) |
124 parse_dates=True) | |
125 labels = df.iloc[:, col_index].values | 133 labels = df.iloc[:, col_index].values |
126 options['labels'] = labels | 134 options["labels"] = labels |
127 | 135 |
128 train, test = train_test_split(array, **options) | 136 train, test = train_test_split(array, **options) |
129 | 137 |
130 # cv splitter | 138 # cv splitter |
131 else: | 139 else: |
132 train, test = _get_single_cv_split(params, array, | 140 train, test = _get_single_cv_split(params, array, infile_labels=infile_labels, infile_groups=infile_groups) |
133 infile_labels=infile_labels, | |
134 infile_groups=infile_groups) | |
135 | 141 |
136 print("Input shape: %s" % repr(array.shape)) | 142 print("Input shape: %s" % repr(array.shape)) |
137 print("Train shape: %s" % repr(train.shape)) | 143 print("Train shape: %s" % repr(train.shape)) |
138 print("Test shape: %s" % repr(test.shape)) | 144 print("Test shape: %s" % repr(test.shape)) |
139 train.to_csv(outfile_train, sep='\t', header=input_header, index=False) | 145 train.to_csv(outfile_train, sep="\t", header=input_header, index=False) |
140 test.to_csv(outfile_test, sep='\t', header=input_header, index=False) | 146 test.to_csv(outfile_test, sep="\t", header=input_header, index=False) |
141 | 147 |
142 | 148 |
143 if __name__ == '__main__': | 149 if __name__ == "__main__": |
144 aparser = argparse.ArgumentParser() | 150 aparser = argparse.ArgumentParser() |
145 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | 151 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) |
146 aparser.add_argument("-X", "--infile_array", dest="infile_array") | 152 aparser.add_argument("-X", "--infile_array", dest="infile_array") |
147 aparser.add_argument("-y", "--infile_labels", dest="infile_labels") | 153 aparser.add_argument("-y", "--infile_labels", dest="infile_labels") |
148 aparser.add_argument("-g", "--infile_groups", dest="infile_groups") | 154 aparser.add_argument("-g", "--infile_groups", dest="infile_groups") |
149 aparser.add_argument("-o", "--outfile_train", dest="outfile_train") | 155 aparser.add_argument("-o", "--outfile_train", dest="outfile_train") |
150 aparser.add_argument("-t", "--outfile_test", dest="outfile_test") | 156 aparser.add_argument("-t", "--outfile_test", dest="outfile_test") |
151 args = aparser.parse_args() | 157 args = aparser.parse_args() |
152 | 158 |
153 main(args.inputs, args.infile_array, args.outfile_train, | 159 main( |
154 args.outfile_test, args.infile_labels, args.infile_groups) | 160 args.inputs, |
161 args.infile_array, | |
162 args.outfile_train, | |
163 args.outfile_test, | |
164 args.infile_labels, | |
165 args.infile_groups, | |
166 ) |