Mercurial > repos > artbio > sashimi_plot
comparison sashimi-plot.py @ 0:9304dd9a16a2 draft
"planemo upload for repository https://github.com/ARTbio/tools-artbio/tree/master/tools/sashimi_plot commit 746c03a1187e1d708af8628920a0c615cddcdacc"
| author | artbio |
|---|---|
| date | Fri, 23 Aug 2019 11:38:29 -0400 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:9304dd9a16a2 |
|---|---|
| 1 #!/usr/bin/env python | |
| 2 | |
| 3 # Import modules | |
| 4 import copy | |
| 5 import os | |
| 6 import re | |
| 7 import subprocess as sp | |
| 8 import sys | |
| 9 from argparse import ArgumentParser | |
| 10 from collections import OrderedDict | |
| 11 | |
| 12 | |
| 13 def define_options(): | |
| 14 # Argument parsing | |
| 15 parser = ArgumentParser(description="""Create sashimi plot for a given | |
| 16 genomic region""") | |
| 17 parser.add_argument("-b", "--bam", type=str, | |
| 18 help=""" | |
| 19 Individual bam file or file with a list of bam files. | |
| 20 In the case of a list of files the format is tsv: | |
| 21 1col: id for bam file, | |
| 22 2col: path of bam file, | |
| 23 3+col: additional columns | |
| 24 """) | |
| 25 parser.add_argument("-c", "--coordinates", type=str, | |
| 26 help="Genomic region. Format: chr:start-end (1-based)") | |
| 27 parser.add_argument("-o", "--out-prefix", type=str, dest="out_prefix", | |
| 28 default="sashimi", | |
| 29 help="Prefix for plot file name [default=%(default)s]") | |
| 30 parser.add_argument("-S", "--out-strand", type=str, dest="out_strand", | |
| 31 default="both", help="""Only for --strand other than | |
| 32 'NONE'. Choose which signal strand to plot: | |
| 33 <both> <plus> <minus> [default=%(default)s]""") | |
| 34 parser.add_argument("-M", "--min-coverage", type=int, default=1, | |
| 35 dest="min_coverage", help="""Minimum number of reads | |
| 36 supporting a junction to be drawn [default=1]""") | |
| 37 parser.add_argument("-j", "--junctions-bed", type=str, default="", | |
| 38 dest="junctions_bed", help="""Junction BED file name | |
| 39 [default=no junction file]""") | |
| 40 parser.add_argument("-g", "--gtf", | |
| 41 help="Gtf file with annotation (only exons is enough)") | |
| 42 parser.add_argument("-s", "--strand", default="NONE", type=str, | |
| 43 help="""Strand specificity: <NONE> <SENSE> <ANTISENSE> | |
| 44 <MATE1_SENSE> <MATE2_SENSE> [default=%(default)s]""") | |
| 45 parser.add_argument("--shrink", action="store_true", | |
| 46 help="""Shrink the junctions by a factor for nicer | |
| 47 display [default=%(default)s]""") | |
| 48 parser.add_argument("-O", "--overlay", type=int, | |
| 49 help="Index of column with overlay levels (1-based)") | |
| 50 parser.add_argument("-A", "--aggr", type=str, default="", | |
| 51 help="""Aggregate function for overlay: | |
| 52 <mean> <median> <mean_j> <median_j>. | |
| 53 Use mean_j | median_j to keep density overlay but | |
| 54 aggregate junction counts [default=no aggregation]""") | |
| 55 parser.add_argument("-C", "--color-factor", type=int, dest="color_factor", | |
| 56 help="Index of column with color levels (1-based)") | |
| 57 parser.add_argument("--alpha", type=float, default=0.5, | |
| 58 help="""Transparency level for density histogram | |
| 59 [default=%(default)s]""") | |
| 60 parser.add_argument("-P", "--palette", type=str, | |
| 61 help="""Color palette file. tsv file with >=1 columns, | |
| 62 where the color is the first column""") | |
| 63 parser.add_argument("-L", "--labels", type=int, dest="labels", default=1, | |
| 64 help="""Index of column with labels (1-based) | |
| 65 [default=%(default)s]""") | |
| 66 parser.add_argument("--height", type=float, default=2, | |
| 67 help="""Height of the individual signal plot in inches | |
| 68 [default=%(default)s]""") | |
| 69 parser.add_argument("--ann-height", type=float, default=1.5, | |
| 70 dest="ann_height", help="""Height of annotation plot in | |
| 71 inches [default=%(default)s]""") | |
| 72 parser.add_argument("--width", type=float, default=10, | |
| 73 help="""Width of the plot in inches | |
| 74 [default=%(default)s]""") | |
| 75 parser.add_argument("--base-size", type=float, default=14, | |
| 76 dest="base_size", help="""Base font size of the plot in | |
| 77 pch [default=%(default)s]""") | |
| 78 parser.add_argument("-F", "--out-format", type=str, default="pdf", | |
| 79 dest="out_format", help="""Output file format: | |
| 80 <pdf> <svg> <png> <jpeg> <tiff> | |
| 81 [default=%(default)s]""") | |
| 82 parser.add_argument("-R", "--out-resolution", type=int, default=300, | |
| 83 dest="out_resolution", help="""Output file resolution in | |
| 84 PPI (pixels per inch). Applies only to raster output | |
| 85 formats [default=%(default)s]""") | |
| 86 return parser | |
| 87 | |
| 88 | |
| 89 def parse_coordinates(c): | |
| 90 c = c.replace(",", "") | |
| 91 chr = c.split(":")[0] | |
| 92 start, end = c.split(":")[1].split("-") | |
| 93 # Convert to 0-based | |
| 94 start, end = int(start) - 1, int(end) | |
| 95 return chr, start, end | |
| 96 | |
| 97 | |
| 98 def count_operator(CIGAR_op, CIGAR_len, pos, start, end, a, junctions): | |
| 99 | |
| 100 # Match | |
| 101 if CIGAR_op == "M": | |
| 102 for i in range(pos, pos + CIGAR_len): | |
| 103 if i < start or i >= end: | |
| 104 continue | |
| 105 ind = i - start | |
| 106 a[ind] += 1 | |
| 107 | |
| 108 # Insertion or Soft-clip | |
| 109 if CIGAR_op == "I" or CIGAR_op == "S": | |
| 110 return pos | |
| 111 | |
| 112 # Deletion | |
| 113 if CIGAR_op == "D": | |
| 114 pass | |
| 115 | |
| 116 # Junction | |
| 117 if CIGAR_op == "N": | |
| 118 don = pos | |
| 119 acc = pos + CIGAR_len | |
| 120 if don > start and acc < end: | |
| 121 junctions[(don, acc)] = junctions.setdefault((don, acc), 0) + 1 | |
| 122 | |
| 123 pos = pos + CIGAR_len | |
| 124 | |
| 125 return pos | |
| 126 | |
| 127 | |
| 128 def flip_read(s, samflag): | |
| 129 if s == "NONE" or s == "SENSE": | |
| 130 return 0 | |
| 131 if s == "ANTISENSE": | |
| 132 return 1 | |
| 133 if s == "MATE1_SENSE": | |
| 134 if int(samflag) & 64: | |
| 135 return 0 | |
| 136 if int(samflag) & 128: | |
| 137 return 1 | |
| 138 if s == "MATE2_SENSE": | |
| 139 if int(samflag) & 64: | |
| 140 return 1 | |
| 141 if int(samflag) & 128: | |
| 142 return 0 | |
| 143 | |
| 144 | |
| 145 def read_bam(f, c, s): | |
| 146 | |
| 147 _, start, end = parse_coordinates(c) | |
| 148 | |
| 149 # Initialize coverage array and junction dict | |
| 150 a = {"+": [0] * (end - start)} | |
| 151 junctions = {"+": OrderedDict()} | |
| 152 if s != "NONE": | |
| 153 a["-"] = [0] * (end - start) | |
| 154 junctions["-"] = OrderedDict() | |
| 155 | |
| 156 p = sp.Popen("samtools view %s %s " % (f, c), shell=True, stdout=sp.PIPE) | |
| 157 for line in p.communicate()[0].decode('utf8').strip().split("\n"): | |
| 158 | |
| 159 if line == "": | |
| 160 continue | |
| 161 | |
| 162 line_sp = line.strip().split("\t") | |
| 163 samflag, read_start, CIGAR = line_sp[1], int(line_sp[3]), line_sp[5] | |
| 164 | |
| 165 # Ignore reads with more exotic CIGAR operators | |
| 166 if any(map(lambda x: x in CIGAR, ["H", "P", "X", "="])): | |
| 167 continue | |
| 168 | |
| 169 read_strand = ["+", "-"][flip_read(s, samflag) ^ bool(int(samflag) & | |
| 170 16)] | |
| 171 if s == "NONE": | |
| 172 read_strand = "+" | |
| 173 | |
| 174 CIGAR_lens = re.split("[MIDNS]", CIGAR)[:-1] | |
| 175 CIGAR_ops = re.split("[0-9]+", CIGAR)[1:] | |
| 176 | |
| 177 pos = read_start | |
| 178 | |
| 179 for n, CIGAR_op in enumerate(CIGAR_ops): | |
| 180 CIGAR_len = int(CIGAR_lens[n]) | |
| 181 pos = count_operator(CIGAR_op, CIGAR_len, pos, start, end, | |
| 182 a[read_strand], junctions[read_strand]) | |
| 183 | |
| 184 p.stdout.close() | |
| 185 return a, junctions | |
| 186 | |
| 187 | |
| 188 def get_bam_path(index, path): | |
| 189 if os.path.isabs(path): | |
| 190 return path | |
| 191 base_dir = os.path.dirname(index) | |
| 192 return os.path.join(base_dir, path) | |
| 193 | |
| 194 | |
| 195 def read_bam_input(f, overlay, color, label): | |
| 196 if f.endswith(".bam"): | |
| 197 bn = f.strip().split("/")[-1].strip(".bam") | |
| 198 yield bn, f, None, None, bn | |
| 199 return | |
| 200 with open(f) as openf: | |
| 201 for line in openf: | |
| 202 line_sp = line.strip().split("\t") | |
| 203 bam = get_bam_path(f, line_sp[1]) | |
| 204 overlay_level = line_sp[overlay-1] if overlay else None | |
| 205 color_level = line_sp[color-1] if color else None | |
| 206 label_text = line_sp[label-1] if label else None | |
| 207 yield line_sp[0], bam, overlay_level, color_level, label_text | |
| 208 | |
| 209 | |
| 210 def prepare_for_R(a, junctions, c, m): | |
| 211 | |
| 212 _, start, _ = parse_coordinates(args.coordinates) | |
| 213 | |
| 214 # Convert the array index to genomic coordinates | |
| 215 x = list(i+start for i in range(len(a))) | |
| 216 y = a | |
| 217 | |
| 218 # Arrays for R | |
| 219 dons, accs, yd, ya, counts = [], [], [], [], [] | |
| 220 | |
| 221 # Prepare arrays for junctions (which will be the arcs) | |
| 222 for (don, acc), n in junctions.items(): | |
| 223 | |
| 224 # Do not add junctions with less than defined coverage | |
| 225 if n < m: | |
| 226 continue | |
| 227 | |
| 228 dons.append(don) | |
| 229 accs.append(acc) | |
| 230 counts.append(n) | |
| 231 | |
| 232 yd.append(a[don - start - 1]) | |
| 233 ya.append(a[acc - start + 1]) | |
| 234 | |
| 235 return x, y, dons, accs, yd, ya, counts | |
| 236 | |
| 237 | |
| 238 def intersect_introns(data): | |
| 239 data = sorted(data) | |
| 240 it = iter(data) | |
| 241 a, b = next(it) | |
| 242 for c, d in it: | |
| 243 if b > c: | |
| 244 # Use `if b > c` if you want (1,2), (2,3) not to be | |
| 245 # treated as intersection. | |
| 246 b = min(b, d) | |
| 247 a = max(a, c) | |
| 248 else: | |
| 249 yield a, b | |
| 250 a, b = c, d | |
| 251 yield a, b | |
| 252 | |
| 253 | |
| 254 def shrink_density(x, y, introns): | |
| 255 new_x, new_y = [], [] | |
| 256 shift = 0 | |
| 257 start = 0 | |
| 258 # introns are already sorted by coordinates | |
| 259 for a, b in introns: | |
| 260 end = x.index(a)+1 | |
| 261 new_x += [int(i-shift) for i in x[start:end]] | |
| 262 new_y += y[start:end] | |
| 263 start = x.index(b) | |
| 264 L = (b-a) | |
| 265 shift += L-L**0.7 | |
| 266 new_x += [int(i-shift) for i in x[start:]] | |
| 267 new_y += y[start:] | |
| 268 return new_x, new_y | |
| 269 | |
| 270 | |
| 271 def shrink_junctions(dons, accs, introns): | |
| 272 new_dons, new_accs = [0]*len(dons), [0]*len(accs) | |
| 273 shift_acc = 0 | |
| 274 shift_don = 0 | |
| 275 s = set() | |
| 276 junctions = list(zip(dons, accs)) | |
| 277 for a, b in introns: | |
| 278 L = b - a | |
| 279 shift_acc += L-int(L**0.7) | |
| 280 for i, (don, acc) in enumerate(junctions): | |
| 281 if a >= don and b <= acc: | |
| 282 if (don, acc) not in s: | |
| 283 new_dons[i] = don - shift_don | |
| 284 new_accs[i] = acc - shift_acc | |
| 285 else: | |
| 286 new_accs[i] = acc - shift_acc | |
| 287 s.add((don, acc)) | |
| 288 shift_don = shift_acc | |
| 289 return new_dons, new_accs | |
| 290 | |
| 291 | |
| 292 def read_palette(f): | |
| 293 palette = "#ff0000", "#00ff00", "#0000ff", "#000000" | |
| 294 if f: | |
| 295 with open(f) as openf: | |
| 296 palette = list(line.split("\t")[0].strip() for line in openf) | |
| 297 return palette | |
| 298 | |
| 299 | |
| 300 def read_gtf(f, c): | |
| 301 exons = OrderedDict() | |
| 302 transcripts = OrderedDict() | |
| 303 chr, start, end = parse_coordinates(c) | |
| 304 end = end - 1 | |
| 305 with open(f) as openf: | |
| 306 for line in openf: | |
| 307 if line.startswith("#"): | |
| 308 continue | |
| 309 (el_chr, _, el, el_start, el_end, _, | |
| 310 strand, _, tags) = line.strip().split("\t") | |
| 311 if el_chr != chr: | |
| 312 continue | |
| 313 d = dict(kv.strip().split(" ") for kv in | |
| 314 tags.strip(";").split("; ")) | |
| 315 transcript_id = d["transcript_id"] | |
| 316 el_start, el_end = int(el_start) - 1, int(el_end) | |
| 317 strand = '"' + strand + '"' | |
| 318 if el == "transcript": | |
| 319 if (el_end > start and el_start < end): | |
| 320 transcripts[transcript_id] = (max(start, el_start), | |
| 321 min(end, el_end), | |
| 322 strand) | |
| 323 continue | |
| 324 if el == "exon": | |
| 325 if (start < el_start < end or start < el_end < end): | |
| 326 exons.setdefault(transcript_id, | |
| 327 []).append((max(el_start, start), | |
| 328 min(end, el_end), strand)) | |
| 329 | |
| 330 return transcripts, exons | |
| 331 | |
| 332 | |
| 333 def make_introns(transcripts, exons, intersected_introns=None): | |
| 334 new_transcripts = copy.deepcopy(transcripts) | |
| 335 new_exons = copy.deepcopy(exons) | |
| 336 introns = OrderedDict() | |
| 337 if intersected_introns: | |
| 338 for tx, (tx_start, tx_end, strand) in new_transcripts.items(): | |
| 339 total_shift = 0 | |
| 340 for a, b in intersected_introns: | |
| 341 L = b - a | |
| 342 shift = L - int(L**0.7) | |
| 343 total_shift += shift | |
| 344 for i, (exon_start, exon_end, strand) in \ | |
| 345 enumerate(exons.get(tx, [])): | |
| 346 new_exon_start, new_exon_end = new_exons[tx][i][:2] | |
| 347 if a < exon_start: | |
| 348 if b > exon_end: | |
| 349 if i == len(exons[tx])-1: | |
| 350 total_shift = total_shift - shift + \ | |
| 351 (exon_start - a)*(1-int(L**-0.3)) | |
| 352 shift = (exon_start - a)*(1-int(L**-0.3)) | |
| 353 new_exon_end = new_exons[tx][i][1] - shift | |
| 354 new_exon_start = new_exons[tx][i][0] - shift | |
| 355 if b <= exon_end: | |
| 356 new_exon_end = new_exons[tx][i][1] - shift | |
| 357 new_exons[tx][i] = (new_exon_start, new_exon_end, strand) | |
| 358 tx_start = min(tx_start, | |
| 359 sorted(new_exons.get(tx, [[sys.maxsize]]))[0][0]) | |
| 360 new_transcripts[tx] = (tx_start, tx_end - total_shift, strand) | |
| 361 | |
| 362 for tx, (tx_start, tx_end, strand) in new_transcripts.items(): | |
| 363 intron_start = tx_start | |
| 364 ex_end = 0 | |
| 365 for ex_start, ex_end, strand in sorted(new_exons.get(tx, [])): | |
| 366 intron_end = ex_start | |
| 367 if tx_start < ex_start: | |
| 368 introns.setdefault(tx, []).append((intron_start, intron_end, | |
| 369 strand)) | |
| 370 intron_start = ex_end | |
| 371 if tx_end > ex_end: | |
| 372 introns.setdefault(tx, []).append((intron_start, tx_end, strand)) | |
| 373 d = {'transcripts': new_transcripts, | |
| 374 'exons': new_exons, | |
| 375 'introns': introns} | |
| 376 return d | |
| 377 | |
| 378 | |
| 379 def gtf_for_ggplot(annotation, start, end, arrow_bins): | |
| 380 arrow_space = int((end - start)/arrow_bins) | |
| 381 s = """ | |
| 382 | |
| 383 # data table with exons | |
| 384 ann_list = list( | |
| 385 "exons" = data.table(), | |
| 386 "introns" = data.table() | |
| 387 ) | |
| 388 """ | |
| 389 | |
| 390 if annotation["exons"]: | |
| 391 | |
| 392 s += """ | |
| 393 ann_list[['exons']] = data.table( | |
| 394 tx = rep(c(%(tx_exons)s), c(%(n_exons)s)), | |
| 395 start = c(%(exon_start)s), | |
| 396 end = c(%(exon_end)s), | |
| 397 strand = c(%(strand)s) | |
| 398 ) | |
| 399 """ % ({ | |
| 400 "tx_exons": ",".join(annotation["exons"].keys()), | |
| 401 "n_exons": ",".join(map(str, map(len, | |
| 402 annotation["exons"].values()))), | |
| 403 "exon_start": ",".join(map(str, (v[0] for vs in | |
| 404 annotation["exons"].values() for v in vs))), | |
| 405 "exon_end": ",".join(map(str, (v[1] for vs in | |
| 406 annotation["exons"].values() for v in vs))), | |
| 407 "strand": ",".join(map(str, (v[2] for vs in | |
| 408 annotation["exons"].values() for v in vs))), | |
| 409 }) | |
| 410 | |
| 411 if annotation["introns"]: | |
| 412 | |
| 413 s += """ | |
| 414 ann_list[['introns']] = data.table( | |
| 415 tx = rep(c(%(tx_introns)s), c(%(n_introns)s)), | |
| 416 start = c(%(intron_start)s), | |
| 417 end = c(%(intron_end)s), | |
| 418 strand = c(%(strand)s) | |
| 419 ) | |
| 420 # Create data table for strand arrows | |
| 421 txarrows = data.table() | |
| 422 introns = ann_list[['introns']] | |
| 423 # Add right-pointing arrows for plus strand | |
| 424 if ("+" %%in%% introns$strand) { | |
| 425 txarrows = rbind( | |
| 426 txarrows, | |
| 427 introns[strand=="+" & end-start>5, list( | |
| 428 seq(start+4,end,by=%(arrow_space)s)-1, | |
| 429 seq(start+4,end,by=%(arrow_space)s) | |
| 430 ), by=.(tx,start,end) | |
| 431 ] | |
| 432 ) | |
| 433 } | |
| 434 # Add left-pointing arrows for minus strand | |
| 435 if ("-" %%in%% introns$strand) { | |
| 436 txarrows = rbind(txarrows, | |
| 437 introns[strand=="-" & end-start>5, | |
| 438 list(seq(start,max(start+1, end-4), | |
| 439 by=%(arrow_space)s), | |
| 440 seq(start,max(start+1, end-4), | |
| 441 by=%(arrow_space)s)-1 | |
| 442 ), | |
| 443 by=.(tx,start,end) | |
| 444 ] | |
| 445 ) | |
| 446 } | |
| 447 """ % ({ | |
| 448 "tx_introns": ",".join(annotation["introns"].keys()), | |
| 449 "n_introns": ",".join(map(str, map(len, | |
| 450 annotation["introns"].values()))), | |
| 451 "intron_start": ",".join(map(str, (v[0] for vs in | |
| 452 annotation["introns"].values() for v in | |
| 453 vs))), | |
| 454 "intron_end": ",".join(map(str, (v[1] for vs in | |
| 455 annotation["introns"].values() for v in | |
| 456 vs))), | |
| 457 "strand": ",".join(map(str, (v[2] for vs in | |
| 458 annotation["introns"].values() for v in vs))), | |
| 459 "arrow_space": arrow_space, | |
| 460 }) | |
| 461 | |
| 462 s += """ | |
| 463 | |
| 464 gtfp = ggplot() | |
| 465 if (length(ann_list[['introns']]) > 0) { | |
| 466 gtfp = gtfp + geom_segment(data = ann_list[['introns']], | |
| 467 aes(x = start, | |
| 468 xend = end, | |
| 469 y = tx, | |
| 470 yend = tx), | |
| 471 size = 0.3) | |
| 472 gtfp = gtfp + geom_segment(data = txarrows, | |
| 473 aes(x = V1, | |
| 474 xend = V2, | |
| 475 y = tx, | |
| 476 yend = tx), | |
| 477 arrow = arrow(length = unit(0.02, "npc"))) | |
| 478 } | |
| 479 if (length(ann_list[['exons']]) > 0) { | |
| 480 gtfp = gtfp + geom_segment(data = ann_list[['exons']], | |
| 481 aes(x = start, | |
| 482 xend = end, | |
| 483 y = tx, | |
| 484 yend = tx), | |
| 485 size = 5, | |
| 486 alpha = 1) | |
| 487 } | |
| 488 gtfp = gtfp + scale_y_discrete(expand = c(0, 0.5)) | |
| 489 gtfp = gtfp + scale_x_continuous(expand = c(0, 0.25), | |
| 490 limits = c( %s,% s)) | |
| 491 gtfp = gtfp + labs(y = NULL) | |
| 492 gtfp = gtfp + theme(axis.line = element_blank(), | |
| 493 axis.text.x = element_blank(), | |
| 494 axis.ticks = element_blank()) | |
| 495 """ % (start, end) | |
| 496 | |
| 497 return s | |
| 498 | |
| 499 | |
| 500 def setup_R_script(h, w, b, label_dict): | |
| 501 s = """ | |
| 502 library(ggplot2) | |
| 503 library(grid) | |
| 504 library(gridExtra) | |
| 505 library(data.table) | |
| 506 library(gtable) | |
| 507 | |
| 508 scale_lwd = function(r) { | |
| 509 lmin = 0.1 | |
| 510 lmax = 4 | |
| 511 return( r*(lmax-lmin)+lmin ) | |
| 512 } | |
| 513 | |
| 514 base_size = %(b)s | |
| 515 height = ( %(h)s + base_size*0.352777778/67 ) * 1.02 | |
| 516 width = %(w)s | |
| 517 theme_set(theme_bw(base_size=base_size)) | |
| 518 theme_update( | |
| 519 plot.margin = unit(c(15,15,15,15), "pt"), | |
| 520 panel.grid = element_blank(), | |
| 521 panel.border = element_blank(), | |
| 522 axis.line = element_line(size=0.5), | |
| 523 axis.title.x = element_blank(), | |
| 524 axis.title.y = element_text(angle=0, vjust=0.5) | |
| 525 ) | |
| 526 | |
| 527 labels = list(%(labels)s) | |
| 528 | |
| 529 density_list = list() | |
| 530 junction_list = list() | |
| 531 | |
| 532 """ % ({ | |
| 533 'h': h, | |
| 534 'w': w, | |
| 535 'b': b, | |
| 536 'labels': ",".join(('"%s"="%s"' % (id, lab) for id, lab in | |
| 537 label_dict.items())), | |
| 538 }) | |
| 539 return s | |
| 540 | |
| 541 | |
| 542 def median(lst): | |
| 543 quotient, remainder = divmod(len(lst), 2) | |
| 544 if remainder: | |
| 545 return sorted(lst)[quotient] | |
| 546 return sum(sorted(lst)[quotient - 1:quotient + 1]) / 2. | |
| 547 | |
| 548 | |
| 549 def mean(lst): | |
| 550 return sum(lst)/len(lst) | |
| 551 | |
| 552 | |
| 553 def make_R_lists(id_list, d, overlay_dict, aggr, intersected_introns): | |
| 554 s = "" | |
| 555 aggr_f = { | |
| 556 "mean": mean, | |
| 557 "median": median, | |
| 558 } | |
| 559 id_list = id_list if not overlay_dict else overlay_dict.keys() | |
| 560 # Iterate over ids to get bam signal and junctions | |
| 561 for k in id_list: | |
| 562 x, y, dons, accs, yd, ya, counts = [], [], [], [], [], [], [] | |
| 563 if not overlay_dict: | |
| 564 x, y, dons, accs, yd, ya, counts = d[k] | |
| 565 if intersected_introns: | |
| 566 x, y = shrink_density(x, y, intersected_introns) | |
| 567 dons, accs = shrink_junctions(dons, accs, intersected_introns) | |
| 568 else: | |
| 569 for id in overlay_dict[k]: | |
| 570 xid, yid, donsid, accsid, ydid, yaid, countsid = d[id] | |
| 571 if intersected_introns: | |
| 572 xid, yid = shrink_density(xid, yid, intersected_introns) | |
| 573 donsid, accsid = shrink_junctions(donsid, accsid, | |
| 574 intersected_introns) | |
| 575 x += xid | |
| 576 y += yid | |
| 577 dons += donsid | |
| 578 accs += accsid | |
| 579 yd += ydid | |
| 580 ya += yaid | |
| 581 counts += countsid | |
| 582 if aggr and "_j" not in aggr: | |
| 583 x = d[overlay_dict[k][0]][0] | |
| 584 y = list(map(aggr_f[aggr], zip(*(d[id][1] for id in | |
| 585 overlay_dict[k])))) | |
| 586 if intersected_introns: | |
| 587 x, y = shrink_density(x, y, intersected_introns) | |
| 588 # dons, accs, yd, ya, counts = [], [], [], [], [] | |
| 589 s += """ | |
| 590 density_list[["%(id)s"]] = data.frame(x = c(%(x)s), y = c(%(y)s)) | |
| 591 junction_list[["%(id)s"]] = data.frame(x = c(%(dons)s), | |
| 592 xend=c(%(accs)s), | |
| 593 y=c(%(yd)s), | |
| 594 yend=c(%(ya)s), | |
| 595 count=c(%(counts)s)) | |
| 596 """ % ({ | |
| 597 "id": k, | |
| 598 'x': ",".join(map(str, x)), | |
| 599 'y': ",".join(map(str, y)), | |
| 600 'dons': ",".join(map(str, dons)), | |
| 601 'accs': ",".join(map(str, accs)), | |
| 602 'yd': ",".join(map(str, yd)), | |
| 603 'ya': ",".join(map(str, ya)), | |
| 604 'counts': ",".join(map(str, counts)), | |
| 605 }) | |
| 606 return s | |
| 607 | |
| 608 | |
| 609 def plot(R_script): | |
| 610 p = sp.Popen("R --vanilla --slave", shell=True, stdin=sp.PIPE) | |
| 611 p.communicate(input=R_script.encode()) | |
| 612 p.stdin.close() | |
| 613 p.wait() | |
| 614 return | |
| 615 | |
| 616 | |
| 617 def colorize(d, p, color_factor): | |
| 618 levels = sorted(set(d.values())) | |
| 619 n = len(levels) | |
| 620 if n > len(p): | |
| 621 p = (p*n)[:n] | |
| 622 if color_factor: | |
| 623 s = "color_list = list(%s)\n" % (",".join('%s="%s"' % (k, | |
| 624 p[levels.index(v)]) for k, v in | |
| 625 d.items())) | |
| 626 else: | |
| 627 s = "color_list = list(%s)\n" % (",".join('%s="%s"' % (k, "grey") for | |
| 628 k, v in d.items())) | |
| 629 return s | |
| 630 | |
| 631 | |
| 632 if __name__ == "__main__": | |
| 633 | |
| 634 strand_dict = {"plus": "+", "minus": "-"} | |
| 635 | |
| 636 parser = define_options() | |
| 637 if len(sys.argv) == 1: | |
| 638 parser.print_help() | |
| 639 sys.exit(1) | |
| 640 args = parser.parse_args() | |
| 641 | |
| 642 if args.aggr and not args.overlay: | |
| 643 print("""ERROR: Cannot apply aggregate function | |
| 644 if overlay is not selected.""") | |
| 645 exit(1) | |
| 646 | |
| 647 palette = read_palette(args.palette) | |
| 648 | |
| 649 (bam_dict, overlay_dict, color_dict, | |
| 650 id_list, label_dict) = ({"+": OrderedDict()}, OrderedDict(), | |
| 651 OrderedDict(), [], OrderedDict()) | |
| 652 if args.strand != "NONE": | |
| 653 bam_dict["-"] = OrderedDict() | |
| 654 if args.junctions_bed != "": | |
| 655 junctions_list = [] | |
| 656 | |
| 657 for (id, bam, overlay_level, | |
| 658 color_level, label_text) in read_bam_input(args.bam, | |
| 659 args.overlay, | |
| 660 args.color_factor, | |
| 661 args.labels): | |
| 662 if not os.path.isfile(bam): | |
| 663 continue | |
| 664 id_list.append(id) | |
| 665 label_dict[id] = label_text | |
| 666 a, junctions = read_bam(bam, args.coordinates, args.strand) | |
| 667 if a.keys() == ["+"] and all(map(lambda x: x == 0, | |
| 668 list(a.values()[0]))): | |
| 669 print("ERROR: No reads in the specified area.") | |
| 670 exit(1) | |
| 671 for strand in a: | |
| 672 # Store junction information | |
| 673 if args.junctions_bed: | |
| 674 for k, v in zip(junctions[strand].keys(), | |
| 675 junctions[strand].values()): | |
| 676 if v > args.min_coverage: | |
| 677 junctions_list.append('\t'.join([args.coordinates.split | |
| 678 (':')[0], str(k[0]), str(k[1]), | |
| 679 id, str(v), strand])) | |
| 680 bam_dict[strand][id] = prepare_for_R(a[strand], | |
| 681 junctions[strand], | |
| 682 args.coordinates, | |
| 683 args.min_coverage) | |
| 684 if color_level is None: | |
| 685 color_dict.setdefault(id, id) | |
| 686 if overlay_level is not None: | |
| 687 overlay_dict.setdefault(overlay_level, []).append(id) | |
| 688 label_dict[overlay_level] = overlay_level | |
| 689 color_dict.setdefault(overlay_level, overlay_level) | |
| 690 if overlay_level is None: | |
| 691 color_dict.setdefault(id, color_level) | |
| 692 | |
| 693 # No bam files | |
| 694 if not bam_dict["+"]: | |
| 695 print("ERROR: No available bam files.") | |
| 696 exit(1) | |
| 697 | |
| 698 # Write junctions to BED | |
| 699 if args.junctions_bed: | |
| 700 if not args.junctions_bed.endswith('.bed'): | |
| 701 args.junctions_bed = args.junctions_bed + '.bed' | |
| 702 jbed = open(args.junctions_bed, 'w') | |
| 703 jbed.write('\n'.join(sorted(junctions_list))) | |
| 704 jbed.close() | |
| 705 | |
| 706 if args.gtf: | |
| 707 transcripts, exons = read_gtf(args.gtf, args.coordinates) | |
| 708 | |
| 709 if args.out_format not in ('pdf', 'png', 'svg', 'tiff', 'jpeg'): | |
| 710 print("""ERROR: Provided output format '%s' is not available. | |
| 711 Please select among 'pdf', 'png', 'svg', | |
| 712 'tiff' or 'jpeg'""" % args.out_format) | |
| 713 exit(1) | |
| 714 | |
| 715 # Iterate for plus and minus strand | |
| 716 for strand in bam_dict: | |
| 717 | |
| 718 # Output file name (allow tiff/tif and jpeg/jpg extensions) | |
| 719 if args.out_prefix.endswith(('.pdf', '.png', '.svg', '.tiff', | |
| 720 '.tif', '.jpeg', '.jpg')): | |
| 721 out_split = os.path.splitext(args.out_prefix) | |
| 722 if (args.out_format == out_split[1][1:] or | |
| 723 args.out_format == 'tiff' | |
| 724 and out_split[1] in ('.tiff', '.tif') or | |
| 725 args.out_format == 'jpeg' | |
| 726 and out_split[1] in ('.jpeg', '.jpg')): | |
| 727 args.out_prefix = out_split[0] | |
| 728 out_suffix = out_split[1][1:] | |
| 729 else: | |
| 730 out_suffix = args.out_format | |
| 731 else: | |
| 732 out_suffix = args.out_format | |
| 733 out_prefix = args.out_prefix + "_" + strand | |
| 734 if args.strand == "NONE": | |
| 735 out_prefix = args.out_prefix | |
| 736 else: | |
| 737 if args.out_strand != "both" \ | |
| 738 and strand != strand_dict[args.out_strand]: | |
| 739 continue | |
| 740 | |
| 741 # Find set of junctions to perform shrink | |
| 742 intersected_introns = None | |
| 743 if args.shrink: | |
| 744 introns = (v for vs in bam_dict[strand].values() for v in | |
| 745 zip(vs[2], vs[3])) | |
| 746 intersected_introns = list(intersect_introns(introns)) | |
| 747 | |
| 748 # *** PLOT *** Define plot height | |
| 749 bam_height = args.height * len(id_list) | |
| 750 if args.overlay: | |
| 751 bam_height = args.height * len(overlay_dict) | |
| 752 if args.gtf: | |
| 753 bam_height += args.ann_height | |
| 754 | |
| 755 # *** PLOT *** Start R script by loading libraries, | |
| 756 # initializing variables, etc... | |
| 757 R_script = setup_R_script(bam_height, args.width, | |
| 758 args.base_size, label_dict) | |
| 759 | |
| 760 R_script += colorize(color_dict, palette, args.color_factor) | |
| 761 | |
| 762 # *** PLOT *** Prepare annotation plot only for the first bam file | |
| 763 arrow_bins = 50 | |
| 764 if args.gtf: | |
| 765 # Make introns from annotation (they are shrunk if required) | |
| 766 annotation = make_introns(transcripts, exons, intersected_introns) | |
| 767 x = list(bam_dict[strand].values())[0][0] | |
| 768 if args.shrink: | |
| 769 x, _ = shrink_density(x, x, intersected_introns) | |
| 770 R_script += gtf_for_ggplot(annotation, x[0], x[-1], arrow_bins) | |
| 771 | |
| 772 R_script += make_R_lists(id_list, bam_dict[strand], overlay_dict, | |
| 773 args.aggr, intersected_introns) | |
| 774 | |
| 775 R_script += """ | |
| 776 | |
| 777 pdf(NULL) # just to remove the blank pdf produced by ggplotGrob | |
| 778 # fix problems with ggplot2 vs >3.0.0 | |
| 779 if(packageVersion('ggplot2') >= '3.0.0'){ | |
| 780 vs = 1 | |
| 781 } else { | |
| 782 vs = 0 | |
| 783 } | |
| 784 | |
| 785 density_grobs = list(); | |
| 786 | |
| 787 for (bam_index in 1:length(density_list)) { | |
| 788 | |
| 789 id = names(density_list)[bam_index] | |
| 790 d = data.table(density_list[[id]]) | |
| 791 junctions = data.table(junction_list[[id]]) | |
| 792 | |
| 793 maxheight = max(d[['y']]) | |
| 794 | |
| 795 # Density plot | |
| 796 gp = ggplot(d) + geom_bar(aes(x, y), width=1, | |
| 797 position='identity', | |
| 798 stat='identity', | |
| 799 fill=color_list[[id]], | |
| 800 alpha=%(alpha)s) | |
| 801 gp = gp + labs(y=labels[[id]]) | |
| 802 gp = gp + scale_x_continuous(expand=c(0,0.2)) | |
| 803 | |
| 804 # fix problems with ggplot2 vs >3.0.0 | |
| 805 if(packageVersion('ggplot2') >= '3.0.0') { | |
| 806 gp = gp + | |
| 807 scale_y_continuous(breaks = | |
| 808 ggplot_build(gp | |
| 809 )$layout$panel_params[[1]]$y.major_source) | |
| 810 } else { | |
| 811 gp = gp + | |
| 812 scale_y_continuous(breaks = | |
| 813 ggplot_build(gp | |
| 814 )$layout$panel_ranges[[1]]$y.major_source) | |
| 815 } | |
| 816 | |
| 817 # Aggregate junction counts | |
| 818 row_i = c() | |
| 819 if (nrow(junctions) >0 ) { | |
| 820 | |
| 821 junctions$jlabel = as.character(junctions$count) | |
| 822 junctions = setNames(junctions[,.(max(y), | |
| 823 max(yend), | |
| 824 round(mean(count)), | |
| 825 paste(jlabel, | |
| 826 collapse=",")), | |
| 827 keyby=.(x,xend)], | |
| 828 names(junctions)) | |
| 829 if ("%(args.aggr)s" != "") { | |
| 830 junctions = setNames( | |
| 831 junctions[,.(max(y), | |
| 832 max(yend), | |
| 833 round(%(args.aggr)s(count)), | |
| 834 round(%(args.aggr)s(count))), | |
| 835 keyby=.(x,xend)], | |
| 836 names(junctions)) | |
| 837 } | |
| 838 # The number of rows (unique junctions per bam) has to be | |
| 839 # calculated after aggregation | |
| 840 row_i = 1:nrow(junctions) | |
| 841 } | |
| 842 | |
| 843 | |
| 844 for (i in row_i) { | |
| 845 | |
| 846 j_tot_counts = sum(junctions[['count']]) | |
| 847 | |
| 848 j = as.numeric(junctions[i,1:5]) | |
| 849 | |
| 850 if ("%(args.aggr)s" != "") { | |
| 851 j[3] = as.numeric(d[x==j[1]-1,y]) | |
| 852 j[4] = as.numeric(d[x==j[2]+1,y]) | |
| 853 } | |
| 854 | |
| 855 # Find intron midpoint | |
| 856 xmid = round(mean(j[1:2]), 1) | |
| 857 ymid = max(j[3:4]) * 1.2 | |
| 858 | |
| 859 # Thickness of the arch | |
| 860 lwd = scale_lwd(j[5]/j_tot_counts) | |
| 861 | |
| 862 curve_par = gpar(lwd=lwd, col=color_list[[id]]) | |
| 863 | |
| 864 # Arc grobs | |
| 865 | |
| 866 # Choose position of the arch (top or bottom) | |
| 867 nss = i | |
| 868 if (nss%%%%2 == 0) { #bottom | |
| 869 ymid = -0.3 * maxheight | |
| 870 # Draw the arcs | |
| 871 # Left | |
| 872 curve = xsplineGrob(x = c(0, 0, 1, 1), | |
| 873 y = c(1, 0, 0, 0), | |
| 874 shape = 1, | |
| 875 gp = curve_par) | |
| 876 gp = gp + annotation_custom(grob = curve, | |
| 877 j[1], xmid, 0, ymid) | |
| 878 # Right | |
| 879 curve = xsplineGrob(x = c(1, 1, 0, 0), | |
| 880 y = c(1, 0, 0, 0), | |
| 881 shape = 1, | |
| 882 gp = curve_par) | |
| 883 gp = gp + annotation_custom(grob = curve, | |
| 884 xmid, | |
| 885 j[2], | |
| 886 0, ymid) | |
| 887 } | |
| 888 | |
| 889 if (nss%%%%2 != 0) { #top | |
| 890 # Draw the arcs | |
| 891 # Left | |
| 892 curve = xsplineGrob(x = c(0, 0, 1, 1), | |
| 893 y = c(0, 1, 1, 1), | |
| 894 shape = 1, | |
| 895 gp = curve_par) | |
| 896 gp = gp + annotation_custom(grob = curve, | |
| 897 j[1], xmid, j[3], ymid) | |
| 898 # Right | |
| 899 curve = xsplineGrob(x = c(1, 1, 0, 0), | |
| 900 y = c(0, 1, 1, 1), | |
| 901 shape = 1, | |
| 902 gp = curve_par) | |
| 903 gp = gp + annotation_custom(grob = curve, | |
| 904 xmid, j[2], j[4], ymid) | |
| 905 } | |
| 906 | |
| 907 # Add junction labels | |
| 908 gp = gp + annotate("label", x = xmid, y = ymid, | |
| 909 label = as.character(junctions[i,6]), | |
| 910 vjust = 0.5, hjust = 0.5, | |
| 911 label.padding = unit(0.01, "lines"), | |
| 912 label.size = NA, | |
| 913 size = (base_size*0.352777778)*0.6) | |
| 914 | |
| 915 } | |
| 916 | |
| 917 gpGrob = ggplotGrob(gp); | |
| 918 gpGrob$layout$clip[gpGrob$layout$name=="panel"] <- "off" | |
| 919 if (bam_index == 1) { | |
| 920 # fix problems ggplot2 vs | |
| 921 maxWidth = gpGrob$widths[2+vs] + gpGrob$widths[3+vs]; | |
| 922 maxYtextWidth = gpGrob$widths[3+vs]; | |
| 923 # Extract x axis grob (trim=F --> keep empty cells) | |
| 924 xaxisGrob <- gtable_filter(gpGrob, "axis-b", trim=F) | |
| 925 # fix problems ggplot2 vs | |
| 926 xaxisGrob$heights[8+vs] = gpGrob$heights[1] | |
| 927 x.axis.height = gpGrob$heights[7+vs] + gpGrob$heights[1] | |
| 928 } | |
| 929 | |
| 930 | |
| 931 # Remove x axis from all density plots | |
| 932 kept_names = gpGrob$layout$name[gpGrob$layout$name != "axis-b"] | |
| 933 gpGrob <- gtable_filter(gpGrob, | |
| 934 paste(kept_names, sep = "", | |
| 935 collapse = "|"), | |
| 936 trim=F) | |
| 937 | |
| 938 # Find max width of y text and y label and max width of y text | |
| 939 # fix problems ggplot2 vs | |
| 940 maxWidth = grid::unit.pmax(maxWidth, | |
| 941 gpGrob$widths[2+vs] + | |
| 942 gpGrob$widths[3+vs]); | |
| 943 maxYtextWidth = grid::unit.pmax(maxYtextWidth, | |
| 944 gpGrob$widths[3+vs]); | |
| 945 density_grobs[[id]] = gpGrob; | |
| 946 } | |
| 947 | |
| 948 # Add x axis grob after density grobs BEFORE annotation grob | |
| 949 density_grobs[["xaxis"]] = xaxisGrob | |
| 950 | |
| 951 # Annotation grob | |
| 952 if (%(args.gtf)s == 1) { | |
| 953 gtfGrob = ggplotGrob(gtfp); | |
| 954 maxWidth = grid::unit.pmax(maxWidth, | |
| 955 gtfGrob$widths[2+vs] + | |
| 956 gtfGrob$widths[3+vs]); | |
| 957 density_grobs[['gtf']] = gtfGrob; | |
| 958 } | |
| 959 | |
| 960 # Reassign grob widths to align the plots | |
| 961 for (id in names(density_grobs)) { | |
| 962 density_grobs[[id]]$widths[1] <- | |
| 963 density_grobs[[id]]$widths[1] + | |
| 964 maxWidth - (density_grobs[[id]]$widths[2 + vs] + | |
| 965 maxYtextWidth) | |
| 966 # fix problems ggplot2 vs | |
| 967 density_grobs[[id]]$widths[3 + vs] <- | |
| 968 maxYtextWidth # fix problems ggplot2 vs | |
| 969 } | |
| 970 | |
| 971 # Heights for density, x axis and annotation | |
| 972 heights = unit.c( | |
| 973 unit(rep(%(signal_height)s, | |
| 974 length(density_list)), "in"), | |
| 975 x.axis.height, | |
| 976 unit(%(ann_height)s*%(args.gtf)s, "in") | |
| 977 ) | |
| 978 | |
| 979 # Arrange grobs | |
| 980 argrobs = arrangeGrob( | |
| 981 grobs=density_grobs, | |
| 982 ncol=1, | |
| 983 heights = heights, | |
| 984 ); | |
| 985 | |
| 986 # Save plot to file in the requested format | |
| 987 if ("%(out_format)s" == "tiff"){ | |
| 988 # TIFF images will be lzw-compressed | |
| 989 ggsave("%(out)s", | |
| 990 plot = argrobs, | |
| 991 device = "tiff", | |
| 992 width = width, | |
| 993 height = height, | |
| 994 units = "in", | |
| 995 dpi = %(out_resolution)s, | |
| 996 compression = "lzw") | |
| 997 } else { | |
| 998 ggsave("%(out)s", | |
| 999 plot = argrobs, | |
| 1000 device = "%(out_format)s", | |
| 1001 width = width, | |
| 1002 height = height, | |
| 1003 units = "in", | |
| 1004 dpi = %(out_resolution)s) | |
| 1005 } | |
| 1006 | |
| 1007 dev.log = dev.off() | |
| 1008 | |
| 1009 """ % ({ | |
| 1010 "out": "%s.%s" % (out_prefix, out_suffix), | |
| 1011 "out_format": args.out_format, | |
| 1012 "out_resolution": args.out_resolution, | |
| 1013 "args.gtf": float(bool(args.gtf)), | |
| 1014 "args.aggr": args.aggr.rstrip("_j"), | |
| 1015 "signal_height": args.height, | |
| 1016 "ann_height": args.ann_height, | |
| 1017 "alpha": args.alpha, | |
| 1018 }) | |
| 1019 if os.getenv('GGSASHIMI_DEBUG') is not None: | |
| 1020 with open("R_script", 'w') as r: | |
| 1021 r.write(R_script) | |
| 1022 else: | |
| 1023 plot(R_script) | |
| 1024 exit() |
