Mercurial > repos > davidvanzessen > change_o
comparison DefineClones.py @ 0:183edf446dcf draft default tip
Uploaded
author | davidvanzessen |
---|---|
date | Mon, 17 Jul 2017 07:44:27 -0400 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:183edf446dcf |
---|---|
1 #!/usr/bin/env python3 | |
2 """ | |
3 Assign Ig sequences into clones | |
4 """ | |
5 # Info | |
6 __author__ = 'Namita Gupta, Jason Anthony Vander Heiden, Gur Yaari, Mohamed Uduman' | |
7 from changeo import __version__, __date__ | |
8 | |
9 # Imports | |
10 import os | |
11 import re | |
12 import sys | |
13 import csv | |
14 import numpy as np | |
15 from argparse import ArgumentParser | |
16 from collections import OrderedDict | |
17 from itertools import chain | |
18 from textwrap import dedent | |
19 from time import time | |
20 from Bio import pairwise2 | |
21 from Bio.Seq import translate | |
22 | |
23 # Presto and changeo imports | |
24 from presto.Defaults import default_out_args | |
25 from presto.IO import getFileType, getOutputHandle, printLog, printProgress | |
26 from presto.Multiprocessing import manageProcesses | |
27 from presto.Sequence import getDNAScoreDict | |
28 from changeo.Commandline import CommonHelpFormatter, checkArgs, getCommonArgParser, parseCommonArgs | |
29 from changeo.Distance import distance_models, calcDistances, formClusters | |
30 from changeo.IO import getDbWriter, readDbFile, countDbFile | |
31 from changeo.Multiprocessing import DbData, DbResult | |
32 | |
33 # Defaults | |
34 default_translate = False | |
35 default_distance = 0.0 | |
36 default_index_mode = 'gene' | |
37 default_index_action = 'set' | |
38 default_bygroup_model = 'ham' | |
39 default_hclust_model = 'chen2010' | |
40 default_seq_field = 'JUNCTION' | |
41 default_norm = 'len' | |
42 default_sym = 'avg' | |
43 default_linkage = 'single' | |
44 choices_bygroup_model = ('ham', 'aa', 'hh_s1f', 'hh_s5f', 'mk_rs1nf', 'mk_rs5nf', 'hs1f_compat', 'm1n_compat') | |
45 | |
46 | |
47 def indexByIdentity(index, key, rec, fields=None): | |
48 """ | |
49 Updates a preclone index with a simple key | |
50 | |
51 Arguments: | |
52 index = preclone index from indexJunctions | |
53 key = index key | |
54 rec = IgRecord to add to the index | |
55 fields = additional annotation fields to use to group preclones; | |
56 if None use only V, J and junction length | |
57 | |
58 Returns: | |
59 None. Updates index with new key and records. | |
60 """ | |
61 index.setdefault(tuple(key), []).append(rec) | |
62 | |
63 | |
64 def indexByUnion(index, key, rec, fields=None): | |
65 """ | |
66 Updates a preclone index with the union of nested keys | |
67 | |
68 Arguments: | |
69 index = preclone index from indexJunctions | |
70 key = index key | |
71 rec = IgRecord to add to the index | |
72 fields = additional annotation fields to use to group preclones; | |
73 if None use only V, J and junction length | |
74 | |
75 Returns: | |
76 None. Updates index with new key and records. | |
77 """ | |
78 # List of values for this/new key | |
79 val = [rec] | |
80 f_range = list(range(2, 3 + (len(fields) if fields else 0))) | |
81 | |
82 # See if field/junction length combination exists in index | |
83 outer_dict = index | |
84 for field in f_range: | |
85 try: | |
86 outer_dict = outer_dict[key[field]] | |
87 except (KeyError): | |
88 outer_dict = None | |
89 break | |
90 # If field combination exists, look through Js | |
91 j_matches = [] | |
92 if outer_dict is not None: | |
93 for j in outer_dict.keys(): | |
94 if not set(key[1]).isdisjoint(set(j)): | |
95 key[1] = tuple(set(key[1]).union(set(j))) | |
96 j_matches += [j] | |
97 # If J overlap exists, look through Vs for each J | |
98 for j in j_matches: | |
99 v_matches = [] | |
100 # Collect V matches for this J | |
101 for v in outer_dict[j].keys(): | |
102 if not set(key[0]).isdisjoint(set(v)): | |
103 key[0] = tuple(set(key[0]).union(set(v))) | |
104 v_matches += [v] | |
105 # If there are V overlaps for this J, pop them out | |
106 if v_matches: | |
107 val += list(chain(*(outer_dict[j].pop(v) for v in v_matches))) | |
108 # If the J dict is now empty, remove it | |
109 if not outer_dict[j]: | |
110 outer_dict.pop(j, None) | |
111 | |
112 # Add value(s) into index nested dictionary | |
113 # OMG Python pointers are the best! | |
114 # Add field dictionaries into index | |
115 outer_dict = index | |
116 for field in f_range: | |
117 outer_dict.setdefault(key[field], {}) | |
118 outer_dict = outer_dict[key[field]] | |
119 # Add J, then V into index | |
120 if key[1] in outer_dict: | |
121 outer_dict[key[1]].update({key[0]: val}) | |
122 else: | |
123 outer_dict[key[1]] = {key[0]: val} | |
124 | |
125 | |
126 def indexJunctions(db_iter, fields=None, mode=default_index_mode, | |
127 action=default_index_action): | |
128 """ | |
129 Identifies preclonal groups by V, J and junction length | |
130 | |
131 Arguments: | |
132 db_iter = an iterator of IgRecords defined by readDbFile | |
133 fields = additional annotation fields to use to group preclones; | |
134 if None use only V, J and junction length | |
135 mode = specificity of alignment call to use for assigning preclones; | |
136 one of ('allele', 'gene') | |
137 action = how to handle multiple value fields when assigning preclones; | |
138 one of ('first', 'set') | |
139 | |
140 Returns: | |
141 a dictionary of {(V, J, junction length):[IgRecords]} | |
142 """ | |
143 # print(fields) | |
144 # Define functions for grouping keys | |
145 if mode == 'allele' and fields is None: | |
146 def _get_key(rec, act): | |
147 return [rec.getVAllele(act), rec.getJAllele(act), | |
148 None if rec.junction is None else len(rec.junction)] | |
149 elif mode == 'gene' and fields is None: | |
150 def _get_key(rec, act): | |
151 return [rec.getVGene(act), rec.getJGene(act), | |
152 None if rec.junction is None else len(rec.junction)] | |
153 elif mode == 'allele' and fields is not None: | |
154 def _get_key(rec, act): | |
155 vdj = [rec.getVAllele(act), rec.getJAllele(act), | |
156 None if rec.junction is None else len(rec.junction)] | |
157 ann = [rec.toDict().get(k, None) for k in fields] | |
158 return list(chain(vdj, ann)) | |
159 elif mode == 'gene' and fields is not None: | |
160 def _get_key(rec, act): | |
161 vdj = [rec.getVGene(act), rec.getJGene(act), | |
162 None if rec.junction is None else len(rec.junction)] | |
163 ann = [rec.toDict().get(k, None) for k in fields] | |
164 return list(chain(vdj, ann)) | |
165 | |
166 # Function to flatten nested dictionary | |
167 def _flatten_dict(d, parent_key=''): | |
168 items = [] | |
169 for k, v in d.items(): | |
170 new_key = parent_key + [k] if parent_key else [k] | |
171 if isinstance(v, dict): | |
172 items.extend(_flatten_dict(v, new_key).items()) | |
173 else: | |
174 items.append((new_key, v)) | |
175 flat_dict = {None if None in i[0] else tuple(i[0]): i[1] for i in items} | |
176 return flat_dict | |
177 | |
178 if action == 'first': | |
179 index_func = indexByIdentity | |
180 elif action == 'set': | |
181 index_func = indexByUnion | |
182 else: | |
183 sys.stderr.write('Unrecognized action: %s.\n' % action) | |
184 | |
185 start_time = time() | |
186 clone_index = {} | |
187 rec_count = 0 | |
188 for rec in db_iter: | |
189 key = _get_key(rec, action) | |
190 | |
191 # Print progress | |
192 if rec_count == 0: | |
193 print('PROGRESS> Grouping sequences') | |
194 | |
195 printProgress(rec_count, step=1000, start_time=start_time) | |
196 rec_count += 1 | |
197 | |
198 # Assigned passed preclone records to key and failed to index None | |
199 if all([k is not None and k != '' for k in key]): | |
200 # Update index dictionary | |
201 index_func(clone_index, key, rec, fields) | |
202 else: | |
203 clone_index.setdefault(None, []).append(rec) | |
204 | |
205 printProgress(rec_count, step=1000, start_time=start_time, end=True) | |
206 | |
207 if action == 'set': | |
208 clone_index = _flatten_dict(clone_index) | |
209 | |
210 return clone_index | |
211 | |
212 | |
213 def distanceClones(records, model=default_bygroup_model, distance=default_distance, | |
214 dist_mat=None, norm=default_norm, sym=default_sym, | |
215 linkage=default_linkage, seq_field=default_seq_field): | |
216 """ | |
217 Separates a set of IgRecords into clones | |
218 | |
219 Arguments: | |
220 records = an iterator of IgRecords | |
221 model = substitution model used to calculate distance | |
222 distance = the distance threshold to assign clonal groups | |
223 dist_mat = pandas DataFrame of pairwise nucleotide or amino acid distances | |
224 norm = normalization method | |
225 sym = symmetry method | |
226 linkage = type of linkage | |
227 seq_field = sequence field used to calculate distance between records | |
228 | |
229 Returns: | |
230 a dictionary of lists defining {clone number: [IgRecords clonal group]} | |
231 """ | |
232 # Get distance matrix if not provided | |
233 if dist_mat is None: | |
234 try: | |
235 dist_mat = distance_models[model] | |
236 except KeyError: | |
237 sys.exit('Unrecognized distance model: %s' % args_dict['model']) | |
238 | |
239 # TODO: can be cleaned up with abstract model class | |
240 # Determine length of n-mers | |
241 if model in ['hs1f_compat', 'm1n_compat', 'aa', 'ham', 'hh_s1f', 'mk_rs1nf']: | |
242 nmer_len = 1 | |
243 elif model in ['hh_s5f', 'mk_rs5nf']: | |
244 nmer_len = 5 | |
245 else: | |
246 sys.exit('Unrecognized distance model: %s.\n' % model) | |
247 | |
248 # Define unique junction mapping | |
249 seq_map = {} | |
250 for ig in records: | |
251 seq = ig.getSeqField(seq_field) | |
252 # Check if sequence length is 0 | |
253 if len(seq) == 0: | |
254 return None | |
255 | |
256 seq = re.sub('[\.-]', 'N', str(seq)) | |
257 if model == 'aa': seq = translate(seq) | |
258 | |
259 seq_map.setdefault(seq, []).append(ig) | |
260 | |
261 # Process records | |
262 if len(seq_map) == 1: | |
263 return {1:records} | |
264 | |
265 # Define sequences | |
266 seqs = list(seq_map.keys()) | |
267 | |
268 # Calculate pairwise distance matrix | |
269 dists = calcDistances(seqs, nmer_len, dist_mat, sym=sym, norm=norm) | |
270 | |
271 # Perform hierarchical clustering | |
272 clusters = formClusters(dists, linkage, distance) | |
273 | |
274 # Turn clusters into clone dictionary | |
275 clone_dict = {} | |
276 for i, c in enumerate(clusters): | |
277 clone_dict.setdefault(c, []).extend(seq_map[seqs[i]]) | |
278 | |
279 return clone_dict | |
280 | |
281 | |
282 def distChen2010(records): | |
283 """ | |
284 Calculate pairwise distances as defined in Chen 2010 | |
285 | |
286 Arguments: | |
287 records = list of IgRecords where first is query to be compared to others in list | |
288 | |
289 Returns: | |
290 list of distances | |
291 """ | |
292 # Pull out query sequence and V/J information | |
293 query = records.popitem(last=False) | |
294 query_cdr3 = query.junction[3:-3] | |
295 query_v_allele = query.getVAllele() | |
296 query_v_gene = query.getVGene() | |
297 query_v_family = query.getVFamily() | |
298 query_j_allele = query.getJAllele() | |
299 query_j_gene = query.getJGene() | |
300 # Create alignment scoring dictionary | |
301 score_dict = getDNAScoreDict() | |
302 | |
303 scores = [0]*len(records) | |
304 for i in range(len(records)): | |
305 ld = pairwise2.align.globalds(query_cdr3, records[i].junction[3:-3], | |
306 score_dict, -1, -1, one_alignment_only=True) | |
307 # Check V similarity | |
308 if records[i].getVAllele() == query_v_allele: ld += 0 | |
309 elif records[i].getVGene() == query_v_gene: ld += 1 | |
310 elif records[i].getVFamily() == query_v_family: ld += 3 | |
311 else: ld += 5 | |
312 # Check J similarity | |
313 if records[i].getJAllele() == query_j_allele: ld += 0 | |
314 elif records[i].getJGene() == query_j_gene: ld += 1 | |
315 else: ld += 3 | |
316 # Divide by length | |
317 scores[i] = ld/max(len(records[i].junction[3:-3]), query_cdr3) | |
318 | |
319 return scores | |
320 | |
321 | |
322 def distAdemokun2011(records): | |
323 """ | |
324 Calculate pairwise distances as defined in Ademokun 2011 | |
325 | |
326 Arguments: | |
327 records = list of IgRecords where first is query to be compared to others in list | |
328 | |
329 Returns: | |
330 list of distances | |
331 """ | |
332 # Pull out query sequence and V family information | |
333 query = records.popitem(last=False) | |
334 query_cdr3 = query.junction[3:-3] | |
335 query_v_family = query.getVFamily() | |
336 # Create alignment scoring dictionary | |
337 score_dict = getDNAScoreDict() | |
338 | |
339 scores = [0]*len(records) | |
340 for i in range(len(records)): | |
341 | |
342 if abs(len(query_cdr3) - len(records[i].junction[3:-3])) > 10: | |
343 scores[i] = 1 | |
344 elif query_v_family != records[i].getVFamily(): | |
345 scores[i] = 1 | |
346 else: | |
347 ld = pairwise2.align.globalds(query_cdr3, records[i].junction[3:-3], | |
348 score_dict, -1, -1, one_alignment_only=True) | |
349 scores[i] = ld/min(len(records[i].junction[3:-3]), query_cdr3) | |
350 | |
351 return scores | |
352 | |
353 | |
354 def hierClust(dist_mat, method='chen2010'): | |
355 """ | |
356 Calculate hierarchical clustering | |
357 | |
358 Arguments: | |
359 dist_mat = square-formed distance matrix of pairwise CDR3 comparisons | |
360 | |
361 Returns: | |
362 list of cluster ids | |
363 """ | |
364 if method == 'chen2010': | |
365 clusters = formClusters(dist_mat, 'average', 0.32) | |
366 elif method == 'ademokun2011': | |
367 clusters = formClusters(dist_mat, 'complete', 0.25) | |
368 else: clusters = np.ones(dist_mat.shape[0]) | |
369 | |
370 return clusters | |
371 | |
372 # TODO: Merge duplicate feed, process and collect functions. | |
373 def feedQueue(alive, data_queue, db_file, group_func, group_args={}): | |
374 """ | |
375 Feeds the data queue with Ig records | |
376 | |
377 Arguments: | |
378 alive = a multiprocessing.Value boolean controlling whether processing continues | |
379 if False exit process | |
380 data_queue = a multiprocessing.Queue to hold data for processing | |
381 db_file = the Ig record database file | |
382 group_func = the function to use for assigning preclones | |
383 group_args = a dictionary of arguments to pass to group_func | |
384 | |
385 Returns: | |
386 None | |
387 """ | |
388 # Open input file and perform grouping | |
389 try: | |
390 # Iterate over Ig records and assign groups | |
391 db_iter = readDbFile(db_file) | |
392 clone_dict = group_func(db_iter, **group_args) | |
393 except: | |
394 #sys.stderr.write('Exception in feeder grouping step\n') | |
395 alive.value = False | |
396 raise | |
397 | |
398 # Add groups to data queue | |
399 try: | |
400 #print 'START FEED', alive.value | |
401 # Iterate over groups and feed data queue | |
402 clone_iter = iter(clone_dict.items()) | |
403 while alive.value: | |
404 # Get data from queue | |
405 if data_queue.full(): continue | |
406 else: data = next(clone_iter, None) | |
407 # Exit upon reaching end of iterator | |
408 if data is None: break | |
409 #print "FEED", alive.value, k | |
410 | |
411 # Feed queue | |
412 data_queue.put(DbData(*data)) | |
413 else: | |
414 sys.stderr.write('PID %s: Error in sibling process detected. Cleaning up.\n' \ | |
415 % os.getpid()) | |
416 return None | |
417 except: | |
418 #sys.stderr.write('Exception in feeder queue feeding step\n') | |
419 alive.value = False | |
420 raise | |
421 | |
422 return None | |
423 | |
424 | |
425 def feedQueueClust(alive, data_queue, db_file, group_func=None, group_args={}): | |
426 """ | |
427 Feeds the data queue with Ig records | |
428 | |
429 Arguments: | |
430 alive = a multiprocessing.Value boolean controlling whether processing continues | |
431 if False exit process | |
432 data_queue = a multiprocessing.Queue to hold data for processing | |
433 db_file = the Ig record database file | |
434 | |
435 Returns: | |
436 None | |
437 """ | |
438 # Open input file and perform grouping | |
439 try: | |
440 # Iterate over Ig records and order by junction length | |
441 records = {} | |
442 db_iter = readDbFile(db_file) | |
443 for rec in db_iter: | |
444 records[rec.id] = rec | |
445 records = OrderedDict(sorted(list(records.items()), key=lambda i: i[1].junction_length)) | |
446 dist_dict = {} | |
447 for __ in range(len(records)): | |
448 k,v = records.popitem(last=False) | |
449 dist_dict[k] = [v].append(list(records.values())) | |
450 except: | |
451 #sys.stderr.write('Exception in feeder grouping step\n') | |
452 alive.value = False | |
453 raise | |
454 | |
455 # Add groups to data queue | |
456 try: | |
457 # print 'START FEED', alive.value | |
458 # Iterate over groups and feed data queue | |
459 dist_iter = iter(dist_dict.items()) | |
460 while alive.value: | |
461 # Get data from queue | |
462 if data_queue.full(): continue | |
463 else: data = next(dist_iter, None) | |
464 # Exit upon reaching end of iterator | |
465 if data is None: break | |
466 #print "FEED", alive.value, k | |
467 | |
468 # Feed queue | |
469 data_queue.put(DbData(*data)) | |
470 else: | |
471 sys.stderr.write('PID %s: Error in sibling process detected. Cleaning up.\n' \ | |
472 % os.getpid()) | |
473 return None | |
474 except: | |
475 #sys.stderr.write('Exception in feeder queue feeding step\n') | |
476 alive.value = False | |
477 raise | |
478 | |
479 return None | |
480 | |
481 | |
482 def processQueue(alive, data_queue, result_queue, clone_func, clone_args): | |
483 """ | |
484 Pulls from data queue, performs calculations, and feeds results queue | |
485 | |
486 Arguments: | |
487 alive = a multiprocessing.Value boolean controlling whether processing continues | |
488 if False exit process | |
489 data_queue = a multiprocessing.Queue holding data to process | |
490 result_queue = a multiprocessing.Queue to hold processed results | |
491 clone_func = the function to call for clonal assignment | |
492 clone_args = a dictionary of arguments to pass to clone_func | |
493 | |
494 Returns: | |
495 None | |
496 """ | |
497 try: | |
498 # Iterator over data queue until sentinel object reached | |
499 while alive.value: | |
500 # Get data from queue | |
501 if data_queue.empty(): continue | |
502 else: data = data_queue.get() | |
503 # Exit upon reaching sentinel | |
504 if data is None: break | |
505 | |
506 # Define result object for iteration and get data records | |
507 records = data.data | |
508 # print(data.id) | |
509 result = DbResult(data.id, records) | |
510 | |
511 # Check for invalid data (due to failed indexing) and add failed result | |
512 if not data: | |
513 result_queue.put(result) | |
514 continue | |
515 | |
516 # Add V(D)J to log | |
517 result.log['ID'] = ','.join([str(x) for x in data.id]) | |
518 result.log['VALLELE'] = ','.join(set([(r.getVAllele() or '') for r in records])) | |
519 result.log['DALLELE'] = ','.join(set([(r.getDAllele() or '') for r in records])) | |
520 result.log['JALLELE'] = ','.join(set([(r.getJAllele() or '') for r in records])) | |
521 result.log['JUNCLEN'] = ','.join(set([(str(len(r.junction)) or '0') for r in records])) | |
522 result.log['SEQUENCES'] = len(records) | |
523 | |
524 # Checking for preclone failure and assign clones | |
525 clones = clone_func(records, **clone_args) if data else None | |
526 | |
527 # import cProfile | |
528 # prof = cProfile.Profile() | |
529 # clones = prof.runcall(clone_func, records, **clone_args) | |
530 # prof.dump_stats('worker-%d.prof' % os.getpid()) | |
531 | |
532 if clones is not None: | |
533 result.results = clones | |
534 result.valid = True | |
535 result.log['CLONES'] = len(clones) | |
536 else: | |
537 result.log['CLONES'] = 0 | |
538 | |
539 # Feed results to result queue | |
540 result_queue.put(result) | |
541 else: | |
542 sys.stderr.write('PID %s: Error in sibling process detected. Cleaning up.\n' \ | |
543 % os.getpid()) | |
544 return None | |
545 except: | |
546 #sys.stderr.write('Exception in worker\n') | |
547 alive.value = False | |
548 raise | |
549 | |
550 return None | |
551 | |
552 | |
553 def processQueueClust(alive, data_queue, result_queue, clone_func, clone_args): | |
554 """ | |
555 Pulls from data queue, performs calculations, and feeds results queue | |
556 | |
557 Arguments: | |
558 alive = a multiprocessing.Value boolean controlling whether processing continues | |
559 if False exit process | |
560 data_queue = a multiprocessing.Queue holding data to process | |
561 result_queue = a multiprocessing.Queue to hold processed results | |
562 clone_func = the function to call for calculating pairwise distances between sequences | |
563 clone_args = a dictionary of arguments to pass to clone_func | |
564 | |
565 Returns: | |
566 None | |
567 """ | |
568 | |
569 try: | |
570 # print 'START WORK', alive.value | |
571 # Iterator over data queue until sentinel object reached | |
572 while alive.value: | |
573 # Get data from queue | |
574 if data_queue.empty(): continue | |
575 else: data = data_queue.get() | |
576 # Exit upon reaching sentinel | |
577 if data is None: break | |
578 # print "WORK", alive.value, data['id'] | |
579 | |
580 # Define result object for iteration and get data records | |
581 records = data.data | |
582 result = DbResult(data.id, records) | |
583 | |
584 # Create row of distance matrix and check for error | |
585 dist_row = clone_func(records, **clone_args) if data else None | |
586 if dist_row is not None: | |
587 result.results = dist_row | |
588 result.valid = True | |
589 | |
590 # Feed results to result queue | |
591 result_queue.put(result) | |
592 else: | |
593 sys.stderr.write('PID %s: Error in sibling process detected. Cleaning up.\n' \ | |
594 % os.getpid()) | |
595 return None | |
596 except: | |
597 #sys.stderr.write('Exception in worker\n') | |
598 alive.value = False | |
599 raise | |
600 | |
601 return None | |
602 | |
603 | |
604 def collectQueue(alive, result_queue, collect_queue, db_file, out_args, cluster_func=None, cluster_args={}): | |
605 """ | |
606 Assembles results from a queue of individual sequence results and manages log/file I/O | |
607 | |
608 Arguments: | |
609 alive = a multiprocessing.Value boolean controlling whether processing continues | |
610 if False exit process | |
611 result_queue = a multiprocessing.Queue holding processQueue results | |
612 collect_queue = a multiprocessing.Queue to store collector return values | |
613 db_file = the input database file name | |
614 out_args = common output argument dictionary from parseCommonArgs | |
615 cluster_func = the function to call for carrying out clustering on distance matrix | |
616 cluster_args = a dictionary of arguments to pass to cluster_func | |
617 | |
618 Returns: | |
619 None | |
620 (adds 'log' and 'out_files' to collect_dict) | |
621 """ | |
622 # Open output files | |
623 try: | |
624 # Count records and define output format | |
625 out_type = getFileType(db_file) if out_args['out_type'] is None \ | |
626 else out_args['out_type'] | |
627 result_count = countDbFile(db_file) | |
628 | |
629 # Defined successful output handle | |
630 pass_handle = getOutputHandle(db_file, | |
631 out_label='clone-pass', | |
632 out_dir=out_args['out_dir'], | |
633 out_name=out_args['out_name'], | |
634 out_type=out_type) | |
635 pass_writer = getDbWriter(pass_handle, db_file, add_fields='CLONE') | |
636 | |
637 # Defined failed alignment output handle | |
638 if out_args['failed']: | |
639 fail_handle = getOutputHandle(db_file, | |
640 out_label='clone-fail', | |
641 out_dir=out_args['out_dir'], | |
642 out_name=out_args['out_name'], | |
643 out_type=out_type) | |
644 fail_writer = getDbWriter(fail_handle, db_file) | |
645 else: | |
646 fail_handle = None | |
647 fail_writer = None | |
648 | |
649 # Define log handle | |
650 if out_args['log_file'] is None: | |
651 log_handle = None | |
652 else: | |
653 log_handle = open(out_args['log_file'], 'w') | |
654 except: | |
655 #sys.stderr.write('Exception in collector file opening step\n') | |
656 alive.value = False | |
657 raise | |
658 | |
659 # Get results from queue and write to files | |
660 try: | |
661 #print 'START COLLECT', alive.value | |
662 # Iterator over results queue until sentinel object reached | |
663 start_time = time() | |
664 rec_count = clone_count = pass_count = fail_count = 0 | |
665 while alive.value: | |
666 # Get result from queue | |
667 if result_queue.empty(): continue | |
668 else: result = result_queue.get() | |
669 # Exit upon reaching sentinel | |
670 if result is None: break | |
671 #print "COLLECT", alive.value, result['id'] | |
672 | |
673 # Print progress for previous iteration and update record count | |
674 if rec_count == 0: | |
675 print('PROGRESS> Assigning clones') | |
676 printProgress(rec_count, result_count, 0.05, start_time) | |
677 rec_count += len(result.data) | |
678 | |
679 # Write passed and failed records | |
680 if result: | |
681 for clone in result.results.values(): | |
682 clone_count += 1 | |
683 for i, rec in enumerate(clone): | |
684 rec.annotations['CLONE'] = clone_count | |
685 pass_writer.writerow(rec.toDict()) | |
686 pass_count += 1 | |
687 result.log['CLONE%i-%i' % (clone_count, i + 1)] = str(rec.junction) | |
688 | |
689 else: | |
690 for i, rec in enumerate(result.data): | |
691 if fail_writer is not None: fail_writer.writerow(rec.toDict()) | |
692 fail_count += 1 | |
693 result.log['CLONE0-%i' % (i + 1)] = str(rec.junction) | |
694 | |
695 # Write log | |
696 printLog(result.log, handle=log_handle) | |
697 else: | |
698 sys.stderr.write('PID %s: Error in sibling process detected. Cleaning up.\n' \ | |
699 % os.getpid()) | |
700 return None | |
701 | |
702 # Print total counts | |
703 printProgress(rec_count, result_count, 0.05, start_time) | |
704 | |
705 # Close file handles | |
706 pass_handle.close() | |
707 if fail_handle is not None: fail_handle.close() | |
708 if log_handle is not None: log_handle.close() | |
709 | |
710 # Update return list | |
711 log = OrderedDict() | |
712 log['OUTPUT'] = os.path.basename(pass_handle.name) | |
713 log['CLONES'] = clone_count | |
714 log['RECORDS'] = rec_count | |
715 log['PASS'] = pass_count | |
716 log['FAIL'] = fail_count | |
717 collect_dict = {'log':log, 'out_files': [pass_handle.name]} | |
718 collect_queue.put(collect_dict) | |
719 except: | |
720 #sys.stderr.write('Exception in collector result processing step\n') | |
721 alive.value = False | |
722 raise | |
723 | |
724 return None | |
725 | |
726 | |
727 def collectQueueClust(alive, result_queue, collect_queue, db_file, out_args, cluster_func, cluster_args): | |
728 """ | |
729 Assembles results from a queue of individual sequence results and manages log/file I/O | |
730 | |
731 Arguments: | |
732 alive = a multiprocessing.Value boolean controlling whether processing continues | |
733 if False exit process | |
734 result_queue = a multiprocessing.Queue holding processQueue results | |
735 collect_queue = a multiprocessing.Queue to store collector return values | |
736 db_file = the input database file name | |
737 out_args = common output argument dictionary from parseCommonArgs | |
738 cluster_func = the function to call for carrying out clustering on distance matrix | |
739 cluster_args = a dictionary of arguments to pass to cluster_func | |
740 | |
741 Returns: | |
742 None | |
743 (adds 'log' and 'out_files' to collect_dict) | |
744 """ | |
745 # Open output files | |
746 try: | |
747 | |
748 # Iterate over Ig records to count and order by junction length | |
749 result_count = 0 | |
750 records = {} | |
751 # print 'Reading file...' | |
752 db_iter = readDbFile(db_file) | |
753 for rec in db_iter: | |
754 records[rec.id] = rec | |
755 result_count += 1 | |
756 records = OrderedDict(sorted(list(records.items()), key=lambda i: i[1].junction_length)) | |
757 | |
758 # Define empty matrix to store assembled results | |
759 dist_mat = np.zeros((result_count,result_count)) | |
760 | |
761 # Count records and define output format | |
762 out_type = getFileType(db_file) if out_args['out_type'] is None \ | |
763 else out_args['out_type'] | |
764 | |
765 # Defined successful output handle | |
766 pass_handle = getOutputHandle(db_file, | |
767 out_label='clone-pass', | |
768 out_dir=out_args['out_dir'], | |
769 out_name=out_args['out_name'], | |
770 out_type=out_type) | |
771 pass_writer = getDbWriter(pass_handle, db_file, add_fields='CLONE') | |
772 | |
773 # Defined failed cloning output handle | |
774 if out_args['failed']: | |
775 fail_handle = getOutputHandle(db_file, | |
776 out_label='clone-fail', | |
777 out_dir=out_args['out_dir'], | |
778 out_name=out_args['out_name'], | |
779 out_type=out_type) | |
780 fail_writer = getDbWriter(fail_handle, db_file) | |
781 else: | |
782 fail_handle = None | |
783 fail_writer = None | |
784 | |
785 # Open log file | |
786 if out_args['log_file'] is None: | |
787 log_handle = None | |
788 else: | |
789 log_handle = open(out_args['log_file'], 'w') | |
790 except: | |
791 alive.value = False | |
792 raise | |
793 | |
794 try: | |
795 # Iterator over results queue until sentinel object reached | |
796 start_time = time() | |
797 row_count = rec_count = 0 | |
798 while alive.value: | |
799 # Get result from queue | |
800 if result_queue.empty(): continue | |
801 else: result = result_queue.get() | |
802 # Exit upon reaching sentinel | |
803 if result is None: break | |
804 | |
805 # Print progress for previous iteration | |
806 if row_count == 0: | |
807 print('PROGRESS> Assigning clones') | |
808 printProgress(row_count, result_count, 0.05, start_time) | |
809 | |
810 # Update counts for iteration | |
811 row_count += 1 | |
812 rec_count += len(result) | |
813 | |
814 # Add result row to distance matrix | |
815 if result: | |
816 dist_mat[list(range(result_count-len(result),result_count)),result_count-len(result)] = result.results | |
817 | |
818 else: | |
819 sys.stderr.write('PID %s: Error in sibling process detected. Cleaning up.\n' \ | |
820 % os.getpid()) | |
821 return None | |
822 | |
823 # Calculate linkage and carry out clustering | |
824 # print dist_mat | |
825 clusters = cluster_func(dist_mat, **cluster_args) if dist_mat is not None else None | |
826 clones = {} | |
827 # print clusters | |
828 for i, c in enumerate(clusters): | |
829 clones.setdefault(c, []).append(records[list(records.keys())[i]]) | |
830 | |
831 # Write passed and failed records | |
832 clone_count = pass_count = fail_count = 0 | |
833 if clones: | |
834 for clone in clones.values(): | |
835 clone_count += 1 | |
836 for i, rec in enumerate(clone): | |
837 rec.annotations['CLONE'] = clone_count | |
838 pass_writer.writerow(rec.toDict()) | |
839 pass_count += 1 | |
840 #result.log['CLONE%i-%i' % (clone_count, i + 1)] = str(rec.junction) | |
841 | |
842 else: | |
843 for i, rec in enumerate(result.data): | |
844 fail_writer.writerow(rec.toDict()) | |
845 fail_count += 1 | |
846 #result.log['CLONE0-%i' % (i + 1)] = str(rec.junction) | |
847 | |
848 # Print final progress | |
849 printProgress(row_count, result_count, 0.05, start_time) | |
850 | |
851 # Close file handles | |
852 pass_handle.close() | |
853 if fail_handle is not None: fail_handle.close() | |
854 if log_handle is not None: log_handle.close() | |
855 | |
856 # Update return list | |
857 log = OrderedDict() | |
858 log['OUTPUT'] = os.path.basename(pass_handle.name) | |
859 log['CLONES'] = clone_count | |
860 log['RECORDS'] = rec_count | |
861 log['PASS'] = pass_count | |
862 log['FAIL'] = fail_count | |
863 collect_dict = {'log':log, 'out_files': [pass_handle.name]} | |
864 collect_queue.put(collect_dict) | |
865 except: | |
866 alive.value = False | |
867 raise | |
868 | |
869 return None | |
870 | |
871 | |
872 def defineClones(db_file, feed_func, work_func, collect_func, clone_func, cluster_func=None, | |
873 group_func=None, group_args={}, clone_args={}, cluster_args={}, | |
874 out_args=default_out_args, nproc=None, queue_size=None): | |
875 """ | |
876 Define clonally related sequences | |
877 | |
878 Arguments: | |
879 db_file = filename of input database | |
880 feed_func = the function that feeds the queue | |
881 work_func = the worker function that will run on each CPU | |
882 collect_func = the function that collects results from the workers | |
883 group_func = the function to use for assigning preclones | |
884 clone_func = the function to use for determining clones within preclonal groups | |
885 group_args = a dictionary of arguments to pass to group_func | |
886 clone_args = a dictionary of arguments to pass to clone_func | |
887 out_args = common output argument dictionary from parseCommonArgs | |
888 nproc = the number of processQueue processes; | |
889 if None defaults to the number of CPUs | |
890 queue_size = maximum size of the argument queue; | |
891 if None defaults to 2*nproc | |
892 | |
893 Returns: | |
894 a list of successful output file names | |
895 """ | |
896 # Print parameter info | |
897 log = OrderedDict() | |
898 log['START'] = 'DefineClones' | |
899 log['DB_FILE'] = os.path.basename(db_file) | |
900 if group_func is not None: | |
901 log['GROUP_FUNC'] = group_func.__name__ | |
902 log['GROUP_ARGS'] = group_args | |
903 log['CLONE_FUNC'] = clone_func.__name__ | |
904 | |
905 # TODO: this is yucky, but can be fixed by using a model class | |
906 clone_log = clone_args.copy() | |
907 if 'dist_mat' in clone_log: del clone_log['dist_mat'] | |
908 log['CLONE_ARGS'] = clone_log | |
909 | |
910 if cluster_func is not None: | |
911 log['CLUSTER_FUNC'] = cluster_func.__name__ | |
912 log['CLUSTER_ARGS'] = cluster_args | |
913 log['NPROC'] = nproc | |
914 printLog(log) | |
915 | |
916 # Define feeder function and arguments | |
917 feed_args = {'db_file': db_file, | |
918 'group_func': group_func, | |
919 'group_args': group_args} | |
920 # Define worker function and arguments | |
921 work_args = {'clone_func': clone_func, | |
922 'clone_args': clone_args} | |
923 # Define collector function and arguments | |
924 collect_args = {'db_file': db_file, | |
925 'out_args': out_args, | |
926 'cluster_func': cluster_func, | |
927 'cluster_args': cluster_args} | |
928 | |
929 # Call process manager | |
930 result = manageProcesses(feed_func, work_func, collect_func, | |
931 feed_args, work_args, collect_args, | |
932 nproc, queue_size) | |
933 | |
934 # Print log | |
935 result['log']['END'] = 'DefineClones' | |
936 printLog(result['log']) | |
937 | |
938 return result['out_files'] | |
939 | |
940 | |
941 def getArgParser(): | |
942 """ | |
943 Defines the ArgumentParser | |
944 | |
945 Arguments: | |
946 None | |
947 | |
948 Returns: | |
949 an ArgumentParser object | |
950 """ | |
951 # Define input and output fields | |
952 fields = dedent( | |
953 ''' | |
954 output files: | |
955 clone-pass | |
956 database with assigned clonal group numbers. | |
957 clone-fail | |
958 database with records failing clonal grouping. | |
959 | |
960 required fields: | |
961 SEQUENCE_ID, V_CALL or V_CALL_GENOTYPED, D_CALL, J_CALL, JUNCTION | |
962 | |
963 <field> | |
964 sequence field specified by the --sf parameter | |
965 | |
966 output fields: | |
967 CLONE | |
968 ''') | |
969 | |
970 # Define ArgumentParser | |
971 parser = ArgumentParser(description=__doc__, epilog=fields, | |
972 formatter_class=CommonHelpFormatter) | |
973 parser.add_argument('--version', action='version', | |
974 version='%(prog)s:' + ' %s-%s' %(__version__, __date__)) | |
975 subparsers = parser.add_subparsers(title='subcommands', dest='command', metavar='', | |
976 help='Cloning method') | |
977 # TODO: This is a temporary fix for Python issue 9253 | |
978 subparsers.required = True | |
979 | |
980 # Parent parser | |
981 parser_parent = getCommonArgParser(seq_in=False, seq_out=False, db_in=True, | |
982 multiproc=True) | |
983 | |
984 # Distance cloning method | |
985 parser_bygroup = subparsers.add_parser('bygroup', parents=[parser_parent], | |
986 formatter_class=CommonHelpFormatter, | |
987 help='''Defines clones as having same V assignment, | |
988 J assignment, and junction length with | |
989 specified substitution distance model.''', | |
990 description='''Defines clones as having same V assignment, | |
991 J assignment, and junction length with | |
992 specified substitution distance model.''') | |
993 parser_bygroup.add_argument('-f', nargs='+', action='store', dest='fields', default=None, | |
994 help='Additional fields to use for grouping clones (non VDJ)') | |
995 parser_bygroup.add_argument('--mode', action='store', dest='mode', | |
996 choices=('allele', 'gene'), default=default_index_mode, | |
997 help='''Specifies whether to use the V(D)J allele or gene for | |
998 initial grouping.''') | |
999 parser_bygroup.add_argument('--act', action='store', dest='action', | |
1000 choices=('first', 'set'), default=default_index_action, | |
1001 help='''Specifies how to handle multiple V(D)J assignments | |
1002 for initial grouping.''') | |
1003 parser_bygroup.add_argument('--model', action='store', dest='model', | |
1004 choices=choices_bygroup_model, | |
1005 default=default_bygroup_model, | |
1006 help='''Specifies which substitution model to use for calculating distance | |
1007 between sequences. The "ham" model is nucleotide Hamming distance and | |
1008 "aa" is amino acid Hamming distance. The "hh_s1f" and "hh_s5f" models are | |
1009 human specific single nucleotide and 5-mer content models, respectively, | |
1010 from Yaari et al, 2013. The "mk_rs1nf" and "mk_rs5nf" models are | |
1011 mouse specific single nucleotide and 5-mer content models, respectively, | |
1012 from Cui et al, 2016. The "m1n_compat" and "hs1f_compat" models are | |
1013 deprecated models provided backwards compatibility with the "m1n" and | |
1014 "hs1f" models in Change-O v0.3.3 and SHazaM v0.1.4. Both | |
1015 5-mer models should be considered experimental.''') | |
1016 parser_bygroup.add_argument('--dist', action='store', dest='distance', type=float, | |
1017 default=default_distance, | |
1018 help='The distance threshold for clonal grouping') | |
1019 parser_bygroup.add_argument('--norm', action='store', dest='norm', | |
1020 choices=('len', 'mut', 'none'), default=default_norm, | |
1021 help='''Specifies how to normalize distances. One of none | |
1022 (do not normalize), len (normalize by length), | |
1023 or mut (normalize by number of mutations between sequences).''') | |
1024 parser_bygroup.add_argument('--sym', action='store', dest='sym', | |
1025 choices=('avg', 'min'), default=default_sym, | |
1026 help='''Specifies how to combine asymmetric distances. One of avg | |
1027 (average of A->B and B->A) or min (minimum of A->B and B->A).''') | |
1028 parser_bygroup.add_argument('--link', action='store', dest='linkage', | |
1029 choices=('single', 'average', 'complete'), default=default_linkage, | |
1030 help='''Type of linkage to use for hierarchical clustering.''') | |
1031 parser_bygroup.add_argument('--sf', action='store', dest='seq_field', | |
1032 default=default_seq_field, | |
1033 help='''The name of the field to be used to calculate | |
1034 distance between records''') | |
1035 parser_bygroup.set_defaults(feed_func=feedQueue) | |
1036 parser_bygroup.set_defaults(work_func=processQueue) | |
1037 parser_bygroup.set_defaults(collect_func=collectQueue) | |
1038 parser_bygroup.set_defaults(group_func=indexJunctions) | |
1039 parser_bygroup.set_defaults(clone_func=distanceClones) | |
1040 | |
1041 # Chen2010 | |
1042 parser_chen = subparsers.add_parser('chen2010', parents=[parser_parent], | |
1043 formatter_class=CommonHelpFormatter, | |
1044 help='''Defines clones by method specified in Chen, 2010.''', | |
1045 description='''Defines clones by method specified in Chen, 2010.''') | |
1046 parser_chen.set_defaults(feed_func=feedQueueClust) | |
1047 parser_chen.set_defaults(work_func=processQueueClust) | |
1048 parser_chen.set_defaults(collect_func=collectQueueClust) | |
1049 parser_chen.set_defaults(cluster_func=hierClust) | |
1050 | |
1051 # Ademokun2011 | |
1052 parser_ade = subparsers.add_parser('ademokun2011', parents=[parser_parent], | |
1053 formatter_class=CommonHelpFormatter, | |
1054 help='''Defines clones by method specified in Ademokun, 2011.''', | |
1055 description='''Defines clones by method specified in Ademokun, 2011.''') | |
1056 parser_ade.set_defaults(feed_func=feedQueueClust) | |
1057 parser_ade.set_defaults(work_func=processQueueClust) | |
1058 parser_ade.set_defaults(collect_func=collectQueueClust) | |
1059 parser_ade.set_defaults(cluster_func=hierClust) | |
1060 | |
1061 return parser | |
1062 | |
1063 | |
1064 if __name__ == '__main__': | |
1065 """ | |
1066 Parses command line arguments and calls main function | |
1067 """ | |
1068 # Parse arguments | |
1069 parser = getArgParser() | |
1070 checkArgs(parser) | |
1071 args = parser.parse_args() | |
1072 args_dict = parseCommonArgs(args) | |
1073 # Convert case of fields | |
1074 if 'seq_field' in args_dict: | |
1075 args_dict['seq_field'] = args_dict['seq_field'].upper() | |
1076 if 'fields' in args_dict and args_dict['fields'] is not None: | |
1077 args_dict['fields'] = [f.upper() for f in args_dict['fields']] | |
1078 | |
1079 # Define clone_args | |
1080 if args.command == 'bygroup': | |
1081 args_dict['group_args'] = {'fields': args_dict['fields'], | |
1082 'action': args_dict['action'], | |
1083 'mode':args_dict['mode']} | |
1084 args_dict['clone_args'] = {'model': args_dict['model'], | |
1085 'distance': args_dict['distance'], | |
1086 'norm': args_dict['norm'], | |
1087 'sym': args_dict['sym'], | |
1088 'linkage': args_dict['linkage'], | |
1089 'seq_field': args_dict['seq_field']} | |
1090 | |
1091 # Get distance matrix | |
1092 try: | |
1093 args_dict['clone_args']['dist_mat'] = distance_models[args_dict['model']] | |
1094 except KeyError: | |
1095 sys.exit('Unrecognized distance model: %s' % args_dict['model']) | |
1096 | |
1097 del args_dict['fields'] | |
1098 del args_dict['action'] | |
1099 del args_dict['mode'] | |
1100 del args_dict['model'] | |
1101 del args_dict['distance'] | |
1102 del args_dict['norm'] | |
1103 del args_dict['sym'] | |
1104 del args_dict['linkage'] | |
1105 del args_dict['seq_field'] | |
1106 | |
1107 # Define clone_args | |
1108 if args.command == 'chen2010': | |
1109 args_dict['clone_func'] = distChen2010 | |
1110 args_dict['cluster_args'] = {'method': args.command } | |
1111 | |
1112 if args.command == 'ademokun2011': | |
1113 args_dict['clone_func'] = distAdemokun2011 | |
1114 args_dict['cluster_args'] = {'method': args.command } | |
1115 | |
1116 # Call defineClones | |
1117 del args_dict['command'] | |
1118 del args_dict['db_files'] | |
1119 for f in args.__dict__['db_files']: | |
1120 args_dict['db_file'] = f | |
1121 defineClones(**args_dict) |