Mercurial > repos > bgruening > model_prediction
comparison fitted_model_eval.py @ 9:4aa701f5a393 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
author | bgruening |
---|---|
date | Tue, 13 Apr 2021 18:00:54 +0000 |
parents | fb1fa391189e |
children | 22f9cbcf1582 |
comparison
equal
deleted
inserted
replaced
8:83228baae3c5 | 9:4aa701f5a393 |
---|---|
9 from sklearn.model_selection._validation import _score | 9 from sklearn.model_selection._validation import _score |
10 from galaxy_ml.utils import get_scoring, load_model, read_columns | 10 from galaxy_ml.utils import get_scoring, load_model, read_columns |
11 | 11 |
12 | 12 |
13 def _get_X_y(params, infile1, infile2): | 13 def _get_X_y(params, infile1, infile2): |
14 """ read from inputs and output X and y | 14 """read from inputs and output X and y |
15 | 15 |
16 Parameters | 16 Parameters |
17 ---------- | 17 ---------- |
18 params : dict | 18 params : dict |
19 Tool inputs parameter | 19 Tool inputs parameter |
24 | 24 |
25 """ | 25 """ |
26 # store read dataframe object | 26 # store read dataframe object |
27 loaded_df = {} | 27 loaded_df = {} |
28 | 28 |
29 input_type = params['input_options']['selected_input'] | 29 input_type = params["input_options"]["selected_input"] |
30 # tabular input | 30 # tabular input |
31 if input_type == 'tabular': | 31 if input_type == "tabular": |
32 header = 'infer' if params['input_options']['header1'] else None | 32 header = "infer" if params["input_options"]["header1"] else None |
33 column_option = (params['input_options']['column_selector_options_1'] | 33 column_option = params["input_options"]["column_selector_options_1"]["selected_column_selector_option"] |
34 ['selected_column_selector_option']) | 34 if column_option in [ |
35 if column_option in ['by_index_number', 'all_but_by_index_number', | 35 "by_index_number", |
36 'by_header_name', 'all_but_by_header_name']: | 36 "all_but_by_index_number", |
37 c = params['input_options']['column_selector_options_1']['col1'] | 37 "by_header_name", |
38 "all_but_by_header_name", | |
39 ]: | |
40 c = params["input_options"]["column_selector_options_1"]["col1"] | |
38 else: | 41 else: |
39 c = None | 42 c = None |
40 | 43 |
41 df_key = infile1 + repr(header) | 44 df_key = infile1 + repr(header) |
42 df = pd.read_csv(infile1, sep='\t', header=header, | 45 df = pd.read_csv(infile1, sep="\t", header=header, parse_dates=True) |
43 parse_dates=True) | |
44 loaded_df[df_key] = df | 46 loaded_df[df_key] = df |
45 | 47 |
46 X = read_columns(df, c=c, c_option=column_option).astype(float) | 48 X = read_columns(df, c=c, c_option=column_option).astype(float) |
47 # sparse input | 49 # sparse input |
48 elif input_type == 'sparse': | 50 elif input_type == "sparse": |
49 X = mmread(open(infile1, 'r')) | 51 X = mmread(open(infile1, "r")) |
50 | 52 |
51 # Get target y | 53 # Get target y |
52 header = 'infer' if params['input_options']['header2'] else None | 54 header = "infer" if params["input_options"]["header2"] else None |
53 column_option = (params['input_options']['column_selector_options_2'] | 55 column_option = params["input_options"]["column_selector_options_2"]["selected_column_selector_option2"] |
54 ['selected_column_selector_option2']) | 56 if column_option in [ |
55 if column_option in ['by_index_number', 'all_but_by_index_number', | 57 "by_index_number", |
56 'by_header_name', 'all_but_by_header_name']: | 58 "all_but_by_index_number", |
57 c = params['input_options']['column_selector_options_2']['col2'] | 59 "by_header_name", |
60 "all_but_by_header_name", | |
61 ]: | |
62 c = params["input_options"]["column_selector_options_2"]["col2"] | |
58 else: | 63 else: |
59 c = None | 64 c = None |
60 | 65 |
61 df_key = infile2 + repr(header) | 66 df_key = infile2 + repr(header) |
62 if df_key in loaded_df: | 67 if df_key in loaded_df: |
63 infile2 = loaded_df[df_key] | 68 infile2 = loaded_df[df_key] |
64 else: | 69 else: |
65 infile2 = pd.read_csv(infile2, sep='\t', | 70 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True) |
66 header=header, parse_dates=True) | |
67 loaded_df[df_key] = infile2 | 71 loaded_df[df_key] = infile2 |
68 | 72 |
69 y = read_columns( | 73 y = read_columns(infile2, c=c, c_option=column_option, sep="\t", header=header, parse_dates=True) |
70 infile2, | |
71 c=c, | |
72 c_option=column_option, | |
73 sep='\t', | |
74 header=header, | |
75 parse_dates=True) | |
76 if len(y.shape) == 2 and y.shape[1] == 1: | 74 if len(y.shape) == 2 and y.shape[1] == 1: |
77 y = y.ravel() | 75 y = y.ravel() |
78 | 76 |
79 return X, y | 77 return X, y |
80 | 78 |
81 | 79 |
82 def main(inputs, infile_estimator, outfile_eval, | 80 def main( |
83 infile_weights=None, infile1=None, | 81 inputs, |
84 infile2=None): | 82 infile_estimator, |
83 outfile_eval, | |
84 infile_weights=None, | |
85 infile1=None, | |
86 infile2=None, | |
87 ): | |
85 """ | 88 """ |
86 Parameter | 89 Parameter |
87 --------- | 90 --------- |
88 inputs : str | 91 inputs : str |
89 File path to galaxy tool parameter | 92 File path to galaxy tool parameter |
101 File path to dataset containing features | 104 File path to dataset containing features |
102 | 105 |
103 infile2 : str | 106 infile2 : str |
104 File path to dataset containing target values | 107 File path to dataset containing target values |
105 """ | 108 """ |
106 warnings.filterwarnings('ignore') | 109 warnings.filterwarnings("ignore") |
107 | 110 |
108 with open(inputs, 'r') as param_handler: | 111 with open(inputs, "r") as param_handler: |
109 params = json.load(param_handler) | 112 params = json.load(param_handler) |
110 | 113 |
111 X_test, y_test = _get_X_y(params, infile1, infile2) | 114 X_test, y_test = _get_X_y(params, infile1, infile2) |
112 | 115 |
113 # load model | 116 # load model |
114 with open(infile_estimator, 'rb') as est_handler: | 117 with open(infile_estimator, "rb") as est_handler: |
115 estimator = load_model(est_handler) | 118 estimator = load_model(est_handler) |
116 | 119 |
117 main_est = estimator | 120 main_est = estimator |
118 if isinstance(estimator, Pipeline): | 121 if isinstance(estimator, Pipeline): |
119 main_est = estimator.steps[-1][-1] | 122 main_est = estimator.steps[-1][-1] |
120 if hasattr(main_est, 'config') and hasattr(main_est, 'load_weights'): | 123 if hasattr(main_est, "config") and hasattr(main_est, "load_weights"): |
121 if not infile_weights or infile_weights == 'None': | 124 if not infile_weights or infile_weights == "None": |
122 raise ValueError("The selected model skeleton asks for weights, " | 125 raise ValueError( |
123 "but no dataset for weights was provided!") | 126 "The selected model skeleton asks for weights, " "but no dataset for weights was provided!" |
127 ) | |
124 main_est.load_weights(infile_weights) | 128 main_est.load_weights(infile_weights) |
125 | 129 |
126 # handle scorer, convert to scorer dict | 130 # handle scorer, convert to scorer dict |
127 scoring = params['scoring'] | 131 # Check if scoring is specified |
132 scoring = params["scoring"] | |
133 if scoring is not None: | |
134 # get_scoring() expects secondary_scoring to be a comma separated string (not a list) | |
135 # Check if secondary_scoring is specified | |
136 secondary_scoring = scoring.get("secondary_scoring", None) | |
137 if secondary_scoring is not None: | |
138 # If secondary_scoring is specified, convert the list into comman separated string | |
139 scoring["secondary_scoring"] = ",".join(scoring["secondary_scoring"]) | |
140 | |
128 scorer = get_scoring(scoring) | 141 scorer = get_scoring(scoring) |
129 scorer, _ = _check_multimetric_scoring(estimator, scoring=scorer) | 142 scorer, _ = _check_multimetric_scoring(estimator, scoring=scorer) |
130 | 143 |
131 if hasattr(estimator, 'evaluate'): | 144 if hasattr(estimator, "evaluate"): |
132 scores = estimator.evaluate(X_test, y_test=y_test, | 145 scores = estimator.evaluate(X_test, y_test=y_test, scorer=scorer, is_multimetric=True) |
133 scorer=scorer, | |
134 is_multimetric=True) | |
135 else: | 146 else: |
136 scores = _score(estimator, X_test, y_test, scorer, | 147 scores = _score(estimator, X_test, y_test, scorer, is_multimetric=True) |
137 is_multimetric=True) | |
138 | 148 |
139 # handle output | 149 # handle output |
140 for name, score in scores.items(): | 150 for name, score in scores.items(): |
141 scores[name] = [score] | 151 scores[name] = [score] |
142 df = pd.DataFrame(scores) | 152 df = pd.DataFrame(scores) |
143 df = df[sorted(df.columns)] | 153 df = df[sorted(df.columns)] |
144 df.to_csv(path_or_buf=outfile_eval, sep='\t', | 154 df.to_csv(path_or_buf=outfile_eval, sep="\t", header=True, index=False) |
145 header=True, index=False) | |
146 | 155 |
147 | 156 |
148 if __name__ == '__main__': | 157 if __name__ == "__main__": |
149 aparser = argparse.ArgumentParser() | 158 aparser = argparse.ArgumentParser() |
150 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | 159 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) |
151 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator") | 160 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator") |
152 aparser.add_argument("-w", "--infile_weights", dest="infile_weights") | 161 aparser.add_argument("-w", "--infile_weights", dest="infile_weights") |
153 aparser.add_argument("-X", "--infile1", dest="infile1") | 162 aparser.add_argument("-X", "--infile1", dest="infile1") |
154 aparser.add_argument("-y", "--infile2", dest="infile2") | 163 aparser.add_argument("-y", "--infile2", dest="infile2") |
155 aparser.add_argument("-O", "--outfile_eval", dest="outfile_eval") | 164 aparser.add_argument("-O", "--outfile_eval", dest="outfile_eval") |
156 args = aparser.parse_args() | 165 args = aparser.parse_args() |
157 | 166 |
158 main(args.inputs, args.infile_estimator, args.outfile_eval, | 167 main( |
159 infile_weights=args.infile_weights, infile1=args.infile1, | 168 args.inputs, |
160 infile2=args.infile2) | 169 args.infile_estimator, |
170 args.outfile_eval, | |
171 infile_weights=args.infile_weights, | |
172 infile1=args.infile1, | |
173 infile2=args.infile2, | |
174 ) |