Mercurial > repos > artbio > sashimi_plot
comparison sashimi-plot.py @ 0:9304dd9a16a2 draft
"planemo upload for repository https://github.com/ARTbio/tools-artbio/tree/master/tools/sashimi_plot commit 746c03a1187e1d708af8628920a0c615cddcdacc"
author | artbio |
---|---|
date | Fri, 23 Aug 2019 11:38:29 -0400 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:9304dd9a16a2 |
---|---|
1 #!/usr/bin/env python | |
2 | |
3 # Import modules | |
4 import copy | |
5 import os | |
6 import re | |
7 import subprocess as sp | |
8 import sys | |
9 from argparse import ArgumentParser | |
10 from collections import OrderedDict | |
11 | |
12 | |
13 def define_options(): | |
14 # Argument parsing | |
15 parser = ArgumentParser(description="""Create sashimi plot for a given | |
16 genomic region""") | |
17 parser.add_argument("-b", "--bam", type=str, | |
18 help=""" | |
19 Individual bam file or file with a list of bam files. | |
20 In the case of a list of files the format is tsv: | |
21 1col: id for bam file, | |
22 2col: path of bam file, | |
23 3+col: additional columns | |
24 """) | |
25 parser.add_argument("-c", "--coordinates", type=str, | |
26 help="Genomic region. Format: chr:start-end (1-based)") | |
27 parser.add_argument("-o", "--out-prefix", type=str, dest="out_prefix", | |
28 default="sashimi", | |
29 help="Prefix for plot file name [default=%(default)s]") | |
30 parser.add_argument("-S", "--out-strand", type=str, dest="out_strand", | |
31 default="both", help="""Only for --strand other than | |
32 'NONE'. Choose which signal strand to plot: | |
33 <both> <plus> <minus> [default=%(default)s]""") | |
34 parser.add_argument("-M", "--min-coverage", type=int, default=1, | |
35 dest="min_coverage", help="""Minimum number of reads | |
36 supporting a junction to be drawn [default=1]""") | |
37 parser.add_argument("-j", "--junctions-bed", type=str, default="", | |
38 dest="junctions_bed", help="""Junction BED file name | |
39 [default=no junction file]""") | |
40 parser.add_argument("-g", "--gtf", | |
41 help="Gtf file with annotation (only exons is enough)") | |
42 parser.add_argument("-s", "--strand", default="NONE", type=str, | |
43 help="""Strand specificity: <NONE> <SENSE> <ANTISENSE> | |
44 <MATE1_SENSE> <MATE2_SENSE> [default=%(default)s]""") | |
45 parser.add_argument("--shrink", action="store_true", | |
46 help="""Shrink the junctions by a factor for nicer | |
47 display [default=%(default)s]""") | |
48 parser.add_argument("-O", "--overlay", type=int, | |
49 help="Index of column with overlay levels (1-based)") | |
50 parser.add_argument("-A", "--aggr", type=str, default="", | |
51 help="""Aggregate function for overlay: | |
52 <mean> <median> <mean_j> <median_j>. | |
53 Use mean_j | median_j to keep density overlay but | |
54 aggregate junction counts [default=no aggregation]""") | |
55 parser.add_argument("-C", "--color-factor", type=int, dest="color_factor", | |
56 help="Index of column with color levels (1-based)") | |
57 parser.add_argument("--alpha", type=float, default=0.5, | |
58 help="""Transparency level for density histogram | |
59 [default=%(default)s]""") | |
60 parser.add_argument("-P", "--palette", type=str, | |
61 help="""Color palette file. tsv file with >=1 columns, | |
62 where the color is the first column""") | |
63 parser.add_argument("-L", "--labels", type=int, dest="labels", default=1, | |
64 help="""Index of column with labels (1-based) | |
65 [default=%(default)s]""") | |
66 parser.add_argument("--height", type=float, default=2, | |
67 help="""Height of the individual signal plot in inches | |
68 [default=%(default)s]""") | |
69 parser.add_argument("--ann-height", type=float, default=1.5, | |
70 dest="ann_height", help="""Height of annotation plot in | |
71 inches [default=%(default)s]""") | |
72 parser.add_argument("--width", type=float, default=10, | |
73 help="""Width of the plot in inches | |
74 [default=%(default)s]""") | |
75 parser.add_argument("--base-size", type=float, default=14, | |
76 dest="base_size", help="""Base font size of the plot in | |
77 pch [default=%(default)s]""") | |
78 parser.add_argument("-F", "--out-format", type=str, default="pdf", | |
79 dest="out_format", help="""Output file format: | |
80 <pdf> <svg> <png> <jpeg> <tiff> | |
81 [default=%(default)s]""") | |
82 parser.add_argument("-R", "--out-resolution", type=int, default=300, | |
83 dest="out_resolution", help="""Output file resolution in | |
84 PPI (pixels per inch). Applies only to raster output | |
85 formats [default=%(default)s]""") | |
86 return parser | |
87 | |
88 | |
89 def parse_coordinates(c): | |
90 c = c.replace(",", "") | |
91 chr = c.split(":")[0] | |
92 start, end = c.split(":")[1].split("-") | |
93 # Convert to 0-based | |
94 start, end = int(start) - 1, int(end) | |
95 return chr, start, end | |
96 | |
97 | |
98 def count_operator(CIGAR_op, CIGAR_len, pos, start, end, a, junctions): | |
99 | |
100 # Match | |
101 if CIGAR_op == "M": | |
102 for i in range(pos, pos + CIGAR_len): | |
103 if i < start or i >= end: | |
104 continue | |
105 ind = i - start | |
106 a[ind] += 1 | |
107 | |
108 # Insertion or Soft-clip | |
109 if CIGAR_op == "I" or CIGAR_op == "S": | |
110 return pos | |
111 | |
112 # Deletion | |
113 if CIGAR_op == "D": | |
114 pass | |
115 | |
116 # Junction | |
117 if CIGAR_op == "N": | |
118 don = pos | |
119 acc = pos + CIGAR_len | |
120 if don > start and acc < end: | |
121 junctions[(don, acc)] = junctions.setdefault((don, acc), 0) + 1 | |
122 | |
123 pos = pos + CIGAR_len | |
124 | |
125 return pos | |
126 | |
127 | |
128 def flip_read(s, samflag): | |
129 if s == "NONE" or s == "SENSE": | |
130 return 0 | |
131 if s == "ANTISENSE": | |
132 return 1 | |
133 if s == "MATE1_SENSE": | |
134 if int(samflag) & 64: | |
135 return 0 | |
136 if int(samflag) & 128: | |
137 return 1 | |
138 if s == "MATE2_SENSE": | |
139 if int(samflag) & 64: | |
140 return 1 | |
141 if int(samflag) & 128: | |
142 return 0 | |
143 | |
144 | |
145 def read_bam(f, c, s): | |
146 | |
147 _, start, end = parse_coordinates(c) | |
148 | |
149 # Initialize coverage array and junction dict | |
150 a = {"+": [0] * (end - start)} | |
151 junctions = {"+": OrderedDict()} | |
152 if s != "NONE": | |
153 a["-"] = [0] * (end - start) | |
154 junctions["-"] = OrderedDict() | |
155 | |
156 p = sp.Popen("samtools view %s %s " % (f, c), shell=True, stdout=sp.PIPE) | |
157 for line in p.communicate()[0].decode('utf8').strip().split("\n"): | |
158 | |
159 if line == "": | |
160 continue | |
161 | |
162 line_sp = line.strip().split("\t") | |
163 samflag, read_start, CIGAR = line_sp[1], int(line_sp[3]), line_sp[5] | |
164 | |
165 # Ignore reads with more exotic CIGAR operators | |
166 if any(map(lambda x: x in CIGAR, ["H", "P", "X", "="])): | |
167 continue | |
168 | |
169 read_strand = ["+", "-"][flip_read(s, samflag) ^ bool(int(samflag) & | |
170 16)] | |
171 if s == "NONE": | |
172 read_strand = "+" | |
173 | |
174 CIGAR_lens = re.split("[MIDNS]", CIGAR)[:-1] | |
175 CIGAR_ops = re.split("[0-9]+", CIGAR)[1:] | |
176 | |
177 pos = read_start | |
178 | |
179 for n, CIGAR_op in enumerate(CIGAR_ops): | |
180 CIGAR_len = int(CIGAR_lens[n]) | |
181 pos = count_operator(CIGAR_op, CIGAR_len, pos, start, end, | |
182 a[read_strand], junctions[read_strand]) | |
183 | |
184 p.stdout.close() | |
185 return a, junctions | |
186 | |
187 | |
188 def get_bam_path(index, path): | |
189 if os.path.isabs(path): | |
190 return path | |
191 base_dir = os.path.dirname(index) | |
192 return os.path.join(base_dir, path) | |
193 | |
194 | |
195 def read_bam_input(f, overlay, color, label): | |
196 if f.endswith(".bam"): | |
197 bn = f.strip().split("/")[-1].strip(".bam") | |
198 yield bn, f, None, None, bn | |
199 return | |
200 with open(f) as openf: | |
201 for line in openf: | |
202 line_sp = line.strip().split("\t") | |
203 bam = get_bam_path(f, line_sp[1]) | |
204 overlay_level = line_sp[overlay-1] if overlay else None | |
205 color_level = line_sp[color-1] if color else None | |
206 label_text = line_sp[label-1] if label else None | |
207 yield line_sp[0], bam, overlay_level, color_level, label_text | |
208 | |
209 | |
210 def prepare_for_R(a, junctions, c, m): | |
211 | |
212 _, start, _ = parse_coordinates(args.coordinates) | |
213 | |
214 # Convert the array index to genomic coordinates | |
215 x = list(i+start for i in range(len(a))) | |
216 y = a | |
217 | |
218 # Arrays for R | |
219 dons, accs, yd, ya, counts = [], [], [], [], [] | |
220 | |
221 # Prepare arrays for junctions (which will be the arcs) | |
222 for (don, acc), n in junctions.items(): | |
223 | |
224 # Do not add junctions with less than defined coverage | |
225 if n < m: | |
226 continue | |
227 | |
228 dons.append(don) | |
229 accs.append(acc) | |
230 counts.append(n) | |
231 | |
232 yd.append(a[don - start - 1]) | |
233 ya.append(a[acc - start + 1]) | |
234 | |
235 return x, y, dons, accs, yd, ya, counts | |
236 | |
237 | |
238 def intersect_introns(data): | |
239 data = sorted(data) | |
240 it = iter(data) | |
241 a, b = next(it) | |
242 for c, d in it: | |
243 if b > c: | |
244 # Use `if b > c` if you want (1,2), (2,3) not to be | |
245 # treated as intersection. | |
246 b = min(b, d) | |
247 a = max(a, c) | |
248 else: | |
249 yield a, b | |
250 a, b = c, d | |
251 yield a, b | |
252 | |
253 | |
254 def shrink_density(x, y, introns): | |
255 new_x, new_y = [], [] | |
256 shift = 0 | |
257 start = 0 | |
258 # introns are already sorted by coordinates | |
259 for a, b in introns: | |
260 end = x.index(a)+1 | |
261 new_x += [int(i-shift) for i in x[start:end]] | |
262 new_y += y[start:end] | |
263 start = x.index(b) | |
264 L = (b-a) | |
265 shift += L-L**0.7 | |
266 new_x += [int(i-shift) for i in x[start:]] | |
267 new_y += y[start:] | |
268 return new_x, new_y | |
269 | |
270 | |
271 def shrink_junctions(dons, accs, introns): | |
272 new_dons, new_accs = [0]*len(dons), [0]*len(accs) | |
273 shift_acc = 0 | |
274 shift_don = 0 | |
275 s = set() | |
276 junctions = list(zip(dons, accs)) | |
277 for a, b in introns: | |
278 L = b - a | |
279 shift_acc += L-int(L**0.7) | |
280 for i, (don, acc) in enumerate(junctions): | |
281 if a >= don and b <= acc: | |
282 if (don, acc) not in s: | |
283 new_dons[i] = don - shift_don | |
284 new_accs[i] = acc - shift_acc | |
285 else: | |
286 new_accs[i] = acc - shift_acc | |
287 s.add((don, acc)) | |
288 shift_don = shift_acc | |
289 return new_dons, new_accs | |
290 | |
291 | |
292 def read_palette(f): | |
293 palette = "#ff0000", "#00ff00", "#0000ff", "#000000" | |
294 if f: | |
295 with open(f) as openf: | |
296 palette = list(line.split("\t")[0].strip() for line in openf) | |
297 return palette | |
298 | |
299 | |
300 def read_gtf(f, c): | |
301 exons = OrderedDict() | |
302 transcripts = OrderedDict() | |
303 chr, start, end = parse_coordinates(c) | |
304 end = end - 1 | |
305 with open(f) as openf: | |
306 for line in openf: | |
307 if line.startswith("#"): | |
308 continue | |
309 (el_chr, _, el, el_start, el_end, _, | |
310 strand, _, tags) = line.strip().split("\t") | |
311 if el_chr != chr: | |
312 continue | |
313 d = dict(kv.strip().split(" ") for kv in | |
314 tags.strip(";").split("; ")) | |
315 transcript_id = d["transcript_id"] | |
316 el_start, el_end = int(el_start) - 1, int(el_end) | |
317 strand = '"' + strand + '"' | |
318 if el == "transcript": | |
319 if (el_end > start and el_start < end): | |
320 transcripts[transcript_id] = (max(start, el_start), | |
321 min(end, el_end), | |
322 strand) | |
323 continue | |
324 if el == "exon": | |
325 if (start < el_start < end or start < el_end < end): | |
326 exons.setdefault(transcript_id, | |
327 []).append((max(el_start, start), | |
328 min(end, el_end), strand)) | |
329 | |
330 return transcripts, exons | |
331 | |
332 | |
333 def make_introns(transcripts, exons, intersected_introns=None): | |
334 new_transcripts = copy.deepcopy(transcripts) | |
335 new_exons = copy.deepcopy(exons) | |
336 introns = OrderedDict() | |
337 if intersected_introns: | |
338 for tx, (tx_start, tx_end, strand) in new_transcripts.items(): | |
339 total_shift = 0 | |
340 for a, b in intersected_introns: | |
341 L = b - a | |
342 shift = L - int(L**0.7) | |
343 total_shift += shift | |
344 for i, (exon_start, exon_end, strand) in \ | |
345 enumerate(exons.get(tx, [])): | |
346 new_exon_start, new_exon_end = new_exons[tx][i][:2] | |
347 if a < exon_start: | |
348 if b > exon_end: | |
349 if i == len(exons[tx])-1: | |
350 total_shift = total_shift - shift + \ | |
351 (exon_start - a)*(1-int(L**-0.3)) | |
352 shift = (exon_start - a)*(1-int(L**-0.3)) | |
353 new_exon_end = new_exons[tx][i][1] - shift | |
354 new_exon_start = new_exons[tx][i][0] - shift | |
355 if b <= exon_end: | |
356 new_exon_end = new_exons[tx][i][1] - shift | |
357 new_exons[tx][i] = (new_exon_start, new_exon_end, strand) | |
358 tx_start = min(tx_start, | |
359 sorted(new_exons.get(tx, [[sys.maxsize]]))[0][0]) | |
360 new_transcripts[tx] = (tx_start, tx_end - total_shift, strand) | |
361 | |
362 for tx, (tx_start, tx_end, strand) in new_transcripts.items(): | |
363 intron_start = tx_start | |
364 ex_end = 0 | |
365 for ex_start, ex_end, strand in sorted(new_exons.get(tx, [])): | |
366 intron_end = ex_start | |
367 if tx_start < ex_start: | |
368 introns.setdefault(tx, []).append((intron_start, intron_end, | |
369 strand)) | |
370 intron_start = ex_end | |
371 if tx_end > ex_end: | |
372 introns.setdefault(tx, []).append((intron_start, tx_end, strand)) | |
373 d = {'transcripts': new_transcripts, | |
374 'exons': new_exons, | |
375 'introns': introns} | |
376 return d | |
377 | |
378 | |
379 def gtf_for_ggplot(annotation, start, end, arrow_bins): | |
380 arrow_space = int((end - start)/arrow_bins) | |
381 s = """ | |
382 | |
383 # data table with exons | |
384 ann_list = list( | |
385 "exons" = data.table(), | |
386 "introns" = data.table() | |
387 ) | |
388 """ | |
389 | |
390 if annotation["exons"]: | |
391 | |
392 s += """ | |
393 ann_list[['exons']] = data.table( | |
394 tx = rep(c(%(tx_exons)s), c(%(n_exons)s)), | |
395 start = c(%(exon_start)s), | |
396 end = c(%(exon_end)s), | |
397 strand = c(%(strand)s) | |
398 ) | |
399 """ % ({ | |
400 "tx_exons": ",".join(annotation["exons"].keys()), | |
401 "n_exons": ",".join(map(str, map(len, | |
402 annotation["exons"].values()))), | |
403 "exon_start": ",".join(map(str, (v[0] for vs in | |
404 annotation["exons"].values() for v in vs))), | |
405 "exon_end": ",".join(map(str, (v[1] for vs in | |
406 annotation["exons"].values() for v in vs))), | |
407 "strand": ",".join(map(str, (v[2] for vs in | |
408 annotation["exons"].values() for v in vs))), | |
409 }) | |
410 | |
411 if annotation["introns"]: | |
412 | |
413 s += """ | |
414 ann_list[['introns']] = data.table( | |
415 tx = rep(c(%(tx_introns)s), c(%(n_introns)s)), | |
416 start = c(%(intron_start)s), | |
417 end = c(%(intron_end)s), | |
418 strand = c(%(strand)s) | |
419 ) | |
420 # Create data table for strand arrows | |
421 txarrows = data.table() | |
422 introns = ann_list[['introns']] | |
423 # Add right-pointing arrows for plus strand | |
424 if ("+" %%in%% introns$strand) { | |
425 txarrows = rbind( | |
426 txarrows, | |
427 introns[strand=="+" & end-start>5, list( | |
428 seq(start+4,end,by=%(arrow_space)s)-1, | |
429 seq(start+4,end,by=%(arrow_space)s) | |
430 ), by=.(tx,start,end) | |
431 ] | |
432 ) | |
433 } | |
434 # Add left-pointing arrows for minus strand | |
435 if ("-" %%in%% introns$strand) { | |
436 txarrows = rbind(txarrows, | |
437 introns[strand=="-" & end-start>5, | |
438 list(seq(start,max(start+1, end-4), | |
439 by=%(arrow_space)s), | |
440 seq(start,max(start+1, end-4), | |
441 by=%(arrow_space)s)-1 | |
442 ), | |
443 by=.(tx,start,end) | |
444 ] | |
445 ) | |
446 } | |
447 """ % ({ | |
448 "tx_introns": ",".join(annotation["introns"].keys()), | |
449 "n_introns": ",".join(map(str, map(len, | |
450 annotation["introns"].values()))), | |
451 "intron_start": ",".join(map(str, (v[0] for vs in | |
452 annotation["introns"].values() for v in | |
453 vs))), | |
454 "intron_end": ",".join(map(str, (v[1] for vs in | |
455 annotation["introns"].values() for v in | |
456 vs))), | |
457 "strand": ",".join(map(str, (v[2] for vs in | |
458 annotation["introns"].values() for v in vs))), | |
459 "arrow_space": arrow_space, | |
460 }) | |
461 | |
462 s += """ | |
463 | |
464 gtfp = ggplot() | |
465 if (length(ann_list[['introns']]) > 0) { | |
466 gtfp = gtfp + geom_segment(data = ann_list[['introns']], | |
467 aes(x = start, | |
468 xend = end, | |
469 y = tx, | |
470 yend = tx), | |
471 size = 0.3) | |
472 gtfp = gtfp + geom_segment(data = txarrows, | |
473 aes(x = V1, | |
474 xend = V2, | |
475 y = tx, | |
476 yend = tx), | |
477 arrow = arrow(length = unit(0.02, "npc"))) | |
478 } | |
479 if (length(ann_list[['exons']]) > 0) { | |
480 gtfp = gtfp + geom_segment(data = ann_list[['exons']], | |
481 aes(x = start, | |
482 xend = end, | |
483 y = tx, | |
484 yend = tx), | |
485 size = 5, | |
486 alpha = 1) | |
487 } | |
488 gtfp = gtfp + scale_y_discrete(expand = c(0, 0.5)) | |
489 gtfp = gtfp + scale_x_continuous(expand = c(0, 0.25), | |
490 limits = c( %s,% s)) | |
491 gtfp = gtfp + labs(y = NULL) | |
492 gtfp = gtfp + theme(axis.line = element_blank(), | |
493 axis.text.x = element_blank(), | |
494 axis.ticks = element_blank()) | |
495 """ % (start, end) | |
496 | |
497 return s | |
498 | |
499 | |
500 def setup_R_script(h, w, b, label_dict): | |
501 s = """ | |
502 library(ggplot2) | |
503 library(grid) | |
504 library(gridExtra) | |
505 library(data.table) | |
506 library(gtable) | |
507 | |
508 scale_lwd = function(r) { | |
509 lmin = 0.1 | |
510 lmax = 4 | |
511 return( r*(lmax-lmin)+lmin ) | |
512 } | |
513 | |
514 base_size = %(b)s | |
515 height = ( %(h)s + base_size*0.352777778/67 ) * 1.02 | |
516 width = %(w)s | |
517 theme_set(theme_bw(base_size=base_size)) | |
518 theme_update( | |
519 plot.margin = unit(c(15,15,15,15), "pt"), | |
520 panel.grid = element_blank(), | |
521 panel.border = element_blank(), | |
522 axis.line = element_line(size=0.5), | |
523 axis.title.x = element_blank(), | |
524 axis.title.y = element_text(angle=0, vjust=0.5) | |
525 ) | |
526 | |
527 labels = list(%(labels)s) | |
528 | |
529 density_list = list() | |
530 junction_list = list() | |
531 | |
532 """ % ({ | |
533 'h': h, | |
534 'w': w, | |
535 'b': b, | |
536 'labels': ",".join(('"%s"="%s"' % (id, lab) for id, lab in | |
537 label_dict.items())), | |
538 }) | |
539 return s | |
540 | |
541 | |
542 def median(lst): | |
543 quotient, remainder = divmod(len(lst), 2) | |
544 if remainder: | |
545 return sorted(lst)[quotient] | |
546 return sum(sorted(lst)[quotient - 1:quotient + 1]) / 2. | |
547 | |
548 | |
549 def mean(lst): | |
550 return sum(lst)/len(lst) | |
551 | |
552 | |
553 def make_R_lists(id_list, d, overlay_dict, aggr, intersected_introns): | |
554 s = "" | |
555 aggr_f = { | |
556 "mean": mean, | |
557 "median": median, | |
558 } | |
559 id_list = id_list if not overlay_dict else overlay_dict.keys() | |
560 # Iterate over ids to get bam signal and junctions | |
561 for k in id_list: | |
562 x, y, dons, accs, yd, ya, counts = [], [], [], [], [], [], [] | |
563 if not overlay_dict: | |
564 x, y, dons, accs, yd, ya, counts = d[k] | |
565 if intersected_introns: | |
566 x, y = shrink_density(x, y, intersected_introns) | |
567 dons, accs = shrink_junctions(dons, accs, intersected_introns) | |
568 else: | |
569 for id in overlay_dict[k]: | |
570 xid, yid, donsid, accsid, ydid, yaid, countsid = d[id] | |
571 if intersected_introns: | |
572 xid, yid = shrink_density(xid, yid, intersected_introns) | |
573 donsid, accsid = shrink_junctions(donsid, accsid, | |
574 intersected_introns) | |
575 x += xid | |
576 y += yid | |
577 dons += donsid | |
578 accs += accsid | |
579 yd += ydid | |
580 ya += yaid | |
581 counts += countsid | |
582 if aggr and "_j" not in aggr: | |
583 x = d[overlay_dict[k][0]][0] | |
584 y = list(map(aggr_f[aggr], zip(*(d[id][1] for id in | |
585 overlay_dict[k])))) | |
586 if intersected_introns: | |
587 x, y = shrink_density(x, y, intersected_introns) | |
588 # dons, accs, yd, ya, counts = [], [], [], [], [] | |
589 s += """ | |
590 density_list[["%(id)s"]] = data.frame(x = c(%(x)s), y = c(%(y)s)) | |
591 junction_list[["%(id)s"]] = data.frame(x = c(%(dons)s), | |
592 xend=c(%(accs)s), | |
593 y=c(%(yd)s), | |
594 yend=c(%(ya)s), | |
595 count=c(%(counts)s)) | |
596 """ % ({ | |
597 "id": k, | |
598 'x': ",".join(map(str, x)), | |
599 'y': ",".join(map(str, y)), | |
600 'dons': ",".join(map(str, dons)), | |
601 'accs': ",".join(map(str, accs)), | |
602 'yd': ",".join(map(str, yd)), | |
603 'ya': ",".join(map(str, ya)), | |
604 'counts': ",".join(map(str, counts)), | |
605 }) | |
606 return s | |
607 | |
608 | |
609 def plot(R_script): | |
610 p = sp.Popen("R --vanilla --slave", shell=True, stdin=sp.PIPE) | |
611 p.communicate(input=R_script.encode()) | |
612 p.stdin.close() | |
613 p.wait() | |
614 return | |
615 | |
616 | |
617 def colorize(d, p, color_factor): | |
618 levels = sorted(set(d.values())) | |
619 n = len(levels) | |
620 if n > len(p): | |
621 p = (p*n)[:n] | |
622 if color_factor: | |
623 s = "color_list = list(%s)\n" % (",".join('%s="%s"' % (k, | |
624 p[levels.index(v)]) for k, v in | |
625 d.items())) | |
626 else: | |
627 s = "color_list = list(%s)\n" % (",".join('%s="%s"' % (k, "grey") for | |
628 k, v in d.items())) | |
629 return s | |
630 | |
631 | |
632 if __name__ == "__main__": | |
633 | |
634 strand_dict = {"plus": "+", "minus": "-"} | |
635 | |
636 parser = define_options() | |
637 if len(sys.argv) == 1: | |
638 parser.print_help() | |
639 sys.exit(1) | |
640 args = parser.parse_args() | |
641 | |
642 if args.aggr and not args.overlay: | |
643 print("""ERROR: Cannot apply aggregate function | |
644 if overlay is not selected.""") | |
645 exit(1) | |
646 | |
647 palette = read_palette(args.palette) | |
648 | |
649 (bam_dict, overlay_dict, color_dict, | |
650 id_list, label_dict) = ({"+": OrderedDict()}, OrderedDict(), | |
651 OrderedDict(), [], OrderedDict()) | |
652 if args.strand != "NONE": | |
653 bam_dict["-"] = OrderedDict() | |
654 if args.junctions_bed != "": | |
655 junctions_list = [] | |
656 | |
657 for (id, bam, overlay_level, | |
658 color_level, label_text) in read_bam_input(args.bam, | |
659 args.overlay, | |
660 args.color_factor, | |
661 args.labels): | |
662 if not os.path.isfile(bam): | |
663 continue | |
664 id_list.append(id) | |
665 label_dict[id] = label_text | |
666 a, junctions = read_bam(bam, args.coordinates, args.strand) | |
667 if a.keys() == ["+"] and all(map(lambda x: x == 0, | |
668 list(a.values()[0]))): | |
669 print("ERROR: No reads in the specified area.") | |
670 exit(1) | |
671 for strand in a: | |
672 # Store junction information | |
673 if args.junctions_bed: | |
674 for k, v in zip(junctions[strand].keys(), | |
675 junctions[strand].values()): | |
676 if v > args.min_coverage: | |
677 junctions_list.append('\t'.join([args.coordinates.split | |
678 (':')[0], str(k[0]), str(k[1]), | |
679 id, str(v), strand])) | |
680 bam_dict[strand][id] = prepare_for_R(a[strand], | |
681 junctions[strand], | |
682 args.coordinates, | |
683 args.min_coverage) | |
684 if color_level is None: | |
685 color_dict.setdefault(id, id) | |
686 if overlay_level is not None: | |
687 overlay_dict.setdefault(overlay_level, []).append(id) | |
688 label_dict[overlay_level] = overlay_level | |
689 color_dict.setdefault(overlay_level, overlay_level) | |
690 if overlay_level is None: | |
691 color_dict.setdefault(id, color_level) | |
692 | |
693 # No bam files | |
694 if not bam_dict["+"]: | |
695 print("ERROR: No available bam files.") | |
696 exit(1) | |
697 | |
698 # Write junctions to BED | |
699 if args.junctions_bed: | |
700 if not args.junctions_bed.endswith('.bed'): | |
701 args.junctions_bed = args.junctions_bed + '.bed' | |
702 jbed = open(args.junctions_bed, 'w') | |
703 jbed.write('\n'.join(sorted(junctions_list))) | |
704 jbed.close() | |
705 | |
706 if args.gtf: | |
707 transcripts, exons = read_gtf(args.gtf, args.coordinates) | |
708 | |
709 if args.out_format not in ('pdf', 'png', 'svg', 'tiff', 'jpeg'): | |
710 print("""ERROR: Provided output format '%s' is not available. | |
711 Please select among 'pdf', 'png', 'svg', | |
712 'tiff' or 'jpeg'""" % args.out_format) | |
713 exit(1) | |
714 | |
715 # Iterate for plus and minus strand | |
716 for strand in bam_dict: | |
717 | |
718 # Output file name (allow tiff/tif and jpeg/jpg extensions) | |
719 if args.out_prefix.endswith(('.pdf', '.png', '.svg', '.tiff', | |
720 '.tif', '.jpeg', '.jpg')): | |
721 out_split = os.path.splitext(args.out_prefix) | |
722 if (args.out_format == out_split[1][1:] or | |
723 args.out_format == 'tiff' | |
724 and out_split[1] in ('.tiff', '.tif') or | |
725 args.out_format == 'jpeg' | |
726 and out_split[1] in ('.jpeg', '.jpg')): | |
727 args.out_prefix = out_split[0] | |
728 out_suffix = out_split[1][1:] | |
729 else: | |
730 out_suffix = args.out_format | |
731 else: | |
732 out_suffix = args.out_format | |
733 out_prefix = args.out_prefix + "_" + strand | |
734 if args.strand == "NONE": | |
735 out_prefix = args.out_prefix | |
736 else: | |
737 if args.out_strand != "both" \ | |
738 and strand != strand_dict[args.out_strand]: | |
739 continue | |
740 | |
741 # Find set of junctions to perform shrink | |
742 intersected_introns = None | |
743 if args.shrink: | |
744 introns = (v for vs in bam_dict[strand].values() for v in | |
745 zip(vs[2], vs[3])) | |
746 intersected_introns = list(intersect_introns(introns)) | |
747 | |
748 # *** PLOT *** Define plot height | |
749 bam_height = args.height * len(id_list) | |
750 if args.overlay: | |
751 bam_height = args.height * len(overlay_dict) | |
752 if args.gtf: | |
753 bam_height += args.ann_height | |
754 | |
755 # *** PLOT *** Start R script by loading libraries, | |
756 # initializing variables, etc... | |
757 R_script = setup_R_script(bam_height, args.width, | |
758 args.base_size, label_dict) | |
759 | |
760 R_script += colorize(color_dict, palette, args.color_factor) | |
761 | |
762 # *** PLOT *** Prepare annotation plot only for the first bam file | |
763 arrow_bins = 50 | |
764 if args.gtf: | |
765 # Make introns from annotation (they are shrunk if required) | |
766 annotation = make_introns(transcripts, exons, intersected_introns) | |
767 x = list(bam_dict[strand].values())[0][0] | |
768 if args.shrink: | |
769 x, _ = shrink_density(x, x, intersected_introns) | |
770 R_script += gtf_for_ggplot(annotation, x[0], x[-1], arrow_bins) | |
771 | |
772 R_script += make_R_lists(id_list, bam_dict[strand], overlay_dict, | |
773 args.aggr, intersected_introns) | |
774 | |
775 R_script += """ | |
776 | |
777 pdf(NULL) # just to remove the blank pdf produced by ggplotGrob | |
778 # fix problems with ggplot2 vs >3.0.0 | |
779 if(packageVersion('ggplot2') >= '3.0.0'){ | |
780 vs = 1 | |
781 } else { | |
782 vs = 0 | |
783 } | |
784 | |
785 density_grobs = list(); | |
786 | |
787 for (bam_index in 1:length(density_list)) { | |
788 | |
789 id = names(density_list)[bam_index] | |
790 d = data.table(density_list[[id]]) | |
791 junctions = data.table(junction_list[[id]]) | |
792 | |
793 maxheight = max(d[['y']]) | |
794 | |
795 # Density plot | |
796 gp = ggplot(d) + geom_bar(aes(x, y), width=1, | |
797 position='identity', | |
798 stat='identity', | |
799 fill=color_list[[id]], | |
800 alpha=%(alpha)s) | |
801 gp = gp + labs(y=labels[[id]]) | |
802 gp = gp + scale_x_continuous(expand=c(0,0.2)) | |
803 | |
804 # fix problems with ggplot2 vs >3.0.0 | |
805 if(packageVersion('ggplot2') >= '3.0.0') { | |
806 gp = gp + | |
807 scale_y_continuous(breaks = | |
808 ggplot_build(gp | |
809 )$layout$panel_params[[1]]$y.major_source) | |
810 } else { | |
811 gp = gp + | |
812 scale_y_continuous(breaks = | |
813 ggplot_build(gp | |
814 )$layout$panel_ranges[[1]]$y.major_source) | |
815 } | |
816 | |
817 # Aggregate junction counts | |
818 row_i = c() | |
819 if (nrow(junctions) >0 ) { | |
820 | |
821 junctions$jlabel = as.character(junctions$count) | |
822 junctions = setNames(junctions[,.(max(y), | |
823 max(yend), | |
824 round(mean(count)), | |
825 paste(jlabel, | |
826 collapse=",")), | |
827 keyby=.(x,xend)], | |
828 names(junctions)) | |
829 if ("%(args.aggr)s" != "") { | |
830 junctions = setNames( | |
831 junctions[,.(max(y), | |
832 max(yend), | |
833 round(%(args.aggr)s(count)), | |
834 round(%(args.aggr)s(count))), | |
835 keyby=.(x,xend)], | |
836 names(junctions)) | |
837 } | |
838 # The number of rows (unique junctions per bam) has to be | |
839 # calculated after aggregation | |
840 row_i = 1:nrow(junctions) | |
841 } | |
842 | |
843 | |
844 for (i in row_i) { | |
845 | |
846 j_tot_counts = sum(junctions[['count']]) | |
847 | |
848 j = as.numeric(junctions[i,1:5]) | |
849 | |
850 if ("%(args.aggr)s" != "") { | |
851 j[3] = as.numeric(d[x==j[1]-1,y]) | |
852 j[4] = as.numeric(d[x==j[2]+1,y]) | |
853 } | |
854 | |
855 # Find intron midpoint | |
856 xmid = round(mean(j[1:2]), 1) | |
857 ymid = max(j[3:4]) * 1.2 | |
858 | |
859 # Thickness of the arch | |
860 lwd = scale_lwd(j[5]/j_tot_counts) | |
861 | |
862 curve_par = gpar(lwd=lwd, col=color_list[[id]]) | |
863 | |
864 # Arc grobs | |
865 | |
866 # Choose position of the arch (top or bottom) | |
867 nss = i | |
868 if (nss%%%%2 == 0) { #bottom | |
869 ymid = -0.3 * maxheight | |
870 # Draw the arcs | |
871 # Left | |
872 curve = xsplineGrob(x = c(0, 0, 1, 1), | |
873 y = c(1, 0, 0, 0), | |
874 shape = 1, | |
875 gp = curve_par) | |
876 gp = gp + annotation_custom(grob = curve, | |
877 j[1], xmid, 0, ymid) | |
878 # Right | |
879 curve = xsplineGrob(x = c(1, 1, 0, 0), | |
880 y = c(1, 0, 0, 0), | |
881 shape = 1, | |
882 gp = curve_par) | |
883 gp = gp + annotation_custom(grob = curve, | |
884 xmid, | |
885 j[2], | |
886 0, ymid) | |
887 } | |
888 | |
889 if (nss%%%%2 != 0) { #top | |
890 # Draw the arcs | |
891 # Left | |
892 curve = xsplineGrob(x = c(0, 0, 1, 1), | |
893 y = c(0, 1, 1, 1), | |
894 shape = 1, | |
895 gp = curve_par) | |
896 gp = gp + annotation_custom(grob = curve, | |
897 j[1], xmid, j[3], ymid) | |
898 # Right | |
899 curve = xsplineGrob(x = c(1, 1, 0, 0), | |
900 y = c(0, 1, 1, 1), | |
901 shape = 1, | |
902 gp = curve_par) | |
903 gp = gp + annotation_custom(grob = curve, | |
904 xmid, j[2], j[4], ymid) | |
905 } | |
906 | |
907 # Add junction labels | |
908 gp = gp + annotate("label", x = xmid, y = ymid, | |
909 label = as.character(junctions[i,6]), | |
910 vjust = 0.5, hjust = 0.5, | |
911 label.padding = unit(0.01, "lines"), | |
912 label.size = NA, | |
913 size = (base_size*0.352777778)*0.6) | |
914 | |
915 } | |
916 | |
917 gpGrob = ggplotGrob(gp); | |
918 gpGrob$layout$clip[gpGrob$layout$name=="panel"] <- "off" | |
919 if (bam_index == 1) { | |
920 # fix problems ggplot2 vs | |
921 maxWidth = gpGrob$widths[2+vs] + gpGrob$widths[3+vs]; | |
922 maxYtextWidth = gpGrob$widths[3+vs]; | |
923 # Extract x axis grob (trim=F --> keep empty cells) | |
924 xaxisGrob <- gtable_filter(gpGrob, "axis-b", trim=F) | |
925 # fix problems ggplot2 vs | |
926 xaxisGrob$heights[8+vs] = gpGrob$heights[1] | |
927 x.axis.height = gpGrob$heights[7+vs] + gpGrob$heights[1] | |
928 } | |
929 | |
930 | |
931 # Remove x axis from all density plots | |
932 kept_names = gpGrob$layout$name[gpGrob$layout$name != "axis-b"] | |
933 gpGrob <- gtable_filter(gpGrob, | |
934 paste(kept_names, sep = "", | |
935 collapse = "|"), | |
936 trim=F) | |
937 | |
938 # Find max width of y text and y label and max width of y text | |
939 # fix problems ggplot2 vs | |
940 maxWidth = grid::unit.pmax(maxWidth, | |
941 gpGrob$widths[2+vs] + | |
942 gpGrob$widths[3+vs]); | |
943 maxYtextWidth = grid::unit.pmax(maxYtextWidth, | |
944 gpGrob$widths[3+vs]); | |
945 density_grobs[[id]] = gpGrob; | |
946 } | |
947 | |
948 # Add x axis grob after density grobs BEFORE annotation grob | |
949 density_grobs[["xaxis"]] = xaxisGrob | |
950 | |
951 # Annotation grob | |
952 if (%(args.gtf)s == 1) { | |
953 gtfGrob = ggplotGrob(gtfp); | |
954 maxWidth = grid::unit.pmax(maxWidth, | |
955 gtfGrob$widths[2+vs] + | |
956 gtfGrob$widths[3+vs]); | |
957 density_grobs[['gtf']] = gtfGrob; | |
958 } | |
959 | |
960 # Reassign grob widths to align the plots | |
961 for (id in names(density_grobs)) { | |
962 density_grobs[[id]]$widths[1] <- | |
963 density_grobs[[id]]$widths[1] + | |
964 maxWidth - (density_grobs[[id]]$widths[2 + vs] + | |
965 maxYtextWidth) | |
966 # fix problems ggplot2 vs | |
967 density_grobs[[id]]$widths[3 + vs] <- | |
968 maxYtextWidth # fix problems ggplot2 vs | |
969 } | |
970 | |
971 # Heights for density, x axis and annotation | |
972 heights = unit.c( | |
973 unit(rep(%(signal_height)s, | |
974 length(density_list)), "in"), | |
975 x.axis.height, | |
976 unit(%(ann_height)s*%(args.gtf)s, "in") | |
977 ) | |
978 | |
979 # Arrange grobs | |
980 argrobs = arrangeGrob( | |
981 grobs=density_grobs, | |
982 ncol=1, | |
983 heights = heights, | |
984 ); | |
985 | |
986 # Save plot to file in the requested format | |
987 if ("%(out_format)s" == "tiff"){ | |
988 # TIFF images will be lzw-compressed | |
989 ggsave("%(out)s", | |
990 plot = argrobs, | |
991 device = "tiff", | |
992 width = width, | |
993 height = height, | |
994 units = "in", | |
995 dpi = %(out_resolution)s, | |
996 compression = "lzw") | |
997 } else { | |
998 ggsave("%(out)s", | |
999 plot = argrobs, | |
1000 device = "%(out_format)s", | |
1001 width = width, | |
1002 height = height, | |
1003 units = "in", | |
1004 dpi = %(out_resolution)s) | |
1005 } | |
1006 | |
1007 dev.log = dev.off() | |
1008 | |
1009 """ % ({ | |
1010 "out": "%s.%s" % (out_prefix, out_suffix), | |
1011 "out_format": args.out_format, | |
1012 "out_resolution": args.out_resolution, | |
1013 "args.gtf": float(bool(args.gtf)), | |
1014 "args.aggr": args.aggr.rstrip("_j"), | |
1015 "signal_height": args.height, | |
1016 "ann_height": args.ann_height, | |
1017 "alpha": args.alpha, | |
1018 }) | |
1019 if os.getenv('GGSASHIMI_DEBUG') is not None: | |
1020 with open("R_script", 'w') as r: | |
1021 r.write(R_script) | |
1022 else: | |
1023 plot(R_script) | |
1024 exit() |