#! /usr/bin/env python
#
# Copyright INRA-URGI 2009-2010
# 
# This software is governed by the CeCILL license under French law and
# abiding by the rules of distribution of free software. You can use,
# modify and/ or redistribute the software under the terms of the CeCILL
# license as circulated by CEA, CNRS and INRIA at the following URL
# "http://www.cecill.info".
# 
# As a counterpart to the access to the source code and rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty and the software's author, the holder of the
# economic rights, and the successive licensors have only limited
# liability.
# 
# In this respect, the user's attention is drawn to the risks associated
# with loading, using, modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean that it is complicated to manipulate, and that also
# therefore means that it is reserved for developers and experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and, more generally, to use and operate it in the
# same conditions as regards security.
# 
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL license and that you accept its terms.
#
import os, os.path
import struct
import shelve
import sys
from SMART.Java.Python.ncList.NCListFilePickle import NCListFilePickle, NCListFileUnpickle
from SMART.Java.Python.ncList.NCIndex import NCIndex
from SMART.Java.Python.misc.Progress import Progress

LONG_SIZE = struct.calcsize('l')

H = 0
L = 1
T = 2
G = 3

H_CELL_SIZE = 2
L_CELL_SIZE = 5
T_CELL_SIZE = 6

START   = 0
END	 = 1
ADDRESS = 2
LIST	= 3
PARENT  = 4
NEW	 = 5
LENGTH  = 1

def pack(input):
	return struct.pack("l", long(input))
def unpack(input):
	return struct.unpack("l", input)[0]


class NCList(object):

	def __init__(self, verbosity):
		self._verbosity		         = verbosity
		self._subPos			     = 0
		self._parentPos		         = 0
		self._nbLines			     = 0
		self._nbLists			     = 0
		self._chromosome		     = None
		self._transcriptFileName     = None
		self._lHandle			     = None
		self._hHandle			     = None
		self._tHandle			     = None
		self._parser			     = None
		self._sizeDict		         = {H: H_CELL_SIZE, L: L_CELL_SIZE, T: T_CELL_SIZE}
		self._offsets			     = {H: 0, L: 0, G: 0}
		self._fileNameDict	         = {}
		self._handleDict		     = {}
		self._createIndex		     = False
		self._missingValues	         = dict([table, {}] for table in self._sizeDict)
		self._missingValues[T][LIST] = -1
		self._missingValues[L][LIST] =  0
		self._missingValues[T][NEW]  = -1

	def __del__(self):
		for handle in (self._lHandle, self._hHandle):
			if handle != None:
				handle.close()

	def createIndex(self, boolean):
		self._createIndex = boolean

	def setChromosome(self, chromosome):
		self._chromosome = chromosome

	def setFileName(self, fileName):
		self._transcriptFileName = fileName
		self._parser = NCListFileUnpickle(fileName, self._verbosity)
		self._setFileNames(fileName)

	def setNbElements(self, nbElements):
		self._nbLines = nbElements

	def setOffset(self, fileType, offset):
		self._offsets[fileType] = offset

	def _setFileNames(self, fileName):
		print "Got file name", fileName
		if self._chromosome != None and fileName != None:
			coreName = os.path.splitext(fileName)[0]
			if "SMARTTMPPATH" in os.environ:
				coreName = os.path.join(os.environ["SMARTTMPPATH"], coreName)
			print "Used core name", coreName
			self._hFileName = "%s_H.bin" % (coreName)
			self._lFileName = "%s_L.bin" % (coreName)
			self._tFileName = "%s_T.bin" % (coreName)
			self._fileNameDict = {H: self._hFileName, L: self._lFileName, T: self._tFileName}

	def getSizeFirstList(self):
		return self._sizeFirstList
		
	def _writeSubListIntoH(self, SubListAddr, SubListLength):
		self._hHandle.write(pack(SubListAddr))
		self._hHandle.write(pack(SubListLength))
		self._subPos += H_CELL_SIZE
			
	def _writeParentIntoL(self, readAddr, subListAddr, parentAddr, start, end): 
		self._lHandle.write(pack(start))
		self._lHandle.write(pack(end))
		self._lHandle.write(pack(readAddr))
		self._lHandle.write(pack(subListAddr)) 
		self._lHandle.write(pack(parentAddr))
		self._parentPos += L_CELL_SIZE

	def getLLineElements(self, subListLAddr): 
		if subListLAddr == -1 or subListLAddr == None:
			#print "reading bad from L", subListLAddr
			return -1, -1, -1, -1, -1
		else:
			self._lHandle.seek(subListLAddr * L_CELL_SIZE * LONG_SIZE + self._offsets[L])
			start = self._lHandle.read(LONG_SIZE)	 
			if len(start) < LONG_SIZE:
				#print "reading very bad from L", subListLAddr
				return -1, -1, -1, -1, -1
			start		= unpack(start)
			end		  = unpack(self._lHandle.read(LONG_SIZE))
			gff3Addr	 = unpack(self._lHandle.read(LONG_SIZE))
			subListHAddr = unpack(self._lHandle.read(LONG_SIZE))
			parentLAddr  = unpack(self._lHandle.read(LONG_SIZE))
			#print "reading from L", subListLAddr, "-->", gff3Addr, subListHAddr, parentLAddr, start, end
			return gff3Addr, subListHAddr, parentLAddr, start, end

	def getHLineElements(self, subListHAddr):
		self._hHandle.seek(subListHAddr * H_CELL_SIZE * LONG_SIZE + self._offsets[H])
		subListStartBin = self._hHandle.read(LONG_SIZE)
		if len(subListStartBin) < 8 :
			#print "reading bad from H"
			return -1, -1
		subListStart		 = unpack(subListStartBin)
		subListElementsNb	= unpack(self._hHandle.read(LONG_SIZE))
		#print "reading from H", subListHAddr, "-->", subListStart, subListElementsNb
		return subListStart, subListElementsNb

	def getRefGffAddr(self, currentRefLAddr):
		RefGff3Addr, subListHAddr, parentLAddr, start, end = self.getLLineElements(currentRefLAddr)
		return RefGff3Addr
	
	def getIntervalFromAdress(self, address):
		self._parser.gotoAddress(int(address) + self._offsets[G])
		iTranscrit = self._parser.getNextTranscript()
		return iTranscrit

	def removeFiles(self):
		return

	def buildLists(self):
		if self._createIndex:
			self._index = NCIndex(self._verbosity)
		self._createTables()
		self._labelLists()
		self._computeSubStart()
		self._computeAbsPosition()
		self._cleanFiles()

	def _createTables(self):
		self._initLists()
		self._createTable(H, self._nbLists)
		self._createTable(T, self._nbLines)
		self._createTable(L, self._nbLines)
		self._fillTables()

	def _initLists(self):
		previousTranscript = None
		self._nbLists	  = 1
		progress = Progress(self._nbLines, "Initializing lists", self._verbosity-5)
		for transcript in self._parser.getIterator():
			if self._isIncluded(transcript, previousTranscript):
				self._nbLists += 1
			previousTranscript = transcript
			progress.inc()
		progress.done()

	def _isIncluded(self, transcript1, transcript2):
		return transcript1 != None and transcript2 != None and transcript1.getStart() >= transcript2.getStart() and transcript1.getEnd() <= transcript2.getEnd()

	def _createTable(self, name, size):
		handle = open(self._fileNameDict[name], "w+b")
		progress = Progress(self._sizeDict[name] * size, "Initializing table %d" % (name), self._verbosity-5)
		for i in xrange(self._sizeDict[name] * size):
			handle.write(pack(-1))
			progress.inc()
		progress.done()
		self._handleDict[name] = handle

	def _fillTables(self):
		progress = Progress(self._nbLines, "Filling table T", self._verbosity-5)
		for i, transcript in enumerate(self._parser.getIterator()):
			self._writeValue(T, i, START,   transcript.getStart())
			self._writeValue(T, i, END,	 transcript.getEnd())
			self._writeValue(T, i, ADDRESS, self._parser.getCurrentTranscriptAddress())
			self._writeValue(T, i, PARENT,  -1)
			self._writeValue(T, i, LIST,	-1)
			progress.inc()
		progress.done()
		progress = Progress(self._nbLists, "Filling table H", self._verbosity-5)
		for i in xrange(self._nbLists):
			self._writeValue(H, i, LENGTH, 0)
			progress.inc()
		progress.done()

	def _labelLists(self):
		progress = Progress(self._nbLines, "Getting table structure", self._verbosity-5)
		nextL = 0
		for i in xrange(self._nbLines):
			p	 = i - 1
			start = self._readValue(T, i, START)
			end   = self._readValue(T, i, END)
			while p != -1 and (start < self._readValue(T, p, START) or end > self._readValue(T, p, END)):
				p = self._readValue(T, p, PARENT)
			thisL = self._readValue(T, p, LIST)
			if thisL == -1:
				#print "entering"
				thisL  = nextL
				nextL += 1
				length = 0
				self._writeValue(T, p, LIST, thisL)
			else:
				length = self._readValue(H, thisL, LENGTH)
			self._writeValue(T, i,	 PARENT, p)
			self._writeValue(H, thisL, LENGTH, length + 1)
			progress.inc()
		progress.done()

	def _computeSubStart(self):
		progress = Progress(self._nbLines, "Getting table sub-lists", self._verbosity-5)
		total = 0
		for i in xrange(self._nbLists):
			self._writeValue(H, i, START, total)
			total += self._readValue(H, i, LENGTH)
			self._writeValue(H, i, LENGTH, 0)
			progress.inc()
		progress.done()

	def _computeAbsPosition(self):
		progress = Progress(self._nbLines, "Writing table", self._verbosity-5)
		self._sizeFirstList = 0
		for i in xrange(self._nbLines):
			s  = self._readValue(T, i,  START)
			e  = self._readValue(T, i,  END)
			a  = self._readValue(T, i,  ADDRESS)
			pt = self._readValue(T, i,  PARENT)
			h  = self._readValue(T, pt, LIST)
			pl = self._readValue(T, pt, NEW)
			nb = self._readValue(H, h,  LENGTH)
			l  = self._readValue(H, h,  START) + nb
			self._writeValue(T, i, NEW,	 l)
			self._writeValue(L, l, START,   s)
			self._writeValue(L, l, END,	 e)
			self._writeValue(L, l, ADDRESS, a)
			self._writeValue(L, l, LIST,	-1)
			self._writeValue(L, l, PARENT,  pl)
			self._writeValue(H, h, LENGTH,  nb+1)
			if nb == 0:
				#print "adding it"
				self._writeValue(L, pl, LIST, h)
			if pl == -1:
				self._sizeFirstList += 1
				if self._createIndex:
					self._index.addTranscript(e, l)
			progress.inc()
		progress.done()

	def closeFiles(self):
		for handle in self._handleDict.values():
			handle.close()
		del self._handleDict
		self._lHandle = None
		self._hHandle = None
		self._tHandle = None
		self._parser = None

	def openFiles(self):
		self._lHandle = open(self._fileNameDict[L], "rb")
		self._hHandle = open(self._fileNameDict[H], "rb")
		self._handleDict = {H: self._hHandle, L: self._lHandle}
		self._parser  = NCListFileUnpickle(self._transcriptFileName, self._verbosity)

	def _cleanFiles(self):
		self.closeFiles()
		os.remove(self._fileNameDict[T])

	def _getPosition(self, table, line, key):
		handle = self._handleDict[table]
		handle.seek(self._sizeDict[table] * line * LONG_SIZE + key * LONG_SIZE)
		return handle

	def _writeValue(self, table, line, key, value):
		#print "writing", table, line, key, "<-", value
		if line == -1:
			self._missingValues[table][key] = value
			return
		handle = self._getPosition(table, line, key)
		handle.write(pack(value))

	def _readValue(self, table, line, key):
		#print "reading", table, line, key, "->",
		if line == -1:
			#print self._missingValues[table][key]
			return self._missingValues[table][key]
		handle = self._getPosition(table, line, key)
		r = unpack(handle.read(LONG_SIZE))
		#print r
		return r

	def getIndex(self):
		return self._index
