view SMART/Java/Python/test/Test_F_RestrictFromCoverage.py @ 31:0ab839023fe4

Uploaded
author m-zytnicki
date Tue, 30 Apr 2013 14:33:21 -0400
parents 94ab73e8a190
children
line wrap: on
line source

import unittest
import os, os.path
from optparse import OptionParser
from SMART.Java.Python.misc import Utils
from SMART.Java.Python.structure.Transcript import Transcript
from commons.core.writer.Gff3Writer import Gff3Writer
from commons.core.parsing.GffParser import GffParser
from SMART.Java.Python.RestrictFromCoverage import RestrictFromCoverage

REFERENCE = 0
QUERY     = 1

class Test_F_Clusterize(unittest.TestCase):

    def setUp(self):
        self._queryFileName     = "testQuery.gff3"
        self._refFileName       = "testRef.gff3"
        self._outputFileName    = "output.gff3"
        self._parsers           = {QUERY: Gff3Writer(self._queryFileName, 0), REFERENCE: Gff3Writer(self._refFileName, 0)}
        self._writeQuery()
        
    def tearDown(self):
        for file in (self._queryFileName, self._refFileName, self._outputFileName):
            if os.path.exists(file):
                os.remove(file)

    def _writeQuery(self):
        self._addTranscript(QUERY, 1, 1000, 2000, "+")
        self._parsers[QUERY].close()

    def _writeReferences(self, values):
        for value in values:
            self._addTranscript(REFERENCE, value["cpt"], value["start"], value["end"], value["strand"])
        self._parsers[REFERENCE].close()

    def _addTranscript(self, type, cpt, start, end, strand):
        t = Transcript()
        t.setChromosome("chr1")
        t.setName("test%d" % (cpt))
        t.setStart(start)
        t.setEnd(end)
        t.setDirection(strand)
        self._parsers[type].addTranscript(t)

    def _checkTranscript(self, transcript, start, end, strand):
        self.assertEquals(transcript.getStart(),     start)
        self.assertEquals(transcript.getEnd(),       end)
        self.assertEquals(transcript.getDirection(), strand)

    def _startTool(self, minNucleotides = None, maxNucleotides = None, minPercent = None, maxPercent = None, minOverlap = None, maxOverlap = None, strands = False):
        rfc = RestrictFromCoverage(0)
        rfc.setInputFileName(self._queryFileName, "gff3", QUERY)
        rfc.setInputFileName(self._refFileName,   "gff3", REFERENCE)
        rfc.setOutputFileName(self._outputFileName)
        rfc.setNbNucleotides(minNucleotides, maxNucleotides)
        rfc.setPercent(minPercent, maxPercent)
        rfc.setOverlap(minOverlap, maxOverlap)
        rfc.setStrands(strands)
        rfc.run()
        return GffParser(self._outputFileName, 0)

    def test_simple(self):
        self._writeReferences([{"cpt": 1, "start": 1000, "end": 2000, "strand": "+"}])
        parser = self._startTool()
        self.assertEquals(parser.getNbTranscripts(), 1)
        for transcript in parser.getIterator():
            self._checkTranscript(transcript, 1000, 2000, 1)

    def test_nbOverlapsMin_pos(self):
        self._writeReferences([{"cpt": 1, "start": 1000, "end": 2000, "strand": "+"}, {"cpt": 2, "start": 1000, "end": 2000, "strand": "+"}])
        parser = self._startTool(1, None, None, None, 2)
        self.assertEquals(parser.getNbTranscripts(), 1)
        for transcript in parser.getIterator():
            self._checkTranscript(transcript, 1000, 2000, 1)

    def test_nbOverlapsMin_neg(self):
        self._writeReferences([{"cpt": 1, "start": 1000, "end": 2000, "strand": "+"}])
        parser = self._startTool(1, None, None, None, 2)
        self.assertEquals(parser.getNbTranscripts(), 0)

    def test_nbOverlapsMax_pos(self):
        self._writeReferences([{"cpt": 1, "start": 1000, "end": 2000, "strand": "+"}])
        parser = self._startTool(1, None, None, None, None, 1)
        self.assertEquals(parser.getNbTranscripts(), 1)
        for transcript in parser.getIterator():
            self._checkTranscript(transcript, 1000, 2000, 1)

    def test_nbOverlapsMax_neg(self):
        self._writeReferences([{"cpt": 1, "start": 1000, "end": 2000, "strand": "+"}, {"cpt": 2, "start": 1000, "end": 2000, "strand": "+"}])
        parser = self._startTool(1, None, None, None, None, 1)
        self.assertEquals(parser.getNbTranscripts(), 0)

    def test_nbNucleotidesMin_pos(self):
        self._writeReferences([{"cpt": 1, "start": 1000, "end": 1100, "strand": "+"}])
        parser = self._startTool(100, None, None, None, 1)
        self.assertEquals(parser.getNbTranscripts(), 1)
        for transcript in parser.getIterator():
            self._checkTranscript(transcript, 1000, 2000, 1)

    def test_nbNucleotidesMin_neg(self):
        self._writeReferences([{"cpt": 1, "start": 1000, "end": 1100, "strand": "+"}])
        parser = self._startTool(200, None, None, None, 1)
        self.assertEquals(parser.getNbTranscripts(), 0)

    def test_PercentMin_pos(self):
        self._writeReferences([{"cpt": 1, "start": 1000, "end": 1500, "strand": "+"}])
        parser = self._startTool(None, None, 50, None, 1)
        self.assertEquals(parser.getNbTranscripts(), 1)
        for transcript in parser.getIterator():
            self._checkTranscript(transcript, 1000, 2000, 1)

    def test_PercentMin_neg(self):
        self._writeReferences([{"cpt": 1, "start": 1000, "end": 1500, "strand": "+"}])
        parser = self._startTool(None, None, 100, None, 1)
        self.assertEquals(parser.getNbTranscripts(), 0)

    def test_NoStrand_neg(self):
        self._writeReferences([{"cpt": 1, "start": 1000, "end": 1500, "strand": "-"}])
        parser = self._startTool(1, None, None, None, 1)
        self.assertEquals(parser.getNbTranscripts(), 1)

    def test_strand_pos(self):
        self._writeReferences([{"cpt": 1, "start": 1000, "end": 1500, "strand": "+"}])
        parser = self._startTool(1, None, None, None, 1, None, True)
        self.assertEquals(parser.getNbTranscripts(), 1)

    def test_strand_neg(self):
        self._writeReferences([{"cpt": 1, "start": 1000, "end": 1500, "strand": "-"}])
        parser = self._startTool(1, None, None, None, 1, None, True)
        self.assertEquals(parser.getNbTranscripts(), 0)

if __name__ == "__main__":
    unittest.main()