diff SMART/Java/Python/GetDifferentialExpression.py @ 38:2c0c0a89fad7

Uploaded
author m-zytnicki
date Thu, 02 May 2013 09:56:47 -0400
parents 769e306b7933
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/SMART/Java/Python/GetDifferentialExpression.py	Thu May 02 09:56:47 2013 -0400
@@ -0,0 +1,441 @@
+#! /usr/bin/env python
+#
+# Copyright INRA-URGI 2009-2010
+# 
+# This software is governed by the CeCILL license under French law and
+# abiding by the rules of distribution of free software. You can use,
+# modify and/ or redistribute the software under the terms of the CeCILL
+# license as circulated by CEA, CNRS and INRIA at the following URL
+# "http://www.cecill.info".
+# 
+# As a counterpart to the access to the source code and rights to copy,
+# modify and redistribute granted by the license, users are provided only
+# with a limited warranty and the software's author, the holder of the
+# economic rights, and the successive licensors have only limited
+# liability.
+# 
+# In this respect, the user's attention is drawn to the risks associated
+# with loading, using, modifying and/or developing or reproducing the
+# software by the user in light of its specific status of free software,
+# that may mean that it is complicated to manipulate, and that also
+# therefore means that it is reserved for developers and experienced
+# professionals having in-depth computer knowledge. Users are therefore
+# encouraged to load and test the software's suitability as regards their
+# requirements in conditions enabling the security of their systems and/or
+# data to be ensured and, more generally, to use and operate it in the
+# same conditions as regards security.
+# 
+# The fact that you are presently reading this means that you have had
+# knowledge of the CeCILL license and that you accept its terms.
+#
+"""Get the differential expression between 2 conditions (2 files), on regions defined by a third file"""
+
+import os, re
+from optparse import OptionParser
+from SMART.Java.Python.structure.TranscriptContainer import TranscriptContainer
+from commons.core.writer.Gff3Writer import Gff3Writer
+from SMART.Java.Python.misc.Progress import Progress
+from SMART.Java.Python.misc.RPlotter import RPlotter
+from SMART.Java.Python.misc import Utils
+from SMART.Java.Python.mySql.MySqlConnection import MySqlConnection
+from SMART.Java.Python.structure.Transcript import Transcript
+
+class GetDifferentialExpression(object):
+    
+    def __init__(self, verbosity = 1):
+        self.verbosity              = verbosity
+        self.mySqlConnection        = MySqlConnection(verbosity)
+        self.inputs                 = (0, 1)
+        self.transcriptContainers   = [None, None]
+        self.transcriptContainerRef = None
+        self.outputFileName         = None
+        self.writer                 = None
+        self.tables                 = [None, None]
+        self.nbElements             = [0, 0]
+
+        self.regionsToValues = {}
+        self.regionsToNames  = {}
+        self.valuesToPvalues = {}
+
+        self.oriented                      = True
+        self.simpleNormalization           = False
+        self.simpleNormalizationParameters = None
+        self.adjustedNormalization         = False
+        self.fixedSizeFactor               = None
+        self.normalizationSize             = None
+        self.normalizationFactors          = [1, 1]
+        self.fdr                           = None 
+        self.fdrPvalue                     = None 
+
+        self.plot    = False
+        self.plotter = None
+        self.plotterName = None
+        self.points  = {}
+
+
+    def setInputFile(self, i, fileName, fileFormat):
+        self.transcriptContainers[i] = TranscriptContainer(fileName, fileFormat, self.verbosity)
+        self.transcriptContainers[i].mySqlConnection = self.mySqlConnection
+
+
+    def setReferenceFile(self, fileName, fileFormat):
+        self.transcriptContainerRef = TranscriptContainer(fileName, fileFormat, self.verbosity)
+        self.transcriptContainerRef.mySqlConnection = self.mySqlConnection
+
+
+    def setOutputFile(self, fileName):
+        self.outputFileName = fileName
+        self.writer         = Gff3Writer(fileName, self.verbosity)
+
+    
+    def setOriented(self, boolean):
+        self.oriented = boolean
+
+
+    def setSimpleNormalization(self, boolean):
+        self.simpleNormalization = boolean
+
+
+    def setSimpleNormalizationParameters(self, parameters):
+        if parameters != None:
+            self.simpleNormalization = True
+            self.simpleNormalizationParameters = [0, 0]
+            for i, splittedParameter in enumerate(parameters.split(",")):
+                self.simpleNormalizationParameters[i] = int(splittedParameter)
+
+
+    def setAdjustedNormalization(self, boolean):
+        self.adjustedNormalization = boolean
+
+
+    def setFixedSizeNormalization(self, value):
+        self.fixedSizeFactor = value
+
+
+    def setFdr(self, fdr):
+        self.fdr = fdr
+
+
+    def setPlot(self, boolean):
+        self.plot = boolean
+
+
+    def setPlotterName(self, plotterName):
+        self.plotterName = plotterName
+
+    def setPlotter(self):
+        self.plot    = True
+        self.plotter = RPlotter(self.plotterName, self.verbosity)
+        self.plotter.setPoints(True)
+        self.plotter.setLog("xy")
+        self.points = {}
+
+
+    def readInput(self, i):
+        self.transcriptContainers[i].storeIntoDatabase()
+        self.tables[i] = self.transcriptContainers[i].getTables()
+        progress       = Progress(len(self.tables[i].keys()), "Adding indices", self.verbosity)
+        for chromosome in self.tables[i]:
+            if self.oriented:
+                self.tables[i][chromosome].createIndex("iStartEndDir_%s_%d" % (chromosome, i), ("start", "end", "direction"))
+            else:
+                self.tables[i][chromosome].createIndex("iStartEnd_%s_%d" % (chromosome, i), ("start", "end"))
+            progress.inc()
+        progress.done()
+    
+        progress = Progress(self.transcriptContainers[i].getNbTranscripts(), "Reading sample %d" % (i +1), self.verbosity)
+        for chromosome in self.tables[i]:
+            for transcript in self.tables[i][chromosome].getIterator():
+                self.nbElements[i] += 1 if "nbElements" not in transcript.getTagNames() else transcript.getTagValue("nbElements")
+                progress.inc()
+        progress.done()
+        if self.verbosity > 0:
+            print "%d elements in sample %d" % (self.nbElements[i], i+1)
+
+
+    def computeSimpleNormalizationFactors(self):
+        nbElements = self.nbElements
+        if self.simpleNormalizationParameters != None:
+            print "Using provided normalization parameters: %s" % (", ".join([str(parameter) for parameter in self.simpleNormalizationParameters]))
+            nbElements = self.simpleNormalizationParameters
+        avgNbElements = int(float(sum(nbElements)) / len(nbElements))
+        for i in self.inputs:
+            self.normalizationFactors[i] = float(avgNbElements) / nbElements[i]
+            self.nbElements[i]          *= self.normalizationFactors[i]
+        if self.verbosity > 1:
+            print "Normalizing to average # reads: %d" % (avgNbElements)
+            if self.simpleNormalizationParameters != None:
+                print "# reads: %s" % (", ".join([str(nbElement) for nbElement in self.nbElements]))
+
+    def __del__(self):
+        self.mySqlConnection.deleteDatabase()
+
+    def regionToString(self, transcript):
+        return "%s:%d-%d(%s)" % (transcript.getChromosome(), transcript.getStart(), transcript.getEnd(), "+" if transcript.getDirection() == 1 else "-")
+
+    def stringToRegion(self, region):
+        m = re.search(r"^(\S+):(\d+)-(\d+)\((\S)\)$", region)
+        if m == None:
+            raise Exception("Internal format error: cannot parse region '%s'" % (region))
+        transcript = Transcript()
+        transcript.setChromosome(m.group(1))
+        transcript.setStart(int(m.group(2)))
+        transcript.setEnd(int(m.group(3)))
+        transcript.setDirection(m.group(4))
+        return transcript
+
+    def computeMinimumSize(self):
+        self.normalizationSize = 1000000000
+        progress = Progress(self.transcriptContainerRef.getNbTranscripts(), "Getting minimum reference size", self.verbosity)
+        for transcriptRef in self.transcriptContainerRef.getIterator():
+            self.normalizationSize = min(self.normalizationSize, transcriptRef.getEnd() - transcriptRef.getStart())
+            progress.inc()
+        progress.done()
+        if self.verbosity > 1:
+            print "Minimum reference size: %d" % (self.normalizationSize+1)
+
+    def useFixedSizeNormalization(self, start, end, starts):
+        currentNb = 0
+        sum       = 0
+        if not starts:
+            return 0
+        for i in range(start - self.normalizationSize, end + 1 + self.normalizationSize):
+            if i not in starts:
+                starts[i] = 0
+        for i, s in starts.iteritems():
+            if i < start:
+                starts[start] += s
+                starts[i]      = 0
+        for i in range(start - self.normalizationSize, end + 1):
+            currentNb += starts[i+self.normalizationSize] - starts[i]
+            sum       += currentNb
+        return (float(sum) / self.normalizationSize) * (self.fixedSizeFactor / (end - start + 1))
+
+    def retrieveCounts(self, transcriptRef, i):
+        if transcriptRef.getChromosome() not in self.tables[i]:
+            return (0, 0)
+        cumulatedCount           = 0
+        cumulatedNormalizedCount = 0
+        for exon in transcriptRef.getExons():
+            count   = 0
+            starts  = {}
+            command = "SELECT start, tags FROM '%s' WHERE start >= %d AND end <= %d" % (self.tables[i][exon.getChromosome()].getName(), exon.getStart(), exon.getEnd())
+            if self.oriented:
+                command += " AND direction = %d" % (exon.getDirection())
+            query = self.mySqlConnection.executeQuery(command)
+            for line in query.getIterator():
+                nb   = 1
+                tags = line[1].split(";")
+                for tag in tags:
+                    key, value = tag.split("=")
+                    if key == "nbElements":
+                        nb = int(float(value))
+                count += nb
+                starts[int(line[0])] = nb
+            normalizedCount = count if self.fixedSizeFactor == None else self.useFixedSizeNormalization(exon.getStart(), exon.getEnd(), starts)
+            cumulatedCount           += count
+            cumulatedNormalizedCount += normalizedCount
+        return (cumulatedCount, cumulatedNormalizedCount)
+
+    def getAllCounts(self):
+        progress = Progress(self.transcriptContainerRef.getNbTranscripts(), "Getting counts", self.verbosity)
+        for cpt, transcriptRef in enumerate(self.transcriptContainerRef.getIterator()):
+            if "ID" in transcriptRef.getTagNames():
+                self.regionsToNames[self.regionToString(transcriptRef)] = transcriptRef.getTagValue("ID")
+            elif transcriptRef.getName() != None:
+                self.regionsToNames[self.regionToString(transcriptRef)] = transcriptRef.getName()
+            else:
+                self.regionsToNames[self.regionToString(transcriptRef)] = "region_%d" % (cpt)
+            values           = [None, None]
+            normalizedValues = [None, None]
+            for i in self.inputs:
+                values[i], normalizedValues[i] = self.retrieveCounts(transcriptRef, i)
+                normalizedValues[i]            = int(self.normalizationFactors[i] * normalizedValues[i])
+            if sum(values) != 0:
+                self.regionsToValues[self.regionToString(transcriptRef)] = (normalizedValues[0], normalizedValues[1], values[0], values[1])
+            progress.inc()
+        progress.done()
+
+    def computeAdjustedNormalizationFactors(self):
+        nbElements = len(self.regionsToValues.keys())
+        avgValues  = []
+        progress   = Progress(nbElements, "Normalization step 1", self.verbosity)
+        for values in self.regionsToValues.values():
+            correctedValues = [values[i] * self.normalizationFactors[i] for i in self.inputs]
+            avgValues.append(float(sum(correctedValues)) / len(correctedValues))
+            progress.inc()
+        progress.done()
+
+        sortedAvgValues = sorted(avgValues)
+        minAvgValues    = sortedAvgValues[nbElements / 4]
+        maxAvgValues    = sortedAvgValues[nbElements * 3 / 4]
+        sums            = [0, 0]
+        progress        = Progress(nbElements, "Normalization step 2", self.verbosity)
+        for values in self.regionsToValues.values():
+            correctedValues = [values[i] * self.normalizationFactors[i] for i in self.inputs]
+            avgValue        = float(sum(correctedValues)) / len(correctedValues)
+            if minAvgValues <= avgValue and avgValue <= maxAvgValues:
+                for i in self.inputs:
+                    sums[i] += values[i]
+            progress.inc()
+        progress.done()
+
+        avgSums = float(sum(sums)) / len(sums)
+        for i in self.inputs:
+            if self.verbosity > 1:
+                print "Normalizing sample %d: %s to" % ((i+1), self.nbElements[i]),
+            self.normalizationFactors[i] *= float(avgSums) / sums[i]
+            self.nbElements[i]           *= self.normalizationFactors[i]
+            if self.verbosity > 1:
+                print "%s" % (int(self.nbElements[i]))
+                
+    def getMinimumReferenceSize(self):
+        self.normalizationSize = 1000000000
+        progress               = Progress(self.transcriptContainerRef.getNbTranscripts(), "Reference element sizes", self.verbosity)
+        for transcriptRef in self.transcriptContainerRef.getIterator():
+            self.normalizationSize = min(self.normalizationSize, transcriptRef.getEnd() - transcriptRef.getStart() + 1)
+            progress.inc()
+        progress.done()
+        if self.verbosity > 1:
+            print "Minimum reference size: %d" % (self.normalizationSize)
+
+    def computePvalues(self):
+        normalizedValues = set()
+        progress         = Progress(len(self.regionsToValues.keys()), "Normalizing counts", self.verbosity)
+        for region in self.regionsToValues:
+            values                       = self.regionsToValues[region]
+            normalizedValues0            = int(round(values[0] * self.normalizationFactors[0]))
+            normalizedValues1            = int(round(values[1] * self.normalizationFactors[1]))
+            self.regionsToValues[region] = (normalizedValues0, normalizedValues1, self.regionsToValues[region][2], self.regionsToValues[region][3])
+            normalizedValues.add((normalizedValues0, normalizedValues1, self.nbElements[0] - normalizedValues0, self.nbElements[1] - normalizedValues1, self.regionsToValues[region][2], self.regionsToValues[region][3]))
+            progress.inc()
+        progress.done()
+
+        if self.verbosity > 1:
+            print "Computing p-values..."
+        self.valuesToPvalues = Utils.fisherExactPValueBulk(list(normalizedValues))
+        if self.verbosity > 1:
+            print "... done"
+
+    def setTagValues(self, transcript, values, pValue):
+        for tag in transcript.getTagNames():
+            transcript.deleteTag(tag)
+        transcript.removeExons()
+        transcript.setTagValue("pValue", str(pValue))
+        transcript.setTagValue("nbReadsCond1", str(values[0]))
+        transcript.setTagValue("nbReadsCond2", str(values[1]))
+        transcript.setTagValue("nbUnnormalizedReadsCond1", str(values[2]))
+        transcript.setTagValue("nbUnnormalizedReadsCond2", str(values[3]))
+        if (values[0] == values[1]) or (self.fdr != None and pValue > self.fdrPvalue):
+            transcript.setTagValue("regulation", "equal")
+        elif values[0] < values[1]:
+            transcript.setTagValue("regulation", "up")
+        else:
+            transcript.setTagValue("regulation", "down")
+        return transcript
+
+    def computeFdr(self):
+        pValues   = []
+        nbRegions = len(self.regionsToValues.keys())
+        progress  = Progress(nbRegions, "Computing FDR", self.verbosity)
+        for values in self.regionsToValues.values():
+            pValues.append(self.valuesToPvalues[values[0:2]])
+            progress.inc()
+        progress.done()
+        
+        for i, pValue in enumerate(reversed(sorted(pValues))):
+            if pValue <= self.fdr * (nbRegions - 1 - i) / nbRegions:
+                self.fdrPvalue = pValue
+                if self.verbosity > 1:
+                    print "FDR: %f, k: %i, m: %d" % (pValue, nbRegions - 1 - i, nbRegions)
+                return
+
+    def writeDifferentialExpression(self):
+        if self.plot:
+            self.setPlotter()
+
+        cpt = 1
+        progress = Progress(len(self.regionsToValues.keys()), "Writing output", self.verbosity)
+        for region, values in self.regionsToValues.iteritems():
+            transcript = self.stringToRegion(region)
+            pValue     = self.valuesToPvalues[values[0:2]]
+            transcript.setName(self.regionsToNames[region])
+            transcript = self.setTagValues(transcript, values, pValue)
+            self.writer.addTranscript(transcript)
+            cpt += 1
+
+            if self.plot:
+                self.points[region] = (values[0], values[1])
+        progress.done()
+        self.writer.write()
+        self.writer.close()
+
+        if self.plot:
+            self.plotter.addLine(self.points)
+            self.plotter.plot()
+
+    def getDifferentialExpression(self):
+        for i in self.inputs:
+            self.readInput(i)
+
+        if self.simpleNormalization:
+            self.computeSimpleNormalizationFactors()
+        if self.fixedSizeFactor != None:
+            self.computeMinimumSize()
+
+        self.getAllCounts()
+
+        if self.adjustedNormalization:
+            self.computeAdjustedNormalizationFactors()
+
+        self.computePvalues()
+
+        if self.fdr != None:
+            self.computeFdr()
+            
+        self.writeDifferentialExpression()
+
+
+if __name__ == "__main__":
+    
+    # parse command line
+    description = "Get Differential Expression v1.0.1: Get the differential expression between 2 conditions using Fisher's exact test, on regions defined by a third file. [Category: Data Comparison]"
+
+    parser = OptionParser(description = description)
+    parser.add_option("-i", "--input1",           dest="inputFileName1",    action="store",                     type="string", help="input file 1 [compulsory] [format: file in transcript format given by -f]")
+    parser.add_option("-f", "--format1",          dest="format1",           action="store",                     type="string", help="format of file 1 [compulsory] [format: transcript file format]")
+    parser.add_option("-j", "--input2",           dest="inputFileName2",    action="store",                     type="string", help="input file 2 [compulsory] [format: file in transcript format given by -g]")
+    parser.add_option("-g", "--format2",          dest="format2",           action="store",                     type="string", help="format of file 2 [compulsory] [format: transcript file format]")
+    parser.add_option("-k", "--reference",        dest="referenceFileName", action="store",                     type="string", help="reference file [compulsory] [format: file in transcript format given by -l]")
+    parser.add_option("-l", "--referenceFormat",  dest="referenceFormat",   action="store",                     type="string", help="format of reference file [compulsory] [format: transcript file format]")
+    parser.add_option("-o", "--output",           dest="outputFileName",    action="store",                     type="string", help="output file [format: output file in gff3 format]")
+    parser.add_option("-n", "--notOriented",      dest="notOriented",       action="store_true", default=False,                help="if the reads are not oriented [default: False] [format: bool]")
+    parser.add_option("-s", "--simple",           dest="simple",            action="store_true", default=False,                help="normalize using the number of reads in each condition [format: bool]")
+    parser.add_option("-S", "--simpleParameters", dest="simpleParameters",  action="store",      default=None,  type="string", help="provide the number of reads [format: bool]")
+    parser.add_option("-a", "--adjusted",         dest="adjusted",          action="store_true", default=False,                help="normalize using the number of reads of 'mean' regions [format: bool]")
+    parser.add_option("-x", "--fixedSizeFactor",  dest="fixedSizeFactor",   action="store",      default=None,  type="int",    help="give the magnification factor for the normalization using fixed size sliding windows in reference regions (leave empty for no such normalization) [format: int]")
+    parser.add_option("-d", "--fdr",              dest="fdr",               action="store",      default=None,  type="float",  help="use FDR [format: float]")
+    parser.add_option("-p", "--plot",             dest="plotName",          action="store",      default=None,  type="string", help="plot cloud plot [format: output file in PNG format]")
+    parser.add_option("-v", "--verbosity",        dest="verbosity",         action="store",      default=1,     type="int",    help="trace level [format: int]")
+    (options, args) = parser.parse_args()
+
+
+        
+    differentialExpression = GetDifferentialExpression(options.verbosity)
+    differentialExpression.setInputFile(0, options.inputFileName1, options.format1)
+    differentialExpression.setInputFile(1, options.inputFileName2, options.format2)
+    differentialExpression.setReferenceFile(options.referenceFileName, options.referenceFormat)
+    differentialExpression.setOutputFile(options.outputFileName)
+    if options.plotName != None :
+        differentialExpression.setPlotterName(options.plotName)
+        differentialExpression.setPlotter()
+    differentialExpression.setOriented(not options.notOriented)
+    differentialExpression.setSimpleNormalization(options.simple)
+    differentialExpression.setSimpleNormalizationParameters(options.simpleParameters)
+    differentialExpression.setAdjustedNormalization(options.adjusted)
+    differentialExpression.setFixedSizeNormalization(options.fixedSizeFactor)
+    differentialExpression.setFdr(options.fdr)
+    differentialExpression.getDifferentialExpression()
+    differentialExpression.mySqlConnection.deleteDatabase()
+    
+