import unittest
from SMART.Java.Python.structure.Interval import Interval
from SMART.Java.Python.structure.Transcript import Transcript

class Test_Transcript(unittest.TestCase):
  
    def test_getSize(self):
        transcript1 = Transcript()
        transcript1.setDirection("+")
        transcript1.setStart(2000)
        transcript1.setEnd(3000)
        transcript1.setChromosome("arm_X")
    
        self.assertEqual(transcript1.getSize(), 1001)
      
        transcript2 = Transcript()
        transcript2.copy(transcript1)
        self.assertEqual(transcript1.getSize(), 1001)
    
        transcript3 = Transcript()
        transcript3.setDirection("+")
        transcript3.setChromosome("arm_X")
    
        exon1 = Interval()
        exon1.setDirection("+")
        exon1.setChromosome("arm_X")
        exon1.setStart(100)
        exon1.setEnd(200)
        transcript3.addExon(exon1)
    
        exon2 = Interval()
        exon2.setDirection("+")
        exon2.setChromosome("arm_X")
        exon2.setStart(300)
        exon2.setEnd(400)
        transcript3.addExon(exon2)
    
        self.assertEqual(transcript3.getSize(), 203)


    def test_overlapWithExons(self):
        exon1_1 = Interval()
        exon1_1.setChromosome("chr1")
        exon1_1.setStart(100)
        exon1_1.setEnd(200)
        exon1_1.setDirection("+")

        exon1_2 = Interval()
        exon1_2.setChromosome("chr1")
        exon1_2.setStart(500)
        exon1_2.setEnd(600)
        exon1_2.setDirection("+")

        transcript1 = Transcript()
        transcript1.setChromosome("chr1")
        transcript1.setStart(100)
        transcript1.setEnd(600)
        transcript1.setDirection("+")
        transcript1.addExon(exon1_1)
        transcript1.addExon(exon1_2)

        exon2_1 = Interval()
        exon2_1.copy(exon1_1)

        transcript2 = Transcript()
        transcript2.setChromosome("chr1")
        transcript2.setStart(100)
        transcript2.setEnd(200)
        transcript2.setDirection("+")
        transcript2.addExon(exon2_1)
    
        self.assertTrue(transcript1.overlapWithExon(transcript2))

        transcript2.reverse()
        try:
            self.assertFalse(transcript1.overlapWithExon(transcript2))
        except Exception:
            pass
      
        transcript2.reverse()
        transcript2.setChromosome("chr2")
        self.assertFalse(transcript1.overlapWithExon(transcript2))

        exon3_1 = Interval()
        exon3_1.copy(exon1_1)
        exon3_1.setEnd(400)
        exon3_1.setStart(300)

        transcript3 = Transcript()
        transcript3.setChromosome("chr1")
        transcript3.setStart(300)
        transcript3.setEnd(400)
        transcript3.setDirection("+")
        transcript3.addExon(exon3_1)
        self.assertFalse(transcript1.overlapWithExon(transcript3))


    def test_merge(self):
        exon1_1 = Interval()
        exon1_1.setChromosome("chr1")
        exon1_1.setStart(100)
        exon1_1.setEnd(200)
        exon1_1.setDirection("+")

        exon1_2 = Interval()
        exon1_2.setChromosome("chr1")
        exon1_2.setStart(500)
        exon1_2.setEnd(600)
        exon1_2.setDirection("+")

        transcript1 = Transcript()
        transcript1.setChromosome("chr1")
        transcript1.setEnd(600)
        transcript1.setStart(100)
        transcript1.setDirection("+")
        transcript1.addExon(exon1_1)
        transcript1.addExon(exon1_2)

        exon2_1 = Interval()
        exon2_1.copy(exon1_1)

        transcript2 = Transcript()
        transcript2.setChromosome("chr1")
        transcript2.setEnd(200)
        transcript2.setStart(100)
        transcript2.setDirection("+")
        transcript2.addExon(exon2_1)
    
        transcript1.merge(transcript2)
        transcript1.sortExonsIncreasing()
        exons = transcript1.getExons()
        self.assertEqual(len(exons), 2)
        exon1, exon2 = exons
        self.assertEqual(exon1.getStart(), 100)
        self.assertEqual(exon1.getEnd(),   200)
        self.assertEqual(exon2.getStart(), 500)
        self.assertEqual(exon2.getEnd(),   600)

        transcript2.setChromosome("chr2")
        try:
            transcript1.merge(transcript2)
            self.fail()
        except Exception:
            pass
      
        exon3_1 = Interval()
        exon3_1.copy(exon1_1)
        exon3_1.setEnd(650)
        exon3_1.setStart(550)

        transcript3 = Transcript()
        transcript3.setChromosome("chr1")
        transcript3.setEnd(650)
        transcript3.setStart(550)
        transcript3.setDirection("+")
        transcript3.addExon(exon3_1)

        transcript1.merge(transcript3)
        self.assertEqual(transcript1.getStart(), 100)
        self.assertEqual(transcript1.getEnd(),   650)
        exons = transcript1.getExons()
        self.assertEqual(len(exons), 2)
        exon1, exon2 = exons
        self.assertEqual(exon1.getStart(), 100)
        self.assertEqual(exon1.getEnd(),   200)
        self.assertEqual(exon2.getStart(), 500)
        self.assertEqual(exon2.getEnd(),   650)

        exon4_1 = Interval()
        exon4_1.copy(exon1_1)
        exon4_1.setEnd(400)
        exon4_1.setStart(300)

        transcript4 = Transcript()
        transcript4.setChromosome("chr1")
        transcript4.setStart(300)
        transcript4.setEnd(400)
        transcript4.setDirection("+")
        transcript4.addExon(exon4_1)

        transcript1.merge(transcript4)
        self.assertEqual(transcript1.getStart(), 100)
        self.assertEqual(transcript1.getEnd(),   650)
        transcript1.sortExonsIncreasing()
        exons = transcript1.getExons()
        self.assertEqual(len(exons), 3)
        exon1, exon2, exon3 = exons
        self.assertEqual(exon1.getStart(), 100)
        self.assertEqual(exon1.getEnd(),   200)
        self.assertEqual(exon2.getStart(), 300)
        self.assertEqual(exon2.getEnd(),   400)
        self.assertEqual(exon3.getStart(), 500)
        self.assertEqual(exon3.getEnd(),   650)


    def test_extendStart(self):
        transcript1 = Transcript()
        transcript1.setStart(2000)
        transcript1.setEnd(3000)
        transcript1.setDirection("+")
        transcript1.setChromosome("arm_X")
      
        transcript2 = Transcript()
        transcript2.copy(transcript1)
        transcript2.setDirection("-")
    
        transcript1.extendStart(1000)
        transcript2.extendStart(1000)

    
        self.assertEqual(transcript1.getDirection(),  1)
        self.assertEqual(transcript1.getStart(),      1000)
        self.assertEqual(transcript1.getEnd(),        3000)
        self.assertEqual(transcript1.getChromosome(), "arm_X")
    
        self.assertEqual(transcript2.getDirection(),  -1)
        self.assertEqual(transcript2.getStart(),      2000)
        self.assertEqual(transcript2.getEnd(),        4000)
        self.assertEqual(transcript2.getChromosome(), "arm_X")


    def test_restrictStart(self):
        exon1_1 = Interval()
        exon1_1.setChromosome("chr1")
        exon1_1.setStart(100)
        exon1_1.setEnd(200)
        exon1_1.setDirection("+")

        exon1_2 = Interval()
        exon1_2.setChromosome("chr1")
        exon1_2.setStart(300)
        exon1_2.setEnd(500)
        exon1_2.setDirection("+")

        transcript1 = Transcript()
        transcript1.setChromosome("chr1")
        transcript1.setStart(100)
        transcript1.setEnd(500)
        transcript1.setDirection("+")
        transcript1.addExon(exon1_1)
        transcript1.addExon(exon1_2)

        transcript1.restrictStart(301)
        exons = transcript1.getExons()
        self.assertEqual(len(exons), 2)
        exon1, exon2 = exons
        self.assertEqual(exon1.getStart(), 100)
        self.assertEqual(exon1.getEnd(),   200)
        self.assertEqual(exon2.getStart(), 300)
        self.assertEqual(exon2.getEnd(),   400)


    def test__include(self):
        iTranscript1 = Transcript()
        iTranscript1.setName("transcript1")
        iTranscript1.setChromosome("chr1")
        iTranscript1.setStart(100)
        iTranscript1.setEnd(200)
        iTranscript1.setDirection("+")
        
        iTranscript2 = Transcript()
        iTranscript2.copy(iTranscript1)
        iTranscript2.setName("transcript2")
        self.assertTrue(iTranscript1.include(iTranscript2))
        self.assertTrue(iTranscript2.include(iTranscript1))

        iTranscript2.setChromosome("chr2")
        self.assertFalse(iTranscript1.include(iTranscript2))
        self.assertFalse(iTranscript2.include(iTranscript1))

        iTranscript2.setChromosome("chr1")
        exon = Interval()
        exon.setChromosome("chr1")
        exon.setDirection("+")
        exon.setStart(300)
        exon.setEnd(400)
        iTranscript1.addExon(exon)
        self.assertTrue(iTranscript1.include(iTranscript2))
        self.assertFalse(iTranscript2.include(iTranscript1))
        
        exon = Interval()
        exon.setChromosome("chr1")
        exon.setDirection("+")
        exon.setStart(500)
        exon.setEnd(600)
        iTranscript2.addExon(exon)
        self.assertFalse(iTranscript1.include(iTranscript2))
        self.assertFalse(iTranscript2.include(iTranscript1))
        

    def test__getDifference(self):
        iTranscript1 = Transcript()
        iTranscript1.setName("transcript1")
        iTranscript1.setChromosome("chr1")
        iTranscript1.setStart(100)
        iTranscript1.setEnd(400)
        iTranscript1.setDirection("+")
        
        iTranscript2 = Transcript()
        iTranscript2.setName("transcript1")
        iTranscript2.setChromosome("chr1")
        iTranscript2.setStart(200)
        iTranscript2.setEnd(400)
        iTranscript2.setDirection("+")

        newTranscript = iTranscript1.getDifference(iTranscript2)
        self.assertTrue(newTranscript.getStart(), 100)
        self.assertTrue(newTranscript.getEnd(),   199)
        exons = newTranscript.getExons()
        self.assertTrue(len(exons), 1)
        exon1 = exons[0]
        self.assertTrue(exon1.getStart(), 100)
        self.assertTrue(exon1.getEnd(),   199)

        iTranscript2 = Transcript()
        iTranscript2.setName("transcript1")
        iTranscript2.setChromosome("chr1")
        iTranscript2.setStart(100)
        iTranscript2.setEnd(200)
        iTranscript2.setDirection("+")

        newTranscript = iTranscript1.getDifference(iTranscript2)
        self.assertTrue(newTranscript.getStart(), 201)
        self.assertTrue(newTranscript.getEnd(),   400)
        exons = newTranscript.getExons()
        self.assertTrue(len(exons), 1)
        exon1 = exons[0]
        self.assertTrue(exon1.getStart(), 201)
        self.assertTrue(exon1.getEnd(),   400)

        iTranscript2 = Transcript()
        iTranscript2.setName("transcript1")
        iTranscript2.setChromosome("chr1")
        iTranscript2.setStart(200)
        iTranscript2.setEnd(300)
        iTranscript2.setDirection("+")

        newTranscript = iTranscript1.getDifference(iTranscript2)
        self.assertTrue(newTranscript.getStart(), 100)
        self.assertTrue(newTranscript.getEnd(),   400)
        exons = newTranscript.getExons()
        self.assertTrue(len(exons), 2)
        exon1, exon2 = exons
        self.assertTrue(exon1.getStart(), 100)
        self.assertTrue(exon1.getEnd(),   199)
        self.assertTrue(exon2.getStart(), 301)
        self.assertTrue(exon2.getEnd(),   400)


if __name__ == '__main__':
    unittest.main()
