| 
0
 | 
     1 #!/usr/bin/env python
 | 
| 
 | 
     2 #Guruprasad Ananda
 | 
| 
 | 
     3 """
 | 
| 
 | 
     4 Least Common Ancestor tool.
 | 
| 
 | 
     5 """
 | 
| 
 | 
     6 import sys, string, re, commands, tempfile, random
 | 
| 
 | 
     7 
 | 
| 
 | 
     8 def stop_err(msg):
 | 
| 
 | 
     9     sys.stderr.write(msg)
 | 
| 
 | 
    10     sys.exit()
 | 
| 
 | 
    11 
 | 
| 
 | 
    12 def main():
 | 
| 
 | 
    13     try:
 | 
| 
 | 
    14         inputfile = sys.argv[1]
 | 
| 
 | 
    15         outfile = sys.argv[2]
 | 
| 
 | 
    16         rank_bound = int( sys.argv[3] )
 | 
| 
 | 
    17         """
 | 
| 
 | 
    18         Mapping of ranks:
 | 
| 
 | 
    19         root        :2, 
 | 
| 
 | 
    20         superkingdom:3, 
 | 
| 
 | 
    21         kingdom     :4, 
 | 
| 
 | 
    22         subkingdom  :5, 
 | 
| 
 | 
    23         superphylum :6, 
 | 
| 
 | 
    24         phylum      :7, 
 | 
| 
 | 
    25         subphylum   :8, 
 | 
| 
 | 
    26         superclass  :9, 
 | 
| 
 | 
    27         class       :10, 
 | 
| 
 | 
    28         subclass    :11, 
 | 
| 
 | 
    29         superorder  :12, 
 | 
| 
 | 
    30         order       :13, 
 | 
| 
 | 
    31         suborder    :14, 
 | 
| 
 | 
    32         superfamily :15,
 | 
| 
 | 
    33         family      :16,
 | 
| 
 | 
    34         subfamily   :17,
 | 
| 
 | 
    35         tribe       :18,
 | 
| 
 | 
    36         subtribe    :19,
 | 
| 
 | 
    37         genus       :20,
 | 
| 
 | 
    38         subgenus    :21,
 | 
| 
 | 
    39         species     :22,
 | 
| 
 | 
    40         subspecies  :23,
 | 
| 
 | 
    41         """
 | 
| 
 | 
    42     except:
 | 
| 
 | 
    43         stop_err("Syntax error: Use correct syntax: program infile outfile")
 | 
| 
 | 
    44     
 | 
| 
 | 
    45     fin = open(sys.argv[1],'r')
 | 
| 
 | 
    46     for j, line in enumerate( fin ):
 | 
| 
 | 
    47         elems = line.strip().split('\t')
 | 
| 
 | 
    48         if len(elems) < 24:
 | 
| 
 | 
    49             stop_err("The format of the input dataset is incorrect. Taxonomy datatype should contain at least 24 columns.")
 | 
| 
 | 
    50         if j > 30:
 | 
| 
 | 
    51             break
 | 
| 
 | 
    52         cols = range(1,len(elems))
 | 
| 
 | 
    53     fin.close()
 | 
| 
 | 
    54        
 | 
| 
 | 
    55     group_col = 0
 | 
| 
 | 
    56     tmpfile = tempfile.NamedTemporaryFile()
 | 
| 
 | 
    57 
 | 
| 
 | 
    58     try:
 | 
| 
 | 
    59         """
 | 
| 
 | 
    60         The -k option for the Posix sort command is as follows:
 | 
| 
 | 
    61         -k, --key=POS1[,POS2]
 | 
| 
 | 
    62         start a key at POS1, end it at POS2 (origin 1)
 | 
| 
 | 
    63         In other words, column positions start at 1 rather than 0, so 
 | 
| 
 | 
    64         we need to add 1 to group_col.
 | 
| 
 | 
    65         if POS2 is not specified, the newer versions of sort will consider the entire line for sorting. To prevent this, we set POS2=POS1.
 | 
| 
 | 
    66         """
 | 
| 
 | 
    67         command_line = "sort -f -k " + str(group_col+1) +"," + str(group_col+1) + " -o " + tmpfile.name + " " + inputfile
 | 
| 
 | 
    68     except Exception, exc:
 | 
| 
 | 
    69         stop_err( 'Initialization error -> %s' %str(exc) )
 | 
| 
 | 
    70         
 | 
| 
 | 
    71     error_code, stdout = commands.getstatusoutput(command_line)
 | 
| 
 | 
    72     
 | 
| 
 | 
    73     if error_code != 0:
 | 
| 
 | 
    74         stop_err( "Sorting input dataset resulted in error: %s: %s" %( error_code, stdout ))    
 | 
| 
 | 
    75 
 | 
| 
 | 
    76     prev_item = ""
 | 
| 
 | 
    77     prev_vals = []
 | 
| 
 | 
    78     remaining_vals = []
 | 
| 
 | 
    79     skipped_lines = 0
 | 
| 
 | 
    80     fout = open(outfile, "w")
 | 
| 
 | 
    81     block_valid = False
 | 
| 
 | 
    82     
 | 
| 
 | 
    83     
 | 
| 
 | 
    84     for ii, line in enumerate( file( tmpfile.name )):
 | 
| 
 | 
    85         if line and not line.startswith( '#' ) and len(line.split('\t')) >= 24: #Taxonomy datatype should have at least 24 columns
 | 
| 
 | 
    86             line = line.rstrip( '\r\n' )
 | 
| 
 | 
    87             try:
 | 
| 
 | 
    88                 fields = line.split("\t")
 | 
| 
 | 
    89                 item = fields[group_col]
 | 
| 
 | 
    90                 if prev_item != "":
 | 
| 
 | 
    91                     # At this level, we're grouping on values (item and prev_item) in group_col
 | 
| 
 | 
    92                     if item == prev_item:
 | 
| 
 | 
    93                         # Keep iterating and storing values until a new value is encountered.
 | 
| 
 | 
    94                         if block_valid:
 | 
| 
 | 
    95                             for i, col in enumerate(cols):
 | 
| 
 | 
    96                                 if col >= 3:
 | 
| 
 | 
    97                                     prev_vals[i].append(fields[col].strip())
 | 
| 
 | 
    98                                     if len(set(prev_vals[i])) > 1:
 | 
| 
 | 
    99                                         block_valid = False
 | 
| 
 | 
   100                                         break
 | 
| 
 | 
   101                             
 | 
| 
 | 
   102                     else:   
 | 
| 
 | 
   103                         """
 | 
| 
 | 
   104                         When a new value is encountered, write the previous value and the 
 | 
| 
 | 
   105                         corresponding aggregate values into the output file.  This works 
 | 
| 
 | 
   106                         due to the sort on group_col we've applied to the data above.
 | 
| 
 | 
   107                         """
 | 
| 
 | 
   108                         out_list = ['']*24
 | 
| 
 | 
   109                         out_list[0] = str(prev_item)
 | 
| 
 | 
   110                         out_list[1] = str(prev_vals[0][0])
 | 
| 
 | 
   111                         out_list[2] = str(prev_vals[1][0])
 | 
| 
 | 
   112                         
 | 
| 
 | 
   113                         for k, col in enumerate(cols):
 | 
| 
 | 
   114                             if col >= 3 and col < 24:
 | 
| 
 | 
   115                                 if len(set(prev_vals[k])) == 1:
 | 
| 
 | 
   116                                     out_list[col] = prev_vals[k][0]
 | 
| 
 | 
   117                                 else:
 | 
| 
 | 
   118                                     break
 | 
| 
 | 
   119                         while k < 23:
 | 
| 
 | 
   120                             out_list[k+1] = 'n' 
 | 
| 
 | 
   121                             k += 1
 | 
| 
 | 
   122                         
 | 
| 
 | 
   123                         j = 0
 | 
| 
 | 
   124                         while True:
 | 
| 
 | 
   125                             try:
 | 
| 
 | 
   126                                 out_list.append(str(prev_vals[23+j][0]))
 | 
| 
 | 
   127                                 j += 1
 | 
| 
 | 
   128                             except:
 | 
| 
 | 
   129                                 break
 | 
| 
 | 
   130                             
 | 
| 
 | 
   131                         if rank_bound == 0:     
 | 
| 
 | 
   132                             print >>fout, '\t'.join(out_list).strip()
 | 
| 
 | 
   133                         else:
 | 
| 
 | 
   134                             if ''.join(out_list[rank_bound:24]) != 'n'*( 24 - rank_bound ):
 | 
| 
 | 
   135                                 print >>fout, '\t'.join(out_list).strip()
 | 
| 
 | 
   136                         
 | 
| 
 | 
   137                         block_valid = True
 | 
| 
 | 
   138                         prev_item = item   
 | 
| 
 | 
   139                         prev_vals = [] 
 | 
| 
 | 
   140                         for col in cols:
 | 
| 
 | 
   141                             val_list = []
 | 
| 
 | 
   142                             val_list.append(fields[col].strip())
 | 
| 
 | 
   143                             prev_vals.append(val_list)
 | 
| 
 | 
   144                         
 | 
| 
 | 
   145                 else:
 | 
| 
 | 
   146                     # This only occurs once, right at the start of the iteration.
 | 
| 
 | 
   147                     block_valid = True
 | 
| 
 | 
   148                     prev_item = item    #groupby item
 | 
| 
 | 
   149                     for col in cols:    #everyting else
 | 
| 
 | 
   150                         val_list = []
 | 
| 
 | 
   151                         val_list.append(fields[col].strip())
 | 
| 
 | 
   152                         prev_vals.append(val_list)
 | 
| 
 | 
   153             
 | 
| 
 | 
   154             except:
 | 
| 
 | 
   155                 skipped_lines += 1
 | 
| 
 | 
   156         else:
 | 
| 
 | 
   157             skipped_lines += 1
 | 
| 
 | 
   158             
 | 
| 
 | 
   159     # Handle the last grouped value
 | 
| 
 | 
   160     out_list = ['']*24
 | 
| 
 | 
   161     out_list[0] = str(prev_item)
 | 
| 
 | 
   162     out_list[1] = str(prev_vals[0][0])
 | 
| 
 | 
   163     out_list[2] = str(prev_vals[1][0])
 | 
| 
 | 
   164     
 | 
| 
 | 
   165     for k, col in enumerate(cols):
 | 
| 
 | 
   166         if col >= 3 and col < 24:
 | 
| 
 | 
   167             if len(set(prev_vals[k])) == 1:
 | 
| 
 | 
   168                 out_list[col] = prev_vals[k][0]
 | 
| 
 | 
   169             else:
 | 
| 
 | 
   170                 break
 | 
| 
 | 
   171     while k < 23:
 | 
| 
 | 
   172         out_list[k+1] = 'n' 
 | 
| 
 | 
   173         k += 1
 | 
| 
 | 
   174     
 | 
| 
 | 
   175     j = 0
 | 
| 
 | 
   176     while True:
 | 
| 
 | 
   177         try:
 | 
| 
 | 
   178             out_list.append(str(prev_vals[23+j][0]))
 | 
| 
 | 
   179             j += 1
 | 
| 
 | 
   180         except:
 | 
| 
 | 
   181             break
 | 
| 
 | 
   182         
 | 
| 
 | 
   183     if rank_bound == 0:     
 | 
| 
 | 
   184         print >>fout, '\t'.join(out_list).strip()
 | 
| 
 | 
   185     else:
 | 
| 
 | 
   186         if ''.join(out_list[rank_bound:24]) != 'n'*( 24 - rank_bound ):
 | 
| 
 | 
   187             print >>fout, '\t'.join(out_list).strip()
 | 
| 
 | 
   188         
 | 
| 
 | 
   189     if skipped_lines > 0:
 | 
| 
 | 
   190         print "Skipped %d invalid lines." % ( skipped_lines )
 | 
| 
 | 
   191     
 | 
| 
 | 
   192 if __name__ == "__main__":
 | 
| 
 | 
   193     main() |