Mercurial > repos > guerler > springsuite
comparison planemo/lib/python3.7/site-packages/rdflib/plugins/sparql/aggregates.py @ 1:56ad4e20f292 draft
"planemo upload commit 6eee67778febed82ddd413c3ca40b3183a3898f1"
| author | guerler |
|---|---|
| date | Fri, 31 Jul 2020 00:32:28 -0400 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| 0:d30785e31577 | 1:56ad4e20f292 |
|---|---|
| 1 from rdflib import Literal, XSD | |
| 2 | |
| 3 from rdflib.plugins.sparql.evalutils import _eval, NotBoundError, _val | |
| 4 from rdflib.plugins.sparql.operators import numeric | |
| 5 from rdflib.plugins.sparql.datatypes import type_promotion | |
| 6 | |
| 7 from rdflib.plugins.sparql.sparql import SPARQLTypeError | |
| 8 | |
| 9 from decimal import Decimal | |
| 10 | |
| 11 """ | |
| 12 Aggregation functions | |
| 13 """ | |
| 14 | |
| 15 class Accumulator(object): | |
| 16 """abstract base class for different aggregation functions """ | |
| 17 | |
| 18 def __init__(self, aggregation): | |
| 19 self.var = aggregation.res | |
| 20 self.expr = aggregation.vars | |
| 21 if not aggregation.distinct: | |
| 22 self.use_row = self.dont_care | |
| 23 self.distinct = False | |
| 24 else: | |
| 25 self.distinct = aggregation.distinct | |
| 26 self.seen = set() | |
| 27 | |
| 28 def dont_care(self, row): | |
| 29 """skips distinct test """ | |
| 30 return True | |
| 31 | |
| 32 def use_row(self, row): | |
| 33 """tests distinct with set """ | |
| 34 return _eval(self.expr, row) not in self.seen | |
| 35 | |
| 36 def set_value(self, bindings): | |
| 37 """sets final value in bindings""" | |
| 38 bindings[self.var] = self.get_value() | |
| 39 | |
| 40 | |
| 41 class Counter(Accumulator): | |
| 42 | |
| 43 def __init__(self, aggregation): | |
| 44 super(Counter, self).__init__(aggregation) | |
| 45 self.value = 0 | |
| 46 if self.expr == "*": | |
| 47 # cannot eval "*" => always use the full row | |
| 48 self.eval_row = self.eval_full_row | |
| 49 | |
| 50 def update(self, row, aggregator): | |
| 51 try: | |
| 52 val = self.eval_row(row) | |
| 53 except NotBoundError: | |
| 54 # skip UNDEF | |
| 55 return | |
| 56 self.value += 1 | |
| 57 if self.distinct: | |
| 58 self.seen.add(val) | |
| 59 | |
| 60 def get_value(self): | |
| 61 return Literal(self.value) | |
| 62 | |
| 63 def eval_row(self, row): | |
| 64 return _eval(self.expr, row) | |
| 65 | |
| 66 def eval_full_row(self, row): | |
| 67 return row | |
| 68 | |
| 69 def use_row(self, row): | |
| 70 return self.eval_row(row) not in self.seen | |
| 71 | |
| 72 | |
| 73 def type_safe_numbers(*args): | |
| 74 types = list(map(type, args)) | |
| 75 if float in types and Decimal in types: | |
| 76 return list(map(float, args)) | |
| 77 return args | |
| 78 | |
| 79 | |
| 80 class Sum(Accumulator): | |
| 81 | |
| 82 def __init__(self, aggregation): | |
| 83 super(Sum, self).__init__(aggregation) | |
| 84 self.value = 0 | |
| 85 self.datatype = None | |
| 86 | |
| 87 def update(self, row, aggregator): | |
| 88 try: | |
| 89 value = _eval(self.expr, row) | |
| 90 dt = self.datatype | |
| 91 if dt is None: | |
| 92 dt = value.datatype | |
| 93 else: | |
| 94 dt = type_promotion(dt, value.datatype) | |
| 95 self.datatype = dt | |
| 96 self.value = sum(type_safe_numbers(self.value, numeric(value))) | |
| 97 if self.distinct: | |
| 98 self.seen.add(value) | |
| 99 except NotBoundError: | |
| 100 # skip UNDEF | |
| 101 pass | |
| 102 | |
| 103 def get_value(self): | |
| 104 return Literal(self.value, datatype=self.datatype) | |
| 105 | |
| 106 class Average(Accumulator): | |
| 107 | |
| 108 def __init__(self, aggregation): | |
| 109 super(Average, self).__init__(aggregation) | |
| 110 self.counter = 0 | |
| 111 self.sum = 0 | |
| 112 self.datatype = None | |
| 113 | |
| 114 def update(self, row, aggregator): | |
| 115 try: | |
| 116 value = _eval(self.expr, row) | |
| 117 dt = self.datatype | |
| 118 self.sum = sum(type_safe_numbers(self.sum, numeric(value))) | |
| 119 if dt is None: | |
| 120 dt = value.datatype | |
| 121 else: | |
| 122 dt = type_promotion(dt, value.datatype) | |
| 123 self.datatype = dt | |
| 124 if self.distinct: | |
| 125 self.seen.add(value) | |
| 126 self.counter += 1 | |
| 127 # skip UNDEF or BNode => SPARQLTypeError | |
| 128 except NotBoundError: | |
| 129 pass | |
| 130 except SPARQLTypeError: | |
| 131 pass | |
| 132 | |
| 133 def get_value(self): | |
| 134 if self.counter == 0: | |
| 135 return Literal(0) | |
| 136 if self.datatype in (XSD.float, XSD.double): | |
| 137 return Literal(self.sum / self.counter) | |
| 138 else: | |
| 139 return Literal(Decimal(self.sum) / Decimal(self.counter)) | |
| 140 | |
| 141 | |
| 142 class Extremum(Accumulator): | |
| 143 """abstract base class for Minimum and Maximum""" | |
| 144 | |
| 145 def __init__(self, aggregation): | |
| 146 super(Extremum, self).__init__(aggregation) | |
| 147 self.value = None | |
| 148 # DISTINCT would not change the value for MIN or MAX | |
| 149 self.use_row = self.dont_care | |
| 150 | |
| 151 def set_value(self, bindings): | |
| 152 if self.value is not None: | |
| 153 # simply do not set if self.value is still None | |
| 154 bindings[self.var] = Literal(self.value) | |
| 155 | |
| 156 def update(self, row, aggregator): | |
| 157 try: | |
| 158 if self.value is None: | |
| 159 self.value = _eval(self.expr, row) | |
| 160 else: | |
| 161 # self.compare is implemented by Minimum/Maximum | |
| 162 self.value = self.compare(self.value, _eval(self.expr, row)) | |
| 163 # skip UNDEF or BNode => SPARQLTypeError | |
| 164 except NotBoundError: | |
| 165 pass | |
| 166 except SPARQLTypeError: | |
| 167 pass | |
| 168 | |
| 169 | |
| 170 class Minimum(Extremum): | |
| 171 | |
| 172 def compare(self, val1, val2): | |
| 173 return min(val1, val2, key=_val) | |
| 174 | |
| 175 | |
| 176 class Maximum(Extremum): | |
| 177 | |
| 178 def compare(self, val1, val2): | |
| 179 return max(val1, val2, key=_val) | |
| 180 | |
| 181 | |
| 182 class Sample(Accumulator): | |
| 183 """takes the first eligable value""" | |
| 184 | |
| 185 def __init__(self, aggregation): | |
| 186 super(Sample, self).__init__(aggregation) | |
| 187 # DISTINCT would not change the value | |
| 188 self.use_row = self.dont_care | |
| 189 | |
| 190 def update(self, row, aggregator): | |
| 191 try: | |
| 192 # set the value now | |
| 193 aggregator.bindings[self.var] = _eval(self.expr, row) | |
| 194 # and skip this accumulator for future rows | |
| 195 del aggregator.accumulators[self.var] | |
| 196 except NotBoundError: | |
| 197 pass | |
| 198 | |
| 199 def get_value(self): | |
| 200 # set None if no value was set | |
| 201 return None | |
| 202 | |
| 203 class GroupConcat(Accumulator): | |
| 204 | |
| 205 def __init__(self, aggregation): | |
| 206 super(GroupConcat, self).__init__(aggregation) | |
| 207 # only GROUPCONCAT needs to have a list as accumlator | |
| 208 self.value = [] | |
| 209 self.separator = aggregation.separator or " " | |
| 210 | |
| 211 def update(self, row, aggregator): | |
| 212 try: | |
| 213 value = _eval(self.expr, row) | |
| 214 self.value.append(value) | |
| 215 if self.distinct: | |
| 216 self.seen.add(value) | |
| 217 # skip UNDEF | |
| 218 except NotBoundError: | |
| 219 pass | |
| 220 | |
| 221 def get_value(self): | |
| 222 return Literal(self.separator.join(str(v) for v in self.value)) | |
| 223 | |
| 224 | |
| 225 class Aggregator(object): | |
| 226 """combines different Accumulator objects""" | |
| 227 | |
| 228 accumulator_classes = { | |
| 229 "Aggregate_Count": Counter, | |
| 230 "Aggregate_Sample": Sample, | |
| 231 "Aggregate_Sum": Sum, | |
| 232 "Aggregate_Avg": Average, | |
| 233 "Aggregate_Min": Minimum, | |
| 234 "Aggregate_Max": Maximum, | |
| 235 "Aggregate_GroupConcat": GroupConcat, | |
| 236 } | |
| 237 | |
| 238 def __init__(self, aggregations): | |
| 239 self.bindings = {} | |
| 240 self.accumulators = {} | |
| 241 for a in aggregations: | |
| 242 accumulator_class = self.accumulator_classes.get(a.name) | |
| 243 if accumulator_class is None: | |
| 244 raise Exception("Unknown aggregate function " + a.name) | |
| 245 self.accumulators[a.res] = accumulator_class(a) | |
| 246 | |
| 247 def update(self, row): | |
| 248 """update all own accumulators""" | |
| 249 # SAMPLE accumulators may delete themselves | |
| 250 # => iterate over list not generator | |
| 251 | |
| 252 for acc in list(self.accumulators.values()): | |
| 253 if acc.use_row(row): | |
| 254 acc.update(row, self) | |
| 255 | |
| 256 def get_bindings(self): | |
| 257 """calculate and set last values""" | |
| 258 for acc in self.accumulators.values(): | |
| 259 acc.set_value(self.bindings) | |
| 260 return self.bindings |
