import os
import unittest
from SMART.Java.Python.structure.TranscriptListsComparator import TranscriptListsComparator
from SMART.Java.Python.structure.TranscriptContainer import TranscriptContainer
from commons.core.writer.Gff3Writer import Gff3Writer
from commons.core.parsing.GffParser import GffParser
from commons.core.utils.FileUtils import FileUtils

SMART_PATH = os.environ["REPET_PATH"] + "/SMART"

class Test_TranscriptListsComparator(unittest.TestCase):
    

    def test_compareTranscriptList(self):
        container1 = TranscriptContainer("%s/Java/Python/TestFiles/testTranscriptListsComparatorCompareTranscriptList1.bed" % SMART_PATH, "bed", 0)
        container2 = TranscriptContainer("%s/Java/Python/TestFiles/testTranscriptListsComparatorCompareTranscriptList2.bed" % SMART_PATH, "bed", 0)
        outputContainer = "output.gff3"
        comparator = TranscriptListsComparator(None, 0)
        comparator.computeOdds(True)
        comparator.setInputTranscriptContainer(comparator.QUERY, container1)
        comparator.setInputTranscriptContainer(comparator.REFERENCE, container2)
        comparator.setOutputWriter(Gff3Writer(outputContainer, 0))
        comparator.compareTranscriptList()
        parser = GffParser("output.gff3", 0)
        self.assertEqual(parser.getNbTranscripts(), 2)
        cpt = 0
        for transcript in parser.getIterator():
            if cpt == 0:
                self.assertEqual(transcript.getChromosome(), "arm_X")
                self.assertEqual(transcript.getStart(), 1000)
                self.assertEqual(transcript.getEnd(), 1999)
                self.assertEqual(transcript.getDirection(), 1)
            elif cpt == 1:
                self.assertEqual(transcript.getChromosome(), "arm_X")
                self.assertEqual(transcript.getStart(), 1000)
                self.assertEqual(transcript.getEnd(), 1999)
                self.assertEqual(transcript.getDirection(), -1)
            cpt += 1


    def test_compareTranscriptListDistanceSimple(self):
        container1 = TranscriptContainer("%s/Java/Python/TestFiles/testCompareTranscriptListDistanceSimple1.gff3" % SMART_PATH, "gff", 0)
        container2 = TranscriptContainer("%s/Java/Python/TestFiles/testCompareTranscriptListDistanceSimple2.gff3" % SMART_PATH, "gff", 0)

        comparator = TranscriptListsComparator(None, 0)
        comparator.computeOdds(True)
        comparator.setMaxDistance(1000)
        comparator.setInputTranscriptContainer(comparator.QUERY, container1)
        comparator.setInputTranscriptContainer(comparator.REFERENCE, container2)
        distances = comparator.compareTranscriptListDistance()

        self.assertEqual(distances, {0: 1})

        comparator = TranscriptListsComparator(None, 0)
        comparator.computeOdds(True)
        comparator.setMaxDistance(1000)
        comparator.setInputTranscriptContainer(comparator.QUERY, container2)
        comparator.setInputTranscriptContainer(comparator.REFERENCE, container1)
        distances = comparator.compareTranscriptListDistance()

        self.assertEqual(distances, {0: 1, -1000: 1})


    def test_compareTranscriptListDistanceAntisense(self):
        container1 = TranscriptContainer("%s/Java/Python/TestFiles/testCompareTranscriptListDistanceAntisense1.gff3" % SMART_PATH, "gff", 0)
        container2 = TranscriptContainer("%s/Java/Python/TestFiles/testCompareTranscriptListDistanceAntisense2.gff3" % SMART_PATH, "gff", 0)

        comparator = TranscriptListsComparator(None, 0)
        comparator.computeOdds(True)
        comparator.setMaxDistance(10000)
        comparator.getAntisenseOnly(True)
        comparator.setInputTranscriptContainer(comparator.QUERY, container1)
        comparator.setInputTranscriptContainer(comparator.REFERENCE, container2)
        distances = comparator.compareTranscriptListDistance()

        self.assertEqual(distances, {1000: 1})



    def test_compareTranscriptListMergeSimple(self):
        container1 = TranscriptContainer("%s/Java/Python/TestFiles/testTranscriptListsComparatorCompareTranscriptListMergeSimple1.bed" % SMART_PATH, "bed", 0)
        container2 = TranscriptContainer("%s/Java/Python/TestFiles/testTranscriptListsComparatorCompareTranscriptListMergeSimple2.bed" % SMART_PATH, "bed", 0)
        outputContainer = 'output.gff3'
        comparator = TranscriptListsComparator(None, 0)
        comparator.computeOdds(True)
        comparator.setInputTranscriptContainer(comparator.QUERY, container1)
        comparator.setInputTranscriptContainer(comparator.REFERENCE, container2)
        comparator.setOutputWriter(Gff3Writer(outputContainer, 0))
        comparator.compareTranscriptListMerge()

        parser = GffParser(outputContainer, 0)
        self.assertEqual(parser.getNbTranscripts(), 1)
        for transcript in parser.getIterator():
            self.assertEqual(transcript.getChromosome(), "arm_X")
            self.assertEqual(transcript.getStart(), 1000)
            self.assertEqual(transcript.getEnd(), 3999)
            self.assertEqual(transcript.getDirection(), 1)
            self.assertEqual(transcript.getSize(), 3000)

    def test_compareTranscriptListMergeSenseAntiSenseAway(self):
        container1 = TranscriptContainer("%s/Java/Python/TestFiles/testTranscriptListsComparatorCompareTranscriptListMergeSenseAntiSenseAway1.bed" % SMART_PATH, "bed", 0)
        container2 = TranscriptContainer("%s/Java/Python/TestFiles/testTranscriptListsComparatorCompareTranscriptListMergeSenseAntiSenseAway2.bed" % SMART_PATH, "bed", 0)
        outputContainer = 'output.gff3'
        comparator = TranscriptListsComparator(None, 0)
        comparator.restrictToStart(comparator.QUERY, 2)
        comparator.restrictToStart(comparator.REFERENCE, 2)
        comparator.extendFivePrime(comparator.REFERENCE, 150)
        comparator.getAntisenseOnly(True)
        comparator.computeOdds(True)
        comparator.setInputTranscriptContainer(comparator.QUERY, container1)
        comparator.setInputTranscriptContainer(comparator.REFERENCE, container2)
        comparator.setOutputWriter(Gff3Writer(outputContainer, 0))
        comparator.compareTranscriptListMerge()

        parser = GffParser(outputContainer, 0)
        self.assertEqual(parser.getNbTranscripts(), 1)
        for transcript in parser.getIterator():
            self.assertEqual(transcript.getChromosome(), "arm_X")
            self.assertEqual(transcript.getStart(), 10000048)
            self.assertEqual(transcript.getEnd(), 10000199)
            self.assertEqual(transcript.getSize(), 152)
            self.assertEqual(transcript.getNbExons(), 1)


    def test_compareTranscriptListMergeAggregation(self):
        container1 = TranscriptContainer("%s/Java/Python/TestFiles/testTranscriptListsComparatorCompareTranscriptListMergeAggregation1.bed" % SMART_PATH, "bed", 0)
        container2 = TranscriptContainer("%s/Java/Python/TestFiles/testTranscriptListsComparatorCompareTranscriptListMergeAggregation2.bed" % SMART_PATH, "bed", 0)
        outputContainer = 'output.gff3'
        comparator = TranscriptListsComparator(None, 0)
        comparator.getColinearOnly(True)
        comparator.computeOdds(True)
        comparator.setInputTranscriptContainer(comparator.QUERY, container1)
        comparator.setInputTranscriptContainer(comparator.REFERENCE, container2)
        comparator.aggregate(True)
        comparator.setOutputWriter(Gff3Writer(outputContainer, 0))
        comparator.compareTranscriptListMerge()

        parser = GffParser(outputContainer, 0)
        self.assertEqual(parser.getNbTranscripts(), 1)
        for transcript in parser.getIterator():
            self.assertEqual(transcript.getChromosome(), "arm_X")
            self.assertEqual(transcript.getStart(), 10000000)
            self.assertEqual(transcript.getEnd(), 10000199)
            self.assertEqual(transcript.getSize(), 200)
            self.assertEqual(transcript.getNbExons(), 1)


    def test_compareTranscriptListSelfMerge(self):
        container1 = TranscriptContainer("%s/Java/Python/TestFiles/testTranscriptListsComparatorCompareTranscriptListSelfMerge1.gff3" % SMART_PATH, "gff", 0)

        comparator = TranscriptListsComparator(None, 0)
        comparator.computeOdds(True)
        comparator.setInputTranscriptContainer(comparator.QUERY, container1)
        comparator.setOutputWriter(Gff3Writer("output.gff3", 0))
        comparator.compareTranscriptListSelfMerge()

        parser = GffParser("output.gff3", 0)
        self.assertEqual(parser.getNbTranscripts(), 1)
        for transcript in parser.getIterator():
            self.assertEqual(transcript.getChromosome(), "arm_X")
            self.assertEqual(transcript.getStart(), 1000)
            self.assertEqual(transcript.getEnd(), 2000)
            self.assertEqual(transcript.getDirection(), 1)
            self.assertEqual(transcript.getNbExons(), 1)
            self.assertEqual(transcript.getSize(), 1001)
            self.assertEqual(float(transcript.getTagValue("nbElements")), 3)


    def test_compareTranscriptListSelfMergeSense(self):
        container1 = TranscriptContainer("%s/Java/Python/TestFiles/testTranscriptListsComparatorCompareTranscriptListSelfMerge1.gff3" % SMART_PATH, "gff", 0)

        comparator = TranscriptListsComparator(None, 0)
        comparator.getColinearOnly(True)
        comparator.computeOdds(True)
        comparator.setInputTranscriptContainer(comparator.QUERY, container1)
        comparator.setOutputWriter(Gff3Writer("output.gff3", 0))
        comparator.compareTranscriptListSelfMerge()

        parser = GffParser("%s/SMART/Java/Python/structure/test/output.gff3" % os.environ["REPET_PATH"], 0) 
        self.assertEqual(parser.getNbTranscripts(), 1)
        for transcript in parser.getIterator():
            self.assertEqual(transcript.getChromosome(), "arm_X")
            self.assertEqual(transcript.getStart(), 1000)
            self.assertEqual(transcript.getEnd(), 2000)
            self.assertEqual(transcript.getDirection(), 1)
            self.assertEqual(transcript.getNbExons(), 1)
            self.assertEqual(transcript.getSize(), 1001)

    def test_compareTranscriptListSelfMergeDifferentClusters(self):
        container1 = TranscriptContainer("%s/Java/Python/TestFiles/testTranscriptListsComparatorCompareTranscriptListSelfMergeDifferentClusters1.bed" % SMART_PATH, "bed", 0)
        comparator = TranscriptListsComparator(None, 0)
        comparator.setInputTranscriptContainer(comparator.QUERY, container1)
        comparator.setOutputWriter(Gff3Writer("output.gff3", 0))
        comparator.compareTranscriptListSelfMerge()

        parser = GffParser("output.gff3", 0)
        self.assertEquals(parser.getNbTranscripts(), 1)
        for transcript in parser.getIterator():
            self.assertEqual(transcript.getChromosome(), "arm_X")
            self.assertEqual(transcript.getStart(), 100)
            self.assertEqual(transcript.getEnd(), 100099)
            self.assertEqual(transcript.getDirection(), 1)
            self.assertEqual(transcript.getNbExons(), 1)
            self.assertEqual(transcript.getSize(), 100000)


    def test_compareTranscriptListgetDifferenceTranscriptList(self):
        container1 = TranscriptContainer("%s/Java/Python/TestFiles/testTranscriptListsComparatorCompareTranscriptListGetDifference1.gff3" % SMART_PATH, "gff", 0)
        container2 = TranscriptContainer("%s/Java/Python/TestFiles/testTranscriptListsComparatorCompareTranscriptListGetDifference2.gff3" % SMART_PATH, "gff", 0)

        comparator = TranscriptListsComparator(None, 0)
        comparator.setInputTranscriptContainer(comparator.QUERY, container1)
        comparator.setInputTranscriptContainer(comparator.REFERENCE, container2)
        comparator.setOutputWriter(Gff3Writer("output.gff3", 0))
        comparator.getDifferenceTranscriptList()

        parser = GffParser("output.gff3", 0)
        self.assertEqual(parser.getNbTranscripts(), 1)
        for transcript in parser.getIterator():
            self.assertEqual(transcript.getChromosome(), "arm_X")
            self.assertEqual(transcript.getStart(), 1000)
            self.assertEqual(transcript.getEnd(), 4000)
            self.assertEqual(transcript.getDirection(), 1)
            self.assertEqual(transcript.getNbExons(), 2)
            exon1, exon2 = transcript.getExons()
            self.assertEqual(exon1.getStart(), 1000)
            self.assertEqual(exon1.getEnd(), 1999)
            self.assertEqual(exon2.getStart(), 3001)
            self.assertEqual(exon2.getEnd(), 4000)



    def test_compareTranscriptListgetDifferenceTranscriptListSplit(self):
        container1 = TranscriptContainer("%s/Java/Python/TestFiles/testTranscriptListsComparatorCompareTranscriptListGetDifference1.gff3" % SMART_PATH, "gff", 0)
        container2 = TranscriptContainer("%s/Java/Python/TestFiles/testTranscriptListsComparatorCompareTranscriptListGetDifference2.gff3" % SMART_PATH, "gff", 0)

        comparator = TranscriptListsComparator(None, 0)
        comparator.setInputTranscriptContainer(comparator.QUERY, container1)
        comparator.setInputTranscriptContainer(comparator.REFERENCE, container2)
        comparator.setSplitDifference(True)
        comparator.setOutputWriter(Gff3Writer("output.gff3", 0))
        comparator.getDifferenceTranscriptList()

        parser = GffParser("output.gff3", 0)
        self.assertEqual(parser.getNbTranscripts(), 2)
        for id, transcript in enumerate(parser.getIterator()):
            if id == 0:
                self.assertEqual(transcript.getChromosome(), "arm_X")
                self.assertEqual(transcript.getStart(), 1000)
                self.assertEqual(transcript.getEnd(), 1999)
                self.assertEqual(transcript.getDirection(), 1)
                self.assertEqual(transcript.getNbExons(), 1)
            else:
                self.assertEqual(transcript.getChromosome(), "arm_X")
                self.assertEqual(transcript.getStart(), 3001)
                self.assertEqual(transcript.getEnd(), 4000)
                self.assertEqual(transcript.getDirection(), 1)
                self.assertEqual(transcript.getNbExons(), 1)


if __name__ == '__main__':
        unittest.main()
