Source code for stupidb.associative.segmenttree
from __future__ import annotations
import collections
import math
from typing import Generic, Iterator, MutableSequence, Sequence
from ..aggregator import Aggregator
from ..functions.associative.core import AssociativeAggregate
from ..typehints import Result, T
from . import indextree
from .bitset import BitSet
[docs]def make_segment_tree(
leaf_arguments: Sequence[tuple[T, ...]],
aggregate_type: type[AssociativeAggregate],
*,
fanout: int,
) -> Sequence[AssociativeAggregate]:
"""Make a segment tree from tuples `leaves` and class `aggregate_type`.
The algorithm used here traverses from the bottom of tree upward, updating
the parent every time a new node is seen.
Parameters
----------
leaves
A sequence of tuples that make up the leaves of the segment tree
aggregate_type
The aggregate class whose instances compose the tree.
"""
number_of_leaves = len(leaf_arguments)
index_tree = indextree.IndexTree(
height=int(math.ceil(math.log(number_of_leaves, fanout))) + 1,
fanout=fanout,
)
num_nodes = len(index_tree)
segment_tree_nodes: MutableSequence[AssociativeAggregate] = [
aggregate_type() for _ in range(num_nodes)
]
queue = collections.deque(index_tree.leaves)
# seed the leaves
for leaf_index, args in zip(queue, leaf_arguments):
segment_tree_nodes[leaf_index].step(*args)
seen = BitSet()
while queue:
node = queue.popleft()
if node not in seen:
seen.add(node)
node_agg = segment_tree_nodes[node]
parent = index_tree.parent(node)
parent_agg = segment_tree_nodes[parent]
parent_agg.combine(node_agg)
if parent:
# don't append the root, since we've already aggregated into
# that node if parent == 0
queue.append(parent)
return segment_tree_nodes
[docs]class SegmentTree(
Generic[T, AssociativeAggregate, Result],
Aggregator[AssociativeAggregate, Result],
):
"""A segment tree for window aggregation.
Attributes
----------
nodes
The nodes of the segment tree
aggregate_type
The class of the aggregate to use
levels
A list of the nodes in each level of the tree
fanout
The number of leaves to aggregate into each interior node
"""
__slots__ = "nodes", "aggregate_type", "levels", "fanout"
def __init__(
self,
leaves: Sequence[tuple[T | None, ...]],
aggregate_type: type[AssociativeAggregate],
*,
fanout: int,
) -> None:
"""Construct a segment tree."""
self.nodes: Sequence[AssociativeAggregate] = make_segment_tree(
leaves, aggregate_type, fanout=fanout
)
self.aggregate_type: type[AssociativeAggregate] = aggregate_type
self.fanout = fanout
self.height = int(math.ceil(math.log(len(leaves), fanout))) + 1
self.levels: Sequence[Sequence[AssociativeAggregate]] = list(
self.iterlevels(self.nodes, fanout=fanout)
)
[docs] @staticmethod
def iterlevels(
nodes: Sequence[AssociativeAggregate], *, fanout: int
) -> Iterator[Sequence[AssociativeAggregate]]:
"""Iterate over every level in the tree.
Parameters
----------
nodes
The nodes of the tree whose levels will be yielded.
fanout
The number child nodes per interior node
"""
height = int(math.ceil(math.log(len(nodes), fanout)))
for level in range(height):
start = indextree.first_node(level, fanout=fanout)
stop = indextree.last_node(level, fanout=fanout)
yield nodes[start:stop]
def __repr__(self) -> str:
return indextree.reprtree(self.nodes, fanout=self.fanout)
[docs] def query(self, begin: int, end: int) -> Result | None:
"""Aggregate the values between `begin` and `end` using `aggregate`.
Parameters
----------
begin
The start of the range to aggregate
end
The end of the range to aggregate
"""
fanout = self.fanout
aggregate: AssociativeAggregate = self.aggregate_type()
for level in reversed(self.levels):
parent_begin = begin // fanout
parent_end = end // fanout
if parent_begin == parent_end:
for item in level[begin:end]:
aggregate.combine(item)
return aggregate.finalize()
group_begin = parent_begin * fanout
if begin != group_begin:
limit = group_begin + fanout
for item in level[begin:limit]:
aggregate.combine(item)
parent_begin += 1
group_end = parent_end * fanout
if end != group_end:
for item in level[group_end:end]:
aggregate.combine(item)
begin = parent_begin
end = parent_end
return None # pragma: no cover