view smart_toolShed/SMART/Java/Python/GetDifferentialExpression.py @ 4:1fc014126d55

Uploaded
author yufei-luo
date Fri, 18 Jan 2013 04:45:50 -0500
parents e0f8dcca02ed
children
line wrap: on
line source

#! /usr/bin/env python
#
# Copyright INRA-URGI 2009-2010
# 
# This software is governed by the CeCILL license under French law and
# abiding by the rules of distribution of free software. You can use,
# modify and/ or redistribute the software under the terms of the CeCILL
# license as circulated by CEA, CNRS and INRIA at the following URL
# "http://www.cecill.info".
# 
# As a counterpart to the access to the source code and rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty and the software's author, the holder of the
# economic rights, and the successive licensors have only limited
# liability.
# 
# In this respect, the user's attention is drawn to the risks associated
# with loading, using, modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean that it is complicated to manipulate, and that also
# therefore means that it is reserved for developers and experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and, more generally, to use and operate it in the
# same conditions as regards security.
# 
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL license and that you accept its terms.
#
"""Get the differential expression between 2 conditions (2 files), on regions defined by a third file"""

import os, re
from optparse import OptionParser
from SMART.Java.Python.structure.TranscriptContainer import TranscriptContainer
from commons.core.writer.Gff3Writer import Gff3Writer
from SMART.Java.Python.misc.Progress import Progress
from SMART.Java.Python.misc.RPlotter import RPlotter
from SMART.Java.Python.misc import Utils
from SMART.Java.Python.mySql.MySqlConnection import MySqlConnection
from SMART.Java.Python.structure.Transcript import Transcript

class GetDifferentialExpression(object):
    
    def __init__(self, verbosity = 1):
        self.verbosity              = verbosity
        self.mySqlConnection        = MySqlConnection(verbosity)
        self.inputs                 = (0, 1)
        self.transcriptContainers   = [None, None]
        self.transcriptContainerRef = None
        self.outputFileName         = None
        self.writer                 = None
        self.tables                 = [None, None]
        self.nbElements             = [0, 0]

        self.regionsToValues = {}
        self.regionsToNames  = {}
        self.valuesToPvalues = {}

        self.oriented                      = True
        self.simpleNormalization           = False
        self.simpleNormalizationParameters = None
        self.adjustedNormalization         = False
        self.fixedSizeFactor               = None
        self.normalizationSize             = None
        self.normalizationFactors          = [1, 1]
        self.fdr                           = None 
        self.fdrPvalue                     = None 

        self.plot    = False
        self.plotter = None
        self.plotterName = None
        self.points  = {}


    def setInputFile(self, i, fileName, fileFormat):
        self.transcriptContainers[i] = TranscriptContainer(fileName, fileFormat, self.verbosity)
        self.transcriptContainers[i].mySqlConnection = self.mySqlConnection


    def setReferenceFile(self, fileName, fileFormat):
        self.transcriptContainerRef = TranscriptContainer(fileName, fileFormat, self.verbosity)
        self.transcriptContainerRef.mySqlConnection = self.mySqlConnection


    def setOutputFile(self, fileName):
        self.outputFileName = fileName
        self.writer         = Gff3Writer(fileName, self.verbosity)

    
    def setOriented(self, boolean):
        self.oriented = boolean


    def setSimpleNormalization(self, boolean):
        self.simpleNormalization = boolean


    def setSimpleNormalizationParameters(self, parameters):
        if parameters != None:
            self.simpleNormalization = True
            self.simpleNormalizationParameters = [0, 0]
            for i, splittedParameter in enumerate(parameters.split(",")):
                self.simpleNormalizationParameters[i] = int(splittedParameter)


    def setAdjustedNormalization(self, boolean):
        self.adjustedNormalization = boolean


    def setFixedSizeNormalization(self, value):
        self.fixedSizeFactor = value


    def setFdr(self, fdr):
        self.fdr = fdr


    def setPlot(self, boolean):
        self.plot = boolean


    def setPlotterName(self, plotterName):
        self.plotterName = plotterName

    def setPlotter(self):
        self.plot    = True
        self.plotter = RPlotter(self.plotterName, self.verbosity)
        self.plotter.setPoints(True)
        self.plotter.setLog("xy")
        self.points = {}


    def readInput(self, i):
        self.transcriptContainers[i].storeIntoDatabase()
        self.tables[i] = self.transcriptContainers[i].getTables()
        progress       = Progress(len(self.tables[i].keys()), "Adding indices", self.verbosity)
        for chromosome in self.tables[i]:
            if self.oriented:
                self.tables[i][chromosome].createIndex("iStartEndDir_%s_%d" % (chromosome, i), ("start", "end", "direction"))
            else:
                self.tables[i][chromosome].createIndex("iStartEnd_%s_%d" % (chromosome, i), ("start", "end"))
            progress.inc()
        progress.done()
    
        progress = Progress(self.transcriptContainers[i].getNbTranscripts(), "Reading sample %d" % (i +1), self.verbosity)
        for chromosome in self.tables[i]:
            for transcript in self.tables[i][chromosome].getIterator():
                self.nbElements[i] += 1 if "nbElements" not in transcript.getTagNames() else transcript.getTagValue("nbElements")
                progress.inc()
        progress.done()
        if self.verbosity > 0:
            print "%d elements in sample %d" % (self.nbElements[i], i+1)


    def computeSimpleNormalizationFactors(self):
        nbElements = self.nbElements
        if self.simpleNormalizationParameters != None:
            print "Using provided normalization parameters: %s" % (", ".join([str(parameter) for parameter in self.simpleNormalizationParameters]))
            nbElements = self.simpleNormalizationParameters
        avgNbElements = int(float(sum(nbElements)) / len(nbElements))
        for i in self.inputs:
            self.normalizationFactors[i] = float(avgNbElements) / nbElements[i]
            self.nbElements[i]          *= self.normalizationFactors[i]
        if self.verbosity > 1:
            print "Normalizing to average # reads: %d" % (avgNbElements)
            if self.simpleNormalizationParameters != None:
                print "# reads: %s" % (", ".join([str(nbElement) for nbElement in self.nbElements]))

    def __del__(self):
        self.mySqlConnection.deleteDatabase()

    def regionToString(self, transcript):
        return "%s:%d-%d(%s)" % (transcript.getChromosome(), transcript.getStart(), transcript.getEnd(), "+" if transcript.getDirection() == 1 else "-")

    def stringToRegion(self, region):
        m = re.search(r"^(\S+):(\d+)-(\d+)\((\S)\)$", region)
        if m == None:
            raise Exception("Internal format error: cannot parse region '%s'" % (region))
        transcript = Transcript()
        transcript.setChromosome(m.group(1))
        transcript.setStart(int(m.group(2)))
        transcript.setEnd(int(m.group(3)))
        transcript.setDirection(m.group(4))
        return transcript

    def computeMinimumSize(self):
        self.normalizationSize = 1000000000
        progress = Progress(self.transcriptContainerRef.getNbTranscripts(), "Getting minimum reference size", self.verbosity)
        for transcriptRef in self.transcriptContainerRef.getIterator():
            self.normalizationSize = min(self.normalizationSize, transcriptRef.getEnd() - transcriptRef.getStart())
            progress.inc()
        progress.done()
        if self.verbosity > 1:
            print "Minimum reference size: %d" % (self.normalizationSize+1)

    def useFixedSizeNormalization(self, start, end, starts):
        currentNb = 0
        sum       = 0
        if not starts:
            return 0
        for i in range(start - self.normalizationSize, end + 1 + self.normalizationSize):
            if i not in starts:
                starts[i] = 0
        for i, s in starts.iteritems():
            if i < start:
                starts[start] += s
                starts[i]      = 0
        for i in range(start - self.normalizationSize, end + 1):
            currentNb += starts[i+self.normalizationSize] - starts[i]
            sum       += currentNb
        return (float(sum) / self.normalizationSize) * (self.fixedSizeFactor / (end - start + 1))

    def retrieveCounts(self, transcriptRef, i):
        if transcriptRef.getChromosome() not in self.tables[i]:
            return (0, 0)
        cumulatedCount           = 0
        cumulatedNormalizedCount = 0
        for exon in transcriptRef.getExons():
            count   = 0
            starts  = {}
            command = "SELECT start, tags FROM '%s' WHERE start >= %d AND end <= %d" % (self.tables[i][exon.getChromosome()].getName(), exon.getStart(), exon.getEnd())
            if self.oriented:
                command += " AND direction = %d" % (exon.getDirection())
            query = self.mySqlConnection.executeQuery(command)
            for line in query.getIterator():
                nb   = 1
                tags = line[1].split(";")
                for tag in tags:
                    key, value = tag.split("=")
                    if key == "nbElements":
                        nb = int(float(value))
                count += nb
                starts[int(line[0])] = nb
            normalizedCount = count if self.fixedSizeFactor == None else self.useFixedSizeNormalization(exon.getStart(), exon.getEnd(), starts)
            cumulatedCount           += count
            cumulatedNormalizedCount += normalizedCount
        return (cumulatedCount, cumulatedNormalizedCount)

    def getAllCounts(self):
        progress = Progress(self.transcriptContainerRef.getNbTranscripts(), "Getting counts", self.verbosity)
        for cpt, transcriptRef in enumerate(self.transcriptContainerRef.getIterator()):
            if "ID" in transcriptRef.getTagNames():
                self.regionsToNames[self.regionToString(transcriptRef)] = transcriptRef.getTagValue("ID")
            elif transcriptRef.getName() != None:
                self.regionsToNames[self.regionToString(transcriptRef)] = transcriptRef.getName()
            else:
                self.regionsToNames[self.regionToString(transcriptRef)] = "region_%d" % (cpt)
            values           = [None, None]
            normalizedValues = [None, None]
            for i in self.inputs:
                values[i], normalizedValues[i] = self.retrieveCounts(transcriptRef, i)
                normalizedValues[i]            = int(self.normalizationFactors[i] * normalizedValues[i])
            if sum(values) != 0:
                self.regionsToValues[self.regionToString(transcriptRef)] = (normalizedValues[0], normalizedValues[1], values[0], values[1])
            progress.inc()
        progress.done()

    def computeAdjustedNormalizationFactors(self):
        nbElements = len(self.regionsToValues.keys())
        avgValues  = []
        progress   = Progress(nbElements, "Normalization step 1", self.verbosity)
        for values in self.regionsToValues.values():
            correctedValues = [values[i] * self.normalizationFactors[i] for i in self.inputs]
            avgValues.append(float(sum(correctedValues)) / len(correctedValues))
            progress.inc()
        progress.done()

        sortedAvgValues = sorted(avgValues)
        minAvgValues    = sortedAvgValues[nbElements / 4]
        maxAvgValues    = sortedAvgValues[nbElements * 3 / 4]
        sums            = [0, 0]
        progress        = Progress(nbElements, "Normalization step 2", self.verbosity)
        for values in self.regionsToValues.values():
            correctedValues = [values[i] * self.normalizationFactors[i] for i in self.inputs]
            avgValue        = float(sum(correctedValues)) / len(correctedValues)
            if minAvgValues <= avgValue and avgValue <= maxAvgValues:
                for i in self.inputs:
                    sums[i] += values[i]
            progress.inc()
        progress.done()

        avgSums = float(sum(sums)) / len(sums)
        for i in self.inputs:
            if self.verbosity > 1:
                print "Normalizing sample %d: %s to" % ((i+1), self.nbElements[i]),
            self.normalizationFactors[i] *= float(avgSums) / sums[i]
            self.nbElements[i]           *= self.normalizationFactors[i]
            if self.verbosity > 1:
                print "%s" % (int(self.nbElements[i]))
                
    def getMinimumReferenceSize(self):
        self.normalizationSize = 1000000000
        progress               = Progress(self.transcriptContainerRef.getNbTranscripts(), "Reference element sizes", self.verbosity)
        for transcriptRef in self.transcriptContainerRef.getIterator():
            self.normalizationSize = min(self.normalizationSize, transcriptRef.getEnd() - transcriptRef.getStart() + 1)
            progress.inc()
        progress.done()
        if self.verbosity > 1:
            print "Minimum reference size: %d" % (self.normalizationSize)

    def computePvalues(self):
        normalizedValues = set()
        progress         = Progress(len(self.regionsToValues.keys()), "Normalizing counts", self.verbosity)
        for region in self.regionsToValues:
            values                       = self.regionsToValues[region]
            normalizedValues0            = int(round(values[0] * self.normalizationFactors[0]))
            normalizedValues1            = int(round(values[1] * self.normalizationFactors[1]))
            self.regionsToValues[region] = (normalizedValues0, normalizedValues1, self.regionsToValues[region][2], self.regionsToValues[region][3])
            normalizedValues.add((normalizedValues0, normalizedValues1, self.nbElements[0] - normalizedValues0, self.nbElements[1] - normalizedValues1, self.regionsToValues[region][2], self.regionsToValues[region][3]))
            progress.inc()
        progress.done()

        if self.verbosity > 1:
            print "Computing p-values..."
        self.valuesToPvalues = Utils.fisherExactPValueBulk(list(normalizedValues))
        if self.verbosity > 1:
            print "... done"

    def setTagValues(self, transcript, values, pValue):
        for tag in transcript.getTagNames():
            transcript.deleteTag(tag)
        transcript.removeExons()
        transcript.setTagValue("pValue", str(pValue))
        transcript.setTagValue("nbReadsCond1", str(values[0]))
        transcript.setTagValue("nbReadsCond2", str(values[1]))
        transcript.setTagValue("nbUnnormalizedReadsCond1", str(values[2]))
        transcript.setTagValue("nbUnnormalizedReadsCond2", str(values[3]))
        if (values[0] == values[1]) or (self.fdr != None and pValue > self.fdrPvalue):
            transcript.setTagValue("regulation", "equal")
        elif values[0] < values[1]:
            transcript.setTagValue("regulation", "up")
        else:
            transcript.setTagValue("regulation", "down")
        return transcript

    def computeFdr(self):
        pValues   = []
        nbRegions = len(self.regionsToValues.keys())
        progress  = Progress(nbRegions, "Computing FDR", self.verbosity)
        for values in self.regionsToValues.values():
            pValues.append(self.valuesToPvalues[values[0:2]])
            progress.inc()
        progress.done()
        
        for i, pValue in enumerate(reversed(sorted(pValues))):
            if pValue <= self.fdr * (nbRegions - 1 - i) / nbRegions:
                self.fdrPvalue = pValue
                if self.verbosity > 1:
                    print "FDR: %f, k: %i, m: %d" % (pValue, nbRegions - 1 - i, nbRegions)
                return

    def writeDifferentialExpression(self):
        if self.plot:
            self.setPlotter()

        cpt = 1
        progress = Progress(len(self.regionsToValues.keys()), "Writing output", self.verbosity)
        for region, values in self.regionsToValues.iteritems():
            transcript = self.stringToRegion(region)
            pValue     = self.valuesToPvalues[values[0:2]]
            transcript.setName(self.regionsToNames[region])
            transcript = self.setTagValues(transcript, values, pValue)
            self.writer.addTranscript(transcript)
            cpt += 1

            if self.plot:
                self.points[region] = (values[0], values[1])
        progress.done()
        self.writer.write()
        self.writer.close()

        if self.plot:
            self.plotter.addLine(self.points)
            self.plotter.plot()

    def getDifferentialExpression(self):
        for i in self.inputs:
            self.readInput(i)

        if self.simpleNormalization:
            self.computeSimpleNormalizationFactors()
        if self.fixedSizeFactor != None:
            self.computeMinimumSize()

        self.getAllCounts()

        if self.adjustedNormalization:
            self.computeAdjustedNormalizationFactors()

        self.computePvalues()

        if self.fdr != None:
            self.computeFdr()
            
        self.writeDifferentialExpression()


if __name__ == "__main__":
    
    # parse command line
    description = "Get Differential Expression v1.0.1: Get the differential expression between 2 conditions using Fisher's exact test, on regions defined by a third file. [Category: Data Comparison]"

    parser = OptionParser(description = description)
    parser.add_option("-i", "--input1",           dest="inputFileName1",    action="store",                     type="string", help="input file 1 [compulsory] [format: file in transcript format given by -f]")
    parser.add_option("-f", "--format1",          dest="format1",           action="store",                     type="string", help="format of file 1 [compulsory] [format: transcript file format]")
    parser.add_option("-j", "--input2",           dest="inputFileName2",    action="store",                     type="string", help="input file 2 [compulsory] [format: file in transcript format given by -g]")
    parser.add_option("-g", "--format2",          dest="format2",           action="store",                     type="string", help="format of file 2 [compulsory] [format: transcript file format]")
    parser.add_option("-k", "--reference",        dest="referenceFileName", action="store",                     type="string", help="reference file [compulsory] [format: file in transcript format given by -l]")
    parser.add_option("-l", "--referenceFormat",  dest="referenceFormat",   action="store",                     type="string", help="format of reference file [compulsory] [format: transcript file format]")
    parser.add_option("-o", "--output",           dest="outputFileName",    action="store",                     type="string", help="output file [format: output file in gff3 format]")
    parser.add_option("-n", "--notOriented",      dest="notOriented",       action="store_true", default=False,                help="if the reads are not oriented [default: False] [format: bool]")
    parser.add_option("-s", "--simple",           dest="simple",            action="store_true", default=False,                help="normalize using the number of reads in each condition [format: bool]")
    parser.add_option("-S", "--simpleParameters", dest="simpleParameters",  action="store",      default=None,  type="string", help="provide the number of reads [format: bool]")
    parser.add_option("-a", "--adjusted",         dest="adjusted",          action="store_true", default=False,                help="normalize using the number of reads of 'mean' regions [format: bool]")
    parser.add_option("-x", "--fixedSizeFactor",  dest="fixedSizeFactor",   action="store",      default=None,  type="int",    help="give the magnification factor for the normalization using fixed size sliding windows in reference regions (leave empty for no such normalization) [format: int]")
    parser.add_option("-d", "--fdr",              dest="fdr",               action="store",      default=None,  type="float",  help="use FDR [format: float]")
    parser.add_option("-p", "--plot",             dest="plotName",          action="store",      default=None,  type="string", help="plot cloud plot [format: output file in PNG format]")
    parser.add_option("-v", "--verbosity",        dest="verbosity",         action="store",      default=1,     type="int",    help="trace level [format: int]")
    (options, args) = parser.parse_args()


        
    differentialExpression = GetDifferentialExpression(options.verbosity)
    differentialExpression.setInputFile(0, options.inputFileName1, options.format1)
    differentialExpression.setInputFile(1, options.inputFileName2, options.format2)
    differentialExpression.setReferenceFile(options.referenceFileName, options.referenceFormat)
    differentialExpression.setOutputFile(options.outputFileName)
    if options.plotName != None :
        differentialExpression.setPlotterName(options.plotName)
        differentialExpression.setPlotter()
    differentialExpression.setOriented(not options.notOriented)
    differentialExpression.setSimpleNormalization(options.simple)
    differentialExpression.setSimpleNormalizationParameters(options.simpleParameters)
    differentialExpression.setAdjustedNormalization(options.adjusted)
    differentialExpression.setFixedSizeNormalization(options.fixedSizeFactor)
    differentialExpression.setFdr(options.fdr)
    differentialExpression.getDifferentialExpression()
    differentialExpression.mySqlConnection.deleteDatabase()