import re
import sys
from structure.subMapping import *
from structure.transcript import *


class Mapping(object):
  """A class that represents a mapping"""

  def __init__(self):
    self.targetInterval   = Interval()
    self.queryInterval    = Interval()
    self.subMappings      = []
    self.size             = None
    self.direction        = 0
    self.transcript       = None
    self.tags             = {}


  def copy(self, mapping):
    for subMapping in mapping.subMappings:
      newSubMapping = SubMapping(subMapping)
      self.addSubMapping(newSubMapping)
    self.targetInterval.copy(mapping.targetInterval)
    self.queryInterval.copy(mapping.queryInterval)
    self.size          = mapping.size
    self.direction     = mapping.direction
    self.tags          = {}
    for tag in mapping.tags:
      self.tags[tag] = mapping[tag]
    self.transcript.copy(mapping.transcript)


  def setTargetInterval(self, interval):
    self.targetInterval.copy(interval)


  def setQueryInterval(self, interval):
    self.queryInterval.copy(interval)


  def addSubMapping(self, subMapping):
    subMappingCopy = SubMapping(subMapping)
    self.subMappings.append(subMappingCopy)

    if self.direction != 0:
      subMapping.direction = self.direction
    if self.size == None:
      self.size = 0
    if "identity" in subMapping.getTagNames() and subMapping.size != None:
      self.setTagValue("identity", (self.getTagValue("identity") * self.size + subMapping.getTagValue("identity") * subMapping.size) / (self.size + subMapping.size))
    if subMapping.size != None:
      self.size += subMapping.size
    if "nbMismatches" in subMapping.getTagNames():
      self.setTagValue("nbMismatches", self.getTagValue("nbMismatches"), subMapping.getTagValue("nbMismatches"))
    if "nbGaps" in subMapping.getTagNames():
      self.setTagValue("nbGaps", self.getTagValue("nbGaps"), subMapping.getTagValue("nbGaps"))


  def setDirection(self, direction):
    if type(direction).__name__ == 'int':
      self.direction = direction / abs(direction)
    elif type(direction).__name__ == 'str':
      if direction == "+":
        self.direction = 1
      elif direction == "-":
        self.direction = -1
      elif direction == "1" or direction == "-1":
        self.direction = int(direction)
      else:
        sys.exit("Cannot understand direction " + direction)
    else:
      sys.exit("Cannot understand direction " + direction)
    for subMapping in self.subMappings:
      subMapping.direction = self.direction


  def setSize(self, size):
    self.size = size
    if self.queryInterval.size == None:
      self.queryInterval.size = self.size
    if self.targetInterval.size == None:
      self.targetInterval.size = self.size
    if "identity" in self.getTagNames():
      self.setTagValue("nbMismatches", self.size - round(self.size * self.getTagValue("identity") / 100.0))


  def setTagValue(self, name, value):
    self.tags[name] = value
    self.transcript = None


  def getTagValue(self, name):
    return self.tags[name]

  
  def getTagNames(self):
    return self.tags.keys()


  def setIdentity(self, identity):
    self.setTagValue("identity", identity)
    if self.size != None and "nbMismatches" not in self.getTagNames():
      nbMismatches = 0 if self.size == 0 else self.size - round(self.size * self.getTagValue("identity") / 100.0)
      self.setTagValue("nbMismatches", nbMismatches)


  def setNbOccurrences(self, nbOccurrences):
    self.setTagValue("nbOccurrences", nbOccurrences)


  def setNbMismatches(self, nbMismatches):
    self.setTagValue("nbMismatches", nbMismatches)
    if self.size != None and "identity" not in self.getTagNames():
      identity = 100 if self.size == 0 else (self.size - self.getTagValue("nbMismatches")) / float(self.size) * 100
      self.setTagValue("identity", identity)


  def setNbGaps(self, nbGaps):
    self.setTagValue("nbGaps", nbGaps)
    
    
  def setRank(self, rank):
    self.setTagValue("rank", rank)
    

  def setEvalue(self, evalue):
    self.setTagValue("evalue", evalue)
    

  def setOccurrence(self, occurrence):
    self.setTagValue("occurrence", occurrence)
    
    
  def setBestRegion(self, bestRegion):
    self.setTagValue("bestRegion", bestRegion)
    
    
  def mergeExons(self, distance):
    previousSubMapping = None
    subMappings        = []
    for subMapping in self.subMappings:
      if previousSubMapping == None:
        subMappings.append(subMapping)
        previousSubMapping = subMapping
      else:
        targetDistance = subMapping.targetInterval.getDistance(previousSubMapping.targetInterval)
        queryDistance  = subMapping.queryInterval.getDistance(previousSubMapping.queryInterval)
        if targetDistance <= distance:
          self.setTagValue("nbGaps", self.getTagValue("nbGaps") + queryDistance)
          previousSubMapping.merge(subMapping)
        else:
          subMappings.append(subMapping)
          previousSubMapping = subMapping
    self.subMappings = subMappings
    
    
  def getTranscript(self):
    """
    Extract a transcript from this mapping
    @return: a transcript
    """
    if self.transcript != None:
      return self.transcript
    self.transcript = Transcript()
    self.transcript.copy(self.targetInterval)
    self.transcript.direction = self.direction
    self.transcript.name      = self.queryInterval.name
    if self.transcript.chromosome == ("arm_3LHet") or self.transcript.chromosome == ("arm_3RHet"):
      self.transcript.setChromosome(self.transcript.chromosome[len("arm_"):])
    for subMapping in self.subMappings:
      self.transcript.addExon(subMapping.targetInterval)
    cpt = 1
    for exon in self.transcript.exons:
      exon.direction  = self.transcript.direction
      exon.name       = "%s-exon%d" % (self.transcript.name, cpt)
      exon.chromosome = self.transcript.chromosome
      cpt            += 1
    self.transcript.sortExons()
    for tag in self.tags:
      if "bestRegion" not in self.getTagNames():
        self.transcript.setTagValue("bestRegion", "(self)")
      self.transcript.setTagValue(tag, self.getTagValue(tag))
    return self.transcript
  
  
  def getErrorScore(self):
    return self.getTagValue("nbGaps") * 3 + self.getTagValue("nbMismatches") + (len(self.subMappings) - 1) * 0.1
      

  def printGBrowseReference(self):
    return self.getTranscript().printGBrowseReference()


  def printGBrowseLine(self):
    return self.getTranscript().printGBrowseLine()


  def printGBrowse(self):
    return self.getTranscript().printGBrowse()


  def printBed(self):
    return self.getTranscript().printBed()


  def __str__(self):
    return "%s   ----   %s" % (str(self.getTranscript()), ", ". join([str(submapping) for submapping in self.subMappings]))
