Mercurial > repos > yufei-luo > s_mart
comparison smart_toolShed/SMART/Java/Python/GetDifferentialExpression.py @ 0:e0f8dcca02ed
Uploaded S-MART tool. A toolbox manages RNA-Seq and ChIP-Seq data.
author | yufei-luo |
---|---|
date | Thu, 17 Jan 2013 10:52:14 -0500 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e0f8dcca02ed |
---|---|
1 #! /usr/bin/env python | |
2 # | |
3 # Copyright INRA-URGI 2009-2010 | |
4 # | |
5 # This software is governed by the CeCILL license under French law and | |
6 # abiding by the rules of distribution of free software. You can use, | |
7 # modify and/ or redistribute the software under the terms of the CeCILL | |
8 # license as circulated by CEA, CNRS and INRIA at the following URL | |
9 # "http://www.cecill.info". | |
10 # | |
11 # As a counterpart to the access to the source code and rights to copy, | |
12 # modify and redistribute granted by the license, users are provided only | |
13 # with a limited warranty and the software's author, the holder of the | |
14 # economic rights, and the successive licensors have only limited | |
15 # liability. | |
16 # | |
17 # In this respect, the user's attention is drawn to the risks associated | |
18 # with loading, using, modifying and/or developing or reproducing the | |
19 # software by the user in light of its specific status of free software, | |
20 # that may mean that it is complicated to manipulate, and that also | |
21 # therefore means that it is reserved for developers and experienced | |
22 # professionals having in-depth computer knowledge. Users are therefore | |
23 # encouraged to load and test the software's suitability as regards their | |
24 # requirements in conditions enabling the security of their systems and/or | |
25 # data to be ensured and, more generally, to use and operate it in the | |
26 # same conditions as regards security. | |
27 # | |
28 # The fact that you are presently reading this means that you have had | |
29 # knowledge of the CeCILL license and that you accept its terms. | |
30 # | |
31 """Get the differential expression between 2 conditions (2 files), on regions defined by a third file""" | |
32 | |
33 import os, re | |
34 from optparse import OptionParser | |
35 from SMART.Java.Python.structure.TranscriptContainer import TranscriptContainer | |
36 from commons.core.writer.Gff3Writer import Gff3Writer | |
37 from SMART.Java.Python.misc.Progress import Progress | |
38 from SMART.Java.Python.misc.RPlotter import RPlotter | |
39 from SMART.Java.Python.misc import Utils | |
40 from SMART.Java.Python.mySql.MySqlConnection import MySqlConnection | |
41 from SMART.Java.Python.structure.Transcript import Transcript | |
42 | |
43 class GetDifferentialExpression(object): | |
44 | |
45 def __init__(self, verbosity = 1): | |
46 self.verbosity = verbosity | |
47 self.mySqlConnection = MySqlConnection(verbosity) | |
48 self.inputs = (0, 1) | |
49 self.transcriptContainers = [None, None] | |
50 self.transcriptContainerRef = None | |
51 self.outputFileName = None | |
52 self.writer = None | |
53 self.tables = [None, None] | |
54 self.nbElements = [0, 0] | |
55 | |
56 self.regionsToValues = {} | |
57 self.regionsToNames = {} | |
58 self.valuesToPvalues = {} | |
59 | |
60 self.oriented = True | |
61 self.simpleNormalization = False | |
62 self.simpleNormalizationParameters = None | |
63 self.adjustedNormalization = False | |
64 self.fixedSizeFactor = None | |
65 self.normalizationSize = None | |
66 self.normalizationFactors = [1, 1] | |
67 self.fdr = None | |
68 self.fdrPvalue = None | |
69 | |
70 self.plot = False | |
71 self.plotter = None | |
72 self.plotterName = None | |
73 self.points = {} | |
74 | |
75 | |
76 def setInputFile(self, i, fileName, fileFormat): | |
77 self.transcriptContainers[i] = TranscriptContainer(fileName, fileFormat, self.verbosity) | |
78 self.transcriptContainers[i].mySqlConnection = self.mySqlConnection | |
79 | |
80 | |
81 def setReferenceFile(self, fileName, fileFormat): | |
82 self.transcriptContainerRef = TranscriptContainer(fileName, fileFormat, self.verbosity) | |
83 self.transcriptContainerRef.mySqlConnection = self.mySqlConnection | |
84 | |
85 | |
86 def setOutputFile(self, fileName): | |
87 self.outputFileName = fileName | |
88 self.writer = Gff3Writer(fileName, self.verbosity) | |
89 | |
90 | |
91 def setOriented(self, boolean): | |
92 self.oriented = boolean | |
93 | |
94 | |
95 def setSimpleNormalization(self, boolean): | |
96 self.simpleNormalization = boolean | |
97 | |
98 | |
99 def setSimpleNormalizationParameters(self, parameters): | |
100 if parameters != None: | |
101 self.simpleNormalization = True | |
102 self.simpleNormalizationParameters = [0, 0] | |
103 for i, splittedParameter in enumerate(parameters.split(",")): | |
104 self.simpleNormalizationParameters[i] = int(splittedParameter) | |
105 | |
106 | |
107 def setAdjustedNormalization(self, boolean): | |
108 self.adjustedNormalization = boolean | |
109 | |
110 | |
111 def setFixedSizeNormalization(self, value): | |
112 self.fixedSizeFactor = value | |
113 | |
114 | |
115 def setFdr(self, fdr): | |
116 self.fdr = fdr | |
117 | |
118 | |
119 def setPlot(self, boolean): | |
120 self.plot = boolean | |
121 | |
122 | |
123 def setPlotterName(self, plotterName): | |
124 self.plotterName = plotterName | |
125 | |
126 def setPlotter(self): | |
127 self.plot = True | |
128 self.plotter = RPlotter(self.plotterName, self.verbosity) | |
129 self.plotter.setPoints(True) | |
130 self.plotter.setLog("xy") | |
131 self.points = {} | |
132 | |
133 | |
134 def readInput(self, i): | |
135 self.transcriptContainers[i].storeIntoDatabase() | |
136 self.tables[i] = self.transcriptContainers[i].getTables() | |
137 progress = Progress(len(self.tables[i].keys()), "Adding indices", self.verbosity) | |
138 for chromosome in self.tables[i]: | |
139 if self.oriented: | |
140 self.tables[i][chromosome].createIndex("iStartEndDir_%s_%d" % (chromosome, i), ("start", "end", "direction")) | |
141 else: | |
142 self.tables[i][chromosome].createIndex("iStartEnd_%s_%d" % (chromosome, i), ("start", "end")) | |
143 progress.inc() | |
144 progress.done() | |
145 | |
146 progress = Progress(self.transcriptContainers[i].getNbTranscripts(), "Reading sample %d" % (i +1), self.verbosity) | |
147 for chromosome in self.tables[i]: | |
148 for transcript in self.tables[i][chromosome].getIterator(): | |
149 self.nbElements[i] += 1 if "nbElements" not in transcript.getTagNames() else transcript.getTagValue("nbElements") | |
150 progress.inc() | |
151 progress.done() | |
152 if self.verbosity > 0: | |
153 print "%d elements in sample %d" % (self.nbElements[i], i+1) | |
154 | |
155 | |
156 def computeSimpleNormalizationFactors(self): | |
157 nbElements = self.nbElements | |
158 if self.simpleNormalizationParameters != None: | |
159 print "Using provided normalization parameters: %s" % (", ".join([str(parameter) for parameter in self.simpleNormalizationParameters])) | |
160 nbElements = self.simpleNormalizationParameters | |
161 avgNbElements = int(float(sum(nbElements)) / len(nbElements)) | |
162 for i in self.inputs: | |
163 self.normalizationFactors[i] = float(avgNbElements) / nbElements[i] | |
164 self.nbElements[i] *= self.normalizationFactors[i] | |
165 if self.verbosity > 1: | |
166 print "Normalizing to average # reads: %d" % (avgNbElements) | |
167 if self.simpleNormalizationParameters != None: | |
168 print "# reads: %s" % (", ".join([str(nbElement) for nbElement in self.nbElements])) | |
169 | |
170 def __del__(self): | |
171 self.mySqlConnection.deleteDatabase() | |
172 | |
173 def regionToString(self, transcript): | |
174 return "%s:%d-%d(%s)" % (transcript.getChromosome(), transcript.getStart(), transcript.getEnd(), "+" if transcript.getDirection() == 1 else "-") | |
175 | |
176 def stringToRegion(self, region): | |
177 m = re.search(r"^(\S+):(\d+)-(\d+)\((\S)\)$", region) | |
178 if m == None: | |
179 raise Exception("Internal format error: cannot parse region '%s'" % (region)) | |
180 transcript = Transcript() | |
181 transcript.setChromosome(m.group(1)) | |
182 transcript.setStart(int(m.group(2))) | |
183 transcript.setEnd(int(m.group(3))) | |
184 transcript.setDirection(m.group(4)) | |
185 return transcript | |
186 | |
187 def computeMinimumSize(self): | |
188 self.normalizationSize = 1000000000 | |
189 progress = Progress(self.transcriptContainerRef.getNbTranscripts(), "Getting minimum reference size", self.verbosity) | |
190 for transcriptRef in self.transcriptContainerRef.getIterator(): | |
191 self.normalizationSize = min(self.normalizationSize, transcriptRef.getEnd() - transcriptRef.getStart()) | |
192 progress.inc() | |
193 progress.done() | |
194 if self.verbosity > 1: | |
195 print "Minimum reference size: %d" % (self.normalizationSize+1) | |
196 | |
197 def useFixedSizeNormalization(self, start, end, starts): | |
198 currentNb = 0 | |
199 sum = 0 | |
200 if not starts: | |
201 return 0 | |
202 for i in range(start - self.normalizationSize, end + 1 + self.normalizationSize): | |
203 if i not in starts: | |
204 starts[i] = 0 | |
205 for i, s in starts.iteritems(): | |
206 if i < start: | |
207 starts[start] += s | |
208 starts[i] = 0 | |
209 for i in range(start - self.normalizationSize, end + 1): | |
210 currentNb += starts[i+self.normalizationSize] - starts[i] | |
211 sum += currentNb | |
212 return (float(sum) / self.normalizationSize) * (self.fixedSizeFactor / (end - start + 1)) | |
213 | |
214 def retrieveCounts(self, transcriptRef, i): | |
215 if transcriptRef.getChromosome() not in self.tables[i]: | |
216 return (0, 0) | |
217 cumulatedCount = 0 | |
218 cumulatedNormalizedCount = 0 | |
219 for exon in transcriptRef.getExons(): | |
220 count = 0 | |
221 starts = {} | |
222 command = "SELECT start, tags FROM '%s' WHERE start >= %d AND end <= %d" % (self.tables[i][exon.getChromosome()].getName(), exon.getStart(), exon.getEnd()) | |
223 if self.oriented: | |
224 command += " AND direction = %d" % (exon.getDirection()) | |
225 query = self.mySqlConnection.executeQuery(command) | |
226 for line in query.getIterator(): | |
227 nb = 1 | |
228 tags = line[1].split(";") | |
229 for tag in tags: | |
230 key, value = tag.split("=") | |
231 if key == "nbElements": | |
232 nb = int(float(value)) | |
233 count += nb | |
234 starts[int(line[0])] = nb | |
235 normalizedCount = count if self.fixedSizeFactor == None else self.useFixedSizeNormalization(exon.getStart(), exon.getEnd(), starts) | |
236 cumulatedCount += count | |
237 cumulatedNormalizedCount += normalizedCount | |
238 return (cumulatedCount, cumulatedNormalizedCount) | |
239 | |
240 def getAllCounts(self): | |
241 progress = Progress(self.transcriptContainerRef.getNbTranscripts(), "Getting counts", self.verbosity) | |
242 for cpt, transcriptRef in enumerate(self.transcriptContainerRef.getIterator()): | |
243 if "ID" in transcriptRef.getTagNames(): | |
244 self.regionsToNames[self.regionToString(transcriptRef)] = transcriptRef.getTagValue("ID") | |
245 elif transcriptRef.getName() != None: | |
246 self.regionsToNames[self.regionToString(transcriptRef)] = transcriptRef.getName() | |
247 else: | |
248 self.regionsToNames[self.regionToString(transcriptRef)] = "region_%d" % (cpt) | |
249 values = [None, None] | |
250 normalizedValues = [None, None] | |
251 for i in self.inputs: | |
252 values[i], normalizedValues[i] = self.retrieveCounts(transcriptRef, i) | |
253 normalizedValues[i] = int(self.normalizationFactors[i] * normalizedValues[i]) | |
254 if sum(values) != 0: | |
255 self.regionsToValues[self.regionToString(transcriptRef)] = (normalizedValues[0], normalizedValues[1], values[0], values[1]) | |
256 progress.inc() | |
257 progress.done() | |
258 | |
259 def computeAdjustedNormalizationFactors(self): | |
260 nbElements = len(self.regionsToValues.keys()) | |
261 avgValues = [] | |
262 progress = Progress(nbElements, "Normalization step 1", self.verbosity) | |
263 for values in self.regionsToValues.values(): | |
264 correctedValues = [values[i] * self.normalizationFactors[i] for i in self.inputs] | |
265 avgValues.append(float(sum(correctedValues)) / len(correctedValues)) | |
266 progress.inc() | |
267 progress.done() | |
268 | |
269 sortedAvgValues = sorted(avgValues) | |
270 minAvgValues = sortedAvgValues[nbElements / 4] | |
271 maxAvgValues = sortedAvgValues[nbElements * 3 / 4] | |
272 sums = [0, 0] | |
273 progress = Progress(nbElements, "Normalization step 2", self.verbosity) | |
274 for values in self.regionsToValues.values(): | |
275 correctedValues = [values[i] * self.normalizationFactors[i] for i in self.inputs] | |
276 avgValue = float(sum(correctedValues)) / len(correctedValues) | |
277 if minAvgValues <= avgValue and avgValue <= maxAvgValues: | |
278 for i in self.inputs: | |
279 sums[i] += values[i] | |
280 progress.inc() | |
281 progress.done() | |
282 | |
283 avgSums = float(sum(sums)) / len(sums) | |
284 for i in self.inputs: | |
285 if self.verbosity > 1: | |
286 print "Normalizing sample %d: %s to" % ((i+1), self.nbElements[i]), | |
287 self.normalizationFactors[i] *= float(avgSums) / sums[i] | |
288 self.nbElements[i] *= self.normalizationFactors[i] | |
289 if self.verbosity > 1: | |
290 print "%s" % (int(self.nbElements[i])) | |
291 | |
292 def getMinimumReferenceSize(self): | |
293 self.normalizationSize = 1000000000 | |
294 progress = Progress(self.transcriptContainerRef.getNbTranscripts(), "Reference element sizes", self.verbosity) | |
295 for transcriptRef in self.transcriptContainerRef.getIterator(): | |
296 self.normalizationSize = min(self.normalizationSize, transcriptRef.getEnd() - transcriptRef.getStart() + 1) | |
297 progress.inc() | |
298 progress.done() | |
299 if self.verbosity > 1: | |
300 print "Minimum reference size: %d" % (self.normalizationSize) | |
301 | |
302 def computePvalues(self): | |
303 normalizedValues = set() | |
304 progress = Progress(len(self.regionsToValues.keys()), "Normalizing counts", self.verbosity) | |
305 for region in self.regionsToValues: | |
306 values = self.regionsToValues[region] | |
307 normalizedValues0 = int(round(values[0] * self.normalizationFactors[0])) | |
308 normalizedValues1 = int(round(values[1] * self.normalizationFactors[1])) | |
309 self.regionsToValues[region] = (normalizedValues0, normalizedValues1, self.regionsToValues[region][2], self.regionsToValues[region][3]) | |
310 normalizedValues.add((normalizedValues0, normalizedValues1, self.nbElements[0] - normalizedValues0, self.nbElements[1] - normalizedValues1, self.regionsToValues[region][2], self.regionsToValues[region][3])) | |
311 progress.inc() | |
312 progress.done() | |
313 | |
314 if self.verbosity > 1: | |
315 print "Computing p-values..." | |
316 self.valuesToPvalues = Utils.fisherExactPValueBulk(list(normalizedValues)) | |
317 if self.verbosity > 1: | |
318 print "... done" | |
319 | |
320 def setTagValues(self, transcript, values, pValue): | |
321 for tag in transcript.getTagNames(): | |
322 transcript.deleteTag(tag) | |
323 transcript.removeExons() | |
324 transcript.setTagValue("pValue", str(pValue)) | |
325 transcript.setTagValue("nbReadsCond1", str(values[0])) | |
326 transcript.setTagValue("nbReadsCond2", str(values[1])) | |
327 transcript.setTagValue("nbUnnormalizedReadsCond1", str(values[2])) | |
328 transcript.setTagValue("nbUnnormalizedReadsCond2", str(values[3])) | |
329 if (values[0] == values[1]) or (self.fdr != None and pValue > self.fdrPvalue): | |
330 transcript.setTagValue("regulation", "equal") | |
331 elif values[0] < values[1]: | |
332 transcript.setTagValue("regulation", "up") | |
333 else: | |
334 transcript.setTagValue("regulation", "down") | |
335 return transcript | |
336 | |
337 def computeFdr(self): | |
338 pValues = [] | |
339 nbRegions = len(self.regionsToValues.keys()) | |
340 progress = Progress(nbRegions, "Computing FDR", self.verbosity) | |
341 for values in self.regionsToValues.values(): | |
342 pValues.append(self.valuesToPvalues[values[0:2]]) | |
343 progress.inc() | |
344 progress.done() | |
345 | |
346 for i, pValue in enumerate(reversed(sorted(pValues))): | |
347 if pValue <= self.fdr * (nbRegions - 1 - i) / nbRegions: | |
348 self.fdrPvalue = pValue | |
349 if self.verbosity > 1: | |
350 print "FDR: %f, k: %i, m: %d" % (pValue, nbRegions - 1 - i, nbRegions) | |
351 return | |
352 | |
353 def writeDifferentialExpression(self): | |
354 if self.plot: | |
355 self.setPlotter() | |
356 | |
357 cpt = 1 | |
358 progress = Progress(len(self.regionsToValues.keys()), "Writing output", self.verbosity) | |
359 for region, values in self.regionsToValues.iteritems(): | |
360 transcript = self.stringToRegion(region) | |
361 pValue = self.valuesToPvalues[values[0:2]] | |
362 transcript.setName(self.regionsToNames[region]) | |
363 transcript = self.setTagValues(transcript, values, pValue) | |
364 self.writer.addTranscript(transcript) | |
365 cpt += 1 | |
366 | |
367 if self.plot: | |
368 self.points[region] = (values[0], values[1]) | |
369 progress.done() | |
370 self.writer.write() | |
371 self.writer.close() | |
372 | |
373 if self.plot: | |
374 self.plotter.addLine(self.points) | |
375 self.plotter.plot() | |
376 | |
377 def getDifferentialExpression(self): | |
378 for i in self.inputs: | |
379 self.readInput(i) | |
380 | |
381 if self.simpleNormalization: | |
382 self.computeSimpleNormalizationFactors() | |
383 if self.fixedSizeFactor != None: | |
384 self.computeMinimumSize() | |
385 | |
386 self.getAllCounts() | |
387 | |
388 if self.adjustedNormalization: | |
389 self.computeAdjustedNormalizationFactors() | |
390 | |
391 self.computePvalues() | |
392 | |
393 if self.fdr != None: | |
394 self.computeFdr() | |
395 | |
396 self.writeDifferentialExpression() | |
397 | |
398 | |
399 if __name__ == "__main__": | |
400 | |
401 # parse command line | |
402 description = "Get Differential Expression v1.0.1: Get the differential expression between 2 conditions using Fisher's exact test, on regions defined by a third file. [Category: Data Comparison]" | |
403 | |
404 parser = OptionParser(description = description) | |
405 parser.add_option("-i", "--input1", dest="inputFileName1", action="store", type="string", help="input file 1 [compulsory] [format: file in transcript format given by -f]") | |
406 parser.add_option("-f", "--format1", dest="format1", action="store", type="string", help="format of file 1 [compulsory] [format: transcript file format]") | |
407 parser.add_option("-j", "--input2", dest="inputFileName2", action="store", type="string", help="input file 2 [compulsory] [format: file in transcript format given by -g]") | |
408 parser.add_option("-g", "--format2", dest="format2", action="store", type="string", help="format of file 2 [compulsory] [format: transcript file format]") | |
409 parser.add_option("-k", "--reference", dest="referenceFileName", action="store", type="string", help="reference file [compulsory] [format: file in transcript format given by -l]") | |
410 parser.add_option("-l", "--referenceFormat", dest="referenceFormat", action="store", type="string", help="format of reference file [compulsory] [format: transcript file format]") | |
411 parser.add_option("-o", "--output", dest="outputFileName", action="store", type="string", help="output file [format: output file in gff3 format]") | |
412 parser.add_option("-n", "--notOriented", dest="notOriented", action="store_true", default=False, help="if the reads are not oriented [default: False] [format: bool]") | |
413 parser.add_option("-s", "--simple", dest="simple", action="store_true", default=False, help="normalize using the number of reads in each condition [format: bool]") | |
414 parser.add_option("-S", "--simpleParameters", dest="simpleParameters", action="store", default=None, type="string", help="provide the number of reads [format: bool]") | |
415 parser.add_option("-a", "--adjusted", dest="adjusted", action="store_true", default=False, help="normalize using the number of reads of 'mean' regions [format: bool]") | |
416 parser.add_option("-x", "--fixedSizeFactor", dest="fixedSizeFactor", action="store", default=None, type="int", help="give the magnification factor for the normalization using fixed size sliding windows in reference regions (leave empty for no such normalization) [format: int]") | |
417 parser.add_option("-d", "--fdr", dest="fdr", action="store", default=None, type="float", help="use FDR [format: float]") | |
418 parser.add_option("-p", "--plot", dest="plotName", action="store", default=None, type="string", help="plot cloud plot [format: output file in PNG format]") | |
419 parser.add_option("-v", "--verbosity", dest="verbosity", action="store", default=1, type="int", help="trace level [format: int]") | |
420 (options, args) = parser.parse_args() | |
421 | |
422 | |
423 | |
424 differentialExpression = GetDifferentialExpression(options.verbosity) | |
425 differentialExpression.setInputFile(0, options.inputFileName1, options.format1) | |
426 differentialExpression.setInputFile(1, options.inputFileName2, options.format2) | |
427 differentialExpression.setReferenceFile(options.referenceFileName, options.referenceFormat) | |
428 differentialExpression.setOutputFile(options.outputFileName) | |
429 if options.plotName != None : | |
430 differentialExpression.setPlotterName(options.plotName) | |
431 differentialExpression.setPlotter() | |
432 differentialExpression.setOriented(not options.notOriented) | |
433 differentialExpression.setSimpleNormalization(options.simple) | |
434 differentialExpression.setSimpleNormalizationParameters(options.simpleParameters) | |
435 differentialExpression.setAdjustedNormalization(options.adjusted) | |
436 differentialExpression.setFixedSizeNormalization(options.fixedSizeFactor) | |
437 differentialExpression.setFdr(options.fdr) | |
438 differentialExpression.getDifferentialExpression() | |
439 differentialExpression.mySqlConnection.deleteDatabase() | |
440 | |
441 |