Mercurial > repos > yufei-luo > s_mart
view smart_toolShed/SMART/Java/Python/GetDifferentialExpression.py @ 4:1fc014126d55
Uploaded
author | yufei-luo |
---|---|
date | Fri, 18 Jan 2013 04:45:50 -0500 |
parents | e0f8dcca02ed |
children |
line wrap: on
line source
#! /usr/bin/env python # # Copyright INRA-URGI 2009-2010 # # This software is governed by the CeCILL license under French law and # abiding by the rules of distribution of free software. You can use, # modify and/ or redistribute the software under the terms of the CeCILL # license as circulated by CEA, CNRS and INRIA at the following URL # "http://www.cecill.info". # # As a counterpart to the access to the source code and rights to copy, # modify and redistribute granted by the license, users are provided only # with a limited warranty and the software's author, the holder of the # economic rights, and the successive licensors have only limited # liability. # # In this respect, the user's attention is drawn to the risks associated # with loading, using, modifying and/or developing or reproducing the # software by the user in light of its specific status of free software, # that may mean that it is complicated to manipulate, and that also # therefore means that it is reserved for developers and experienced # professionals having in-depth computer knowledge. Users are therefore # encouraged to load and test the software's suitability as regards their # requirements in conditions enabling the security of their systems and/or # data to be ensured and, more generally, to use and operate it in the # same conditions as regards security. # # The fact that you are presently reading this means that you have had # knowledge of the CeCILL license and that you accept its terms. # """Get the differential expression between 2 conditions (2 files), on regions defined by a third file""" import os, re from optparse import OptionParser from SMART.Java.Python.structure.TranscriptContainer import TranscriptContainer from commons.core.writer.Gff3Writer import Gff3Writer from SMART.Java.Python.misc.Progress import Progress from SMART.Java.Python.misc.RPlotter import RPlotter from SMART.Java.Python.misc import Utils from SMART.Java.Python.mySql.MySqlConnection import MySqlConnection from SMART.Java.Python.structure.Transcript import Transcript class GetDifferentialExpression(object): def __init__(self, verbosity = 1): self.verbosity = verbosity self.mySqlConnection = MySqlConnection(verbosity) self.inputs = (0, 1) self.transcriptContainers = [None, None] self.transcriptContainerRef = None self.outputFileName = None self.writer = None self.tables = [None, None] self.nbElements = [0, 0] self.regionsToValues = {} self.regionsToNames = {} self.valuesToPvalues = {} self.oriented = True self.simpleNormalization = False self.simpleNormalizationParameters = None self.adjustedNormalization = False self.fixedSizeFactor = None self.normalizationSize = None self.normalizationFactors = [1, 1] self.fdr = None self.fdrPvalue = None self.plot = False self.plotter = None self.plotterName = None self.points = {} def setInputFile(self, i, fileName, fileFormat): self.transcriptContainers[i] = TranscriptContainer(fileName, fileFormat, self.verbosity) self.transcriptContainers[i].mySqlConnection = self.mySqlConnection def setReferenceFile(self, fileName, fileFormat): self.transcriptContainerRef = TranscriptContainer(fileName, fileFormat, self.verbosity) self.transcriptContainerRef.mySqlConnection = self.mySqlConnection def setOutputFile(self, fileName): self.outputFileName = fileName self.writer = Gff3Writer(fileName, self.verbosity) def setOriented(self, boolean): self.oriented = boolean def setSimpleNormalization(self, boolean): self.simpleNormalization = boolean def setSimpleNormalizationParameters(self, parameters): if parameters != None: self.simpleNormalization = True self.simpleNormalizationParameters = [0, 0] for i, splittedParameter in enumerate(parameters.split(",")): self.simpleNormalizationParameters[i] = int(splittedParameter) def setAdjustedNormalization(self, boolean): self.adjustedNormalization = boolean def setFixedSizeNormalization(self, value): self.fixedSizeFactor = value def setFdr(self, fdr): self.fdr = fdr def setPlot(self, boolean): self.plot = boolean def setPlotterName(self, plotterName): self.plotterName = plotterName def setPlotter(self): self.plot = True self.plotter = RPlotter(self.plotterName, self.verbosity) self.plotter.setPoints(True) self.plotter.setLog("xy") self.points = {} def readInput(self, i): self.transcriptContainers[i].storeIntoDatabase() self.tables[i] = self.transcriptContainers[i].getTables() progress = Progress(len(self.tables[i].keys()), "Adding indices", self.verbosity) for chromosome in self.tables[i]: if self.oriented: self.tables[i][chromosome].createIndex("iStartEndDir_%s_%d" % (chromosome, i), ("start", "end", "direction")) else: self.tables[i][chromosome].createIndex("iStartEnd_%s_%d" % (chromosome, i), ("start", "end")) progress.inc() progress.done() progress = Progress(self.transcriptContainers[i].getNbTranscripts(), "Reading sample %d" % (i +1), self.verbosity) for chromosome in self.tables[i]: for transcript in self.tables[i][chromosome].getIterator(): self.nbElements[i] += 1 if "nbElements" not in transcript.getTagNames() else transcript.getTagValue("nbElements") progress.inc() progress.done() if self.verbosity > 0: print "%d elements in sample %d" % (self.nbElements[i], i+1) def computeSimpleNormalizationFactors(self): nbElements = self.nbElements if self.simpleNormalizationParameters != None: print "Using provided normalization parameters: %s" % (", ".join([str(parameter) for parameter in self.simpleNormalizationParameters])) nbElements = self.simpleNormalizationParameters avgNbElements = int(float(sum(nbElements)) / len(nbElements)) for i in self.inputs: self.normalizationFactors[i] = float(avgNbElements) / nbElements[i] self.nbElements[i] *= self.normalizationFactors[i] if self.verbosity > 1: print "Normalizing to average # reads: %d" % (avgNbElements) if self.simpleNormalizationParameters != None: print "# reads: %s" % (", ".join([str(nbElement) for nbElement in self.nbElements])) def __del__(self): self.mySqlConnection.deleteDatabase() def regionToString(self, transcript): return "%s:%d-%d(%s)" % (transcript.getChromosome(), transcript.getStart(), transcript.getEnd(), "+" if transcript.getDirection() == 1 else "-") def stringToRegion(self, region): m = re.search(r"^(\S+):(\d+)-(\d+)\((\S)\)$", region) if m == None: raise Exception("Internal format error: cannot parse region '%s'" % (region)) transcript = Transcript() transcript.setChromosome(m.group(1)) transcript.setStart(int(m.group(2))) transcript.setEnd(int(m.group(3))) transcript.setDirection(m.group(4)) return transcript def computeMinimumSize(self): self.normalizationSize = 1000000000 progress = Progress(self.transcriptContainerRef.getNbTranscripts(), "Getting minimum reference size", self.verbosity) for transcriptRef in self.transcriptContainerRef.getIterator(): self.normalizationSize = min(self.normalizationSize, transcriptRef.getEnd() - transcriptRef.getStart()) progress.inc() progress.done() if self.verbosity > 1: print "Minimum reference size: %d" % (self.normalizationSize+1) def useFixedSizeNormalization(self, start, end, starts): currentNb = 0 sum = 0 if not starts: return 0 for i in range(start - self.normalizationSize, end + 1 + self.normalizationSize): if i not in starts: starts[i] = 0 for i, s in starts.iteritems(): if i < start: starts[start] += s starts[i] = 0 for i in range(start - self.normalizationSize, end + 1): currentNb += starts[i+self.normalizationSize] - starts[i] sum += currentNb return (float(sum) / self.normalizationSize) * (self.fixedSizeFactor / (end - start + 1)) def retrieveCounts(self, transcriptRef, i): if transcriptRef.getChromosome() not in self.tables[i]: return (0, 0) cumulatedCount = 0 cumulatedNormalizedCount = 0 for exon in transcriptRef.getExons(): count = 0 starts = {} command = "SELECT start, tags FROM '%s' WHERE start >= %d AND end <= %d" % (self.tables[i][exon.getChromosome()].getName(), exon.getStart(), exon.getEnd()) if self.oriented: command += " AND direction = %d" % (exon.getDirection()) query = self.mySqlConnection.executeQuery(command) for line in query.getIterator(): nb = 1 tags = line[1].split(";") for tag in tags: key, value = tag.split("=") if key == "nbElements": nb = int(float(value)) count += nb starts[int(line[0])] = nb normalizedCount = count if self.fixedSizeFactor == None else self.useFixedSizeNormalization(exon.getStart(), exon.getEnd(), starts) cumulatedCount += count cumulatedNormalizedCount += normalizedCount return (cumulatedCount, cumulatedNormalizedCount) def getAllCounts(self): progress = Progress(self.transcriptContainerRef.getNbTranscripts(), "Getting counts", self.verbosity) for cpt, transcriptRef in enumerate(self.transcriptContainerRef.getIterator()): if "ID" in transcriptRef.getTagNames(): self.regionsToNames[self.regionToString(transcriptRef)] = transcriptRef.getTagValue("ID") elif transcriptRef.getName() != None: self.regionsToNames[self.regionToString(transcriptRef)] = transcriptRef.getName() else: self.regionsToNames[self.regionToString(transcriptRef)] = "region_%d" % (cpt) values = [None, None] normalizedValues = [None, None] for i in self.inputs: values[i], normalizedValues[i] = self.retrieveCounts(transcriptRef, i) normalizedValues[i] = int(self.normalizationFactors[i] * normalizedValues[i]) if sum(values) != 0: self.regionsToValues[self.regionToString(transcriptRef)] = (normalizedValues[0], normalizedValues[1], values[0], values[1]) progress.inc() progress.done() def computeAdjustedNormalizationFactors(self): nbElements = len(self.regionsToValues.keys()) avgValues = [] progress = Progress(nbElements, "Normalization step 1", self.verbosity) for values in self.regionsToValues.values(): correctedValues = [values[i] * self.normalizationFactors[i] for i in self.inputs] avgValues.append(float(sum(correctedValues)) / len(correctedValues)) progress.inc() progress.done() sortedAvgValues = sorted(avgValues) minAvgValues = sortedAvgValues[nbElements / 4] maxAvgValues = sortedAvgValues[nbElements * 3 / 4] sums = [0, 0] progress = Progress(nbElements, "Normalization step 2", self.verbosity) for values in self.regionsToValues.values(): correctedValues = [values[i] * self.normalizationFactors[i] for i in self.inputs] avgValue = float(sum(correctedValues)) / len(correctedValues) if minAvgValues <= avgValue and avgValue <= maxAvgValues: for i in self.inputs: sums[i] += values[i] progress.inc() progress.done() avgSums = float(sum(sums)) / len(sums) for i in self.inputs: if self.verbosity > 1: print "Normalizing sample %d: %s to" % ((i+1), self.nbElements[i]), self.normalizationFactors[i] *= float(avgSums) / sums[i] self.nbElements[i] *= self.normalizationFactors[i] if self.verbosity > 1: print "%s" % (int(self.nbElements[i])) def getMinimumReferenceSize(self): self.normalizationSize = 1000000000 progress = Progress(self.transcriptContainerRef.getNbTranscripts(), "Reference element sizes", self.verbosity) for transcriptRef in self.transcriptContainerRef.getIterator(): self.normalizationSize = min(self.normalizationSize, transcriptRef.getEnd() - transcriptRef.getStart() + 1) progress.inc() progress.done() if self.verbosity > 1: print "Minimum reference size: %d" % (self.normalizationSize) def computePvalues(self): normalizedValues = set() progress = Progress(len(self.regionsToValues.keys()), "Normalizing counts", self.verbosity) for region in self.regionsToValues: values = self.regionsToValues[region] normalizedValues0 = int(round(values[0] * self.normalizationFactors[0])) normalizedValues1 = int(round(values[1] * self.normalizationFactors[1])) self.regionsToValues[region] = (normalizedValues0, normalizedValues1, self.regionsToValues[region][2], self.regionsToValues[region][3]) normalizedValues.add((normalizedValues0, normalizedValues1, self.nbElements[0] - normalizedValues0, self.nbElements[1] - normalizedValues1, self.regionsToValues[region][2], self.regionsToValues[region][3])) progress.inc() progress.done() if self.verbosity > 1: print "Computing p-values..." self.valuesToPvalues = Utils.fisherExactPValueBulk(list(normalizedValues)) if self.verbosity > 1: print "... done" def setTagValues(self, transcript, values, pValue): for tag in transcript.getTagNames(): transcript.deleteTag(tag) transcript.removeExons() transcript.setTagValue("pValue", str(pValue)) transcript.setTagValue("nbReadsCond1", str(values[0])) transcript.setTagValue("nbReadsCond2", str(values[1])) transcript.setTagValue("nbUnnormalizedReadsCond1", str(values[2])) transcript.setTagValue("nbUnnormalizedReadsCond2", str(values[3])) if (values[0] == values[1]) or (self.fdr != None and pValue > self.fdrPvalue): transcript.setTagValue("regulation", "equal") elif values[0] < values[1]: transcript.setTagValue("regulation", "up") else: transcript.setTagValue("regulation", "down") return transcript def computeFdr(self): pValues = [] nbRegions = len(self.regionsToValues.keys()) progress = Progress(nbRegions, "Computing FDR", self.verbosity) for values in self.regionsToValues.values(): pValues.append(self.valuesToPvalues[values[0:2]]) progress.inc() progress.done() for i, pValue in enumerate(reversed(sorted(pValues))): if pValue <= self.fdr * (nbRegions - 1 - i) / nbRegions: self.fdrPvalue = pValue if self.verbosity > 1: print "FDR: %f, k: %i, m: %d" % (pValue, nbRegions - 1 - i, nbRegions) return def writeDifferentialExpression(self): if self.plot: self.setPlotter() cpt = 1 progress = Progress(len(self.regionsToValues.keys()), "Writing output", self.verbosity) for region, values in self.regionsToValues.iteritems(): transcript = self.stringToRegion(region) pValue = self.valuesToPvalues[values[0:2]] transcript.setName(self.regionsToNames[region]) transcript = self.setTagValues(transcript, values, pValue) self.writer.addTranscript(transcript) cpt += 1 if self.plot: self.points[region] = (values[0], values[1]) progress.done() self.writer.write() self.writer.close() if self.plot: self.plotter.addLine(self.points) self.plotter.plot() def getDifferentialExpression(self): for i in self.inputs: self.readInput(i) if self.simpleNormalization: self.computeSimpleNormalizationFactors() if self.fixedSizeFactor != None: self.computeMinimumSize() self.getAllCounts() if self.adjustedNormalization: self.computeAdjustedNormalizationFactors() self.computePvalues() if self.fdr != None: self.computeFdr() self.writeDifferentialExpression() if __name__ == "__main__": # parse command line description = "Get Differential Expression v1.0.1: Get the differential expression between 2 conditions using Fisher's exact test, on regions defined by a third file. [Category: Data Comparison]" parser = OptionParser(description = description) parser.add_option("-i", "--input1", dest="inputFileName1", action="store", type="string", help="input file 1 [compulsory] [format: file in transcript format given by -f]") parser.add_option("-f", "--format1", dest="format1", action="store", type="string", help="format of file 1 [compulsory] [format: transcript file format]") parser.add_option("-j", "--input2", dest="inputFileName2", action="store", type="string", help="input file 2 [compulsory] [format: file in transcript format given by -g]") parser.add_option("-g", "--format2", dest="format2", action="store", type="string", help="format of file 2 [compulsory] [format: transcript file format]") parser.add_option("-k", "--reference", dest="referenceFileName", action="store", type="string", help="reference file [compulsory] [format: file in transcript format given by -l]") parser.add_option("-l", "--referenceFormat", dest="referenceFormat", action="store", type="string", help="format of reference file [compulsory] [format: transcript file format]") parser.add_option("-o", "--output", dest="outputFileName", action="store", type="string", help="output file [format: output file in gff3 format]") parser.add_option("-n", "--notOriented", dest="notOriented", action="store_true", default=False, help="if the reads are not oriented [default: False] [format: bool]") parser.add_option("-s", "--simple", dest="simple", action="store_true", default=False, help="normalize using the number of reads in each condition [format: bool]") parser.add_option("-S", "--simpleParameters", dest="simpleParameters", action="store", default=None, type="string", help="provide the number of reads [format: bool]") parser.add_option("-a", "--adjusted", dest="adjusted", action="store_true", default=False, help="normalize using the number of reads of 'mean' regions [format: bool]") parser.add_option("-x", "--fixedSizeFactor", dest="fixedSizeFactor", action="store", default=None, type="int", help="give the magnification factor for the normalization using fixed size sliding windows in reference regions (leave empty for no such normalization) [format: int]") parser.add_option("-d", "--fdr", dest="fdr", action="store", default=None, type="float", help="use FDR [format: float]") parser.add_option("-p", "--plot", dest="plotName", action="store", default=None, type="string", help="plot cloud plot [format: output file in PNG format]") parser.add_option("-v", "--verbosity", dest="verbosity", action="store", default=1, type="int", help="trace level [format: int]") (options, args) = parser.parse_args() differentialExpression = GetDifferentialExpression(options.verbosity) differentialExpression.setInputFile(0, options.inputFileName1, options.format1) differentialExpression.setInputFile(1, options.inputFileName2, options.format2) differentialExpression.setReferenceFile(options.referenceFileName, options.referenceFormat) differentialExpression.setOutputFile(options.outputFileName) if options.plotName != None : differentialExpression.setPlotterName(options.plotName) differentialExpression.setPlotter() differentialExpression.setOriented(not options.notOriented) differentialExpression.setSimpleNormalization(options.simple) differentialExpression.setSimpleNormalizationParameters(options.simpleParameters) differentialExpression.setAdjustedNormalization(options.adjusted) differentialExpression.setFixedSizeNormalization(options.fixedSizeFactor) differentialExpression.setFdr(options.fdr) differentialExpression.getDifferentialExpression() differentialExpression.mySqlConnection.deleteDatabase()