Mercurial > repos > bgruening > sklearn_model_validation
comparison model_validation.xml @ 34:1fe00785190d draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author | bgruening |
---|---|
date | Wed, 09 Aug 2023 13:44:18 +0000 |
parents | 4b359039f09f |
children |
comparison
equal
deleted
inserted
replaced
33:5d5d9cc554f9 | 34:1fe00785190d |
---|---|
1 <tool id="sklearn_model_validation" name="Model Validation" version="@VERSION@" profile="20.05"> | 1 <tool id="sklearn_model_validation" name="Model Validation" version="@VERSION@" profile="@PROFILE@"> |
2 <description>includes cross_validate, cross_val_predict, learning_curve, and more</description> | 2 <description>includes cross_validate, cross_val_predict, learning_curve, and more</description> |
3 <macros> | 3 <macros> |
4 <import>main_macros.xml</import> | 4 <import>main_macros.xml</import> |
5 </macros> | 5 </macros> |
6 <expand macro="python_requirements" /> | 6 <expand macro="python_requirements" /> |
20 import joblib | 20 import joblib |
21 import json | 21 import json |
22 import numpy as np | 22 import numpy as np |
23 import os | 23 import os |
24 import pandas as pd | 24 import pandas as pd |
25 import pickle | |
26 import pprint | 25 import pprint |
27 import skrebate | 26 import skrebate |
28 import sys | 27 import sys |
29 import warnings | 28 import warnings |
30 import xgboost | 29 import xgboost |
33 cluster, compose, decomposition, ensemble, feature_extraction, | 32 cluster, compose, decomposition, ensemble, feature_extraction, |
34 feature_selection, gaussian_process, kernel_approximation, metrics, | 33 feature_selection, gaussian_process, kernel_approximation, metrics, |
35 model_selection, naive_bayes, neighbors, pipeline, preprocessing, | 34 model_selection, naive_bayes, neighbors, pipeline, preprocessing, |
36 svm, linear_model, tree, discriminant_analysis) | 35 svm, linear_model, tree, discriminant_analysis) |
37 from sklearn.model_selection import _validation | 36 from sklearn.model_selection import _validation |
38 | 37 from sklearn.preprocessing import LabelEncoder |
39 from galaxy_ml.utils import (SafeEval, get_cv, get_scoring, load_model, | 38 |
40 read_columns, get_module) | 39 from distutils.version import LooseVersion as Version |
41 from galaxy_ml.model_validations import _fit_and_score | 40 from galaxy_ml import __version__ as galaxy_ml_version |
42 | 41 from galaxy_ml.model_persist import load_model_from_h5 |
43 | 42 from galaxy_ml.utils import (SafeEval, get_cv, get_scoring, |
44 setattr(_validation, '_fit_and_score', _fit_and_score) | 43 read_columns, get_module, |
44 clean_params, get_main_estimator) | |
45 | |
45 | 46 |
46 N_JOBS = int(os.environ.get('GALAXY_SLOTS', 1)) | 47 N_JOBS = int(os.environ.get('GALAXY_SLOTS', 1)) |
47 CACHE_DIR = os.path.join(os.getcwd(), 'cached') | 48 CACHE_DIR = os.path.join(os.getcwd(), 'cached') |
48 del os | |
49 ALLOWED_CALLBACKS = ('EarlyStopping', 'TerminateOnNaN', 'ReduceLROnPlateau', | |
50 'CSVLogger', 'None') | |
51 | 49 |
52 warnings.filterwarnings('ignore') | 50 warnings.filterwarnings('ignore') |
53 | 51 |
54 safe_eval = SafeEval() | 52 safe_eval = SafeEval() |
55 | 53 |
56 input_json_path = sys.argv[1] | 54 input_json_path = sys.argv[1] |
57 with open(input_json_path, 'r') as param_handler: | 55 with open(input_json_path, 'r') as param_handler: |
58 params = json.load(param_handler) | 56 params = json.load(param_handler) |
59 | 57 |
60 ## load estimator | 58 ## load estimator |
61 with open('$infile_estimator', 'rb') as estimator_handler: | 59 estimator = load_model_from_h5('$infile_estimator') |
62 estimator = load_model(estimator_handler) | 60 estimator = clean_params(estimator) |
61 | |
62 if estimator.__class__.__name__ == 'KerasGBatchClassifier': | |
63 _fit_and_score = try_get_attr('galaxy_ml.model_validations', | |
64 '_fit_and_score') | |
65 | |
66 setattr(_search, '_fit_and_score', _fit_and_score) | |
67 setattr(_validation, '_fit_and_score', _fit_and_score) | |
63 | 68 |
64 estimator_params = estimator.get_params() | 69 estimator_params = estimator.get_params() |
65 | 70 |
66 ## check estimator hyperparameters | 71 ## check estimator hyperparameters |
67 memory = joblib.Memory(location=CACHE_DIR, verbose=0) | 72 memory = joblib.Memory(location=CACHE_DIR, verbose=0) |
69 if estimator.__class__.__name__ == 'IRAPSClassifier': | 74 if estimator.__class__.__name__ == 'IRAPSClassifier': |
70 estimator.set_params(memory=memory) | 75 estimator.set_params(memory=memory) |
71 else: | 76 else: |
72 # For iraps buried in pipeline | 77 # For iraps buried in pipeline |
73 for p, v in estimator_params.items(): | 78 for p, v in estimator_params.items(): |
74 if p.endswith('memory'): | 79 if p.endswith('__irapsclassifier__memory'): |
75 # for case of `__irapsclassifier__memory` | 80 new_params = {p: memory} |
76 if len(p) > 8 and p[:-8].endswith('irapsclassifier'): | |
77 # cache iraps_core fits could increase search | |
78 # speed significantly | |
79 new_params = {p: memory} | |
80 estimator.set_params(**new_params) | |
81 # security reason, we don't want memory being | |
82 # modified unexpectedly | |
83 elif v: | |
84 new_params = {p, None} | |
85 estimator.set_params(**new_params) | |
86 # For now, 1 CPU is suggested for iprasclassifier | |
87 elif p.endswith('n_jobs'): | |
88 new_params = {p: 1} | |
89 estimator.set_params(**new_params) | 81 estimator.set_params(**new_params) |
90 # for security reason, types of callback are limited | |
91 elif p.endswith('callbacks'): | |
92 for cb in v: | |
93 cb_type = cb['callback_selection']['callback_type'] | |
94 if cb_type not in ALLOWED_CALLBACKS: | |
95 raise ValueError( | |
96 "Prohibited callback type: %s!" % cb_type) | |
97 | 82 |
98 ## store read dataframe object | 83 ## store read dataframe object |
99 loaded_df = {} | 84 loaded_df = {} |
100 | 85 |
101 #if $input_options.selected_input == 'tabular' | 86 #if $input_options.selected_input == 'tabular' |
160 infile2 = loaded_df[df_key] | 145 infile2 = loaded_df[df_key] |
161 else: | 146 else: |
162 infile2 = pd.read_csv(infile2, sep='\t', header=header, parse_dates=True) | 147 infile2 = pd.read_csv(infile2, sep='\t', header=header, parse_dates=True) |
163 loaded_df[df_key] = infile2 | 148 loaded_df[df_key] = infile2 |
164 y = read_columns( | 149 y = read_columns( |
165 infile2, | 150 infile2, |
166 c = c, | 151 c = c, |
167 c_option = column_option, | 152 c_option = column_option, |
168 sep='\t', | 153 sep='\t', |
169 header=header, | 154 header=header, |
170 parse_dates=True) | 155 parse_dates=True) |
171 if len(y.shape) == 2 and y.shape[1] == 1: | 156 if len(y.shape) == 2 and y.shape[1] == 1: |
172 y = y.ravel() | 157 y = y.ravel() |
173 #if $input_options.selected_input == 'refseq_and_interval' | 158 #if $input_options.selected_input == 'refseq_and_interval' |
174 estimator.set_params( | 159 estimator.set_params( |
175 data_batch_generator__features=y.ravel().tolist()) | 160 data_batch_generator__features=y.ravel().tolist()) |
176 y = None | 161 y = None |
162 label_encoder = LabelEncoder() | |
163 if get_main_estimator(estimator).__class__.__name__ == "XGBClassifier": | |
164 y = label_encoder.fit_transform(y) | |
165 print(label_encoder.classes_) | |
177 #end if | 166 #end if |
178 | 167 |
179 ## handle options | 168 ## handle options |
180 options = params['model_validation_functions']['options'] | 169 options = params['model_validation_functions']['options'] |
181 | 170 |
200 #end if | 189 #end if |
201 | 190 |
202 ## del loaded_df | 191 ## del loaded_df |
203 del loaded_df | 192 del loaded_df |
204 | 193 |
205 splitter, groups = get_cv( options.pop('cv_selector') ) | 194 cv_selector = options.pop('cv_selector') |
195 if Version(galaxy_ml_version) < Version('0.8.3'): | |
196 cv_selector.pop('n_stratification_bins', None) | |
197 splitter, groups = get_cv( cv_selector ) | |
206 options['cv'] = splitter | 198 options['cv'] = splitter |
207 options['groups'] = groups | 199 options['groups'] = groups |
208 options['n_jobs'] = N_JOBS | 200 options['n_jobs'] = N_JOBS |
209 if 'scoring' in options: | 201 if 'scoring' in options: |
210 primary_scoring = options['scoring']['primary_scoring'] | 202 primary_scoring = options['scoring']['primary_scoring'] |
236 else: | 228 else: |
237 rval = pd.DataFrame(predicted) | 229 rval = pd.DataFrame(predicted) |
238 elif selected_function == 'learning_curve': | 230 elif selected_function == 'learning_curve': |
239 try: | 231 try: |
240 train_sizes = safe_eval(options['train_sizes']) | 232 train_sizes = safe_eval(options['train_sizes']) |
241 except Exception: | 233 except: |
242 sys.exit("Unsupported train_sizes input! Supports int/float in tuple and array-like structure.") | 234 sys.exit("Unsupported train_sizes input! Supports int/float in tuple and array-like structure.") |
243 if type(train_sizes) is tuple: | 235 if type(train_sizes) is tuple: |
244 train_sizes = np.linspace(*train_sizes) | 236 train_sizes = np.linspace(*train_sizes) |
245 options['train_sizes'] = train_sizes | 237 options['train_sizes'] = train_sizes |
246 train_sizes_abs, train_scores, test_scores = validator(estimator, X, y, **options) | 238 train_sizes_abs, train_scores, test_scores = validator(estimator, X, y, **options) |
265 | 257 |
266 ]]> | 258 ]]> |
267 </configfile> | 259 </configfile> |
268 </configfiles> | 260 </configfiles> |
269 <inputs> | 261 <inputs> |
270 <param name="infile_estimator" type="data" format="zip" label="Choose the dataset containing model/pipeline object" /> | 262 <param name="infile_estimator" type="data" format="h5mlm" label="Choose the dataset containing model/pipeline object" /> |
271 <conditional name="model_validation_functions"> | 263 <conditional name="model_validation_functions"> |
272 <param name="selected_function" type="select" label="Select a model validation function"> | 264 <param name="selected_function" type="select" label="Select a model validation function"> |
273 <option value="cross_validate">cross_validate - Evaluate metric(s) by cross-validation and also record fit/score times</option> | 265 <option value="cross_validate">cross_validate - Evaluate metric(s) by cross-validation and also record fit/score times</option> |
274 <option value="cross_val_predict">cross_val_predict - Generate cross-validated estimates for each input data point</option> | 266 <option value="cross_val_predict">cross_val_predict - Generate cross-validated estimates for each input data point</option> |
275 <option value="learning_curve">learning_curve - Learning curve</option> | 267 <option value="learning_curve">learning_curve - Learning curve</option> |
279 <when value="cross_validate"> | 271 <when value="cross_validate"> |
280 <section name="options" title="Other Options" expanded="false"> | 272 <section name="options" title="Other Options" expanded="false"> |
281 <expand macro="scoring_selection" /> | 273 <expand macro="scoring_selection" /> |
282 <expand macro="model_validation_common_options" /> | 274 <expand macro="model_validation_common_options" /> |
283 <param argument="return_train_score" type="boolean" optional="true" truevalue="booltrue" falsevalue="boolfalse" checked="false" help="Whether to include train scores." /> | 275 <param argument="return_train_score" type="boolean" optional="true" truevalue="booltrue" falsevalue="boolfalse" checked="false" help="Whether to include train scores." /> |
284 <!--param argument="return_estimator" type="boolean" optional="true" truevalue="booltrue" falsevalue="boolfalse" checked="false" help="Whether to return the estimators fitted on each split."/> --> | 276 <!--param argument="return_estimator" type="boolean" optional="true" truevalue="booltrue" falsevalue="boolfalse" checked="false" help="Whether to return the estimators fitted on each split." /> --> |
285 <!--param argument="error_score" type="boolean" truevalue="booltrue" falsevalue="boolfalse" checked="true" label="Raise fit error:" help="If false, the metric score is assigned to NaN if an error occurs in estimator fitting and FitFailedWarning is raised."/> --> | 277 <!--param argument="error_score" type="boolean" truevalue="booltrue" falsevalue="boolfalse" checked="true" label="Raise fit error:" help="If false, the metric score is assigned to NaN if an error occurs in estimator fitting and FitFailedWarning is raised." /> --> |
286 <!--fit_params--> | 278 <!--fit_params--> |
287 <expand macro="pre_dispatch" /> | 279 <expand macro="pre_dispatch" /> |
288 </section> | 280 </section> |
289 </when> | 281 </when> |
290 <when value="cross_val_predict"> | 282 <when value="cross_val_predict"> |
300 </when> | 292 </when> |
301 <when value="learning_curve"> | 293 <when value="learning_curve"> |
302 <section name="options" title="Other Options" expanded="false"> | 294 <section name="options" title="Other Options" expanded="false"> |
303 <expand macro="scoring_selection" /> | 295 <expand macro="scoring_selection" /> |
304 <expand macro="model_validation_common_options" /> | 296 <expand macro="model_validation_common_options" /> |
305 <param argument="train_sizes" type="text" value="(0.1, 1.0, 5)" label="train_sizes" help="Relative or absolute numbers of training examples that will be used to generate the learning curve. Supports 1) tuple, to be evaled by np.linspace, e.g. (0.1, 1.0, 5); 2) array-like, e.g. [0.1 , 0.325, 0.55 , 0.775, 1.]"> | 297 <param argument="train_sizes" type="text" value="(0.1, 1.0, 5)" label="train_sizes" |
298 help="Relative or absolute numbers of training examples that will be used to generate the learning curve. Supports 1) tuple, to be evaled by np.linspace, e.g. (0.1, 1.0, 5); 2) array-like, e.g. [0.1 , 0.325, 0.55 , 0.775, 1.]"> | |
306 <sanitizer> | 299 <sanitizer> |
307 <valid initial="default"> | 300 <valid initial="default"> |
308 <add value="[" /> | 301 <add value="[" /> |
309 <add value="]" /> | 302 <add value="]" /> |
310 </valid> | 303 </valid> |
341 <param name="infile2" value="regression_train.tabular" ftype="tabular" /> | 334 <param name="infile2" value="regression_train.tabular" ftype="tabular" /> |
342 <param name="col2" value="6" /> | 335 <param name="col2" value="6" /> |
343 <output name="outfile"> | 336 <output name="outfile"> |
344 <assert_contents> | 337 <assert_contents> |
345 <has_n_columns n="6" /> | 338 <has_n_columns n="6" /> |
346 <has_text text="0.9999961390418067" /> | 339 <has_text text="0.9998136508657879" /> |
347 <has_text text="0.9944541531269271" /> | 340 <has_text text="0.9999980090366614" /> |
348 <has_text text="0.9999193322454393" /> | 341 <has_text text="0.9999977541353663" /> |
349 </assert_contents> | 342 </assert_contents> |
350 </output> | 343 </output> |
351 </test> | 344 </test> |
352 <test> | 345 <test> |
353 <param name="infile_estimator" value="pipeline02" /> | 346 <param name="infile_estimator" value="pipeline02" /> |
354 <param name="selected_function" value="cross_val_predict" /> | 347 <param name="selected_function" value="cross_val_predict" /> |
355 <param name="infile1" value="regression_train.tabular" ftype="tabular" /> | 348 <param name="infile1" value="regression_train.tabular" ftype="tabular" /> |
356 <param name="col1" value="1,2,3,4,5" /> | 349 <param name="col1" value="1,2,3,4,5" /> |
357 <param name="infile2" value="regression_train.tabular" ftype="tabular" /> | 350 <param name="infile2" value="regression_train.tabular" ftype="tabular" /> |
358 <param name="col2" value="6" /> | 351 <param name="col2" value="6" /> |
359 <output name="outfile" file="mv_result02.tabular" lines_diff="14" /> | 352 <output name="outfile"> |
353 <assert_contents> | |
354 <has_n_columns n="1" /> | |
355 <has_text text="1.5781414" /> | |
356 <has_text text="-1.19994559787" /> | |
357 <has_text text="-0.7187446" /> | |
358 <has_text text="0.324693926" /> | |
359 <has_text text="1.25823227" /> | |
360 </assert_contents> | |
361 </output> | |
360 </test> | 362 </test> |
361 <test> | 363 <test> |
362 <param name="infile_estimator" value="pipeline05" /> | 364 <param name="infile_estimator" value="pipeline05" /> |
363 <param name="selected_function" value="learning_curve" /> | 365 <param name="selected_function" value="learning_curve" /> |
364 <param name="infile1" value="regression_X.tabular" ftype="tabular" /> | 366 <param name="infile1" value="regression_X.tabular" ftype="tabular" /> |
377 <param name="infile2" value="regression_train.tabular" ftype="tabular" /> | 379 <param name="infile2" value="regression_train.tabular" ftype="tabular" /> |
378 <param name="col2" value="6" /> | 380 <param name="col2" value="6" /> |
379 <output name="outfile"> | 381 <output name="outfile"> |
380 <assert_contents> | 382 <assert_contents> |
381 <has_n_columns n="3" /> | 383 <has_n_columns n="3" /> |
382 <has_text text="0.25697059258228816" /> | 384 <has_text text="-2.7453395018288753" /> |
383 </assert_contents> | 385 </assert_contents> |
384 </output> | 386 </output> |
385 </test> | 387 </test> |
386 <test> | 388 <test> |
387 <param name="infile_estimator" value="pipeline05" /> | 389 <param name="infile_estimator" value="pipeline05" /> |