Mercurial > repos > bgruening > sklearn_clf_metrics
annotate model_validations.py @ 25:68afcd163b3d draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 49522db5f2dc8a571af49e3f38e80c22571068f4
| author | bgruening | 
|---|---|
| date | Tue, 09 Jul 2019 19:36:00 -0400 | 
| parents | 9bf11bbeccc3 | 
| children | 
| rev | line source | 
|---|---|
| 24 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 1 """ | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 2 class | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 3 ----- | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 4 OrderedKFold | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 5 RepeatedOrderedKold | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 6 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 7 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 8 function | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 9 -------- | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 10 train_test_split | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 11 """ | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 12 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 13 import numpy as np | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 14 import warnings | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 15 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 16 from itertools import chain | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 17 from math import ceil, floor | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 18 from sklearn.model_selection import (GroupShuffleSplit, ShuffleSplit, | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 19 StratifiedShuffleSplit) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 20 from sklearn.model_selection._split import _BaseKFold, _RepeatedSplits | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 21 from sklearn.utils import check_random_state, indexable, safe_indexing | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 22 from sklearn.utils.validation import _num_samples, check_array | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 23 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 24 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 25 def _validate_shuffle_split(n_samples, test_size, train_size, | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 26 default_test_size=None): | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 27 """ | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 28 Validation helper to check if the test/test sizes are meaningful wrt to the | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 29 size of the data (n_samples) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 30 """ | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 31 if test_size is None and train_size is None: | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 32 test_size = default_test_size | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 33 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 34 test_size_type = np.asarray(test_size).dtype.kind | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 35 train_size_type = np.asarray(train_size).dtype.kind | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 36 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 37 if (test_size_type == 'i' and (test_size >= n_samples or test_size <= 0) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 38 or test_size_type == 'f' and (test_size <= 0 or test_size >= 1)): | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 39 raise ValueError('test_size={0} should be either positive and smaller' | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 40 ' than the number of samples {1} or a float in the ' | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 41 '(0, 1) range'.format(test_size, n_samples)) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 42 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 43 if (train_size_type == 'i' and (train_size >= n_samples or train_size <= 0) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 44 or train_size_type == 'f' and (train_size <= 0 or train_size >= 1)): | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 45 raise ValueError('train_size={0} should be either positive and smaller' | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 46 ' than the number of samples {1} or a float in the ' | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 47 '(0, 1) range'.format(train_size, n_samples)) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 48 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 49 if train_size is not None and train_size_type not in ('i', 'f'): | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 50 raise ValueError("Invalid value for train_size: {}".format(train_size)) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 51 if test_size is not None and test_size_type not in ('i', 'f'): | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 52 raise ValueError("Invalid value for test_size: {}".format(test_size)) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 53 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 54 if (train_size_type == 'f' and test_size_type == 'f' and | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 55 train_size + test_size > 1): | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 56 raise ValueError( | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 57 'The sum of test_size and train_size = {}, should be in the (0, 1)' | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 58 ' range. Reduce test_size and/or train_size.' | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 59 .format(train_size + test_size)) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 60 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 61 if test_size_type == 'f': | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 62 n_test = ceil(test_size * n_samples) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 63 elif test_size_type == 'i': | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 64 n_test = float(test_size) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 65 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 66 if train_size_type == 'f': | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 67 n_train = floor(train_size * n_samples) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 68 elif train_size_type == 'i': | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 69 n_train = float(train_size) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 70 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 71 if train_size is None: | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 72 n_train = n_samples - n_test | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 73 elif test_size is None: | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 74 n_test = n_samples - n_train | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 75 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 76 if n_train + n_test > n_samples: | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 77 raise ValueError('The sum of train_size and test_size = %d, ' | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 78 'should be smaller than the number of ' | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 79 'samples %d. Reduce test_size and/or ' | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 80 'train_size.' % (n_train + n_test, n_samples)) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 81 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 82 n_train, n_test = int(n_train), int(n_test) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 83 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 84 if n_train == 0: | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 85 raise ValueError( | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 86 'With n_samples={}, test_size={} and train_size={}, the ' | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 87 'resulting train set will be empty. Adjust any of the ' | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 88 'aforementioned parameters.'.format(n_samples, test_size, | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 89 train_size) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 90 ) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 91 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 92 return n_train, n_test | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 93 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 94 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 95 def train_test_split(*arrays, **options): | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 96 """Extend sklearn.model_selection.train_test_slit to have group split. | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 97 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 98 Parameters | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 99 ---------- | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 100 *arrays : sequence of indexables with same length / shape[0] | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 101 Allowed inputs are lists, numpy arrays, scipy-sparse | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 102 matrices or pandas dataframes. | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 103 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 104 test_size : float, int or None, optional (default=None) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 105 If float, should be between 0.0 and 1.0 and represent the proportion | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 106 of the dataset to include in the test split. If int, represents the | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 107 absolute number of test samples. If None, the value is set to the | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 108 complement of the train size. If ``train_size`` is also None, it will | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 109 be set to 0.25. | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 110 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 111 train_size : float, int, or None, (default=None) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 112 If float, should be between 0.0 and 1.0 and represent the | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 113 proportion of the dataset to include in the train split. If | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 114 int, represents the absolute number of train samples. If None, | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 115 the value is automatically set to the complement of the test size. | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 116 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 117 random_state : int, RandomState instance or None, optional (default=None) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 118 If int, random_state is the seed used by the random number generator; | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 119 If RandomState instance, random_state is the random number generator; | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 120 If None, the random number generator is the RandomState instance used | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 121 by `np.random`. | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 122 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 123 shuffle : None or str (default='simple') | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 124 How to shuffle the data before splitting. | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 125 None, no shuffle. | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 126 For str, one of 'simple', 'stratified' and 'group', corresponding to | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 127 `ShuffleSplit`, `StratifiedShuffleSplit` and `GroupShuffleSplit`, | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 128 respectively. | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 129 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 130 labels : array-like or None (default=None) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 131 Ignored if shuffle is None or 'simple'. | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 132 When shuffle='stratified', this array is used as class labels. | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 133 When shuffle='group', this array is used as groups. | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 134 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 135 Returns | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 136 ------- | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 137 splitting : list, length=2 * len(arrays) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 138 List containing train-test split of inputs. | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 139 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 140 """ | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 141 n_arrays = len(arrays) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 142 if n_arrays == 0: | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 143 raise ValueError("At least one array required as input") | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 144 test_size = options.pop('test_size', None) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 145 train_size = options.pop('train_size', None) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 146 random_state = options.pop('random_state', None) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 147 shuffle = options.pop('shuffle', 'simple') | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 148 labels = options.pop('labels', None) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 149 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 150 if options: | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 151 raise TypeError("Invalid parameters passed: %s" % str(options)) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 152 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 153 arrays = indexable(*arrays) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 154 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 155 n_samples = _num_samples(arrays[0]) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 156 if shuffle == 'group': | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 157 if labels is None: | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 158 raise ValueError("When shuffle='group', " | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 159 "labels should not be None!") | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 160 labels = check_array(labels, ensure_2d=False, dtype=None) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 161 uniques = np.unique(labels) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 162 n_samples = uniques.size | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 163 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 164 n_train, n_test = _validate_shuffle_split(n_samples, test_size, train_size, | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 165 default_test_size=0.25) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 166 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 167 shuffle_options = dict(test_size=n_test, | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 168 train_size=n_train, | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 169 random_state=random_state) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 170 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 171 if shuffle is None: | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 172 if labels is not None: | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 173 warnings.warn("The `labels` is ignored for " | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 174 "shuffle being None!") | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 175 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 176 train = np.arange(n_train) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 177 test = np.arange(n_train, n_train + n_test) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 178 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 179 elif shuffle == 'simple': | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 180 if labels is not None: | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 181 warnings.warn("The `labels` is not needed and therefore " | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 182 "ignored for ShuffleSplit, as shuffle='simple'!") | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 183 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 184 cv = ShuffleSplit(**shuffle_options) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 185 train, test = next(cv.split(X=arrays[0], y=None)) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 186 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 187 elif shuffle == 'stratified': | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 188 cv = StratifiedShuffleSplit(**shuffle_options) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 189 train, test = next(cv.split(X=arrays[0], y=labels)) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 190 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 191 elif shuffle == 'group': | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 192 cv = GroupShuffleSplit(**shuffle_options) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 193 train, test = next(cv.split(X=arrays[0], y=None, groups=labels)) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 194 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 195 else: | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 196 raise ValueError("The argument `shuffle` only supports None, " | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 197 "'simple', 'stratified' and 'group', but got `%s`!" | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 198 % shuffle) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 199 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 200 return list(chain.from_iterable((safe_indexing(a, train), | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 201 safe_indexing(a, test)) for a in arrays)) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 202 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 203 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 204 class OrderedKFold(_BaseKFold): | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 205 """ | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 206 Split into K fold based on ordered target value | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 207 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 208 Parameters | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 209 ---------- | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 210 n_splits : int, default=3 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 211 Number of folds. Must be at least 2. | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 212 shuffle: bool | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 213 random_state: None or int | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 214 """ | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 215 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 216 def __init__(self, n_splits=3, shuffle=False, random_state=None): | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 217 super(OrderedKFold, self).__init__(n_splits, shuffle, random_state) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 218 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 219 def _iter_test_indices(self, X, y, groups=None): | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 220 n_samples = _num_samples(X) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 221 n_splits = self.n_splits | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 222 y = np.asarray(y) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 223 sorted_index = np.argsort(y) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 224 if self.shuffle: | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 225 current = 0 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 226 rng = check_random_state(self.random_state) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 227 for i in range(n_samples // int(n_splits)): | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 228 start, stop = current, current + n_splits | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 229 rng.shuffle(sorted_index[start:stop]) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 230 current = stop | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 231 rng.shuffle(sorted_index[current:]) | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 232 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 233 for i in range(n_splits): | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 234 yield sorted_index[i:n_samples:n_splits] | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 235 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 236 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 237 class RepeatedOrderedKFold(_RepeatedSplits): | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 238 """ Repeated OrderedKFold runs mutiple times with different randomization. | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 239 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 240 Parameters | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 241 ---------- | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 242 n_splits : int, default=5 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 243 Number of folds. Must be at least 2. | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 244 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 245 n_repeats : int, default=5 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 246 Number of times cross-validator to be repeated. | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 247 | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 248 random_state: int, RandomState instance or None. Optional | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 249 """ | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 250 def __init__(self, n_splits=5, n_repeats=5, random_state=None): | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 251 super(RepeatedOrderedKFold, self).__init__( | 
| 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 bgruening parents: diff
changeset | 252 OrderedKFold, n_repeats, random_state, n_splits=n_splits) | 
