Mercurial > repos > bgruening > sklearn_discriminant_classifier
comparison pca.py @ 35:eeaf989f1024 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
author | bgruening |
---|---|
date | Tue, 13 Apr 2021 18:09:01 +0000 |
parents | 2d032cff49eb |
children |
comparison
equal
deleted
inserted
replaced
34:2d032cff49eb | 35:eeaf989f1024 |
---|---|
1 import argparse | 1 import argparse |
2 | |
2 import numpy as np | 3 import numpy as np |
3 from sklearn.decomposition import PCA, IncrementalPCA, KernelPCA | |
4 from galaxy_ml.utils import read_columns | 4 from galaxy_ml.utils import read_columns |
5 from sklearn.decomposition import IncrementalPCA, KernelPCA, PCA | |
6 | |
5 | 7 |
6 def main(): | 8 def main(): |
7 parser = argparse.ArgumentParser(description='RDKit screen') | 9 parser = argparse.ArgumentParser(description="RDKit screen") |
8 parser.add_argument('-i', '--infile', | 10 parser.add_argument("-i", "--infile", help="Input file") |
9 help="Input file") | 11 parser.add_argument( |
10 parser.add_argument('--header', action='store_true', help="Include the header row or skip it") | 12 "--header", action="store_true", help="Include the header row or skip it" |
11 parser.add_argument('-c', '--columns', type=str.lower, default='all', choices=['by_index_number', 'all_but_by_index_number',\ | 13 ) |
12 'by_header_name', 'all_but_by_header_name', 'all_columns'], | 14 parser.add_argument( |
13 help="Choose to select all columns, or exclude/include some") | 15 "-c", |
14 parser.add_argument('-ci', '--column_indices', type=str.lower, | 16 "--columns", |
15 help="Choose to select all columns, or exclude/include some") | 17 type=str.lower, |
16 parser.add_argument('-n', '--number', nargs='?', type=int, default=None,\ | 18 default="all", |
17 help="Number of components to keep. If not set, all components are kept") | 19 choices=[ |
18 parser.add_argument('--whiten', action='store_true', help="Whiten the components") | 20 "by_index_number", |
19 parser.add_argument('-t', '--pca_type', type=str.lower, default='classical', choices=['classical', 'incremental', 'kernel'], | 21 "all_but_by_index_number", |
20 help="Choose which flavour of PCA to use") | 22 "by_header_name", |
21 parser.add_argument('-s', '--svd_solver', type=str.lower, default='auto', choices=['auto', 'full', 'arpack', 'randomized'], | 23 "all_but_by_header_name", |
22 help="Choose the type of svd solver.") | 24 "all_columns", |
23 parser.add_argument('-b', '--batch_size', nargs='?', type=int, default=None,\ | 25 ], |
24 help="The number of samples to use for each batch") | 26 help="Choose to select all columns, or exclude/include some", |
25 parser.add_argument('-k', '--kernel', type=str.lower, default='linear',\ | 27 ) |
26 choices=['linear', 'poly', 'rbf', 'sigmoid', 'cosine', 'precomputed'], | 28 parser.add_argument( |
27 help="Choose the type of kernel.") | 29 "-ci", |
28 parser.add_argument('-g', '--gamma', nargs='?', type=float, default=None, | 30 "--column_indices", |
29 help='Kernel coefficient for rbf, poly and sigmoid kernels. Ignored by other kernels') | 31 type=str.lower, |
30 parser.add_argument('-tol', '--tolerance', type=float, default=0.0, | 32 help="Choose to select all columns, or exclude/include some", |
31 help='Convergence tolerance for arpack. If 0, optimal value will be chosen by arpack') | 33 ) |
32 parser.add_argument('-mi', '--max_iter', nargs='?', type=int, default=None,\ | 34 parser.add_argument( |
33 help="Maximum number of iterations for arpack") | 35 "-n", |
34 parser.add_argument('-d', '--degree', type=int, default=3,\ | 36 "--number", |
35 help="Degree for poly kernels. Ignored by other kernels") | 37 nargs="?", |
36 parser.add_argument('-cf', '--coef0', type=float, default=1.0, | 38 type=int, |
37 help='Independent term in poly and sigmoid kernels') | 39 default=None, |
38 parser.add_argument('-e', '--eigen_solver', type=str.lower, default='auto', choices=['auto', 'dense', 'arpack'], | 40 help="Number of components to keep. If not set, all components are kept", |
39 help="Choose the type of eigen solver.") | 41 ) |
40 parser.add_argument('-o', '--outfile', | 42 parser.add_argument("--whiten", action="store_true", help="Whiten the components") |
41 help="Base name for output file (no extension).") | 43 parser.add_argument( |
44 "-t", | |
45 "--pca_type", | |
46 type=str.lower, | |
47 default="classical", | |
48 choices=["classical", "incremental", "kernel"], | |
49 help="Choose which flavour of PCA to use", | |
50 ) | |
51 parser.add_argument( | |
52 "-s", | |
53 "--svd_solver", | |
54 type=str.lower, | |
55 default="auto", | |
56 choices=["auto", "full", "arpack", "randomized"], | |
57 help="Choose the type of svd solver.", | |
58 ) | |
59 parser.add_argument( | |
60 "-b", | |
61 "--batch_size", | |
62 nargs="?", | |
63 type=int, | |
64 default=None, | |
65 help="The number of samples to use for each batch", | |
66 ) | |
67 parser.add_argument( | |
68 "-k", | |
69 "--kernel", | |
70 type=str.lower, | |
71 default="linear", | |
72 choices=["linear", "poly", "rbf", "sigmoid", "cosine", "precomputed"], | |
73 help="Choose the type of kernel.", | |
74 ) | |
75 parser.add_argument( | |
76 "-g", | |
77 "--gamma", | |
78 nargs="?", | |
79 type=float, | |
80 default=None, | |
81 help="Kernel coefficient for rbf, poly and sigmoid kernels. Ignored by other kernels", | |
82 ) | |
83 parser.add_argument( | |
84 "-tol", | |
85 "--tolerance", | |
86 type=float, | |
87 default=0.0, | |
88 help="Convergence tolerance for arpack. If 0, optimal value will be chosen by arpack", | |
89 ) | |
90 parser.add_argument( | |
91 "-mi", | |
92 "--max_iter", | |
93 nargs="?", | |
94 type=int, | |
95 default=None, | |
96 help="Maximum number of iterations for arpack", | |
97 ) | |
98 parser.add_argument( | |
99 "-d", | |
100 "--degree", | |
101 type=int, | |
102 default=3, | |
103 help="Degree for poly kernels. Ignored by other kernels", | |
104 ) | |
105 parser.add_argument( | |
106 "-cf", | |
107 "--coef0", | |
108 type=float, | |
109 default=1.0, | |
110 help="Independent term in poly and sigmoid kernels", | |
111 ) | |
112 parser.add_argument( | |
113 "-e", | |
114 "--eigen_solver", | |
115 type=str.lower, | |
116 default="auto", | |
117 choices=["auto", "dense", "arpack"], | |
118 help="Choose the type of eigen solver.", | |
119 ) | |
120 parser.add_argument( | |
121 "-o", "--outfile", help="Base name for output file (no extension)." | |
122 ) | |
42 args = parser.parse_args() | 123 args = parser.parse_args() |
43 | 124 |
44 usecols = None | 125 usecols = None |
45 cols = [] | |
46 pca_params = {} | 126 pca_params = {} |
47 | 127 |
48 if args.columns == 'by_index_number' or args.columns == 'all_but_by_index_number': | 128 if args.columns == "by_index_number" or args.columns == "all_but_by_index_number": |
49 usecols = [int(i) for i in args.column_indices.split(',')] | 129 usecols = [int(i) for i in args.column_indices.split(",")] |
50 elif args.columns == 'by_header_name' or args.columns == 'all_but_by_header_name': | 130 elif args.columns == "by_header_name" or args.columns == "all_but_by_header_name": |
51 usecols = args.column_indices | 131 usecols = args.column_indices |
52 | 132 |
53 header = 'infer' if args.header else None | 133 header = "infer" if args.header else None |
54 | 134 |
55 pca_input = read_columns( | 135 pca_input = read_columns( |
56 f=args.infile, | 136 f=args.infile, |
57 c=usecols, | 137 c=usecols, |
58 c_option=args.columns, | 138 c_option=args.columns, |
59 sep='\t', | 139 sep="\t", |
60 header=header, | 140 header=header, |
61 parse_dates=True, | 141 parse_dates=True, |
62 encoding=None, | 142 encoding=None, |
63 index_col=None) | 143 index_col=None, |
144 ) | |
64 | 145 |
65 pca_params.update({'n_components': args.number}) | 146 pca_params.update({"n_components": args.number}) |
66 | 147 |
67 if args.pca_type == 'classical': | 148 if args.pca_type == "classical": |
68 pca_params.update({'svd_solver': args.svd_solver, 'whiten': args.whiten}) | 149 pca_params.update({"svd_solver": args.svd_solver, "whiten": args.whiten}) |
69 if args.svd_solver == 'arpack': | 150 if args.svd_solver == "arpack": |
70 pca_params.update({'tol': args.tolerance}) | 151 pca_params.update({"tol": args.tolerance}) |
71 pca = PCA() | 152 pca = PCA() |
72 | 153 |
73 elif args.pca_type == 'incremental': | 154 elif args.pca_type == "incremental": |
74 pca_params.update({'batch_size': args.batch_size, 'whiten': args.whiten}) | 155 pca_params.update({"batch_size": args.batch_size, "whiten": args.whiten}) |
75 pca = IncrementalPCA() | 156 pca = IncrementalPCA() |
76 | 157 |
77 elif args.pca_type == 'kernel': | 158 elif args.pca_type == "kernel": |
78 pca_params.update({'kernel': args.kernel, 'eigen_solver': args.eigen_solver, 'gamma': args.gamma}) | 159 pca_params.update( |
160 { | |
161 "kernel": args.kernel, | |
162 "eigen_solver": args.eigen_solver, | |
163 "gamma": args.gamma, | |
164 } | |
165 ) | |
79 | 166 |
80 if args.kernel == 'poly': | 167 if args.kernel == "poly": |
81 pca_params.update({'degree': args.degree, 'coef0': args.coef0}) | 168 pca_params.update({"degree": args.degree, "coef0": args.coef0}) |
82 elif args.kernel == 'sigmoid': | 169 elif args.kernel == "sigmoid": |
83 pca_params.update({'coef0': args.coef0}) | 170 pca_params.update({"coef0": args.coef0}) |
84 elif args.kernel == 'precomputed': | 171 elif args.kernel == "precomputed": |
85 pca_input = np.dot(pca_input, pca_input.T) | 172 pca_input = np.dot(pca_input, pca_input.T) |
86 | 173 |
87 if args.eigen_solver == 'arpack': | 174 if args.eigen_solver == "arpack": |
88 pca_params.update({'tol': args.tolerance, 'max_iter': args.max_iter}) | 175 pca_params.update({"tol": args.tolerance, "max_iter": args.max_iter}) |
89 | 176 |
90 pca = KernelPCA() | 177 pca = KernelPCA() |
91 | 178 |
92 print(pca_params) | 179 print(pca_params) |
93 pca.set_params(**pca_params) | 180 pca.set_params(**pca_params) |
94 pca_output = pca.fit_transform(pca_input) | 181 pca_output = pca.fit_transform(pca_input) |
95 np.savetxt(fname=args.outfile, X=pca_output, fmt='%.4f', delimiter='\t') | 182 np.savetxt(fname=args.outfile, X=pca_output, fmt="%.4f", delimiter="\t") |
96 | 183 |
97 | 184 |
98 if __name__ == "__main__": | 185 if __name__ == "__main__": |
99 main() | 186 main() |