Mercurial > repos > guerler > springsuite
comparison planemo/lib/python3.7/site-packages/rdflib/plugins/sparql/aggregates.py @ 1:56ad4e20f292 draft
"planemo upload commit 6eee67778febed82ddd413c3ca40b3183a3898f1"
author | guerler |
---|---|
date | Fri, 31 Jul 2020 00:32:28 -0400 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
0:d30785e31577 | 1:56ad4e20f292 |
---|---|
1 from rdflib import Literal, XSD | |
2 | |
3 from rdflib.plugins.sparql.evalutils import _eval, NotBoundError, _val | |
4 from rdflib.plugins.sparql.operators import numeric | |
5 from rdflib.plugins.sparql.datatypes import type_promotion | |
6 | |
7 from rdflib.plugins.sparql.sparql import SPARQLTypeError | |
8 | |
9 from decimal import Decimal | |
10 | |
11 """ | |
12 Aggregation functions | |
13 """ | |
14 | |
15 class Accumulator(object): | |
16 """abstract base class for different aggregation functions """ | |
17 | |
18 def __init__(self, aggregation): | |
19 self.var = aggregation.res | |
20 self.expr = aggregation.vars | |
21 if not aggregation.distinct: | |
22 self.use_row = self.dont_care | |
23 self.distinct = False | |
24 else: | |
25 self.distinct = aggregation.distinct | |
26 self.seen = set() | |
27 | |
28 def dont_care(self, row): | |
29 """skips distinct test """ | |
30 return True | |
31 | |
32 def use_row(self, row): | |
33 """tests distinct with set """ | |
34 return _eval(self.expr, row) not in self.seen | |
35 | |
36 def set_value(self, bindings): | |
37 """sets final value in bindings""" | |
38 bindings[self.var] = self.get_value() | |
39 | |
40 | |
41 class Counter(Accumulator): | |
42 | |
43 def __init__(self, aggregation): | |
44 super(Counter, self).__init__(aggregation) | |
45 self.value = 0 | |
46 if self.expr == "*": | |
47 # cannot eval "*" => always use the full row | |
48 self.eval_row = self.eval_full_row | |
49 | |
50 def update(self, row, aggregator): | |
51 try: | |
52 val = self.eval_row(row) | |
53 except NotBoundError: | |
54 # skip UNDEF | |
55 return | |
56 self.value += 1 | |
57 if self.distinct: | |
58 self.seen.add(val) | |
59 | |
60 def get_value(self): | |
61 return Literal(self.value) | |
62 | |
63 def eval_row(self, row): | |
64 return _eval(self.expr, row) | |
65 | |
66 def eval_full_row(self, row): | |
67 return row | |
68 | |
69 def use_row(self, row): | |
70 return self.eval_row(row) not in self.seen | |
71 | |
72 | |
73 def type_safe_numbers(*args): | |
74 types = list(map(type, args)) | |
75 if float in types and Decimal in types: | |
76 return list(map(float, args)) | |
77 return args | |
78 | |
79 | |
80 class Sum(Accumulator): | |
81 | |
82 def __init__(self, aggregation): | |
83 super(Sum, self).__init__(aggregation) | |
84 self.value = 0 | |
85 self.datatype = None | |
86 | |
87 def update(self, row, aggregator): | |
88 try: | |
89 value = _eval(self.expr, row) | |
90 dt = self.datatype | |
91 if dt is None: | |
92 dt = value.datatype | |
93 else: | |
94 dt = type_promotion(dt, value.datatype) | |
95 self.datatype = dt | |
96 self.value = sum(type_safe_numbers(self.value, numeric(value))) | |
97 if self.distinct: | |
98 self.seen.add(value) | |
99 except NotBoundError: | |
100 # skip UNDEF | |
101 pass | |
102 | |
103 def get_value(self): | |
104 return Literal(self.value, datatype=self.datatype) | |
105 | |
106 class Average(Accumulator): | |
107 | |
108 def __init__(self, aggregation): | |
109 super(Average, self).__init__(aggregation) | |
110 self.counter = 0 | |
111 self.sum = 0 | |
112 self.datatype = None | |
113 | |
114 def update(self, row, aggregator): | |
115 try: | |
116 value = _eval(self.expr, row) | |
117 dt = self.datatype | |
118 self.sum = sum(type_safe_numbers(self.sum, numeric(value))) | |
119 if dt is None: | |
120 dt = value.datatype | |
121 else: | |
122 dt = type_promotion(dt, value.datatype) | |
123 self.datatype = dt | |
124 if self.distinct: | |
125 self.seen.add(value) | |
126 self.counter += 1 | |
127 # skip UNDEF or BNode => SPARQLTypeError | |
128 except NotBoundError: | |
129 pass | |
130 except SPARQLTypeError: | |
131 pass | |
132 | |
133 def get_value(self): | |
134 if self.counter == 0: | |
135 return Literal(0) | |
136 if self.datatype in (XSD.float, XSD.double): | |
137 return Literal(self.sum / self.counter) | |
138 else: | |
139 return Literal(Decimal(self.sum) / Decimal(self.counter)) | |
140 | |
141 | |
142 class Extremum(Accumulator): | |
143 """abstract base class for Minimum and Maximum""" | |
144 | |
145 def __init__(self, aggregation): | |
146 super(Extremum, self).__init__(aggregation) | |
147 self.value = None | |
148 # DISTINCT would not change the value for MIN or MAX | |
149 self.use_row = self.dont_care | |
150 | |
151 def set_value(self, bindings): | |
152 if self.value is not None: | |
153 # simply do not set if self.value is still None | |
154 bindings[self.var] = Literal(self.value) | |
155 | |
156 def update(self, row, aggregator): | |
157 try: | |
158 if self.value is None: | |
159 self.value = _eval(self.expr, row) | |
160 else: | |
161 # self.compare is implemented by Minimum/Maximum | |
162 self.value = self.compare(self.value, _eval(self.expr, row)) | |
163 # skip UNDEF or BNode => SPARQLTypeError | |
164 except NotBoundError: | |
165 pass | |
166 except SPARQLTypeError: | |
167 pass | |
168 | |
169 | |
170 class Minimum(Extremum): | |
171 | |
172 def compare(self, val1, val2): | |
173 return min(val1, val2, key=_val) | |
174 | |
175 | |
176 class Maximum(Extremum): | |
177 | |
178 def compare(self, val1, val2): | |
179 return max(val1, val2, key=_val) | |
180 | |
181 | |
182 class Sample(Accumulator): | |
183 """takes the first eligable value""" | |
184 | |
185 def __init__(self, aggregation): | |
186 super(Sample, self).__init__(aggregation) | |
187 # DISTINCT would not change the value | |
188 self.use_row = self.dont_care | |
189 | |
190 def update(self, row, aggregator): | |
191 try: | |
192 # set the value now | |
193 aggregator.bindings[self.var] = _eval(self.expr, row) | |
194 # and skip this accumulator for future rows | |
195 del aggregator.accumulators[self.var] | |
196 except NotBoundError: | |
197 pass | |
198 | |
199 def get_value(self): | |
200 # set None if no value was set | |
201 return None | |
202 | |
203 class GroupConcat(Accumulator): | |
204 | |
205 def __init__(self, aggregation): | |
206 super(GroupConcat, self).__init__(aggregation) | |
207 # only GROUPCONCAT needs to have a list as accumlator | |
208 self.value = [] | |
209 self.separator = aggregation.separator or " " | |
210 | |
211 def update(self, row, aggregator): | |
212 try: | |
213 value = _eval(self.expr, row) | |
214 self.value.append(value) | |
215 if self.distinct: | |
216 self.seen.add(value) | |
217 # skip UNDEF | |
218 except NotBoundError: | |
219 pass | |
220 | |
221 def get_value(self): | |
222 return Literal(self.separator.join(str(v) for v in self.value)) | |
223 | |
224 | |
225 class Aggregator(object): | |
226 """combines different Accumulator objects""" | |
227 | |
228 accumulator_classes = { | |
229 "Aggregate_Count": Counter, | |
230 "Aggregate_Sample": Sample, | |
231 "Aggregate_Sum": Sum, | |
232 "Aggregate_Avg": Average, | |
233 "Aggregate_Min": Minimum, | |
234 "Aggregate_Max": Maximum, | |
235 "Aggregate_GroupConcat": GroupConcat, | |
236 } | |
237 | |
238 def __init__(self, aggregations): | |
239 self.bindings = {} | |
240 self.accumulators = {} | |
241 for a in aggregations: | |
242 accumulator_class = self.accumulator_classes.get(a.name) | |
243 if accumulator_class is None: | |
244 raise Exception("Unknown aggregate function " + a.name) | |
245 self.accumulators[a.res] = accumulator_class(a) | |
246 | |
247 def update(self, row): | |
248 """update all own accumulators""" | |
249 # SAMPLE accumulators may delete themselves | |
250 # => iterate over list not generator | |
251 | |
252 for acc in list(self.accumulators.values()): | |
253 if acc.use_row(row): | |
254 acc.update(row, self) | |
255 | |
256 def get_bindings(self): | |
257 """calculate and set last values""" | |
258 for acc in self.accumulators.values(): | |
259 acc.set_value(self.bindings) | |
260 return self.bindings |