import unittest
import os, os.path, glob
from SMART.Java.Python.structure.Transcript import Transcript
from SMART.Java.Python.GetFlanking import GetFlanking
from commons.core.writer.Gff3Writer import Gff3Writer
from commons.core.parsing.GffParser import GffParser

class Test_F_GetFlanking(unittest.TestCase):

    def setUp(self):
        self.queryFileName     = "testQuery.gff3"
        self.referenceFileName = "testReference.gff3"
        self.outputFileName    = "testOutput.gff3"
         
    def tearDown(self):
        for fileRoot in (self.queryFileName, self.referenceFileName, self.outputFileName):
            for file in glob.glob("%s*" % (fileRoot)):
                os.remove(file)

    def test_run_simple(self):
        #return
        reference1 = self._createTranscript("chr1", 1000, 1100, "+", "ref1")
        reference2 = self._createTranscript("chr1", 2000, 2100, "+", "ref2")
        reference3 = self._createTranscript("chr1", 1000000, 1200000, "+", "ref3")
        writer = Gff3Writer(self.referenceFileName, 0)
        writer.addTranscript(reference1)
        writer.addTranscript(reference2)
        writer.addTranscript(reference3)
        writer.close()
        query1 = self._createTranscript("chr1", 100, 200, "+", "query1")
        query2 = self._createTranscript("chr1", 10000, 10100, "+", "query2")
        writer = Gff3Writer(self.queryFileName, 0)
        writer.addTranscript(query1)
        writer.addTranscript(query2)
        writer.close()
        gf = GetFlanking(0)
        gf.setInputFile(self.queryFileName, 'gff3', 0)
        gf.setInputFile(self.referenceFileName, 'gff3', 1)
        gf.setOutputFile(self.outputFileName)
        gf.run()
        parser = GffParser(self.outputFileName)
        self.assertEqual(parser.getNbTranscripts(), 2)
        for i, transcript in enumerate(sorted(parser.getIterator(), key = lambda t: t.getStart())):
            if i == 0:
                self._checkTranscript(transcript, "chr1", 100, 200, "+", "query1")
                self.assertEqual(transcript.getTagValue("flanking"), "ref1")
                self.assertEqual(transcript.getTagValue("_region_flanking"), "downstream")
                self.assertEqual(transcript.getTagValue("_sense_flanking"), "collinear")
            else:
                self._checkTranscript(transcript, "chr1", 10000, 10100, "+", "query2")
                self.assertEqual(transcript.getTagValue("flanking"), "ref2")
                self.assertEqual(transcript.getTagValue("_region_flanking"), "upstream")
                self.assertEqual(transcript.getTagValue("_sense_flanking"), "collinear")

    def test_run_simple_downstream(self):
        return
        reference1 = self._createTranscript("chr1", 300, 400, "+", "ref1")
        reference2 = self._createTranscript("chr1", 1000, 1100, "+", "ref2")
        writer = Gff3Writer(self.referenceFileName, 0)
        writer.addTranscript(reference1)
        writer.addTranscript(reference2)
        writer.close()
        query1 = self._createTranscript("chr1", 100, 200, "+", "query1")
        query2 = self._createTranscript("chr1", 1200, 1300, "+", "query2")
        query3 = self._createTranscript("chr1", 1400, 1500, "+", "query3")
        writer = Gff3Writer(self.queryFileName, 0)
        writer.addTranscript(query1)
        writer.addTranscript(query2)
        writer.addTranscript(query3)
        writer.close()
        gf = GetFlanking(0)
        gf.setInputFile(self.queryFileName, 'gff3', 0)
        gf.setInputFile(self.referenceFileName, 'gff3', 1)
        gf.setOutputFile(self.outputFileName)
        gf.addDownstreamDirection(True)
        gf.run()
        parser = GffParser(self.outputFileName)
        self.assertEqual(parser.getNbTranscripts(), 3)
        for i, transcript in enumerate(sorted(parser.getIterator(), key = lambda t: t.getStart())):
            if i == 0:
                self._checkTranscript(transcript, "chr1", 100, 200, "+", "query1")
                self.assertEqual(transcript.getTagValue("flanking_downstream"), "ref1")
                self.assertEqual(transcript.getTagValue("_region_flanking"), "downstream")
                self.assertEqual(transcript.getTagValue("_sense_flanking"), "collinear")
            if i == 1:
                self._checkTranscript(transcript, "chr1", 1200, 1300, "+", "query2")
                self.assertIsNone(transcript.getTagValue("flanking_downstream"))
            if i == 2:
                self._checkTranscript(transcript, "chr1", 1400, 1500, "+", "query3")
                self.assertIsNone(transcript.getTagValue("flanking_downstream"))

    def test_run_simple_minus_strand_downstream(self):
        return
        reference1 = self._createTranscript("chr1", 1000, 1100, "+", "ref1")
        reference2 = self._createTranscript("chr1", 2000, 2100, "+", "ref2")
        writer = Gff3Writer(self.referenceFileName, 0)
        writer.addTranscript(reference1)
        writer.addTranscript(reference2)
        writer.close()
        query1 = self._createTranscript("chr1", 100, 200, "-", "query1")
        query2 = self._createTranscript("chr1", 1200, 1300, "-", "query2")
        query3 = self._createTranscript("chr1", 1400, 1500, "-", "query3")
        writer = Gff3Writer(self.queryFileName, 0)
        writer.addTranscript(query1)
        writer.addTranscript(query2)
        writer.addTranscript(query3)
        writer.close()
        gf = GetFlanking(0)
        gf.setInputFile(self.queryFileName, 'gff3', 0)
        gf.setInputFile(self.referenceFileName, 'gff3', 1)
        gf.setOutputFile(self.outputFileName)
        gf.addDownstreamDirection(True)
        gf.run()
        parser = GffParser(self.outputFileName)
        self.assertEqual(parser.getNbTranscripts(), 3)
        for i, transcript in enumerate(sorted(parser.getIterator(), key = lambda t: t.getStart())):
            if i == 0:
                self._checkTranscript(transcript, "chr1", 100, 200, "-", "query1")
                self.assertIsNone(transcript.getTagValue("flanking_downstream"))
            if i == 1:
                self._checkTranscript(transcript, "chr1", 1200, 1300, "-", "query2")
                self.assertEqual(transcript.getTagValue("flanking_downstream"), "ref1")
            if i == 2:
                self._checkTranscript(transcript, "chr1", 1400, 1500, "-", "query3")
                self.assertEqual(transcript.getTagValue("flanking_downstream"), "ref1")

    def test_run_simple_upstream(self):
        return
        reference1 = self._createTranscript("chr1", 500, 600, "+", "ref1")
        reference2 = self._createTranscript("chr1", 700, 800, "+", "ref2")
        reference3 = self._createTranscript("chr1", 2000, 2100, "+", "ref3")
        writer = Gff3Writer(self.referenceFileName, 0)
        writer.addTranscript(reference1)
        writer.addTranscript(reference2)
        writer.addTranscript(reference3)
        writer.close()
        query1 = self._createTranscript("chr1", 100, 200, "+", "query1")
        query2 = self._createTranscript("chr1", 300, 400, "+", "query2")
        query3 = self._createTranscript("chr1", 1200, 1300, "+", "query3")
        writer = Gff3Writer(self.queryFileName, 0)
        writer.addTranscript(query1)
        writer.addTranscript(query2)
        writer.addTranscript(query3)
        writer.close()
        gf = GetFlanking(0)
        gf.setInputFile(self.queryFileName, 'gff3', 0)
        gf.setInputFile(self.referenceFileName, 'gff3', 1)
        gf.setOutputFile(self.outputFileName)
        gf.addUpstreamDirection(True)
        gf.run()
        parser = GffParser(self.outputFileName)
        self.assertEqual(parser.getNbTranscripts(), 3)
        for i, transcript in enumerate(sorted(parser.getIterator(), key = lambda t: t.getStart())):
            if i == 0:
                self._checkTranscript(transcript, "chr1", 100, 200, "+", "query1")
                self.assertIsNone(transcript.getTagValue("flanking_upstream"))
            if i == 1:
                self._checkTranscript(transcript, "chr1", 300, 400, "+", "query2")
                self.assertIsNone(transcript.getTagValue("flanking_upstream"))
            if i == 2:
                self._checkTranscript(transcript, "chr1", 1200, 1300, "+", "query3")
                self.assertEqual(transcript.getTagValue("flanking_upstream"), "ref2")

    def test_run_simple_colinear(self):
        return
        reference1 = self._createTranscript("chr1", 100, 200, "+", "ref1")
        reference2 = self._createTranscript("chr1", 1000, 1100, "+", "ref2")
        reference3 = self._createTranscript("chr1", 1600, 1700, "+", "ref3")
        writer = Gff3Writer(self.referenceFileName, 0)
        writer.addTranscript(reference1)
        writer.addTranscript(reference2)
        writer.addTranscript(reference3)
        writer.close()
        query1 = self._createTranscript("chr1", 1200, 1300, "-", "query1")
        query2 = self._createTranscript("chr1", 1400, 1500, "+", "query2")
        writer = Gff3Writer(self.queryFileName, 0)
        writer.addTranscript(query1)
        writer.addTranscript(query2)
        writer.close()
        gf = GetFlanking(0)
        gf.setInputFile(self.queryFileName, 'gff3', 0)
        gf.setInputFile(self.referenceFileName, 'gff3', 1)
        gf.setOutputFile(self.outputFileName)
        gf.addUpstreamDirection(True)
        gf.setColinear(True)
        gf.run()
        parser = GffParser(self.outputFileName)
        self.assertEqual(parser.getNbTranscripts(), 2)
        for i, transcript in enumerate(sorted(parser.getIterator(), key = lambda t: t.getStart())):
            if i == 0:
                self._checkTranscript(transcript, "chr1", 1200, 1300, "-", "query1")
                self.assertIsNone(transcript.getTagValue("flanking"))
            if i == 1:
                self._checkTranscript(transcript, "chr1", 1400, 1500, "+", "query2")
                self.assertEqual(transcript.getTagValue("flanking_upstream"), "ref2")

    def test_run_simple_max_distance(self):
        return
        reference = self._createTranscript("chr1", 1000, 1100, "+", "ref")
        writer = Gff3Writer(self.referenceFileName, 0)
        writer.addTranscript(reference)
        writer.close()
        query1 = self._createTranscript("chr1", 2000, 2100, "-", "query1")
        writer = Gff3Writer(self.queryFileName, 0)
        writer.addTranscript(query1)
        writer.close()
        gf = GetFlanking(0)
        gf.setInputFile(self.queryFileName, 'gff3', 0)
        gf.setInputFile(self.referenceFileName, 'gff3', 1)
        gf.setOutputFile(self.outputFileName)
        gf.setMaxDistance(100)
        gf.run()
        parser = GffParser(self.outputFileName)
        self.assertEqual(parser.getNbTranscripts(), 1)
        for i, transcript in enumerate(sorted(parser.getIterator(), key = lambda t: t.getStart())):
            if i == 0:
                self._checkTranscript(transcript, "chr1", 2000, 2100, "-", "query1")
                self.assertIsNone(transcript.getTagValue("flanking"))

    def _createTranscript(self, chromosome, start, end, strand, name):
        transcript = Transcript()
        transcript.setChromosome(chromosome)
        transcript.setStart(start)
        transcript.setEnd(end)
        transcript.setDirection(strand)
        transcript.setName(name)
        return transcript

    def _checkTranscript(self, transcript, chromosome, start, end, strand, name):
        self.assertEqual(transcript.getChromosome(), chromosome)
        self.assertEqual(transcript.getStart(), start)
        self.assertEqual(transcript.getEnd(), end)
        self.assertEqual(transcript.getStrand(), strand)
        self.assertEqual(transcript.getName(), name)

        
if __name__ == "__main__":
    unittest.main()
