Mercurial > repos > bgruening > sklearn_clf_metrics
annotate model_validations.py @ 25:68afcd163b3d draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 49522db5f2dc8a571af49e3f38e80c22571068f4
| author | bgruening | 
|---|---|
| date | Tue, 09 Jul 2019 19:36:00 -0400 | 
| parents | 9bf11bbeccc3 | 
| children | 
| rev | line source | 
|---|---|
| 
24
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
1 """ | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
2 class | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
3 ----- | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
4 OrderedKFold | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
5 RepeatedOrderedKold | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
6 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
7 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
8 function | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
9 -------- | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
10 train_test_split | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
11 """ | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
12 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
13 import numpy as np | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
14 import warnings | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
15 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
16 from itertools import chain | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
17 from math import ceil, floor | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
18 from sklearn.model_selection import (GroupShuffleSplit, ShuffleSplit, | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
19 StratifiedShuffleSplit) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
20 from sklearn.model_selection._split import _BaseKFold, _RepeatedSplits | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
21 from sklearn.utils import check_random_state, indexable, safe_indexing | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
22 from sklearn.utils.validation import _num_samples, check_array | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
23 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
24 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
25 def _validate_shuffle_split(n_samples, test_size, train_size, | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
26 default_test_size=None): | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
27 """ | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
28 Validation helper to check if the test/test sizes are meaningful wrt to the | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
29 size of the data (n_samples) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
30 """ | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
31 if test_size is None and train_size is None: | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
32 test_size = default_test_size | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
33 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
34 test_size_type = np.asarray(test_size).dtype.kind | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
35 train_size_type = np.asarray(train_size).dtype.kind | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
36 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
37 if (test_size_type == 'i' and (test_size >= n_samples or test_size <= 0) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
38 or test_size_type == 'f' and (test_size <= 0 or test_size >= 1)): | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
39 raise ValueError('test_size={0} should be either positive and smaller' | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
40 ' than the number of samples {1} or a float in the ' | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
41 '(0, 1) range'.format(test_size, n_samples)) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
42 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
43 if (train_size_type == 'i' and (train_size >= n_samples or train_size <= 0) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
44 or train_size_type == 'f' and (train_size <= 0 or train_size >= 1)): | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
45 raise ValueError('train_size={0} should be either positive and smaller' | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
46 ' than the number of samples {1} or a float in the ' | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
47 '(0, 1) range'.format(train_size, n_samples)) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
48 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
49 if train_size is not None and train_size_type not in ('i', 'f'): | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
50 raise ValueError("Invalid value for train_size: {}".format(train_size)) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
51 if test_size is not None and test_size_type not in ('i', 'f'): | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
52 raise ValueError("Invalid value for test_size: {}".format(test_size)) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
53 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
54 if (train_size_type == 'f' and test_size_type == 'f' and | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
55 train_size + test_size > 1): | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
56 raise ValueError( | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
57 'The sum of test_size and train_size = {}, should be in the (0, 1)' | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
58 ' range. Reduce test_size and/or train_size.' | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
59 .format(train_size + test_size)) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
60 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
61 if test_size_type == 'f': | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
62 n_test = ceil(test_size * n_samples) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
63 elif test_size_type == 'i': | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
64 n_test = float(test_size) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
65 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
66 if train_size_type == 'f': | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
67 n_train = floor(train_size * n_samples) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
68 elif train_size_type == 'i': | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
69 n_train = float(train_size) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
70 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
71 if train_size is None: | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
72 n_train = n_samples - n_test | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
73 elif test_size is None: | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
74 n_test = n_samples - n_train | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
75 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
76 if n_train + n_test > n_samples: | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
77 raise ValueError('The sum of train_size and test_size = %d, ' | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
78 'should be smaller than the number of ' | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
79 'samples %d. Reduce test_size and/or ' | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
80 'train_size.' % (n_train + n_test, n_samples)) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
81 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
82 n_train, n_test = int(n_train), int(n_test) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
83 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
84 if n_train == 0: | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
85 raise ValueError( | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
86 'With n_samples={}, test_size={} and train_size={}, the ' | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
87 'resulting train set will be empty. Adjust any of the ' | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
88 'aforementioned parameters.'.format(n_samples, test_size, | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
89 train_size) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
90 ) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
91 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
92 return n_train, n_test | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
93 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
94 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
95 def train_test_split(*arrays, **options): | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
96 """Extend sklearn.model_selection.train_test_slit to have group split. | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
97 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
98 Parameters | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
99 ---------- | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
100 *arrays : sequence of indexables with same length / shape[0] | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
101 Allowed inputs are lists, numpy arrays, scipy-sparse | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
102 matrices or pandas dataframes. | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
103 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
104 test_size : float, int or None, optional (default=None) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
105 If float, should be between 0.0 and 1.0 and represent the proportion | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
106 of the dataset to include in the test split. If int, represents the | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
107 absolute number of test samples. If None, the value is set to the | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
108 complement of the train size. If ``train_size`` is also None, it will | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
109 be set to 0.25. | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
110 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
111 train_size : float, int, or None, (default=None) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
112 If float, should be between 0.0 and 1.0 and represent the | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
113 proportion of the dataset to include in the train split. If | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
114 int, represents the absolute number of train samples. If None, | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
115 the value is automatically set to the complement of the test size. | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
116 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
117 random_state : int, RandomState instance or None, optional (default=None) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
118 If int, random_state is the seed used by the random number generator; | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
119 If RandomState instance, random_state is the random number generator; | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
120 If None, the random number generator is the RandomState instance used | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
121 by `np.random`. | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
122 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
123 shuffle : None or str (default='simple') | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
124 How to shuffle the data before splitting. | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
125 None, no shuffle. | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
126 For str, one of 'simple', 'stratified' and 'group', corresponding to | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
127 `ShuffleSplit`, `StratifiedShuffleSplit` and `GroupShuffleSplit`, | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
128 respectively. | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
129 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
130 labels : array-like or None (default=None) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
131 Ignored if shuffle is None or 'simple'. | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
132 When shuffle='stratified', this array is used as class labels. | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
133 When shuffle='group', this array is used as groups. | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
134 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
135 Returns | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
136 ------- | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
137 splitting : list, length=2 * len(arrays) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
138 List containing train-test split of inputs. | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
139 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
140 """ | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
141 n_arrays = len(arrays) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
142 if n_arrays == 0: | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
143 raise ValueError("At least one array required as input") | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
144 test_size = options.pop('test_size', None) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
145 train_size = options.pop('train_size', None) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
146 random_state = options.pop('random_state', None) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
147 shuffle = options.pop('shuffle', 'simple') | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
148 labels = options.pop('labels', None) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
149 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
150 if options: | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
151 raise TypeError("Invalid parameters passed: %s" % str(options)) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
152 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
153 arrays = indexable(*arrays) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
154 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
155 n_samples = _num_samples(arrays[0]) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
156 if shuffle == 'group': | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
157 if labels is None: | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
158 raise ValueError("When shuffle='group', " | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
159 "labels should not be None!") | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
160 labels = check_array(labels, ensure_2d=False, dtype=None) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
161 uniques = np.unique(labels) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
162 n_samples = uniques.size | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
163 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
164 n_train, n_test = _validate_shuffle_split(n_samples, test_size, train_size, | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
165 default_test_size=0.25) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
166 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
167 shuffle_options = dict(test_size=n_test, | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
168 train_size=n_train, | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
169 random_state=random_state) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
170 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
171 if shuffle is None: | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
172 if labels is not None: | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
173 warnings.warn("The `labels` is ignored for " | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
174 "shuffle being None!") | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
175 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
176 train = np.arange(n_train) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
177 test = np.arange(n_train, n_train + n_test) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
178 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
179 elif shuffle == 'simple': | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
180 if labels is not None: | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
181 warnings.warn("The `labels` is not needed and therefore " | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
182 "ignored for ShuffleSplit, as shuffle='simple'!") | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
183 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
184 cv = ShuffleSplit(**shuffle_options) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
185 train, test = next(cv.split(X=arrays[0], y=None)) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
186 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
187 elif shuffle == 'stratified': | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
188 cv = StratifiedShuffleSplit(**shuffle_options) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
189 train, test = next(cv.split(X=arrays[0], y=labels)) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
190 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
191 elif shuffle == 'group': | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
192 cv = GroupShuffleSplit(**shuffle_options) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
193 train, test = next(cv.split(X=arrays[0], y=None, groups=labels)) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
194 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
195 else: | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
196 raise ValueError("The argument `shuffle` only supports None, " | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
197 "'simple', 'stratified' and 'group', but got `%s`!" | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
198 % shuffle) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
199 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
200 return list(chain.from_iterable((safe_indexing(a, train), | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
201 safe_indexing(a, test)) for a in arrays)) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
202 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
203 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
204 class OrderedKFold(_BaseKFold): | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
205 """ | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
206 Split into K fold based on ordered target value | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
207 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
208 Parameters | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
209 ---------- | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
210 n_splits : int, default=3 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
211 Number of folds. Must be at least 2. | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
212 shuffle: bool | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
213 random_state: None or int | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
214 """ | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
215 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
216 def __init__(self, n_splits=3, shuffle=False, random_state=None): | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
217 super(OrderedKFold, self).__init__(n_splits, shuffle, random_state) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
218 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
219 def _iter_test_indices(self, X, y, groups=None): | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
220 n_samples = _num_samples(X) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
221 n_splits = self.n_splits | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
222 y = np.asarray(y) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
223 sorted_index = np.argsort(y) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
224 if self.shuffle: | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
225 current = 0 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
226 rng = check_random_state(self.random_state) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
227 for i in range(n_samples // int(n_splits)): | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
228 start, stop = current, current + n_splits | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
229 rng.shuffle(sorted_index[start:stop]) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
230 current = stop | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
231 rng.shuffle(sorted_index[current:]) | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
232 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
233 for i in range(n_splits): | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
234 yield sorted_index[i:n_samples:n_splits] | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
235 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
236 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
237 class RepeatedOrderedKFold(_RepeatedSplits): | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
238 """ Repeated OrderedKFold runs mutiple times with different randomization. | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
239 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
240 Parameters | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
241 ---------- | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
242 n_splits : int, default=5 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
243 Number of folds. Must be at least 2. | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
244 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
245 n_repeats : int, default=5 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
246 Number of times cross-validator to be repeated. | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
247 | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
248 random_state: int, RandomState instance or None. Optional | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
249 """ | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
250 def __init__(self, n_splits=5, n_repeats=5, random_state=None): | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
251 super(RepeatedOrderedKFold, self).__init__( | 
| 
 
9bf11bbeccc3
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
 
bgruening 
parents:  
diff
changeset
 | 
252 OrderedKFold, n_repeats, random_state, n_splits=n_splits) | 
