diff allele-counts.py @ 5:31361191d2d2

Uploaded tarball. Version 1.1: Stranded output, slightly different handling of minor allele ties and 0 coverage sites, revised help text, added test datasets.
author nick
date Thu, 12 Sep 2013 11:34:23 -0400
parents 318fdf77aa54
children df3b28364cd2
line wrap: on
line diff
--- a/allele-counts.py	Tue Jun 04 00:16:29 2013 -0400
+++ b/allele-counts.py	Thu Sep 12 11:34:23 2013 -0400
@@ -1,25 +1,25 @@
 #!/usr/bin/python
 # This parses the output of Dan's "Naive Variant Detector" (previously,
 # "BAM Coverage"). It was forked from the code of "bam-coverage.py".
+# Or, to put it briefly,
+# cat variants.vcf | grep -v '^#' | cut -f 10 | cut -d ':' -f 4 | tr ',=' '\t:'
 #
 # New in this version:
 #   Made header line customizable
 #     - separate from internal column labels, which are used as dict keys
-#
-# TODO:
-# - test handling of -c 0 (and -f 0?)
-# - should it technically handle data lines that start with a '#'?
 import os
 import sys
+import random
 from optparse import OptionParser
 
 COLUMNS = ['sample', 'chr', 'pos', 'A', 'C', 'G', 'T', 'coverage', 'alleles', 'major', 'minor', 'freq'] #, 'bias']
 COLUMN_LABELS = ['SAMPLE', 'CHR',  'POS', 'A', 'C', 'G', 'T', 'CVRG', 'ALLELES', 'MAJOR', 'MINOR', 'MINOR.FREQ.PERC.'] #, 'STRAND.BIAS']
 CANONICAL_VARIANTS = ['A', 'C', 'G', 'T']
-USAGE = """Usage: cat variants.vcf | %prog [options] > alleles.csv
-       %prog [options] -i variants.vcf -o alleles.csv"""
+USAGE = """Usage: %prog [options] -i variants.vcf -o alleles.csv
+       cat variants.vcf | %prog [options] > alleles.csv"""
 OPT_DEFAULTS = {'infile':'-', 'outfile':'-', 'freq_thres':1.0, 'covg_thres':100,
-  'print_header':False, 'stdin':False}
+  'print_header':False, 'stdin':False, 'stranded':False, 'no_filter':False,
+  'debug_loc':'', 'seed':''}
 DESCRIPTION = """This will parse the VCF output of Dan's "Naive Variant Caller" (aka "BAM Coverage") Galaxy tool. For each position reported, it counts the number of reads of each base, determines the major allele, minor allele (second most frequent variant), and number of alleles above a threshold. So currently it only considers SNVs (ACGT), including in the coverage figure. By default it reads from stdin and prints to stdout."""
 EPILOG = """Requirements:
 The input VCF must report the variants for each strand.
@@ -40,27 +40,35 @@
     help='Print output data to this file instead of stdout.')
   parser.add_option('-f', '--freq-thres', dest='freq_thres', type='float',
     default=defaults.get('freq_thres'),
-    help='Frequency threshold for counting alleles, given in percentage: -f 1 = 1% frequency. Default is %default%.')
+    help=('Frequency threshold for counting alleles, given in percentage: -f 1 '
+      +'= 1% frequency. Default is %default%.'))
   parser.add_option('-c', '--covg-thres', dest='covg_thres', type='int',
     default=defaults.get('covg_thres'),
-    help='Coverage threshold. Each site must be supported by at least this many reads on each strand. Otherwise the site will not be printed in the output. The default is %default reads per strand.')
+    help=('Coverage threshold. Each site must be supported by at least this '
+      +'many reads on each strand. Otherwise the site will not be printed in '
+      +'the output. The default is %default reads per strand.'))
+  parser.add_option('-n', '--no-filter', dest='no_filter', action='store_const',
+    const=not(defaults.get('no_filter')), default=defaults.get('no_filter'),
+    help=('Operate without a frequency threshold or coverage threshold. '
+      +'Equivalent to "-c 0 -f 0".'))
   parser.add_option('-H', '--header', dest='print_header', action='store_const',
     const=not(defaults.get('print_header')), default=defaults.get('print_header'),
-    help='Print header line. This is a #-commented line with the column labels. Off by default.')
-  parser.add_option('-d', '--debug', dest='debug', action='store_true',
-    default=False,
-    help='Turn on debug mode. You must also specify a single site to process in a final argument using UCSC coordinate format.')
+    help=('Print header line. This is a #-commented line with the column '
+      +'labels. Off by default.'))
+  parser.add_option('-s', '--stranded', dest='stranded', action='store_const',
+    const=not(defaults.get('stranded')), default=defaults.get('stranded'),
+    help='Report variant counts by strand, in separate columns. Off by default.')
+  parser.add_option('-r', '--rand-seed', dest='seed',
+    default=defaults.get('seed'), help=('Seed for random number generator.'))
+  parser.add_option('-d', '--debug', dest='debug_loc',
+    default=defaults.get('debug_loc'),
+    help=('Turn on debug mode and specify a single site to process using UCSC '
+      +'coordinate format. You can also append a sample ID after another ":" '
+      +'to restrict it further.'))
 
   (options, args) = parser.parse_args()
 
-  # read in positional arguments
-  arguments = {}
-  if options.debug:
-    if len(args) >= 1:
-      arguments['print_loc'] = args[0]
-      args.remove(args[0])
-
-  return (options, arguments)
+  return (options, args)
 
 
 def main():
@@ -72,19 +80,26 @@
   print_header = options.print_header
   freq_thres = options.freq_thres / 100.0
   covg_thres = options.covg_thres
-  debug = options.debug
+  stranded = options.stranded
+  debug_loc = options.debug_loc
+  seed = options.seed
+  
+  if options.no_filter:
+    freq_thres = 0
+    covg_thres = 0
 
-  if debug:
-    print_loc = args.get('print_loc')
-    if print_loc:
-      if ':' in print_loc:
-        (print_chr, print_pos) = print_loc.split(':')
-      else:
-        print_pos = print_loc
-    else:
-      sys.stderr.write("Warning: No site coordinate found in arguments. "
-        +"Turning off debug mode.\n")
-      debug = False
+  if seed:
+    random.seed(seed)
+
+  debug = False
+  print_sample = ''
+  if debug_loc:
+    debug = True
+    coords = debug_loc.split(':')
+    print_chr = coords[0]
+    print_pos = ''
+    if len(coords) > 1: print_pos = coords[1]
+    if len(coords) > 2: print_sample = coords[2]
 
   # set infile_handle to either stdin or the input file
   if infile == OPT_DEFAULTS.get('infile'):
@@ -105,12 +120,19 @@
     except IOError, e:
       fail('Error: The given output filename '+outfile+' could not be opened.')
 
+  # Take care of column names, print header
   if len(COLUMNS) != len(COLUMN_LABELS):
-    fail('Error: Internal column names do not match column labels.')
+    fail('Error: Internal column names list do not match column labels list.')
+  if stranded:
+    COLUMNS[3:7]       = ['+A', '+C', '+G', '+T', '-A', '-C', '-G', '-T']
+    COLUMN_LABELS[3:7] = ['+A', '+C', '+G', '+T', '-A', '-C', '-G', '-T']
   if print_header:
     outfile_handle.write('#'+'\t'.join(COLUMN_LABELS)+"\n")
 
-  # main loop: process and print one line at a time
+  # main loop
+  # each iteration processes one VCF line and prints one or more output lines
+  # one VCF line    = one site, one or more samples
+  # one output line = one site, one sample
   sample_names = []
   for line in infile_handle:
     line = line.rstrip('\r\n')
@@ -128,19 +150,24 @@
     site_data = read_site(line, sample_names, CANONICAL_VARIANTS)
 
     if debug:
-      if site_data['pos'] != print_pos:
+      if print_pos != site_data['pos']:
+        continue
+      if print_chr != site_data['chr'] and print_chr != '':
         continue
-      try:
-        if site_data['chr'] != print_chr:
-          continue
-      except NameError, e:
-        pass  # No chr specified. Just go ahead and print the line.
+      if print_sample != '':
+        for sample in site_data['samples'].keys():
+          if sample.lower() != print_sample.lower():
+            site_data['samples'].pop(sample, None)
+        if len(site_data['samples']) == 0:
+          sys.stderr.write("Error: Sample '"+print_sample+"' not found.\n")
+          sys.exit(1)
+
 
     site_summary = summarize_site(site_data, sample_names, CANONICAL_VARIANTS,
-      freq_thres, covg_thres, debug=debug)
+      freq_thres, covg_thres, stranded, debug=debug)
 
     if debug and site_summary[0]['print']:
-      print line.split('\t')[9].split(':')[-1]
+        print line.split('\t')[9].split(':')[-1]
 
     print_site(outfile_handle, site_summary, COLUMNS)
 
@@ -158,8 +185,27 @@
 def read_site(line, sample_names, canonical):
   """Read in a line, parse the variants into a data structure, and return it.
   The line should be actual site data, not a header line, so check beforehand.
-  Notes:
-  - The line is assumed to have been chomped."""
+  Only the variants in 'canonical' will be read; all others are ignored.
+  Note: the line is assumed to have been chomped.
+  The returned data is stored in a dict, with the following structure:
+  {
+    'chr': 'chr1',
+    'pos': '2617',
+    'samples': {
+      'THYROID': {
+        '+A': 32,
+        '-A': 45,
+        '-G': 1,
+      },
+      'BLOOD': {
+        '+A': 2,
+        '-C': 1,
+        '+G': 37,
+        '-G': 42,
+      },
+    },
+  }
+  """
   
   site = {}
   fields = line.split('\t')
@@ -212,7 +258,7 @@
 
 
 def summarize_site(site, sample_names, canonical, freq_thres, covg_thres,
-  debug=False):
+  stranded=False, debug=False):
   """Take the raw data from the VCF line and transform it into the summary data
   to be printed in the output format."""
 
@@ -221,9 +267,6 @@
 
     sample = {'print':False}
     variants = site['samples'].get(sample_name)
-    if not variants:
-      site_summary.append(sample)
-      continue
 
     sample['sample'] = sample_name
     sample['chr']    = site['chr']
@@ -240,31 +283,49 @@
       elif variant[0] == '-':
         covg_minus += variants[variant]
     # stranded coverage threshold
-    if coverage <= 0 or covg_plus < covg_thres or covg_minus < covg_thres:
+    if covg_plus < covg_thres or covg_minus < covg_thres:
       site_summary.append(sample)
       continue
     else:
       sample['print'] = True
 
-    # get an ordered list of read counts for all variants (either strand)
-    ranked_bases = get_read_counts(variants, 0, strands='+-', debug=debug)
+    # get an ordered list of read counts for all variants (both strands)
+    bases = get_read_counts(variants, '+-')
+    ranked_bases = process_read_counts(bases, sort=True, debug=debug)
+
+    # prepare stranded or unstranded lists of base counts
+    base_count_lists = []
+    if stranded:
+      strands = ['+', '-']
+      base_count_lists.append(get_read_counts(variants, '+'))
+      base_count_lists.append(get_read_counts(variants, '-'))
+    else:
+      strands = ['']
+      base_count_lists.append(ranked_bases)
 
-    # record read counts into dict for this sample
-    for base in ranked_bases:
-      sample[base[0]] = base[1]
-    # fill in any zeros
-    for variant in canonical:
-      if not sample.has_key(variant):
-        sample[variant] = 0
+    # record read counts into output dict
+    # If stranded, this will loop twice, once for each strand, and prepend '+'
+    # or '-' to the base name. If not stranded, it will loop once, and prepend
+    # nothing ('').
+    for (strand, base_count_list) in zip(strands, base_count_lists):
+      for base_count in base_count_list:
+        sample[strand+base_count[0]] = base_count[1]
+      # fill in any zeros
+      for base in canonical:
+        if not sample.has_key(strand+base):
+          sample[strand+base] = 0
 
-    sample['alleles']  = count_alleles(variants, freq_thres, debug=debug)
+    sample['alleles'] = count_alleles(variants, freq_thres, debug=debug)
 
-    # set minor allele to N if there's a tie for 2nd
+    # If there's a tie for 2nd, randomly choose one to be 2nd
     if len(ranked_bases) >= 3 and ranked_bases[1][1] == ranked_bases[2][1]:
-      ranked_bases[1] = ('N', 0)
-      sample['alleles'] = 1 if sample['alleles'] else 0
+      swap = random.choice([True,False])
+      if swap:
+        tmp_base = ranked_bases[1]
+        ranked_bases[1] = ranked_bases[2]
+        ranked_bases[2] = tmp_base
 
-    if debug: print ranked_bases
+    if debug: print "ranked +-: "+str(ranked_bases)
 
     sample['coverage'] = coverage
     try:
@@ -283,67 +344,63 @@
   return site_summary
 
 
-def print_site(filehandle, site, columns):
-  """Print the output lines for one site (one per sample).
-  filehandle must be open."""
-  for sample in site:
-    if sample['print']:
-      fields = [str(sample.get(column)) for column in columns]
-      filehandle.write('\t'.join(fields)+"\n")
+def get_read_counts(stranded_counts, strands='+-'):
+  """Do a simple sum of the read counts per variant, on the specified strands.
+      Arguments:
+  stranded_counts: Dict of the stranded variants (keys) and their read counts
+    (values).
+  strands: Which strand(s) to count. Can be '+', '-', or '+-' for both (default).
+      Return value:
+  summed_counts: A list of the alleles and their read counts. The elements are
+    tuples (variant, read count)."""
+
+  variants = stranded_counts.keys()
+
+  summed_counts = {}
+  for variant in variants:
+    strand = variant[0]
+    base = variant[1:]
+    if strand in strands:
+      summed_counts[base] = stranded_counts[variant] + summed_counts.get(base, 0)
+
+  return summed_counts.items()
 
 
-def get_read_counts(variant_counts, freq_thres, strands='+-', debug=False):
-  """Count the number of reads for each base, and create a ranked list of
-  alleles passing the frequency threshold.
+def process_read_counts(variant_counts, freq_thres=0, sort=False, debug=False):
+  """Process a list of read counts by frequency filtering and/or sorting.
       Arguments:
-  variant_counts: Dict of the stranded variants (keys) and their read counts (values).
+  variant_counts: List of the non-stranded variants and their read counts. The
+    elements are tuples (variant, read count).
   freq_thres: The frequency threshold each allele needs to pass to be included.
-  strands: Which strand(s) to count. Can be '+', '-', or '+-' for both (default).
-  variants: A list of the variants of interest. Other types of variants will not
-    be included in the returned list. If no list is given, all variants found in
-    the variant_counts will be used.
+  sort: Whether to sort the list in descending order of read counts.
       Return value:
-  ranked_bases: A list of the alleles and their read counts. The elements are
-    tuples (base, read count). The alleles are listed in descending order of
-    frequency, and only those passing the threshold are included."""
-
-  # Get list of all variants from variant_counts list
-  variants = [variant[1:] for variant in variant_counts]
-  # deduplicate via a dict
-  variant_dict = dict((variant, 1) for variant in variants)
-  variants = variant_dict.keys()
-
-  ranked_bases = []
-  for variant in variants:
-    reads = 0
-    for strand in strands:
-      reads += variant_counts.get(strand+variant, 0)
-    ranked_bases.append((variant, reads))
+  variant_counts: A list of the processed alleles and their read counts. The
+    elements are tuples (variant, read count)."""
 
   # get coverage for the specified strands
   coverage = 0
   for variant in variant_counts:
-    if variant[0] in strands:
-      coverage += variant_counts.get(variant, 0)
-  # if debug: print "strands: "+strands+', covg: '+str(coverage)
+    coverage += variant[1]
 
-  if coverage < 1:
+  if coverage <= 0:
     return []
 
   # sort the list of alleles by read count
-  ranked_bases.sort(reverse=True, key=lambda base: base[1])
+  if sort:
+    variant_counts.sort(reverse=True, key=lambda variant: variant[1])
 
   if debug:
-    print strands+' coverage: '+str(coverage)+', freq_thres: '+str(freq_thres)
-    for base in ranked_bases:
-      print (base[0]+': '+str(base[1])+'/'+str(float(coverage))+' = '+
-        str(base[1]/float(coverage)))
+    print 'coverage: '+str(coverage)+', freq_thres: '+str(freq_thres)
+    for variant in variant_counts:
+      print (variant[0]+': '+str(variant[1])+'/'+str(float(coverage))+' = '+
+        str(variant[1]/float(coverage)))
 
   # remove bases below the frequency threshold
-  ranked_bases = [base for base in ranked_bases
-    if base[1]/float(coverage) >= freq_thres]
+  if freq_thres > 0:
+    variant_counts = [variant for variant in variant_counts
+      if variant[1]/float(coverage) >= freq_thres]
 
-  return ranked_bases
+  return variant_counts
 
 
 def count_alleles(variant_counts, freq_thres, debug=False):
@@ -354,16 +411,19 @@
   is zero."""
   allele_count = 0
 
-  alleles_plus  = get_read_counts(variant_counts, freq_thres, debug=debug,
-    strands='+')
-  alleles_minus = get_read_counts(variant_counts, freq_thres, debug=debug,
-    strands='-')
+  alleles_plus  = get_read_counts(variant_counts, '+')
+  alleles_plus  = process_read_counts(alleles_plus, freq_thres=freq_thres,
+    sort=False, debug=debug)
+  alleles_minus = get_read_counts(variant_counts, '-')
+  alleles_minus = process_read_counts(alleles_minus, freq_thres=freq_thres,
+    sort=False, debug=debug)
 
   if debug:
     print '+ '+str(alleles_plus)
     print '- '+str(alleles_minus)
 
-  # check if each strand reports the same set of alleles
+  # Check if each strand reports the same set of alleles.
+  # Sorting by base is to compare lists without regard to order (as sets).
   alleles_plus_sorted  = sorted([base[0] for base in alleles_plus if base[1]])
   alleles_minus_sorted = sorted([base[0] for base in alleles_minus if base[1]])
   if alleles_plus_sorted == alleles_minus_sorted:
@@ -372,9 +432,19 @@
   return allele_count
 
 
+def print_site(filehandle, site, columns):
+  """Print the output lines for one site (one per sample).
+  filehandle must be open."""
+  for sample in site:
+    if sample['print']:
+      fields = [str(sample.get(column)) for column in columns]
+      filehandle.write('\t'.join(fields)+"\n")
+
+
 def fail(message):
   sys.stderr.write(message+'\n')
   sys.exit(1)
 
+
 if __name__ == "__main__":
   main()
\ No newline at end of file