comparison DefineClones.py @ 0:183edf446dcf draft default tip

Uploaded
author davidvanzessen
date Mon, 17 Jul 2017 07:44:27 -0400
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:183edf446dcf
1 #!/usr/bin/env python3
2 """
3 Assign Ig sequences into clones
4 """
5 # Info
6 __author__ = 'Namita Gupta, Jason Anthony Vander Heiden, Gur Yaari, Mohamed Uduman'
7 from changeo import __version__, __date__
8
9 # Imports
10 import os
11 import re
12 import sys
13 import csv
14 import numpy as np
15 from argparse import ArgumentParser
16 from collections import OrderedDict
17 from itertools import chain
18 from textwrap import dedent
19 from time import time
20 from Bio import pairwise2
21 from Bio.Seq import translate
22
23 # Presto and changeo imports
24 from presto.Defaults import default_out_args
25 from presto.IO import getFileType, getOutputHandle, printLog, printProgress
26 from presto.Multiprocessing import manageProcesses
27 from presto.Sequence import getDNAScoreDict
28 from changeo.Commandline import CommonHelpFormatter, checkArgs, getCommonArgParser, parseCommonArgs
29 from changeo.Distance import distance_models, calcDistances, formClusters
30 from changeo.IO import getDbWriter, readDbFile, countDbFile
31 from changeo.Multiprocessing import DbData, DbResult
32
33 # Defaults
34 default_translate = False
35 default_distance = 0.0
36 default_index_mode = 'gene'
37 default_index_action = 'set'
38 default_bygroup_model = 'ham'
39 default_hclust_model = 'chen2010'
40 default_seq_field = 'JUNCTION'
41 default_norm = 'len'
42 default_sym = 'avg'
43 default_linkage = 'single'
44 choices_bygroup_model = ('ham', 'aa', 'hh_s1f', 'hh_s5f', 'mk_rs1nf', 'mk_rs5nf', 'hs1f_compat', 'm1n_compat')
45
46
47 def indexByIdentity(index, key, rec, fields=None):
48 """
49 Updates a preclone index with a simple key
50
51 Arguments:
52 index = preclone index from indexJunctions
53 key = index key
54 rec = IgRecord to add to the index
55 fields = additional annotation fields to use to group preclones;
56 if None use only V, J and junction length
57
58 Returns:
59 None. Updates index with new key and records.
60 """
61 index.setdefault(tuple(key), []).append(rec)
62
63
64 def indexByUnion(index, key, rec, fields=None):
65 """
66 Updates a preclone index with the union of nested keys
67
68 Arguments:
69 index = preclone index from indexJunctions
70 key = index key
71 rec = IgRecord to add to the index
72 fields = additional annotation fields to use to group preclones;
73 if None use only V, J and junction length
74
75 Returns:
76 None. Updates index with new key and records.
77 """
78 # List of values for this/new key
79 val = [rec]
80 f_range = list(range(2, 3 + (len(fields) if fields else 0)))
81
82 # See if field/junction length combination exists in index
83 outer_dict = index
84 for field in f_range:
85 try:
86 outer_dict = outer_dict[key[field]]
87 except (KeyError):
88 outer_dict = None
89 break
90 # If field combination exists, look through Js
91 j_matches = []
92 if outer_dict is not None:
93 for j in outer_dict.keys():
94 if not set(key[1]).isdisjoint(set(j)):
95 key[1] = tuple(set(key[1]).union(set(j)))
96 j_matches += [j]
97 # If J overlap exists, look through Vs for each J
98 for j in j_matches:
99 v_matches = []
100 # Collect V matches for this J
101 for v in outer_dict[j].keys():
102 if not set(key[0]).isdisjoint(set(v)):
103 key[0] = tuple(set(key[0]).union(set(v)))
104 v_matches += [v]
105 # If there are V overlaps for this J, pop them out
106 if v_matches:
107 val += list(chain(*(outer_dict[j].pop(v) for v in v_matches)))
108 # If the J dict is now empty, remove it
109 if not outer_dict[j]:
110 outer_dict.pop(j, None)
111
112 # Add value(s) into index nested dictionary
113 # OMG Python pointers are the best!
114 # Add field dictionaries into index
115 outer_dict = index
116 for field in f_range:
117 outer_dict.setdefault(key[field], {})
118 outer_dict = outer_dict[key[field]]
119 # Add J, then V into index
120 if key[1] in outer_dict:
121 outer_dict[key[1]].update({key[0]: val})
122 else:
123 outer_dict[key[1]] = {key[0]: val}
124
125
126 def indexJunctions(db_iter, fields=None, mode=default_index_mode,
127 action=default_index_action):
128 """
129 Identifies preclonal groups by V, J and junction length
130
131 Arguments:
132 db_iter = an iterator of IgRecords defined by readDbFile
133 fields = additional annotation fields to use to group preclones;
134 if None use only V, J and junction length
135 mode = specificity of alignment call to use for assigning preclones;
136 one of ('allele', 'gene')
137 action = how to handle multiple value fields when assigning preclones;
138 one of ('first', 'set')
139
140 Returns:
141 a dictionary of {(V, J, junction length):[IgRecords]}
142 """
143 # print(fields)
144 # Define functions for grouping keys
145 if mode == 'allele' and fields is None:
146 def _get_key(rec, act):
147 return [rec.getVAllele(act), rec.getJAllele(act),
148 None if rec.junction is None else len(rec.junction)]
149 elif mode == 'gene' and fields is None:
150 def _get_key(rec, act):
151 return [rec.getVGene(act), rec.getJGene(act),
152 None if rec.junction is None else len(rec.junction)]
153 elif mode == 'allele' and fields is not None:
154 def _get_key(rec, act):
155 vdj = [rec.getVAllele(act), rec.getJAllele(act),
156 None if rec.junction is None else len(rec.junction)]
157 ann = [rec.toDict().get(k, None) for k in fields]
158 return list(chain(vdj, ann))
159 elif mode == 'gene' and fields is not None:
160 def _get_key(rec, act):
161 vdj = [rec.getVGene(act), rec.getJGene(act),
162 None if rec.junction is None else len(rec.junction)]
163 ann = [rec.toDict().get(k, None) for k in fields]
164 return list(chain(vdj, ann))
165
166 # Function to flatten nested dictionary
167 def _flatten_dict(d, parent_key=''):
168 items = []
169 for k, v in d.items():
170 new_key = parent_key + [k] if parent_key else [k]
171 if isinstance(v, dict):
172 items.extend(_flatten_dict(v, new_key).items())
173 else:
174 items.append((new_key, v))
175 flat_dict = {None if None in i[0] else tuple(i[0]): i[1] for i in items}
176 return flat_dict
177
178 if action == 'first':
179 index_func = indexByIdentity
180 elif action == 'set':
181 index_func = indexByUnion
182 else:
183 sys.stderr.write('Unrecognized action: %s.\n' % action)
184
185 start_time = time()
186 clone_index = {}
187 rec_count = 0
188 for rec in db_iter:
189 key = _get_key(rec, action)
190
191 # Print progress
192 if rec_count == 0:
193 print('PROGRESS> Grouping sequences')
194
195 printProgress(rec_count, step=1000, start_time=start_time)
196 rec_count += 1
197
198 # Assigned passed preclone records to key and failed to index None
199 if all([k is not None and k != '' for k in key]):
200 # Update index dictionary
201 index_func(clone_index, key, rec, fields)
202 else:
203 clone_index.setdefault(None, []).append(rec)
204
205 printProgress(rec_count, step=1000, start_time=start_time, end=True)
206
207 if action == 'set':
208 clone_index = _flatten_dict(clone_index)
209
210 return clone_index
211
212
213 def distanceClones(records, model=default_bygroup_model, distance=default_distance,
214 dist_mat=None, norm=default_norm, sym=default_sym,
215 linkage=default_linkage, seq_field=default_seq_field):
216 """
217 Separates a set of IgRecords into clones
218
219 Arguments:
220 records = an iterator of IgRecords
221 model = substitution model used to calculate distance
222 distance = the distance threshold to assign clonal groups
223 dist_mat = pandas DataFrame of pairwise nucleotide or amino acid distances
224 norm = normalization method
225 sym = symmetry method
226 linkage = type of linkage
227 seq_field = sequence field used to calculate distance between records
228
229 Returns:
230 a dictionary of lists defining {clone number: [IgRecords clonal group]}
231 """
232 # Get distance matrix if not provided
233 if dist_mat is None:
234 try:
235 dist_mat = distance_models[model]
236 except KeyError:
237 sys.exit('Unrecognized distance model: %s' % args_dict['model'])
238
239 # TODO: can be cleaned up with abstract model class
240 # Determine length of n-mers
241 if model in ['hs1f_compat', 'm1n_compat', 'aa', 'ham', 'hh_s1f', 'mk_rs1nf']:
242 nmer_len = 1
243 elif model in ['hh_s5f', 'mk_rs5nf']:
244 nmer_len = 5
245 else:
246 sys.exit('Unrecognized distance model: %s.\n' % model)
247
248 # Define unique junction mapping
249 seq_map = {}
250 for ig in records:
251 seq = ig.getSeqField(seq_field)
252 # Check if sequence length is 0
253 if len(seq) == 0:
254 return None
255
256 seq = re.sub('[\.-]', 'N', str(seq))
257 if model == 'aa': seq = translate(seq)
258
259 seq_map.setdefault(seq, []).append(ig)
260
261 # Process records
262 if len(seq_map) == 1:
263 return {1:records}
264
265 # Define sequences
266 seqs = list(seq_map.keys())
267
268 # Calculate pairwise distance matrix
269 dists = calcDistances(seqs, nmer_len, dist_mat, sym=sym, norm=norm)
270
271 # Perform hierarchical clustering
272 clusters = formClusters(dists, linkage, distance)
273
274 # Turn clusters into clone dictionary
275 clone_dict = {}
276 for i, c in enumerate(clusters):
277 clone_dict.setdefault(c, []).extend(seq_map[seqs[i]])
278
279 return clone_dict
280
281
282 def distChen2010(records):
283 """
284 Calculate pairwise distances as defined in Chen 2010
285
286 Arguments:
287 records = list of IgRecords where first is query to be compared to others in list
288
289 Returns:
290 list of distances
291 """
292 # Pull out query sequence and V/J information
293 query = records.popitem(last=False)
294 query_cdr3 = query.junction[3:-3]
295 query_v_allele = query.getVAllele()
296 query_v_gene = query.getVGene()
297 query_v_family = query.getVFamily()
298 query_j_allele = query.getJAllele()
299 query_j_gene = query.getJGene()
300 # Create alignment scoring dictionary
301 score_dict = getDNAScoreDict()
302
303 scores = [0]*len(records)
304 for i in range(len(records)):
305 ld = pairwise2.align.globalds(query_cdr3, records[i].junction[3:-3],
306 score_dict, -1, -1, one_alignment_only=True)
307 # Check V similarity
308 if records[i].getVAllele() == query_v_allele: ld += 0
309 elif records[i].getVGene() == query_v_gene: ld += 1
310 elif records[i].getVFamily() == query_v_family: ld += 3
311 else: ld += 5
312 # Check J similarity
313 if records[i].getJAllele() == query_j_allele: ld += 0
314 elif records[i].getJGene() == query_j_gene: ld += 1
315 else: ld += 3
316 # Divide by length
317 scores[i] = ld/max(len(records[i].junction[3:-3]), query_cdr3)
318
319 return scores
320
321
322 def distAdemokun2011(records):
323 """
324 Calculate pairwise distances as defined in Ademokun 2011
325
326 Arguments:
327 records = list of IgRecords where first is query to be compared to others in list
328
329 Returns:
330 list of distances
331 """
332 # Pull out query sequence and V family information
333 query = records.popitem(last=False)
334 query_cdr3 = query.junction[3:-3]
335 query_v_family = query.getVFamily()
336 # Create alignment scoring dictionary
337 score_dict = getDNAScoreDict()
338
339 scores = [0]*len(records)
340 for i in range(len(records)):
341
342 if abs(len(query_cdr3) - len(records[i].junction[3:-3])) > 10:
343 scores[i] = 1
344 elif query_v_family != records[i].getVFamily():
345 scores[i] = 1
346 else:
347 ld = pairwise2.align.globalds(query_cdr3, records[i].junction[3:-3],
348 score_dict, -1, -1, one_alignment_only=True)
349 scores[i] = ld/min(len(records[i].junction[3:-3]), query_cdr3)
350
351 return scores
352
353
354 def hierClust(dist_mat, method='chen2010'):
355 """
356 Calculate hierarchical clustering
357
358 Arguments:
359 dist_mat = square-formed distance matrix of pairwise CDR3 comparisons
360
361 Returns:
362 list of cluster ids
363 """
364 if method == 'chen2010':
365 clusters = formClusters(dist_mat, 'average', 0.32)
366 elif method == 'ademokun2011':
367 clusters = formClusters(dist_mat, 'complete', 0.25)
368 else: clusters = np.ones(dist_mat.shape[0])
369
370 return clusters
371
372 # TODO: Merge duplicate feed, process and collect functions.
373 def feedQueue(alive, data_queue, db_file, group_func, group_args={}):
374 """
375 Feeds the data queue with Ig records
376
377 Arguments:
378 alive = a multiprocessing.Value boolean controlling whether processing continues
379 if False exit process
380 data_queue = a multiprocessing.Queue to hold data for processing
381 db_file = the Ig record database file
382 group_func = the function to use for assigning preclones
383 group_args = a dictionary of arguments to pass to group_func
384
385 Returns:
386 None
387 """
388 # Open input file and perform grouping
389 try:
390 # Iterate over Ig records and assign groups
391 db_iter = readDbFile(db_file)
392 clone_dict = group_func(db_iter, **group_args)
393 except:
394 #sys.stderr.write('Exception in feeder grouping step\n')
395 alive.value = False
396 raise
397
398 # Add groups to data queue
399 try:
400 #print 'START FEED', alive.value
401 # Iterate over groups and feed data queue
402 clone_iter = iter(clone_dict.items())
403 while alive.value:
404 # Get data from queue
405 if data_queue.full(): continue
406 else: data = next(clone_iter, None)
407 # Exit upon reaching end of iterator
408 if data is None: break
409 #print "FEED", alive.value, k
410
411 # Feed queue
412 data_queue.put(DbData(*data))
413 else:
414 sys.stderr.write('PID %s: Error in sibling process detected. Cleaning up.\n' \
415 % os.getpid())
416 return None
417 except:
418 #sys.stderr.write('Exception in feeder queue feeding step\n')
419 alive.value = False
420 raise
421
422 return None
423
424
425 def feedQueueClust(alive, data_queue, db_file, group_func=None, group_args={}):
426 """
427 Feeds the data queue with Ig records
428
429 Arguments:
430 alive = a multiprocessing.Value boolean controlling whether processing continues
431 if False exit process
432 data_queue = a multiprocessing.Queue to hold data for processing
433 db_file = the Ig record database file
434
435 Returns:
436 None
437 """
438 # Open input file and perform grouping
439 try:
440 # Iterate over Ig records and order by junction length
441 records = {}
442 db_iter = readDbFile(db_file)
443 for rec in db_iter:
444 records[rec.id] = rec
445 records = OrderedDict(sorted(list(records.items()), key=lambda i: i[1].junction_length))
446 dist_dict = {}
447 for __ in range(len(records)):
448 k,v = records.popitem(last=False)
449 dist_dict[k] = [v].append(list(records.values()))
450 except:
451 #sys.stderr.write('Exception in feeder grouping step\n')
452 alive.value = False
453 raise
454
455 # Add groups to data queue
456 try:
457 # print 'START FEED', alive.value
458 # Iterate over groups and feed data queue
459 dist_iter = iter(dist_dict.items())
460 while alive.value:
461 # Get data from queue
462 if data_queue.full(): continue
463 else: data = next(dist_iter, None)
464 # Exit upon reaching end of iterator
465 if data is None: break
466 #print "FEED", alive.value, k
467
468 # Feed queue
469 data_queue.put(DbData(*data))
470 else:
471 sys.stderr.write('PID %s: Error in sibling process detected. Cleaning up.\n' \
472 % os.getpid())
473 return None
474 except:
475 #sys.stderr.write('Exception in feeder queue feeding step\n')
476 alive.value = False
477 raise
478
479 return None
480
481
482 def processQueue(alive, data_queue, result_queue, clone_func, clone_args):
483 """
484 Pulls from data queue, performs calculations, and feeds results queue
485
486 Arguments:
487 alive = a multiprocessing.Value boolean controlling whether processing continues
488 if False exit process
489 data_queue = a multiprocessing.Queue holding data to process
490 result_queue = a multiprocessing.Queue to hold processed results
491 clone_func = the function to call for clonal assignment
492 clone_args = a dictionary of arguments to pass to clone_func
493
494 Returns:
495 None
496 """
497 try:
498 # Iterator over data queue until sentinel object reached
499 while alive.value:
500 # Get data from queue
501 if data_queue.empty(): continue
502 else: data = data_queue.get()
503 # Exit upon reaching sentinel
504 if data is None: break
505
506 # Define result object for iteration and get data records
507 records = data.data
508 # print(data.id)
509 result = DbResult(data.id, records)
510
511 # Check for invalid data (due to failed indexing) and add failed result
512 if not data:
513 result_queue.put(result)
514 continue
515
516 # Add V(D)J to log
517 result.log['ID'] = ','.join([str(x) for x in data.id])
518 result.log['VALLELE'] = ','.join(set([(r.getVAllele() or '') for r in records]))
519 result.log['DALLELE'] = ','.join(set([(r.getDAllele() or '') for r in records]))
520 result.log['JALLELE'] = ','.join(set([(r.getJAllele() or '') for r in records]))
521 result.log['JUNCLEN'] = ','.join(set([(str(len(r.junction)) or '0') for r in records]))
522 result.log['SEQUENCES'] = len(records)
523
524 # Checking for preclone failure and assign clones
525 clones = clone_func(records, **clone_args) if data else None
526
527 # import cProfile
528 # prof = cProfile.Profile()
529 # clones = prof.runcall(clone_func, records, **clone_args)
530 # prof.dump_stats('worker-%d.prof' % os.getpid())
531
532 if clones is not None:
533 result.results = clones
534 result.valid = True
535 result.log['CLONES'] = len(clones)
536 else:
537 result.log['CLONES'] = 0
538
539 # Feed results to result queue
540 result_queue.put(result)
541 else:
542 sys.stderr.write('PID %s: Error in sibling process detected. Cleaning up.\n' \
543 % os.getpid())
544 return None
545 except:
546 #sys.stderr.write('Exception in worker\n')
547 alive.value = False
548 raise
549
550 return None
551
552
553 def processQueueClust(alive, data_queue, result_queue, clone_func, clone_args):
554 """
555 Pulls from data queue, performs calculations, and feeds results queue
556
557 Arguments:
558 alive = a multiprocessing.Value boolean controlling whether processing continues
559 if False exit process
560 data_queue = a multiprocessing.Queue holding data to process
561 result_queue = a multiprocessing.Queue to hold processed results
562 clone_func = the function to call for calculating pairwise distances between sequences
563 clone_args = a dictionary of arguments to pass to clone_func
564
565 Returns:
566 None
567 """
568
569 try:
570 # print 'START WORK', alive.value
571 # Iterator over data queue until sentinel object reached
572 while alive.value:
573 # Get data from queue
574 if data_queue.empty(): continue
575 else: data = data_queue.get()
576 # Exit upon reaching sentinel
577 if data is None: break
578 # print "WORK", alive.value, data['id']
579
580 # Define result object for iteration and get data records
581 records = data.data
582 result = DbResult(data.id, records)
583
584 # Create row of distance matrix and check for error
585 dist_row = clone_func(records, **clone_args) if data else None
586 if dist_row is not None:
587 result.results = dist_row
588 result.valid = True
589
590 # Feed results to result queue
591 result_queue.put(result)
592 else:
593 sys.stderr.write('PID %s: Error in sibling process detected. Cleaning up.\n' \
594 % os.getpid())
595 return None
596 except:
597 #sys.stderr.write('Exception in worker\n')
598 alive.value = False
599 raise
600
601 return None
602
603
604 def collectQueue(alive, result_queue, collect_queue, db_file, out_args, cluster_func=None, cluster_args={}):
605 """
606 Assembles results from a queue of individual sequence results and manages log/file I/O
607
608 Arguments:
609 alive = a multiprocessing.Value boolean controlling whether processing continues
610 if False exit process
611 result_queue = a multiprocessing.Queue holding processQueue results
612 collect_queue = a multiprocessing.Queue to store collector return values
613 db_file = the input database file name
614 out_args = common output argument dictionary from parseCommonArgs
615 cluster_func = the function to call for carrying out clustering on distance matrix
616 cluster_args = a dictionary of arguments to pass to cluster_func
617
618 Returns:
619 None
620 (adds 'log' and 'out_files' to collect_dict)
621 """
622 # Open output files
623 try:
624 # Count records and define output format
625 out_type = getFileType(db_file) if out_args['out_type'] is None \
626 else out_args['out_type']
627 result_count = countDbFile(db_file)
628
629 # Defined successful output handle
630 pass_handle = getOutputHandle(db_file,
631 out_label='clone-pass',
632 out_dir=out_args['out_dir'],
633 out_name=out_args['out_name'],
634 out_type=out_type)
635 pass_writer = getDbWriter(pass_handle, db_file, add_fields='CLONE')
636
637 # Defined failed alignment output handle
638 if out_args['failed']:
639 fail_handle = getOutputHandle(db_file,
640 out_label='clone-fail',
641 out_dir=out_args['out_dir'],
642 out_name=out_args['out_name'],
643 out_type=out_type)
644 fail_writer = getDbWriter(fail_handle, db_file)
645 else:
646 fail_handle = None
647 fail_writer = None
648
649 # Define log handle
650 if out_args['log_file'] is None:
651 log_handle = None
652 else:
653 log_handle = open(out_args['log_file'], 'w')
654 except:
655 #sys.stderr.write('Exception in collector file opening step\n')
656 alive.value = False
657 raise
658
659 # Get results from queue and write to files
660 try:
661 #print 'START COLLECT', alive.value
662 # Iterator over results queue until sentinel object reached
663 start_time = time()
664 rec_count = clone_count = pass_count = fail_count = 0
665 while alive.value:
666 # Get result from queue
667 if result_queue.empty(): continue
668 else: result = result_queue.get()
669 # Exit upon reaching sentinel
670 if result is None: break
671 #print "COLLECT", alive.value, result['id']
672
673 # Print progress for previous iteration and update record count
674 if rec_count == 0:
675 print('PROGRESS> Assigning clones')
676 printProgress(rec_count, result_count, 0.05, start_time)
677 rec_count += len(result.data)
678
679 # Write passed and failed records
680 if result:
681 for clone in result.results.values():
682 clone_count += 1
683 for i, rec in enumerate(clone):
684 rec.annotations['CLONE'] = clone_count
685 pass_writer.writerow(rec.toDict())
686 pass_count += 1
687 result.log['CLONE%i-%i' % (clone_count, i + 1)] = str(rec.junction)
688
689 else:
690 for i, rec in enumerate(result.data):
691 if fail_writer is not None: fail_writer.writerow(rec.toDict())
692 fail_count += 1
693 result.log['CLONE0-%i' % (i + 1)] = str(rec.junction)
694
695 # Write log
696 printLog(result.log, handle=log_handle)
697 else:
698 sys.stderr.write('PID %s: Error in sibling process detected. Cleaning up.\n' \
699 % os.getpid())
700 return None
701
702 # Print total counts
703 printProgress(rec_count, result_count, 0.05, start_time)
704
705 # Close file handles
706 pass_handle.close()
707 if fail_handle is not None: fail_handle.close()
708 if log_handle is not None: log_handle.close()
709
710 # Update return list
711 log = OrderedDict()
712 log['OUTPUT'] = os.path.basename(pass_handle.name)
713 log['CLONES'] = clone_count
714 log['RECORDS'] = rec_count
715 log['PASS'] = pass_count
716 log['FAIL'] = fail_count
717 collect_dict = {'log':log, 'out_files': [pass_handle.name]}
718 collect_queue.put(collect_dict)
719 except:
720 #sys.stderr.write('Exception in collector result processing step\n')
721 alive.value = False
722 raise
723
724 return None
725
726
727 def collectQueueClust(alive, result_queue, collect_queue, db_file, out_args, cluster_func, cluster_args):
728 """
729 Assembles results from a queue of individual sequence results and manages log/file I/O
730
731 Arguments:
732 alive = a multiprocessing.Value boolean controlling whether processing continues
733 if False exit process
734 result_queue = a multiprocessing.Queue holding processQueue results
735 collect_queue = a multiprocessing.Queue to store collector return values
736 db_file = the input database file name
737 out_args = common output argument dictionary from parseCommonArgs
738 cluster_func = the function to call for carrying out clustering on distance matrix
739 cluster_args = a dictionary of arguments to pass to cluster_func
740
741 Returns:
742 None
743 (adds 'log' and 'out_files' to collect_dict)
744 """
745 # Open output files
746 try:
747
748 # Iterate over Ig records to count and order by junction length
749 result_count = 0
750 records = {}
751 # print 'Reading file...'
752 db_iter = readDbFile(db_file)
753 for rec in db_iter:
754 records[rec.id] = rec
755 result_count += 1
756 records = OrderedDict(sorted(list(records.items()), key=lambda i: i[1].junction_length))
757
758 # Define empty matrix to store assembled results
759 dist_mat = np.zeros((result_count,result_count))
760
761 # Count records and define output format
762 out_type = getFileType(db_file) if out_args['out_type'] is None \
763 else out_args['out_type']
764
765 # Defined successful output handle
766 pass_handle = getOutputHandle(db_file,
767 out_label='clone-pass',
768 out_dir=out_args['out_dir'],
769 out_name=out_args['out_name'],
770 out_type=out_type)
771 pass_writer = getDbWriter(pass_handle, db_file, add_fields='CLONE')
772
773 # Defined failed cloning output handle
774 if out_args['failed']:
775 fail_handle = getOutputHandle(db_file,
776 out_label='clone-fail',
777 out_dir=out_args['out_dir'],
778 out_name=out_args['out_name'],
779 out_type=out_type)
780 fail_writer = getDbWriter(fail_handle, db_file)
781 else:
782 fail_handle = None
783 fail_writer = None
784
785 # Open log file
786 if out_args['log_file'] is None:
787 log_handle = None
788 else:
789 log_handle = open(out_args['log_file'], 'w')
790 except:
791 alive.value = False
792 raise
793
794 try:
795 # Iterator over results queue until sentinel object reached
796 start_time = time()
797 row_count = rec_count = 0
798 while alive.value:
799 # Get result from queue
800 if result_queue.empty(): continue
801 else: result = result_queue.get()
802 # Exit upon reaching sentinel
803 if result is None: break
804
805 # Print progress for previous iteration
806 if row_count == 0:
807 print('PROGRESS> Assigning clones')
808 printProgress(row_count, result_count, 0.05, start_time)
809
810 # Update counts for iteration
811 row_count += 1
812 rec_count += len(result)
813
814 # Add result row to distance matrix
815 if result:
816 dist_mat[list(range(result_count-len(result),result_count)),result_count-len(result)] = result.results
817
818 else:
819 sys.stderr.write('PID %s: Error in sibling process detected. Cleaning up.\n' \
820 % os.getpid())
821 return None
822
823 # Calculate linkage and carry out clustering
824 # print dist_mat
825 clusters = cluster_func(dist_mat, **cluster_args) if dist_mat is not None else None
826 clones = {}
827 # print clusters
828 for i, c in enumerate(clusters):
829 clones.setdefault(c, []).append(records[list(records.keys())[i]])
830
831 # Write passed and failed records
832 clone_count = pass_count = fail_count = 0
833 if clones:
834 for clone in clones.values():
835 clone_count += 1
836 for i, rec in enumerate(clone):
837 rec.annotations['CLONE'] = clone_count
838 pass_writer.writerow(rec.toDict())
839 pass_count += 1
840 #result.log['CLONE%i-%i' % (clone_count, i + 1)] = str(rec.junction)
841
842 else:
843 for i, rec in enumerate(result.data):
844 fail_writer.writerow(rec.toDict())
845 fail_count += 1
846 #result.log['CLONE0-%i' % (i + 1)] = str(rec.junction)
847
848 # Print final progress
849 printProgress(row_count, result_count, 0.05, start_time)
850
851 # Close file handles
852 pass_handle.close()
853 if fail_handle is not None: fail_handle.close()
854 if log_handle is not None: log_handle.close()
855
856 # Update return list
857 log = OrderedDict()
858 log['OUTPUT'] = os.path.basename(pass_handle.name)
859 log['CLONES'] = clone_count
860 log['RECORDS'] = rec_count
861 log['PASS'] = pass_count
862 log['FAIL'] = fail_count
863 collect_dict = {'log':log, 'out_files': [pass_handle.name]}
864 collect_queue.put(collect_dict)
865 except:
866 alive.value = False
867 raise
868
869 return None
870
871
872 def defineClones(db_file, feed_func, work_func, collect_func, clone_func, cluster_func=None,
873 group_func=None, group_args={}, clone_args={}, cluster_args={},
874 out_args=default_out_args, nproc=None, queue_size=None):
875 """
876 Define clonally related sequences
877
878 Arguments:
879 db_file = filename of input database
880 feed_func = the function that feeds the queue
881 work_func = the worker function that will run on each CPU
882 collect_func = the function that collects results from the workers
883 group_func = the function to use for assigning preclones
884 clone_func = the function to use for determining clones within preclonal groups
885 group_args = a dictionary of arguments to pass to group_func
886 clone_args = a dictionary of arguments to pass to clone_func
887 out_args = common output argument dictionary from parseCommonArgs
888 nproc = the number of processQueue processes;
889 if None defaults to the number of CPUs
890 queue_size = maximum size of the argument queue;
891 if None defaults to 2*nproc
892
893 Returns:
894 a list of successful output file names
895 """
896 # Print parameter info
897 log = OrderedDict()
898 log['START'] = 'DefineClones'
899 log['DB_FILE'] = os.path.basename(db_file)
900 if group_func is not None:
901 log['GROUP_FUNC'] = group_func.__name__
902 log['GROUP_ARGS'] = group_args
903 log['CLONE_FUNC'] = clone_func.__name__
904
905 # TODO: this is yucky, but can be fixed by using a model class
906 clone_log = clone_args.copy()
907 if 'dist_mat' in clone_log: del clone_log['dist_mat']
908 log['CLONE_ARGS'] = clone_log
909
910 if cluster_func is not None:
911 log['CLUSTER_FUNC'] = cluster_func.__name__
912 log['CLUSTER_ARGS'] = cluster_args
913 log['NPROC'] = nproc
914 printLog(log)
915
916 # Define feeder function and arguments
917 feed_args = {'db_file': db_file,
918 'group_func': group_func,
919 'group_args': group_args}
920 # Define worker function and arguments
921 work_args = {'clone_func': clone_func,
922 'clone_args': clone_args}
923 # Define collector function and arguments
924 collect_args = {'db_file': db_file,
925 'out_args': out_args,
926 'cluster_func': cluster_func,
927 'cluster_args': cluster_args}
928
929 # Call process manager
930 result = manageProcesses(feed_func, work_func, collect_func,
931 feed_args, work_args, collect_args,
932 nproc, queue_size)
933
934 # Print log
935 result['log']['END'] = 'DefineClones'
936 printLog(result['log'])
937
938 return result['out_files']
939
940
941 def getArgParser():
942 """
943 Defines the ArgumentParser
944
945 Arguments:
946 None
947
948 Returns:
949 an ArgumentParser object
950 """
951 # Define input and output fields
952 fields = dedent(
953 '''
954 output files:
955 clone-pass
956 database with assigned clonal group numbers.
957 clone-fail
958 database with records failing clonal grouping.
959
960 required fields:
961 SEQUENCE_ID, V_CALL or V_CALL_GENOTYPED, D_CALL, J_CALL, JUNCTION
962
963 <field>
964 sequence field specified by the --sf parameter
965
966 output fields:
967 CLONE
968 ''')
969
970 # Define ArgumentParser
971 parser = ArgumentParser(description=__doc__, epilog=fields,
972 formatter_class=CommonHelpFormatter)
973 parser.add_argument('--version', action='version',
974 version='%(prog)s:' + ' %s-%s' %(__version__, __date__))
975 subparsers = parser.add_subparsers(title='subcommands', dest='command', metavar='',
976 help='Cloning method')
977 # TODO: This is a temporary fix for Python issue 9253
978 subparsers.required = True
979
980 # Parent parser
981 parser_parent = getCommonArgParser(seq_in=False, seq_out=False, db_in=True,
982 multiproc=True)
983
984 # Distance cloning method
985 parser_bygroup = subparsers.add_parser('bygroup', parents=[parser_parent],
986 formatter_class=CommonHelpFormatter,
987 help='''Defines clones as having same V assignment,
988 J assignment, and junction length with
989 specified substitution distance model.''',
990 description='''Defines clones as having same V assignment,
991 J assignment, and junction length with
992 specified substitution distance model.''')
993 parser_bygroup.add_argument('-f', nargs='+', action='store', dest='fields', default=None,
994 help='Additional fields to use for grouping clones (non VDJ)')
995 parser_bygroup.add_argument('--mode', action='store', dest='mode',
996 choices=('allele', 'gene'), default=default_index_mode,
997 help='''Specifies whether to use the V(D)J allele or gene for
998 initial grouping.''')
999 parser_bygroup.add_argument('--act', action='store', dest='action',
1000 choices=('first', 'set'), default=default_index_action,
1001 help='''Specifies how to handle multiple V(D)J assignments
1002 for initial grouping.''')
1003 parser_bygroup.add_argument('--model', action='store', dest='model',
1004 choices=choices_bygroup_model,
1005 default=default_bygroup_model,
1006 help='''Specifies which substitution model to use for calculating distance
1007 between sequences. The "ham" model is nucleotide Hamming distance and
1008 "aa" is amino acid Hamming distance. The "hh_s1f" and "hh_s5f" models are
1009 human specific single nucleotide and 5-mer content models, respectively,
1010 from Yaari et al, 2013. The "mk_rs1nf" and "mk_rs5nf" models are
1011 mouse specific single nucleotide and 5-mer content models, respectively,
1012 from Cui et al, 2016. The "m1n_compat" and "hs1f_compat" models are
1013 deprecated models provided backwards compatibility with the "m1n" and
1014 "hs1f" models in Change-O v0.3.3 and SHazaM v0.1.4. Both
1015 5-mer models should be considered experimental.''')
1016 parser_bygroup.add_argument('--dist', action='store', dest='distance', type=float,
1017 default=default_distance,
1018 help='The distance threshold for clonal grouping')
1019 parser_bygroup.add_argument('--norm', action='store', dest='norm',
1020 choices=('len', 'mut', 'none'), default=default_norm,
1021 help='''Specifies how to normalize distances. One of none
1022 (do not normalize), len (normalize by length),
1023 or mut (normalize by number of mutations between sequences).''')
1024 parser_bygroup.add_argument('--sym', action='store', dest='sym',
1025 choices=('avg', 'min'), default=default_sym,
1026 help='''Specifies how to combine asymmetric distances. One of avg
1027 (average of A->B and B->A) or min (minimum of A->B and B->A).''')
1028 parser_bygroup.add_argument('--link', action='store', dest='linkage',
1029 choices=('single', 'average', 'complete'), default=default_linkage,
1030 help='''Type of linkage to use for hierarchical clustering.''')
1031 parser_bygroup.add_argument('--sf', action='store', dest='seq_field',
1032 default=default_seq_field,
1033 help='''The name of the field to be used to calculate
1034 distance between records''')
1035 parser_bygroup.set_defaults(feed_func=feedQueue)
1036 parser_bygroup.set_defaults(work_func=processQueue)
1037 parser_bygroup.set_defaults(collect_func=collectQueue)
1038 parser_bygroup.set_defaults(group_func=indexJunctions)
1039 parser_bygroup.set_defaults(clone_func=distanceClones)
1040
1041 # Chen2010
1042 parser_chen = subparsers.add_parser('chen2010', parents=[parser_parent],
1043 formatter_class=CommonHelpFormatter,
1044 help='''Defines clones by method specified in Chen, 2010.''',
1045 description='''Defines clones by method specified in Chen, 2010.''')
1046 parser_chen.set_defaults(feed_func=feedQueueClust)
1047 parser_chen.set_defaults(work_func=processQueueClust)
1048 parser_chen.set_defaults(collect_func=collectQueueClust)
1049 parser_chen.set_defaults(cluster_func=hierClust)
1050
1051 # Ademokun2011
1052 parser_ade = subparsers.add_parser('ademokun2011', parents=[parser_parent],
1053 formatter_class=CommonHelpFormatter,
1054 help='''Defines clones by method specified in Ademokun, 2011.''',
1055 description='''Defines clones by method specified in Ademokun, 2011.''')
1056 parser_ade.set_defaults(feed_func=feedQueueClust)
1057 parser_ade.set_defaults(work_func=processQueueClust)
1058 parser_ade.set_defaults(collect_func=collectQueueClust)
1059 parser_ade.set_defaults(cluster_func=hierClust)
1060
1061 return parser
1062
1063
1064 if __name__ == '__main__':
1065 """
1066 Parses command line arguments and calls main function
1067 """
1068 # Parse arguments
1069 parser = getArgParser()
1070 checkArgs(parser)
1071 args = parser.parse_args()
1072 args_dict = parseCommonArgs(args)
1073 # Convert case of fields
1074 if 'seq_field' in args_dict:
1075 args_dict['seq_field'] = args_dict['seq_field'].upper()
1076 if 'fields' in args_dict and args_dict['fields'] is not None:
1077 args_dict['fields'] = [f.upper() for f in args_dict['fields']]
1078
1079 # Define clone_args
1080 if args.command == 'bygroup':
1081 args_dict['group_args'] = {'fields': args_dict['fields'],
1082 'action': args_dict['action'],
1083 'mode':args_dict['mode']}
1084 args_dict['clone_args'] = {'model': args_dict['model'],
1085 'distance': args_dict['distance'],
1086 'norm': args_dict['norm'],
1087 'sym': args_dict['sym'],
1088 'linkage': args_dict['linkage'],
1089 'seq_field': args_dict['seq_field']}
1090
1091 # Get distance matrix
1092 try:
1093 args_dict['clone_args']['dist_mat'] = distance_models[args_dict['model']]
1094 except KeyError:
1095 sys.exit('Unrecognized distance model: %s' % args_dict['model'])
1096
1097 del args_dict['fields']
1098 del args_dict['action']
1099 del args_dict['mode']
1100 del args_dict['model']
1101 del args_dict['distance']
1102 del args_dict['norm']
1103 del args_dict['sym']
1104 del args_dict['linkage']
1105 del args_dict['seq_field']
1106
1107 # Define clone_args
1108 if args.command == 'chen2010':
1109 args_dict['clone_func'] = distChen2010
1110 args_dict['cluster_args'] = {'method': args.command }
1111
1112 if args.command == 'ademokun2011':
1113 args_dict['clone_func'] = distAdemokun2011
1114 args_dict['cluster_args'] = {'method': args.command }
1115
1116 # Call defineClones
1117 del args_dict['command']
1118 del args_dict['db_files']
1119 for f in args.__dict__['db_files']:
1120 args_dict['db_file'] = f
1121 defineClones(**args_dict)