#! /usr/bin/env python
"""Get the differential expression between 2 conditions (2 files), on regions defined by a third file"""

import os
from optparse import OptionParser
from structure.transcriptContainer import *
from writer.gff3Writer import *
from misc.progress import *
from misc.rPlotter import *
from misc import utils


def setTagValues(transcript, values, pValue):
  for tag in transcript.getTagNames():
    transcript.deleteTag(tag)
  transcript.removeExons()
  transcript.setTagValue("pValue", str(pValue))
  transcript.setTagValue("nbReadsCond1", str(values[0]))
  transcript.setTagValue("nbReadsCond2", str(values[1]))
  if values[0] < values[1]:
    transcript.setTagValue("regulation", "up")
  elif values[0] > values[1]:
    transcript.setTagValue("regulation", "down")
  else:
    transcript.setTagValue("regulation", "equal")
  return transcript


if __name__ == "__main__":
  
  # parse command line
  description = "Get Differential Expression: Get the differential expression between 2 conditions using Fisher's exact test, on regions defined by a third file. [Category: Comparison]"

  parser = OptionParser(description = description)
  parser.add_option("-i", "--input1",          dest="inputFileName1",    action="store",                 type="string", help="input file 1 [compulsory] [format: file in transcript format given by -f]")
  parser.add_option("-f", "--format1",         dest="format1",           action="store",                 type="string", help="format of file 1 [compulsory] [format: transcript file format]")
  parser.add_option("-j", "--input2",          dest="inputFileName2",    action="store",                 type="string", help="input file 2 [compulsory] [format: file in transcript format given by -g]")
  parser.add_option("-g", "--format2",         dest="format2",           action="store",                 type="string", help="format of file 2 [compulsory] [format: transcript file format]")
  parser.add_option("-k", "--reference",       dest="referenceFileName", action="store",                 type="string", help="reference file [compulsory] [format: file in transcript format given by -l]")
  parser.add_option("-l", "--referenceFormat", dest="referenceFormat",   action="store",                 type="string", help="format of reference file [compulsory] [format: transcript file format]")
  parser.add_option("-o", "--output",          dest="outputFileName",    action="store",                 type="string", help="output file [format: output file in CSV format]")
  parser.add_option("-n", "--normalization",   dest="normalization",     action="store_true", default=False,            help="normalize using the number of reads in each condition [format: bool]")
  parser.add_option("-p", "--plot",            dest="plot",              action="store_true", default=False,            help="plot cloud plot [format: bool]")
  parser.add_option("-v", "--verbosity",       dest="verbosity",         action="store",      default=1, type="int",    help="trace level [format: int]")
  (options, args) = parser.parse_args()

  inputs                 = (0, 1)
  mySqlConnection        = MySqlConnection(options.verbosity)
  transcriptContainers   = [TranscriptContainer(options.inputFileName1, options.format1, options.verbosity), TranscriptContainer(options.inputFileName2, options.format2, options.verbosity)]
  transcriptContainerRef = TranscriptContainer(options.referenceFileName, options.referenceFormat, options.verbosity)  
  tables                 = [None, None]
  nbElements             = [0, 0]
  normalizationFactors   = [0, 0]
  writer                 = Gff3Writer("%s.gff3" % options.outputFileName, options.verbosity)
  nbPValuesComputed      = 10000
  computedPValues        = {}
  uncomputedPValues      = {}
  if options.plot:
    plotter = RPlotter("%s.png" % (options.outputFileName), options.verbosity)
    plotter.setPoints(True)
    plotter.setLog("xy")
    points  = {}

  for i in inputs:
    transcriptContainers[i].storeIntoDatabase()
    tables[i] = transcriptContainers[i].getTables()
    for chromosome in tables[i]:
      mySqlConnection.executeQuery("CREATE INDEX iStartEndDir ON %s (start, end, direction)" % tables[i][chromosome].name)

  for i in inputs:
    progress = Progress(transcriptContainers[i].getNbTranscripts(), "Reading sample %d" % (i +1), options.verbosity)
    for chromosome in tables[i]:
      for transcript in tables[i][chromosome].getIterator():
        nbElements[i] += 1 if "nbElements" not in transcript.getTagNames() else transcript.getTagValue("nbElements")
        progress.inc()
    progress.done()
    if options.verbosity > 0:
      print "%d elements in sample %d" % (nbElements[i], i+1)

  avgNbElements = int(float(sum(nbElements)) / len(nbElements))
  if options.normalization and options.verbosity > 1:
    print "Normalization to average # reads: %d" % (avgNbElements)
  for i in inputs:
    normalizationFactors[i] = float(avgNbElements) / nbElements[i]
    if options.normalization:
      nbElements[i] = avgNbElements

  progress = Progress(transcriptContainerRef.getNbTranscripts(), "Reading reference", options.verbosity)
  for transcriptRef in transcriptContainerRef.getIterator():
    progress.inc()
    transcripts = [0 for i in inputs]
    for i in inputs:
      if transcriptRef.chromosome not in tables[i]: continue
      command = "SELECT tags FROM %s WHERE start >= %d AND end <= %d AND direction = %d" % (tables[i][transcriptRef.chromosome].name, transcriptRef.start, transcriptRef.end, transcriptRef.direction)
      query   = mySqlConnection.executeQuery(command)
      for line in query.getIterator():
        nb   = 1
        tags = line[0].split(";")
        for tag in tags:
          key, value = tag.split("=")
          if key == "nbElements":
            nb = int(value)
        transcripts[i] += nb

      if options.normalization:
        transcripts[i] = int(normalizationFactors[i] * transcripts[i])

    if sum(transcripts) == 0:
      continue

    values = (transcripts[0], transcripts[1], nbElements[0] - transcripts[0], nbElements[1] - transcripts[1])
    transcript = Transcript()
    transcript.copy(transcriptRef)

    if options.plot:
      points["%s:%d-%d(%s)" % (transcriptRef.chromosome, transcriptRef.start, transcriptRef.end, "+" if transcriptRef.direction == 1 else "-")] = (transcripts[0], transcripts[1])

    # p-value is already computed
    if values in computedPValues:
      writer.addTranscript(setTagValues(transcript, values, pValue))
      continue

    # keep p-value for postponed computation
    if values not in uncomputedPValues:
      uncomputedPValues[values] = []
    uncomputedPValues[values].append(transcript)

    # compute bulk p-values
    if len(uncomputedPValues.keys()) > nbPValuesComputed:
      computedValues = utils.fisherExactPValueBulk(uncomputedPValues.keys())
      for values, pValue in computedValues.iteritems():
        for transcript in uncomputedPValues[values]:
          writer.addTranscript(setTagValues(transcript, values, pValue))
        # remember results
        computedValues[values] = pValue
      uncomputedPValues = {}


  computedValues = utils.fisherExactPValueBulk(uncomputedPValues.keys())
  for values, pValue in computedValues.iteritems():
    for transcript in uncomputedPValues[values]:
      writer.addTranscript(setTagValues(transcript, values, pValue))
    # remember results
    computedValues[values] = pValue
  uncomputedPValues = {}

  progress.done()

  if options.plot:
    plotter.addLine(points)
    plotter.plot()

