comparison train_test_split.py @ 0:af2624d5ab32 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ea12f973df4b97a2691d9e4ce6bf6fae59d57717"
author bgruening
date Sat, 01 May 2021 01:24:32 +0000
parents
children 9349ed2749c6
comparison
equal deleted inserted replaced
-1:000000000000 0:af2624d5ab32
1 import argparse
2 import json
3 import warnings
4
5 import pandas as pd
6 from galaxy_ml.model_validations import train_test_split
7 from galaxy_ml.utils import get_cv, read_columns
8
9
10 def _get_single_cv_split(params, array, infile_labels=None, infile_groups=None):
11 """output (train, test) subset from a cv splitter
12
13 Parameters
14 ----------
15 params : dict
16 Galaxy tool inputs
17 array : pandas DataFrame object
18 The target dataset to split
19 infile_labels : str
20 File path to dataset containing target values
21 infile_groups : str
22 File path to dataset containing group values
23 """
24 y = None
25 groups = None
26
27 nth_split = params["mode_selection"]["nth_split"]
28
29 # read groups
30 if infile_groups:
31 header = (
32 "infer"
33 if (params["mode_selection"]["cv_selector"]["groups_selector"]["header_g"])
34 else None
35 )
36 column_option = params["mode_selection"]["cv_selector"]["groups_selector"][
37 "column_selector_options_g"
38 ]["selected_column_selector_option_g"]
39 if column_option in [
40 "by_index_number",
41 "all_but_by_index_number",
42 "by_header_name",
43 "all_but_by_header_name",
44 ]:
45 c = params["mode_selection"]["cv_selector"]["groups_selector"][
46 "column_selector_options_g"
47 ]["col_g"]
48 else:
49 c = None
50
51 groups = read_columns(
52 infile_groups,
53 c=c,
54 c_option=column_option,
55 sep="\t",
56 header=header,
57 parse_dates=True,
58 )
59 groups = groups.ravel()
60
61 params["mode_selection"]["cv_selector"]["groups_selector"] = groups
62
63 # read labels
64 if infile_labels:
65 target_input = params["mode_selection"]["cv_selector"].pop("target_input")
66 header = "infer" if target_input["header1"] else None
67 col_index = target_input["col"][0] - 1
68 df = pd.read_csv(infile_labels, sep="\t", header=header, parse_dates=True)
69 y = df.iloc[:, col_index].values
70
71 # construct the cv splitter object
72 splitter, groups = get_cv(params["mode_selection"]["cv_selector"])
73
74 total_n_splits = splitter.get_n_splits(array.values, y=y, groups=groups)
75 if nth_split > total_n_splits:
76 raise ValueError(
77 "Total number of splits is {}, but got `nth_split` "
78 "= {}".format(total_n_splits, nth_split)
79 )
80
81 i = 1
82 for train_index, test_index in splitter.split(array.values, y=y, groups=groups):
83 # suppose nth_split >= 1
84 if i == nth_split:
85 break
86 else:
87 i += 1
88
89 train = array.iloc[train_index, :]
90 test = array.iloc[test_index, :]
91
92 return train, test
93
94
95 def main(
96 inputs,
97 infile_array,
98 outfile_train,
99 outfile_test,
100 infile_labels=None,
101 infile_groups=None,
102 ):
103 """
104 Parameter
105 ---------
106 inputs : str
107 File path to galaxy tool parameter
108
109 infile_array : str
110 File paths of input arrays separated by comma
111
112 infile_labels : str
113 File path to dataset containing labels
114
115 infile_groups : str
116 File path to dataset containing groups
117
118 outfile_train : str
119 File path to dataset containing train split
120
121 outfile_test : str
122 File path to dataset containing test split
123 """
124 warnings.simplefilter("ignore")
125
126 with open(inputs, "r") as param_handler:
127 params = json.load(param_handler)
128
129 input_header = params["header0"]
130 header = "infer" if input_header else None
131 array = pd.read_csv(infile_array, sep="\t", header=header, parse_dates=True)
132
133 # train test split
134 if params["mode_selection"]["selected_mode"] == "train_test_split":
135 options = params["mode_selection"]["options"]
136 shuffle_selection = options.pop("shuffle_selection")
137 options["shuffle"] = shuffle_selection["shuffle"]
138 if infile_labels:
139 header = "infer" if shuffle_selection["header1"] else None
140 col_index = shuffle_selection["col"][0] - 1
141 df = pd.read_csv(infile_labels, sep="\t", header=header, parse_dates=True)
142 labels = df.iloc[:, col_index].values
143 options["labels"] = labels
144
145 train, test = train_test_split(array, **options)
146
147 # cv splitter
148 else:
149 train, test = _get_single_cv_split(
150 params, array, infile_labels=infile_labels, infile_groups=infile_groups
151 )
152
153 print("Input shape: %s" % repr(array.shape))
154 print("Train shape: %s" % repr(train.shape))
155 print("Test shape: %s" % repr(test.shape))
156 train.to_csv(outfile_train, sep="\t", header=input_header, index=False)
157 test.to_csv(outfile_test, sep="\t", header=input_header, index=False)
158
159
160 if __name__ == "__main__":
161 aparser = argparse.ArgumentParser()
162 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
163 aparser.add_argument("-X", "--infile_array", dest="infile_array")
164 aparser.add_argument("-y", "--infile_labels", dest="infile_labels")
165 aparser.add_argument("-g", "--infile_groups", dest="infile_groups")
166 aparser.add_argument("-o", "--outfile_train", dest="outfile_train")
167 aparser.add_argument("-t", "--outfile_test", dest="outfile_test")
168 args = aparser.parse_args()
169
170 main(
171 args.inputs,
172 args.infile_array,
173 args.outfile_train,
174 args.outfile_test,
175 args.infile_labels,
176 args.infile_groups,
177 )