import re
from structure.transcript import *
from mySql.mySqlTable import *
from structure.transcriptListIterator import *
from misc.progress import *


class TranscriptList(object):
  """A class that codes for a list of transcript"""

  def __init__(self, verbosity = 0):
    self.transcripts       = dict()
    self.longestTranscript = 0
    self.verbosity         = verbosity


  def getTranscript(self, chromosome, index):
    return self.transcripts[chromosome][index]
    

  def getChromosomes(self):
    return self.transcripts.keys()


  def getTranscriptsOnChromosome(self, chromosome):
    if chromosome not in self.transcripts:
      return []
    return self.transcripts[chromosome]


  def addTranscript(self, transcript):
    if transcript.chromosome in self.transcripts:
      self.transcripts[transcript.chromosome].append(transcript)
    else:
      self.transcripts[transcript.chromosome] = [transcript]
    self.longestTranscript = max(self.longestTranscript, transcript.end - transcript.start)
    

  def removeTranscript(self, chromosome, i):
    del self.transcripts[chromosome][i]


  def removeAll(self):
    self.transcripts = {}


  def getNbTranscripts(self):
    nbTranscripts = 0
    for chromosome in self.transcripts:
      nbTranscripts = len(self.transcripts[chromosome])
    return nbTranscripts


  def getSize(self):
    size = 0
    for chromosome in self.transcripts:
      for transcript in self.transcripts[chromosome]:
        size += transcript.getSize()
    return size


  def sort(self):
    for chromosome in self.transcripts:
      self.transcripts[chromosome].sort(lambda x, y: x.start - y.start)


  def removeOverlapWith(self, transcriptList):
    transcriptList.sort()
    for chromosome in self.transcripts:
      progress = Progress(len(self.transcripts[chromosome]), "Handling chromosome %s" % (chromosome), self.verbosity)
      for thisTranscriptId in range(len(self.transcripts[chromosome])):
        progress.inc()
        for thatTranscriptId in range(len(transcriptList.transcripts[chromosome])):
          if self.transcripts[chromosome][thisTranscriptId].overlapWith(transcriptList.transcripts[chromosome][thatTranscriptId]):
            self.transcripts[chromosome][thisTranscriptId] = None
            break
          if self.transcripts[chromosome][thisTranscriptId].end > transcriptList.transcripts[chromosome][thatTranscriptId]:
            break
      self.transcripts[chromosome] = [transcript for transcript in self.transcripts[chromosome] if transcript != None]
    progress.done()


  def removeOverlapWithExon(self, transcriptList):
    transcriptList.sort()
    for chromosome in self.transcripts:
      progress = Progress(len(self.transcripts[chromosome]), "Handling chromosome %s" % (chromosome), self.verbosity)
      for thisTranscriptId in range(len(self.transcripts[chromosome])):
        progress.inc()
        for thatTranscriptId in range(len(transcriptList.transcripts[chromosome])):
          if self.transcripts[chromosome][thisTranscriptId].overlapWithExon(transcriptList.transcripts[chromosome][thatTranscriptId]):
            self.transcripts[chromosome][thisTranscriptId] = None
            break
          if self.transcripts[chromosome][thisTranscriptId].end > transcriptList.transcripts[chromosome][thatTranscriptId]:
            break
      self.transcripts[chromosome] = [transcript for transcript in self.transcripts[chromosome] if transcript != None]
    progress.done()


  def setDefaultTagValue(self, name, value):
    for transcript in self.getIterator():
      transcript.setTag(name, value)


  def storeDatabase(self, mySqlConnection):
    transcriptsTable   = MySqlTable("TmpTranscriptsTable", mySqlConnection)
    transcriptsTable.create(Transcript.getSqlVariables(), Transcript.getSqlTypes())
    intervalsVariables = Interval.getSqlVariables()
    intervalsVariables.append("idTranscript")
    intervalsTypes     = Interval.getSqlTypes()
    intervalsTypes["idTranscript"] = "int"
    intervalsTable = MySqlTable("TmpIntervalsTable", mySqlConnection)
    intervalsTable.create(intervalsVariables, intervalsTypes)
    for chromosome in self.transcripts:
      for transcript in self.transcripts[chromosome]:
        idTranscript = transcriptsTable.addLine(transcript.getSqlValues())
        for exon in transcript.getExons():
          intervalValues = exon.getSqlValues()
          intervalValues["idTranscript"] = idTranscript
          intervalsTable.addLine(intervalValues)
          
  
  def getIterator(self):
    chromosomes       = self.transcripts.keys()
    currentChromosome = 0
    currentTranscript = 0
    while True:
      if currentChromosome >= len(chromosomes):
        return
      elif currentTranscript >= len(self.transcripts[chromosomes[currentChromosome]]):
        currentTranscript  = 0
        currentChromosome += 1
      elif self.transcripts[chromosomes[currentChromosome]][currentTranscript] == None:
        currentTranscript += 1
      else:
        yield self.transcripts[chromosomes[currentChromosome]][currentTranscript]
        currentTranscript += 1


  def __str__(self):
    string = ""
    for transcript in self.getIterator():
      string += str(transcript)
    return string

