Mercurial > repos > xuebing > sharplabtool
comparison tools/stats/grouping.py @ 0:9071e359b9a3
Uploaded
author | xuebing |
---|---|
date | Fri, 09 Mar 2012 19:37:19 -0500 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:9071e359b9a3 |
---|---|
1 #!/usr/bin/env python | |
2 # Guruprasad Ananda | |
3 # Refactored 2011, Kanwei Li | |
4 # Refactored to use numpy instead of rpy | |
5 """ | |
6 This tool provides the SQL "group by" functionality. | |
7 """ | |
8 import sys, commands, tempfile, random | |
9 try: | |
10 import numpy | |
11 except: | |
12 from galaxy import eggs | |
13 eggs.require( "numpy" ) | |
14 import numpy | |
15 | |
16 from itertools import groupby | |
17 | |
18 def stop_err(msg): | |
19 sys.stderr.write(msg) | |
20 sys.exit() | |
21 | |
22 def mode(data): | |
23 counts = {} | |
24 for x in data: | |
25 counts[x] = counts.get(x,0) + 1 | |
26 maxcount = max(counts.values()) | |
27 modelist = [] | |
28 for x in counts: | |
29 if counts[x] == maxcount: | |
30 modelist.append( str(x) ) | |
31 return ','.join(modelist) | |
32 | |
33 def main(): | |
34 inputfile = sys.argv[2] | |
35 ignorecase = int(sys.argv[4]) | |
36 ops = [] | |
37 cols = [] | |
38 round_val = [] | |
39 data_ary = [] | |
40 | |
41 for var in sys.argv[5:]: | |
42 op, col, do_round = var.split() | |
43 ops.append(op) | |
44 cols.append(col) | |
45 round_val.append(do_round) | |
46 """ | |
47 At this point, ops, cols and rounds will look something like this: | |
48 ops: ['mean', 'min', 'c'] | |
49 cols: ['1', '3', '4'] | |
50 round_val: ['no', 'yes' 'no'] | |
51 """ | |
52 | |
53 try: | |
54 group_col = int( sys.argv[3] )-1 | |
55 except: | |
56 stop_err( "Group column not specified." ) | |
57 | |
58 str_ops = ['c', 'length', 'unique', 'random', 'cuniq', 'Mode'] #ops that can handle string/non-numeric inputs | |
59 | |
60 tmpfile = tempfile.NamedTemporaryFile() | |
61 | |
62 try: | |
63 """ | |
64 The -k option for the Posix sort command is as follows: | |
65 -k, --key=POS1[,POS2] | |
66 start a key at POS1, end it at POS2 (origin 1) | |
67 In other words, column positions start at 1 rather than 0, so | |
68 we need to add 1 to group_col. | |
69 if POS2 is not specified, the newer versions of sort will consider the entire line for sorting. To prevent this, we set POS2=POS1. | |
70 """ | |
71 case = '' | |
72 if ignorecase == 1: | |
73 case = '-f' | |
74 command_line = "sort -t ' ' %s -k%s,%s -o %s %s" % (case, group_col+1, group_col+1, tmpfile.name, inputfile) | |
75 except Exception, exc: | |
76 stop_err( 'Initialization error -> %s' %str(exc) ) | |
77 | |
78 error_code, stdout = commands.getstatusoutput(command_line) | |
79 | |
80 if error_code != 0: | |
81 stop_err( "Sorting input dataset resulted in error: %s: %s" %( error_code, stdout )) | |
82 | |
83 fout = open(sys.argv[1], "w") | |
84 | |
85 def is_new_item(line): | |
86 item = line.strip().split("\t")[group_col] | |
87 if ignorecase == 1: | |
88 return item.lower() | |
89 return item | |
90 | |
91 for key, line_list in groupby(tmpfile, key=is_new_item): | |
92 op_vals = [ [] for op in ops ] | |
93 out_str = key | |
94 multiple_modes = False | |
95 mode_index = None | |
96 | |
97 for line in line_list: | |
98 fields = line.strip().split("\t") | |
99 for i, col in enumerate(cols): | |
100 col = int(col)-1 # cXX from galaxy is 1-based | |
101 try: | |
102 val = fields[col].strip() | |
103 op_vals[i].append(val) | |
104 except IndexError: | |
105 sys.stderr.write( 'Could not access the value for column %s on line: "%s". Make sure file is tab-delimited.\n' % (col+1, line) ) | |
106 sys.exit( 1 ) | |
107 | |
108 # Generate string for each op for this group | |
109 for i, op in enumerate( ops ): | |
110 data = op_vals[i] | |
111 rval = "" | |
112 if op == "mode": | |
113 rval = mode( data ) | |
114 elif op == "length": | |
115 rval = len( data ) | |
116 elif op == "random": | |
117 rval = random.choice(data) | |
118 elif op in ['cat', 'cat_uniq']: | |
119 if op == 'cat_uniq': | |
120 data = numpy.unique(data) | |
121 rval = ','.join(data) | |
122 elif op == "unique": | |
123 rval = len( numpy.unique(data) ) | |
124 else: | |
125 # some kind of numpy fn | |
126 try: | |
127 data = map(float, data) | |
128 except ValueError: | |
129 sys.stderr.write( "Operation %s expected number values but got %s instead.\n" % (op, data) ) | |
130 sys.exit( 1 ) | |
131 rval = getattr(numpy, op)( data ) | |
132 if round_val[i] == 'yes': | |
133 rval = round(rval) | |
134 else: | |
135 rval = '%g' % rval | |
136 | |
137 out_str += "\t%s" % rval | |
138 | |
139 fout.write(out_str + "\n") | |
140 | |
141 # Generate a useful info message. | |
142 msg = "--Group by c%d: " %(group_col+1) | |
143 for i, op in enumerate(ops): | |
144 if op == 'cat': | |
145 op = 'concat' | |
146 elif op == 'cat_uniq': | |
147 op = 'concat_distinct' | |
148 elif op == 'length': | |
149 op = 'count' | |
150 elif op == 'unique': | |
151 op = 'count_distinct' | |
152 elif op == 'random': | |
153 op = 'randomly_pick' | |
154 | |
155 msg += op + "[c" + cols[i] + "] " | |
156 | |
157 print msg | |
158 fout.close() | |
159 tmpfile.close() | |
160 | |
161 if __name__ == "__main__": | |
162 main() |