Mercurial > repos > yufei-luo > s_mart
diff SMART/Java/Python/GetDifferentialExpression.py @ 38:2c0c0a89fad7
Uploaded
author | m-zytnicki |
---|---|
date | Thu, 02 May 2013 09:56:47 -0400 |
parents | 769e306b7933 |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/SMART/Java/Python/GetDifferentialExpression.py Thu May 02 09:56:47 2013 -0400 @@ -0,0 +1,441 @@ +#! /usr/bin/env python +# +# Copyright INRA-URGI 2009-2010 +# +# This software is governed by the CeCILL license under French law and +# abiding by the rules of distribution of free software. You can use, +# modify and/ or redistribute the software under the terms of the CeCILL +# license as circulated by CEA, CNRS and INRIA at the following URL +# "http://www.cecill.info". +# +# As a counterpart to the access to the source code and rights to copy, +# modify and redistribute granted by the license, users are provided only +# with a limited warranty and the software's author, the holder of the +# economic rights, and the successive licensors have only limited +# liability. +# +# In this respect, the user's attention is drawn to the risks associated +# with loading, using, modifying and/or developing or reproducing the +# software by the user in light of its specific status of free software, +# that may mean that it is complicated to manipulate, and that also +# therefore means that it is reserved for developers and experienced +# professionals having in-depth computer knowledge. Users are therefore +# encouraged to load and test the software's suitability as regards their +# requirements in conditions enabling the security of their systems and/or +# data to be ensured and, more generally, to use and operate it in the +# same conditions as regards security. +# +# The fact that you are presently reading this means that you have had +# knowledge of the CeCILL license and that you accept its terms. +# +"""Get the differential expression between 2 conditions (2 files), on regions defined by a third file""" + +import os, re +from optparse import OptionParser +from SMART.Java.Python.structure.TranscriptContainer import TranscriptContainer +from commons.core.writer.Gff3Writer import Gff3Writer +from SMART.Java.Python.misc.Progress import Progress +from SMART.Java.Python.misc.RPlotter import RPlotter +from SMART.Java.Python.misc import Utils +from SMART.Java.Python.mySql.MySqlConnection import MySqlConnection +from SMART.Java.Python.structure.Transcript import Transcript + +class GetDifferentialExpression(object): + + def __init__(self, verbosity = 1): + self.verbosity = verbosity + self.mySqlConnection = MySqlConnection(verbosity) + self.inputs = (0, 1) + self.transcriptContainers = [None, None] + self.transcriptContainerRef = None + self.outputFileName = None + self.writer = None + self.tables = [None, None] + self.nbElements = [0, 0] + + self.regionsToValues = {} + self.regionsToNames = {} + self.valuesToPvalues = {} + + self.oriented = True + self.simpleNormalization = False + self.simpleNormalizationParameters = None + self.adjustedNormalization = False + self.fixedSizeFactor = None + self.normalizationSize = None + self.normalizationFactors = [1, 1] + self.fdr = None + self.fdrPvalue = None + + self.plot = False + self.plotter = None + self.plotterName = None + self.points = {} + + + def setInputFile(self, i, fileName, fileFormat): + self.transcriptContainers[i] = TranscriptContainer(fileName, fileFormat, self.verbosity) + self.transcriptContainers[i].mySqlConnection = self.mySqlConnection + + + def setReferenceFile(self, fileName, fileFormat): + self.transcriptContainerRef = TranscriptContainer(fileName, fileFormat, self.verbosity) + self.transcriptContainerRef.mySqlConnection = self.mySqlConnection + + + def setOutputFile(self, fileName): + self.outputFileName = fileName + self.writer = Gff3Writer(fileName, self.verbosity) + + + def setOriented(self, boolean): + self.oriented = boolean + + + def setSimpleNormalization(self, boolean): + self.simpleNormalization = boolean + + + def setSimpleNormalizationParameters(self, parameters): + if parameters != None: + self.simpleNormalization = True + self.simpleNormalizationParameters = [0, 0] + for i, splittedParameter in enumerate(parameters.split(",")): + self.simpleNormalizationParameters[i] = int(splittedParameter) + + + def setAdjustedNormalization(self, boolean): + self.adjustedNormalization = boolean + + + def setFixedSizeNormalization(self, value): + self.fixedSizeFactor = value + + + def setFdr(self, fdr): + self.fdr = fdr + + + def setPlot(self, boolean): + self.plot = boolean + + + def setPlotterName(self, plotterName): + self.plotterName = plotterName + + def setPlotter(self): + self.plot = True + self.plotter = RPlotter(self.plotterName, self.verbosity) + self.plotter.setPoints(True) + self.plotter.setLog("xy") + self.points = {} + + + def readInput(self, i): + self.transcriptContainers[i].storeIntoDatabase() + self.tables[i] = self.transcriptContainers[i].getTables() + progress = Progress(len(self.tables[i].keys()), "Adding indices", self.verbosity) + for chromosome in self.tables[i]: + if self.oriented: + self.tables[i][chromosome].createIndex("iStartEndDir_%s_%d" % (chromosome, i), ("start", "end", "direction")) + else: + self.tables[i][chromosome].createIndex("iStartEnd_%s_%d" % (chromosome, i), ("start", "end")) + progress.inc() + progress.done() + + progress = Progress(self.transcriptContainers[i].getNbTranscripts(), "Reading sample %d" % (i +1), self.verbosity) + for chromosome in self.tables[i]: + for transcript in self.tables[i][chromosome].getIterator(): + self.nbElements[i] += 1 if "nbElements" not in transcript.getTagNames() else transcript.getTagValue("nbElements") + progress.inc() + progress.done() + if self.verbosity > 0: + print "%d elements in sample %d" % (self.nbElements[i], i+1) + + + def computeSimpleNormalizationFactors(self): + nbElements = self.nbElements + if self.simpleNormalizationParameters != None: + print "Using provided normalization parameters: %s" % (", ".join([str(parameter) for parameter in self.simpleNormalizationParameters])) + nbElements = self.simpleNormalizationParameters + avgNbElements = int(float(sum(nbElements)) / len(nbElements)) + for i in self.inputs: + self.normalizationFactors[i] = float(avgNbElements) / nbElements[i] + self.nbElements[i] *= self.normalizationFactors[i] + if self.verbosity > 1: + print "Normalizing to average # reads: %d" % (avgNbElements) + if self.simpleNormalizationParameters != None: + print "# reads: %s" % (", ".join([str(nbElement) for nbElement in self.nbElements])) + + def __del__(self): + self.mySqlConnection.deleteDatabase() + + def regionToString(self, transcript): + return "%s:%d-%d(%s)" % (transcript.getChromosome(), transcript.getStart(), transcript.getEnd(), "+" if transcript.getDirection() == 1 else "-") + + def stringToRegion(self, region): + m = re.search(r"^(\S+):(\d+)-(\d+)\((\S)\)$", region) + if m == None: + raise Exception("Internal format error: cannot parse region '%s'" % (region)) + transcript = Transcript() + transcript.setChromosome(m.group(1)) + transcript.setStart(int(m.group(2))) + transcript.setEnd(int(m.group(3))) + transcript.setDirection(m.group(4)) + return transcript + + def computeMinimumSize(self): + self.normalizationSize = 1000000000 + progress = Progress(self.transcriptContainerRef.getNbTranscripts(), "Getting minimum reference size", self.verbosity) + for transcriptRef in self.transcriptContainerRef.getIterator(): + self.normalizationSize = min(self.normalizationSize, transcriptRef.getEnd() - transcriptRef.getStart()) + progress.inc() + progress.done() + if self.verbosity > 1: + print "Minimum reference size: %d" % (self.normalizationSize+1) + + def useFixedSizeNormalization(self, start, end, starts): + currentNb = 0 + sum = 0 + if not starts: + return 0 + for i in range(start - self.normalizationSize, end + 1 + self.normalizationSize): + if i not in starts: + starts[i] = 0 + for i, s in starts.iteritems(): + if i < start: + starts[start] += s + starts[i] = 0 + for i in range(start - self.normalizationSize, end + 1): + currentNb += starts[i+self.normalizationSize] - starts[i] + sum += currentNb + return (float(sum) / self.normalizationSize) * (self.fixedSizeFactor / (end - start + 1)) + + def retrieveCounts(self, transcriptRef, i): + if transcriptRef.getChromosome() not in self.tables[i]: + return (0, 0) + cumulatedCount = 0 + cumulatedNormalizedCount = 0 + for exon in transcriptRef.getExons(): + count = 0 + starts = {} + command = "SELECT start, tags FROM '%s' WHERE start >= %d AND end <= %d" % (self.tables[i][exon.getChromosome()].getName(), exon.getStart(), exon.getEnd()) + if self.oriented: + command += " AND direction = %d" % (exon.getDirection()) + query = self.mySqlConnection.executeQuery(command) + for line in query.getIterator(): + nb = 1 + tags = line[1].split(";") + for tag in tags: + key, value = tag.split("=") + if key == "nbElements": + nb = int(float(value)) + count += nb + starts[int(line[0])] = nb + normalizedCount = count if self.fixedSizeFactor == None else self.useFixedSizeNormalization(exon.getStart(), exon.getEnd(), starts) + cumulatedCount += count + cumulatedNormalizedCount += normalizedCount + return (cumulatedCount, cumulatedNormalizedCount) + + def getAllCounts(self): + progress = Progress(self.transcriptContainerRef.getNbTranscripts(), "Getting counts", self.verbosity) + for cpt, transcriptRef in enumerate(self.transcriptContainerRef.getIterator()): + if "ID" in transcriptRef.getTagNames(): + self.regionsToNames[self.regionToString(transcriptRef)] = transcriptRef.getTagValue("ID") + elif transcriptRef.getName() != None: + self.regionsToNames[self.regionToString(transcriptRef)] = transcriptRef.getName() + else: + self.regionsToNames[self.regionToString(transcriptRef)] = "region_%d" % (cpt) + values = [None, None] + normalizedValues = [None, None] + for i in self.inputs: + values[i], normalizedValues[i] = self.retrieveCounts(transcriptRef, i) + normalizedValues[i] = int(self.normalizationFactors[i] * normalizedValues[i]) + if sum(values) != 0: + self.regionsToValues[self.regionToString(transcriptRef)] = (normalizedValues[0], normalizedValues[1], values[0], values[1]) + progress.inc() + progress.done() + + def computeAdjustedNormalizationFactors(self): + nbElements = len(self.regionsToValues.keys()) + avgValues = [] + progress = Progress(nbElements, "Normalization step 1", self.verbosity) + for values in self.regionsToValues.values(): + correctedValues = [values[i] * self.normalizationFactors[i] for i in self.inputs] + avgValues.append(float(sum(correctedValues)) / len(correctedValues)) + progress.inc() + progress.done() + + sortedAvgValues = sorted(avgValues) + minAvgValues = sortedAvgValues[nbElements / 4] + maxAvgValues = sortedAvgValues[nbElements * 3 / 4] + sums = [0, 0] + progress = Progress(nbElements, "Normalization step 2", self.verbosity) + for values in self.regionsToValues.values(): + correctedValues = [values[i] * self.normalizationFactors[i] for i in self.inputs] + avgValue = float(sum(correctedValues)) / len(correctedValues) + if minAvgValues <= avgValue and avgValue <= maxAvgValues: + for i in self.inputs: + sums[i] += values[i] + progress.inc() + progress.done() + + avgSums = float(sum(sums)) / len(sums) + for i in self.inputs: + if self.verbosity > 1: + print "Normalizing sample %d: %s to" % ((i+1), self.nbElements[i]), + self.normalizationFactors[i] *= float(avgSums) / sums[i] + self.nbElements[i] *= self.normalizationFactors[i] + if self.verbosity > 1: + print "%s" % (int(self.nbElements[i])) + + def getMinimumReferenceSize(self): + self.normalizationSize = 1000000000 + progress = Progress(self.transcriptContainerRef.getNbTranscripts(), "Reference element sizes", self.verbosity) + for transcriptRef in self.transcriptContainerRef.getIterator(): + self.normalizationSize = min(self.normalizationSize, transcriptRef.getEnd() - transcriptRef.getStart() + 1) + progress.inc() + progress.done() + if self.verbosity > 1: + print "Minimum reference size: %d" % (self.normalizationSize) + + def computePvalues(self): + normalizedValues = set() + progress = Progress(len(self.regionsToValues.keys()), "Normalizing counts", self.verbosity) + for region in self.regionsToValues: + values = self.regionsToValues[region] + normalizedValues0 = int(round(values[0] * self.normalizationFactors[0])) + normalizedValues1 = int(round(values[1] * self.normalizationFactors[1])) + self.regionsToValues[region] = (normalizedValues0, normalizedValues1, self.regionsToValues[region][2], self.regionsToValues[region][3]) + normalizedValues.add((normalizedValues0, normalizedValues1, self.nbElements[0] - normalizedValues0, self.nbElements[1] - normalizedValues1, self.regionsToValues[region][2], self.regionsToValues[region][3])) + progress.inc() + progress.done() + + if self.verbosity > 1: + print "Computing p-values..." + self.valuesToPvalues = Utils.fisherExactPValueBulk(list(normalizedValues)) + if self.verbosity > 1: + print "... done" + + def setTagValues(self, transcript, values, pValue): + for tag in transcript.getTagNames(): + transcript.deleteTag(tag) + transcript.removeExons() + transcript.setTagValue("pValue", str(pValue)) + transcript.setTagValue("nbReadsCond1", str(values[0])) + transcript.setTagValue("nbReadsCond2", str(values[1])) + transcript.setTagValue("nbUnnormalizedReadsCond1", str(values[2])) + transcript.setTagValue("nbUnnormalizedReadsCond2", str(values[3])) + if (values[0] == values[1]) or (self.fdr != None and pValue > self.fdrPvalue): + transcript.setTagValue("regulation", "equal") + elif values[0] < values[1]: + transcript.setTagValue("regulation", "up") + else: + transcript.setTagValue("regulation", "down") + return transcript + + def computeFdr(self): + pValues = [] + nbRegions = len(self.regionsToValues.keys()) + progress = Progress(nbRegions, "Computing FDR", self.verbosity) + for values in self.regionsToValues.values(): + pValues.append(self.valuesToPvalues[values[0:2]]) + progress.inc() + progress.done() + + for i, pValue in enumerate(reversed(sorted(pValues))): + if pValue <= self.fdr * (nbRegions - 1 - i) / nbRegions: + self.fdrPvalue = pValue + if self.verbosity > 1: + print "FDR: %f, k: %i, m: %d" % (pValue, nbRegions - 1 - i, nbRegions) + return + + def writeDifferentialExpression(self): + if self.plot: + self.setPlotter() + + cpt = 1 + progress = Progress(len(self.regionsToValues.keys()), "Writing output", self.verbosity) + for region, values in self.regionsToValues.iteritems(): + transcript = self.stringToRegion(region) + pValue = self.valuesToPvalues[values[0:2]] + transcript.setName(self.regionsToNames[region]) + transcript = self.setTagValues(transcript, values, pValue) + self.writer.addTranscript(transcript) + cpt += 1 + + if self.plot: + self.points[region] = (values[0], values[1]) + progress.done() + self.writer.write() + self.writer.close() + + if self.plot: + self.plotter.addLine(self.points) + self.plotter.plot() + + def getDifferentialExpression(self): + for i in self.inputs: + self.readInput(i) + + if self.simpleNormalization: + self.computeSimpleNormalizationFactors() + if self.fixedSizeFactor != None: + self.computeMinimumSize() + + self.getAllCounts() + + if self.adjustedNormalization: + self.computeAdjustedNormalizationFactors() + + self.computePvalues() + + if self.fdr != None: + self.computeFdr() + + self.writeDifferentialExpression() + + +if __name__ == "__main__": + + # parse command line + description = "Get Differential Expression v1.0.1: Get the differential expression between 2 conditions using Fisher's exact test, on regions defined by a third file. [Category: Data Comparison]" + + parser = OptionParser(description = description) + parser.add_option("-i", "--input1", dest="inputFileName1", action="store", type="string", help="input file 1 [compulsory] [format: file in transcript format given by -f]") + parser.add_option("-f", "--format1", dest="format1", action="store", type="string", help="format of file 1 [compulsory] [format: transcript file format]") + parser.add_option("-j", "--input2", dest="inputFileName2", action="store", type="string", help="input file 2 [compulsory] [format: file in transcript format given by -g]") + parser.add_option("-g", "--format2", dest="format2", action="store", type="string", help="format of file 2 [compulsory] [format: transcript file format]") + parser.add_option("-k", "--reference", dest="referenceFileName", action="store", type="string", help="reference file [compulsory] [format: file in transcript format given by -l]") + parser.add_option("-l", "--referenceFormat", dest="referenceFormat", action="store", type="string", help="format of reference file [compulsory] [format: transcript file format]") + parser.add_option("-o", "--output", dest="outputFileName", action="store", type="string", help="output file [format: output file in gff3 format]") + parser.add_option("-n", "--notOriented", dest="notOriented", action="store_true", default=False, help="if the reads are not oriented [default: False] [format: bool]") + parser.add_option("-s", "--simple", dest="simple", action="store_true", default=False, help="normalize using the number of reads in each condition [format: bool]") + parser.add_option("-S", "--simpleParameters", dest="simpleParameters", action="store", default=None, type="string", help="provide the number of reads [format: bool]") + parser.add_option("-a", "--adjusted", dest="adjusted", action="store_true", default=False, help="normalize using the number of reads of 'mean' regions [format: bool]") + parser.add_option("-x", "--fixedSizeFactor", dest="fixedSizeFactor", action="store", default=None, type="int", help="give the magnification factor for the normalization using fixed size sliding windows in reference regions (leave empty for no such normalization) [format: int]") + parser.add_option("-d", "--fdr", dest="fdr", action="store", default=None, type="float", help="use FDR [format: float]") + parser.add_option("-p", "--plot", dest="plotName", action="store", default=None, type="string", help="plot cloud plot [format: output file in PNG format]") + parser.add_option("-v", "--verbosity", dest="verbosity", action="store", default=1, type="int", help="trace level [format: int]") + (options, args) = parser.parse_args() + + + + differentialExpression = GetDifferentialExpression(options.verbosity) + differentialExpression.setInputFile(0, options.inputFileName1, options.format1) + differentialExpression.setInputFile(1, options.inputFileName2, options.format2) + differentialExpression.setReferenceFile(options.referenceFileName, options.referenceFormat) + differentialExpression.setOutputFile(options.outputFileName) + if options.plotName != None : + differentialExpression.setPlotterName(options.plotName) + differentialExpression.setPlotter() + differentialExpression.setOriented(not options.notOriented) + differentialExpression.setSimpleNormalization(options.simple) + differentialExpression.setSimpleNormalizationParameters(options.simpleParameters) + differentialExpression.setAdjustedNormalization(options.adjusted) + differentialExpression.setFixedSizeNormalization(options.fixedSizeFactor) + differentialExpression.setFdr(options.fdr) + differentialExpression.getDifferentialExpression() + differentialExpression.mySqlConnection.deleteDatabase() + +