diff planemo/lib/python3.7/site-packages/rdflib/plugins/sparql/aggregates.py @ 1:56ad4e20f292 draft

"planemo upload commit 6eee67778febed82ddd413c3ca40b3183a3898f1"
author guerler
date Fri, 31 Jul 2020 00:32:28 -0400
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/planemo/lib/python3.7/site-packages/rdflib/plugins/sparql/aggregates.py	Fri Jul 31 00:32:28 2020 -0400
@@ -0,0 +1,260 @@
+from rdflib import Literal, XSD
+
+from rdflib.plugins.sparql.evalutils import _eval, NotBoundError, _val
+from rdflib.plugins.sparql.operators import numeric
+from rdflib.plugins.sparql.datatypes import type_promotion
+
+from rdflib.plugins.sparql.sparql import SPARQLTypeError
+
+from decimal import Decimal
+
+"""
+Aggregation functions
+"""
+
+class Accumulator(object):
+    """abstract base class for different aggregation functions """
+
+    def __init__(self, aggregation):
+        self.var = aggregation.res
+        self.expr = aggregation.vars
+        if not aggregation.distinct:
+            self.use_row = self.dont_care
+            self.distinct = False
+        else:
+            self.distinct = aggregation.distinct
+            self.seen = set()
+
+    def dont_care(self, row):
+        """skips distinct test """
+        return True
+
+    def use_row(self, row):
+        """tests distinct with set """
+        return _eval(self.expr, row) not in self.seen
+
+    def set_value(self, bindings):
+        """sets final value in bindings"""
+        bindings[self.var] = self.get_value()
+
+
+class Counter(Accumulator):
+
+    def __init__(self, aggregation):
+        super(Counter, self).__init__(aggregation)
+        self.value = 0
+        if self.expr == "*":
+            # cannot eval "*" => always use the full row
+            self.eval_row = self.eval_full_row
+
+    def update(self, row, aggregator):
+        try:
+            val = self.eval_row(row)
+        except NotBoundError:
+            # skip UNDEF
+            return
+        self.value += 1
+        if self.distinct:
+            self.seen.add(val)
+
+    def get_value(self):
+        return Literal(self.value)
+
+    def eval_row(self, row):
+        return _eval(self.expr, row)
+
+    def eval_full_row(self, row):
+        return row
+
+    def use_row(self, row):
+        return self.eval_row(row) not in self.seen
+
+
+def type_safe_numbers(*args):
+    types = list(map(type, args))
+    if float in types and Decimal in types:
+        return list(map(float, args))
+    return args
+
+
+class Sum(Accumulator):
+
+    def __init__(self, aggregation):
+        super(Sum, self).__init__(aggregation)
+        self.value = 0
+        self.datatype = None
+
+    def update(self, row, aggregator):
+        try:
+            value = _eval(self.expr, row)
+            dt = self.datatype
+            if dt is None:
+                dt = value.datatype
+            else:
+                dt = type_promotion(dt, value.datatype)
+            self.datatype = dt
+            self.value = sum(type_safe_numbers(self.value, numeric(value)))
+            if self.distinct:
+                self.seen.add(value)
+        except NotBoundError:
+            # skip UNDEF
+            pass
+
+    def get_value(self):
+        return Literal(self.value, datatype=self.datatype)
+
+class Average(Accumulator):
+
+    def __init__(self, aggregation):
+        super(Average, self).__init__(aggregation)
+        self.counter = 0
+        self.sum = 0
+        self.datatype = None
+
+    def update(self, row, aggregator):
+        try:
+            value = _eval(self.expr, row)
+            dt = self.datatype
+            self.sum = sum(type_safe_numbers(self.sum, numeric(value)))
+            if dt is None:
+                dt = value.datatype
+            else:
+                dt = type_promotion(dt, value.datatype)
+            self.datatype = dt
+            if self.distinct:
+                self.seen.add(value)
+            self.counter += 1
+        # skip UNDEF or BNode => SPARQLTypeError
+        except NotBoundError:
+            pass
+        except SPARQLTypeError:
+            pass
+
+    def get_value(self):
+        if self.counter == 0:
+            return Literal(0)
+        if self.datatype in (XSD.float, XSD.double):
+            return Literal(self.sum / self.counter)
+        else:
+            return Literal(Decimal(self.sum) / Decimal(self.counter))
+
+
+class Extremum(Accumulator):
+    """abstract base class for Minimum and Maximum"""
+
+    def __init__(self, aggregation):
+        super(Extremum, self).__init__(aggregation)
+        self.value = None
+        # DISTINCT would not change the value for MIN or MAX
+        self.use_row = self.dont_care
+
+    def set_value(self, bindings):
+        if self.value is not None:
+            # simply do not set if self.value is still None
+            bindings[self.var] = Literal(self.value)
+
+    def update(self, row, aggregator):
+        try:
+            if self.value is None:
+                self.value = _eval(self.expr, row)
+            else:
+                # self.compare is implemented by Minimum/Maximum
+                self.value = self.compare(self.value, _eval(self.expr, row))
+        # skip UNDEF or BNode => SPARQLTypeError
+        except NotBoundError:
+            pass
+        except SPARQLTypeError:
+            pass
+
+
+class Minimum(Extremum):
+
+    def compare(self, val1, val2):
+        return min(val1, val2, key=_val)
+
+
+class Maximum(Extremum):
+
+    def compare(self, val1, val2):
+        return max(val1, val2, key=_val)
+
+
+class Sample(Accumulator):
+    """takes the first eligable value"""
+
+    def __init__(self, aggregation):
+        super(Sample, self).__init__(aggregation)
+        # DISTINCT would not change the value
+        self.use_row = self.dont_care
+
+    def update(self, row, aggregator):
+        try:
+            # set the value now
+            aggregator.bindings[self.var] =  _eval(self.expr, row)
+            # and skip this accumulator for future rows
+            del aggregator.accumulators[self.var]
+        except NotBoundError:
+            pass
+
+    def get_value(self):
+        # set None if no value was set
+        return None
+
+class GroupConcat(Accumulator):
+
+    def __init__(self, aggregation):
+        super(GroupConcat, self).__init__(aggregation)
+        # only GROUPCONCAT needs to have a list as accumlator
+        self.value = []
+        self.separator = aggregation.separator or " "
+
+    def update(self, row, aggregator):
+        try:
+            value = _eval(self.expr, row)
+            self.value.append(value)
+            if self.distinct:
+                self.seen.add(value)
+        # skip UNDEF
+        except NotBoundError:
+            pass
+
+    def get_value(self):
+        return Literal(self.separator.join(str(v) for v in self.value))
+
+
+class Aggregator(object):
+    """combines different Accumulator objects"""
+
+    accumulator_classes = {
+        "Aggregate_Count": Counter,
+        "Aggregate_Sample": Sample,
+        "Aggregate_Sum": Sum,
+        "Aggregate_Avg": Average,
+        "Aggregate_Min": Minimum,
+        "Aggregate_Max": Maximum,
+        "Aggregate_GroupConcat": GroupConcat,
+    }
+
+    def __init__(self, aggregations):
+        self.bindings = {}
+        self.accumulators = {}
+        for a in aggregations:
+            accumulator_class = self.accumulator_classes.get(a.name)
+            if accumulator_class is None:
+                raise Exception("Unknown aggregate function " + a.name)
+            self.accumulators[a.res] = accumulator_class(a)
+
+    def update(self, row):
+        """update all own accumulators"""
+        # SAMPLE accumulators may delete themselves
+        # => iterate over list not generator
+
+        for acc in list(self.accumulators.values()):
+            if acc.use_row(row):
+                acc.update(row, self)
+
+    def get_bindings(self):
+        """calculate and set last values"""
+        for acc in self.accumulators.values():
+            acc.set_value(self.bindings)
+        return self.bindings