comparison planemo/lib/python3.7/site-packages/rdflib/plugins/sparql/aggregates.py @ 1:56ad4e20f292 draft

"planemo upload commit 6eee67778febed82ddd413c3ca40b3183a3898f1"
author guerler
date Fri, 31 Jul 2020 00:32:28 -0400
parents
children
comparison
equal deleted inserted replaced
0:d30785e31577 1:56ad4e20f292
1 from rdflib import Literal, XSD
2
3 from rdflib.plugins.sparql.evalutils import _eval, NotBoundError, _val
4 from rdflib.plugins.sparql.operators import numeric
5 from rdflib.plugins.sparql.datatypes import type_promotion
6
7 from rdflib.plugins.sparql.sparql import SPARQLTypeError
8
9 from decimal import Decimal
10
11 """
12 Aggregation functions
13 """
14
15 class Accumulator(object):
16 """abstract base class for different aggregation functions """
17
18 def __init__(self, aggregation):
19 self.var = aggregation.res
20 self.expr = aggregation.vars
21 if not aggregation.distinct:
22 self.use_row = self.dont_care
23 self.distinct = False
24 else:
25 self.distinct = aggregation.distinct
26 self.seen = set()
27
28 def dont_care(self, row):
29 """skips distinct test """
30 return True
31
32 def use_row(self, row):
33 """tests distinct with set """
34 return _eval(self.expr, row) not in self.seen
35
36 def set_value(self, bindings):
37 """sets final value in bindings"""
38 bindings[self.var] = self.get_value()
39
40
41 class Counter(Accumulator):
42
43 def __init__(self, aggregation):
44 super(Counter, self).__init__(aggregation)
45 self.value = 0
46 if self.expr == "*":
47 # cannot eval "*" => always use the full row
48 self.eval_row = self.eval_full_row
49
50 def update(self, row, aggregator):
51 try:
52 val = self.eval_row(row)
53 except NotBoundError:
54 # skip UNDEF
55 return
56 self.value += 1
57 if self.distinct:
58 self.seen.add(val)
59
60 def get_value(self):
61 return Literal(self.value)
62
63 def eval_row(self, row):
64 return _eval(self.expr, row)
65
66 def eval_full_row(self, row):
67 return row
68
69 def use_row(self, row):
70 return self.eval_row(row) not in self.seen
71
72
73 def type_safe_numbers(*args):
74 types = list(map(type, args))
75 if float in types and Decimal in types:
76 return list(map(float, args))
77 return args
78
79
80 class Sum(Accumulator):
81
82 def __init__(self, aggregation):
83 super(Sum, self).__init__(aggregation)
84 self.value = 0
85 self.datatype = None
86
87 def update(self, row, aggregator):
88 try:
89 value = _eval(self.expr, row)
90 dt = self.datatype
91 if dt is None:
92 dt = value.datatype
93 else:
94 dt = type_promotion(dt, value.datatype)
95 self.datatype = dt
96 self.value = sum(type_safe_numbers(self.value, numeric(value)))
97 if self.distinct:
98 self.seen.add(value)
99 except NotBoundError:
100 # skip UNDEF
101 pass
102
103 def get_value(self):
104 return Literal(self.value, datatype=self.datatype)
105
106 class Average(Accumulator):
107
108 def __init__(self, aggregation):
109 super(Average, self).__init__(aggregation)
110 self.counter = 0
111 self.sum = 0
112 self.datatype = None
113
114 def update(self, row, aggregator):
115 try:
116 value = _eval(self.expr, row)
117 dt = self.datatype
118 self.sum = sum(type_safe_numbers(self.sum, numeric(value)))
119 if dt is None:
120 dt = value.datatype
121 else:
122 dt = type_promotion(dt, value.datatype)
123 self.datatype = dt
124 if self.distinct:
125 self.seen.add(value)
126 self.counter += 1
127 # skip UNDEF or BNode => SPARQLTypeError
128 except NotBoundError:
129 pass
130 except SPARQLTypeError:
131 pass
132
133 def get_value(self):
134 if self.counter == 0:
135 return Literal(0)
136 if self.datatype in (XSD.float, XSD.double):
137 return Literal(self.sum / self.counter)
138 else:
139 return Literal(Decimal(self.sum) / Decimal(self.counter))
140
141
142 class Extremum(Accumulator):
143 """abstract base class for Minimum and Maximum"""
144
145 def __init__(self, aggregation):
146 super(Extremum, self).__init__(aggregation)
147 self.value = None
148 # DISTINCT would not change the value for MIN or MAX
149 self.use_row = self.dont_care
150
151 def set_value(self, bindings):
152 if self.value is not None:
153 # simply do not set if self.value is still None
154 bindings[self.var] = Literal(self.value)
155
156 def update(self, row, aggregator):
157 try:
158 if self.value is None:
159 self.value = _eval(self.expr, row)
160 else:
161 # self.compare is implemented by Minimum/Maximum
162 self.value = self.compare(self.value, _eval(self.expr, row))
163 # skip UNDEF or BNode => SPARQLTypeError
164 except NotBoundError:
165 pass
166 except SPARQLTypeError:
167 pass
168
169
170 class Minimum(Extremum):
171
172 def compare(self, val1, val2):
173 return min(val1, val2, key=_val)
174
175
176 class Maximum(Extremum):
177
178 def compare(self, val1, val2):
179 return max(val1, val2, key=_val)
180
181
182 class Sample(Accumulator):
183 """takes the first eligable value"""
184
185 def __init__(self, aggregation):
186 super(Sample, self).__init__(aggregation)
187 # DISTINCT would not change the value
188 self.use_row = self.dont_care
189
190 def update(self, row, aggregator):
191 try:
192 # set the value now
193 aggregator.bindings[self.var] = _eval(self.expr, row)
194 # and skip this accumulator for future rows
195 del aggregator.accumulators[self.var]
196 except NotBoundError:
197 pass
198
199 def get_value(self):
200 # set None if no value was set
201 return None
202
203 class GroupConcat(Accumulator):
204
205 def __init__(self, aggregation):
206 super(GroupConcat, self).__init__(aggregation)
207 # only GROUPCONCAT needs to have a list as accumlator
208 self.value = []
209 self.separator = aggregation.separator or " "
210
211 def update(self, row, aggregator):
212 try:
213 value = _eval(self.expr, row)
214 self.value.append(value)
215 if self.distinct:
216 self.seen.add(value)
217 # skip UNDEF
218 except NotBoundError:
219 pass
220
221 def get_value(self):
222 return Literal(self.separator.join(str(v) for v in self.value))
223
224
225 class Aggregator(object):
226 """combines different Accumulator objects"""
227
228 accumulator_classes = {
229 "Aggregate_Count": Counter,
230 "Aggregate_Sample": Sample,
231 "Aggregate_Sum": Sum,
232 "Aggregate_Avg": Average,
233 "Aggregate_Min": Minimum,
234 "Aggregate_Max": Maximum,
235 "Aggregate_GroupConcat": GroupConcat,
236 }
237
238 def __init__(self, aggregations):
239 self.bindings = {}
240 self.accumulators = {}
241 for a in aggregations:
242 accumulator_class = self.accumulator_classes.get(a.name)
243 if accumulator_class is None:
244 raise Exception("Unknown aggregate function " + a.name)
245 self.accumulators[a.res] = accumulator_class(a)
246
247 def update(self, row):
248 """update all own accumulators"""
249 # SAMPLE accumulators may delete themselves
250 # => iterate over list not generator
251
252 for acc in list(self.accumulators.values()):
253 if acc.use_row(row):
254 acc.update(row, self)
255
256 def get_bindings(self):
257 """calculate and set last values"""
258 for acc in self.accumulators.values():
259 acc.set_value(self.bindings)
260 return self.bindings