Mercurial > repos > bgruening > sklearn_model_validation
comparison train_test_eval.py @ 30:4b359039f09f draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ea12f973df4b97a2691d9e4ce6bf6fae59d57717"
author | bgruening |
---|---|
date | Sat, 01 May 2021 01:03:56 +0000 |
parents | de360b57a5ab |
children | 1fe00785190d |
comparison
equal
deleted
inserted
replaced
29:de360b57a5ab | 30:4b359039f09f |
---|---|
7 | 7 |
8 import joblib | 8 import joblib |
9 import numpy as np | 9 import numpy as np |
10 import pandas as pd | 10 import pandas as pd |
11 from galaxy_ml.model_validations import train_test_split | 11 from galaxy_ml.model_validations import train_test_split |
12 from galaxy_ml.utils import ( | 12 from galaxy_ml.utils import (get_module, get_scoring, load_model, |
13 get_module, | 13 read_columns, SafeEval, try_get_attr) |
14 get_scoring, | |
15 load_model, | |
16 read_columns, | |
17 SafeEval, | |
18 try_get_attr, | |
19 ) | |
20 from scipy.io import mmread | 14 from scipy.io import mmread |
21 from sklearn import pipeline | 15 from sklearn import pipeline |
22 from sklearn.metrics.scorer import _check_multimetric_scoring | 16 from sklearn.metrics.scorer import _check_multimetric_scoring |
23 from sklearn.model_selection import _search, _validation | 17 from sklearn.model_selection import _search, _validation |
24 from sklearn.model_selection._validation import _score | 18 from sklearn.model_selection._validation import _score |
25 from sklearn.utils import indexable, safe_indexing | 19 from sklearn.utils import indexable, safe_indexing |
26 | |
27 | 20 |
28 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score") | 21 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score") |
29 setattr(_search, "_fit_and_score", _fit_and_score) | 22 setattr(_search, "_fit_and_score", _fit_and_score) |
30 setattr(_validation, "_fit_and_score", _fit_and_score) | 23 setattr(_validation, "_fit_and_score", _fit_and_score) |
31 | 24 |
260 infile2 = loaded_df[df_key] | 253 infile2 = loaded_df[df_key] |
261 else: | 254 else: |
262 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True) | 255 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True) |
263 loaded_df[df_key] = infile2 | 256 loaded_df[df_key] = infile2 |
264 | 257 |
265 y = read_columns(infile2, | 258 y = read_columns( |
266 c=c, | 259 infile2, c=c, c_option=column_option, sep="\t", header=header, parse_dates=True |
267 c_option=column_option, | 260 ) |
268 sep='\t', | |
269 header=header, | |
270 parse_dates=True) | |
271 if len(y.shape) == 2 and y.shape[1] == 1: | 261 if len(y.shape) == 2 and y.shape[1] == 1: |
272 y = y.ravel() | 262 y = y.ravel() |
273 if input_type == "refseq_and_interval": | 263 if input_type == "refseq_and_interval": |
274 estimator.set_params(data_batch_generator__features=y.ravel().tolist()) | 264 estimator.set_params(data_batch_generator__features=y.ravel().tolist()) |
275 y = None | 265 y = None |
297 | 287 |
298 df_key = groups + repr(header) | 288 df_key = groups + repr(header) |
299 if df_key in loaded_df: | 289 if df_key in loaded_df: |
300 groups = loaded_df[df_key] | 290 groups = loaded_df[df_key] |
301 | 291 |
302 groups = read_columns(groups, | 292 groups = read_columns( |
303 c=c, | 293 groups, |
304 c_option=column_option, | 294 c=c, |
305 sep='\t', | 295 c_option=column_option, |
306 header=header, | 296 sep="\t", |
307 parse_dates=True) | 297 header=header, |
298 parse_dates=True, | |
299 ) | |
308 groups = groups.ravel() | 300 groups = groups.ravel() |
309 | 301 |
310 # del loaded_df | 302 # del loaded_df |
311 del loaded_df | 303 del loaded_df |
312 | 304 |
369 else: | 361 else: |
370 raise ValueError( | 362 raise ValueError( |
371 "Stratified shuffle split is not " "applicable on empty target values!" | 363 "Stratified shuffle split is not " "applicable on empty target values!" |
372 ) | 364 ) |
373 | 365 |
374 X_train, X_test, y_train, y_test, groups_train, _groups_test = train_test_split_none( | 366 ( |
375 X, y, groups, **test_split_options | 367 X_train, |
376 ) | 368 X_test, |
369 y_train, | |
370 y_test, | |
371 groups_train, | |
372 _groups_test, | |
373 ) = train_test_split_none(X, y, groups, **test_split_options) | |
377 | 374 |
378 exp_scheme = params["experiment_schemes"]["selected_exp_scheme"] | 375 exp_scheme = params["experiment_schemes"]["selected_exp_scheme"] |
379 | 376 |
380 # handle validation (second) split | 377 # handle validation (second) split |
381 if exp_scheme == "train_val_test": | 378 if exp_scheme == "train_val_test": |