#! /usr/bin/env python
#
# Copyright INRA-URGI 2009-2010
# 
# This software is governed by the CeCILL license under French law and
# abiding by the rules of distribution of free software. You can use,
# modify and/ or redistribute the software under the terms of the CeCILL
# license as circulated by CEA, CNRS and INRIA at the following URL
# "http://www.cecill.info".
# 
# As a counterpart to the access to the source code and rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty and the software's author, the holder of the
# economic rights, and the successive licensors have only limited
# liability.
# 
# In this respect, the user's attention is drawn to the risks associated
# with loading, using, modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean that it is complicated to manipulate, and that also
# therefore means that it is reserved for developers and experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and, more generally, to use and operate it in the
# same conditions as regards security.
# 
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL license and that you accept its terms.
#

try:
	import cPickle as pickle
except:
	import pickle
import random, os
from heapq import heapify, heappop, heappush
from itertools import islice, cycle
from SMART.Java.Python.structure.Transcript import Transcript
from SMART.Java.Python.misc.Progress import Progress
from SMART.Java.Python.misc.UnlimitedProgress import UnlimitedProgress

BUFFER_SIZE = 100 * 1024

class FileSorter(object):

	def __init__(self, parser, verbosity = 1):
		self._parser				  = parser
		self._verbosity			      = verbosity
		self._chunks				  = {}
		self._nbElements			  = 0
		self._nbElementsPerChromosome = {}
		self._perChromosome		      = False
		self._isPreSorted             = False
		self._outputFileNames		  = {}
		self._prefix				  = "tmpFile_%d" % (random.randint(0, 100000))
		self._chromosome			  = None
		if "SMARTTMPPATH" in os.environ:
			self._prefix = os.path.join(os.environ["SMARTTMPPATH"], self._prefix)

	def selectChromosome(self, chromosome):
		self._chromosome = chromosome

	def perChromosome(self, boolean):
		self._perChromosome = boolean

	def setOutputFileName(self, fileName):
		self._outputFileName = fileName
		if self._perChromosome:
			self._outputFileName = os.path.splitext(self._outputFileName)[0]
			
	def setPresorted(self, presorted):
		self._isPreSorted = presorted
		
	def sort(self):
		if not self._isPreSorted:
			self._batchSort()
		else:
			self._presorted()
	
	def _presorted(self):
		progress = UnlimitedProgress(1000, "Writing files %s" % (self._parser.fileName), self._verbosity)
		curChromosome = None
		outputHandle  = None
		
		if not self._perChromosome:
			outputHandle = open(self._outputFileName, "wb")
		for transcript in self._parser.getIterator():
			progress.inc()
			if transcript.__class__.__name__ == "Mapping":
				transcript = transcript.getTranscript()
			chromosome = transcript.getChromosome()
			if self._chromosome != None and chromosome != self._chromosome:
				continue
			self._nbElements += 1
			self._nbElementsPerChromosome[chromosome] = self._nbElementsPerChromosome.get(chromosome, 0) + 1
			if self._perChromosome:
				if chromosome != curChromosome:
					if outputHandle != None:
						outputHandle.close()
					self._outputFileNames[chromosome] = "%s_%s.pkl" % (self._outputFileName, chromosome)
					outputHandle  = open(self._outputFileNames[chromosome], "wb")
					curChromosome = chromosome
			outputHandle.writelines("%s" % pickle.dumps(transcript))  
		if outputHandle != None:
			outputHandle.close()
		progress.done() 

	def getNbElements(self):
		return self._nbElements

	def getNbElementsPerChromosome(self):
		return self._nbElementsPerChromosome

	def _printSorted(self, chromosome, chunk):
		chunk.sort(key = lambda transcript: (transcript.getStart(), -transcript.getEnd()))
		outputChunk = open("%s_%s_%06i.tmp" % (self._prefix, chromosome, len(self._chunks[chromosome])), "wb", 32000)
		self._chunks[chromosome].append(outputChunk)
		for transcript in chunk:
			outputChunk.write(pickle.dumps(transcript, -1))
		outputChunk.close()
		
	def _merge(self, chunks):
		values = []
		for chunk in chunks:
			chunk = open(chunk.name, "rb")
			try:
				transcript = pickle.load(chunk)
				start	   = transcript.getStart()
				end		   = -transcript.getEnd()
			except EOFError:
				try:
					chunk.close()
					chunks.remove(chunk)
					os.remove(chunk.name)
				except:
					pass
			else:
				heappush(values, (start, end, transcript, chunk))
		while values:
			start, end, transcript, chunk = heappop(values)
			yield transcript
			try:
				transcript = pickle.load(chunk)
				start	   = transcript.getStart()
				end		   = -transcript.getEnd()
			except EOFError:
				try:
					chunk.close()
					chunks.remove(chunk)
					os.remove(chunk.name)
				except:
					pass
			else:
				heappush(values, (start, end, transcript, chunk))
		
	def _batchSort(self):
		currentChunks = {}
		counts		  = {}
		try:
			progress = UnlimitedProgress(1000, "Sorting file %s" % (self._parser.fileName), self._verbosity)
			for transcript in self._parser.getIterator():
				progress.inc()
				if transcript.__class__.__name__ == "Mapping":
					transcript = transcript.getTranscript()
				chromosome = transcript.getChromosome()
				if self._chromosome != None and chromosome != self._chromosome:
					continue
				if chromosome not in self._chunks:
					self._chunks[chromosome]  = []
					currentChunks[chromosome] = []
					counts[chromosome]		= 0
				currentChunks[chromosome].append(transcript)
				counts[chromosome] += 1
				if counts[chromosome] == BUFFER_SIZE:
					self._printSorted(chromosome, currentChunks[chromosome])
					currentChunks[chromosome] = []
					counts[chromosome]		  = 0
				self._nbElements += 1
				self._nbElementsPerChromosome[chromosome] = self._nbElementsPerChromosome.get(chromosome, 0) + 1
			for chromosome in self._chunks:
				if counts[chromosome] > 0:
					self._printSorted(chromosome, currentChunks[chromosome])
			progress.done()
			if not self._perChromosome:
				outputHandle = open(self._outputFileName, "wb")
			progress = Progress(len(self._chunks), "Writing sorted file %s" % (self._parser.fileName), self._verbosity)
			for chromosome in self._chunks:
				if self._perChromosome:
					self._outputFileNames[chromosome] = "%s_%s.pkl" % (self._outputFileName, chromosome)
					outputHandle = open(self._outputFileNames[chromosome], "wb")
				for sequence in self._merge(self._chunks[chromosome]):
					pickle.dump(sequence, outputHandle, -1)
				if self._perChromosome:
					outputHandle.close()
				progress.inc()
			if not self._perChromosome:
				outputHandle.close()
			progress.done()
		finally:
			for chunks in self._chunks.values():
				for chunk in chunks:
					try:
						chunk.close()
						os.remove(chunk.name)
					except Exception:
						pass

	def getOutputFileNames(self):
		return self._outputFileNames
