view tools/taxonomy/lca.py @ 1:cdcb0ce84a1b

Uploaded
author xuebing
date Fri, 09 Mar 2012 19:45:15 -0500
parents 9071e359b9a3
children
line wrap: on
line source

#!/usr/bin/env python
#Guruprasad Ananda
"""
Least Common Ancestor tool.
"""
import sys, string, re, commands, tempfile, random

def stop_err(msg):
    sys.stderr.write(msg)
    sys.exit()

def main():
    try:
        inputfile = sys.argv[1]
        outfile = sys.argv[2]
        rank_bound = int( sys.argv[3] )
        """
        Mapping of ranks:
        root        :2, 
        superkingdom:3, 
        kingdom     :4, 
        subkingdom  :5, 
        superphylum :6, 
        phylum      :7, 
        subphylum   :8, 
        superclass  :9, 
        class       :10, 
        subclass    :11, 
        superorder  :12, 
        order       :13, 
        suborder    :14, 
        superfamily :15,
        family      :16,
        subfamily   :17,
        tribe       :18,
        subtribe    :19,
        genus       :20,
        subgenus    :21,
        species     :22,
        subspecies  :23,
        """
    except:
        stop_err("Syntax error: Use correct syntax: program infile outfile")
    
    fin = open(sys.argv[1],'r')
    for j, line in enumerate( fin ):
        elems = line.strip().split('\t')
        if len(elems) < 24:
            stop_err("The format of the input dataset is incorrect. Taxonomy datatype should contain at least 24 columns.")
        if j > 30:
            break
        cols = range(1,len(elems))
    fin.close()
       
    group_col = 0
    tmpfile = tempfile.NamedTemporaryFile()

    try:
        """
        The -k option for the Posix sort command is as follows:
        -k, --key=POS1[,POS2]
        start a key at POS1, end it at POS2 (origin 1)
        In other words, column positions start at 1 rather than 0, so 
        we need to add 1 to group_col.
        if POS2 is not specified, the newer versions of sort will consider the entire line for sorting. To prevent this, we set POS2=POS1.
        """
        command_line = "sort -f -k " + str(group_col+1) +"," + str(group_col+1) + " -o " + tmpfile.name + " " + inputfile
    except Exception, exc:
        stop_err( 'Initialization error -> %s' %str(exc) )
        
    error_code, stdout = commands.getstatusoutput(command_line)
    
    if error_code != 0:
        stop_err( "Sorting input dataset resulted in error: %s: %s" %( error_code, stdout ))    

    prev_item = ""
    prev_vals = []
    remaining_vals = []
    skipped_lines = 0
    fout = open(outfile, "w")
    block_valid = False
    
    
    for ii, line in enumerate( file( tmpfile.name )):
        if line and not line.startswith( '#' ) and len(line.split('\t')) >= 24: #Taxonomy datatype should have at least 24 columns
            line = line.rstrip( '\r\n' )
            try:
                fields = line.split("\t")
                item = fields[group_col]
                if prev_item != "":
                    # At this level, we're grouping on values (item and prev_item) in group_col
                    if item == prev_item:
                        # Keep iterating and storing values until a new value is encountered.
                        if block_valid:
                            for i, col in enumerate(cols):
                                if col >= 3:
                                    prev_vals[i].append(fields[col].strip())
                                    if len(set(prev_vals[i])) > 1:
                                        block_valid = False
                                        break
                            
                    else:   
                        """
                        When a new value is encountered, write the previous value and the 
                        corresponding aggregate values into the output file.  This works 
                        due to the sort on group_col we've applied to the data above.
                        """
                        out_list = ['']*24
                        out_list[0] = str(prev_item)
                        out_list[1] = str(prev_vals[0][0])
                        out_list[2] = str(prev_vals[1][0])
                        
                        for k, col in enumerate(cols):
                            if col >= 3 and col < 24:
                                if len(set(prev_vals[k])) == 1:
                                    out_list[col] = prev_vals[k][0]
                                else:
                                    break
                        while k < 23:
                            out_list[k+1] = 'n' 
                            k += 1
                        
                        j = 0
                        while True:
                            try:
                                out_list.append(str(prev_vals[23+j][0]))
                                j += 1
                            except:
                                break
                            
                        if rank_bound == 0:     
                            print >>fout, '\t'.join(out_list).strip()
                        else:
                            if ''.join(out_list[rank_bound:24]) != 'n'*( 24 - rank_bound ):
                                print >>fout, '\t'.join(out_list).strip()
                        
                        block_valid = True
                        prev_item = item   
                        prev_vals = [] 
                        for col in cols:
                            val_list = []
                            val_list.append(fields[col].strip())
                            prev_vals.append(val_list)
                        
                else:
                    # This only occurs once, right at the start of the iteration.
                    block_valid = True
                    prev_item = item    #groupby item
                    for col in cols:    #everyting else
                        val_list = []
                        val_list.append(fields[col].strip())
                        prev_vals.append(val_list)
            
            except:
                skipped_lines += 1
        else:
            skipped_lines += 1
            
    # Handle the last grouped value
    out_list = ['']*24
    out_list[0] = str(prev_item)
    out_list[1] = str(prev_vals[0][0])
    out_list[2] = str(prev_vals[1][0])
    
    for k, col in enumerate(cols):
        if col >= 3 and col < 24:
            if len(set(prev_vals[k])) == 1:
                out_list[col] = prev_vals[k][0]
            else:
                break
    while k < 23:
        out_list[k+1] = 'n' 
        k += 1
    
    j = 0
    while True:
        try:
            out_list.append(str(prev_vals[23+j][0]))
            j += 1
        except:
            break
        
    if rank_bound == 0:     
        print >>fout, '\t'.join(out_list).strip()
    else:
        if ''.join(out_list[rank_bound:24]) != 'n'*( 24 - rank_bound ):
            print >>fout, '\t'.join(out_list).strip()
        
    if skipped_lines > 0:
        print "Skipped %d invalid lines." % ( skipped_lines )
    
if __name__ == "__main__":
    main()