comparison sashimi-plot.py @ 0:9304dd9a16a2 draft

"planemo upload for repository https://github.com/ARTbio/tools-artbio/tree/master/tools/sashimi_plot commit 746c03a1187e1d708af8628920a0c615cddcdacc"
author artbio
date Fri, 23 Aug 2019 11:38:29 -0400
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:9304dd9a16a2
1 #!/usr/bin/env python
2
3 # Import modules
4 import copy
5 import os
6 import re
7 import subprocess as sp
8 import sys
9 from argparse import ArgumentParser
10 from collections import OrderedDict
11
12
13 def define_options():
14 # Argument parsing
15 parser = ArgumentParser(description="""Create sashimi plot for a given
16 genomic region""")
17 parser.add_argument("-b", "--bam", type=str,
18 help="""
19 Individual bam file or file with a list of bam files.
20 In the case of a list of files the format is tsv:
21 1col: id for bam file,
22 2col: path of bam file,
23 3+col: additional columns
24 """)
25 parser.add_argument("-c", "--coordinates", type=str,
26 help="Genomic region. Format: chr:start-end (1-based)")
27 parser.add_argument("-o", "--out-prefix", type=str, dest="out_prefix",
28 default="sashimi",
29 help="Prefix for plot file name [default=%(default)s]")
30 parser.add_argument("-S", "--out-strand", type=str, dest="out_strand",
31 default="both", help="""Only for --strand other than
32 'NONE'. Choose which signal strand to plot:
33 <both> <plus> <minus> [default=%(default)s]""")
34 parser.add_argument("-M", "--min-coverage", type=int, default=1,
35 dest="min_coverage", help="""Minimum number of reads
36 supporting a junction to be drawn [default=1]""")
37 parser.add_argument("-j", "--junctions-bed", type=str, default="",
38 dest="junctions_bed", help="""Junction BED file name
39 [default=no junction file]""")
40 parser.add_argument("-g", "--gtf",
41 help="Gtf file with annotation (only exons is enough)")
42 parser.add_argument("-s", "--strand", default="NONE", type=str,
43 help="""Strand specificity: <NONE> <SENSE> <ANTISENSE>
44 <MATE1_SENSE> <MATE2_SENSE> [default=%(default)s]""")
45 parser.add_argument("--shrink", action="store_true",
46 help="""Shrink the junctions by a factor for nicer
47 display [default=%(default)s]""")
48 parser.add_argument("-O", "--overlay", type=int,
49 help="Index of column with overlay levels (1-based)")
50 parser.add_argument("-A", "--aggr", type=str, default="",
51 help="""Aggregate function for overlay:
52 <mean> <median> <mean_j> <median_j>.
53 Use mean_j | median_j to keep density overlay but
54 aggregate junction counts [default=no aggregation]""")
55 parser.add_argument("-C", "--color-factor", type=int, dest="color_factor",
56 help="Index of column with color levels (1-based)")
57 parser.add_argument("--alpha", type=float, default=0.5,
58 help="""Transparency level for density histogram
59 [default=%(default)s]""")
60 parser.add_argument("-P", "--palette", type=str,
61 help="""Color palette file. tsv file with >=1 columns,
62 where the color is the first column""")
63 parser.add_argument("-L", "--labels", type=int, dest="labels", default=1,
64 help="""Index of column with labels (1-based)
65 [default=%(default)s]""")
66 parser.add_argument("--height", type=float, default=2,
67 help="""Height of the individual signal plot in inches
68 [default=%(default)s]""")
69 parser.add_argument("--ann-height", type=float, default=1.5,
70 dest="ann_height", help="""Height of annotation plot in
71 inches [default=%(default)s]""")
72 parser.add_argument("--width", type=float, default=10,
73 help="""Width of the plot in inches
74 [default=%(default)s]""")
75 parser.add_argument("--base-size", type=float, default=14,
76 dest="base_size", help="""Base font size of the plot in
77 pch [default=%(default)s]""")
78 parser.add_argument("-F", "--out-format", type=str, default="pdf",
79 dest="out_format", help="""Output file format:
80 <pdf> <svg> <png> <jpeg> <tiff>
81 [default=%(default)s]""")
82 parser.add_argument("-R", "--out-resolution", type=int, default=300,
83 dest="out_resolution", help="""Output file resolution in
84 PPI (pixels per inch). Applies only to raster output
85 formats [default=%(default)s]""")
86 return parser
87
88
89 def parse_coordinates(c):
90 c = c.replace(",", "")
91 chr = c.split(":")[0]
92 start, end = c.split(":")[1].split("-")
93 # Convert to 0-based
94 start, end = int(start) - 1, int(end)
95 return chr, start, end
96
97
98 def count_operator(CIGAR_op, CIGAR_len, pos, start, end, a, junctions):
99
100 # Match
101 if CIGAR_op == "M":
102 for i in range(pos, pos + CIGAR_len):
103 if i < start or i >= end:
104 continue
105 ind = i - start
106 a[ind] += 1
107
108 # Insertion or Soft-clip
109 if CIGAR_op == "I" or CIGAR_op == "S":
110 return pos
111
112 # Deletion
113 if CIGAR_op == "D":
114 pass
115
116 # Junction
117 if CIGAR_op == "N":
118 don = pos
119 acc = pos + CIGAR_len
120 if don > start and acc < end:
121 junctions[(don, acc)] = junctions.setdefault((don, acc), 0) + 1
122
123 pos = pos + CIGAR_len
124
125 return pos
126
127
128 def flip_read(s, samflag):
129 if s == "NONE" or s == "SENSE":
130 return 0
131 if s == "ANTISENSE":
132 return 1
133 if s == "MATE1_SENSE":
134 if int(samflag) & 64:
135 return 0
136 if int(samflag) & 128:
137 return 1
138 if s == "MATE2_SENSE":
139 if int(samflag) & 64:
140 return 1
141 if int(samflag) & 128:
142 return 0
143
144
145 def read_bam(f, c, s):
146
147 _, start, end = parse_coordinates(c)
148
149 # Initialize coverage array and junction dict
150 a = {"+": [0] * (end - start)}
151 junctions = {"+": OrderedDict()}
152 if s != "NONE":
153 a["-"] = [0] * (end - start)
154 junctions["-"] = OrderedDict()
155
156 p = sp.Popen("samtools view %s %s " % (f, c), shell=True, stdout=sp.PIPE)
157 for line in p.communicate()[0].decode('utf8').strip().split("\n"):
158
159 if line == "":
160 continue
161
162 line_sp = line.strip().split("\t")
163 samflag, read_start, CIGAR = line_sp[1], int(line_sp[3]), line_sp[5]
164
165 # Ignore reads with more exotic CIGAR operators
166 if any(map(lambda x: x in CIGAR, ["H", "P", "X", "="])):
167 continue
168
169 read_strand = ["+", "-"][flip_read(s, samflag) ^ bool(int(samflag) &
170 16)]
171 if s == "NONE":
172 read_strand = "+"
173
174 CIGAR_lens = re.split("[MIDNS]", CIGAR)[:-1]
175 CIGAR_ops = re.split("[0-9]+", CIGAR)[1:]
176
177 pos = read_start
178
179 for n, CIGAR_op in enumerate(CIGAR_ops):
180 CIGAR_len = int(CIGAR_lens[n])
181 pos = count_operator(CIGAR_op, CIGAR_len, pos, start, end,
182 a[read_strand], junctions[read_strand])
183
184 p.stdout.close()
185 return a, junctions
186
187
188 def get_bam_path(index, path):
189 if os.path.isabs(path):
190 return path
191 base_dir = os.path.dirname(index)
192 return os.path.join(base_dir, path)
193
194
195 def read_bam_input(f, overlay, color, label):
196 if f.endswith(".bam"):
197 bn = f.strip().split("/")[-1].strip(".bam")
198 yield bn, f, None, None, bn
199 return
200 with open(f) as openf:
201 for line in openf:
202 line_sp = line.strip().split("\t")
203 bam = get_bam_path(f, line_sp[1])
204 overlay_level = line_sp[overlay-1] if overlay else None
205 color_level = line_sp[color-1] if color else None
206 label_text = line_sp[label-1] if label else None
207 yield line_sp[0], bam, overlay_level, color_level, label_text
208
209
210 def prepare_for_R(a, junctions, c, m):
211
212 _, start, _ = parse_coordinates(args.coordinates)
213
214 # Convert the array index to genomic coordinates
215 x = list(i+start for i in range(len(a)))
216 y = a
217
218 # Arrays for R
219 dons, accs, yd, ya, counts = [], [], [], [], []
220
221 # Prepare arrays for junctions (which will be the arcs)
222 for (don, acc), n in junctions.items():
223
224 # Do not add junctions with less than defined coverage
225 if n < m:
226 continue
227
228 dons.append(don)
229 accs.append(acc)
230 counts.append(n)
231
232 yd.append(a[don - start - 1])
233 ya.append(a[acc - start + 1])
234
235 return x, y, dons, accs, yd, ya, counts
236
237
238 def intersect_introns(data):
239 data = sorted(data)
240 it = iter(data)
241 a, b = next(it)
242 for c, d in it:
243 if b > c:
244 # Use `if b > c` if you want (1,2), (2,3) not to be
245 # treated as intersection.
246 b = min(b, d)
247 a = max(a, c)
248 else:
249 yield a, b
250 a, b = c, d
251 yield a, b
252
253
254 def shrink_density(x, y, introns):
255 new_x, new_y = [], []
256 shift = 0
257 start = 0
258 # introns are already sorted by coordinates
259 for a, b in introns:
260 end = x.index(a)+1
261 new_x += [int(i-shift) for i in x[start:end]]
262 new_y += y[start:end]
263 start = x.index(b)
264 L = (b-a)
265 shift += L-L**0.7
266 new_x += [int(i-shift) for i in x[start:]]
267 new_y += y[start:]
268 return new_x, new_y
269
270
271 def shrink_junctions(dons, accs, introns):
272 new_dons, new_accs = [0]*len(dons), [0]*len(accs)
273 shift_acc = 0
274 shift_don = 0
275 s = set()
276 junctions = list(zip(dons, accs))
277 for a, b in introns:
278 L = b - a
279 shift_acc += L-int(L**0.7)
280 for i, (don, acc) in enumerate(junctions):
281 if a >= don and b <= acc:
282 if (don, acc) not in s:
283 new_dons[i] = don - shift_don
284 new_accs[i] = acc - shift_acc
285 else:
286 new_accs[i] = acc - shift_acc
287 s.add((don, acc))
288 shift_don = shift_acc
289 return new_dons, new_accs
290
291
292 def read_palette(f):
293 palette = "#ff0000", "#00ff00", "#0000ff", "#000000"
294 if f:
295 with open(f) as openf:
296 palette = list(line.split("\t")[0].strip() for line in openf)
297 return palette
298
299
300 def read_gtf(f, c):
301 exons = OrderedDict()
302 transcripts = OrderedDict()
303 chr, start, end = parse_coordinates(c)
304 end = end - 1
305 with open(f) as openf:
306 for line in openf:
307 if line.startswith("#"):
308 continue
309 (el_chr, _, el, el_start, el_end, _,
310 strand, _, tags) = line.strip().split("\t")
311 if el_chr != chr:
312 continue
313 d = dict(kv.strip().split(" ") for kv in
314 tags.strip(";").split("; "))
315 transcript_id = d["transcript_id"]
316 el_start, el_end = int(el_start) - 1, int(el_end)
317 strand = '"' + strand + '"'
318 if el == "transcript":
319 if (el_end > start and el_start < end):
320 transcripts[transcript_id] = (max(start, el_start),
321 min(end, el_end),
322 strand)
323 continue
324 if el == "exon":
325 if (start < el_start < end or start < el_end < end):
326 exons.setdefault(transcript_id,
327 []).append((max(el_start, start),
328 min(end, el_end), strand))
329
330 return transcripts, exons
331
332
333 def make_introns(transcripts, exons, intersected_introns=None):
334 new_transcripts = copy.deepcopy(transcripts)
335 new_exons = copy.deepcopy(exons)
336 introns = OrderedDict()
337 if intersected_introns:
338 for tx, (tx_start, tx_end, strand) in new_transcripts.items():
339 total_shift = 0
340 for a, b in intersected_introns:
341 L = b - a
342 shift = L - int(L**0.7)
343 total_shift += shift
344 for i, (exon_start, exon_end, strand) in \
345 enumerate(exons.get(tx, [])):
346 new_exon_start, new_exon_end = new_exons[tx][i][:2]
347 if a < exon_start:
348 if b > exon_end:
349 if i == len(exons[tx])-1:
350 total_shift = total_shift - shift + \
351 (exon_start - a)*(1-int(L**-0.3))
352 shift = (exon_start - a)*(1-int(L**-0.3))
353 new_exon_end = new_exons[tx][i][1] - shift
354 new_exon_start = new_exons[tx][i][0] - shift
355 if b <= exon_end:
356 new_exon_end = new_exons[tx][i][1] - shift
357 new_exons[tx][i] = (new_exon_start, new_exon_end, strand)
358 tx_start = min(tx_start,
359 sorted(new_exons.get(tx, [[sys.maxsize]]))[0][0])
360 new_transcripts[tx] = (tx_start, tx_end - total_shift, strand)
361
362 for tx, (tx_start, tx_end, strand) in new_transcripts.items():
363 intron_start = tx_start
364 ex_end = 0
365 for ex_start, ex_end, strand in sorted(new_exons.get(tx, [])):
366 intron_end = ex_start
367 if tx_start < ex_start:
368 introns.setdefault(tx, []).append((intron_start, intron_end,
369 strand))
370 intron_start = ex_end
371 if tx_end > ex_end:
372 introns.setdefault(tx, []).append((intron_start, tx_end, strand))
373 d = {'transcripts': new_transcripts,
374 'exons': new_exons,
375 'introns': introns}
376 return d
377
378
379 def gtf_for_ggplot(annotation, start, end, arrow_bins):
380 arrow_space = int((end - start)/arrow_bins)
381 s = """
382
383 # data table with exons
384 ann_list = list(
385 "exons" = data.table(),
386 "introns" = data.table()
387 )
388 """
389
390 if annotation["exons"]:
391
392 s += """
393 ann_list[['exons']] = data.table(
394 tx = rep(c(%(tx_exons)s), c(%(n_exons)s)),
395 start = c(%(exon_start)s),
396 end = c(%(exon_end)s),
397 strand = c(%(strand)s)
398 )
399 """ % ({
400 "tx_exons": ",".join(annotation["exons"].keys()),
401 "n_exons": ",".join(map(str, map(len,
402 annotation["exons"].values()))),
403 "exon_start": ",".join(map(str, (v[0] for vs in
404 annotation["exons"].values() for v in vs))),
405 "exon_end": ",".join(map(str, (v[1] for vs in
406 annotation["exons"].values() for v in vs))),
407 "strand": ",".join(map(str, (v[2] for vs in
408 annotation["exons"].values() for v in vs))),
409 })
410
411 if annotation["introns"]:
412
413 s += """
414 ann_list[['introns']] = data.table(
415 tx = rep(c(%(tx_introns)s), c(%(n_introns)s)),
416 start = c(%(intron_start)s),
417 end = c(%(intron_end)s),
418 strand = c(%(strand)s)
419 )
420 # Create data table for strand arrows
421 txarrows = data.table()
422 introns = ann_list[['introns']]
423 # Add right-pointing arrows for plus strand
424 if ("+" %%in%% introns$strand) {
425 txarrows = rbind(
426 txarrows,
427 introns[strand=="+" & end-start>5, list(
428 seq(start+4,end,by=%(arrow_space)s)-1,
429 seq(start+4,end,by=%(arrow_space)s)
430 ), by=.(tx,start,end)
431 ]
432 )
433 }
434 # Add left-pointing arrows for minus strand
435 if ("-" %%in%% introns$strand) {
436 txarrows = rbind(txarrows,
437 introns[strand=="-" & end-start>5,
438 list(seq(start,max(start+1, end-4),
439 by=%(arrow_space)s),
440 seq(start,max(start+1, end-4),
441 by=%(arrow_space)s)-1
442 ),
443 by=.(tx,start,end)
444 ]
445 )
446 }
447 """ % ({
448 "tx_introns": ",".join(annotation["introns"].keys()),
449 "n_introns": ",".join(map(str, map(len,
450 annotation["introns"].values()))),
451 "intron_start": ",".join(map(str, (v[0] for vs in
452 annotation["introns"].values() for v in
453 vs))),
454 "intron_end": ",".join(map(str, (v[1] for vs in
455 annotation["introns"].values() for v in
456 vs))),
457 "strand": ",".join(map(str, (v[2] for vs in
458 annotation["introns"].values() for v in vs))),
459 "arrow_space": arrow_space,
460 })
461
462 s += """
463
464 gtfp = ggplot()
465 if (length(ann_list[['introns']]) > 0) {
466 gtfp = gtfp + geom_segment(data = ann_list[['introns']],
467 aes(x = start,
468 xend = end,
469 y = tx,
470 yend = tx),
471 size = 0.3)
472 gtfp = gtfp + geom_segment(data = txarrows,
473 aes(x = V1,
474 xend = V2,
475 y = tx,
476 yend = tx),
477 arrow = arrow(length = unit(0.02, "npc")))
478 }
479 if (length(ann_list[['exons']]) > 0) {
480 gtfp = gtfp + geom_segment(data = ann_list[['exons']],
481 aes(x = start,
482 xend = end,
483 y = tx,
484 yend = tx),
485 size = 5,
486 alpha = 1)
487 }
488 gtfp = gtfp + scale_y_discrete(expand = c(0, 0.5))
489 gtfp = gtfp + scale_x_continuous(expand = c(0, 0.25),
490 limits = c( %s,% s))
491 gtfp = gtfp + labs(y = NULL)
492 gtfp = gtfp + theme(axis.line = element_blank(),
493 axis.text.x = element_blank(),
494 axis.ticks = element_blank())
495 """ % (start, end)
496
497 return s
498
499
500 def setup_R_script(h, w, b, label_dict):
501 s = """
502 library(ggplot2)
503 library(grid)
504 library(gridExtra)
505 library(data.table)
506 library(gtable)
507
508 scale_lwd = function(r) {
509 lmin = 0.1
510 lmax = 4
511 return( r*(lmax-lmin)+lmin )
512 }
513
514 base_size = %(b)s
515 height = ( %(h)s + base_size*0.352777778/67 ) * 1.02
516 width = %(w)s
517 theme_set(theme_bw(base_size=base_size))
518 theme_update(
519 plot.margin = unit(c(15,15,15,15), "pt"),
520 panel.grid = element_blank(),
521 panel.border = element_blank(),
522 axis.line = element_line(size=0.5),
523 axis.title.x = element_blank(),
524 axis.title.y = element_text(angle=0, vjust=0.5)
525 )
526
527 labels = list(%(labels)s)
528
529 density_list = list()
530 junction_list = list()
531
532 """ % ({
533 'h': h,
534 'w': w,
535 'b': b,
536 'labels': ",".join(('"%s"="%s"' % (id, lab) for id, lab in
537 label_dict.items())),
538 })
539 return s
540
541
542 def median(lst):
543 quotient, remainder = divmod(len(lst), 2)
544 if remainder:
545 return sorted(lst)[quotient]
546 return sum(sorted(lst)[quotient - 1:quotient + 1]) / 2.
547
548
549 def mean(lst):
550 return sum(lst)/len(lst)
551
552
553 def make_R_lists(id_list, d, overlay_dict, aggr, intersected_introns):
554 s = ""
555 aggr_f = {
556 "mean": mean,
557 "median": median,
558 }
559 id_list = id_list if not overlay_dict else overlay_dict.keys()
560 # Iterate over ids to get bam signal and junctions
561 for k in id_list:
562 x, y, dons, accs, yd, ya, counts = [], [], [], [], [], [], []
563 if not overlay_dict:
564 x, y, dons, accs, yd, ya, counts = d[k]
565 if intersected_introns:
566 x, y = shrink_density(x, y, intersected_introns)
567 dons, accs = shrink_junctions(dons, accs, intersected_introns)
568 else:
569 for id in overlay_dict[k]:
570 xid, yid, donsid, accsid, ydid, yaid, countsid = d[id]
571 if intersected_introns:
572 xid, yid = shrink_density(xid, yid, intersected_introns)
573 donsid, accsid = shrink_junctions(donsid, accsid,
574 intersected_introns)
575 x += xid
576 y += yid
577 dons += donsid
578 accs += accsid
579 yd += ydid
580 ya += yaid
581 counts += countsid
582 if aggr and "_j" not in aggr:
583 x = d[overlay_dict[k][0]][0]
584 y = list(map(aggr_f[aggr], zip(*(d[id][1] for id in
585 overlay_dict[k]))))
586 if intersected_introns:
587 x, y = shrink_density(x, y, intersected_introns)
588 # dons, accs, yd, ya, counts = [], [], [], [], []
589 s += """
590 density_list[["%(id)s"]] = data.frame(x = c(%(x)s), y = c(%(y)s))
591 junction_list[["%(id)s"]] = data.frame(x = c(%(dons)s),
592 xend=c(%(accs)s),
593 y=c(%(yd)s),
594 yend=c(%(ya)s),
595 count=c(%(counts)s))
596 """ % ({
597 "id": k,
598 'x': ",".join(map(str, x)),
599 'y': ",".join(map(str, y)),
600 'dons': ",".join(map(str, dons)),
601 'accs': ",".join(map(str, accs)),
602 'yd': ",".join(map(str, yd)),
603 'ya': ",".join(map(str, ya)),
604 'counts': ",".join(map(str, counts)),
605 })
606 return s
607
608
609 def plot(R_script):
610 p = sp.Popen("R --vanilla --slave", shell=True, stdin=sp.PIPE)
611 p.communicate(input=R_script.encode())
612 p.stdin.close()
613 p.wait()
614 return
615
616
617 def colorize(d, p, color_factor):
618 levels = sorted(set(d.values()))
619 n = len(levels)
620 if n > len(p):
621 p = (p*n)[:n]
622 if color_factor:
623 s = "color_list = list(%s)\n" % (",".join('%s="%s"' % (k,
624 p[levels.index(v)]) for k, v in
625 d.items()))
626 else:
627 s = "color_list = list(%s)\n" % (",".join('%s="%s"' % (k, "grey") for
628 k, v in d.items()))
629 return s
630
631
632 if __name__ == "__main__":
633
634 strand_dict = {"plus": "+", "minus": "-"}
635
636 parser = define_options()
637 if len(sys.argv) == 1:
638 parser.print_help()
639 sys.exit(1)
640 args = parser.parse_args()
641
642 if args.aggr and not args.overlay:
643 print("""ERROR: Cannot apply aggregate function
644 if overlay is not selected.""")
645 exit(1)
646
647 palette = read_palette(args.palette)
648
649 (bam_dict, overlay_dict, color_dict,
650 id_list, label_dict) = ({"+": OrderedDict()}, OrderedDict(),
651 OrderedDict(), [], OrderedDict())
652 if args.strand != "NONE":
653 bam_dict["-"] = OrderedDict()
654 if args.junctions_bed != "":
655 junctions_list = []
656
657 for (id, bam, overlay_level,
658 color_level, label_text) in read_bam_input(args.bam,
659 args.overlay,
660 args.color_factor,
661 args.labels):
662 if not os.path.isfile(bam):
663 continue
664 id_list.append(id)
665 label_dict[id] = label_text
666 a, junctions = read_bam(bam, args.coordinates, args.strand)
667 if a.keys() == ["+"] and all(map(lambda x: x == 0,
668 list(a.values()[0]))):
669 print("ERROR: No reads in the specified area.")
670 exit(1)
671 for strand in a:
672 # Store junction information
673 if args.junctions_bed:
674 for k, v in zip(junctions[strand].keys(),
675 junctions[strand].values()):
676 if v > args.min_coverage:
677 junctions_list.append('\t'.join([args.coordinates.split
678 (':')[0], str(k[0]), str(k[1]),
679 id, str(v), strand]))
680 bam_dict[strand][id] = prepare_for_R(a[strand],
681 junctions[strand],
682 args.coordinates,
683 args.min_coverage)
684 if color_level is None:
685 color_dict.setdefault(id, id)
686 if overlay_level is not None:
687 overlay_dict.setdefault(overlay_level, []).append(id)
688 label_dict[overlay_level] = overlay_level
689 color_dict.setdefault(overlay_level, overlay_level)
690 if overlay_level is None:
691 color_dict.setdefault(id, color_level)
692
693 # No bam files
694 if not bam_dict["+"]:
695 print("ERROR: No available bam files.")
696 exit(1)
697
698 # Write junctions to BED
699 if args.junctions_bed:
700 if not args.junctions_bed.endswith('.bed'):
701 args.junctions_bed = args.junctions_bed + '.bed'
702 jbed = open(args.junctions_bed, 'w')
703 jbed.write('\n'.join(sorted(junctions_list)))
704 jbed.close()
705
706 if args.gtf:
707 transcripts, exons = read_gtf(args.gtf, args.coordinates)
708
709 if args.out_format not in ('pdf', 'png', 'svg', 'tiff', 'jpeg'):
710 print("""ERROR: Provided output format '%s' is not available.
711 Please select among 'pdf', 'png', 'svg',
712 'tiff' or 'jpeg'""" % args.out_format)
713 exit(1)
714
715 # Iterate for plus and minus strand
716 for strand in bam_dict:
717
718 # Output file name (allow tiff/tif and jpeg/jpg extensions)
719 if args.out_prefix.endswith(('.pdf', '.png', '.svg', '.tiff',
720 '.tif', '.jpeg', '.jpg')):
721 out_split = os.path.splitext(args.out_prefix)
722 if (args.out_format == out_split[1][1:] or
723 args.out_format == 'tiff'
724 and out_split[1] in ('.tiff', '.tif') or
725 args.out_format == 'jpeg'
726 and out_split[1] in ('.jpeg', '.jpg')):
727 args.out_prefix = out_split[0]
728 out_suffix = out_split[1][1:]
729 else:
730 out_suffix = args.out_format
731 else:
732 out_suffix = args.out_format
733 out_prefix = args.out_prefix + "_" + strand
734 if args.strand == "NONE":
735 out_prefix = args.out_prefix
736 else:
737 if args.out_strand != "both" \
738 and strand != strand_dict[args.out_strand]:
739 continue
740
741 # Find set of junctions to perform shrink
742 intersected_introns = None
743 if args.shrink:
744 introns = (v for vs in bam_dict[strand].values() for v in
745 zip(vs[2], vs[3]))
746 intersected_introns = list(intersect_introns(introns))
747
748 # *** PLOT *** Define plot height
749 bam_height = args.height * len(id_list)
750 if args.overlay:
751 bam_height = args.height * len(overlay_dict)
752 if args.gtf:
753 bam_height += args.ann_height
754
755 # *** PLOT *** Start R script by loading libraries,
756 # initializing variables, etc...
757 R_script = setup_R_script(bam_height, args.width,
758 args.base_size, label_dict)
759
760 R_script += colorize(color_dict, palette, args.color_factor)
761
762 # *** PLOT *** Prepare annotation plot only for the first bam file
763 arrow_bins = 50
764 if args.gtf:
765 # Make introns from annotation (they are shrunk if required)
766 annotation = make_introns(transcripts, exons, intersected_introns)
767 x = list(bam_dict[strand].values())[0][0]
768 if args.shrink:
769 x, _ = shrink_density(x, x, intersected_introns)
770 R_script += gtf_for_ggplot(annotation, x[0], x[-1], arrow_bins)
771
772 R_script += make_R_lists(id_list, bam_dict[strand], overlay_dict,
773 args.aggr, intersected_introns)
774
775 R_script += """
776
777 pdf(NULL) # just to remove the blank pdf produced by ggplotGrob
778 # fix problems with ggplot2 vs >3.0.0
779 if(packageVersion('ggplot2') >= '3.0.0'){
780 vs = 1
781 } else {
782 vs = 0
783 }
784
785 density_grobs = list();
786
787 for (bam_index in 1:length(density_list)) {
788
789 id = names(density_list)[bam_index]
790 d = data.table(density_list[[id]])
791 junctions = data.table(junction_list[[id]])
792
793 maxheight = max(d[['y']])
794
795 # Density plot
796 gp = ggplot(d) + geom_bar(aes(x, y), width=1,
797 position='identity',
798 stat='identity',
799 fill=color_list[[id]],
800 alpha=%(alpha)s)
801 gp = gp + labs(y=labels[[id]])
802 gp = gp + scale_x_continuous(expand=c(0,0.2))
803
804 # fix problems with ggplot2 vs >3.0.0
805 if(packageVersion('ggplot2') >= '3.0.0') {
806 gp = gp +
807 scale_y_continuous(breaks =
808 ggplot_build(gp
809 )$layout$panel_params[[1]]$y.major_source)
810 } else {
811 gp = gp +
812 scale_y_continuous(breaks =
813 ggplot_build(gp
814 )$layout$panel_ranges[[1]]$y.major_source)
815 }
816
817 # Aggregate junction counts
818 row_i = c()
819 if (nrow(junctions) >0 ) {
820
821 junctions$jlabel = as.character(junctions$count)
822 junctions = setNames(junctions[,.(max(y),
823 max(yend),
824 round(mean(count)),
825 paste(jlabel,
826 collapse=",")),
827 keyby=.(x,xend)],
828 names(junctions))
829 if ("%(args.aggr)s" != "") {
830 junctions = setNames(
831 junctions[,.(max(y),
832 max(yend),
833 round(%(args.aggr)s(count)),
834 round(%(args.aggr)s(count))),
835 keyby=.(x,xend)],
836 names(junctions))
837 }
838 # The number of rows (unique junctions per bam) has to be
839 # calculated after aggregation
840 row_i = 1:nrow(junctions)
841 }
842
843
844 for (i in row_i) {
845
846 j_tot_counts = sum(junctions[['count']])
847
848 j = as.numeric(junctions[i,1:5])
849
850 if ("%(args.aggr)s" != "") {
851 j[3] = as.numeric(d[x==j[1]-1,y])
852 j[4] = as.numeric(d[x==j[2]+1,y])
853 }
854
855 # Find intron midpoint
856 xmid = round(mean(j[1:2]), 1)
857 ymid = max(j[3:4]) * 1.2
858
859 # Thickness of the arch
860 lwd = scale_lwd(j[5]/j_tot_counts)
861
862 curve_par = gpar(lwd=lwd, col=color_list[[id]])
863
864 # Arc grobs
865
866 # Choose position of the arch (top or bottom)
867 nss = i
868 if (nss%%%%2 == 0) { #bottom
869 ymid = -0.3 * maxheight
870 # Draw the arcs
871 # Left
872 curve = xsplineGrob(x = c(0, 0, 1, 1),
873 y = c(1, 0, 0, 0),
874 shape = 1,
875 gp = curve_par)
876 gp = gp + annotation_custom(grob = curve,
877 j[1], xmid, 0, ymid)
878 # Right
879 curve = xsplineGrob(x = c(1, 1, 0, 0),
880 y = c(1, 0, 0, 0),
881 shape = 1,
882 gp = curve_par)
883 gp = gp + annotation_custom(grob = curve,
884 xmid,
885 j[2],
886 0, ymid)
887 }
888
889 if (nss%%%%2 != 0) { #top
890 # Draw the arcs
891 # Left
892 curve = xsplineGrob(x = c(0, 0, 1, 1),
893 y = c(0, 1, 1, 1),
894 shape = 1,
895 gp = curve_par)
896 gp = gp + annotation_custom(grob = curve,
897 j[1], xmid, j[3], ymid)
898 # Right
899 curve = xsplineGrob(x = c(1, 1, 0, 0),
900 y = c(0, 1, 1, 1),
901 shape = 1,
902 gp = curve_par)
903 gp = gp + annotation_custom(grob = curve,
904 xmid, j[2], j[4], ymid)
905 }
906
907 # Add junction labels
908 gp = gp + annotate("label", x = xmid, y = ymid,
909 label = as.character(junctions[i,6]),
910 vjust = 0.5, hjust = 0.5,
911 label.padding = unit(0.01, "lines"),
912 label.size = NA,
913 size = (base_size*0.352777778)*0.6)
914
915 }
916
917 gpGrob = ggplotGrob(gp);
918 gpGrob$layout$clip[gpGrob$layout$name=="panel"] <- "off"
919 if (bam_index == 1) {
920 # fix problems ggplot2 vs
921 maxWidth = gpGrob$widths[2+vs] + gpGrob$widths[3+vs];
922 maxYtextWidth = gpGrob$widths[3+vs];
923 # Extract x axis grob (trim=F --> keep empty cells)
924 xaxisGrob <- gtable_filter(gpGrob, "axis-b", trim=F)
925 # fix problems ggplot2 vs
926 xaxisGrob$heights[8+vs] = gpGrob$heights[1]
927 x.axis.height = gpGrob$heights[7+vs] + gpGrob$heights[1]
928 }
929
930
931 # Remove x axis from all density plots
932 kept_names = gpGrob$layout$name[gpGrob$layout$name != "axis-b"]
933 gpGrob <- gtable_filter(gpGrob,
934 paste(kept_names, sep = "",
935 collapse = "|"),
936 trim=F)
937
938 # Find max width of y text and y label and max width of y text
939 # fix problems ggplot2 vs
940 maxWidth = grid::unit.pmax(maxWidth,
941 gpGrob$widths[2+vs] +
942 gpGrob$widths[3+vs]);
943 maxYtextWidth = grid::unit.pmax(maxYtextWidth,
944 gpGrob$widths[3+vs]);
945 density_grobs[[id]] = gpGrob;
946 }
947
948 # Add x axis grob after density grobs BEFORE annotation grob
949 density_grobs[["xaxis"]] = xaxisGrob
950
951 # Annotation grob
952 if (%(args.gtf)s == 1) {
953 gtfGrob = ggplotGrob(gtfp);
954 maxWidth = grid::unit.pmax(maxWidth,
955 gtfGrob$widths[2+vs] +
956 gtfGrob$widths[3+vs]);
957 density_grobs[['gtf']] = gtfGrob;
958 }
959
960 # Reassign grob widths to align the plots
961 for (id in names(density_grobs)) {
962 density_grobs[[id]]$widths[1] <-
963 density_grobs[[id]]$widths[1] +
964 maxWidth - (density_grobs[[id]]$widths[2 + vs] +
965 maxYtextWidth)
966 # fix problems ggplot2 vs
967 density_grobs[[id]]$widths[3 + vs] <-
968 maxYtextWidth # fix problems ggplot2 vs
969 }
970
971 # Heights for density, x axis and annotation
972 heights = unit.c(
973 unit(rep(%(signal_height)s,
974 length(density_list)), "in"),
975 x.axis.height,
976 unit(%(ann_height)s*%(args.gtf)s, "in")
977 )
978
979 # Arrange grobs
980 argrobs = arrangeGrob(
981 grobs=density_grobs,
982 ncol=1,
983 heights = heights,
984 );
985
986 # Save plot to file in the requested format
987 if ("%(out_format)s" == "tiff"){
988 # TIFF images will be lzw-compressed
989 ggsave("%(out)s",
990 plot = argrobs,
991 device = "tiff",
992 width = width,
993 height = height,
994 units = "in",
995 dpi = %(out_resolution)s,
996 compression = "lzw")
997 } else {
998 ggsave("%(out)s",
999 plot = argrobs,
1000 device = "%(out_format)s",
1001 width = width,
1002 height = height,
1003 units = "in",
1004 dpi = %(out_resolution)s)
1005 }
1006
1007 dev.log = dev.off()
1008
1009 """ % ({
1010 "out": "%s.%s" % (out_prefix, out_suffix),
1011 "out_format": args.out_format,
1012 "out_resolution": args.out_resolution,
1013 "args.gtf": float(bool(args.gtf)),
1014 "args.aggr": args.aggr.rstrip("_j"),
1015 "signal_height": args.height,
1016 "ann_height": args.ann_height,
1017 "alpha": args.alpha,
1018 })
1019 if os.getenv('GGSASHIMI_DEBUG') is not None:
1020 with open("R_script", 'w') as r:
1021 r.write(R_script)
1022 else:
1023 plot(R_script)
1024 exit()