#! /usr/bin/env python
"""
Cluster the data into regions (defined by size and overlap with next region) and keep only highest peaks.
"""

import os
from optparse import OptionParser
from structure.transcriptContainer import *
from writer.transcriptWriter import *
from misc.rPlotter import *
from misc.progress import *
from misc import utils


if __name__ == "__main__":
  
  # parse command line
  description = "Clusterize by Sliding Windows: Produces a GFF3 file that clusters a list of transcripts using a sliding window. [Category: Sliding Windows]"

  parser = OptionParser(description = description)
  parser.add_option("-i", "--input",        dest="inputFileName",     action="store",                        type="string", help="input file [compulsory] [format: file in transcript format given by -f]")
  parser.add_option("-f", "--inputFormat",  dest="inputFormat",       action="store",                        type="string", help="format of the input file [compulsory] [format: transcript file format]")
  parser.add_option("-o", "--output",       dest="outputFileName",    action="store",                        type="string", help="output file [compulsory] [format: output file in GFF3 format]")
  parser.add_option("-s", "--size",         dest="size",              action="store",                        type="int",    help="size of the regions [compulsory] [format: int]")
  parser.add_option("-e", "--overlap",      dest="overlap",           action="store",                        type="int",    help="overlap between two consecutive regions [compulsory] [format: int]")
  parser.add_option("-t", "--threshold",    dest="threshold",         action="store",      default=100,      type="int",    help="keep the highest n% peaks [format: int] [default: 100]")
  parser.add_option("-n", "--number",       dest="number",            action="store",      default=0,        type="int",    help="min. number of features per peak [format: int] [default: 0]")
  parser.add_option("-u", "--unique",       dest="unique",            action="store",      default=0,        type="int",    help="min. number of unique features per peak [default: 0] [format: int] [default: 0]")  
  parser.add_option("-g", "--tag",          dest="tag",               action="store",      default=None,     type="string", help="use a given tag (instead of summing number of features) [format: string]")  
  parser.add_option("-r", "--operation",    dest="operation",         action="store",      default=None,     type="string", help="combine tag value with given operation [format: choice (sum, avg, med, min, max)]")
  parser.add_option("-2", "--strands",      dest="strands",           action="store_true", default=False,                   help="consider the two strands separately [format: bool] [default: false]")
  parser.add_option("-p", "--plot",         dest="plot",              action="store_true", default=False,                   help="plot regions [format: bool] [default: false]")
  parser.add_option("-x", "--excel",        dest="excel",             action="store_true", default=False,                   help="write an Excel file [format: bool] [default: false]")
  parser.add_option("-y", "--mysql",        dest="mysql",             action="store_true", default=False,                   help="mySQL output [format: bool] [default: false]")
  parser.add_option("-v", "--verbosity",    dest="verbosity",         action="store",      default=1,        type="int",    help="trace level [format: int] [default: 1]")
  parser.add_option("-l", "--log",          dest="log",               action="store_true", default=False,                   help="write a log file [format: bool] [default: false]")
  (options, args) = parser.parse_args()

  if options.log:
    logHandle = open("%s.log" % options.outputFileName, "w")

  if options.tag == None and options.operation != None:
    sys.exit("Trying to combine the values without specifying tag! Aborting...")

  if options.operation != None and options.operation not in ("sum", "avg", "med", "min", "max"):
    sys.exit("Do not understand tag '%s'! Aborting..." % (options.operation))

  # remove possible existing output file
  if os.path.exists(options.outputFileName):
    os.unlink(options.outputFileName)

  # create parser
  parser = TranscriptContainer(options.inputFileName, options.inputFormat, options.verbosity)
    
  # get the maximum coordinates for each chromosome
  sizes = {}
  progress = Progress(parser.getNbTranscripts(), "Getting sizes in %s" % (options.inputFileName), options.verbosity)
  for transcript in parser.getIterator():
    if transcript.chromosome not in sizes:
      sizes[transcript.chromosome] = transcript.start
    else:
      sizes[transcript.chromosome] = max(sizes[transcript.chromosome], transcript.start)
    progress.inc()
  progress.done()
  
  print "%d transcripts parsed" % (parser.getNbTranscripts())
  

  # In the following, 1 represents plus strand, -1 represents minus strand and 0 represents both strands
  if options.strands:
    strands = [-1, 1]
  else:
    strands = [0]

  # initialize bins
  binsPerStrand    = {}
  uniquesPerStrand = {}
  sumsPerStrand    = {}
  valuesPerStrand  = {}
  for strand in strands:
    binsPerStrand[strand]    = {}
    uniquesPerStrand[strand] = {}
    sumsPerStrand[strand]    = {}
    valuesPerStrand[strand]  = {}
    for chromosome in sizes:
      binsPerStrand[strand][chromosome]    = dict([[i, 0] for i in range(0, sizes[chromosome] / (options.size - options.overlap) + 1)])
      uniquesPerStrand[strand][chromosome] = dict([[i, 0] for i in range(0, sizes[chromosome] / (options.size - options.overlap) + 1)])
      sumsPerStrand[strand][chromosome]    = dict([[i, 0.0] for i in range(0, sizes[chromosome] / (options.size - options.overlap) + 1)])
      valuesPerStrand[strand][chromosome]  = dict([[i, []] for i in range(0, sizes[chromosome] / (options.size - options.overlap) + 1)])
    
  # set bins
  progress = Progress(parser.getNbTranscripts(), "Setting bins", options.verbosity)
  for transcript in parser.getIterator():
    
    # find number of occurrences
    nbOccurrences = transcript.getTagValue("nbOccurrences")
    if nbOccurrences == None:
      nbOccurrences = 1

    strand = transcript.direction if options.strands else 0
    
    # update bins and number of occurrences
    bin = transcript.start / (options.size - options.overlap)
    binsPerStrand[strand][transcript.chromosome][bin] += 1
    if nbOccurrences == 1:      
      uniquesPerStrand[strand][transcript.chromosome][bin] += 1
    if options.tag != None:
      if options.tag not in transcript.getTagNames():
        sys.exit("Tag %s undefined in transcript %s" % (options.tag, transcript))
      value = float(transcript.getTagValue(options.tag))
      sumsPerStrand[strand][transcript.chromosome][bin] += value
      valuesPerStrand[strand][transcript.chromosome][bin].append(value)
        
    # handle overlap between regions
    if transcript.start <= bin * (options.size - options.overlap) + options.overlap:
      prevBin = bin - 1
      if prevBin >= 0:
        binsPerStrand[strand][transcript.chromosome][prevBin] += 1
        if nbOccurrences == 1:      
          uniquesPerStrand[strand][transcript.chromosome][prevBin] += 1
        if options.tag != None:
          if options.tag not in transcript.getTagNames():
            sys.exit("Tag %s undefined in transcript %s" % (options.tag, transcript))
          value = float(transcript.getTagValue(options.tag))
          sumsPerStrand[strand][transcript.chromosome][prevBin] += value
          valuesPerStrand[strand][transcript.chromosome][prevBin].append(value)
          
    progress.inc()
  progress.done()
  
  # remove bins with few unique matches
  for strand in strands:
    nbRegions = 0
    nbRemoved = 0
    for chromosome in binsPerStrand[strand]:
      toBeRemoved = []
      for pos in binsPerStrand[strand][chromosome]:
        nbRegions += 1
        if (uniquesPerStrand[strand][chromosome][pos] < options.unique):
          toBeRemoved.append(pos)
          nbRemoved += 1
      for pos in toBeRemoved:
        del binsPerStrand[strand][chromosome][pos]
        del sumsPerStrand[strand][chromosome][pos]
        del valuesPerStrand[strand][chromosome][pos]
    adjunct = ""
    if strand != 0:
      adjunct = " on strand %d" % (strand)
    print "%d regions found%s" % (nbRegions, adjunct)
    print "%d regions removed for they have too few unique matches%s" % (nbRemoved, adjunct)
  
  # aggregate data
  if options.operation == "sum":
    toBePlottedPerStrand = sumsPerStrand
  elif options.operation == "avg":
    averagesPerStrand = {}
    for strand in strands:
      averagesPerStrand[strand] = {}
      for chromosome in binsPerStrand[strand]:
        averagesPerStrand[strand][chromosome] = {}
        for bin in binsPerStrand[strand][chromosome]:
          averagesPerStrand[strand][chromosome][bin] = 0 if binsPerStrand[strand][chromosome][bin] == 0 else float(sumsPerStrand[strand][chromosome][bin]) / binsPerStrand[strand][chromosome][bin]
    toBePlottedPerStrand = averagesPerStrand
  elif options.operation == "med":
    medsPerStrand = {}
    for strand in strands:
      medsPerStrand[strand] = {}
      for chromosome in binsPerStrand[strand]:
        medsPerStrand[strand][chromosome] = {}
        for bin in binsPerStrand[strand][chromosome]:
          if valuesPerStrand[strand][chromosome][bin]:
            valuesPerStrand[strand][chromosome][bin].sort()
            size = len(valuesPerStrand[strand][chromosome][bin])
            if size % 2 == 1:
              medsPerStrand[strand][chromosome][bin] = valuesPerStrand[strand][chromosome][bin][(size - 1) / 2]
            else:
              medsPerStrand[strand][chromosome][bin] = (valuesPerStrand[strand][chromosome][bin][size / 2 - 1] + valuesPerStrand[strand][chromosome][bin][size / 2]) / 2.0
          else:
            medsPerStrand[strand][chromosome][bin] = 0
    toBePlottedPerStrand = medsPerStrand
  elif options.operation == "min":
    minsPerStrand = {}
    for strand in strands:
      minsPerStrand[strand] = {}
      for chromosome in binsPerStrand[strand]:
        minsPerStrand[strand][chromosome] = {}
        for bin in binsPerStrand[strand][chromosome]:
          if valuesPerStrand[strand][chromosome][bin]:
            minsPerStrand[strand][chromosome][bin] = min(valuesPerStrand[strand][chromosome][bin])
          else:
            minsPerStrand[strand][chromosome][bin] = 0
    toBePlottedPerStrand = minsPerStrand
  elif options.operation == "max":
    maxsPerStrand = {}
    for strand in strands:
      maxsPerStrand[strand] = {}
      for chromosome in binsPerStrand[strand]:
        maxsPerStrand[strand][chromosome] = {}
        for bin in binsPerStrand[strand][chromosome]:
          if valuesPerStrand[strand][chromosome][bin]:
            maxsPerStrand[strand][chromosome][bin] = max(valuesPerStrand[strand][chromosome][bin])
          else:
            maxsPerStrand[strand][chromosome][bin] = 0
    toBePlottedPerStrand = maxsPerStrand
  else:
    toBePlottedPerStrand = binsPerStrand
  
  # plot the regions
  if options.plot:
    for strand in strands:
      adjunct = ""
      if strand != 0:
        adjunct = "Strand%d" % (strand)
      for chromosome in toBePlottedPerStrand[strand]:
        if len(toBePlottedPerStrand[strand][chromosome].keys()) > 0:
          plotter = RPlotter("%s%s%s.png" % (options.outputFileName, chromosome.capitalize(), adjunct), options.verbosity)
          plotter.setFill(0)
          plotter.addLine(toBePlottedPerStrand[strand][chromosome], chromosome)
          plotter.plot()
        
  # write Excel file
  if options.excel:
    for strand in strands:
      if strand == 0 and not options.strands:
        excelFile = open("%s.csv" % (options.outputFileName), "w")
      elif not options.strands:
        continue
      else:
        excelFile = open("%sStrand%d.csv" % (options.outputFileName, strand), "w")
      maxBin = max([max(toBePlottedPerStrand[strand][chromosome].keys()) for chromosome in binsPerStrand[strand]])
      for bin in range(0, maxBin + 1):
        excelFile.write(",%d-%d" % (bin * (options.size - options.overlap), bin * (options.size - options.overlap) + options.size))
      excelFile.write("\n")
      for chromosome in toBePlottedPerStrand[strand]:
        excelFile.write("%s" % (chromosome))
        for bin in toBePlottedPerStrand[strand][chromosome]:
          excelFile.write(",%f" % (toBePlottedPerStrand[strand][chromosome][bin]))
        excelFile.write("\n")
      excelFile.close()

  # get the threshold
  values = []
  for strand in strands:
    for chromosome in toBePlottedPerStrand[strand]:
      for pos in toBePlottedPerStrand[strand][chromosome]:
        values.append(toBePlottedPerStrand[strand][chromosome][pos])
  values.sort()
  threshold = values[int(1 - (float(options.threshold) / 100) * len(values))]
  print "got %d values" % (len(values))
  print "best is %.1f" % (values[-1])
  print "keeping peaks higher than %d" % (threshold)

  
  # print the regions
  cpt     = 1
  tagOp   = "nb"
  tagName = "Elements"
  if options.operation != None:
    tagOp = options.operation.lower()
  if options.tag != None:
    tagName = options.tag.title()
  writer  = Gff3Writer("%s.gff3" % (options.outputFileName), options.verbosity)
  for strand in strands:
    for chromosome in toBePlottedPerStrand[strand]:
      for pos in toBePlottedPerStrand[strand][chromosome]:
        if toBePlottedPerStrand[strand][chromosome][pos] >= threshold:
          transcript = Transcript()
          transcript.setName("region%d" % cpt)
          transcript.setChromosome(chromosome)
          transcript.setStart(pos * (options.size - options.overlap))
          transcript.setEnd(pos * (options.size - options.overlap) + options.size)
          transcript.setDirection(strand if strand != 0 else 1)
          transcript.setTagValue("nbElements", binsPerStrand[strand][chromosome][pos])
          transcript.setTagValue("%s%s" % (tagOp, tagName), str(toBePlottedPerStrand[strand][chromosome][pos]))
          writer.addTranscript(transcript)
          cpt += 1
          if options.mysql:
            regionWriter.addTranscript(transcript)
        
  print "keeping %d regions" % (cpt - 1)
  
  if options.mysql:
    for strand in strands:
      regionWriter[strand].write()
      transcriptWriter[strand].write()
  
