#! /usr/bin/env python
"""
Read a mapping file (many formats supported) and select some of them
Mappings should be sorted by read names
"""

import os, random
from optparse import OptionParser, OptionGroup
from parsing.parserChooser import *
from writer.bedWriter import *
from writer.ucscWriter import *
from writer.gbWriter import *
from writer.gff2Writer import *
from writer.gff3Writer import *
from writer.fastaWriter import *
from writer.fastqWriter import *
from writer.mySqlTranscriptWriter import *
from misc.rPlotter import *
from misc.progress import *


distanceExons = 30
exonSize      = 30


class MapperAnalyzer(object):
  """
  Analyse the output of a parser
  """

  def __init__(self, verbosity = 0):
    self.verbosity             = verbosity
    self.mySqlConnection       = MySqlConnection(verbosity)
    self.tooShort              = 0
    self.tooManyMismatches     = 0
    self.tooManyGaps           = 0
    self.tooShortExons         = 0
    self.tooManyMappings       = 0
    self.nbMappings            = 0
    self.nbSequences           = 0
    self.nbAlreadyMapped       = 0
    self.nbWrittenMappings     = 0
    self.nbWrittenSequences    = 0
    self.parser                = None
    self.logHandle             = None
    self.randomNumber          = random.randint(0, 100000)
    self.gff3Writer            = None
    self.alreadyMappedReader   = None
    self.unmatchedWriter       = None
    self.sequenceListParser    = None
    self.sequenceTable         = None
    self.alreadyMappedTable    = None
    self.mappedNamesTable      = None
    self.minSize               = None
    self.minId                 = None
    self.maxMismatches         = None 
    self.maxGaps               = None 
    self.maxMappings           = None 
    self.exons                 = True
    self.suffix                = None


  def __del__(self):
    if self.sequenceTable != None:
      self.sequenceTable.remove()
    if self.alreadyMappedTable != None:
      self.alreadyMappedTable.remove()
    if self.mappedNamesTable != None:
      self.mappedNamesTable.remove()
    if self.gff3Writer != None:
      self.gff3Writer.close()

    if self.logHandle != None:
      self.logHandle.close()
    

  def setMappingFile(self, fileName, format):
    parserChooser = ParserChooser(self.verbosity)
    parserChooser.findFormat(format, "mapping")
    self.parser = parserChooser.getParser(fileName)


  def setSequenceFile(self, fileName, format):
    if format == "fasta":
      self.sequenceListParser = FastaParser(fileName, self.verbosity)
    elif format == "fastq":
      self.sequenceListParser = FastqParser(fileName, self.verbosity)
    else:
      sys.exit("Do not understand sequence format %s" % (format))

  
  def setOutputFile(self, fileName, title):
    self.gff3Writer = Gff3Writer(fileName, self.verbosity)
    self.gff3Writer.setTitle(title)

  
  def setAlreadyMatched(self, fileName):
    self.alreadyMappedReader = GffParser(fileName, self.verbosity)


  def setRemainingFile(self, fileName, format):
    if format == "fasta":
      self.unmatchedWriter = FastaWriter("%s_unmatched.fasta" % (fileName), self.verbosity)
    elif format == "fastq":
      self.unmatchedWriter = FastqWriter("%s_unmatched.fastq" % (fileName), self.verbosity)
    else:
      sys.exit("Do not understand %s format." % (format))
    self.mappedNamesTable = MySqlTable("mappedNames_%d" % (self.randomNumber), self.verbosity)
    self.mappedNamesTable.create(["name"], {"name": "char"}, {"name": 50})


  def setLog(self, fileName):
    self.logHandle = open(fileName, "w")


  def setMinSize(self, size):
    self.minSize = size


  def setMinId(self, id):
    self.minId = id


  def setMaxMismatches(self, mismatches):
    self.maxMismatches = mismatches


  def setMaxGaps(self, gaps):
    self.maxGaps = gaps


  def setMaxMappings(self, mappings):
    self.maxMappings = mappings


  def acceptExons(self, b):
    self.exons = b


  def readInputMappings(self):
    self.nbMappings = self.parser.getNbMappings()
    if self.verbosity > 0:
      print "%i matches found" % (self.nbMappings)


  def storeAlreadyMapped(self):
    self.alreadyMappedTable = MySqlTable("tmpAlreadyMapped_%d" % (self.randomNumber), self.verbosity)
    self.alreadyMappedTable.create(["name"], {"name": "varchar"}, {"name": 100})
    self.alreadyMappedTable.createIndex("iSequences", ["name", ])
    tmpTranscriptFileName = "tmpTranscriptFile%d.dat" % (self.randomNumber)
    tmpTranscriptFile     = open(tmpTranscriptFileName, "w")
    progress  = Progress(self.alreadyMappedReader.getNbTranscripts(), "Reading already mapped reads", self.verbosity)
    id        = 1
    for transcript in self.alreadyMappedReader.getIterator():
      tmpTranscriptFile.write("%d\t%s\n" % (id, transcript.name))
      self.gff3Writer.addTranscript(transcript)
      id += 1
      progress.inc()
    tmpTranscriptFile.close()
    self.alreadyMappedTable.loadFromFile(tmpTranscriptFileName)
    progress.done()
    os.remove(tmpTranscriptFileName)
    if self.verbosity >= 10:
      print "Optimizing sequence table..."
    self.alreadyMappedTable.createIndex("iNames", ["name"])
    if self.verbosity >= 10:
      print "... done."
    self.nbAlreadyMapped = id - 1


  def storeSequences(self):
    self.sequenceTable = MySqlTable("tmpSequences_%d" % (self.randomNumber), self.verbosity)
    self.sequenceTable.create(["name", "size"], {"name": "char", "size": "int"}, {"name": 50, "size": 5})

    tmpSequenceFileName = "tmpSequenceFile%d.dat" % (self.randomNumber)
    tmpSequenceFile     = open(tmpSequenceFileName, "w")
    progress            = Progress(self.sequenceListParser.getNbSequences(), "Reading sequences", self.verbosity)
    for sequence in self.sequenceListParser.getIterator():
      name         = sequence.name.split(" ")[0]
      self.nbSequences += 1
      tmpSequenceFile.write("%s\t%s\t%s\n" % (self.nbSequences, name, len(sequence.sequence)))
      progress.inc()
    tmpSequenceFile.close()
    self.sequenceTable.loadFromFile(tmpSequenceFileName)
    progress.done()
    if self.verbosity > 0:
      print "%i sequences read" % (self.nbSequences)
    os.remove(tmpSequenceFileName)
    if self.verbosity >= 10:
      print "Optimizing sequence table..."
    self.sequenceTable.createIndex("iNames", ["name"])
    if self.verbosity >= 10:
      print "... done."


  def checkOrder(self):
    namesTable = MySqlTable("tmpNames_%d" % (self.randomNumber), self.verbosity)
    namesTable.create(["name"], {"name": "varchar"}, {"name": 100})
    namesTable.createIndex("iName", ["name", ])
    previousName = None
    progress = Progress(self.nbMappings, "Checking mapping file", self.verbosity)
    for mapping in self.parser.getIterator():
      name = mapping.queryInterval.name
      if name != previousName and previousName != None:
        query = self.mySqlConnection.executeQuery("SELECT * FROM %s WHERE name = '%s'" % (namesTable.name, previousName))
        if query.isEmpty():
          sys.exit("Error! Input mapping file is not ordered! (Name '%s' occurs at least twice)" % (previousName))
        query = self.mySqlConnection.executeQuery("INSERT INTO %s SET name = '%s'" % (namesTable.name, previousName))
        previousName = name
      progress.inc()
    progress.done()
    namesTable.remove()


  def checkPreviouslyMapped(self, name):
    if self.alreadyMappedReader == None:
      return False
    query = self.mySqlConnection.executeQuery("SELECT * FROM %s WHERE name = '%s' LIMIT 1" % (self.alreadyMappedTable.name, queryName))
    return not query.isEmpty()


  def findOriginalSize(self, name):
    alternate  = "%s/1" % (name)
    if (self.suffix == None) or (not self.suffix):
      query = self.mySqlConnection.executeQuery("SELECT id, size FROM %s WHERE name = '%s' LIMIT 1" % (self.sequenceTable.name, name))
      lines = query.getLines()
      if (len(lines) == 0):
        if self.suffix == None:
          self.suffix = True
        else:
          sys.exit("Cannot find name %n" % (name))
      else:
        self.suffix = False
    if (self.suffix):
      query = self.mySqlConnection.executeQuery("SELECT id, size FROM %s WHERE name = '%s' LIMIT 1" % (self.sequenceTable.name, alternate))
      lines = query.getLines()
      if len(lines) == 0:
        sys.exit("Cannot find name %s" % (name))
    size = int(lines[0][1])
    return size
    

  def checkErrors(self, mapping):
    accepted = True
    # short size
    if self.minSize != None and mapping.size * 100 < self.minSize * mapping.queryInterval.size:
      self.tooShort += 1
      accepted  = False
      if self.logHandle != None:
        self.logHandle.write("size of mapping %s is too short (%i instead of %i)\n" % (str(mapping), mapping.queryInterval.size, mapping.size))
    # low identity
    if self.minId != None and mapping.getTagValue("identity") < self.minId:
      self.tooManyMismatches += 1
      accepted           = False
      if self.logHandle != None:
        self.logHandle.write("mapping %s has a low identity rate\n" % (str(mapping)))
    # too many mismatches
    if self.maxMismatches != None and mapping.getTagValue("nbMismatches") > self.maxMismatches:
      self.tooManyMismatches += 1
      accepted           = False
      if self.logHandle != None:
        self.logHandle.write("mapping %s has more mismatches than %i\n" % (str(mapping), self.maxMismatches))
    #  too many gaps
    if self.maxGaps != None and mapping.getTagValue("nbGaps") > self.maxGaps:
      self.tooManyGaps += 1
      accepted     = False
      if self.logHandle != None:
        self.logHandle.write("mapping %s has more gaps than %i\n" % (str(mapping), self.maxGaps))
    # short exons
    if self.exons and min([subMapping.targetInterval.getSize() for subMapping in mapping.subMappings]) < exonSize:
      self.tooShortExons += 1
      accepted       = False
      if self.logHandle != None:
        self.logHandle.write("sequence %s maps as too short exons\n" % (mapping))
    return accepted

  
  def checkNbMappings(self, mappings):
    nbOccurrences = 0
    for mapping in mappings:
      nbOccurrences += 1 if "nbOccurrences" not in mapping.getTagNames() else mapping.getTagValue("nbOccurrences")
    if (self.maxMappings != None and nbOccurrences > self.maxMappings):
      self.tooManyMappings += 1
      if self.logHandle != None:
        self.logHandle.write("sequence %s maps %i times\n" % (queryName, nbOccurrences))
      return False
    return (nbOccurrences > 0)


  def sortMappings(self, mappings):
    nbOccurrences = 0
    for mapping in mappings:
      nbOccurrences += 1 if "nbOccurrences" not in mapping.getTagNames() else mapping.getTagValue("nbOccurrences")

    orderedMappings = sorted(mappings, key = lambda mapping: mapping.getErrorScore())
    cpt                = 1
    rank               = 1
    previousMapping    = None
    previousScore      = None
    wasLastTie         = False
    rankedMappings     = []
    bestRegion         = "%s:%d-%d" % (orderedMappings[0].targetInterval.chromosome, orderedMappings[0].targetInterval.start, orderedMappings[0].targetInterval.end)
    for mapping in orderedMappings:
      mapping.setNbOccurrences(nbOccurrences)
      mapping.setOccurrence(cpt)

      score = mapping.getErrorScore()
      if previousScore != None and previousScore == score:
        if "Rank" in previousMapping.getTagNames():
          if not wasLastTie:
            previousTranscript.setRank("%sTie" % (rank))
          mapping.setRank("%sTie" % (rank))
          wasLastTie = True
      else:
        rank = cpt
        mapping.setRank(rank)
        wasLastTie = False
      if cpt != 1:
        mapping.setBestRegion(bestRegion)

      rankedMappings.append(mapping)
      previousMapping = mapping
      previousScore   = score
      cpt            += 1
    return rankedMappings


  def processMappings(self, mappings):
    if not mappings:
      return
    selectedMappings = []
    name = mappings[0].queryInterval.name
    size = self.findOriginalSize(name)
    for mapping in mappings:
      if not self.exons:
        mapping.mergeExons(distanceExons)
      mapping.queryInterval.size = size
      if self.checkErrors(mapping):
        selectedMappings.append(mapping)

    if self.checkNbMappings(selectedMappings):
      # remember names
      if self.unmatchedWriter != None:
        query = self.mySqlConnection.executeQuery("INSERT INTO %s SET name = '%s'" % (self.mappedNamesTable.name, name if not self.suffix else "%s/1" % (name)))
        
      mappings = self.sortMappings(selectedMappings)
      for mapping in mappings:
        self.nbWrittenMappings += 1
        self.gff3Writer.addTranscript(mapping.getTranscript())
    

  def writeUnmatched(self):
    progress = Progress(self.sequenceListParser.getNbSequences(), "Reading unmatched sequences", self.verbosity)
    for sequence in self.sequenceListParser.getIterator():
      name = sequence.name.split(" ")[0]
      progress.inc()
      query = self.mySqlConnection.executeQuery("SELECT * FROM %s WHERE name = '%s' LIMIT 1" % (self.mappedNamesTable.name, name))
      if query.isEmpty():
        self.unmatchedWriter.addSequence(sequence)
      progress.inc()
    progress.done() 


  def analyze(self):
    self.readInputMappings()
    self.checkOrder()
    self.storeSequences()

    if self.alreadyMappedReader != None:
      self.storeAlreadyMapped()

    # read mappings
    previousQueryName = None
    mappings          = []
    self.parser.reset()
    progress = Progress(self.nbMappings, "Reading mappings", self.verbosity)
    for mapping in self.parser.getIterator():
      progress.inc()

      queryName = mapping.queryInterval.name.split(" ")[0]

      # skip if mapping has been previously seen
      if self.checkPreviouslyMapped(queryName):
        if self.logHandle != None:
          self.logHandle.write("Mapping %s has already been mapped.\n" % (queryName))
      else:
        if previousQueryName == queryName:
          mappings.append(mapping)
        else:
          if previousQueryName != None:
            self.processMappings(mappings)
          previousQueryName = queryName
          mappings          = [mapping, ]

    self.processMappings(mappings)
    progress.done()

    if self.unmatchedWriter != None:
      self.writeUnmatched()




if __name__ == "__main__":
  
  # parse command line
  description = "Mapper Analyzer: Read the output of an aligner, print statistics and possibly translate into BED or GBrowse formats. [Category: Mappings]"

  parser = OptionParser(description = description)
  compGroup = OptionGroup(parser, "Compulsory options")
  filtGroup = OptionGroup(parser, "Filtering options")
  tranGroup = OptionGroup(parser, "Transformation options")
  outpGroup = OptionGroup(parser, "Output options")
  otheGroup = OptionGroup(parser, "Other options")
  compGroup.add_option("-i", "--input",        dest="inputFileName",     action="store",                        type="string", help="input file (output of the tool) [compulsory] [format: file in mapping format given by -f]")
  compGroup.add_option("-f", "--format",       dest="format",            action="store",      default="seqmap", type="string", help="format of the file [compulsory] [format: mapping file format]")
  compGroup.add_option("-q", "--sequences",    dest="sequencesFileName", action="store",                        type="string", help="file of the sequences [compulsory] [format: file in sequence format given by -k]")
  compGroup.add_option("-k", "--seqFormat",    dest="sequenceFormat",    action="store",      default="fasta",  type="string", help="format of the sequences: fasta or fastq [default: fasta] [format: sequence file format]")
  compGroup.add_option("-o", "--output",       dest="outputFileName",    action="store",                        type="string", help="output file [compulsory] [format: output file in GFF3 format]")
  filtGroup.add_option("-n", "--number",       dest="number",            action="store",      default=None,     type="int",    help="max. number of occurrences of a sequence [format: int]")
  filtGroup.add_option("-s", "--size",         dest="size",              action="store",      default=None,     type="int",    help="minimum pourcentage of size [format: int]")
  filtGroup.add_option("-d", "--identity",     dest="identity",          action="store",      default=None,     type="int",    help="minimum pourcentage of identity [format: int]")
  filtGroup.add_option("-m", "--mismatch",     dest="mismatch",          action="store",      default=None,     type="int",    help="maximum number of mismatches [format: int]")
  filtGroup.add_option("-p", "--gap",          dest="gap",               action="store",      default=None,     type="int",    help="maximum number of gaps [format: int]")
  tranGroup.add_option("-e", "--mergeExons",   dest="mergeExons",        action="store_true", default=False,                   help="merge exons when introns are short [format: bool] [default: false]")
  tranGroup.add_option("-x", "--removeExons",  dest="removeExons",       action="store_true", default=False,                   help="remove transcripts when exons are short [format: bool] [default: false]")
  outpGroup.add_option("-t", "--title",        dest="title",             action="store",      default="SMART",  type="string", help="title of the UCSC track [format: string] [default: SMART]")
  outpGroup.add_option("-r", "--remaining",    dest="remaining",         action="store_true", default=False,                   help="print the unmatched sequences [format: bool] [default: false]")
  otheGroup.add_option("-a", "--append",       dest="appendFileName",    action="store",      default=None,     type="string", help="append to GFF3 file [format: file in GFF3 format]")  
  otheGroup.add_option("-v", "--verbosity",    dest="verbosity",         action="store",      default=1,        type="int",    help="trace level [default: 1] [format: int]")
  otheGroup.add_option("-l", "--log",          dest="log",               action="store_true", default=False,                   help="write a log file [format: bool] [default: false]")
  parser.add_option_group(compGroup)
  parser.add_option_group(filtGroup)
  parser.add_option_group(tranGroup)
  parser.add_option_group(outpGroup)
  parser.add_option_group(otheGroup)
  (options, args) = parser.parse_args()

  
  analyzer = MapperAnalyzer(options.verbosity)
  analyzer.setMappingFile(options.inputFileName, options.format)
  analyzer.setSequenceFile(options.sequencesFileName, options.sequenceFormat)
  analyzer.setOutputFile("%s.gff3" % (options.outputFileName), options.title)
  if options.appendFileName != None:
    analyzer.setAlreadyMatched(options.append)
  if options.remaining:
    analyzer.setRemainingFile(options.outputFileName, options.sequenceFormat)
  if options.number != None:
    analyzer.setMaxMappings(options.number)
  if options.size != None:
    analyzer.setMinSize(options.size)
  if options.identity != None:
    analyzer.setMinId(options.identity)
  if options.mismatch != None:
    analyzer.setMaxMismatches(options.mismatches)
  if options.gap != None:
    analyzer.setMaxGaps(options.gaps)
  if options.mergeExons:
    analyzer.acceptExons(False)
  if options.log:
    analyzer.setLog("%s.log" % (options.outputFileName))
  analyzer.analyze()
  
  if options.verbosity > 0:
    print "kept %i sequences over %s (%f%%)" % (analyzer.nbWrittenSequences, analyzer.nbSequences, float(analyzer.nbWrittenSequences) / analyzer.nbSequences * 100)
    if options.appendFileName != None:
      print "kept %i sequences over %s (%f%%) including already mapped sequences" % (analyzer.nbMappings + analyzer.nbAlreadyMapped, analyzer.nbSequences, float(analyzer.nbMappings + analyzer.nbAlreadyMapped) / nbSequences * 100)
    print "kept %i mappings over %i (%f%%)" % (analyzer.nbWrittenMappings, analyzer.nbMappings, float(analyzer.nbWrittenMappings) / analyzer.nbMappings * 100)
    print "removed %i too short mappings (%f%%)" % (analyzer.tooShort, float(analyzer.tooShort) / analyzer.nbMappings * 100)
    print "removed %i mappings with too many mismatches (%f%%)" % (analyzer.tooManyMismatches, float(analyzer.tooManyMismatches) / analyzer.nbMappings * 100)
    print "removed %i mappings with too many gaps (%f%%)" % (analyzer.tooManyGaps, float(analyzer.tooManyGaps) / analyzer.nbMappings * 100)
    print "removed %i mappings with too short exons (%f%%)" % (analyzer.tooShortExons, float(analyzer.tooShortExons) / analyzer.nbMappings * 100)
    print "removed %i sequences with too many hits (%f%%)" % (analyzer.tooManyMappings, float(analyzer.tooManyMappings) / analyzer.nbSequences * 100)
    print "%i sequences have no mapping (%f%%)" % (analyzer.nbSequences - analyzer.nbMappings, float(analyzer.nbSequences - analyzer.nbMappings) / analyzer.nbSequences * 100)
    if options.appendFileName != None:
      print "%i sequences have no mapping (%f%%) excluding already mapped sequences" % (analyzer.nbSequences - analyzer.nbMappings - analyzer.nbAlreadyMapped, float(analyzer.nbSequences - analyzer.nbMappings - analyzer.nbAlreadyMapped) / analyzer.nbSequences * 100)

