Mercurial > repos > davidvanzessen > mutation_analysis
comparison change_o/DefineClones.py @ 0:8a5a2abbb870 draft default tip
Uploaded
| author | davidvanzessen |
|---|---|
| date | Mon, 29 Aug 2016 05:36:10 -0400 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:8a5a2abbb870 |
|---|---|
| 1 #!/usr/bin/env python3 | |
| 2 """ | |
| 3 Assign Ig sequences into clones | |
| 4 """ | |
| 5 # Info | |
| 6 __author__ = 'Namita Gupta, Jason Anthony Vander Heiden, Gur Yaari, Mohamed Uduman' | |
| 7 from changeo import __version__, __date__ | |
| 8 | |
| 9 # Imports | |
| 10 import os | |
| 11 import re | |
| 12 import sys | |
| 13 import numpy as np | |
| 14 from argparse import ArgumentParser | |
| 15 from collections import OrderedDict | |
| 16 from itertools import chain | |
| 17 from textwrap import dedent | |
| 18 from time import time | |
| 19 from Bio import pairwise2 | |
| 20 from Bio.Seq import translate | |
| 21 | |
| 22 # Presto and changeo imports | |
| 23 from presto.Defaults import default_out_args | |
| 24 from presto.IO import getFileType, getOutputHandle, printLog, printProgress | |
| 25 from presto.Multiprocessing import manageProcesses | |
| 26 from presto.Sequence import getDNAScoreDict | |
| 27 from changeo.Commandline import CommonHelpFormatter, getCommonArgParser, parseCommonArgs | |
| 28 from changeo.Distance import getDNADistMatrix, getAADistMatrix, \ | |
| 29 hs1f_model, m1n_model, hs5f_model, \ | |
| 30 calcDistances, formClusters | |
| 31 from changeo.IO import getDbWriter, readDbFile, countDbFile | |
| 32 from changeo.Multiprocessing import DbData, DbResult | |
| 33 | |
| 34 # Defaults | |
| 35 default_translate = False | |
| 36 default_distance = 0.0 | |
| 37 default_bygroup_model = 'hs1f' | |
| 38 default_hclust_model = 'chen2010' | |
| 39 default_seq_field = 'JUNCTION' | |
| 40 default_norm = 'len' | |
| 41 default_sym = 'avg' | |
| 42 default_linkage = 'single' | |
| 43 | |
| 44 # TODO: should be in Distance, but need to be after function definitions | |
| 45 # Amino acid Hamming distance | |
| 46 aa_model = getAADistMatrix(mask_dist=1, gap_dist=0) | |
| 47 | |
| 48 # DNA Hamming distance | |
| 49 ham_model = getDNADistMatrix(mask_dist=0, gap_dist=0) | |
| 50 | |
| 51 | |
| 52 # TODO: this function is an abstraction to facilitate later cleanup | |
| 53 def getModelMatrix(model): | |
| 54 """ | |
| 55 Simple wrapper to get distance matrix from model name | |
| 56 | |
| 57 Arguments: | |
| 58 model = model name | |
| 59 | |
| 60 Return: | |
| 61 a pandas.DataFrame containing the character distance matrix | |
| 62 """ | |
| 63 if model == 'aa': | |
| 64 return(aa_model) | |
| 65 elif model == 'ham': | |
| 66 return(ham_model) | |
| 67 elif model == 'm1n': | |
| 68 return(m1n_model) | |
| 69 elif model == 'hs1f': | |
| 70 return(hs1f_model) | |
| 71 elif model == 'hs5f': | |
| 72 return(hs5f_model) | |
| 73 else: | |
| 74 sys.stderr.write('Unrecognized distance model: %s.\n' % model) | |
| 75 | |
| 76 | |
| 77 def indexJunctions(db_iter, fields=None, mode='gene', action='first'): | |
| 78 """ | |
| 79 Identifies preclonal groups by V, J and junction length | |
| 80 | |
| 81 Arguments: | |
| 82 db_iter = an iterator of IgRecords defined by readDbFile | |
| 83 fields = additional annotation fields to use to group preclones; | |
| 84 if None use only V, J and junction length | |
| 85 mode = specificity of alignment call to use for assigning preclones; | |
| 86 one of ('allele', 'gene') | |
| 87 action = how to handle multiple value fields when assigning preclones; | |
| 88 one of ('first', 'set') | |
| 89 | |
| 90 Returns: | |
| 91 a dictionary of {(V, J, junction length):[IgRecords]} | |
| 92 """ | |
| 93 # Define functions for grouping keys | |
| 94 if mode == 'allele' and fields is None: | |
| 95 def _get_key(rec, act): | |
| 96 return (rec.getVAllele(act), rec.getJAllele(act), | |
| 97 None if rec.junction is None else len(rec.junction)) | |
| 98 elif mode == 'gene' and fields is None: | |
| 99 def _get_key(rec, act): | |
| 100 return (rec.getVGene(act), rec.getJGene(act), | |
| 101 None if rec.junction is None else len(rec.junction)) | |
| 102 elif mode == 'allele' and fields is not None: | |
| 103 def _get_key(rec, act): | |
| 104 vdj = [rec.getVAllele(act), rec.getJAllele(act), | |
| 105 None if rec.junction is None else len(rec.junction)] | |
| 106 ann = [rec.toDict().get(k, None) for k in fields] | |
| 107 return tuple(chain(vdj, ann)) | |
| 108 elif mode == 'gene' and fields is not None: | |
| 109 def _get_key(rec, act): | |
| 110 vdj = [rec.getVGene(act), rec.getJGene(act), | |
| 111 None if rec.junction is None else len(rec.junction)] | |
| 112 ann = [rec.toDict().get(k, None) for k in fields] | |
| 113 return tuple(chain(vdj, ann)) | |
| 114 | |
| 115 start_time = time() | |
| 116 clone_index = {} | |
| 117 rec_count = 0 | |
| 118 for rec in db_iter: | |
| 119 key = _get_key(rec, action) | |
| 120 | |
| 121 # Print progress | |
| 122 if rec_count == 0: | |
| 123 print('PROGRESS> Grouping sequences') | |
| 124 | |
| 125 printProgress(rec_count, step=1000, start_time=start_time) | |
| 126 rec_count += 1 | |
| 127 | |
| 128 # Assigned passed preclone records to key and failed to index None | |
| 129 if all([k is not None and k != '' for k in key]): | |
| 130 #print key | |
| 131 # TODO: Has much slow. Should have less slow. | |
| 132 if action == 'set': | |
| 133 | |
| 134 f_range = list(range(2, 3 + (len(fields) if fields else 0))) | |
| 135 vdj_range = list(range(2)) | |
| 136 | |
| 137 # Check for any keys that have matching columns and junction length and overlapping genes/alleles | |
| 138 to_remove = [] | |
| 139 if len(clone_index) > (1 if None in clone_index else 0) and key not in clone_index: | |
| 140 key = list(key) | |
| 141 for k in clone_index: | |
| 142 if k is not None and all([key[i] == k[i] for i in f_range]): | |
| 143 if all([not set(key[i]).isdisjoint(set(k[i])) for i in vdj_range]): | |
| 144 for i in vdj_range: key[i] = tuple(set(key[i]).union(set(k[i]))) | |
| 145 to_remove.append(k) | |
| 146 | |
| 147 # Remove original keys, replace with union of all genes/alleles and append values to new key | |
| 148 val = [rec] | |
| 149 val += list(chain(*(clone_index.pop(k) for k in to_remove))) | |
| 150 clone_index[tuple(key)] = clone_index.get(tuple(key),[]) + val | |
| 151 | |
| 152 elif action == 'first': | |
| 153 clone_index.setdefault(key, []).append(rec) | |
| 154 else: | |
| 155 clone_index.setdefault(None, []).append(rec) | |
| 156 | |
| 157 printProgress(rec_count, step=1000, start_time=start_time, end=True) | |
| 158 | |
| 159 return clone_index | |
| 160 | |
| 161 | |
| 162 def distanceClones(records, model=default_bygroup_model, distance=default_distance, | |
| 163 dist_mat=None, norm=default_norm, sym=default_sym, | |
| 164 linkage=default_linkage, seq_field=default_seq_field): | |
| 165 """ | |
| 166 Separates a set of IgRecords into clones | |
| 167 | |
| 168 Arguments: | |
| 169 records = an iterator of IgRecords | |
| 170 model = substitution model used to calculate distance | |
| 171 distance = the distance threshold to assign clonal groups | |
| 172 dist_mat = pandas DataFrame of pairwise nucleotide or amino acid distances | |
| 173 norm = normalization method | |
| 174 sym = symmetry method | |
| 175 linkage = type of linkage | |
| 176 seq_field = sequence field used to calculate distance between records | |
| 177 | |
| 178 Returns: | |
| 179 a dictionary of lists defining {clone number: [IgRecords clonal group]} | |
| 180 """ | |
| 181 # Get distance matrix if not provided | |
| 182 if dist_mat is None: dist_mat = getModelMatrix(model) | |
| 183 | |
| 184 # Determine length of n-mers | |
| 185 if model in ['hs1f', 'm1n', 'aa', 'ham']: | |
| 186 nmer_len = 1 | |
| 187 elif model in ['hs5f']: | |
| 188 nmer_len = 5 | |
| 189 else: | |
| 190 sys.stderr.write('Unrecognized distance model: %s.\n' % model) | |
| 191 | |
| 192 # Define unique junction mapping | |
| 193 seq_map = {} | |
| 194 for ig in records: | |
| 195 seq = ig.getSeqField(seq_field) | |
| 196 # Check if sequence length is 0 | |
| 197 if len(seq) == 0: | |
| 198 return None | |
| 199 | |
| 200 seq = re.sub('[\.-]','N', str(seq)) | |
| 201 if model == 'aa': seq = translate(seq) | |
| 202 | |
| 203 seq_map.setdefault(seq, []).append(ig) | |
| 204 | |
| 205 # Process records | |
| 206 if len(seq_map) == 1: | |
| 207 return {1:records} | |
| 208 | |
| 209 # Define sequences | |
| 210 seqs = list(seq_map.keys()) | |
| 211 | |
| 212 # Calculate pairwise distance matrix | |
| 213 dists = calcDistances(seqs, nmer_len, dist_mat, norm, sym) | |
| 214 | |
| 215 # Perform hierarchical clustering | |
| 216 clusters = formClusters(dists, linkage, distance) | |
| 217 | |
| 218 # Turn clusters into clone dictionary | |
| 219 clone_dict = {} | |
| 220 for i, c in enumerate(clusters): | |
| 221 clone_dict.setdefault(c, []).extend(seq_map[seqs[i]]) | |
| 222 | |
| 223 return clone_dict | |
| 224 | |
| 225 | |
| 226 def distChen2010(records): | |
| 227 """ | |
| 228 Calculate pairwise distances as defined in Chen 2010 | |
| 229 | |
| 230 Arguments: | |
| 231 records = list of IgRecords where first is query to be compared to others in list | |
| 232 | |
| 233 Returns: | |
| 234 list of distances | |
| 235 """ | |
| 236 # Pull out query sequence and V/J information | |
| 237 query = records.popitem(last=False) | |
| 238 query_cdr3 = query.junction[3:-3] | |
| 239 query_v_allele = query.getVAllele() | |
| 240 query_v_gene = query.getVGene() | |
| 241 query_v_family = query.getVFamily() | |
| 242 query_j_allele = query.getJAllele() | |
| 243 query_j_gene = query.getJGene() | |
| 244 # Create alignment scoring dictionary | |
| 245 score_dict = getDNAScoreDict() | |
| 246 | |
| 247 scores = [0]*len(records) | |
| 248 for i in range(len(records)): | |
| 249 ld = pairwise2.align.globalds(query_cdr3, records[i].junction[3:-3], | |
| 250 score_dict, -1, -1, one_alignment_only=True) | |
| 251 # Check V similarity | |
| 252 if records[i].getVAllele() == query_v_allele: ld += 0 | |
| 253 elif records[i].getVGene() == query_v_gene: ld += 1 | |
| 254 elif records[i].getVFamily() == query_v_family: ld += 3 | |
| 255 else: ld += 5 | |
| 256 # Check J similarity | |
| 257 if records[i].getJAllele() == query_j_allele: ld += 0 | |
| 258 elif records[i].getJGene() == query_j_gene: ld += 1 | |
| 259 else: ld += 3 | |
| 260 # Divide by length | |
| 261 scores[i] = ld/max(len(records[i].junction[3:-3]), query_cdr3) | |
| 262 | |
| 263 return scores | |
| 264 | |
| 265 | |
| 266 def distAdemokun2011(records): | |
| 267 """ | |
| 268 Calculate pairwise distances as defined in Ademokun 2011 | |
| 269 | |
| 270 Arguments: | |
| 271 records = list of IgRecords where first is query to be compared to others in list | |
| 272 | |
| 273 Returns: | |
| 274 list of distances | |
| 275 """ | |
| 276 # Pull out query sequence and V family information | |
| 277 query = records.popitem(last=False) | |
| 278 query_cdr3 = query.junction[3:-3] | |
| 279 query_v_family = query.getVFamily() | |
| 280 # Create alignment scoring dictionary | |
| 281 score_dict = getDNAScoreDict() | |
| 282 | |
| 283 scores = [0]*len(records) | |
| 284 for i in range(len(records)): | |
| 285 | |
| 286 if abs(len(query_cdr3) - len(records[i].junction[3:-3])) > 10: | |
| 287 scores[i] = 1 | |
| 288 elif query_v_family != records[i].getVFamily(): | |
| 289 scores[i] = 1 | |
| 290 else: | |
| 291 ld = pairwise2.align.globalds(query_cdr3, records[i].junction[3:-3], | |
| 292 score_dict, -1, -1, one_alignment_only=True) | |
| 293 scores[i] = ld/min(len(records[i].junction[3:-3]), query_cdr3) | |
| 294 | |
| 295 return scores | |
| 296 | |
| 297 | |
| 298 def hierClust(dist_mat, method='chen2010'): | |
| 299 """ | |
| 300 Calculate hierarchical clustering | |
| 301 | |
| 302 Arguments: | |
| 303 dist_mat = square-formed distance matrix of pairwise CDR3 comparisons | |
| 304 | |
| 305 Returns: | |
| 306 list of cluster ids | |
| 307 """ | |
| 308 if method == 'chen2010': | |
| 309 clusters = formClusters(dist_mat, 'average', 0.32) | |
| 310 elif method == 'ademokun2011': | |
| 311 clusters = formClusters(dist_mat, 'complete', 0.25) | |
| 312 else: clusters = np.ones(dist_mat.shape[0]) | |
| 313 | |
| 314 return clusters | |
| 315 | |
| 316 # TODO: Merge duplicate feed, process and collect functions. | |
| 317 def feedQueue(alive, data_queue, db_file, group_func, group_args={}): | |
| 318 """ | |
| 319 Feeds the data queue with Ig records | |
| 320 | |
| 321 Arguments: | |
| 322 alive = a multiprocessing.Value boolean controlling whether processing continues | |
| 323 if False exit process | |
| 324 data_queue = a multiprocessing.Queue to hold data for processing | |
| 325 db_file = the Ig record database file | |
| 326 group_func = the function to use for assigning preclones | |
| 327 group_args = a dictionary of arguments to pass to group_func | |
| 328 | |
| 329 Returns: | |
| 330 None | |
| 331 """ | |
| 332 # Open input file and perform grouping | |
| 333 try: | |
| 334 # Iterate over Ig records and assign groups | |
| 335 db_iter = readDbFile(db_file) | |
| 336 clone_dict = group_func(db_iter, **group_args) | |
| 337 except: | |
| 338 #sys.stderr.write('Exception in feeder grouping step\n') | |
| 339 alive.value = False | |
| 340 raise | |
| 341 | |
| 342 # Add groups to data queue | |
| 343 try: | |
| 344 #print 'START FEED', alive.value | |
| 345 # Iterate over groups and feed data queue | |
| 346 clone_iter = iter(clone_dict.items()) | |
| 347 while alive.value: | |
| 348 # Get data from queue | |
| 349 if data_queue.full(): continue | |
| 350 else: data = next(clone_iter, None) | |
| 351 # Exit upon reaching end of iterator | |
| 352 if data is None: break | |
| 353 #print "FEED", alive.value, k | |
| 354 | |
| 355 # Feed queue | |
| 356 data_queue.put(DbData(*data)) | |
| 357 else: | |
| 358 sys.stderr.write('PID %s: Error in sibling process detected. Cleaning up.\n' \ | |
| 359 % os.getpid()) | |
| 360 return None | |
| 361 except: | |
| 362 #sys.stderr.write('Exception in feeder queue feeding step\n') | |
| 363 alive.value = False | |
| 364 raise | |
| 365 | |
| 366 return None | |
| 367 | |
| 368 | |
| 369 def feedQueueClust(alive, data_queue, db_file, group_func=None, group_args={}): | |
| 370 """ | |
| 371 Feeds the data queue with Ig records | |
| 372 | |
| 373 Arguments: | |
| 374 alive = a multiprocessing.Value boolean controlling whether processing continues | |
| 375 if False exit process | |
| 376 data_queue = a multiprocessing.Queue to hold data for processing | |
| 377 db_file = the Ig record database file | |
| 378 | |
| 379 Returns: | |
| 380 None | |
| 381 """ | |
| 382 # Open input file and perform grouping | |
| 383 try: | |
| 384 # Iterate over Ig records and order by junction length | |
| 385 records = {} | |
| 386 db_iter = readDbFile(db_file) | |
| 387 for rec in db_iter: | |
| 388 records[rec.id] = rec | |
| 389 records = OrderedDict(sorted(list(records.items()), key=lambda i: i[1].junction_length)) | |
| 390 dist_dict = {} | |
| 391 for __ in range(len(records)): | |
| 392 k,v = records.popitem(last=False) | |
| 393 dist_dict[k] = [v].append(list(records.values())) | |
| 394 except: | |
| 395 #sys.stderr.write('Exception in feeder grouping step\n') | |
| 396 alive.value = False | |
| 397 raise | |
| 398 | |
| 399 # Add groups to data queue | |
| 400 try: | |
| 401 # print 'START FEED', alive.value | |
| 402 # Iterate over groups and feed data queue | |
| 403 dist_iter = iter(dist_dict.items()) | |
| 404 while alive.value: | |
| 405 # Get data from queue | |
| 406 if data_queue.full(): continue | |
| 407 else: data = next(dist_iter, None) | |
| 408 # Exit upon reaching end of iterator | |
| 409 if data is None: break | |
| 410 #print "FEED", alive.value, k | |
| 411 | |
| 412 # Feed queue | |
| 413 data_queue.put(DbData(*data)) | |
| 414 else: | |
| 415 sys.stderr.write('PID %s: Error in sibling process detected. Cleaning up.\n' \ | |
| 416 % os.getpid()) | |
| 417 return None | |
| 418 except: | |
| 419 #sys.stderr.write('Exception in feeder queue feeding step\n') | |
| 420 alive.value = False | |
| 421 raise | |
| 422 | |
| 423 return None | |
| 424 | |
| 425 | |
| 426 def processQueue(alive, data_queue, result_queue, clone_func, clone_args): | |
| 427 """ | |
| 428 Pulls from data queue, performs calculations, and feeds results queue | |
| 429 | |
| 430 Arguments: | |
| 431 alive = a multiprocessing.Value boolean controlling whether processing continues | |
| 432 if False exit process | |
| 433 data_queue = a multiprocessing.Queue holding data to process | |
| 434 result_queue = a multiprocessing.Queue to hold processed results | |
| 435 clone_func = the function to call for clonal assignment | |
| 436 clone_args = a dictionary of arguments to pass to clone_func | |
| 437 | |
| 438 Returns: | |
| 439 None | |
| 440 """ | |
| 441 try: | |
| 442 # Iterator over data queue until sentinel object reached | |
| 443 while alive.value: | |
| 444 # Get data from queue | |
| 445 if data_queue.empty(): continue | |
| 446 else: data = data_queue.get() | |
| 447 # Exit upon reaching sentinel | |
| 448 if data is None: break | |
| 449 | |
| 450 # Define result object for iteration and get data records | |
| 451 records = data.data | |
| 452 result = DbResult(data.id, records) | |
| 453 | |
| 454 # Check for invalid data (due to failed indexing) and add failed result | |
| 455 if not data: | |
| 456 result_queue.put(result) | |
| 457 continue | |
| 458 | |
| 459 # Add V(D)J to log | |
| 460 result.log['ID'] = ','.join([str(x) for x in data.id]) | |
| 461 result.log['VALLELE'] = ','.join(set([(r.getVAllele() or '') for r in records])) | |
| 462 result.log['DALLELE'] = ','.join(set([(r.getDAllele() or '') for r in records])) | |
| 463 result.log['JALLELE'] = ','.join(set([(r.getJAllele() or '') for r in records])) | |
| 464 result.log['JUNCLEN'] = ','.join(set([(str(len(r.junction)) or '0') for r in records])) | |
| 465 result.log['SEQUENCES'] = len(records) | |
| 466 | |
| 467 # Checking for preclone failure and assign clones | |
| 468 clones = clone_func(records, **clone_args) if data else None | |
| 469 | |
| 470 # import cProfile | |
| 471 # prof = cProfile.Profile() | |
| 472 # clones = prof.runcall(clone_func, records, **clone_args) | |
| 473 # prof.dump_stats('worker-%d.prof' % os.getpid()) | |
| 474 | |
| 475 if clones is not None: | |
| 476 result.results = clones | |
| 477 result.valid = True | |
| 478 result.log['CLONES'] = len(clones) | |
| 479 else: | |
| 480 result.log['CLONES'] = 0 | |
| 481 | |
| 482 # Feed results to result queue | |
| 483 result_queue.put(result) | |
| 484 else: | |
| 485 sys.stderr.write('PID %s: Error in sibling process detected. Cleaning up.\n' \ | |
| 486 % os.getpid()) | |
| 487 return None | |
| 488 except: | |
| 489 #sys.stderr.write('Exception in worker\n') | |
| 490 alive.value = False | |
| 491 raise | |
| 492 | |
| 493 return None | |
| 494 | |
| 495 | |
| 496 def processQueueClust(alive, data_queue, result_queue, clone_func, clone_args): | |
| 497 """ | |
| 498 Pulls from data queue, performs calculations, and feeds results queue | |
| 499 | |
| 500 Arguments: | |
| 501 alive = a multiprocessing.Value boolean controlling whether processing continues | |
| 502 if False exit process | |
| 503 data_queue = a multiprocessing.Queue holding data to process | |
| 504 result_queue = a multiprocessing.Queue to hold processed results | |
| 505 clone_func = the function to call for calculating pairwise distances between sequences | |
| 506 clone_args = a dictionary of arguments to pass to clone_func | |
| 507 | |
| 508 Returns: | |
| 509 None | |
| 510 """ | |
| 511 | |
| 512 try: | |
| 513 # print 'START WORK', alive.value | |
| 514 # Iterator over data queue until sentinel object reached | |
| 515 while alive.value: | |
| 516 # Get data from queue | |
| 517 if data_queue.empty(): continue | |
| 518 else: data = data_queue.get() | |
| 519 # Exit upon reaching sentinel | |
| 520 if data is None: break | |
| 521 # print "WORK", alive.value, data['id'] | |
| 522 | |
| 523 # Define result object for iteration and get data records | |
| 524 records = data.data | |
| 525 result = DbResult(data.id, records) | |
| 526 | |
| 527 # Create row of distance matrix and check for error | |
| 528 dist_row = clone_func(records, **clone_args) if data else None | |
| 529 if dist_row is not None: | |
| 530 result.results = dist_row | |
| 531 result.valid = True | |
| 532 | |
| 533 # Feed results to result queue | |
| 534 result_queue.put(result) | |
| 535 else: | |
| 536 sys.stderr.write('PID %s: Error in sibling process detected. Cleaning up.\n' \ | |
| 537 % os.getpid()) | |
| 538 return None | |
| 539 except: | |
| 540 #sys.stderr.write('Exception in worker\n') | |
| 541 alive.value = False | |
| 542 raise | |
| 543 | |
| 544 return None | |
| 545 | |
| 546 | |
| 547 def collectQueue(alive, result_queue, collect_queue, db_file, out_args, cluster_func=None, cluster_args={}): | |
| 548 """ | |
| 549 Assembles results from a queue of individual sequence results and manages log/file I/O | |
| 550 | |
| 551 Arguments: | |
| 552 alive = a multiprocessing.Value boolean controlling whether processing continues | |
| 553 if False exit process | |
| 554 result_queue = a multiprocessing.Queue holding processQueue results | |
| 555 collect_queue = a multiprocessing.Queue to store collector return values | |
| 556 db_file = the input database file name | |
| 557 out_args = common output argument dictionary from parseCommonArgs | |
| 558 cluster_func = the function to call for carrying out clustering on distance matrix | |
| 559 cluster_args = a dictionary of arguments to pass to cluster_func | |
| 560 | |
| 561 Returns: | |
| 562 None | |
| 563 (adds 'log' and 'out_files' to collect_dict) | |
| 564 """ | |
| 565 # Open output files | |
| 566 try: | |
| 567 # Count records and define output format | |
| 568 out_type = getFileType(db_file) if out_args['out_type'] is None \ | |
| 569 else out_args['out_type'] | |
| 570 result_count = countDbFile(db_file) | |
| 571 | |
| 572 # Defined successful output handle | |
| 573 pass_handle = getOutputHandle(db_file, | |
| 574 out_label='clone-pass', | |
| 575 out_dir=out_args['out_dir'], | |
| 576 out_name=out_args['out_name'], | |
| 577 out_type=out_type) | |
| 578 pass_writer = getDbWriter(pass_handle, db_file, add_fields='CLONE') | |
| 579 | |
| 580 # Defined failed alignment output handle | |
| 581 if out_args['failed']: | |
| 582 fail_handle = getOutputHandle(db_file, | |
| 583 out_label='clone-fail', | |
| 584 out_dir=out_args['out_dir'], | |
| 585 out_name=out_args['out_name'], | |
| 586 out_type=out_type) | |
| 587 fail_writer = getDbWriter(fail_handle, db_file) | |
| 588 else: | |
| 589 fail_handle = None | |
| 590 fail_writer = None | |
| 591 | |
| 592 # Define log handle | |
| 593 if out_args['log_file'] is None: | |
| 594 log_handle = None | |
| 595 else: | |
| 596 log_handle = open(out_args['log_file'], 'w') | |
| 597 except: | |
| 598 #sys.stderr.write('Exception in collector file opening step\n') | |
| 599 alive.value = False | |
| 600 raise | |
| 601 | |
| 602 # Get results from queue and write to files | |
| 603 try: | |
| 604 #print 'START COLLECT', alive.value | |
| 605 # Iterator over results queue until sentinel object reached | |
| 606 start_time = time() | |
| 607 rec_count = clone_count = pass_count = fail_count = 0 | |
| 608 while alive.value: | |
| 609 # Get result from queue | |
| 610 if result_queue.empty(): continue | |
| 611 else: result = result_queue.get() | |
| 612 # Exit upon reaching sentinel | |
| 613 if result is None: break | |
| 614 #print "COLLECT", alive.value, result['id'] | |
| 615 | |
| 616 # Print progress for previous iteration and update record count | |
| 617 if rec_count == 0: | |
| 618 print('PROGRESS> Assigning clones') | |
| 619 printProgress(rec_count, result_count, 0.05, start_time) | |
| 620 rec_count += len(result.data) | |
| 621 | |
| 622 # Write passed and failed records | |
| 623 if result: | |
| 624 for clone in result.results.values(): | |
| 625 clone_count += 1 | |
| 626 for i, rec in enumerate(clone): | |
| 627 rec.annotations['CLONE'] = clone_count | |
| 628 pass_writer.writerow(rec.toDict()) | |
| 629 pass_count += 1 | |
| 630 result.log['CLONE%i-%i' % (clone_count, i + 1)] = str(rec.junction) | |
| 631 | |
| 632 else: | |
| 633 for i, rec in enumerate(result.data): | |
| 634 if fail_writer is not None: fail_writer.writerow(rec.toDict()) | |
| 635 fail_count += 1 | |
| 636 result.log['CLONE0-%i' % (i + 1)] = str(rec.junction) | |
| 637 | |
| 638 # Write log | |
| 639 printLog(result.log, handle=log_handle) | |
| 640 else: | |
| 641 sys.stderr.write('PID %s: Error in sibling process detected. Cleaning up.\n' \ | |
| 642 % os.getpid()) | |
| 643 return None | |
| 644 | |
| 645 # Print total counts | |
| 646 printProgress(rec_count, result_count, 0.05, start_time) | |
| 647 | |
| 648 # Close file handles | |
| 649 pass_handle.close() | |
| 650 if fail_handle is not None: fail_handle.close() | |
| 651 if log_handle is not None: log_handle.close() | |
| 652 | |
| 653 # Update return list | |
| 654 log = OrderedDict() | |
| 655 log['OUTPUT'] = os.path.basename(pass_handle.name) | |
| 656 log['CLONES'] = clone_count | |
| 657 log['RECORDS'] = rec_count | |
| 658 log['PASS'] = pass_count | |
| 659 log['FAIL'] = fail_count | |
| 660 collect_dict = {'log':log, 'out_files': [pass_handle.name]} | |
| 661 collect_queue.put(collect_dict) | |
| 662 except: | |
| 663 #sys.stderr.write('Exception in collector result processing step\n') | |
| 664 alive.value = False | |
| 665 raise | |
| 666 | |
| 667 return None | |
| 668 | |
| 669 | |
| 670 def collectQueueClust(alive, result_queue, collect_queue, db_file, out_args, cluster_func, cluster_args): | |
| 671 """ | |
| 672 Assembles results from a queue of individual sequence results and manages log/file I/O | |
| 673 | |
| 674 Arguments: | |
| 675 alive = a multiprocessing.Value boolean controlling whether processing continues | |
| 676 if False exit process | |
| 677 result_queue = a multiprocessing.Queue holding processQueue results | |
| 678 collect_queue = a multiprocessing.Queue to store collector return values | |
| 679 db_file = the input database file name | |
| 680 out_args = common output argument dictionary from parseCommonArgs | |
| 681 cluster_func = the function to call for carrying out clustering on distance matrix | |
| 682 cluster_args = a dictionary of arguments to pass to cluster_func | |
| 683 | |
| 684 Returns: | |
| 685 None | |
| 686 (adds 'log' and 'out_files' to collect_dict) | |
| 687 """ | |
| 688 # Open output files | |
| 689 try: | |
| 690 | |
| 691 # Iterate over Ig records to count and order by junction length | |
| 692 result_count = 0 | |
| 693 records = {} | |
| 694 # print 'Reading file...' | |
| 695 db_iter = readDbFile(db_file) | |
| 696 for rec in db_iter: | |
| 697 records[rec.id] = rec | |
| 698 result_count += 1 | |
| 699 records = OrderedDict(sorted(list(records.items()), key=lambda i: i[1].junction_length)) | |
| 700 | |
| 701 # Define empty matrix to store assembled results | |
| 702 dist_mat = np.zeros((result_count,result_count)) | |
| 703 | |
| 704 # Count records and define output format | |
| 705 out_type = getFileType(db_file) if out_args['out_type'] is None \ | |
| 706 else out_args['out_type'] | |
| 707 | |
| 708 # Defined successful output handle | |
| 709 pass_handle = getOutputHandle(db_file, | |
| 710 out_label='clone-pass', | |
| 711 out_dir=out_args['out_dir'], | |
| 712 out_name=out_args['out_name'], | |
| 713 out_type=out_type) | |
| 714 pass_writer = getDbWriter(pass_handle, db_file, add_fields='CLONE') | |
| 715 | |
| 716 # Defined failed cloning output handle | |
| 717 if out_args['failed']: | |
| 718 fail_handle = getOutputHandle(db_file, | |
| 719 out_label='clone-fail', | |
| 720 out_dir=out_args['out_dir'], | |
| 721 out_name=out_args['out_name'], | |
| 722 out_type=out_type) | |
| 723 fail_writer = getDbWriter(fail_handle, db_file) | |
| 724 else: | |
| 725 fail_handle = None | |
| 726 fail_writer = None | |
| 727 | |
| 728 # Open log file | |
| 729 if out_args['log_file'] is None: | |
| 730 log_handle = None | |
| 731 else: | |
| 732 log_handle = open(out_args['log_file'], 'w') | |
| 733 except: | |
| 734 alive.value = False | |
| 735 raise | |
| 736 | |
| 737 try: | |
| 738 # Iterator over results queue until sentinel object reached | |
| 739 start_time = time() | |
| 740 row_count = rec_count = 0 | |
| 741 while alive.value: | |
| 742 # Get result from queue | |
| 743 if result_queue.empty(): continue | |
| 744 else: result = result_queue.get() | |
| 745 # Exit upon reaching sentinel | |
| 746 if result is None: break | |
| 747 | |
| 748 # Print progress for previous iteration | |
| 749 if row_count == 0: | |
| 750 print('PROGRESS> Assigning clones') | |
| 751 printProgress(row_count, result_count, 0.05, start_time) | |
| 752 | |
| 753 # Update counts for iteration | |
| 754 row_count += 1 | |
| 755 rec_count += len(result) | |
| 756 | |
| 757 # Add result row to distance matrix | |
| 758 if result: | |
| 759 dist_mat[list(range(result_count-len(result),result_count)),result_count-len(result)] = result.results | |
| 760 | |
| 761 else: | |
| 762 sys.stderr.write('PID %s: Error in sibling process detected. Cleaning up.\n' \ | |
| 763 % os.getpid()) | |
| 764 return None | |
| 765 | |
| 766 # Calculate linkage and carry out clustering | |
| 767 # print dist_mat | |
| 768 clusters = cluster_func(dist_mat, **cluster_args) if dist_mat is not None else None | |
| 769 clones = {} | |
| 770 # print clusters | |
| 771 for i, c in enumerate(clusters): | |
| 772 clones.setdefault(c, []).append(records[list(records.keys())[i]]) | |
| 773 | |
| 774 # Write passed and failed records | |
| 775 clone_count = pass_count = fail_count = 0 | |
| 776 if clones: | |
| 777 for clone in clones.values(): | |
| 778 clone_count += 1 | |
| 779 for i, rec in enumerate(clone): | |
| 780 rec.annotations['CLONE'] = clone_count | |
| 781 pass_writer.writerow(rec.toDict()) | |
| 782 pass_count += 1 | |
| 783 #result.log['CLONE%i-%i' % (clone_count, i + 1)] = str(rec.junction) | |
| 784 | |
| 785 else: | |
| 786 for i, rec in enumerate(result.data): | |
| 787 fail_writer.writerow(rec.toDict()) | |
| 788 fail_count += 1 | |
| 789 #result.log['CLONE0-%i' % (i + 1)] = str(rec.junction) | |
| 790 | |
| 791 # Print final progress | |
| 792 printProgress(row_count, result_count, 0.05, start_time) | |
| 793 | |
| 794 # Close file handles | |
| 795 pass_handle.close() | |
| 796 if fail_handle is not None: fail_handle.close() | |
| 797 if log_handle is not None: log_handle.close() | |
| 798 | |
| 799 # Update return list | |
| 800 log = OrderedDict() | |
| 801 log['OUTPUT'] = os.path.basename(pass_handle.name) | |
| 802 log['CLONES'] = clone_count | |
| 803 log['RECORDS'] = rec_count | |
| 804 log['PASS'] = pass_count | |
| 805 log['FAIL'] = fail_count | |
| 806 collect_dict = {'log':log, 'out_files': [pass_handle.name]} | |
| 807 collect_queue.put(collect_dict) | |
| 808 except: | |
| 809 alive.value = False | |
| 810 raise | |
| 811 | |
| 812 return None | |
| 813 | |
| 814 | |
| 815 def defineClones(db_file, feed_func, work_func, collect_func, clone_func, cluster_func=None, | |
| 816 group_func=None, group_args={}, clone_args={}, cluster_args={}, | |
| 817 out_args=default_out_args, nproc=None, queue_size=None): | |
| 818 """ | |
| 819 Define clonally related sequences | |
| 820 | |
| 821 Arguments: | |
| 822 db_file = filename of input database | |
| 823 feed_func = the function that feeds the queue | |
| 824 work_func = the worker function that will run on each CPU | |
| 825 collect_func = the function that collects results from the workers | |
| 826 group_func = the function to use for assigning preclones | |
| 827 clone_func = the function to use for determining clones within preclonal groups | |
| 828 group_args = a dictionary of arguments to pass to group_func | |
| 829 clone_args = a dictionary of arguments to pass to clone_func | |
| 830 out_args = common output argument dictionary from parseCommonArgs | |
| 831 nproc = the number of processQueue processes; | |
| 832 if None defaults to the number of CPUs | |
| 833 queue_size = maximum size of the argument queue; | |
| 834 if None defaults to 2*nproc | |
| 835 | |
| 836 Returns: | |
| 837 a list of successful output file names | |
| 838 """ | |
| 839 # Print parameter info | |
| 840 log = OrderedDict() | |
| 841 log['START'] = 'DefineClones' | |
| 842 log['DB_FILE'] = os.path.basename(db_file) | |
| 843 if group_func is not None: | |
| 844 log['GROUP_FUNC'] = group_func.__name__ | |
| 845 log['GROUP_ARGS'] = group_args | |
| 846 log['CLONE_FUNC'] = clone_func.__name__ | |
| 847 | |
| 848 # TODO: this is yucky, but can be fixed by using a model class | |
| 849 clone_log = clone_args.copy() | |
| 850 if 'dist_mat' in clone_log: del clone_log['dist_mat'] | |
| 851 log['CLONE_ARGS'] = clone_log | |
| 852 | |
| 853 if cluster_func is not None: | |
| 854 log['CLUSTER_FUNC'] = cluster_func.__name__ | |
| 855 log['CLUSTER_ARGS'] = cluster_args | |
| 856 log['NPROC'] = nproc | |
| 857 printLog(log) | |
| 858 | |
| 859 # Define feeder function and arguments | |
| 860 feed_args = {'db_file': db_file, | |
| 861 'group_func': group_func, | |
| 862 'group_args': group_args} | |
| 863 # Define worker function and arguments | |
| 864 work_args = {'clone_func': clone_func, | |
| 865 'clone_args': clone_args} | |
| 866 # Define collector function and arguments | |
| 867 collect_args = {'db_file': db_file, | |
| 868 'out_args': out_args, | |
| 869 'cluster_func': cluster_func, | |
| 870 'cluster_args': cluster_args} | |
| 871 | |
| 872 # Call process manager | |
| 873 result = manageProcesses(feed_func, work_func, collect_func, | |
| 874 feed_args, work_args, collect_args, | |
| 875 nproc, queue_size) | |
| 876 | |
| 877 # Print log | |
| 878 result['log']['END'] = 'DefineClones' | |
| 879 printLog(result['log']) | |
| 880 | |
| 881 return result['out_files'] | |
| 882 | |
| 883 | |
| 884 def getArgParser(): | |
| 885 """ | |
| 886 Defines the ArgumentParser | |
| 887 | |
| 888 Arguments: | |
| 889 None | |
| 890 | |
| 891 Returns: | |
| 892 an ArgumentParser object | |
| 893 """ | |
| 894 # Define input and output fields | |
| 895 fields = dedent( | |
| 896 ''' | |
| 897 output files: | |
| 898 clone-pass | |
| 899 database with assigned clonal group numbers. | |
| 900 clone-fail | |
| 901 database with records failing clonal grouping. | |
| 902 | |
| 903 required fields: | |
| 904 SEQUENCE_ID, V_CALL or V_CALL_GENOTYPED, D_CALL, J_CALL, JUNCTION_LENGTH | |
| 905 | |
| 906 <field> | |
| 907 sequence field specified by the --sf parameter | |
| 908 | |
| 909 output fields: | |
| 910 CLONE | |
| 911 ''') | |
| 912 | |
| 913 # Define ArgumentParser | |
| 914 parser = ArgumentParser(description=__doc__, epilog=fields, | |
| 915 formatter_class=CommonHelpFormatter) | |
| 916 parser.add_argument('--version', action='version', | |
| 917 version='%(prog)s:' + ' %s-%s' %(__version__, __date__)) | |
| 918 subparsers = parser.add_subparsers(title='subcommands', dest='command', metavar='', | |
| 919 help='Cloning method') | |
| 920 # TODO: This is a temporary fix for Python issue 9253 | |
| 921 subparsers.required = True | |
| 922 | |
| 923 # Parent parser | |
| 924 parser_parent = getCommonArgParser(seq_in=False, seq_out=False, db_in=True, | |
| 925 multiproc=True) | |
| 926 | |
| 927 # Distance cloning method | |
| 928 parser_bygroup = subparsers.add_parser('bygroup', parents=[parser_parent], | |
| 929 formatter_class=CommonHelpFormatter, | |
| 930 help='''Defines clones as having same V assignment, | |
| 931 J assignment, and junction length with | |
| 932 specified substitution distance model.''') | |
| 933 parser_bygroup.add_argument('-f', nargs='+', action='store', dest='fields', default=None, | |
| 934 help='Additional fields to use for grouping clones (non VDJ)') | |
| 935 parser_bygroup.add_argument('--mode', action='store', dest='mode', | |
| 936 choices=('allele', 'gene'), default='gene', | |
| 937 help='''Specifies whether to use the V(D)J allele or gene for | |
| 938 initial grouping.''') | |
| 939 parser_bygroup.add_argument('--act', action='store', dest='action', default='set', | |
| 940 choices=('first', 'set'), | |
| 941 help='''Specifies how to handle multiple V(D)J assignments | |
| 942 for initial grouping.''') | |
| 943 parser_bygroup.add_argument('--model', action='store', dest='model', | |
| 944 choices=('aa', 'ham', 'm1n', 'hs1f', 'hs5f'), | |
| 945 default=default_bygroup_model, | |
| 946 help='''Specifies which substitution model to use for | |
| 947 calculating distance between sequences. Where m1n is the | |
| 948 mouse single nucleotide transition/trasversion model | |
| 949 of Smith et al, 1996; hs1f is the human single | |
| 950 nucleotide model derived from Yaari et al, 2013; hs5f | |
| 951 is the human S5F model of Yaari et al, 2013; ham is | |
| 952 nucleotide Hamming distance; and aa is amino acid | |
| 953 Hamming distance. The hs5f data should be | |
| 954 considered experimental.''') | |
| 955 parser_bygroup.add_argument('--dist', action='store', dest='distance', type=float, | |
| 956 default=default_distance, | |
| 957 help='The distance threshold for clonal grouping') | |
| 958 parser_bygroup.add_argument('--norm', action='store', dest='norm', | |
| 959 choices=('len', 'mut', 'none'), default=default_norm, | |
| 960 help='''Specifies how to normalize distances. One of none | |
| 961 (do not normalize), len (normalize by length), | |
| 962 or mut (normalize by number of mutations between sequences).''') | |
| 963 parser_bygroup.add_argument('--sym', action='store', dest='sym', | |
| 964 choices=('avg', 'min'), default=default_sym, | |
| 965 help='''Specifies how to combine asymmetric distances. One of avg | |
| 966 (average of A->B and B->A) or min (minimum of A->B and B->A).''') | |
| 967 parser_bygroup.add_argument('--link', action='store', dest='linkage', | |
| 968 choices=('single', 'average', 'complete'), default=default_linkage, | |
| 969 help='''Type of linkage to use for hierarchical clustering.''') | |
| 970 parser_bygroup.add_argument('--sf', action='store', dest='seq_field', | |
| 971 default=default_seq_field, | |
| 972 help='''The name of the field to be used to calculate | |
| 973 distance between records''') | |
| 974 parser_bygroup.set_defaults(feed_func=feedQueue) | |
| 975 parser_bygroup.set_defaults(work_func=processQueue) | |
| 976 parser_bygroup.set_defaults(collect_func=collectQueue) | |
| 977 parser_bygroup.set_defaults(group_func=indexJunctions) | |
| 978 parser_bygroup.set_defaults(clone_func=distanceClones) | |
| 979 | |
| 980 | |
| 981 # Hierarchical clustering cloning method | |
| 982 parser_hclust = subparsers.add_parser('hclust', parents=[parser_parent], | |
| 983 formatter_class=CommonHelpFormatter, | |
| 984 help='Defines clones by specified distance metric on CDR3s and \ | |
| 985 cutting of hierarchical clustering tree') | |
| 986 # parser_hclust.add_argument('-f', nargs='+', action='store', dest='fields', default=None, | |
| 987 # help='Fields to use for grouping clones (non VDJ)') | |
| 988 parser_hclust.add_argument('--method', action='store', dest='method', | |
| 989 choices=('chen2010', 'ademokun2011'), default=default_hclust_model, | |
| 990 help='Specifies which cloning method to use for calculating distance \ | |
| 991 between CDR3s, computing linkage, and cutting clusters') | |
| 992 parser_hclust.set_defaults(feed_func=feedQueueClust) | |
| 993 parser_hclust.set_defaults(work_func=processQueueClust) | |
| 994 parser_hclust.set_defaults(collect_func=collectQueueClust) | |
| 995 parser_hclust.set_defaults(cluster_func=hierClust) | |
| 996 | |
| 997 return parser | |
| 998 | |
| 999 | |
| 1000 if __name__ == '__main__': | |
| 1001 """ | |
| 1002 Parses command line arguments and calls main function | |
| 1003 """ | |
| 1004 # Parse arguments | |
| 1005 parser = getArgParser() | |
| 1006 args = parser.parse_args() | |
| 1007 args_dict = parseCommonArgs(args) | |
| 1008 # Convert case of fields | |
| 1009 if 'seq_field' in args_dict: | |
| 1010 args_dict['seq_field'] = args_dict['seq_field'].upper() | |
| 1011 if 'fields' in args_dict and args_dict['fields'] is not None: | |
| 1012 args_dict['fields'] = [f.upper() for f in args_dict['fields']] | |
| 1013 | |
| 1014 # Define clone_args | |
| 1015 if args.command == 'bygroup': | |
| 1016 args_dict['group_args'] = {'fields': args_dict['fields'], | |
| 1017 'action': args_dict['action'], | |
| 1018 'mode':args_dict['mode']} | |
| 1019 args_dict['clone_args'] = {'model': args_dict['model'], | |
| 1020 'distance': args_dict['distance'], | |
| 1021 'norm': args_dict['norm'], | |
| 1022 'sym': args_dict['sym'], | |
| 1023 'linkage': args_dict['linkage'], | |
| 1024 'seq_field': args_dict['seq_field']} | |
| 1025 | |
| 1026 # TODO: can be cleaned up with abstract model class | |
| 1027 args_dict['clone_args']['dist_mat'] = getModelMatrix(args_dict['model']) | |
| 1028 | |
| 1029 del args_dict['fields'] | |
| 1030 del args_dict['action'] | |
| 1031 del args_dict['mode'] | |
| 1032 del args_dict['model'] | |
| 1033 del args_dict['distance'] | |
| 1034 del args_dict['norm'] | |
| 1035 del args_dict['sym'] | |
| 1036 del args_dict['linkage'] | |
| 1037 del args_dict['seq_field'] | |
| 1038 | |
| 1039 # Define clone_args | |
| 1040 if args.command == 'hclust': | |
| 1041 dist_funcs = {'chen2010':distChen2010, 'ademokun2011':distAdemokun2011} | |
| 1042 args_dict['clone_func'] = dist_funcs[args_dict['method']] | |
| 1043 args_dict['cluster_args'] = {'method': args_dict['method']} | |
| 1044 #del args_dict['fields'] | |
| 1045 del args_dict['method'] | |
| 1046 | |
| 1047 # Call defineClones | |
| 1048 del args_dict['command'] | |
| 1049 del args_dict['db_files'] | |
| 1050 for f in args.__dict__['db_files']: | |
| 1051 args_dict['db_file'] = f | |
| 1052 defineClones(**args_dict) |
