Mercurial > repos > bgruening > sklearn_model_validation
comparison model_validation.xml @ 19:efbec977a47d draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 60f0fbc0eafd7c11bc60fb6c77f2937782efd8a9-dirty
author | bgruening |
---|---|
date | Fri, 09 Aug 2019 07:26:09 -0400 |
parents | cf9aa11b91c8 |
children | 5895fe0b8bde |
comparison
equal
deleted
inserted
replaced
18:492d34a75de6 | 19:efbec977a47d |
---|---|
1 <tool id="sklearn_model_validation" name="Model Validation" version="@VERSION@"> | 1 <tool id="sklearn_model_validation" name="Model Validation" version="@VERSION@"> |
2 <description>evaluates estimator performance by cross-validation</description> | 2 <description>evaluates estimator performances without changing parameters</description> |
3 <macros> | 3 <macros> |
4 <import>main_macros.xml</import> | 4 <import>main_macros.xml</import> |
5 </macros> | 5 </macros> |
6 <expand macro="python_requirements"/> | 6 <expand macro="python_requirements"/> |
7 <expand macro="macro_stdio"/> | 7 <expand macro="macro_stdio"/> |
14 <configfiles> | 14 <configfiles> |
15 <inputs name="inputs" /> | 15 <inputs name="inputs" /> |
16 <configfile name="sklearn_model_validation_script"> | 16 <configfile name="sklearn_model_validation_script"> |
17 <![CDATA[ | 17 <![CDATA[ |
18 import imblearn | 18 import imblearn |
19 import joblib | |
19 import json | 20 import json |
20 import numpy as np | 21 import numpy as np |
21 import pandas as pd | 22 import pandas as pd |
22 import pickle | 23 import pickle |
23 import pprint | 24 import pprint |
29 from sklearn import ( | 30 from sklearn import ( |
30 cluster, compose, decomposition, ensemble, feature_extraction, | 31 cluster, compose, decomposition, ensemble, feature_extraction, |
31 feature_selection, gaussian_process, kernel_approximation, metrics, | 32 feature_selection, gaussian_process, kernel_approximation, metrics, |
32 model_selection, naive_bayes, neighbors, pipeline, preprocessing, | 33 model_selection, naive_bayes, neighbors, pipeline, preprocessing, |
33 svm, linear_model, tree, discriminant_analysis) | 34 svm, linear_model, tree, discriminant_analysis) |
34 | 35 from sklearn.model_selection import _validation |
35 sys.path.insert(0, '$__tool_directory__') | 36 |
36 from utils import SafeEval, get_cv, get_scoring, load_model, read_columns | 37 from galaxy_ml.utils import (SafeEval, get_cv, get_scoring, load_model, |
38 read_columns, get_module) | |
39 from galaxy_ml.model_validations import _fit_and_score | |
40 | |
41 | |
42 setattr(_validation, '_fit_and_score', _fit_and_score) | |
37 | 43 |
38 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) | 44 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) |
45 CACHE_DIR = './cached' | |
46 ALLOWED_CALLBACKS = ('EarlyStopping', 'TerminateOnNaN', 'ReduceLROnPlateau', | |
47 'CSVLogger', 'None') | |
39 | 48 |
40 warnings.filterwarnings('ignore') | 49 warnings.filterwarnings('ignore') |
41 | 50 |
42 safe_eval = SafeEval() | 51 safe_eval = SafeEval() |
43 | 52 |
44 input_json_path = sys.argv[1] | 53 input_json_path = sys.argv[1] |
45 with open(input_json_path, 'r') as param_handler: | 54 with open(input_json_path, 'r') as param_handler: |
46 params = json.load(param_handler) | 55 params = json.load(param_handler) |
47 | 56 |
48 #if $model_validation_functions.options.cv_selector.selected_cv\ | 57 ## load estimator |
49 in ['GroupKFold', 'GroupShuffleSplit', 'LeaveOneGroupOut', 'LeavePGroupsOut']: | 58 with open('$infile_estimator', 'rb') as estimator_handler: |
50 params['model_validation_functions']['options']['cv_selector']['groups_selector']['infile_g'] =\ | 59 estimator = load_model(estimator_handler) |
51 '$model_validation_functions.options.cv_selector.groups_selector.infile_g' | 60 |
61 estimator_params = estimator.get_params() | |
62 | |
63 ## check estimator hyperparameters | |
64 memory = joblib.Memory(location=CACHE_DIR, verbose=0) | |
65 # cache iraps_core fits could increase search speed significantly | |
66 if estimator.__class__.__name__ == 'IRAPSClassifier': | |
67 estimator.set_params(memory=memory) | |
68 else: | |
69 # For iraps buried in pipeline | |
70 for p, v in estimator_params.items(): | |
71 if p.endswith('memory'): | |
72 # for case of `__irapsclassifier__memory` | |
73 if len(p) > 8 and p[:-8].endswith('irapsclassifier'): | |
74 # cache iraps_core fits could increase search | |
75 # speed significantly | |
76 new_params = {p: memory} | |
77 estimator.set_params(**new_params) | |
78 # security reason, we don't want memory being | |
79 # modified unexpectedly | |
80 elif v: | |
81 new_params = {p, None} | |
82 estimator.set_params(**new_params) | |
83 # For now, 1 CPU is suggested for iprasclassifier | |
84 elif p.endswith('n_jobs'): | |
85 new_params = {p: 1} | |
86 estimator.set_params(**new_params) | |
87 # for security reason, types of callback are limited | |
88 elif p.endswith('callbacks'): | |
89 for cb in v: | |
90 cb_type = cb['callback_selection']['callback_type'] | |
91 if cb_type not in ALLOWED_CALLBACKS: | |
92 raise ValueError( | |
93 "Prohibited callback type: %s!" % cb_type) | |
94 | |
95 ## store read dataframe object | |
96 loaded_df = {} | |
97 | |
98 #if $input_options.selected_input == 'tabular' | |
99 header = 'infer' if params['input_options']['header1'] else None | |
100 column_option = params['input_options']['column_selector_options_1']['selected_column_selector_option'] | |
101 if column_option in ['by_index_number', 'all_but_by_index_number', 'by_header_name', 'all_but_by_header_name']: | |
102 c = params['input_options']['column_selector_options_1']['col1'] | |
103 else: | |
104 c = None | |
105 infile1 = '$input_options.infile1' | |
106 df_key = infile1 + repr(header) | |
107 df = pd.read_csv(infile1, sep='\t', header=header, parse_dates=True) | |
108 loaded_df[df_key] = df | |
109 X = read_columns(df, c=c, c_option=column_option).astype(float) | |
110 | |
111 #elif $input_options.selected_input == 'sparse': | |
112 X = mmread('$input_options.infile1') | |
113 | |
114 #elif $input_options.selected_input == 'seq_fasta' | |
115 fasta_path = '$input_options.fasta_path' | |
116 pyfaidx = get_module('pyfaidx') | |
117 sequences = pyfaidx.Fasta(fasta_path) | |
118 n_seqs = len(sequences.keys()) | |
119 X = np.arange(n_seqs)[:, np.newaxis] | |
120 for param in estimator_params.keys(): | |
121 if param.endswith('fasta_path'): | |
122 estimator.set_params( | |
123 **{param: fasta_path}) | |
124 break | |
125 else: | |
126 raise ValueError( | |
127 "The selected estimator doesn't support " | |
128 "fasta file input! Please consider using " | |
129 "KerasGBatchClassifier with " | |
130 "FastaDNABatchGenerator/FastaProteinBatchGenerator " | |
131 "or having GenomeOneHotEncoder/ProteinOneHotEncoder " | |
132 "in pipeline!") | |
133 #elif $input_options.selected_input == 'refseq_and_interval' | |
134 ref_seq = '$input_options.ref_genome_file' | |
135 intervals = '$input_options.interval_file' | |
136 targets = __import__('os').path.join(__import__('os').getcwd(), | |
137 '${target_file.element_identifier}.gz') | |
138 path_params = { | |
139 'data_batch_generator__ref_genome_path': ref_seq, | |
140 'data_batch_generator__intervals_path': intervals, | |
141 'data_batch_generator__target_path': targets | |
142 } | |
143 estimator.set_params(**path_params) | |
144 n_intervals = sum(1 for line in open(intervals)) | |
145 X = np.arange(n_intervals)[:, np.newaxis] | |
52 #end if | 146 #end if |
53 | |
54 input_type = params['input_options']['selected_input'] | |
55 if input_type == 'tabular': | |
56 header = 'infer' if params['input_options']['header1'] else None | |
57 column_option = params['input_options']['column_selector_options_1']['selected_column_selector_option'] | |
58 if column_option in ['by_index_number', 'all_but_by_index_number', 'by_header_name', 'all_but_by_header_name']: | |
59 c = params['input_options']['column_selector_options_1']['col1'] | |
60 else: | |
61 c = None | |
62 X = read_columns( | |
63 '$input_options.infile1', | |
64 c = c, | |
65 c_option = column_option, | |
66 sep='\t', | |
67 header=header, | |
68 parse_dates=True).astype(float) | |
69 else: | |
70 X = mmread('$input_options.infile1') | |
71 | 147 |
72 header = 'infer' if params['input_options']['header2'] else None | 148 header = 'infer' if params['input_options']['header2'] else None |
73 column_option = params['input_options']['column_selector_options_2']['selected_column_selector_option2'] | 149 column_option = params['input_options']['column_selector_options_2']['selected_column_selector_option2'] |
74 if column_option in ['by_index_number', 'all_but_by_index_number', 'by_header_name', 'all_but_by_header_name']: | 150 if column_option in ['by_index_number', 'all_but_by_index_number', 'by_header_name', 'all_but_by_header_name']: |
75 c = params['input_options']['column_selector_options_2']['col2'] | 151 c = params['input_options']['column_selector_options_2']['col2'] |
76 else: | 152 else: |
77 c = None | 153 c = None |
154 infile2 = '$input_options.infile2' | |
155 df_key = infile2 + repr(header) | |
156 if df_key in loaded_df: | |
157 infile2 = loaded_df[df_key] | |
158 else: | |
159 infile2 = pd.read_csv(infile2, sep='\t', header=header, parse_dates=True) | |
160 loaded_df[df_key] = infile2 | |
78 y = read_columns( | 161 y = read_columns( |
79 '$input_options.infile2', | 162 infile2, |
80 c = c, | 163 c = c, |
81 c_option = column_option, | 164 c_option = column_option, |
82 sep='\t', | 165 sep='\t', |
83 header=header, | 166 header=header, |
84 parse_dates=True) | 167 parse_dates=True) |
85 y = y.ravel() | 168 if len(y.shape) == 2 and y.shape[1] == 1: |
169 y = y.ravel() | |
170 #if $input_options.selected_input == 'refseq_and_interval' | |
171 estimator.set_params( | |
172 data_batch_generator__features=y.ravel().tolist()) | |
173 y = None | |
174 #end if | |
86 | 175 |
87 ## handle options | 176 ## handle options |
88 options = params['model_validation_functions']['options'] | 177 options = params['model_validation_functions']['options'] |
178 | |
179 #if $model_validation_functions.options.cv_selector.selected_cv\ | |
180 in ['GroupKFold', 'GroupShuffleSplit', 'LeaveOneGroupOut', 'LeavePGroupsOut']: | |
181 infile_g = '$model_validation_functions.options.cv_selector.groups_selector.infile_g' | |
182 header = 'infer' if options['cv_selector']['groups_selector']['header_g'] else None | |
183 column_option = (options['cv_selector']['groups_selector']['column_selector_options_g'] | |
184 ['selected_column_selector_option_g']) | |
185 if column_option in ['by_index_number', 'all_but_by_index_number', | |
186 'by_header_name', 'all_but_by_header_name']: | |
187 c = (options['cv_selector']['groups_selector']['column_selector_options_g']['col_g']) | |
188 else: | |
189 c = None | |
190 df_key = infile_g + repr(header) | |
191 if df_key in loaded_df: | |
192 infile_g = loaded_df[df_key] | |
193 groups = read_columns(infile_g, c=c, c_option=column_option, | |
194 sep='\t', header=header, parse_dates=True) | |
195 groups = groups.ravel() | |
196 options['cv_selector']['groups_selector'] = groups | |
197 #end if | |
198 | |
199 ## del loaded_df | |
200 del loaded_df | |
201 | |
89 splitter, groups = get_cv( options.pop('cv_selector') ) | 202 splitter, groups = get_cv( options.pop('cv_selector') ) |
90 options['cv'] = splitter | 203 options['cv'] = splitter |
91 options['groups'] = groups | 204 options['groups'] = groups |
92 options['n_jobs'] = N_JOBS | 205 options['n_jobs'] = N_JOBS |
93 if 'scoring' in options: | 206 if 'scoring' in options: |
94 primary_scoring = options['scoring']['primary_scoring'] | 207 primary_scoring = options['scoring']['primary_scoring'] |
95 options['scoring'] = get_scoring(options['scoring']) | 208 options['scoring'] = get_scoring(options['scoring']) |
96 if 'pre_dispatch' in options and options['pre_dispatch'] == '': | 209 if 'pre_dispatch' in options and options['pre_dispatch'] == '': |
97 options['pre_dispatch'] = None | 210 options['pre_dispatch'] = None |
98 | 211 |
99 ## load pipeline | 212 ## Set up validator, run estimator through validator and return results. |
100 with open('$infile_pipeline', 'rb') as pipeline_handler: | |
101 pipeline = load_model(pipeline_handler) | |
102 | |
103 ## Set up validator, run pipeline through validator and return results. | |
104 | 213 |
105 validator = params['model_validation_functions']['selected_function'] | 214 validator = params['model_validation_functions']['selected_function'] |
106 validator = getattr(model_selection, validator) | 215 validator = getattr(_validation, validator) |
107 | 216 |
108 selected_function = params['model_validation_functions']['selected_function'] | 217 selected_function = params['model_validation_functions']['selected_function'] |
109 | 218 |
110 if selected_function == 'cross_validate': | 219 if selected_function == 'cross_validate': |
111 res = validator(pipeline, X, y, **options) | 220 res = validator(estimator, X, y, **options) |
221 stat = {} | |
222 for k, v in res.items(): | |
223 if k.startswith('test'): | |
224 stat['mean_' + k] = np.mean(v) | |
225 stat['std_' + k] = np.std(v) | |
226 res.update(stat) | |
112 rval = pd.DataFrame(res) | 227 rval = pd.DataFrame(res) |
113 col_rename = {} | 228 rval = rval[sorted(rval.columns)] |
114 for col in rval.columns: | |
115 if col.endswith('_primary'): | |
116 col_rename[col] = col[:-7] + primary_scoring | |
117 rval.rename(inplace=True, columns=col_rename) | |
118 elif selected_function == 'cross_val_predict': | 229 elif selected_function == 'cross_val_predict': |
119 predicted = validator(pipeline, X, y, **options) | 230 predicted = validator(estimator, X, y, **options) |
120 if len(predicted.shape) == 1: | 231 if len(predicted.shape) == 1: |
121 rval = pd.DataFrame(predicted, columns=['Predicted']) | 232 rval = pd.DataFrame(predicted, columns=['Predicted']) |
122 else: | 233 else: |
123 rval = pd.DataFrame(predicted) | 234 rval = pd.DataFrame(predicted) |
124 elif selected_function == 'learning_curve': | 235 elif selected_function == 'learning_curve': |
127 except: | 238 except: |
128 sys.exit("Unsupported train_sizes input! Supports int/float in tuple and array-like structure.") | 239 sys.exit("Unsupported train_sizes input! Supports int/float in tuple and array-like structure.") |
129 if type(train_sizes) is tuple: | 240 if type(train_sizes) is tuple: |
130 train_sizes = np.linspace(*train_sizes) | 241 train_sizes = np.linspace(*train_sizes) |
131 options['train_sizes'] = train_sizes | 242 options['train_sizes'] = train_sizes |
132 train_sizes_abs, train_scores, test_scores = validator(pipeline, X, y, **options) | 243 train_sizes_abs, train_scores, test_scores = validator(estimator, X, y, **options) |
133 rval = pd.DataFrame(dict( | 244 rval = pd.DataFrame(dict( |
134 train_sizes_abs = train_sizes_abs, | 245 train_sizes_abs = train_sizes_abs, |
135 mean_train_scores = np.mean(train_scores, axis=1), | 246 mean_train_scores = np.mean(train_scores, axis=1), |
136 std_train_scores = np.std(train_scores, axis=1), | 247 std_train_scores = np.std(train_scores, axis=1), |
137 mean_test_scores = np.mean(test_scores, axis=1), | 248 mean_test_scores = np.mean(test_scores, axis=1), |
138 std_test_scores = np.std(test_scores, axis=1))) | 249 std_test_scores = np.std(test_scores, axis=1))) |
139 rval = rval[['train_sizes_abs', 'mean_train_scores', 'std_train_scores', | 250 rval = rval[['train_sizes_abs', 'mean_train_scores', 'std_train_scores', |
140 'mean_test_scores', 'std_test_scores']] | 251 'mean_test_scores', 'std_test_scores']] |
141 elif selected_function == 'permutation_test_score': | 252 elif selected_function == 'permutation_test_score': |
142 score, permutation_scores, pvalue = validator(pipeline, X, y, **options) | 253 score, permutation_scores, pvalue = validator(estimator, X, y, **options) |
143 permutation_scores_df = pd.DataFrame(dict( | 254 permutation_scores_df = pd.DataFrame(dict( |
144 permutation_scores = permutation_scores)) | 255 permutation_scores = permutation_scores)) |
145 score_df = pd.DataFrame(dict( | 256 score_df = pd.DataFrame(dict( |
146 score = [score], | 257 score = [score], |
147 pvalue = [pvalue])) | 258 pvalue = [pvalue])) |
151 | 262 |
152 ]]> | 263 ]]> |
153 </configfile> | 264 </configfile> |
154 </configfiles> | 265 </configfiles> |
155 <inputs> | 266 <inputs> |
156 <param name="infile_pipeline" type="data" format="zip" label="Choose the dataset containing model/pipeline object"/> | 267 <param name="infile_estimator" type="data" format="zip" label="Choose the dataset containing model/pipeline object"/> |
157 <conditional name="model_validation_functions"> | 268 <conditional name="model_validation_functions"> |
158 <param name="selected_function" type="select" label="Select a model validation function"> | 269 <param name="selected_function" type="select" label="Select a model validation function"> |
159 <option value="cross_validate">cross_validate - Evaluate metric(s) by cross-validation and also record fit/score times</option> | 270 <option value="cross_validate">cross_validate - Evaluate metric(s) by cross-validation and also record fit/score times</option> |
160 <option value="cross_val_predict">cross_val_predict - Generate cross-validated estimates for each input data point</option> | 271 <option value="cross_val_predict">cross_val_predict - Generate cross-validated estimates for each input data point</option> |
161 <option value="learning_curve">learning_curve - Learning curve</option> | 272 <option value="learning_curve">learning_curve - Learning curve</option> |
218 <outputs> | 329 <outputs> |
219 <data format="tabular" name="outfile"/> | 330 <data format="tabular" name="outfile"/> |
220 </outputs> | 331 </outputs> |
221 <tests> | 332 <tests> |
222 <test> | 333 <test> |
223 <param name="infile_pipeline" value="pipeline02"/> | 334 <param name="infile_estimator" value="pipeline02"/> |
224 <param name="selected_function" value="cross_validate"/> | 335 <param name="selected_function" value="cross_validate"/> |
225 <param name="infile1" value="regression_train.tabular" ftype="tabular"/> | 336 <param name="infile1" value="regression_train.tabular" ftype="tabular"/> |
226 <param name="col1" value="1,2,3,4,5"/> | 337 <param name="col1" value="1,2,3,4,5"/> |
227 <param name="infile2" value="regression_train.tabular" ftype="tabular"/> | 338 <param name="infile2" value="regression_train.tabular" ftype="tabular"/> |
228 <param name="col2" value="6"/> | 339 <param name="col2" value="6"/> |
229 <output name="outfile"> | 340 <output name="outfile"> |
230 <assert_contents> | 341 <assert_contents> |
231 <has_n_columns n="4"/> | 342 <has_n_columns n="6"/> |
232 <has_text text="0.9999961390418067"/> | 343 <has_text text="0.9999961390418067"/> |
233 <has_text text="0.9944541531269271"/> | 344 <has_text text="0.9944541531269271"/> |
234 <has_text text="0.9999193322454393"/> | 345 <has_text text="0.9999193322454393"/> |
235 </assert_contents> | 346 </assert_contents> |
236 </output> | 347 </output> |
237 </test> | 348 </test> |
238 <test> | 349 <test> |
239 <param name="infile_pipeline" value="pipeline02"/> | 350 <param name="infile_estimator" value="pipeline02"/> |
240 <param name="selected_function" value="cross_val_predict"/> | 351 <param name="selected_function" value="cross_val_predict"/> |
241 <param name="infile1" value="regression_train.tabular" ftype="tabular"/> | 352 <param name="infile1" value="regression_train.tabular" ftype="tabular"/> |
242 <param name="col1" value="1,2,3,4,5"/> | 353 <param name="col1" value="1,2,3,4,5"/> |
243 <param name="infile2" value="regression_train.tabular" ftype="tabular"/> | 354 <param name="infile2" value="regression_train.tabular" ftype="tabular"/> |
244 <param name="col2" value="6"/> | 355 <param name="col2" value="6"/> |
245 <output name="outfile" file="mv_result02.tabular" lines_diff="4"/> | 356 <output name="outfile" file="mv_result02.tabular" lines_diff="4"/> |
246 </test> | 357 </test> |
247 <test> | 358 <test> |
248 <param name="infile_pipeline" value="pipeline05"/> | 359 <param name="infile_estimator" value="pipeline05"/> |
249 <param name="selected_function" value="learning_curve"/> | 360 <param name="selected_function" value="learning_curve"/> |
250 <param name="infile1" value="regression_X.tabular" ftype="tabular"/> | 361 <param name="infile1" value="regression_X.tabular" ftype="tabular"/> |
251 <param name="header1" value="true" /> | 362 <param name="header1" value="true" /> |
252 <param name="col1" value="1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17"/> | 363 <param name="col1" value="1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17"/> |
253 <param name="infile2" value="regression_y.tabular" ftype="tabular"/> | 364 <param name="infile2" value="regression_y.tabular" ftype="tabular"/> |
254 <param name="header2" value="true" /> | 365 <param name="header2" value="true" /> |
255 <param name="col2" value="1"/> | 366 <param name="col2" value="1"/> |
256 <output name="outfile" file="mv_result03.tabular"/> | 367 <output name="outfile" file="mv_result03.tabular"/> |
257 </test> | 368 </test> |
258 <test> | 369 <test> |
259 <param name="infile_pipeline" value="pipeline05"/> | 370 <param name="infile_estimator" value="pipeline05"/> |
260 <param name="selected_function" value="permutation_test_score"/> | 371 <param name="selected_function" value="permutation_test_score"/> |
261 <param name="infile1" value="regression_train.tabular" ftype="tabular"/> | 372 <param name="infile1" value="regression_train.tabular" ftype="tabular"/> |
262 <param name="col1" value="1,2,3,4,5"/> | 373 <param name="col1" value="1,2,3,4,5"/> |
263 <param name="infile2" value="regression_train.tabular" ftype="tabular"/> | 374 <param name="infile2" value="regression_train.tabular" ftype="tabular"/> |
264 <param name="col2" value="6"/> | 375 <param name="col2" value="6"/> |
268 <has_text text="0.25697059258228816"/> | 379 <has_text text="0.25697059258228816"/> |
269 </assert_contents> | 380 </assert_contents> |
270 </output> | 381 </output> |
271 </test> | 382 </test> |
272 <test> | 383 <test> |
273 <param name="infile_pipeline" value="pipeline05"/> | 384 <param name="infile_estimator" value="pipeline05"/> |
274 <param name="selected_function" value="cross_val_predict"/> | 385 <param name="selected_function" value="cross_val_predict"/> |
275 <section name="groups_selector"> | 386 <section name="groups_selector"> |
276 <param name="infile_groups" value="regression_y.tabular" ftype="tabular"/> | 387 <param name="infile_groups" value="regression_y.tabular" ftype="tabular"/> |
277 <param name="header_g" value="true"/> | 388 <param name="header_g" value="true"/> |
278 <param name="selected_column_selector_option_g" value="by_index_number"/> | 389 <param name="selected_column_selector_option_g" value="by_index_number"/> |