Mercurial > repos > nick > allele_counts
comparison allele-counts.py @ 1:49bb46c3a1af
Uploaded script
author | nick |
---|---|
date | Fri, 24 May 2013 10:33:35 -0400 |
parents | |
children | 318fdf77aa54 |
comparison
equal
deleted
inserted
replaced
0:28c40f4b7d2b | 1:49bb46c3a1af |
---|---|
1 #!/usr/bin/python | |
2 # This parses the output of Dan's "Naive Variant Detector" (previously, | |
3 # "BAM Coverage"). It was forked from the code of "bam-coverage.py". | |
4 # | |
5 # New in this version: default to stdin and stdout, override by using -i and -o | |
6 # to specify filenames | |
7 # | |
8 # TODO: | |
9 # - test handling of -c 0 (and -f 0?) | |
10 # - should it technically handle data lines that start with a '#'? | |
11 import os | |
12 import sys | |
13 from optparse import OptionParser | |
14 | |
15 COLUMNS = ['sample', 'chr', 'pos', 'A', 'C', 'G', 'T', 'coverage', 'alleles', | |
16 'major', 'minor', 'freq'] #, 'bias'] | |
17 CANONICAL_VARIANTS = ['A', 'C', 'G', 'T'] | |
18 USAGE = """Usage: cat variants.vcf | %prog [options] > alleles.csv | |
19 %prog [options] -i variants.vcf -o alleles.csv""" | |
20 OPT_DEFAULTS = {'infile':'-', 'outfile':'-', 'freq_thres':1.0, 'covg_thres':100, | |
21 'print_header':False, 'stdin':False} | |
22 DESCRIPTION = """This will parse the VCF output of Dan's "Naive Variant Caller" (aka "BAM Coverage") Galaxy tool. For each position reported, it counts the number of reads of each base, determines the major allele, minor allele (second most frequent variant), and number of alleles above a threshold. So currently it only considers SNVs (ACGT), including in the coverage figure. By default it reads from stdin and prints to stdout.""" | |
23 EPILOG = """Requirements: | |
24 The input VCF must report the variants for each strand. | |
25 The variants should be case-sensitive (e.g. all capital base letters). | |
26 Strand bias: Both strands must show the same bases passing the frequency threshold (but not necessarily in the same order). If the site fails this test, the number of alleles is reported as 0.""" | |
27 | |
28 | |
29 def get_options(defaults, usage, description='', epilog=''): | |
30 """Get options, print usage text.""" | |
31 | |
32 parser = OptionParser(usage=usage, description=description, epilog=epilog) | |
33 | |
34 parser.add_option('-i', '--infile', dest='infile', | |
35 default=defaults.get('infile'), | |
36 help='Read input VCF data from this file instead of stdin.') | |
37 parser.add_option('-o', '--outfile', dest='outfile', | |
38 default=defaults.get('outfile'), | |
39 help='Print output data to this file instead of stdout.') | |
40 parser.add_option('-f', '--freq-thres', dest='freq_thres', type='float', | |
41 default=defaults.get('freq_thres'), | |
42 help='Frequency threshold for counting alleles, given in percentage: -f 1 = 1% frequency. Default is %default%.') | |
43 parser.add_option('-c', '--covg-thres', dest='covg_thres', type='int', | |
44 default=defaults.get('covg_thres'), | |
45 help='Coverage threshold. Each site must be supported by at least this many reads on each strand. Otherwise the site will not be printed in the output. The default is %default reads per strand.') | |
46 parser.add_option('-H', '--header', dest='print_header', action='store_const', | |
47 const=not(defaults.get('print_header')), default=defaults.get('print_header'), | |
48 help='Print header line. This is a #-commented line with the column labels. Off by default.') | |
49 parser.add_option('-d', '--debug', dest='debug', action='store_true', | |
50 default=False, | |
51 help='Turn on debug mode. You must also specify a single site to process in a final argument using UCSC coordinate format.') | |
52 | |
53 (options, args) = parser.parse_args() | |
54 | |
55 # read in positional arguments | |
56 arguments = {} | |
57 if options.debug: | |
58 if len(args) >= 1: | |
59 arguments['print_loc'] = args[0] | |
60 args.remove(args[0]) | |
61 | |
62 return (options, arguments) | |
63 | |
64 | |
65 def main(): | |
66 | |
67 (options, args) = get_options(OPT_DEFAULTS, USAGE, DESCRIPTION, EPILOG) | |
68 | |
69 infile = options.infile | |
70 outfile = options.outfile | |
71 print_header = options.print_header | |
72 freq_thres = options.freq_thres / 100.0 | |
73 covg_thres = options.covg_thres | |
74 debug = options.debug | |
75 | |
76 if debug: | |
77 print_loc = args.get('print_loc') | |
78 if print_loc: | |
79 if ':' in print_loc: | |
80 (print_chr, print_pos) = print_loc.split(':') | |
81 else: | |
82 print_pos = print_loc | |
83 else: | |
84 sys.stderr.write("Warning: No site coordinate found in arguments. " | |
85 +"Turning off debug mode.\n") | |
86 debug = False | |
87 | |
88 # set infile_handle to either stdin or the input file | |
89 if infile == OPT_DEFAULTS.get('infile'): | |
90 infile_handle = sys.stdin | |
91 sys.stderr.write("Reading from standard input..\n") | |
92 else: | |
93 if os.path.exists(infile): | |
94 infile_handle = open(infile, 'r') | |
95 else: | |
96 fail('Error: Input VCF file '+infile+' not found.') | |
97 | |
98 # set outfile_handle to either stdout or the output file | |
99 if outfile == OPT_DEFAULTS.get('outfile'): | |
100 outfile_handle = sys.stdout | |
101 else: | |
102 if os.path.exists(outfile): | |
103 fail('Error: The given output filename '+outfile+' already exists.') | |
104 else: | |
105 outfile_handle = open(outfile, 'w') | |
106 | |
107 if print_header: | |
108 outfile_handle.write('#'+'\t'.join(COLUMNS)+"\n") | |
109 | |
110 # main loop: process and print one line at a time | |
111 sample_names = [] | |
112 for line in infile_handle: | |
113 line = line.rstrip('\r\n') | |
114 | |
115 # header lines | |
116 if line[0] == '#': | |
117 if line[0:6].upper() == '#CHROM': | |
118 sample_names = line.split('\t')[9:] | |
119 continue | |
120 | |
121 if not sample_names: | |
122 fail("Error in input VCF: Data line encountered before header line. " | |
123 +"Failed on line:\n"+line) | |
124 | |
125 site_data = read_site(line, sample_names, CANONICAL_VARIANTS) | |
126 | |
127 if debug: | |
128 if site_data['pos'] != print_pos: | |
129 continue | |
130 try: | |
131 if site_data['chr'] != print_chr: | |
132 continue | |
133 except NameError, e: | |
134 pass # No chr specified. Just go ahead and print the line. | |
135 | |
136 site_summary = summarize_site(site_data, sample_names, CANONICAL_VARIANTS, | |
137 freq_thres, covg_thres, debug=debug) | |
138 | |
139 if debug and site_summary[0]['print']: | |
140 print line.split('\t')[9].split(':')[-1] | |
141 | |
142 print_site(outfile_handle, site_summary, COLUMNS) | |
143 | |
144 # close any open filehandles | |
145 if infile_handle is not sys.stdin: | |
146 infile_handle.close() | |
147 if outfile_handle is not sys.stdout: | |
148 outfile_handle.close() | |
149 | |
150 # keeps Galaxy from giving an error if there were messages on stderr | |
151 sys.exit(0) | |
152 | |
153 | |
154 | |
155 def read_site(line, sample_names, canonical): | |
156 """Read in a line, parse the variants into a data structure, and return it. | |
157 The line should be actual site data, not a header line, so check beforehand. | |
158 Notes: | |
159 - The line is assumed to have been chomped.""" | |
160 | |
161 site = {} | |
162 fields = line.split('\t') | |
163 | |
164 if len(fields) < 9: | |
165 fail("Error in input VCF: wrong number of fields in data line. " | |
166 +"Failed on line:\n"+line) | |
167 | |
168 site['chr'] = fields[0] | |
169 site['pos'] = fields[1] | |
170 samples = fields[9:] | |
171 | |
172 if len(samples) < len(sample_names): | |
173 fail("Error in input VCF: missing sample fields in data line. " | |
174 +"Failed on line:\n"+line) | |
175 elif len(samples) > len(sample_names): | |
176 fail("Error in input VCF: more sample fields in data line than in header. " | |
177 +"Failed on line:\n"+line) | |
178 | |
179 sample_counts = {} | |
180 for i in range(len(samples)): | |
181 | |
182 variant_counts = {} | |
183 counts = samples[i].split(':')[-1] | |
184 counts = counts.split(',') | |
185 | |
186 for count in counts: | |
187 if not count: | |
188 continue | |
189 fields = count.split('=') | |
190 if len(fields) != 2: | |
191 fail("Error in input VCF: Incorrect variant data format (must contain " | |
192 +"a single '='). Failed on line:\n"+line) | |
193 (variant, reads) = fields | |
194 if variant[1:] not in canonical: | |
195 continue | |
196 if variant[0] != '-' and variant[0] != '+': | |
197 fail("Error in input VCF: variant data not strand-specific. " | |
198 +"Failed on line:\n"+line) | |
199 try: | |
200 variant_counts[variant] = int(reads) | |
201 except ValueError, e: | |
202 continue | |
203 | |
204 sample_counts[sample_names[i]] = variant_counts | |
205 | |
206 site['samples'] = sample_counts | |
207 | |
208 return site | |
209 | |
210 | |
211 def summarize_site(site, sample_names, canonical, freq_thres, covg_thres, | |
212 debug=False): | |
213 """Take the raw data from the VCF line and transform it into the summary data | |
214 to be printed in the output format.""" | |
215 | |
216 site_summary = [] | |
217 for sample_name in sample_names: | |
218 | |
219 sample = {'print':False} | |
220 variants = site['samples'].get(sample_name) | |
221 if not variants: | |
222 site_summary.append(sample) | |
223 continue | |
224 | |
225 sample['sample'] = sample_name | |
226 sample['chr'] = site['chr'] | |
227 sample['pos'] = site['pos'] | |
228 | |
229 coverage = sum(variants.values()) | |
230 | |
231 # get stranded coverage | |
232 covg_plus = 0 | |
233 covg_minus = 0 | |
234 for variant in variants: | |
235 if variant[0] == '+': | |
236 covg_plus += variants[variant] | |
237 elif variant[0] == '-': | |
238 covg_minus += variants[variant] | |
239 # stranded coverage threshold | |
240 if coverage <= 0 or covg_plus < covg_thres or covg_minus < covg_thres: | |
241 site_summary.append(sample) | |
242 continue | |
243 else: | |
244 sample['print'] = True | |
245 | |
246 # get an ordered list of read counts for all variants (either strand) | |
247 ranked_bases = get_read_counts(variants, 0, strands='+-', debug=debug) | |
248 | |
249 # record read counts into dict for this sample | |
250 for base in ranked_bases: | |
251 sample[base[0]] = base[1] | |
252 # fill in any zeros | |
253 for variant in canonical: | |
254 if not sample.has_key(variant): | |
255 sample[variant] = 0 | |
256 | |
257 sample['alleles'] = count_alleles(variants, freq_thres, debug=debug) | |
258 | |
259 # set minor allele to N if there's a tie for 2nd | |
260 if len(ranked_bases) >= 3 and ranked_bases[1][1] == ranked_bases[2][1]: | |
261 ranked_bases[1] = ('N', 0) | |
262 sample['alleles'] = 1 if sample['alleles'] else 0 | |
263 | |
264 if debug: print ranked_bases | |
265 | |
266 sample['coverage'] = coverage | |
267 try: | |
268 sample['major'] = ranked_bases[0][0] | |
269 except IndexError, e: | |
270 sample['major'] = '.' | |
271 try: | |
272 sample['minor'] = ranked_bases[1][0] | |
273 sample['freq'] = ranked_bases[1][1] / float(coverage) | |
274 except IndexError, e: | |
275 sample['minor'] = '.' | |
276 sample['freq'] = 0.0 | |
277 | |
278 site_summary.append(sample) | |
279 | |
280 return site_summary | |
281 | |
282 | |
283 def print_site(filehandle, site, columns): | |
284 """Print the output lines for one site (one per sample). | |
285 filehandle must be open.""" | |
286 for sample in site: | |
287 if sample['print']: | |
288 fields = [str(sample.get(column)) for column in columns] | |
289 filehandle.write('\t'.join(fields)+"\n") | |
290 | |
291 | |
292 def get_read_counts(variant_counts, freq_thres, strands='+-', debug=False): | |
293 """Count the number of reads for each base, and create a ranked list of | |
294 alleles passing the frequency threshold. | |
295 Arguments: | |
296 variant_counts: Dict of the stranded variants (keys) and their read counts (values). | |
297 freq_thres: The frequency threshold each allele needs to pass to be included. | |
298 strands: Which strand(s) to count. Can be '+', '-', or '+-' for both (default). | |
299 variants: A list of the variants of interest. Other types of variants will not | |
300 be included in the returned list. If no list is given, all variants found in | |
301 the variant_counts will be used. | |
302 Return value: | |
303 ranked_bases: A list of the alleles and their read counts. The elements are | |
304 tuples (base, read count). The alleles are listed in descending order of | |
305 frequency, and only those passing the threshold are included.""" | |
306 | |
307 # Get list of all variants from variant_counts list | |
308 variants = [variant[1:] for variant in variant_counts] | |
309 # deduplicate via a dict | |
310 variant_dict = dict((variant, 1) for variant in variants) | |
311 variants = variant_dict.keys() | |
312 | |
313 ranked_bases = [] | |
314 for variant in variants: | |
315 reads = 0 | |
316 for strand in strands: | |
317 reads += variant_counts.get(strand+variant, 0) | |
318 ranked_bases.append((variant, reads)) | |
319 | |
320 # get coverage for the specified strands | |
321 coverage = 0 | |
322 for variant in variant_counts: | |
323 if variant[0] in strands: | |
324 coverage += variant_counts.get(variant, 0) | |
325 # if debug: print "strands: "+strands+', covg: '+str(coverage) | |
326 | |
327 if coverage < 1: | |
328 return [] | |
329 | |
330 # sort the list of alleles by read count | |
331 ranked_bases.sort(reverse=True, key=lambda base: base[1]) | |
332 | |
333 if debug: | |
334 print strands+' coverage: '+str(coverage)+', freq_thres: '+str(freq_thres) | |
335 for base in ranked_bases: | |
336 print (base[0]+': '+str(base[1])+'/'+str(float(coverage))+' = '+ | |
337 str(base[1]/float(coverage))) | |
338 | |
339 # remove bases below the frequency threshold | |
340 ranked_bases = [base for base in ranked_bases | |
341 if base[1]/float(coverage) >= freq_thres] | |
342 | |
343 return ranked_bases | |
344 | |
345 | |
346 def count_alleles(variant_counts, freq_thres, debug=False): | |
347 """Determine how many alleles to report, based on filtering rules. | |
348 The current rule determines which bases pass the frequency threshold on each | |
349 strand individually, then compares the two sets of bases. If they are the same | |
350 (regardless of order), the allele count is the number of bases. Otherwise it | |
351 is zero.""" | |
352 allele_count = 0 | |
353 | |
354 alleles_plus = get_read_counts(variant_counts, freq_thres, debug=debug, | |
355 strands='+') | |
356 alleles_minus = get_read_counts(variant_counts, freq_thres, debug=debug, | |
357 strands='-') | |
358 | |
359 if debug: | |
360 print '+ '+str(alleles_plus) | |
361 print '- '+str(alleles_minus) | |
362 | |
363 # check if each strand reports the same set of alleles | |
364 alleles_plus_sorted = sorted([base[0] for base in alleles_plus if base[1]]) | |
365 alleles_minus_sorted = sorted([base[0] for base in alleles_minus if base[1]]) | |
366 if alleles_plus_sorted == alleles_minus_sorted: | |
367 allele_count = len(alleles_plus) | |
368 | |
369 return allele_count | |
370 | |
371 | |
372 def fail(message): | |
373 sys.stderr.write(message+'\n') | |
374 sys.exit(1) | |
375 | |
376 if __name__ == "__main__": | |
377 main() |