Source code for stupidb.api

"""StupiDB user-facing API.

.. note::

   The join functions all take `right`, `predicate` and then `left` as
   arguments, **in that order**.

   This is intentional, and is the way the functions must be written to enable
   `currying <https://en.wikipedia.org/wiki/Currying>`_.  Currying is the
   technique that allows us to use the right shift operator (``>>``) to chain
   operations.

"""

from __future__ import annotations

import inspect
import operator
from typing import Any, Callable, Iterable, Mapping

import tabulate
import toolz
from public import private, public

from .aggregation import (
    AggregateSpecification,
    FrameClause,
    Nulls,
    WindowAggregateSpecification,
)
from .core import (
    Aggregation,
    CrossJoin,
    Difference,
    DifferenceAll,
    GroupBy,
    InnerJoin,
    Intersect,
    IntersectAll,
    Join,
    JoinPredicate,
    LeftJoin,
    Limit,
    Mutate,
    PartitionBy,
    Predicate,
    Projection,
    Projector,
    Relation,
    RightJoin,
    Selection,
    SortBy,
    Table,
    Union,
    UnionAll,
)
from .functions.associative import (
    Count,
    Max,
    Mean,
    Min,
    PopulationCovariance,
    PopulationStandardDeviation,
    PopulationVariance,
    SampleCovariance,
    SampleStandardDeviation,
    SampleVariance,
    Sum,
    Total,
)
from .functions.navigation import First, Lag, Last, Lead, Nth
from .functions.ranking import DenseRank, Rank, RowNumber
from .protocols import Comparable
from .row import AbstractRow
from .typehints import R1, R2, OrderBy, R, T


[docs]@public # type: ignore[misc] def const(x: T | None) -> Callable[[AbstractRow], T | None]: """Return a function that returns `x` regardless of input.""" return lambda _: x
[docs]@public # type: ignore[misc] def get(name: str) -> Callable[[AbstractRow], T | None]: """Return a function that gets the `name` field from a row.""" return operator.itemgetter(name)
@private # type: ignore[misc] class shiftable(toolz.curry): """Shiftable curry.""" @property def __signature__(self) -> inspect.Signature: return inspect.signature(self.func) # pragma: no cover def __rrshift__(self, other: Relation) -> shiftable: return self(other)
[docs]@public # type: ignore[misc] @shiftable def table(rows: Iterable[Mapping[str, Any]]) -> Table: """Construct a relation from an iterable of mappings. Parameters ---------- rows An iterable of mappings whose keys are :class:`str` instances. Examples -------- >>> from stupidb import table >>> rows = [ ... dict(name="Bob", balance=-300), ... dict(name="Bob", balance=-100), ... dict(name="Alice", balance=400), ... dict(name="Alice", balance=700), ... ] >>> t = table(rows) >>> t # doctest: +ELLIPSIS name balance ------ --------- Bob -300 Bob -100 Alice 400 Alice 700 """ return Table.from_iterable(rows)
[docs]@public # type: ignore[misc] @shiftable def cross_join(right: Relation, left: Relation) -> Join: """Return the Cartesian product of tuples from `left` and `right`. Parameters ---------- right A relation left A relation Examples -------- >>> from stupidb import cross_join, select, table >>> rows = [ ... dict(name="Bob", balance=-300), ... dict(name="Bob", balance=-100), ... dict(name="Alice", balance=400), ... dict(name="Alice", balance=700), ... ] >>> t = table(rows) >>> s = table(rows) >>> crossed = cross_join(t, s) >>> crossed >> select( ... left_name=lambda row: row.left["name"], ... left_balance=lambda row: row.left["balance"], ... right_name=lambda row: row.right["name"], ... right_balance=lambda row: row.right["balance"] ... ) left_name left_balance right_name right_balance ----------- -------------- ------------ --------------- Bob -300 Bob -300 Bob -300 Bob -100 Bob -300 Alice 400 Bob -300 Alice 700 Bob -100 Bob -300 Bob -100 Bob -100 Bob -100 Alice 400 Bob -100 Alice 700 Alice 400 Bob -300 Alice 400 Bob -100 """ return CrossJoin(left, right)
[docs]@public # type: ignore[misc] @shiftable def inner_join(right: Relation, predicate: JoinPredicate, left: Relation) -> Join: """Join `left` and `right` relations using `predicate`. Drop rows if `predicate` returns ``False``. Parameters ---------- right A relation predicate A callable taking two arguments and returning a :class:`bool`. Examples -------- >>> from stupidb import inner_join, table >>> rows = [ ... dict(name="Bob", balance=-300), ... dict(name="Bob", balance=-100), ... dict(name="Alice", balance=400), ... dict(name="Alice", balance=700), ... ] >>> t = table(rows) >>> s = table(rows) >>> t >> inner_join( ... s, ... lambda left, right: left["balance"] < right["balance"] ... ) >> select(name=lambda r: r.left["name"], bal=lambda r: r.left["balance"]) name bal ------ ----- Bob -300 Bob -300 Bob -300 Bob -100 Bob -100 Alice 400 """ return InnerJoin(left, right, predicate)
[docs]@public # type: ignore[misc] @shiftable def left_join(right: Relation, predicate: JoinPredicate, left: Relation) -> LeftJoin: """Join `left` and `right` relations using `predicate`. Drop rows if `predicate` returns ``False``. Returns at least one of every row from `left`. Parameters ---------- right A relation predicate A callable taking two arguments and returning a :class:`bool`. """ return LeftJoin(left, right, predicate)
[docs]@public # type: ignore[misc] @shiftable def right_join(right: Relation, predicate: JoinPredicate, left: Relation) -> RightJoin: """Join `left` and `right` relations using `predicate`. Drop rows if `predicate` returns ``False``. Returns at least one of every row from `right`. Parameters ---------- right A relation predicate A callable taking two arguments and returning a :class:`bool`. """ return RightJoin(left, right, predicate)
[docs]@public # type: ignore[misc] @shiftable def full_join(right: Relation, predicate: JoinPredicate, left: Relation) -> Relation: """Full outer join.""" raise NotImplementedError("full outer joins are not yet supported")
@private # type: ignore[misc] @shiftable def _order_by(order_by: tuple[OrderBy, ...], nulls: Nulls, child: Relation) -> SortBy: return SortBy(child, order_by, nulls)
[docs]@public # type: ignore[misc] def order_by(*order_by: OrderBy, nulls: Nulls = Nulls.FIRST) -> SortBy: """Order the rows of the child operator according to `order_by`. Parameters ---------- order_by A sequence of ``OrderBy`` instances nulls One of :class:`~stupidb.aggregation.Nulls` indicating how to treat nulls when sorting. :attr:`~stupidb.aggregation.Nulls.FIRST` treats nulls as less than every other value, and :attr:`~stupidb.aggregation.Nulls.LAST` treats them as greater than every other value. Examples -------- >>> from stupidb import order_by, table >>> rows = [ ... dict(name="Bob", balance=-300), ... dict(name="Alice", balance=400), ... dict(name="Bob", balance=-100), ... dict(name="Alice", balance=700), ... ] >>> ordered = table(rows) >> order_by(lambda r: r.balance) >>> balances = [row.balance for row in ordered] >>> balances [-300, -100, 400, 700] """ return _order_by(order_by, nulls)
@private # type: ignore[misc] @shiftable def _select( projectors: Mapping[str, Projector | WindowAggregateSpecification], child: Relation ) -> Projection: return Projection(child, projectors)
[docs]@public # type: ignore[misc] def select(**projectors: Projector | WindowAggregateSpecification) -> Projection: """Subset or compute new columns from `projectors`. Parameters ---------- projectors A mapping from :class:`str` to :data:`FullProjector` instances. Examples -------- >>> from stupidb import select, table >>> rows = [ ... dict(name="Bob", balance=-300), ... dict(name="Alice", balance=400), ... dict(name="Bob", balance=-100), ... dict(name="Alice", balance=700), ... ] >>> names = table(rows) >> select(lower_name=lambda r: r.name.lower()) >>> names = [row.lower_name for row in names] >>> names ['bob', 'alice', 'bob', 'alice'] See Also -------- mutate """ valid_projectors = { name: projector for name, projector in projectors.items() if callable(projector) or isinstance(projector, WindowAggregateSpecification) } if len(valid_projectors) != len(projectors): raise TypeError("Invalid projection") return _select(projectors)
@private # type: ignore[misc] @shiftable def _mutate( mutators: Mapping[str, Projector | WindowAggregateSpecification], child: Relation ) -> Mutate: return Mutate(child, mutators)
[docs]@public # type: ignore[misc] def mutate(**mutators: Projector | WindowAggregateSpecification) -> Mutate: """Add new columns specified by `mutators`. Parameters ---------- projectors A mapping from :class:`str` to :data:`FullProjector` instances. Notes ----- Columns are appended, unlike :func:`~stupidb.api.select`. Examples -------- >>> from stupidb import mutate, table >>> rows = [ ... dict(name="Bob", balance=-300), ... dict(name="Alice", balance=400), ... dict(name="Bob", balance=-100), ... dict(name="Alice", balance=700), ... ] >>> rows = table(rows) >> mutate(lower_name=lambda r: r.name.lower()) >>> rows name balance lower_name ------ --------- ------------ Bob -300 bob Alice 400 alice Bob -100 bob Alice 700 alice See Also -------- select """ return _mutate(mutators)
[docs]@public # type: ignore[misc] @shiftable def sift(predicate: Predicate, child: Relation) -> Selection: """Filter rows in `child` according to `predicate`. Parameters ---------- predicate A callable of one argument taking an :class:`~stupidb.row.AbstractRow` and returning a ``bool``. Examples -------- >>> from stupidb import sift, table >>> rows = [ ... dict(name="Bob", balance=-300), ... dict(name="Alice", balance=400), ... dict(name="Bob", balance=-100), ... dict(name="Alice", balance=700), ... ] >>> rows = table(rows) >> sift(lambda r: r.name.lower().startswith("a")) >>> rows name balance ------ --------- Alice 400 Alice 700 """ return Selection(child, predicate)
[docs]@public # type: ignore[misc] def exists(relation: Relation) -> bool: """Compute whether any of the rows in `relation` are truthy. This is useful for computing semi-joins. """ return any(relation)
@private # type: ignore[misc] @shiftable def _aggregate( aggregations: Mapping[str, AggregateSpecification], child: Relation ) -> Aggregation: return Aggregation(child, aggregations)
[docs]@public # type: ignore[misc] def aggregate(**aggregations: AggregateSpecification) -> Aggregation: """Aggregate values from the child operator using `aggregations`. Parameters ---------- aggregations A mapping from :class:`str` column names to :class:`~stupidb.aggregation.AggregateSpecification` instances. Examples -------- Compute the average of a column: >>> from stupidb import aggregate, group_by, mean, table >>> rows = [ ... dict(name="Bob", age=30, timezone="America/New_York"), ... dict(name="Susan", age=20, timezone="America/New_York"), ... dict(name="Joe", age=41, timezone="America/Los_Angeles"), ... dict(name="Alice", age=39, timezone="America/Los_Angeles"), ... ] >>> average_age = table(rows) >> aggregate(avg_age=mean(lambda r: r.age)) >>> average_age avg_age --------- 32.5 Compute the average a column, grouped by another column: >>> average_age_by_timezone = ( ... table(rows) >> group_by(tz=lambda r: r.timezone) ... >> aggregate(avg_age=mean(lambda r: r.age)) ... ) >>> average_age_by_timezone tz avg_age ------------------- --------- America/New_York 25 America/Los_Angeles 40 See Also -------- group_by """ return _aggregate(aggregations)
[docs]@public # type: ignore[misc] @shiftable def over( window: FrameClause, child: AggregateSpecification ) -> WindowAggregateSpecification: """Construct a window aggregate. Parameters ---------- window A :class:`~stupidb.aggregation.FrameClause` instance constructed from :class:`~stupidb.aggregation.Window.rows` or :class:`~stupidb.aggregation.Window.range`. child The aggregation to compute over `window` Notes ----- This is one of the few user-facing functions that does **not** return a :class:`~stupidb.core.Relation`. The behavior of materializing the rows of the result of calling this function is undefined. Examples -------- >>> from stupidb import Window, over, mean, select, table >>> from datetime import date, timedelta >>> today = date(2019, 2, 9) >>> days = timedelta(days=1) >>> rows = [ ... {"name": "Alice", "balance": 400, "date": today}, ... {"name": "Alice", "balance": 300, "date": today + 1 * days}, ... {"name": "Alice", "balance": 100, "date": today + 2 * days}, ... {"name": "Bob", "balance": -150, "date": today - 4 * days}, ... {"name": "Bob", "balance": 200, "date": today - 3 * days}, ... ] >>> t = table(rows) >>> window = Window.range( ... partition_by=[lambda r: r.name], ... order_by=[lambda r: r.date], ... preceding=lambda r: 2 * days # two days behind + the current row ... ) >>> avg_balance_per_person = table(rows) >> select( ... name=lambda r: r.name, ... avg_balance=mean(lambda r: r.balance) >> over(window), ... balance=lambda r: r.balance, ... date=lambda r: r.date, ... ) >> order_by(lambda r: r.name, lambda r: r.date) >>> avg_balance_per_person name balance date avg_balance ------ --------- ---------- ------------- Alice 400 2019-02-09 400 Alice 300 2019-02-10 350 Alice 100 2019-02-11 266.667 Bob -150 2019-02-05 -150 Bob 200 2019-02-06 25 """ return WindowAggregateSpecification(child.aggregate_type, child.getters, window)
@private # type: ignore[misc] @shiftable def _group_by(group_by: Mapping[str, PartitionBy], child: Relation) -> GroupBy: return GroupBy(child, group_by)
[docs]@public # type: ignore[misc] def group_by(**group_by: PartitionBy) -> GroupBy: """Group the rows of the child operator according to `group_by`. Parameters ---------- group_by A mapping of :class:`str` column names to functions that compute grouping keys. Notes ----- Iterating over the rows of the result of this function is not very useful, since its :meth:`~stupidb.core.GroupBy.__iter__` method just yields the rows of its child. A call to this function is best followed by a call to :func:`~stupidb.api.aggregate`. Examples -------- >>> from stupidb import aggregate, group_by, mean, table >>> rows = [ ... dict(name="Bob", age=30, timezone="America/New_York"), ... dict(name="Susan", age=20, timezone="America/New_York"), ... dict(name="Joe", age=41, timezone="America/Los_Angeles"), ... dict(name="Alice", age=39, timezone="America/Los_Angeles"), ... ] >>> average_age_by_timezone = ( ... table(rows) >> group_by(tz=lambda r: r.timezone) ... >> aggregate(avg_age=mean(lambda r: r.age)) ... ) >>> average_age_by_timezone tz avg_age ------------------- --------- America/New_York 25 America/Los_Angeles 40 See Also -------- aggregate """ return _group_by(group_by)
[docs]@public # type: ignore[misc] @shiftable def union(right: Relation, left: Relation) -> Union: """Compute the union of `left` and `right`, ignoring duplicate rows. Parameters ---------- right A relation left A relation See Also -------- union_all """ return Union(left, right)
[docs]@public # type: ignore[misc] @shiftable def union_all(right: Relation, left: Relation) -> UnionAll: """Compute the union of `left` and `right`, preserving duplicate rows. Parameters ---------- right A relation left A relation See Also -------- union """ return UnionAll(left, right)
[docs]@public # type: ignore[misc] @shiftable def intersect(right: Relation, left: Relation) -> Intersect: """Compute the intersection of `left` and `right`, ignoring duplicate rows. Parameters ---------- right A relation left A relation See Also -------- intersect_all """ return Intersect(left, right)
[docs]@public # type: ignore[misc] @shiftable def intersect_all(right: Relation, left: Relation) -> IntersectAll: """Compute the intersection of `left` and `right`, preserving duplicates. Parameters ---------- right A relation left A relation See Also -------- intersect """ return IntersectAll(left, right)
[docs]@public # type: ignore[misc] @shiftable def difference(right: Relation, left: Relation) -> Difference: """Compute the set difference of `left` and `right`. Parameters ---------- right A relation left A relation """ return Difference(left, right)
[docs]@public # type: ignore[misc] @shiftable def difference_all(right: Relation, left: Relation) -> DifferenceAll: """Compute the set difference of `left` and `right`, preserving duplicates. Parameters ---------- right A relation left A relation """ return DifferenceAll(left, right)
[docs]@public # type: ignore[misc] @shiftable def limit(limit: int | None, relation: Relation, *, offset: int = 0) -> Limit: """Return the rows in `relation` starting from `offset` up to `limit`. Parameters ---------- limit The number of rows starting from offset to produce relation Relation whose rows to limit offset The number of rows to skip before yielding """ if offset < 0: raise ValueError(f"invalid offset, must be non-negative: {offset}") if limit is not None and limit < 0: raise ValueError(f"invalid limit, must be non-negative or None: {limit}") return Limit(relation, offset=offset, limit=limit)
[docs]@public # type: ignore[misc] def count(x: Callable[[AbstractRow], T | None]) -> AggregateSpecification: """Count the number of non-NULL values of `x`. Parameters ---------- x A column getter. """ return AggregateSpecification(Count, x)
[docs]@public # type: ignore[misc] def sum(x: Callable[[AbstractRow], R | None]) -> AggregateSpecification: """Compute the sum of `x`, with an empty column summing to NULL. Parameters ---------- x A column getter. """ return AggregateSpecification(Sum, x)
[docs]@public # type: ignore[misc] def total(x: Callable[[AbstractRow], R | None]) -> AggregateSpecification: """Compute the sum of `x`, with an empty column summing to zero. Parameters ---------- x A column getter. """ return AggregateSpecification(Total, x)
[docs]@public # type: ignore[misc] def first(x: Callable[[AbstractRow], T | None]) -> AggregateSpecification: """Compute the first row of `x` over a window. Parameters ---------- x A column getter. """ return AggregateSpecification(First, x)
[docs]@public # type: ignore[misc] def last(x: Callable[[AbstractRow], T | None]) -> AggregateSpecification: """Compute the last row of `x` over a window. Parameters ---------- x A column getter. """ return AggregateSpecification(Last, x)
[docs]@public # type: ignore[misc] def nth( x: Callable[[AbstractRow], T | None], i: Callable[[AbstractRow], int | None], ) -> AggregateSpecification: """Compute the `i`-th row of `x` over a window. Parameters ---------- x Column selector. i Callable to compute the row offset of the window to return. """ return AggregateSpecification(Nth, x, i)
[docs]@public # type: ignore[misc] def row_number() -> AggregateSpecification: """Compute the row number over a window.""" return AggregateSpecification(RowNumber)
[docs]@public # type: ignore[misc] def rank() -> AggregateSpecification: """Rank the rows of a relation based on the ordering key given in `over`.""" return AggregateSpecification(Rank)
[docs]@public # type: ignore[misc] def dense_rank() -> AggregateSpecification: """Rank the rows of a relation based on the ordering key given in `over`.""" return AggregateSpecification(DenseRank)
[docs]@public # type: ignore[misc] def lead( x: Callable[[AbstractRow], T | None], n: Callable[[AbstractRow], int | None] = const(1), default: Callable[[AbstractRow], T | None] = const(None), ) -> AggregateSpecification: """Lead a column `x` by `n` rows, using `default` for NULL values. Parameters ---------- x A column selector. n A callable computing the number of rows to lead. Defaults to a lead of **1** row. The callable takes the current row as input and thus the lead can be computed relative to the current row. default A callable computing the default value for the lead if the row would produce a NULL value when led. The callable takes the current row as input and thus the default can be computed relative to the current row. """ return AggregateSpecification(Lead, x, n, default)
[docs]@public # type: ignore[misc] def lag( x: Callable[[AbstractRow], T | None], n: Callable[[AbstractRow], int | None] = const(1), default: Callable[[AbstractRow], T | None] = const(None), ) -> AggregateSpecification: """Lag a column `x` by `n` rows, using `default` for NULL values. Parameters ---------- x A column selector. n A callable computing the number of rows to lag. Defaults to a lag of **1** row. The callable takes the current row as input and thus the lag can be computed relative to the current row. default A callable computing the default value for the lag if the row would produce a NULL value when lagged. The callable takes the current row as input and thus the default can be computed relative to the current row. """ return AggregateSpecification(Lag, x, n, default)
[docs]@public # type: ignore[misc] def mean(x: Callable[[AbstractRow], R | None]) -> AggregateSpecification: """Compute the average of a column. Parameters ---------- x A column selector. """ return AggregateSpecification(Mean, x)
[docs]@public # type: ignore[misc] def min(x: Callable[[AbstractRow], Comparable | None]) -> AggregateSpecification: """Compute the minimum of a column. Parameters ---------- x A column selector. """ return AggregateSpecification(Min, x)
[docs]@public # type: ignore[misc] def max(x: Callable[[AbstractRow], Comparable | None]) -> AggregateSpecification: """Compute the maximum of a column. Parameters ---------- x A column selector. """ return AggregateSpecification(Max, x)
[docs]@public # type: ignore[misc] def cov_samp( x: Callable[[AbstractRow], R1 | None], y: Callable[[AbstractRow], R2 | None] ) -> AggregateSpecification: """Compute the sample covariance of two columns. Parameters ---------- x A column selector. y A column selector. """ return AggregateSpecification(SampleCovariance, x, y)
[docs]@public # type: ignore[misc] def var_samp(x: Callable[[AbstractRow], R | None]) -> AggregateSpecification: """Compute the sample variance of a column. Parameters ---------- x A column selector. """ return AggregateSpecification(SampleVariance, x)
[docs]@public # type: ignore[misc] def stdev_samp(x: Callable[[AbstractRow], R | None]) -> AggregateSpecification: """Compute the sample standard deviation of a column. Parameters ---------- x A column selector. """ return AggregateSpecification(SampleStandardDeviation, x)
[docs]@public # type: ignore[misc] def cov_pop( x: Callable[[AbstractRow], R1 | None], y: Callable[[AbstractRow], R2 | None] ) -> AggregateSpecification: """Compute the population covariance of two columns. Parameters ---------- x A column selector. y A column selector. """ return AggregateSpecification(PopulationCovariance, x, y)
[docs]@public # type: ignore[misc] def var_pop(x: Callable[[AbstractRow], R]) -> AggregateSpecification: """Compute the population variance of a column. Parameters ---------- x A column selector. """ return AggregateSpecification(PopulationVariance, x)
[docs]@public # type: ignore[misc] def stdev_pop(x: Callable[[AbstractRow], R | None]) -> AggregateSpecification: """Compute the population standard deviation of a column. Parameters ---------- x A column selector. """ return AggregateSpecification(PopulationStandardDeviation, x)
[docs]@public # type: ignore[misc] @shiftable def pretty( rows: Relation, *, n: int | None = 10, tablefmt: str = "simple", headers: str = "keys", **kwargs: Any, ) -> str: """Pretty-format a relation. Parameters ---------- rows The relation to print n The number of rows to format; A value of `None` will pretty format all rows tablefmt The kind of table to use for formatting headers A string indicating how to compute column names kwargs Additional keyword arguments passed to the `tabulate.tabulate` function Returns ------- str Pretty-formatted relation See Also -------- stupidb.api.show """ return tabulate.tabulate( limit(n, rows), tablefmt=tablefmt, headers=headers, **kwargs )
[docs]@public # type: ignore[misc] @shiftable def show(rows: Relation, **kwargs: Any) -> None: """Pretty-print a relation. Parameters ---------- rows The relation to print kwargs Additional keyword arguments passed to the `stupidb.api.pretty` function See Also -------- stupidb.api.pretty """ print(pretty(rows, **kwargs))