Source code for stupidb.core

"""StupiDB. The stupidest database you'll ever come across.

This is project designed to illustate the concepts that underly a typical
relational database implementation, starting at naive execution of table-stakes
features up to rule-based query optimization.

.. warning::
   Please do not use this for any other reason than learning. There are no
   guarantees here except that there will be bugs.

"""

from __future__ import annotations

import abc
import collections
import functools
import itertools
import typing
from typing import Any, Generic, Iterable, Iterator, Mapping

import toolz

from .aggregation import (
    AggregateSpecification,
    Nulls,
    WindowAggregateSpecification,
    row_key_compare,
)
from .functions.associative.core import AssociativeAggregate
from .row import AbstractRow, JoinedRow, Row
from .typehints import (
    JoinPredicate,
    OrderBy,
    PartitionBy,
    PartitionKey,
    Predicate,
    Projector,
)


[docs]class Relation(abc.ABC): """An abstract relation.""" __slots__ = ("partitioners",) def __init__(self) -> None: self.partitioners: Mapping[str, PartitionBy] = {} def __iter__(self) -> Iterator[AbstractRow]: """Iterate over the rows of a :class:`~stupidb.stupidb.Relation`. This method will reify rows with a new row identifier equal to the row number. """ return ( Row.from_mapping(row, _id=id) for id, row in enumerate(filter(None, self._produce())) ) @abc.abstractmethod def _produce(self) -> Iterator[AbstractRow]: """Iterate over the rows of a :class:`~stupidb.stupidb.Relation`. Specific relation should implement this without reifying the row with the row identifier if possible. Reification is handled in the :meth:`~stupidb.stupidb.Relation.__iter__` method. """ def __repr__(self) -> str: from stupidb.api import pretty return pretty(self)
[docs]class Table(Relation): __slots__ = ("rows",) def __init__(self, rows: Iterable[AbstractRow]) -> None: super().__init__() self.rows = rows @classmethod def from_iterable(cls, iterable: Iterable[Mapping[str, Any]]) -> Table: return cls( Row.from_mapping(mapping, _id=i) for i, mapping in enumerate(iterable) ) def _produce(self) -> Iterator[AbstractRow]: return iter(self.rows)
[docs]class Projection(Relation): """A relation representing column selection. Attributes ---------- child aggregations projections """ __slots__ = "child", "aggregations", "projections" def __init__( self, child: Relation, projections: Mapping[str, Projector | WindowAggregateSpecification], ) -> None: super().__init__() self.child = child self.aggregations: Mapping[str, WindowAggregateSpecification] = { aggname: aggspec for aggname, aggspec in projections.items() if isinstance(aggspec, WindowAggregateSpecification) } self.projections: Mapping[str, Projector] = { name: projector for name, projector in projections.items() if callable(projector) } def _produce(self) -> Iterator[AbstractRow]: aggregations = self.aggregations # we need a row iterator for every aggregation to be fully generic # since they potentially share no structure # # one child iter for *all* projections # one child iter for *each* window aggregation child, *rowterators = itertools.tee(self.child, len(aggregations) + 1) aggnames = aggregations.keys() aggvalues = aggregations.values() # The .compute method returns an iterator of aggregation results # Each element of the iterator is the result of a single column in a # single row of the corresponding window function aggrows = ( dict(zip(aggnames, aggrow)) for aggrow in zip( *map( WindowAggregateSpecification.compute, aggvalues, rowterators, ) ) ) projections = self.projections projnames = projections.keys() projvalues = projections.values() projrows = ( dict(zip(projnames, (proj(row) for proj in projvalues))) for row in child ) # Use zip_longest here, because either of aggrows or projrows can be # empty return ( Row(toolz.merge(projrow, aggrow), _id=-1) for aggrow, projrow in itertools.zip_longest( aggrows, projrows, fillvalue={} ) )
[docs]class Mutate(Projection): """A relation representing appending columns to an existing relation.""" __slots__ = () def _produce(self) -> Iterator[AbstractRow]: # reasign self.child here to avoid clobbering its iteration # we need to use it twice: once for the computed columns (self.child) # used during the iteration of super().__iter__() and once for the # original relation (child) child, self.child = itertools.tee(self.child) return ( Row.from_mapping(row, _id=-1) for row in map(toolz.merge, child, super()._produce()) )
[docs]class Aggregation(Generic[AssociativeAggregate], Relation): """A relation representing aggregation of columns.""" __slots__ = "child", "metrics" def __init__( self, child: Relation, metrics: Mapping[str, AggregateSpecification[AssociativeAggregate]], ) -> None: super().__init__() self.child = child self.metrics: Mapping[ str, AggregateSpecification[AssociativeAggregate] ] = metrics def _produce(self) -> Iterator[AbstractRow]: aggregations = self.metrics # initialize aggregates grouped_aggs: Mapping[ PartitionKey, Mapping[str, AssociativeAggregate] ] = collections.defaultdict( lambda: { name: aggspec.aggregate_type() for name, aggspec in aggregations.items() } ) child = typing.cast(Relation, self.child) for row in child: key = tuple( (name, keyfunc(row)) for name, keyfunc in child.partitioners.items() ) for name, agg in grouped_aggs[key].items(): inputs = (getter(row) for getter in aggregations[name].getters) agg.step(*inputs) for grouping_key, aggs in grouped_aggs.items(): data = dict(grouping_key) data.update((name, agg.finalize()) for name, agg in aggs.items()) yield Row.from_mapping(data)
[docs]class Selection(Relation): """A relation of rows selected based on a predicate. Attributes ---------- predicate A callable that takes an :class:`~stupidb.row.AbstractRow` and returns a :class:`bool`. """ __slots__ = "child", "predicate" def __init__(self, child: Relation, predicate: Predicate) -> None: super().__init__() self.child = child self.predicate = predicate def _produce(self) -> Iterator[AbstractRow]: return filter(self.predicate, self.child)
[docs]class GroupBy(Relation): """A relation representing a partitioning of rows by a key. Attributes ---------- group_by A callable that takes an :class:`~stupidb.row.AbstractRow` and returns an instance of :class:`typing.Hashable`. """ __slots__ = "child", "group_by", "partitioners" def __init__(self, child: Relation, group_by: Mapping[str, PartitionBy]) -> None: super().__init__() self.child = child self.partitioners = group_by def _produce(self) -> Iterator[AbstractRow]: return iter(self.child)
[docs]class SortBy(Relation): """A relation representing rows of its child sorted by one or more keys. Attributes ---------- order_by A callable that takes an :class:`~stupidb.row.AbstractRow` and returns an instance of :class:`~stupidb.protocols.Comparable`. null_ordering Whether to place the nulls of a column first or last. """ __slots__ = "child", "order_by", "null_ordering" def __init__( self, child: Relation, order_by: tuple[OrderBy, ...], null_ordering: Nulls ) -> None: super().__init__() self.child = child self.order_by = order_by self.null_ordering = null_ordering def _produce(self) -> Iterator[AbstractRow]: return iter( sorted( self.child, key=functools.cmp_to_key( functools.partial( row_key_compare, toolz.juxt(*self.order_by), self.null_ordering, ) ), ) )
[docs]class Limit(Relation): __slots__ = "child", "offset", "limit" def __init__(self, child: Relation, *, offset: int, limit: int | None) -> None: super().__init__() self.child = child self.offset = offset self.limit = limit def _produce(self) -> Iterator[AbstractRow]: limit = self.limit offset = self.offset return itertools.islice( self.child, offset, None if limit is None else offset + limit, )
[docs]class Join(Relation): __slots__ = "grouped", "rows" def __init__(self, left: Relation, right: Relation) -> None: super().__init__() self.grouped = itertools.groupby( ( JoinedRow(left_row, right_row, _id=-1) for left_row, right_row in itertools.product(left, right) ), key=lambda row: row.left, ) self.rows = itertools.chain.from_iterable(rows for _, rows in self.grouped)
[docs]class CrossJoin(Join): __slots__ = () def _produce(self) -> Iterator[AbstractRow]: return iter(self.rows)
[docs]class InnerJoin(Join): __slots__ = ("predicate",) def __init__( self, left: Relation, right: Relation, predicate: JoinPredicate ) -> None: super().__init__(left, right) self.predicate = predicate def _produce(self) -> Iterator[AbstractRow]: return (row for row in self.rows if self.predicate(row.left, row.right))
[docs]class LeftJoin(Join): __slots__ = ("predicate",) def __init__( self, left: Relation, right: Relation, predicate: JoinPredicate ) -> None: super().__init__(left, right) self.predicate = predicate def _produce(self) -> Iterator[AbstractRow]: for left_row, joined_rows in self.grouped: matched = False for joined_row in joined_rows: right_row = joined_row.right if self.predicate(left_row, right_row): matched = True yield JoinedRow(left_row, right_row, _id=-1) if not matched: yield JoinedRow(left_row, dict.fromkeys(right_row), _id=-1)
[docs]class RightJoin(LeftJoin): __slots__ = () def __init__( self, left: Relation, right: Relation, predicate: JoinPredicate ) -> None: super().__init__(right, left, predicate) def _produce(self) -> Iterator[AbstractRow]: for row in super()._produce(): yield JoinedRow(row.right, row.left)
[docs]class SetOperation(Relation): """An abstract set operation.""" __slots__ = "left", "right" def __init__(self, left: Relation, right: Relation) -> None: super().__init__() self.left = left self.right = right
[docs] @staticmethod def itemize( mappings: Iterable[AbstractRow], ) -> frozenset[tuple[tuple[str, Any], ...]]: """Return a hashable version of `mappings`.""" return frozenset(tuple(mapping.items()) for mapping in mappings)
[docs]class Union(SetOperation): """Union between two relations.""" __slots__ = () def _produce(self) -> Iterator[AbstractRow]: return toolz.unique( itertools.chain(self.left, self.right), key=lambda row: frozenset(row.items()), )
[docs]class UnionAll(SetOperation): """Non-unique union between two relations.""" __slots__ = () def _produce(self) -> Iterator[AbstractRow]: return itertools.chain(self.left, self.right)
[docs]class IntersectAll(SetOperation): """Non-unique intersection between two relations.""" __slots__ = () def _produce(self) -> Iterator[AbstractRow]: left_set = self.itemize(self.left) right_set = self.itemize(self.right) left_filtered = (row_items for row_items in left_set if row_items in right_set) right_filtered = (row_items for row_items in right_set if row_items in left_set) return ( Row.from_mapping(dict(row)) for row in itertools.chain(left_filtered, right_filtered) )
[docs]class Intersect(SetOperation): """Intersection of two relations.""" __slots__ = () def _produce(self) -> Iterator[AbstractRow]: return ( Row.from_mapping(dict(row)) for row in self.itemize(self.left) & self.itemize(self.right) )
[docs]class Difference(SetOperation): """Unique difference between two relations.""" __slots__ = () def _produce(self) -> Iterator[AbstractRow]: right_set = self.itemize(self.right) return toolz.unique( Row.from_mapping(dict(row_items)) for row_items in (tuple(row.items()) for row in self.left) if row_items not in right_set )
[docs]class DifferenceAll(SetOperation): """Non-unique difference between two relations.""" __slots__ = () def _produce(self) -> Iterator[AbstractRow]: right_set = self.itemize(self.right) return ( Row.from_mapping(dict(row_items)) for row_items in (tuple(row.items()) for row in self.left) if row_items not in right_set )