Mercurial > repos > xuebing > sharplabtool
diff tools/taxonomy/lca.py @ 0:9071e359b9a3
Uploaded
author | xuebing |
---|---|
date | Fri, 09 Mar 2012 19:37:19 -0500 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/tools/taxonomy/lca.py Fri Mar 09 19:37:19 2012 -0500 @@ -0,0 +1,193 @@ +#!/usr/bin/env python +#Guruprasad Ananda +""" +Least Common Ancestor tool. +""" +import sys, string, re, commands, tempfile, random + +def stop_err(msg): + sys.stderr.write(msg) + sys.exit() + +def main(): + try: + inputfile = sys.argv[1] + outfile = sys.argv[2] + rank_bound = int( sys.argv[3] ) + """ + Mapping of ranks: + root :2, + superkingdom:3, + kingdom :4, + subkingdom :5, + superphylum :6, + phylum :7, + subphylum :8, + superclass :9, + class :10, + subclass :11, + superorder :12, + order :13, + suborder :14, + superfamily :15, + family :16, + subfamily :17, + tribe :18, + subtribe :19, + genus :20, + subgenus :21, + species :22, + subspecies :23, + """ + except: + stop_err("Syntax error: Use correct syntax: program infile outfile") + + fin = open(sys.argv[1],'r') + for j, line in enumerate( fin ): + elems = line.strip().split('\t') + if len(elems) < 24: + stop_err("The format of the input dataset is incorrect. Taxonomy datatype should contain at least 24 columns.") + if j > 30: + break + cols = range(1,len(elems)) + fin.close() + + group_col = 0 + tmpfile = tempfile.NamedTemporaryFile() + + try: + """ + The -k option for the Posix sort command is as follows: + -k, --key=POS1[,POS2] + start a key at POS1, end it at POS2 (origin 1) + In other words, column positions start at 1 rather than 0, so + we need to add 1 to group_col. + if POS2 is not specified, the newer versions of sort will consider the entire line for sorting. To prevent this, we set POS2=POS1. + """ + command_line = "sort -f -k " + str(group_col+1) +"," + str(group_col+1) + " -o " + tmpfile.name + " " + inputfile + except Exception, exc: + stop_err( 'Initialization error -> %s' %str(exc) ) + + error_code, stdout = commands.getstatusoutput(command_line) + + if error_code != 0: + stop_err( "Sorting input dataset resulted in error: %s: %s" %( error_code, stdout )) + + prev_item = "" + prev_vals = [] + remaining_vals = [] + skipped_lines = 0 + fout = open(outfile, "w") + block_valid = False + + + for ii, line in enumerate( file( tmpfile.name )): + if line and not line.startswith( '#' ) and len(line.split('\t')) >= 24: #Taxonomy datatype should have at least 24 columns + line = line.rstrip( '\r\n' ) + try: + fields = line.split("\t") + item = fields[group_col] + if prev_item != "": + # At this level, we're grouping on values (item and prev_item) in group_col + if item == prev_item: + # Keep iterating and storing values until a new value is encountered. + if block_valid: + for i, col in enumerate(cols): + if col >= 3: + prev_vals[i].append(fields[col].strip()) + if len(set(prev_vals[i])) > 1: + block_valid = False + break + + else: + """ + When a new value is encountered, write the previous value and the + corresponding aggregate values into the output file. This works + due to the sort on group_col we've applied to the data above. + """ + out_list = ['']*24 + out_list[0] = str(prev_item) + out_list[1] = str(prev_vals[0][0]) + out_list[2] = str(prev_vals[1][0]) + + for k, col in enumerate(cols): + if col >= 3 and col < 24: + if len(set(prev_vals[k])) == 1: + out_list[col] = prev_vals[k][0] + else: + break + while k < 23: + out_list[k+1] = 'n' + k += 1 + + j = 0 + while True: + try: + out_list.append(str(prev_vals[23+j][0])) + j += 1 + except: + break + + if rank_bound == 0: + print >>fout, '\t'.join(out_list).strip() + else: + if ''.join(out_list[rank_bound:24]) != 'n'*( 24 - rank_bound ): + print >>fout, '\t'.join(out_list).strip() + + block_valid = True + prev_item = item + prev_vals = [] + for col in cols: + val_list = [] + val_list.append(fields[col].strip()) + prev_vals.append(val_list) + + else: + # This only occurs once, right at the start of the iteration. + block_valid = True + prev_item = item #groupby item + for col in cols: #everyting else + val_list = [] + val_list.append(fields[col].strip()) + prev_vals.append(val_list) + + except: + skipped_lines += 1 + else: + skipped_lines += 1 + + # Handle the last grouped value + out_list = ['']*24 + out_list[0] = str(prev_item) + out_list[1] = str(prev_vals[0][0]) + out_list[2] = str(prev_vals[1][0]) + + for k, col in enumerate(cols): + if col >= 3 and col < 24: + if len(set(prev_vals[k])) == 1: + out_list[col] = prev_vals[k][0] + else: + break + while k < 23: + out_list[k+1] = 'n' + k += 1 + + j = 0 + while True: + try: + out_list.append(str(prev_vals[23+j][0])) + j += 1 + except: + break + + if rank_bound == 0: + print >>fout, '\t'.join(out_list).strip() + else: + if ''.join(out_list[rank_bound:24]) != 'n'*( 24 - rank_bound ): + print >>fout, '\t'.join(out_list).strip() + + if skipped_lines > 0: + print "Skipped %d invalid lines." % ( skipped_lines ) + +if __name__ == "__main__": + main() \ No newline at end of file