"""Algorithms for aggregation."""
from __future__ import annotations
import abc
import bisect
import enum
import functools
import typing
from typing import (
Any,
Callable,
Generic,
Iterable,
Iterator,
NamedTuple,
Sequence,
TypeVar,
)
import toolz
from .aggregator import Aggregate, Aggregator
from .functions.associative import BinaryAssociativeAggregate, UnaryAssociativeAggregate
from .functions.navigation import (
BinaryNavigationAggregate,
TernaryNavigationAggregate,
UnaryNavigationAggregate,
)
from .functions.ranking import RankingAggregate
from .protocols import Comparable
from .row import AbstractRow
from .typehints import Following, OrderBy, OrderingKey, PartitionBy, Preceding, T
[docs]class StartStop(NamedTuple):
"""A class to hold start and stop values for a range of rows."""
start: int
stop: int
[docs]@enum.unique
class Nulls(enum.Enum):
"""An enumeration indicating how to handle null values when sorting."""
FIRST = -1
LAST = 1
[docs]class FrameClause(abc.ABC):
"""Class for computing frame boundaries."""
__slots__ = "order_by", "partition_by", "preceding", "following", "nulls"
def __init__(
self,
order_by: Sequence[OrderBy],
partition_by: Sequence[PartitionBy],
preceding: Preceding | None,
following: Following | None,
nulls: Nulls,
) -> None:
self.order_by = order_by
self.partition_by = partition_by
self.preceding = preceding
self.following = following
self.nulls = nulls
[docs] @abc.abstractmethod
def find_partition_begin(
self,
current_row: AbstractRow,
row_id_in_partition: int,
current_row_order_by_value: OrderingKey | None,
order_by_values: Sequence[OrderingKey],
) -> int:
"""Find the beginning of a window in a partition.
Parameters
----------
current_row
The row relative to which we are computing the window.
row_id_in_partition
The zero-based index of `current_row` in possible_peers.
current_row_order_by_value
The value of the ORDER BY key in the current row.
order_by_values
The order by values for the current partition.
Returns
-------
int
The start point of the window in the current partition
"""
[docs] @abc.abstractmethod
def find_partition_end(
self,
current_row: AbstractRow,
row_id_in_partition: int,
current_row_order_by_value: OrderingKey | None,
order_by_values: Sequence[OrderingKey],
) -> int:
"""Find the end of a window in a partition.
Parameters
----------
current_row
The row relative to which we are computing the window.
row_id_in_partition
The zero-based index of `current_row` in possible_peers.
current_row_order_by_value
The value of the ORDER BY key in the current row.
order_by_values
The order by values for the current partition.
Returns
-------
int
The end point of the window in the current partition
"""
[docs] @abc.abstractmethod
def setup_window(
self,
possible_peers: Sequence[AbstractRow],
current_row: AbstractRow,
order_by_columns: Sequence[str],
) -> tuple[OrderingKey, Sequence[OrderingKey]]:
"""Compute the current row's ordering keys."""
[docs] def compute_window_frame(
self,
possible_peers: Sequence[AbstractRow],
current_row: AbstractRow,
row_id_in_partition: int,
order_by_columns: Sequence[str],
) -> StartStop:
"""Compute the bounds of the window frame.
Parameters
----------
possible_peers
The sequence of possible rows of which the window could consist.
current_row
The row relative to which we are computing the window.
row_id_in_partition
The zero-based index of `current_row` in possible_peers.
order_by_columns
The columns by which we have ordered our window, if any.
Returns
-------
StartStop
The start and stop of the window frame.
"""
current_row_order_by_value, order_by_values = self.setup_window(
possible_peers, current_row, order_by_columns
)
preceding = self.preceding
if preceding is not None:
start = self.find_partition_begin(
current_row,
row_id_in_partition,
current_row_order_by_value,
order_by_values,
)
else:
start = 0
npeers = len(possible_peers)
following = self.following
if following is not None:
stop = self.find_partition_end(
current_row,
row_id_in_partition,
current_row_order_by_value,
order_by_values,
)
else:
if not all(order_by_values):
# if we don't have an order by then all possible peers are the
# actual peers of this row
stop = npeers
else:
# default to the current row if following is not provided. This
# is consistent with the defaults in at least PostgreSQL and
# SQLite.
stop = row_id_in_partition + 1
new_start = max(start, 0)
new_stop = min(stop, npeers)
return StartStop(new_start, new_stop)
[docs]class RowsMode(FrameClause):
"""A frame clause implementation for window function ``ROWS`` mode.
``ROWS`` mode computes the window frame relative to the difference between
the row index of the current row and what is given by ``preceding`` and
``following``.
See Also
--------
RangeMode
"""
__slots__ = ()
[docs] def find_partition_begin(
self,
current_row: AbstractRow,
row_id_in_partition: int,
current_row_order_by_value: OrderingKey | None,
order_by_values: Sequence[OrderingKey],
) -> int: # noqa: D102
preceding = self.preceding
assert preceding is not None, "preceding is None"
return row_id_in_partition - typing.cast(int, preceding(current_row))
[docs] def find_partition_end(
self,
current_row: AbstractRow,
row_id_in_partition: int,
current_row_order_by_value: OrderingKey | None,
order_by_values: Sequence[OrderingKey],
) -> int: # noqa: D102
following = self.following
assert following is not None, "following is None"
return row_id_in_partition + typing.cast(int, following(current_row)) + 1
[docs] def setup_window(
self,
possible_peers: Sequence[AbstractRow],
current_row: AbstractRow,
order_by_columns: Sequence[str],
) -> tuple[OrderingKey, Sequence[OrderingKey]]: # noqa: D102
cols = [
tuple(map(peer.__getitem__, order_by_columns)) for peer in possible_peers
]
return tuple(map(current_row.__getitem__, order_by_columns)), cols
[docs]class RangeMode(FrameClause):
"""A frame clause implementation for window function ``RANGE`` mode.
``RANGE`` mode computes the window frame relative to the difference between
``preceding`` and ``following`` and the current row's ordering key.
See Also
--------
RowsMode
"""
__slots__ = ()
def __init__(
self,
order_by: Sequence[OrderBy],
partition_by: Sequence[PartitionBy],
preceding: Preceding | None,
following: Following | None,
nulls: Nulls,
) -> None:
n_order_by = len(order_by)
if n_order_by > 1:
raise ValueError(
"Must have exactly ONE order by to use range windows. "
f"Got {n_order_by:d} functions."
)
super().__init__(order_by, partition_by, preceding, following, nulls)
[docs] def setup_window(
self,
possible_peers: Sequence[AbstractRow],
current_row: AbstractRow,
order_by_columns: Sequence[str],
) -> tuple[OrderingKey, Sequence[OrderingKey]]: # noqa: D102
# range mode allows no order by
if not order_by_columns:
return (), [()]
ncolumns = len(order_by_columns)
assert ncolumns == 1, f"ncolumns == {ncolumns:d}"
(order_by_column,) = order_by_columns
order_by_values = [(peer[order_by_column],) for peer in possible_peers]
current_row_order_by_value = (current_row[order_by_column],)
return current_row_order_by_value, order_by_values
[docs] def find_partition_begin(
self,
current_row: AbstractRow,
row_id_in_partition: int,
current_row_order_by_values: OrderingKey | None,
order_by_values: Sequence[OrderingKey],
) -> int: # noqa: D102
assert (
current_row_order_by_values is not None
), "current_row_order_by_value is None"
preceding = self.preceding
assert preceding is not None, "preceding function is None"
if not current_row_order_by_values:
return 0
assert len(current_row_order_by_values) == 1
(current_row_order_by_value,) = current_row_order_by_values
value_to_find = current_row_order_by_value - preceding(current_row)
return bisect.bisect_left(order_by_values, (value_to_find,))
[docs] def find_partition_end(
self,
current_row: AbstractRow,
row_id_in_partition: int,
current_row_order_by_values: OrderingKey | None,
order_by_values: Sequence[OrderingKey],
) -> int: # noqa: D102
assert (
current_row_order_by_values is not None
), "current_row_order_by_values is None"
following = self.following
assert following is not None, "following function is None"
if not current_row_order_by_values:
return len(order_by_values)
assert len(current_row_order_by_values) == 1
(current_row_order_by_value,) = current_row_order_by_values
value_to_find = current_row_order_by_value + following(current_row)
return bisect.bisect_right(order_by_values, (value_to_find,))
[docs]class Window:
"""A namespace class providing the user-facing API for windowing modes."""
__slots__ = ()
[docs] @staticmethod
def rows(
order_by: Sequence[OrderBy] = (),
partition_by: Sequence[PartitionBy] = (),
preceding: Preceding | None = None,
following: Following | None = None,
nulls: Nulls = Nulls.FIRST,
) -> FrameClause:
"""Construct a ``ROWS`` mode frame clause.
``ROWS`` windows are useful for computing over windows that can be
determined by relative row index alone.
See Also
--------
Window.range
"""
return RowsMode(order_by, partition_by, preceding, following, nulls)
[docs] @staticmethod
def range(
order_by: Sequence[OrderBy] = (),
partition_by: Sequence[PartitionBy] = (),
preceding: Preceding | None = None,
following: Following | None = None,
nulls: Nulls = Nulls.FIRST,
) -> FrameClause:
"""Construct a ``RANGE`` mode frame clause.
``RANGE`` windows can be used to compute over windows whose bounds are
not easily determined by row number such as time based windows.
See Also
--------
Window.rows
"""
return RangeMode(order_by, partition_by, preceding, following, nulls)
Getter = Callable[[AbstractRow], Any]
ConcreteAggregate = TypeVar(
"ConcreteAggregate",
UnaryAssociativeAggregate,
BinaryAssociativeAggregate,
UnaryNavigationAggregate,
BinaryNavigationAggregate,
TernaryNavigationAggregate,
RankingAggregate,
)
[docs]class AggregateSpecification(Generic[ConcreteAggregate]):
"""Specification for computing a (non-windowed) aggregation.
Attributes
----------
aggregate_type
The aggregate class to use for aggregation.
getters
A tuple of callables used to produce the arguments for the aggregation.
See Also
--------
WindowAggregateSpecification
"""
__slots__ = "aggregate_type", "getters"
def __init__(
self,
aggregate_type: type[ConcreteAggregate],
*getters: Getter,
) -> None:
self.aggregate_type: type[ConcreteAggregate] = aggregate_type
self.getters = getters
[docs]def row_key_compare(
order_func: Callable[[AbstractRow], tuple[Comparable[T], ...]],
null_ordering: Nulls,
left_row: AbstractRow,
right_row: AbstractRow,
) -> int:
"""Compare `left_row` and `right_row` using `order_by`.
Notes
-----
``NULL`` ordering is handled using `null_ordering`.
"""
for left_key, right_key in zip(order_func(left_row), order_func(right_row)):
if left_key is None and right_key is not None:
return null_ordering.value
if left_key is not None and right_key is None:
return -null_ordering.value
if left_key is None and right_key is None:
return 0
assert left_key is not None, "left_key is None"
assert right_key is not None, "right_key is None"
if left_key < right_key:
return -1
if left_key > right_key:
return 1
return 0
[docs]def make_key_func(
order_func: Callable[[AbstractRow], tuple[Comparable[T], ...]],
nulls: Nulls,
) -> Callable[[AbstractRow], OrderingKey[T]]:
"""Make a function usable with the key argument to sorting functions.
This return value of this function can be passed to
:func:`sorted`/:meth:`list.sort`.
Parameters
----------
order_by_columns
A sequence of :class:`str` instances referring to the keys of an
:class:`~stupidb.row.AbstractRow`.
"""
return functools.cmp_to_key(functools.partial(row_key_compare, order_func, nulls))
[docs]class WindowAggregateSpecification(Generic[ConcreteAggregate]):
"""A specification for a window aggregate.
Attributes
----------
aggregate_type
The class of :data:`~stupidb.aggregation.ConcreteAggregate` to use for
aggregation.
getters
A tuple of functions that produce single column values given an
instance of :class:`~stupidb.row.AbstractRow`.
frame_clause
A thin struct encapsulating the details of the window such as ``ORDER
BY`` (:attr:`stupidb.aggregation.FrameClause.order_by`), ``PARTITION
BY`` (:attr:`stupidb.aggregation.FrameClause.partition_by`) and
preceding and following.
See Also
--------
stupidb.aggregation.FrameClause
"""
__slots__ = "aggregate_type", "getters", "frame_clause"
def __init__(
self,
aggregate_type: type[ConcreteAggregate],
getters: tuple[Getter, ...],
frame_clause: FrameClause,
) -> None:
self.aggregate_type: type[ConcreteAggregate] = aggregate_type
self.getters = getters
self.frame_clause = frame_clause
[docs] def compute(self, rows: Iterable[AbstractRow]) -> Iterator[T | None]:
"""Aggregate `rows` over a window, producing an iterator of results.
Parameters
----------
rows
An :class:`~typing.Iterable` of rows.
"""
frame_clause = self.frame_clause
order_by = frame_clause.order_by
# Generate names for temporary order by columns, users never see these.
#
# TODO: If we had static schema information these wouldn't be necessary
# in cases where the ordering keys are named columns (either physical
# or computed)
order_by_columns = [f"_order_by_{i:d}" for i in range(len(order_by))]
# Add computed order by columns that are used when evaluating window
# functions in range mode
# TODO: check that if in range mode we only have single order by
order_func = toolz.juxt(*order_by)
rows_for_partition = (
row.merge(dict(zip(order_by_columns, order_func(row)))) for row in rows
)
# divide the input rows into partitions
#
# we only need the values once we've grouped, so there's no need to
# store a possibly expensive key like a tuple, so we store just the
# integer result of hash in the dict
#
# we also only need the partition values once the rows have
# been partitioned
partitions = toolz.groupby(
toolz.compose(hash, toolz.juxt(*frame_clause.partition_by)),
rows_for_partition,
).values()
# aggregation results, preallocated to avoid the need to sort
# before returning: we later assign elements to this list using
# the original row id
results: list[T | None] = [None] * sum(map(len, partitions))
# Aggregate over each partition
aggregate_type = self.aggregate_type
getters = self.getters
key_func = make_key_func(order_func, frame_clause.nulls)
for possible_peers in partitions:
# sort the partition according to the ordering key
possible_peers.sort(key=key_func)
# Construct an aggregator for the function being computed
#
# For navigation functions like lead, lag, first, last and nth, we
# construct a simple structure that computes the current value
# of the navigation function given the inputs. We use the same
# approach for ranking functions such as row_number, rank, and
# dense_rank.
#
# For associative aggregations we construct a segment tree
# using `arguments` as the leaves, with `aggregate` instances as
# the interior nodes. Each node (both leaves and non-leaves) is a
# state of the aggregation. The leaves are the initial states, the
# root is the final state.
aggregator: Aggregator[Aggregate, T] = aggregate_type.prepare(
possible_peers, getters, order_by_columns
)
# For every row in the set of possible peers of the current row
# compute the window frame, and query the aggregator for the value
# of the aggregation within that frame.
for row_id_in_partition, row in enumerate(possible_peers):
start, stop = frame_clause.compute_window_frame(
possible_peers, row, row_id_in_partition, order_by_columns
)
# Assign the result to the position of the original row id
# because we processed them in partition order, which might not
# be the same as the input order.
results[row._id] = aggregator.query(start, stop)
return iter(results)