From 718f56d973b68e32a495480558a9a2a787fe2320 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Wed, 6 May 2026 16:21:46 -0400 Subject: [PATCH 01/10] Add monoid module (#653) * add monoid module * clean up * fix doctest * fix * wip * remove incorrect rule * add disjoint set tests and fix bug * lint * drop jax monoid defs * drop incorrect comment * add assert * reduce nondeterminism and add assertions * fix inconsistent stream numbering and missing constant factors --- effectful/internals/disjoint_set.py | 99 +++++ effectful/ops/monoid.py | 556 +++++++++++++++++++++++++++ effectful/ops/syntax.py | 78 ++++ pyproject.toml | 1 + tests/_monoid_helpers.py | 85 ++++ tests/test_internals_disjoint_set.py | 124 ++++++ tests/test_ops_monoid.py | 518 +++++++++++++++++++++++++ 7 files changed, 1461 insertions(+) create mode 100644 effectful/internals/disjoint_set.py create mode 100644 effectful/ops/monoid.py create mode 100644 tests/_monoid_helpers.py create mode 100644 tests/test_internals_disjoint_set.py create mode 100644 tests/test_ops_monoid.py diff --git a/effectful/internals/disjoint_set.py b/effectful/internals/disjoint_set.py new file mode 100644 index 00000000..73b5c5c5 --- /dev/null +++ b/effectful/internals/disjoint_set.py @@ -0,0 +1,99 @@ +class DisjointSet: + """Disjoint Set Union (Union-Find) data structure. + + Maintains a collection of disjoint sets over the integers 0..n-1, + supporting near-constant-time union and find operations via + path compression and union by rank. + + The amortized time complexity per operation is O(α(n)), where α + is the inverse Ackermann function (effectively constant for any + practical n). + + Example: + >>> dsu = DisjointSet(5) + >>> dsu.union(0, 1) + True + >>> dsu.union(1, 2) + True + >>> dsu.find(0) == dsu.find(2) + True + >>> dsu.find(0) == dsu.find(3) + False + """ + + def __init__(self, n): + """Initialize n singleton sets: {0}, {1}, ..., {n-1}. + + Args: + n: The number of elements. Elements are labeled 0..n-1. + """ + self.parent = list(range(n)) + self.rank = [0] * n + + def _validate(self, x): + if x < 0 or x >= len(self.parent): + raise IndexError(f"Element {x} out of bounds") + + def find(self, x): + """Return the representative (root) of the set containing x. + + Two elements belong to the same set if and only if they have + the same representative. Applies path compression: every node + traversed is re-parented directly to its grandparent, flattening + the tree to speed up future queries. + + Args: + x: The element to look up. + + Returns: + The root element of x's set. + """ + self._validate(x) + while self.parent[x] != x: + self.parent[x] = self.parent[self.parent[x]] # path compression + x = self.parent[x] + return x + + def union(self, *elements): + """Merge the sets containing all given elements into one. + + Accepts any number of elements and unions them all together. + Uses union by rank: shallower trees are attached under the root + of the deeper one, keeping the combined tree shallow. + + Args: + *elements: Two or more elements to merge into a single set. + Calling with 0 or 1 elements is a no-op and returns False. + + Returns: + True if any merging occurred (i.e., at least two of the + elements were in different sets); False if all elements + were already in the same set or fewer than 2 were given. + """ + if len(elements) < 2: + return False + + merged = False + first = elements[0] + + for y in elements[1:]: + if self._union_pair(first, y): + merged = True + + return merged + + def _union_pair(self, x, y): + rx = self.find(x) + ry = self.find(y) + + if rx == ry: + return False + + if self.rank[rx] < self.rank[ry]: + rx, ry = ry, rx + + self.parent[ry] = rx + if self.rank[rx] == self.rank[ry]: + self.rank[rx] += 1 + + return True diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py new file mode 100644 index 00000000..58a10ba3 --- /dev/null +++ b/effectful/ops/monoid.py @@ -0,0 +1,556 @@ +import collections.abc +import functools +import itertools +import numbers +import typing +from collections import Counter, defaultdict +from collections.abc import Callable, Generator, Iterable, Iterator, Mapping, Sequence +from dataclasses import dataclass +from graphlib import TopologicalSorter +from typing import Annotated, Any + +from effectful.internals.disjoint_set import DisjointSet +from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler +from effectful.ops.syntax import ( + ObjectInterpretation, + Scoped, + _NumberTerm, + defdata, + implements, + iter_, + syntactic_eq, + syntactic_hash, +) +from effectful.ops.types import Interpretation, NotHandled, Operation, Term + +# Note: The streams value type should be something like Iterable[T], but some of +# our target stream types (e.g. jax.Array) are not subtypes of Iterable +type Streams[T] = Mapping[Operation[[], T], Any] + +type Body[T] = ( + Iterable[T] + | Callable[..., Body[T]] + | T + | Mapping[Any, Body[T]] + | Interpretation[T, Body[T]] +) + + +def order_streams[T](streams: Streams[T]) -> Iterable[tuple[Operation[[], T], Any]]: + """Determine an order to evaluate the streams based on their dependencies""" + stream_vars = set(streams.keys()) + dependencies = {k: fvsof(v) & stream_vars for k, v in streams.items()} + topo = TopologicalSorter(dependencies) + topo.prepare() + while topo.is_active(): + node_group = topo.get_ready() + for op in sorted(node_group): + yield (op, streams[op]) + topo.done(*node_group) + + +class Monoid[T]: + kernel: Operation[[T, T], T] + identity: T + + def __init__(self, kernel: Callable[[T, T], T], identity: T): + self.identity = identity + self.kernel = ( + kernel if isinstance(kernel, Operation) else Operation.define(kernel) + ) + + def __repr__(self): + return f"{type(self)}({self.kernel}, {self.identity})" + + @Operation.define + def plus[S: Body[T]](self, *args: S) -> S: + """Monoid addition with broadcasting over common collection types, + callables, and interpretations. + + """ + if not args: + return typing.cast(S, self.identity) + + if any(isinstance(x, Term) for x in args): + return typing.cast(S, defdata(self.plus, *args)) + + return self._plus(*args) + + @functools.singledispatchmethod + def _plus[S](self, *args: S) -> S: + return typing.cast(S, functools.reduce(self.kernel, args, self.identity)) + + @_plus.register(Sequence) + def _(self, *args): + return type(args[0])(self.plus(*vs) for vs in zip(*args, strict=True)) + + @_plus.register(Mapping) + def _(self, *args): + if isinstance(args[0], Interpretation): + keys = args[0].keys() + + for b in args[1:]: + if not isinstance(b, Interpretation): + raise TypeError(f"Expected interpretation but got {b}") + + b_keys = b.keys() + if not keys == b_keys: + raise ValueError( + f"Expected interpretation of {keys} but got {b_keys}" + ) + + result = {k: self.plus(*(handler(b)(b[k]) for b in args)) for k in keys} + return result + + for b in args[1:]: + if not isinstance(b, Mapping): + raise TypeError(f"Expected mapping but got {b}") + + all_values = collections.defaultdict(list) + for d in args: + for k, v in d.items(): + all_values[k].append(v) + result = {k: self.plus(*vs) for (k, vs) in all_values.items()} + return result + + @Operation.define + @functools.singledispatchmethod + def reduce[A, B, U: Body]( + self, + body: Annotated[U, Scoped[A | B]], + streams: Annotated[Streams, Scoped[A]], + ) -> Annotated[U, Scoped[B]]: + if callable(body): + return typing.cast(U, lambda *a, **k: self.reduce(body(*a, **k), streams)) + + def generator(loop_order) -> Iterator[Interpretation]: + if len(loop_order) == 0: + return + + stream_key = loop_order[0][0] + stream_values = evaluate(streams[stream_key]) + stream_values_iter = iter(stream_values) # type: ignore[arg-type] + + # If we try to iterate and get a term instead of a real + # iterator, give up + if isinstance(stream_values_iter, Term) and stream_values_iter.op is iter_: + raise NotHandled + + if len(loop_order) == 1: + for val in stream_values_iter: + yield {stream_key: functools.partial(lambda v: v, val)} + else: + for val in stream_values_iter: + intp: Interpretation = { + stream_key: functools.partial(lambda v: v, val) + } + with handler(intp): + for intp2 in generator(loop_order[1:]): + yield coproduct(intp, intp2) + + loop_order = list(order_streams(streams)) + try: + return self.plus( + *(handler(intp)(evaluate)(body) for intp in generator(loop_order)) + ) + except NotHandled: + return typing.cast(U, defdata(self.reduce, body, streams)) + + @reduce.register # type: ignore[attr-defined] + def _(self, body: Mapping, streams): + return {k: self.reduce(v, streams) for (k, v) in body.items()} + + @reduce.register # type: ignore[attr-defined] + def _(self, body: Sequence, streams): + return type(body)(self.reduce(x, streams) for x in body) # type:ignore[call-arg] + + @reduce.register # type: ignore[attr-defined] + def _(self, body: Generator, streams): + return (self.reduce(x, streams) for x in body) + + +class IdempotentMonoid[T](Monoid[T]): + @Operation.define + def plus[S: Body[T]](self, *args: S) -> S: + return super().plus(*args) + + @Operation.define + def reduce[A, B, U: Body]( + self, + body: Annotated[U, Scoped[A | B]], + streams: Annotated[Streams, Scoped[A]], + ) -> Annotated[U, Scoped[B]]: + return super().reduce(body, streams) + + +class CommutativeMonoid[T](Monoid[T]): + @Operation.define + def plus[S: Body[T]](self, *args: S) -> S: + return super().plus(*args) + + @Operation.define + def reduce[A, B, U: Body]( + self, + body: Annotated[U, Scoped[A | B]], + streams: Annotated[Streams, Scoped[A]], + ) -> Annotated[U, Scoped[B]]: + return super().reduce(body, streams) + + +class CommutativeMonoidWithZero[T](CommutativeMonoid[T]): + zero: T + + def __init__(self, kernel: Callable[[T, T], T], identity: T, zero: T): + super().__init__(kernel, identity) + self.zero = zero + + def __repr__(self): + return f"{type(self)}({self.kernel}, {self.identity}, {self.zero})" + + @Operation.define + def plus[S: Body[T]](self, *args: S) -> S: + return super().plus(*args) + + @Operation.define + def reduce[A, B, U: Body]( + self, + body: Annotated[U, Scoped[A | B]], + streams: Annotated[Streams, Scoped[A]], + ) -> Annotated[U, Scoped[B]]: + return super().reduce(body, streams) + + +class Semilattice[T](IdempotentMonoid[T], CommutativeMonoid[T]): + @Operation.define + def plus[S: Body[T]](self, *args: S) -> S: + return super().plus(*args) + + @Operation.define + def reduce[A, B, U: Body]( + self, + body: Annotated[U, Scoped[A | B]], + streams: Annotated[Streams, Scoped[A]], + ) -> Annotated[U, Scoped[B]]: + return super().reduce(body, streams) + + +@Operation.define +def _arg_min[T]( + a: tuple[numbers.Number, T | None], b: tuple[numbers.Number, T | None] +) -> tuple[numbers.Number, T | None]: + if isinstance(a[0], Term) or isinstance(b[0], Term): + raise NotHandled + return b if b[0] < a[0] else a # type: ignore + + +@Operation.define +def _arg_max[T]( + a: tuple[numbers.Number, T | None], b: tuple[numbers.Number, T | None] +) -> tuple[numbers.Number, T | None]: + if isinstance(a[0], Term) or isinstance(b[0], Term): + raise NotHandled + return b if b[0] > a[0] else a # type: ignore + + +Min = Semilattice(kernel=min, identity=float("inf")) +Max = Semilattice(kernel=max, identity=float("-inf")) +ArgMin = Monoid(kernel=_arg_min, identity=(float("inf"), None)) +ArgMax = Monoid(kernel=_arg_max, identity=(float("-inf"), None)) +Sum = CommutativeMonoid(kernel=_NumberTerm.__add__, identity=0) +Product = CommutativeMonoidWithZero(kernel=_NumberTerm.__mul__, identity=1, zero=0) + + +@dataclass +class _ExtensibleBinaryRelation[S, T]: + tuples: set[tuple[S, T]] + + def register(self, s: S, t: T) -> None: + self.tuples.add((s, t)) + + def __call__(self, s: S, t: T) -> bool: + return (s, t) in self.tuples + + +distributes_over = _ExtensibleBinaryRelation( + { + (Max.plus, Min.plus), + (Min.plus, Max.plus), + (Sum.plus, Min.plus), + (Sum.plus, Max.plus), + (Product.plus, Sum.plus), + } +) + + +class PlusEmpty(ObjectInterpretation): + """plus() = 0""" + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if not args: + return monoid.identity + return fwd() + + +class PlusSingle(ObjectInterpretation): + """plus(x) = x""" + + @implements(Monoid.plus) + def plus(self, _, *args): + if len(args) == 1: + return args[0] + return fwd() + + +class PlusIdentity(ObjectInterpretation): + """x₁ + ... + 0 + ... + xₙ = x₁ + ... + xₙ""" + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if any(x is monoid.identity for x in args): + return monoid.plus(*(x for x in args if x is not monoid.identity)) + return fwd() + + +class PlusAssoc(ObjectInterpretation): + """x + (y + z) = (x + y) + z = x + y + z""" + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if any(isinstance(x, Term) and x.op is monoid.plus for x in args): + flat_args = itertools.chain.from_iterable( + t.args if isinstance(t, Term) and t.op is monoid.plus else (t,) + for t in args + ) + assert len(args) > 0 + return monoid.plus(*flat_args) + return fwd() + + +class PlusDistr(ObjectInterpretation): + """x + (y * z) = x * y + x * z""" + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if any( + isinstance(x, Term) and distributes_over(monoid.plus, x.op) for x in args + ): + non_terms = [] + + # group terms by head operation + by_head_op = defaultdict(list) + for t in args: + if isinstance(t, Term): + by_head_op[t.op].append(t) + else: + non_terms.append(t) + + # distribute over each group + progress = False + final_sum = [] + for op, terms in by_head_op.items(): + if ( + len(terms) > 1 + and distributes_over(monoid.plus, op) + and not distributes_over(op, monoid.plus) + ): + progress = True + term_args = (t.args for t in terms) + dist_terms = ( + monoid.plus(*args) for args in itertools.product(*term_args) + ) + final_sum.append(op(*dist_terms)) + else: + final_sum += terms + if progress: + return monoid.plus(*non_terms, *final_sum) + return fwd() + + +class PlusZero(ObjectInterpretation): + """x₁ * ... * 0 * ... * xₙ = 0""" + + @implements(CommutativeMonoidWithZero.plus) + def plus(self, monoid, *args): + if any(x is monoid.zero for x in args): + return monoid.zero + return fwd() + + +class PlusConsecutiveDups(ObjectInterpretation): + """x ⊕ x ⊕ y = x ⊕ y""" + + @implements(IdempotentMonoid.plus) + def plus(self, monoid, *args): + dedup_args = ( + args[i] + for i in range(len(args)) + if i == 0 or not syntactic_eq(args[i - 1], args[i]) + ) + return fwd(monoid, *dedup_args) + + +class PlusDups(ObjectInterpretation): + """x ⊕ y ⊕ x = x ⊕ y""" + + @dataclass + class _HashableTerm: + term: Term + + def __eq__(self, other): + return syntactic_eq(self, other) + + def __hash__(self): + return syntactic_hash(self) + + @implements(Semilattice.plus) + def plus(self, monoid, *args): + # elim dups + args_count = Counter(self._HashableTerm(t) for t in args) + if len(args_count) < len(args): + dedup_args = [] + for t in args: + ht = self._HashableTerm(t) + if ht in args_count: + dedup_args.append(t) + del args_count[ht] + return fwd(monoid, *dedup_args) + return fwd() + + +NormalizePlusIntp = functools.reduce( + coproduct, + typing.cast( + list[Interpretation], + [ + PlusEmpty(), + PlusSingle(), + PlusIdentity(), + PlusAssoc(), + PlusDistr(), + PlusZero(), + PlusConsecutiveDups(), + PlusDups(), + ], + ), +) + + +class ReduceNoStreams(ObjectInterpretation): + """Implements the identity + reduce(R, ∅, body) = 0 + """ + + @implements(Monoid.reduce) + def reduce(self, monoid, _, streams): + if len(streams) == 0: + return monoid.identity + return fwd() + + +class ReduceFusion(ObjectInterpretation): + """Implements the identity + reduce(R, S1, reduce(R, S2, body)) = reduce(R, S1 ∪ S2, body) + """ + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if isinstance(body, Term) and body.op == monoid.reduce: + return monoid.reduce(body.args[0], streams | body.args[1]) + return fwd() + + +class ReduceSplit(ObjectInterpretation): + """Implements the identity + reduce(R, S, b1 + ... + bn) = reduce(R, S, b1) + ... + reduce(R, S, bn) + """ + + @implements(CommutativeMonoid.reduce) + def reduce(self, monoid, body, streams): + if isinstance(body, Term) and body.op == monoid.plus: + return monoid.plus(*(monoid.reduce(x, streams) for x in body.args)) + return fwd() + + +class ReduceFactorization(ObjectInterpretation): + """ + Implements factorization of independent terms. + For example, when having two independent distributions, + we can rewrite their marginalization as: + ∫p(x)⋅q(y)dxdy => ∫p(x)dx ⋅ ∫q(y)dy + + More specifically, in terms of reduces we are performing: + reduce(R, (S₁ × ... × Sₖ) , A₁ * ... * Aₖ) + => reduce(R, S₁, A₁) * ... * reduce(R, Sₖ, Aₖ) + where free(Aᵢ) ∩ free(Aⱼ) ∩ S = ∅ + and free(Aᵢ) ∩ S ⊆ Sᵢ + """ + + @implements(CommutativeMonoid.reduce) + def reduce(self, monoid, body, streams): + if isinstance(body, Term) and distributes_over(body.op, monoid.plus): + stream_vars = set(streams.keys()) + factors = [(arg, fvsof(arg)) for arg in body.args] + stream_ids = {v: i for (i, v) in enumerate(stream_vars)} + ds = DisjointSet(len(streams)) + + # streams are in the same partition as their dependencies + for stream_var, stream_id in stream_ids.items(): + stream_body = streams[stream_var] + deps = sorted([stream_ids[v] for v in fvsof(stream_body) & stream_vars]) + ds.union(stream_id, *deps) + + # factors are in the same partition as their dependencies + for factor, factor_fvs in factors: + factor_streams = sorted( + [stream_ids[v] for v in (factor_fvs & stream_vars)] + ) + ds.union(*factor_streams) + + placed_streams = set() + new_reduces = [] + for stream_key in streams: + if stream_key in placed_streams: + continue + + partition = ds.find(stream_ids[stream_key]) + partition_streams = { + k: v + for (k, v) in streams.items() + if ds.find(stream_ids[k]) == partition + } + partition_stream_keys = set(partition_streams.keys()) + + partition_factors = [ + t for t in factors if (t[1] & partition_stream_keys) + ] + + assert all( + (t[1] & stream_vars) <= partition_stream_keys + for t in partition_factors + ), "partition contains all streams required by factor" + + partition_term = body.op(*(t[0] for t in partition_factors)) + new_reduces.append((partition_term, partition_streams)) + placed_streams |= partition_stream_keys + + constant_factors = [t for (t, fvs) in factors if not (fvs & stream_vars)] + + if len(new_reduces) > 1: + result = body.op( + *constant_factors, *(monoid.reduce(*args) for args in new_reduces) + ) + return result + + return fwd() + + +NormalizeReduceIntp = functools.reduce( + coproduct, + typing.cast( + list[Interpretation], + [ReduceNoStreams(), ReduceFusion(), ReduceSplit(), ReduceFactorization()], + ), +) + +NormalizeIntp = coproduct(NormalizePlusIntp, NormalizeReduceIntp) diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 76401675..8fb12598 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -852,6 +852,84 @@ def _(x: object, other) -> bool: return x == other +@_CustomSingleDispatchCallable +def syntactic_hash(__dispatch: Callable[[type], Callable[[Any], int]], x) -> int: + """Structural hash compatible with :func:`syntactic_eq`. + + Guarantees that ``syntactic_eq(x, y)`` implies + ``syntactic_hash(x) == syntactic_hash(y)``. + + :param x: A term. + :returns: An integer hash. + """ + if dataclasses.is_dataclass(x) and not isinstance(x, type): + return hash( + ( + "dataclass", + type(x), + syntactic_hash( + { + field.name: getattr(x, field.name) + for field in dataclasses.fields(x) + } + ), + ) + ) + else: + return __dispatch(type(x))(x) + + +@syntactic_hash.register +def _(x: Term) -> int: + return hash( + ( + "term", + x.op, + len(x.args), + tuple(syntactic_hash(a) for a in x.args), + # sort kwargs so order doesn't affect the hash + tuple((k, syntactic_hash(x.kwargs[k])) for k in sorted(x.kwargs)), + ) + ) + + +@syntactic_hash.register +def _(x: collections.abc.Mapping) -> int: + # XOR over (key_hash, value_hash) pairs — order-independent, + # matching the set-based comparison in syntactic_eq's Mapping branch. + acc = 0 + for k in x: + acc ^= hash((hash(k), syntactic_hash(x[k]))) + return hash(("mapping", acc)) + + +@syntactic_hash.register +def _(x: collections.abc.Sequence) -> int: + if ( + isinstance(x, tuple) + and hasattr(x, "_fields") + and all(hasattr(x, f) for f in x._fields) + ): + return hash( + ( + "namedtuple", + type(x), + tuple(syntactic_hash(getattr(x, f)) for f in x._fields), + ) + ) + else: + # Use the abstract Sequence tag (not type(x)) because syntactic_eq + # treats any two Sequences of equal length and elementwise-equal + # contents as equal — e.g. [1,2] and (1,2) compare equal. + return hash(("sequence", len(x), tuple(syntactic_hash(a) for a in x))) + + +@syntactic_hash.register(object) +@syntactic_hash.register(str | bytes) +def _(x: object) -> int: + return hash(x) + + class ObjectInterpretation[T, V](collections.abc.Mapping): """A helper superclass for defining an ``Interpretation`` of many :class:`~effectful.ops.types.Operation` instances with shared state or behavior. diff --git a/pyproject.toml b/pyproject.toml index d565403f..685aaf55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ test = [ "pytest-cov", "pytest-xdist", "pytest-benchmark", + "hypothesis", "mypy", "ruff", "nbval", diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py new file mode 100644 index 00000000..4532ae72 --- /dev/null +++ b/tests/_monoid_helpers.py @@ -0,0 +1,85 @@ +from collections.abc import Callable, Mapping, Sequence +from typing import Any, get_args, get_origin + +from hypothesis import strategies as st + +from effectful.ops.syntax import deffn +from effectful.ops.types import Operation + + +def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: + """Strategy for the value an *0-arg* Operation should return.""" + if annotation is int: + return st.integers() + if annotation is float: + return st.floats(allow_nan=False) + if get_origin(annotation) is list and get_args(annotation) == (int,): + return st.lists(st.integers()) + raise NotImplementedError( + f"No value strategy for return annotation {annotation!r}; " + "supported: int, list[int]" + ) + + +_UNARY_INT_FNS: list[Callable[[int], int]] = [ + lambda x: x, + lambda x: x + 1, + lambda x: x - 1, + lambda x: -x, + lambda x: 2 * x, + lambda x: 3 * x + 1, +] + +_BINARY_INT_FNS: list[Callable[[int, int], int]] = [ + lambda x, y: x + y, + lambda x, y: x - y, + lambda x, y: x * y, + lambda x, y: x + 2 * y, + lambda x, y: 2 * x - y, +] + +_UNARY_LIST_FNS: list[Callable[[int], list[int]]] = [ + lambda _x: [], + lambda x: [x], + lambda x: [x, x + 1], + lambda x: [x, -x], + lambda x: [0, x, x + 1], +] + + +def _strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]: + """Pick a strategy producing a callable suitable for binding `op` in an + interpretation. Inspects the operation's signature. + """ + sig = op.__signature__ + params = list(sig.parameters.values()) + ret = sig.return_annotation + param_types = tuple(p.annotation for p in params) + + if not params: + return _value_strategy_for(ret).map(deffn) + if ret is int and param_types == (int,): + return st.sampled_from(_UNARY_INT_FNS) + if ret is int and param_types == (int, int): + return st.sampled_from(_BINARY_INT_FNS) + if get_origin(ret) is list and get_args(ret) == (int,) and param_types == (int,): + return st.sampled_from(_UNARY_LIST_FNS) + raise NotImplementedError( + f"Function-typed free var must return int or list[int]; got {ret!r} for {op}" + ) + + +@st.composite +def random_interpretation( + draw: st.DrawFn, free_vars: Sequence[Operation] +) -> Mapping[Operation, Callable[..., Any]]: + """Draw an Interpretation binding every Operation in `case.free_vars` to + a randomly chosen value/callable. Keys are Operation identities. + """ + intp: dict[Operation, Callable[..., Any]] = {} + for op in free_vars: + intp[op] = draw(_strategy_for_op(op)) + return intp + + +__all__ = ["random_interpretation"] diff --git a/tests/test_internals_disjoint_set.py b/tests/test_internals_disjoint_set.py new file mode 100644 index 00000000..808b8d25 --- /dev/null +++ b/tests/test_internals_disjoint_set.py @@ -0,0 +1,124 @@ +import random + +import pytest + +from effectful.internals.disjoint_set import DisjointSet + + +@pytest.fixture +def dsu(): + return DisjointSet(10) + + +def test_initial_state(dsu): + for i in range(10): + assert dsu.find(i) == i + + +def test_simple_union(dsu): + assert dsu.union(1, 2) is True + assert dsu.find(1) == dsu.find(2) + + +def test_union_idempotent(dsu): + dsu.union(1, 2) + assert dsu.union(1, 2) is False + + +def test_union_chain(dsu): + dsu.union(1, 2) + dsu.union(2, 3) + assert dsu.find(1) == dsu.find(3) + + +def test_union_multiple_elements_all_connected(dsu): + dsu.union(1, 2, 3, 4, 5) + roots = {dsu.find(i) for i in [1, 2, 3, 4, 5]} + assert len(roots) == 1 + + +def test_union_multiple_elements_partial_overlap(dsu): + dsu.union(1, 2) + dsu.union(3, 4) + dsu.union(2, 3, 5) + + roots = {dsu.find(i) for i in [1, 2, 3, 4, 5]} + assert len(roots) == 1 + + +def test_union_multiple_elements_with_existing_connections(dsu): + dsu.union(1, 2) + dsu.union(2, 3) + dsu.union(3, 4, 5, 6) + + roots = {dsu.find(i) for i in [1, 2, 3, 4, 5, 6]} + assert len(roots) == 1 + + +def test_union_single_element(dsu): + assert dsu.union(1) is False + + +def test_union_no_elements(dsu): + assert dsu.union() is False + + +def test_union_self(dsu): + assert dsu.union(3, 3) is False + assert dsu.find(3) == 3 + + +def test_transitivity(dsu): + dsu.union(1, 2) + dsu.union(2, 3) + dsu.union(3, 4) + assert dsu.find(1) == dsu.find(4) + + +def test_disjoint_sets_remain_separate(dsu): + dsu.union(1, 2) + dsu.union(3, 4) + assert dsu.find(1) != dsu.find(3) + + +def test_randomized_unions(): + n = 50 + dsu = DisjointSet(n) + + groups = [{i} for i in range(n)] + + def find_group(x): + for g in groups: + if x in g: + return g + + for _ in range(100): + elems = random.sample(range(n), random.randint(2, 5)) + dsu.union(*elems) + + # merge ground-truth groups + merged = set() + for e in elems: + merged |= find_group(e) + + groups = [g for g in groups if g.isdisjoint(merged)] + groups.append(merged) + + # verify structure matches ground truth + for g in groups: + roots = {dsu.find(x) for x in g} + assert len(roots) == 1 + + +def test_path_compression_effect(): + dsu = DisjointSet(6) + dsu.union(0, 1) + dsu.union(1, 2) + dsu.union(2, 3) + dsu.union(3, 4) + + # Trigger compression + root_before = dsu.find(4) + root_after = dsu.find(4) + + assert root_before == root_after diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py new file mode 100644 index 00000000..a22928cc --- /dev/null +++ b/tests/test_ops_monoid.py @@ -0,0 +1,518 @@ +import functools +import itertools + +import pytest +from hypothesis import given, settings +from hypothesis import strategies as st + +from effectful.internals.runtime import interpreter +from effectful.ops.monoid import Max, Min, NormalizeIntp, Product, Semilattice, Sum +from effectful.ops.semantics import apply, evaluate, fvsof, handler +from effectful.ops.syntax import _BaseTerm, defdata, syntactic_eq +from effectful.ops.types import NotHandled, Operation +from tests._monoid_helpers import random_interpretation + +_INT = st.integers(min_value=-100, max_value=100) + +ALL_MONOIDS = [ + pytest.param(Sum, id="Sum"), + pytest.param(Product, id="Product"), + pytest.param(Min, id="Min"), + pytest.param(Max, id="Max"), +] + +COMMUTATIVE = [ + pytest.param(Sum, id="Sum"), + pytest.param(Product, id="Product"), + pytest.param(Min, id="Min"), + pytest.param(Max, id="Max"), +] + +IDEMPOTENT = [ + pytest.param(Min, id="Min"), + pytest.param(Max, id="Max"), +] + +WITH_ZERO = [ + pytest.param(Product, id="Product"), +] + + +def define_vars(*names, typ=int): + if len(names) == 1: + return Operation.define(typ, name=names[0]) + return tuple(Operation.define(typ, name=n) for n in names) + + +@functools.cache +def _canonical_op(idx: int) -> Operation: + """Globally cached canonical Operation, keyed by encounter index. + + Cached so that two independent canonicalize runs return the same + Operation object for the same index — letting ``syntactic_eq`` + compare canonical forms by Operation identity. + """ + return Operation.define(int, name=f"__cv_{idx}") + + +def syntactic_eq_alpha(x, y) -> bool: + """Alpha-equivalence-respecting variant of ``syntactic_eq``. + + Walks each expression bottom-up with :func:`evaluate` and renames + every bound variable to a deterministic canonical Operation. The + canonical names are assigned by a counter that increments in + ``evaluate``'s natural traversal order, so two alpha-equivalent + expressions canonicalize to syntactically identical results. + """ + return syntactic_eq(_canonicalize(x), _canonicalize(y)) + + +def _canonicalize(expr): + counter = itertools.count() + + def _passthrough(op, *args, **kwargs): + return defdata(op, *args, **kwargs) + + def _substitute(arg, renaming): + """Apply a bound-variable renaming using ``evaluate`` for traversal.""" + if not renaming: + return arg + with interpreter({apply: _passthrough, **renaming}): + return evaluate(arg) + + def _bound_var_order(args, kwargs, bound_set): + """Return bound variables in deterministic encounter order.""" + seen: list[Operation] = [] + seen_set: set[Operation] = set() + + def _capture(op, *a, **kw): + if op in bound_set and op not in seen_set: + seen.append(op) + seen_set.add(op) + return defdata(op, *a, **kw) + + # ``evaluate`` walks Terms, lists, tuples, mappings, dataclasses, + # etc. for free; the apply handler captures bound vars used as + # ``x()`` anywhere in the body. + with interpreter({apply: _capture}): + evaluate((args, kwargs)) + + # Binders bypass the apply handler. Pick them up with a small structural + # walk that visits dict keys too. + def _walk_bare(obj): + if isinstance(obj, Operation): + if obj in bound_set and obj not in seen_set: + seen.append(obj) + seen_set.add(obj) + elif isinstance(obj, dict): + for k, v in obj.items(): + _walk_bare(k) + _walk_bare(v) + elif isinstance(obj, list | set | frozenset | tuple): + for v in obj: + _walk_bare(v) + + _walk_bare((args, kwargs)) + return seen + + def _apply_canonical(op, *args, **kwargs): + bindings = op.__fvs_rule__(*args, **kwargs) + all_bound: set[Operation] = set().union( + *bindings.args, *bindings.kwargs.values() + ) + if not all_bound: + return defdata(op, *args, **kwargs) + + order = _bound_var_order(args, kwargs, all_bound) + canonical = {var: _canonical_op(next(counter)) for var in order} + assert all_bound <= set(order) + + new_args = tuple( + _substitute( + arg, {v: canonical[v] for v in bindings.args[i] if v in canonical} + ) + for i, arg in enumerate(args) + ) + new_kwargs = { + k: _substitute( + v, + {var: canonical[var] for var in bindings.kwargs[k] if var in canonical}, + ) + for k, v in kwargs.items() + } + + # avoid the renaming from defdata + return _BaseTerm(op, *new_args, **new_kwargs) + + with interpreter({apply: _apply_canonical}): + return evaluate(expr) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +@given(a=_INT, b=_INT, c=_INT) +@settings(max_examples=50, deadline=None) +def test_associativity(monoid, a, b, c): + left = monoid.plus(monoid.plus(a, b), c) + right = monoid.plus(a, monoid.plus(b, c)) + assert left == right + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +@given(a=_INT) +@settings(max_examples=50, deadline=None) +def test_identity(monoid, a): + assert monoid.plus(monoid.identity, a) == a + assert monoid.plus(a, monoid.identity) == a + + +@pytest.mark.parametrize("monoid", COMMUTATIVE) +@given(a=_INT, b=_INT) +@settings(max_examples=50, deadline=None) +def test_commutativity(monoid, a, b): + assert monoid.plus(a, b) == monoid.plus(b, a) + + +@pytest.mark.parametrize("monoid", IDEMPOTENT) +@given(a=_INT) +@settings(max_examples=50, deadline=None) +def test_idempotence(monoid, a): + assert monoid.plus(a, a) == a + + +@pytest.mark.parametrize("monoid", WITH_ZERO) +@given(a=_INT) +@settings(max_examples=50, deadline=None) +def test_zero_absorbs(monoid, a): + assert monoid.plus(monoid.zero, a) == monoid.zero + assert monoid.plus(a, monoid.zero) == monoid.zero + + +def _check_pair(lhs, rhs, *, free_vars=[], max_examples: int = 25) -> None: + """Run structural + semantic checks on a TermPair.""" + with handler(NormalizeIntp): + norm = evaluate(lhs) + + assert syntactic_eq_alpha(norm, rhs) + + @given(intp=random_interpretation(free_vars)) + @settings(max_examples=max_examples, deadline=None) + def _check_semantics(intp): + with handler(intp): + lhs_val = evaluate(lhs) + rhs_val = evaluate(rhs) + assert lhs_val == rhs_val + + _check_semantics() + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_empty(monoid): + _check_pair(lhs=monoid.plus(), rhs=monoid.identity) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_single(monoid): + x = define_vars("x", typ=type(monoid.identity)) + _check_pair(lhs=monoid.plus(x()), rhs=x(), free_vars=[x]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_identity_right(monoid): + x = define_vars("x", typ=type(monoid.identity)) + _check_pair(lhs=monoid.plus(x(), monoid.identity), rhs=x(), free_vars=[x]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_identity_left(monoid): + x = define_vars("x", typ=type(monoid.identity)) + _check_pair(lhs=monoid.plus(monoid.identity, x()), rhs=x(), free_vars=[x]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_assoc_right(monoid): + x, y, z = define_vars("x", "y", "z", typ=type(monoid.identity)) + _check_pair( + lhs=monoid.plus(x(), monoid.plus(y(), z())), + rhs=monoid.plus(x(), y(), z()), + free_vars=[x, y, z], + ) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_assoc_left(monoid): + x, y, z = define_vars("x", "y", "z", typ=type(monoid.identity)) + _check_pair( + lhs=monoid.plus(monoid.plus(x(), y()), z()), + rhs=monoid.plus(x(), y(), z()), + free_vars=[x, y, z], + ) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_sequence(monoid): + a, b, c, d = define_vars("a", "b", "c", "d", typ=type(monoid.identity)) + _check_pair( + lhs=monoid.plus([a(), b()], [c(), d()]), + rhs=[monoid.plus(a(), c()), monoid.plus(b(), d())], + free_vars=[a, b, c, d], + ) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_mapping(monoid): + a, b, c, d = define_vars("a", "b", "c", "d", typ=type(monoid.identity)) + _check_pair( + lhs=monoid.plus({"x": a(), "y": b()}, {"x": c(), "z": d()}), + rhs={"x": monoid.plus(a(), c()), "y": b(), "z": d()}, + free_vars=[a, b, c, d], + ) + + +def test_plus_distributes(): + a, b, c, d = define_vars("a", "b", "c", "d") + lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d())) + rhs = Sum.plus( + Product.plus(a(), c()), + Product.plus(a(), d()), + Product.plus(b(), c()), + Product.plus(b(), d()), + ) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b, c, d]) + + +def test_plus_distributes_constant(): + a, b, c, d = define_vars("a", "b", "c", "d") + lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d()), 5) + rhs = Product.plus( + 5, + Sum.plus( + Product.plus(a(), c()), + Product.plus(a(), d()), + Product.plus(b(), c()), + Product.plus(b(), d()), + ), + ) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b, c, d]) + + +def test_plus_distributes_multiple(): + a, b, c, d = define_vars("a", "b", "c", "d") + lhs = Sum.plus( + Min.plus(a(), b()), + Min.plus(c(), d()), + Max.plus(a(), b()), + Max.plus(c(), d()), + ) + rhs = Sum.plus( + Min.plus( + Sum.plus(a(), c()), + Sum.plus(a(), d()), + Sum.plus(b(), c()), + Sum.plus(b(), d()), + ), + Max.plus( + Sum.plus(a(), c()), + Sum.plus(a(), d()), + Sum.plus(b(), c()), + Sum.plus(b(), d()), + ), + ) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b, c, d]) + + +@pytest.mark.parametrize("monoid", IDEMPOTENT) +def test_plus_idempotent_consecutive(monoid): + """``a, a, b → a, b`` — only consecutive duplicates collapse.""" + a, b = define_vars("a", "b") + lhs = monoid.plus(a(), a(), b()) + return _check_pair(lhs=lhs, rhs=monoid.plus(a(), b()), free_vars=[a, b]) + + +@pytest.mark.parametrize("monoid", IDEMPOTENT) +def test_plus_idempotent_non_consecutive(monoid): + """``a, b, a`` — Semilattice (Min/Max) collapses via commutative + PlusDups; plain IdempotentMonoid leaves it as-is (consecutive-only).""" + a, b = define_vars("a", "b") + lhs = monoid.plus(a(), b(), a()) + if isinstance(monoid, Semilattice): + rhs = monoid.plus(a(), b()) + else: + rhs = monoid.plus(a(), b(), a()) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b]) + + +def test_plus_commutative_idempotent_long(): + """Long alternation collapses via commutative dedup (Min/Max only).""" + a, b = define_vars("a", "b") + lhs = Min.plus(a(), b(), a(), b(), b(), a(), a()) + _check_pair(lhs=lhs, rhs=Min.plus(a(), b()), free_vars=[a, b]) + + +@pytest.mark.parametrize("monoid", WITH_ZERO) +def test_plus_zero(monoid): + a = define_vars("a") + lhs_right = monoid.plus(a(), monoid.zero) + lhs_left = monoid.plus(monoid.zero, a()) + _check_pair(lhs=lhs_right, rhs=monoid.zero, free_vars=[a]) + _check_pair(lhs=lhs_left, rhs=monoid.zero, free_vars=[a]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_body_sequence(monoid): + x = Operation.define(int, name="x") + X = Operation.define(list[int], name="X") + + @Operation.define + def f(_x: int) -> int: + raise NotHandled + + g = Operation.define(f, name="g") + + lhs = monoid.reduce([f(x()), g(x())], {x: X()}) + rhs = [monoid.reduce(f(x()), {x: X()}), monoid.reduce(g(x()), {x: X()})] + + _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, f, g]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_body_sequence_2(monoid): + x, y = define_vars("x", "y") + X, Y = define_vars("X", "Y", typ=list[int]) + + @Operation.define + def f(_x: int) -> int: + raise NotHandled + + g = Operation.define(f, name="g") + + lhs = monoid.reduce([f(x()), g(y())], {x: X(), y: Y()}) + rhs = [ + monoid.reduce(f(x()), {x: X(), y: Y()}), + monoid.reduce(g(y()), {x: X(), y: Y()}), + ] + + _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, Y, f, g]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_body_mapping(monoid): + x = Operation.define(int, name="x") + X = Operation.define(list[int], name="X") + + @Operation.define + def f(_x: int) -> int: + raise NotHandled + + g = Operation.define(f, name="g") + + lhs = monoid.reduce({"a": f(x()), "b": g(x())}, {x: X()}) + rhs = { + "a": monoid.reduce(f(x()), {x: X()}), + "b": monoid.reduce(g(x()), {x: X()}), + } + _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, f, g]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_no_streams(monoid): + a = define_vars("a") + lhs = monoid.reduce(a(), {}) + rhs = monoid.identity + + _check_pair(lhs=lhs, rhs=rhs, free_vars=[a]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_reduce(monoid): + a, b = define_vars("a", "b") + A, B = define_vars("A", "B", typ=list[int]) + + @Operation.define + def f(_x: int, _y: int) -> int: + raise NotHandled + + lhs = monoid.reduce(monoid.reduce(f(a(), b()), {a: A()}), {b: B()}) + rhs = monoid.reduce(f(a(), b()), {a: A(), b: B()}) + + _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B, f]) + + +@pytest.mark.parametrize("monoid", COMMUTATIVE) +def test_reduce_plus(monoid): + a, b = define_vars("a", "b") + A, B = define_vars("A", "B", typ=list[int]) + lhs = monoid.reduce(monoid.plus(a(), b()), {a: A(), b: B()}) + rhs = monoid.plus( + monoid.reduce(a(), {a: A(), b: B()}), + monoid.reduce(b(), {a: A(), b: B()}), + ) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B]) + + +def test_reduce_independent_1(): + a, b = define_vars("a", "b") + A, B = define_vars("A", "B", typ=list[int]) + lhs = Sum.reduce(Product.plus(a(), b()), {a: A(), b: B()}) + rhs = Product.plus(Sum.reduce(a(), {a: A()}), Sum.reduce(b(), {b: B()})) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B]) + + +def test_reduce_independent_2(): + a, b, c = define_vars("a", "b", "c") + A, B, C = define_vars("A", "B", "C", typ=list[int]) + + @Operation.define + def f(_x: int, _y: int) -> int: + raise NotHandled + + lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c())), {a: A(), b: B(), c: C()}) + rhs = Product.plus( + Sum.reduce(a(), {a: A()}), + Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), + ) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B, C, f]) + + +def test_reduce_independent_3_negative(): + """Stream `b` depends on `a` (b: g(a())), so the proposed factorization + is unsound — the normalizer must NOT apply it.""" + a, b, c = define_vars("a", "b", "c") + A, C = define_vars("A", "C", typ=list[int]) + + @Operation.define + def f(_x: int, _y: int) -> int: + raise NotHandled + + @Operation.define + def g(_x: int) -> list[int]: + raise NotHandled + + with handler(NormalizeIntp): + lhs = Sum.reduce( + Product.plus(a(), b(), f(b(), c())), {a: A(), b: g(a()), c: C()} + ) + bogus_rhs = Product.plus( + Sum.reduce(a(), {a: A()}), + Sum.reduce(Product.plus(b(), f(b(), c())), {b: g(a()), c: C()}), + ) + assert fvsof(bogus_rhs) != fvsof(lhs) + # Structural-only negative check: the normalizer correctly refused to apply + # the bogus factorization. + assert not syntactic_eq_alpha(lhs, bogus_rhs) + + +def test_reduce_independent_4(): + a, b, c = define_vars("a", "b", "c") + A, B, C = define_vars("A", "B", "C", typ=list[int]) + + @Operation.define + def f(_x: int, _y: int) -> int: + raise NotHandled + + lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c()), 7), {a: A(), b: B(), c: C()}) + rhs = Product.plus( + 7, + Sum.reduce(a(), {a: A()}), + Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), + ) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B, C, f]) From 965df122d9741bc04560e0b860c6d3565d76778d Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Tue, 12 May 2026 14:57:58 -0400 Subject: [PATCH 02/10] Add inversion from `weighted` (#655) * Add monoid module (#653) * add monoid module * clean up * fix doctest * fix * wip * remove incorrect rule * add disjoint set tests and fix bug * lint * drop jax monoid defs * drop incorrect comment * add assert * reduce nondeterminism and add assertions * fix inconsistent stream numbering and missing constant factors * wip * cleanup * fix rule * wip * fix bug * cleanup * lin --- effectful/internals/product_n.py | 2 +- effectful/ops/monoid.py | 162 ++++++++++++++++++++++++++-- effectful/ops/semantics.py | 1 + effectful/ops/types.py | 5 +- tests/_monoid_helpers.py | 14 +-- tests/test_handlers_llm_provider.py | 2 +- tests/test_ops_monoid.py | 146 ++++++++++++++++++++++--- tests/test_ops_syntax.py | 1 - 8 files changed, 300 insertions(+), 33 deletions(-) diff --git a/effectful/internals/product_n.py b/effectful/internals/product_n.py index 4b8bd2a8..87a9c6a4 100644 --- a/effectful/internals/product_n.py +++ b/effectful/internals/product_n.py @@ -69,7 +69,7 @@ def map_structure(func, expr): else: return type(expr)(map_structure(func, tuple(expr.items()))) elif isinstance(expr, collections.abc.Sequence): - if isinstance(expr, str | bytes): + if isinstance(expr, str | bytes | range): return expr elif ( isinstance(expr, tuple) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 58a10ba3..ad83de47 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -4,24 +4,26 @@ import numbers import typing from collections import Counter, defaultdict -from collections.abc import Callable, Generator, Iterable, Iterator, Mapping, Sequence +from collections.abc import Callable, Generator, Iterable, Iterator, Mapping from dataclasses import dataclass from graphlib import TopologicalSorter from typing import Annotated, Any from effectful.internals.disjoint_set import DisjointSet +from effectful.internals.runtime import interpreter from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler from effectful.ops.syntax import ( ObjectInterpretation, Scoped, _NumberTerm, defdata, + deffn, implements, iter_, syntactic_eq, syntactic_hash, ) -from effectful.ops.types import Interpretation, NotHandled, Operation, Term +from effectful.ops.types import Expr, Interpretation, NotHandled, Operation, Term # Note: The streams value type should be something like Iterable[T], but some of # our target stream types (e.g. jax.Array) are not subtypes of Iterable @@ -80,9 +82,13 @@ def plus[S: Body[T]](self, *args: S) -> S: def _plus[S](self, *args: S) -> S: return typing.cast(S, functools.reduce(self.kernel, args, self.identity)) - @_plus.register(Sequence) + @_plus.register(tuple) def _(self, *args): - return type(args[0])(self.plus(*vs) for vs in zip(*args, strict=True)) + return tuple(self.plus(*vs) for vs in zip(*args, strict=True)) + + @_plus.register(Generator) + def _(self, *args): + return (self.plus(*vs) for vs in zip(*args, strict=True)) @_plus.register(Mapping) def _(self, *args): @@ -161,8 +167,8 @@ def _(self, body: Mapping, streams): return {k: self.reduce(v, streams) for (k, v) in body.items()} @reduce.register # type: ignore[attr-defined] - def _(self, body: Sequence, streams): - return type(body)(self.reduce(x, streams) for x in body) # type:ignore[call-arg] + def _(self, body: tuple, streams): + return tuple(self.reduce(x, streams) for x in body) @reduce.register # type: ignore[attr-defined] def _(self, body: Generator, streams): @@ -252,12 +258,26 @@ def _arg_max[T]( return b if b[0] > a[0] else a # type: ignore +@Operation.define +def product[T]( + a: Iterable[tuple[T, ...] | T], b: Iterable[tuple[T, ...] | T] +) -> Iterable[tuple[T, ...]]: + if isinstance(a, Term) or isinstance(b, Term): + raise NotHandled + + def to_tuple(x): + return x if isinstance(x, tuple) else (x,) + + return [to_tuple(x) + to_tuple(y) for (x, y) in itertools.product(a, b)] + + Min = Semilattice(kernel=min, identity=float("inf")) Max = Semilattice(kernel=max, identity=float("-inf")) ArgMin = Monoid(kernel=_arg_min, identity=(float("inf"), None)) ArgMax = Monoid(kernel=_arg_max, identity=(float("-inf"), None)) Sum = CommutativeMonoid(kernel=_NumberTerm.__add__, identity=0) Product = CommutativeMonoidWithZero(kernel=_NumberTerm.__mul__, identity=1, zero=0) +CartesianProduct = Monoid(kernel=product, identity=[()]) @dataclass @@ -545,11 +565,139 @@ def reduce(self, monoid, body, streams): return fwd() +def inner_stream( + streams: dict[Operation, Expr], +) -> Iterable[tuple[dict[Operation, Expr], Operation, Expr]]: + """Returns the streams that can be ordered innermost in the loop nest as + well as the remaining streams in the nest. + + """ + stream_vars = set(streams.keys()) + + no_dependents = set() + succ = defaultdict(set) + for k, v in streams.items(): + preds = fvsof(v) & stream_vars + if preds: + for pred in preds: + succ[pred].add(k) + else: + no_dependents.add(k) + + topo = TopologicalSorter(succ) + topo.prepare() + return ( + ({k: v for (k, v) in streams.items() if k != op}, op, streams[op]) + for op in set(topo.get_ready()) | no_dependents + ) + + +def match_reduce(term: Term) -> tuple | None: + reduce_args = None + + def set_reduce_args(*args, **kwargs): + nonlocal reduce_args + reduce_args = args + + with interpreter({Monoid.reduce: set_reduce_args}): + term.op(*term.args, **term.kwargs) + return reduce_args + + +class ReduceDistributeCartesianProduct(ObjectInterpretation): + """Eliminates a reduce over a cartesian product. + ∑_x₁ ∑_x₂ ... ∑_xₙ ∏_i f(xᵢ) = ∏_i ∑_xᵢ f(xᵢ) + This transform is also called inversion in the lifting + literature (e.g. [1]). + + More specifically, this transform implements the identity + reduce(⨁, reduce(⨂, body2, {vv: v()}), {v: reduce(×, body1, S1)} ∪ S2) + = reduce(⨁, reduce(⨂, reduce(⨁, body2, {vv: body1}), S1), S2) + where × is the cartesian product and ⨂ distributes over ⨁. + + Note: This could be generalized to grouped inversion [2]. + + [1] Braz, Rd, Eyal Amir, and Dan Roth. "Lifted first-order + probabilistic inference." IJCAI. 2005. + [2] Taghipour, Nima, et al. "Completeness results for lifted + variable elimination." AISTATS. 2013. + """ + + @implements(CommutativeMonoid.reduce) + def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): + if not (isinstance(sum_body, Term)): + return fwd() + + # body is a product or multiplication of products + if distributes_over(sum_body.op, sum_monoid.plus): + prod_reduces = sum_body.args + else: + prod_reduces = [sum_body] + + products: list[tuple[Monoid, Callable, Operation, Term]] = [] + for prod_reduce in prod_reduces: + prod_args = match_reduce(prod_reduce) + if prod_args is None: + return fwd() + (prod_monoid, prod_body, prod_streams) = prod_args + if not ( + distributes_over(prod_monoid.plus, sum_monoid.plus) + and (len(products) == 0 or products[-1][0] == prod_monoid) + ): + return fwd() + + if len(prod_streams) > 1 or len(prod_streams) == 0: + return fwd() + (prod_op, prod_stream) = next(iter(prod_streams.items())) + products.append( + (prod_monoid, deffn(prod_body, prod_op), prod_op, prod_stream) + ) + + assert len(products) > 0 + + for outer_sum_streams, cprod_op, cprod_term in inner_stream(sum_streams): + if not ( + isinstance(cprod_term, Term) + and cprod_term.op == CartesianProduct.reduce + ): + continue + (cprod_body, cprod_streams) = cprod_term.args + + if not all( + prod_stream.op == cprod_op for (_, _, _, prod_stream) in products + ): + continue + + prod_op = Operation.define(products[0][2]) + prod_monoid = products[0][0] + inner_sum = sum_monoid.reduce( + prod_monoid.plus( + *(prod_body(prod_op()) for (_, prod_body, _, _) in products) + ), + {prod_op: cprod_body}, + ) + prod = prod_monoid.reduce(inner_sum, cprod_streams) + outer_sum = ( + sum_monoid.reduce(prod, outer_sum_streams) + if outer_sum_streams + else prod + ) + return outer_sum + + return fwd() + + NormalizeReduceIntp = functools.reduce( coproduct, typing.cast( list[Interpretation], - [ReduceNoStreams(), ReduceFusion(), ReduceSplit(), ReduceFactorization()], + [ + ReduceNoStreams(), + ReduceFusion(), + ReduceSplit(), + ReduceFactorization(), + ReduceDistributeCartesianProduct(), + ], ), ) diff --git a/effectful/ops/semantics.py b/effectful/ops/semantics.py index de041b61..e54729bb 100644 --- a/effectful/ops/semantics.py +++ b/effectful/ops/semantics.py @@ -151,6 +151,7 @@ def evaluate[T]( @evaluate.register(object) @evaluate.register(str) @evaluate.register(bytes) +@evaluate.register(range) def _evaluate_object[T](expr: T, **kwargs) -> T: if dataclasses.is_dataclass(expr) and not isinstance(expr, type): return typing.cast( diff --git a/effectful/ops/types.py b/effectful/ops/types.py index 46419d7a..14795ec9 100644 --- a/effectful/ops/types.py +++ b/effectful/ops/types.py @@ -488,7 +488,10 @@ def _instance_op(instance, *args, **kwargs): else: return default_result - instance_op = self.define(types.MethodType(_instance_op, instance)) + name = ("" if owner is None else f"{owner.__name__}_") + self.__name__ + instance_op = self.define( + types.MethodType(_instance_op, instance), name=name + ) instance.__dict__[self._name_on_instance] = instance_op return instance_op elif instance is not None: diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index 4532ae72..9b311b25 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -14,14 +14,14 @@ def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: if annotation is float: return st.floats(allow_nan=False) if get_origin(annotation) is list and get_args(annotation) == (int,): - return st.lists(st.integers()) + return st.lists(st.integers(), max_size=2) raise NotImplementedError( f"No value strategy for return annotation {annotation!r}; " "supported: int, list[int]" ) -_UNARY_INT_FNS: list[Callable[[int], int]] = [ +_UNARY_NUM_FNS: list[Callable[[int], int]] = [ lambda x: x, lambda x: x + 1, lambda x: x - 1, @@ -30,7 +30,7 @@ def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: lambda x: 3 * x + 1, ] -_BINARY_INT_FNS: list[Callable[[int, int], int]] = [ +_BINARY_NUM_FNS: list[Callable[[int, int], int]] = [ lambda x, y: x + y, lambda x, y: x - y, lambda x, y: x * y, @@ -58,10 +58,10 @@ def _strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]: if not params: return _value_strategy_for(ret).map(deffn) - if ret is int and param_types == (int,): - return st.sampled_from(_UNARY_INT_FNS) - if ret is int and param_types == (int, int): - return st.sampled_from(_BINARY_INT_FNS) + if ret in (int, float) and param_types == (int,): + return st.sampled_from(_UNARY_NUM_FNS) + if ret in (int, float) and param_types == (int, int): + return st.sampled_from(_BINARY_NUM_FNS) if get_origin(ret) is list and get_args(ret) == (int,) and param_types == (int,): return st.sampled_from(_UNARY_LIST_FNS) raise NotImplementedError( diff --git a/tests/test_handlers_llm_provider.py b/tests/test_handlers_llm_provider.py index d6325041..08e92c8a 100644 --- a/tests/test_handlers_llm_provider.py +++ b/tests/test_handlers_llm_provider.py @@ -244,7 +244,7 @@ def test_agent_tool_names_are_valid_integration(): agent = _ToolNameAgent() template = agent.ask tools = template.tools - expected_helper_tool_name = f"self__{agent.helper.__name__}" + expected_helper_tool_name = "self__helper" assert tools assert expected_helper_tool_name in tools assert all(re.fullmatch(r"[a-zA-Z0-9_-]+", name) for name in tools) diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index a22928cc..e73a9a7b 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -1,12 +1,23 @@ import functools import itertools +import typing import pytest from hypothesis import given, settings from hypothesis import strategies as st from effectful.internals.runtime import interpreter -from effectful.ops.monoid import Max, Min, NormalizeIntp, Product, Semilattice, Sum +from effectful.ops.monoid import ( + CartesianProduct, + Max, + Min, + Monoid, + NormalizeIntp, + Product, + Semilattice, + Sum, + distributes_over, +) from effectful.ops.semantics import apply, evaluate, fvsof, handler from effectful.ops.syntax import _BaseTerm, defdata, syntactic_eq from effectful.ops.types import NotHandled, Operation @@ -37,6 +48,18 @@ pytest.param(Product, id="Product"), ] +# Pairs (outer, inner) such that inner distributes over outer — i.e. the lifting +# identity ``outer(inner(body, A), CartesianProduct...) == inner(outer(body, D), ...)`` +# is valid for that semiring pair. +MONOID_PAIRS = [ + pytest.param(o.values[0], i.values[0], id=f"{o.id}-{i.id}") + for o in ALL_MONOIDS + for i in ALL_MONOIDS + if distributes_over( + typing.cast(Monoid, i.values[0]).plus, typing.cast(Monoid, o.values[0]).plus + ) +] + def define_vars(*names, typ=int): if len(names) == 1: @@ -70,14 +93,11 @@ def syntactic_eq_alpha(x, y) -> bool: def _canonicalize(expr): counter = itertools.count() - def _passthrough(op, *args, **kwargs): - return defdata(op, *args, **kwargs) - def _substitute(arg, renaming): """Apply a bound-variable renaming using ``evaluate`` for traversal.""" if not renaming: return arg - with interpreter({apply: _passthrough, **renaming}): + with interpreter({apply: _BaseTerm, **renaming}): return evaluate(arg) def _bound_var_order(args, kwargs, bound_set): @@ -121,7 +141,7 @@ def _apply_canonical(op, *args, **kwargs): *bindings.args, *bindings.kwargs.values() ) if not all_bound: - return defdata(op, *args, **kwargs) + return _BaseTerm(op, *args, **kwargs) order = _bound_var_order(args, kwargs, all_bound) canonical = {var: _canonical_op(next(counter)) for var in order} @@ -252,8 +272,8 @@ def test_plus_assoc_left(monoid): def test_plus_sequence(monoid): a, b, c, d = define_vars("a", "b", "c", "d", typ=type(monoid.identity)) _check_pair( - lhs=monoid.plus([a(), b()], [c(), d()]), - rhs=[monoid.plus(a(), c()), monoid.plus(b(), d())], + lhs=monoid.plus((a(), b()), (c(), d())), + rhs=(monoid.plus(a(), c()), monoid.plus(b(), d())), free_vars=[a, b, c, d], ) @@ -368,8 +388,8 @@ def f(_x: int) -> int: g = Operation.define(f, name="g") - lhs = monoid.reduce([f(x()), g(x())], {x: X()}) - rhs = [monoid.reduce(f(x()), {x: X()}), monoid.reduce(g(x()), {x: X()})] + lhs = monoid.reduce((f(x()), g(x())), {x: X()}) + rhs = (monoid.reduce(f(x()), {x: X()}), monoid.reduce(g(x()), {x: X()})) _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, f, g]) @@ -385,11 +405,11 @@ def f(_x: int) -> int: g = Operation.define(f, name="g") - lhs = monoid.reduce([f(x()), g(y())], {x: X(), y: Y()}) - rhs = [ + lhs = monoid.reduce((f(x()), g(y())), {x: X(), y: Y()}) + rhs = ( monoid.reduce(f(x()), {x: X(), y: Y()}), monoid.reduce(g(y()), {x: X(), y: Y()}), - ] + ) _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, Y, f, g]) @@ -496,8 +516,6 @@ def g(_x: int) -> list[int]: Sum.reduce(Product.plus(b(), f(b(), c())), {b: g(a()), c: C()}), ) assert fvsof(bogus_rhs) != fvsof(lhs) - # Structural-only negative check: the normalizer correctly refused to apply - # the bogus factorization. assert not syntactic_eq_alpha(lhs, bogus_rhs) @@ -516,3 +534,101 @@ def f(_x: int, _y: int) -> int: Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), ) _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B, C, f]) + + +@pytest.mark.parametrize("outer,inner", MONOID_PAIRS) +def test_reduce_lifted_1(outer, inner): + a, i = define_vars("a", "i") + A, N, A_domain = define_vars("A", "N", "A_domain", typ=list[int]) + + @Operation.define + def f(_: int) -> float: + raise NotHandled + + term1 = outer.reduce( + inner.reduce(f(a()), {a: A()}), + {A: CartesianProduct.reduce(A_domain(), {i: N()})}, + ) + term2 = inner.reduce(outer.reduce(f(a()), {a: A_domain()}), {i: N()}) + _check_pair(lhs=term1, rhs=term2, free_vars=[N, A_domain, f]) + + +def test_reduce_cartesian_1(): + a, i = define_vars("a", "i") + A = define_vars("A", typ=list[int]) + + term1 = Sum.reduce( + Product.reduce(a(), {a: []}), + {A: CartesianProduct.reduce([], {i: []})}, + ) + term2 = Product.reduce(Sum.reduce(a(), {a: []}), {i: []}) + assert term1 == term2 + + +def test_reduce_cartesian_2(): + a, i = define_vars("a", "i") + A = define_vars("A", typ=list[int]) + + term1 = Sum.reduce( + Product.reduce(a(), {a: A()}), + {A: CartesianProduct.reduce([(0,)], {i: [0]})}, + ) + term2 = Product.reduce(Sum.reduce(a(), {a: [0]}), {i: [0]}) + assert term1 == term2 + + +@pytest.mark.parametrize("outer,inner", MONOID_PAIRS) +def test_reduce_lifted_multi_index(outer, inner): + a, i, j = define_vars("a", "i", "j") + A, N, M, A_domain = define_vars("A", "N", "M", "A_domain", typ=list[int]) + + @Operation.define + def f(_: int) -> float: + raise NotHandled + + term1 = outer.reduce( + inner.reduce(f(a()), {a: A()}), + {A: CartesianProduct.reduce(A_domain(), {i: N(), j: M()})}, + ) + term2 = inner.reduce( + outer.reduce(f(a()), {a: A_domain()}), + {i: N(), j: M()}, + ) + _check_pair(lhs=term1, rhs=term2, free_vars=[N, M, A_domain, f]) + + +@pytest.mark.parametrize("outer,inner", MONOID_PAIRS) +def test_reduce_lifted_2(outer, inner): + """The worked example on page 396 of 'Lifted Variable Elimination: + Decoupling the Operators from the Constraint Language'. + + """ + a, i, s, t = define_vars("a", "i", "s", "t") + A, N, T = define_vars("A", "N", "T", typ=list[int]) + + @Operation.define + def A_domain(_i: int) -> list[int]: + raise NotHandled + + @Operation.define + def f1(_a: int, _s: int) -> float: + raise NotHandled + + @Operation.define + def f2(_t: int, _a: int) -> float: + raise NotHandled + + term1 = outer.reduce( + inner.reduce(inner.plus(f1(a(), s()), f2(t(), a())), {a: A()}), + {A: CartesianProduct.reduce(A_domain(i()), {i: N()}), t: T()}, + ) + + term2 = outer.reduce( + inner.reduce( + outer.reduce(inner.plus(f1(a(), s()), f2(t(), a())), {a: A_domain(i())}), + {i: N()}, + ), + {t: T()}, + ) + + _check_pair(lhs=term1, rhs=term2, free_vars=[a, i, s, t, A, N, T, A_domain, f1, f2]) diff --git a/tests/test_ops_syntax.py b/tests/test_ops_syntax.py index 185b6132..1f5c4776 100644 --- a/tests/test_ops_syntax.py +++ b/tests/test_ops_syntax.py @@ -489,7 +489,6 @@ def _(self, x: bool) -> bool: ) assert isinstance(term_float, Term) - assert term_float.op.__name__ == "my_singledispatch" assert term_float.args == (1.5,) assert term_float.kwargs == {} From 81a465c01616ca41fb3aca6826bb3119ec4110fa Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Tue, 12 May 2026 16:42:40 -0400 Subject: [PATCH 03/10] Refactor `monoid.py` to remove class structure (#661) * Add monoid module (#653) * add monoid module * clean up * fix doctest * fix * wip * remove incorrect rule * add disjoint set tests and fix bug * lint * drop jax monoid defs * drop incorrect comment * add assert * reduce nondeterminism and add assertions * fix inconsistent stream numbering and missing constant factors * wip * cleanup * fix rule * wip * fix bug * cleanup * lin * wip * fix tests * format * lint * wip --- effectful/ops/monoid.py | 255 ++++++++++++++++++--------------------- effectful/ops/types.py | 76 ++++++++++++ tests/test_ops_monoid.py | 6 +- 3 files changed, 197 insertions(+), 140 deletions(-) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index ad83de47..0d6e230c 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -10,20 +10,25 @@ from typing import Annotated, Any from effectful.internals.disjoint_set import DisjointSet -from effectful.internals.runtime import interpreter from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler from effectful.ops.syntax import ( ObjectInterpretation, Scoped, _NumberTerm, - defdata, deffn, implements, iter_, syntactic_eq, syntactic_hash, ) -from effectful.ops.types import Expr, Interpretation, NotHandled, Operation, Term +from effectful.ops.types import ( + Expr, + Interpretation, + NotHandled, + Operation, + Term, + _CustomSingleDispatchMethod, +) # Note: The streams value type should be something like Iterable[T], but some of # our target stream types (e.g. jax.Array) are not subtypes of Iterable @@ -64,60 +69,57 @@ def __init__(self, kernel: Callable[[T, T], T], identity: T): def __repr__(self): return f"{type(self)}({self.kernel}, {self.identity})" + def __eq__(self, other): + return id(self) == id(other) + + def __hash__(self): + return hash(id(self)) + @Operation.define - def plus[S: Body[T]](self, *args: S) -> S: + @_CustomSingleDispatchMethod + def plus[S](self, dispatch, *args: S) -> S: """Monoid addition with broadcasting over common collection types, callables, and interpretations. - """ if not args: return typing.cast(S, self.identity) + return dispatch(type(args[0]))(self, *args) + @plus.register(object) # type: ignore[attr-defined] + def _(self, *args): if any(isinstance(x, Term) for x in args): - return typing.cast(S, defdata(self.plus, *args)) + raise NotHandled + return functools.reduce(self.kernel, args, self.identity) - return self._plus(*args) - - @functools.singledispatchmethod - def _plus[S](self, *args: S) -> S: - return typing.cast(S, functools.reduce(self.kernel, args, self.identity)) - - @_plus.register(tuple) + @plus.register(tuple) # type: ignore[attr-defined] def _(self, *args): return tuple(self.plus(*vs) for vs in zip(*args, strict=True)) - @_plus.register(Generator) + @plus.register(Generator) # type: ignore[attr-defined] def _(self, *args): return (self.plus(*vs) for vs in zip(*args, strict=True)) - @_plus.register(Mapping) + @plus.register(Mapping) # type: ignore[attr-defined] def _(self, *args): if isinstance(args[0], Interpretation): keys = args[0].keys() - for b in args[1:]: if not isinstance(b, Interpretation): raise TypeError(f"Expected interpretation but got {b}") - - b_keys = b.keys() - if not keys == b_keys: + if not keys == b.keys(): raise ValueError( - f"Expected interpretation of {keys} but got {b_keys}" + f"Expected interpretation of {keys} but got {b.keys()}" ) - - result = {k: self.plus(*(handler(b)(b[k]) for b in args)) for k in keys} - return result + return {k: self.plus(*(handler(b)(b[k]) for b in args)) for k in keys} for b in args[1:]: if not isinstance(b, Mapping): raise TypeError(f"Expected mapping but got {b}") - all_values = collections.defaultdict(list) for d in args: for k, v in d.items(): all_values[k].append(v) - result = {k: self.plus(*vs) for (k, vs) in all_values.items()} - return result + return {k: self.plus(*vs) for (k, vs) in all_values.items()} @Operation.define @functools.singledispatchmethod @@ -155,12 +157,9 @@ def generator(loop_order) -> Iterator[Interpretation]: yield coproduct(intp, intp2) loop_order = list(order_streams(streams)) - try: - return self.plus( - *(handler(intp)(evaluate)(body) for intp in generator(loop_order)) - ) - except NotHandled: - return typing.cast(U, defdata(self.reduce, body, streams)) + return self.plus( + *(handler(intp)(evaluate)(body) for intp in generator(loop_order)) + ) @reduce.register # type: ignore[attr-defined] def _(self, body: Mapping, streams): @@ -175,35 +174,19 @@ def _(self, body: Generator, streams): return (self.reduce(x, streams) for x in body) -class IdempotentMonoid[T](Monoid[T]): - @Operation.define - def plus[S: Body[T]](self, *args: S) -> S: - return super().plus(*args) - - @Operation.define - def reduce[A, B, U: Body]( - self, - body: Annotated[U, Scoped[A | B]], - streams: Annotated[Streams, Scoped[A]], - ) -> Annotated[U, Scoped[B]]: - return super().reduce(body, streams) +def _is_monoid_plus(op: Operation) -> bool: + """True if ``op`` is the ``plus`` operation of some :class:`Monoid`.""" + owner = getattr(op, "__self__", None) + return isinstance(owner, Monoid) and op is owner.plus -class CommutativeMonoid[T](Monoid[T]): - @Operation.define - def plus[S: Body[T]](self, *args: S) -> S: - return super().plus(*args) - - @Operation.define - def reduce[A, B, U: Body]( - self, - body: Annotated[U, Scoped[A | B]], - streams: Annotated[Streams, Scoped[A]], - ) -> Annotated[U, Scoped[B]]: - return super().reduce(body, streams) +def _is_monoid_reduce(op: Operation) -> bool: + """True if ``op`` is the ``reduce`` operation of some :class:`Monoid`.""" + owner = getattr(op, "__self__", None) + return isinstance(owner, Monoid) and op is owner.reduce -class CommutativeMonoidWithZero[T](CommutativeMonoid[T]): +class MonoidWithZero[T](Monoid[T]): zero: T def __init__(self, kernel: Callable[[T, T], T], identity: T, zero: T): @@ -213,32 +196,6 @@ def __init__(self, kernel: Callable[[T, T], T], identity: T, zero: T): def __repr__(self): return f"{type(self)}({self.kernel}, {self.identity}, {self.zero})" - @Operation.define - def plus[S: Body[T]](self, *args: S) -> S: - return super().plus(*args) - - @Operation.define - def reduce[A, B, U: Body]( - self, - body: Annotated[U, Scoped[A | B]], - streams: Annotated[Streams, Scoped[A]], - ) -> Annotated[U, Scoped[B]]: - return super().reduce(body, streams) - - -class Semilattice[T](IdempotentMonoid[T], CommutativeMonoid[T]): - @Operation.define - def plus[S: Body[T]](self, *args: S) -> S: - return super().plus(*args) - - @Operation.define - def reduce[A, B, U: Body]( - self, - body: Annotated[U, Scoped[A | B]], - streams: Annotated[Streams, Scoped[A]], - ) -> Annotated[U, Scoped[B]]: - return super().reduce(body, streams) - @Operation.define def _arg_min[T]( @@ -271,15 +228,30 @@ def to_tuple(x): return [to_tuple(x) + to_tuple(y) for (x, y) in itertools.product(a, b)] -Min = Semilattice(kernel=min, identity=float("inf")) -Max = Semilattice(kernel=max, identity=float("-inf")) +Min = Monoid(kernel=min, identity=float("inf")) +Max = Monoid(kernel=max, identity=float("-inf")) ArgMin = Monoid(kernel=_arg_min, identity=(float("inf"), None)) ArgMax = Monoid(kernel=_arg_max, identity=(float("-inf"), None)) -Sum = CommutativeMonoid(kernel=_NumberTerm.__add__, identity=0) -Product = CommutativeMonoidWithZero(kernel=_NumberTerm.__mul__, identity=1, zero=0) +Sum = Monoid(kernel=_NumberTerm.__add__, identity=0) +Product = MonoidWithZero(kernel=_NumberTerm.__mul__, identity=1, zero=0) CartesianProduct = Monoid(kernel=product, identity=[()]) +@dataclass +class _ExtensiblePredicate[T]: + elems: set[T] + + def register(self, t: T) -> None: + self.elems.add(t) + + def __call__(self, t: T) -> bool: + return t in self.elems + + +is_commutative = _ExtensiblePredicate({Max, Min, Sum, Product}) +is_idempotent = _ExtensiblePredicate({Max, Min}) + + @dataclass class _ExtensibleBinaryRelation[S, T]: tuples: set[tuple[S, T]] @@ -292,13 +264,7 @@ def __call__(self, s: S, t: T) -> bool: distributes_over = _ExtensibleBinaryRelation( - { - (Max.plus, Min.plus), - (Min.plus, Max.plus), - (Sum.plus, Min.plus), - (Sum.plus, Max.plus), - (Product.plus, Sum.plus), - } + {(Max, Min), (Min, Max), (Sum, Min), (Sum, Max), (Product, Sum)} ) @@ -337,10 +303,12 @@ class PlusAssoc(ObjectInterpretation): @implements(Monoid.plus) def plus(self, monoid, *args): - if any(isinstance(x, Term) and x.op is monoid.plus for x in args): + def is_nested_plus(x): + return isinstance(x, Term) and x.op is monoid.plus + + if any(is_nested_plus(x) for x in args): flat_args = itertools.chain.from_iterable( - t.args if isinstance(t, Term) and t.op is monoid.plus else (t,) - for t in args + t.args if is_nested_plus(t) else (t,) for t in args ) assert len(args) > 0 return monoid.plus(*flat_args) @@ -353,33 +321,36 @@ class PlusDistr(ObjectInterpretation): @implements(Monoid.plus) def plus(self, monoid, *args): if any( - isinstance(x, Term) and distributes_over(monoid.plus, x.op) for x in args + isinstance(x, Term) + and _is_monoid_plus(x.op) + and distributes_over(monoid, x.op.__self__) + for x in args ): non_terms = [] - # group terms by head operation - by_head_op = defaultdict(list) + # group terms by their monoid + by_monoid: dict[Monoid, list[Term]] = defaultdict(list) for t in args: - if isinstance(t, Term): - by_head_op[t.op].append(t) + if isinstance(t, Term) and _is_monoid_plus(t.op): + by_monoid[t.op.__self__].append(t) else: non_terms.append(t) # distribute over each group progress = False final_sum = [] - for op, terms in by_head_op.items(): + for m, terms in by_monoid.items(): if ( len(terms) > 1 - and distributes_over(monoid.plus, op) - and not distributes_over(op, monoid.plus) + and distributes_over(monoid, m) + and not distributes_over(m, monoid) ): progress = True term_args = (t.args for t in terms) dist_terms = ( monoid.plus(*args) for args in itertools.product(*term_args) ) - final_sum.append(op(*dist_terms)) + final_sum.append(m.plus(*dist_terms)) else: final_sum += terms if progress: @@ -390,8 +361,10 @@ def plus(self, monoid, *args): class PlusZero(ObjectInterpretation): """x₁ * ... * 0 * ... * xₙ = 0""" - @implements(CommutativeMonoidWithZero.plus) + @implements(Monoid.plus) def plus(self, monoid, *args): + if not (isinstance(monoid, MonoidWithZero)): + return fwd() if any(x is monoid.zero for x in args): return monoid.zero return fwd() @@ -400,8 +373,11 @@ def plus(self, monoid, *args): class PlusConsecutiveDups(ObjectInterpretation): """x ⊕ x ⊕ y = x ⊕ y""" - @implements(IdempotentMonoid.plus) + @implements(Monoid.plus) def plus(self, monoid, *args): + if not is_idempotent(monoid): + return fwd() + dedup_args = ( args[i] for i in range(len(args)) @@ -423,8 +399,11 @@ def __eq__(self, other): def __hash__(self): return syntactic_hash(self) - @implements(Semilattice.plus) + @implements(Monoid.plus) def plus(self, monoid, *args): + if not (is_idempotent(monoid) and is_commutative(monoid)): + return fwd() + # elim dups args_count = Counter(self._HashableTerm(t) for t in args) if len(args_count) < len(args): @@ -475,7 +454,7 @@ class ReduceFusion(ObjectInterpretation): @implements(Monoid.reduce) def reduce(self, monoid, body, streams): - if isinstance(body, Term) and body.op == monoid.reduce: + if isinstance(body, Term) and body.op is monoid.reduce: return monoid.reduce(body.args[0], streams | body.args[1]) return fwd() @@ -485,9 +464,11 @@ class ReduceSplit(ObjectInterpretation): reduce(R, S, b1 + ... + bn) = reduce(R, S, b1) + ... + reduce(R, S, bn) """ - @implements(CommutativeMonoid.reduce) + @implements(Monoid.reduce) def reduce(self, monoid, body, streams): - if isinstance(body, Term) and body.op == monoid.plus: + if not is_commutative(monoid): + return fwd() + if isinstance(body, Term) and body.op is monoid.plus: return monoid.plus(*(monoid.reduce(x, streams) for x in body.args)) return fwd() @@ -506,9 +487,16 @@ class ReduceFactorization(ObjectInterpretation): and free(Aᵢ) ∩ S ⊆ Sᵢ """ - @implements(CommutativeMonoid.reduce) + @implements(Monoid.reduce) def reduce(self, monoid, body, streams): - if isinstance(body, Term) and distributes_over(body.op, monoid.plus): + if not is_commutative(monoid): + return fwd() + if ( + isinstance(body, Term) + and _is_monoid_plus(body.op) + and distributes_over(body.op.__self__, monoid) + ): + inner_monoid: Monoid = body.op.__self__ stream_vars = set(streams.keys()) factors = [(arg, fvsof(arg)) for arg in body.args] stream_ids = {v: i for (i, v) in enumerate(stream_vars)} @@ -521,7 +509,7 @@ def reduce(self, monoid, body, streams): ds.union(stream_id, *deps) # factors are in the same partition as their dependencies - for factor, factor_fvs in factors: + for _, factor_fvs in factors: factor_streams = sorted( [stream_ids[v] for v in (factor_fvs & stream_vars)] ) @@ -550,14 +538,14 @@ def reduce(self, monoid, body, streams): for t in partition_factors ), "partition contains all streams required by factor" - partition_term = body.op(*(t[0] for t in partition_factors)) + partition_term = inner_monoid.plus(*(t[0] for t in partition_factors)) new_reduces.append((partition_term, partition_streams)) placed_streams |= partition_stream_keys constant_factors = [t for (t, fvs) in factors if not (fvs & stream_vars)] if len(new_reduces) > 1: - result = body.op( + result = inner_monoid.plus( *constant_factors, *(monoid.reduce(*args) for args in new_reduces) ) return result @@ -592,18 +580,6 @@ def inner_stream( ) -def match_reduce(term: Term) -> tuple | None: - reduce_args = None - - def set_reduce_args(*args, **kwargs): - nonlocal reduce_args - reduce_args = args - - with interpreter({Monoid.reduce: set_reduce_args}): - term.op(*term.args, **term.kwargs) - return reduce_args - - class ReduceDistributeCartesianProduct(ObjectInterpretation): """Eliminates a reduce over a cartesian product. ∑_x₁ ∑_x₂ ... ∑_xₙ ∏_i f(xᵢ) = ∏_i ∑_xᵢ f(xᵢ) @@ -623,25 +599,30 @@ class ReduceDistributeCartesianProduct(ObjectInterpretation): variable elimination." AISTATS. 2013. """ - @implements(CommutativeMonoid.reduce) + @implements(Monoid.reduce) def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): - if not (isinstance(sum_body, Term)): + if not (is_commutative(sum_monoid) and isinstance(sum_body, Term)): return fwd() # body is a product or multiplication of products - if distributes_over(sum_body.op, sum_monoid.plus): + if _is_monoid_plus(sum_body.op) and distributes_over( + sum_body.op.__self__, sum_monoid + ): prod_reduces = sum_body.args else: prod_reduces = [sum_body] products: list[tuple[Monoid, Callable, Operation, Term]] = [] for prod_reduce in prod_reduces: - prod_args = match_reduce(prod_reduce) - if prod_args is None: + if not ( + isinstance(prod_reduce, Term) and _is_monoid_reduce(prod_reduce.op) + ): return fwd() - (prod_monoid, prod_body, prod_streams) = prod_args + prod_monoid: Monoid = prod_reduce.op.__self__ + prod_body = prod_reduce.args[0] + prod_streams = typing.cast(Mapping, prod_reduce.args[1]) if not ( - distributes_over(prod_monoid.plus, sum_monoid.plus) + distributes_over(prod_monoid, sum_monoid) and (len(products) == 0 or products[-1][0] == prod_monoid) ): return fwd() @@ -658,7 +639,7 @@ def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): for outer_sum_streams, cprod_op, cprod_term in inner_stream(sum_streams): if not ( isinstance(cprod_term, Term) - and cprod_term.op == CartesianProduct.reduce + and cprod_term.op is CartesianProduct.reduce ): continue (cprod_body, cprod_streams) = cprod_term.args diff --git a/effectful/ops/types.py b/effectful/ops/types.py index 14795ec9..77823b7b 100644 --- a/effectful/ops/types.py +++ b/effectful/ops/types.py @@ -42,6 +42,59 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: return self.func(self.dispatch, *args, **kwargs) +class _CustomSingleDispatchMethod[**P, **Q, S, T]: + """Method analog of :class:`_CustomSingleDispatchCallable`. + + The wrapped function has signature ``(self, dispatch, *args, **kwargs)``, + where ``dispatch`` is :meth:`functools.singledispatch.dispatch`. As a + descriptor, it binds ``self`` on attribute access, so callers invoke it + as ``instance.method(*args, **kwargs)``. + """ + + def __init__( + self, + func: Callable[Concatenate[Any, Callable[[type], Callable[Q, S]], P], T], + ): + self.func = func + self._registry = functools.singledispatch(func) + self.__signature__ = inspect.signature( + functools.partial(func, None, None) # type: ignore[arg-type] + ) + functools.update_wrapper(self, func) # type: ignore[arg-type] + + @property + def dispatch(self): + return self._registry.dispatch + + @property + def register(self): + return self._registry.register + + def __get__(self, instance, owner=None): + if instance is None: + return self + return _BoundCustomSingleDispatchMethod(self, instance) + + +class _BoundCustomSingleDispatchMethod: + __slots__ = ("_method", "_instance") + + def __init__(self, method: _CustomSingleDispatchMethod, instance: Any): + self._method = method + self._instance = instance + + @property + def dispatch(self): + return self._method.dispatch + + @property + def register(self): + return self._method.register + + def __call__(self, *args, **kwargs): + return self._method.func(self._instance, self._method.dispatch, *args, **kwargs) + + class _ClassMethodOpDescriptor(classmethod): def __init__(self, define, *args, **kwargs): super().__init__(*args, **kwargs) @@ -311,6 +364,15 @@ def func(*args, **kwargs): return typing.cast(Operation[P, T], cls.define(func, **kwargs)) + @define.register(types.MethodType) + @classmethod + def _define_methodtype[**P, T]( + cls, t: Callable[P, T], *, name: str | None = None + ) -> "Operation[P, T]": + op = cls._define_callable(t, name=name) + op.__self__ = t.__self__ # type: ignore[attr-defined] + return typing.cast("Operation[P, T]", op) + @define.register(staticmethod) @classmethod def _define_staticmethod[**P, T](cls, t: "staticmethod[P, T]", **kwargs): @@ -350,6 +412,20 @@ def func(*args, **kwargs): op.register = default._registry.register # type: ignore[attr-defined] return op + @define.register(_CustomSingleDispatchMethod) + @classmethod + def _define_customsingledispatchmethod( + cls, default: _CustomSingleDispatchMethod, **kwargs + ): + @functools.wraps(default.func) + def _wrapper(obj, *args, **kwargs): + return default.__get__(obj)(*args, **kwargs) + + op = cls.define(_wrapper, **kwargs) + op.register = default.register # type: ignore[attr-defined] + op.dispatch = default.dispatch # type: ignore[attr-defined] + return op + @typing.final def __default_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> "Expr[V]": """The default rule is used when the operation is not handled. diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index e73a9a7b..d881869a 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -14,9 +14,9 @@ Monoid, NormalizeIntp, Product, - Semilattice, Sum, distributes_over, + is_commutative, ) from effectful.ops.semantics import apply, evaluate, fvsof, handler from effectful.ops.syntax import _BaseTerm, defdata, syntactic_eq @@ -56,7 +56,7 @@ for o in ALL_MONOIDS for i in ALL_MONOIDS if distributes_over( - typing.cast(Monoid, i.values[0]).plus, typing.cast(Monoid, o.values[0]).plus + typing.cast(Monoid, i.values[0]), typing.cast(Monoid, o.values[0]) ) ] @@ -354,7 +354,7 @@ def test_plus_idempotent_non_consecutive(monoid): PlusDups; plain IdempotentMonoid leaves it as-is (consecutive-only).""" a, b = define_vars("a", "b") lhs = monoid.plus(a(), b(), a()) - if isinstance(monoid, Semilattice): + if is_commutative(monoid): rhs = monoid.plus(a(), b()) else: rhs = monoid.plus(a(), b(), a()) From 567f0302a346d0ed80501f87cb4aceba2a6a3146 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Thu, 21 May 2026 15:11:59 -0400 Subject: [PATCH 04/10] Add `jax` array monoids and reduction rule (#658) * Add monoid module (#653) * add monoid module * clean up * fix doctest * fix * wip * remove incorrect rule * add disjoint set tests and fix bug * lint * drop jax monoid defs * drop incorrect comment * add assert * reduce nondeterminism and add assertions * fix inconsistent stream numbering and missing constant factors * wip * cleanup * wip * fix rule * wip * fix bug * cleanup * lin * wip * fix tests * format * lint * wip * wip * wip * wip * wip * wip * wip * wip * drop runtime typed dict lifting * wip * format * reorganize * stop using string dicts to avoid unification issue * wip * wip * wip * wip * wip * use check_rewrite in jax tests * lint * fix bugs --- effectful/handlers/jax/_handlers.py | 7 + effectful/handlers/jax/monoid.py | 162 +++++++ effectful/ops/monoid.py | 486 +++++++++++-------- effectful/ops/syntax.py | 2 + tests/_monoid_helpers.py | 284 ++++++++++- tests/test_handlers_jax_monoid.py | 96 ++++ tests/test_ops_monoid.py | 718 +++++++++++++++------------- 7 files changed, 1206 insertions(+), 549 deletions(-) create mode 100644 effectful/handlers/jax/monoid.py create mode 100644 tests/test_handlers_jax_monoid.py diff --git a/effectful/handlers/jax/_handlers.py b/effectful/handlers/jax/_handlers.py index 308cdb76..c5d10423 100644 --- a/effectful/handlers/jax/_handlers.py +++ b/effectful/handlers/jax/_handlers.py @@ -19,6 +19,7 @@ deffn, defop, syntactic_eq, + syntactic_hash, ) from effectful.ops.types import Expr, NotHandled, Operation, Term @@ -277,3 +278,9 @@ def _(x: jax.Array, other) -> bool: and x.shape == other.shape and bool((jnp.asarray(x) == jnp.asarray(other)).all()) ) + + +@syntactic_hash.register(jax.Array) +def _(x: jax.Array) -> int: + # Concrete arrays aren't hashable; hash by shape, dtype, and bytes. + return hash(("jax.Array", x.shape, str(x.dtype), bytes(jax.numpy.asarray(x)))) diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py new file mode 100644 index 00000000..a406cda5 --- /dev/null +++ b/effectful/handlers/jax/monoid.py @@ -0,0 +1,162 @@ +import functools + +import jax + +import effectful.handlers.jax.numpy as jnp +from effectful.handlers.jax import bind_dims, unbind_dims +from effectful.handlers.jax.scipy.special import logsumexp +from effectful.ops.monoid import ( + CartesianProduct, + Max, + Min, + Monoid, + NormalizeIntp, + Product, + Sum, + outer_stream, +) +from effectful.ops.semantics import evaluate, fvsof, fwd, handler, typeof +from effectful.ops.syntax import ObjectInterpretation, deffn, implements +from effectful.ops.types import Operation + + +def cartesian_prod(x, y): + if x.ndim == 1: + x = x[:, None] + if y.ndim == 1: + y = y[:, None] + nx, dx = x.shape + ny, dy = y.shape + # Broadcast into (nx, ny, dx+dy), then flatten the first two axes + x_b = jnp.broadcast_to(x[:, None, :], (nx, ny, dx)) + y_b = jnp.broadcast_to(y[None, :, :], (nx, ny, dy)) + return jnp.concatenate([x_b, y_b], axis=-1).reshape(nx * ny, dx + dy) + + +LogSumExp = Monoid(name="LogSumExp", identity=jnp.asarray(float("-inf"))) + + +def _jax_args(args): + """True iff ``args`` is non-empty and every arg is a concrete + :class:`jax.Array` (no Terms). + """ + typs = (typeof(a) for a in args) + return ( + bool(args) + and any(issubclass(t, jax.Array) for t in typs) + and all(issubclass(t, jax.typing.ArrayLike) for t in typs) + ) + + +class SumPlusJax(ObjectInterpretation): + @implements(Sum.plus) + def plus(self, *args): + if not _jax_args(args): + return fwd() + return functools.reduce(jnp.add, args) + + +class ProductPlusJax(ObjectInterpretation): + @implements(Product.plus) + def plus(self, *args): + if not _jax_args(args): + return fwd() + return functools.reduce(jnp.multiply, args) + + +class MinPlusJax(ObjectInterpretation): + @implements(Min.plus) + def plus(self, *args): + if not _jax_args(args): + return fwd() + return functools.reduce(jnp.minimum, args) + + +class MaxPlusJax(ObjectInterpretation): + @implements(Max.plus) + def plus(self, *args): + if not _jax_args(args): + return fwd() + return functools.reduce(jnp.maximum, args) + + +class LogSumExpPlusJax(ObjectInterpretation): + @implements(LogSumExp.plus) + def plus(self, *args): + if not _jax_args(args): + return fwd() + return functools.reduce(jnp.logaddexp, args) + + +class CartesianProductPlusJax(ObjectInterpretation): + @implements(CartesianProduct.plus) + def plus(self, *args): + # Skip identity ``[()]`` args; short-circuit on zero ``[]``. Both + # sentinels arrive as Python lists alongside jax-array factors, so + # check for them explicitly before composing. + if not any(isinstance(a, jax.Array) for a in args): + return fwd() + result = None + for a in args: + if a is CartesianProduct.zero: + return CartesianProduct.zero + if a is CartesianProduct.identity: + continue + if not isinstance(a, jax.Array): + return fwd() + result = a if result is None else cartesian_prod(result, a) + return result if result is not None else CartesianProduct.identity + + +ARRAY_REDUCTORS = { + Sum: jnp.sum, + Product: jnp.prod, + Min: jnp.min, + Max: jnp.max, + LogSumExp: logsumexp, +} + + +class ArrayReduce(ObjectInterpretation): + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if monoid not in ARRAY_REDUCTORS or typeof(body) is not jax.Array: + return fwd() + if not streams: + return monoid.identity + + reductor = ARRAY_REDUCTORS[monoid] + index = Operation.define(jax.Array) + for stream_key, stream_body, streams_tail in outer_stream(streams): + if not issubclass(typeof(stream_body), jax.Array): + continue + + if stream_key in fvsof(body): + with handler({stream_key: deffn(unbind_dims(stream_body, index))}): + eval_body = evaluate(body) + eval_streams_tail = evaluate(streams_tail) + assert isinstance(eval_streams_tail, dict) + reduce_tail = ( + monoid.reduce(eval_body, eval_streams_tail) + if len(eval_streams_tail) > 0 + else eval_body + ) + return reductor(bind_dims(reduce_tail, index), axis=0) + else: + # TODO: In this case, the stream is unused in the body. The body + # should be multiplied by the length of the stream. The current + # behavior is not efficient. + return fwd() + + return fwd() + + +NormalizeIntp.extend( + ArrayReduce(), + SumPlusJax(), + ProductPlusJax(), + MinPlusJax(), + MaxPlusJax(), + LogSumExpPlusJax(), + CartesianProductPlusJax(), +) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 0d6e230c..70bb5002 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -1,10 +1,10 @@ import collections.abc import functools import itertools -import numbers +import operator import typing -from collections import Counter, defaultdict -from collections.abc import Callable, Generator, Iterable, Iterator, Mapping +from collections import Counter, UserDict, defaultdict +from collections.abc import Callable, Generator, Iterable, Mapping from dataclasses import dataclass from graphlib import TopologicalSorter from typing import Annotated, Any @@ -14,21 +14,13 @@ from effectful.ops.syntax import ( ObjectInterpretation, Scoped, - _NumberTerm, deffn, implements, iter_, syntactic_eq, syntactic_hash, ) -from effectful.ops.types import ( - Expr, - Interpretation, - NotHandled, - Operation, - Term, - _CustomSingleDispatchMethod, -) +from effectful.ops.types import Expr, Interpretation, NotHandled, Operation, Term # Note: The streams value type should be something like Iterable[T], but some of # our target stream types (e.g. jax.Array) are not subtypes of Iterable @@ -43,31 +35,35 @@ ) -def order_streams[T](streams: Streams[T]) -> Iterable[tuple[Operation[[], T], Any]]: - """Determine an order to evaluate the streams based on their dependencies""" +def outer_stream( + streams: Streams, +) -> Iterable[tuple[Operation, Expr, dict[Operation, Expr]]]: + """Returns the streams that can be ordered outermost in the loop nest as + well as the remaining streams in the nest. + + """ stream_vars = set(streams.keys()) - dependencies = {k: fvsof(v) & stream_vars for k, v in streams.items()} - topo = TopologicalSorter(dependencies) + pred = {k: fvsof(v) & stream_vars for k, v in streams.items()} + topo = TopologicalSorter(pred) topo.prepare() - while topo.is_active(): - node_group = topo.get_ready() - for op in sorted(node_group): - yield (op, streams[op]) - topo.done(*node_group) + return ( + (op, streams[op], {k: v for (k, v) in streams.items() if k != op}) + for op in topo.get_ready() + ) class Monoid[T]: - kernel: Operation[[T, T], T] + """A monoid with ``plus`` and ``reduce`` :class:`Operation` s.""" + + _name: str identity: T - def __init__(self, kernel: Callable[[T, T], T], identity: T): + def __init__(self, identity: T, name: str): + self._name = name self.identity = identity - self.kernel = ( - kernel if isinstance(kernel, Operation) else Operation.define(kernel) - ) def __repr__(self): - return f"{type(self)}({self.kernel}, {self.identity})" + return f"Monoid({self._name!r})" def __eq__(self, other): return id(self) == id(other) @@ -75,166 +71,63 @@ def __eq__(self, other): def __hash__(self): return hash(id(self)) + # the weak typing allows us to write monoid.plus(monoid.identity, ) + # and monoid.plus(monoid.identity, ) @Operation.define - @_CustomSingleDispatchMethod - def plus[S](self, dispatch, *args: S) -> S: - """Monoid addition with broadcasting over common collection types, - callables, and interpretations. + def plus(self, *args: Any) -> Any: + """Monoid addition. Handlers supply per-monoid and broadcasting + behavior; the default rule only handles empty / Term cases. """ if not args: - return typing.cast(S, self.identity) - return dispatch(type(args[0]))(self, *args) - - @plus.register(object) # type: ignore[attr-defined] - def _(self, *args): - if any(isinstance(x, Term) for x in args): - raise NotHandled - return functools.reduce(self.kernel, args, self.identity) - - @plus.register(tuple) # type: ignore[attr-defined] - def _(self, *args): - return tuple(self.plus(*vs) for vs in zip(*args, strict=True)) - - @plus.register(Generator) # type: ignore[attr-defined] - def _(self, *args): - return (self.plus(*vs) for vs in zip(*args, strict=True)) - - @plus.register(Mapping) # type: ignore[attr-defined] - def _(self, *args): - if isinstance(args[0], Interpretation): - keys = args[0].keys() - for b in args[1:]: - if not isinstance(b, Interpretation): - raise TypeError(f"Expected interpretation but got {b}") - if not keys == b.keys(): - raise ValueError( - f"Expected interpretation of {keys} but got {b.keys()}" - ) - return {k: self.plus(*(handler(b)(b[k]) for b in args)) for k in keys} - - for b in args[1:]: - if not isinstance(b, Mapping): - raise TypeError(f"Expected mapping but got {b}") - all_values = collections.defaultdict(list) - for d in args: - for k, v in d.items(): - all_values[k].append(v) - return {k: self.plus(*vs) for (k, vs) in all_values.items()} + return self.identity + raise NotHandled @Operation.define - @functools.singledispatchmethod def reduce[A, B, U: Body]( self, body: Annotated[U, Scoped[A | B]], streams: Annotated[Streams, Scoped[A]], ) -> Annotated[U, Scoped[B]]: - if callable(body): - return typing.cast(U, lambda *a, **k: self.reduce(body(*a, **k), streams)) - - def generator(loop_order) -> Iterator[Interpretation]: - if len(loop_order) == 0: - return - - stream_key = loop_order[0][0] - stream_values = evaluate(streams[stream_key]) - stream_values_iter = iter(stream_values) # type: ignore[arg-type] - - # If we try to iterate and get a term instead of a real - # iterator, give up + """Reduce ``body`` over ``streams``. Handlers supply per-monoid and + broadcasting behavior; the default rule only handles the empty-stream + case. + """ + for stream_key, stream_body, streams_tail in outer_stream(streams): + if isinstance(stream_body, Term): + continue + stream_values_iter = iter(stream_body) if isinstance(stream_values_iter, Term) and stream_values_iter.op is iter_: - raise NotHandled - - if len(loop_order) == 1: - for val in stream_values_iter: - yield {stream_key: functools.partial(lambda v: v, val)} - else: - for val in stream_values_iter: - intp: Interpretation = { - stream_key: functools.partial(lambda v: v, val) - } - with handler(intp): - for intp2 in generator(loop_order[1:]): - yield coproduct(intp, intp2) - - loop_order = list(order_streams(streams)) - return self.plus( - *(handler(intp)(evaluate)(body) for intp in generator(loop_order)) - ) - - @reduce.register # type: ignore[attr-defined] - def _(self, body: Mapping, streams): - return {k: self.reduce(v, streams) for (k, v) in body.items()} - - @reduce.register # type: ignore[attr-defined] - def _(self, body: tuple, streams): - return tuple(self.reduce(x, streams) for x in body) - - @reduce.register # type: ignore[attr-defined] - def _(self, body: Generator, streams): - return (self.reduce(x, streams) for x in body) - - -def _is_monoid_plus(op: Operation) -> bool: - """True if ``op`` is the ``plus`` operation of some :class:`Monoid`.""" - owner = getattr(op, "__self__", None) - return isinstance(owner, Monoid) and op is owner.plus - - -def _is_monoid_reduce(op: Operation) -> bool: - """True if ``op`` is the ``reduce`` operation of some :class:`Monoid`.""" - owner = getattr(op, "__self__", None) - return isinstance(owner, Monoid) and op is owner.reduce + continue + new_reduces = [] + for stream_val in stream_values_iter: + with handler({stream_key: deffn(stream_val)}): + eval_args = evaluate((body, streams_tail)) + assert isinstance(eval_args, tuple) + new_reduces.append( + self.reduce(*eval_args) if streams_tail else eval_args[0] + ) + return self.plus(*new_reduces) + raise NotHandled class MonoidWithZero[T](Monoid[T]): zero: T - def __init__(self, kernel: Callable[[T, T], T], identity: T, zero: T): - super().__init__(kernel, identity) + def __init__(self, name: str, identity: T, zero: T): + super().__init__(name=name, identity=identity) self.zero = zero - def __repr__(self): - return f"{type(self)}({self.kernel}, {self.identity}, {self.zero})" - - -@Operation.define -def _arg_min[T]( - a: tuple[numbers.Number, T | None], b: tuple[numbers.Number, T | None] -) -> tuple[numbers.Number, T | None]: - if isinstance(a[0], Term) or isinstance(b[0], Term): - raise NotHandled - return b if b[0] < a[0] else a # type: ignore - - -@Operation.define -def _arg_max[T]( - a: tuple[numbers.Number, T | None], b: tuple[numbers.Number, T | None] -) -> tuple[numbers.Number, T | None]: - if isinstance(a[0], Term) or isinstance(b[0], Term): - raise NotHandled - return b if b[0] > a[0] else a # type: ignore - - -@Operation.define -def product[T]( - a: Iterable[tuple[T, ...] | T], b: Iterable[tuple[T, ...] | T] -) -> Iterable[tuple[T, ...]]: - if isinstance(a, Term) or isinstance(b, Term): - raise NotHandled - - def to_tuple(x): - return x if isinstance(x, tuple) else (x,) - return [to_tuple(x) + to_tuple(y) for (x, y) in itertools.product(a, b)] - - -Min = Monoid(kernel=min, identity=float("inf")) -Max = Monoid(kernel=max, identity=float("-inf")) -ArgMin = Monoid(kernel=_arg_min, identity=(float("inf"), None)) -ArgMax = Monoid(kernel=_arg_max, identity=(float("-inf"), None)) -Sum = Monoid(kernel=_NumberTerm.__add__, identity=0) -Product = MonoidWithZero(kernel=_NumberTerm.__mul__, identity=1, zero=0) -CartesianProduct = Monoid(kernel=product, identity=[()]) +Min = Monoid(name="Min", identity=float("inf")) +Max = Monoid(name="Max", identity=-float("inf")) +ArgMin = Monoid(name="ArgMin", identity=(Min.identity, None)) +ArgMax = Monoid(name="ArgMax", identity=(Max.identity, None)) +Sum = Monoid(name="Sum", identity=0) +Product = MonoidWithZero(name="Product", identity=1, zero=0) +# CartesianProduct values are "two-level indexable" (rows × positions). The +# identity ``[()]`` is one row of zero positions (composing with it preserves +# shape); the zero ``[]`` is no rows (absorbs under product). +CartesianProduct = MonoidWithZero(name="CartesianProduct", identity=[()], zero=[]) @dataclass @@ -268,6 +161,18 @@ def __call__(self, s: S, t: T) -> bool: ) +def _is_monoid_plus(op: Operation) -> bool: + """True if ``op`` is the ``plus`` operation of some :class:`Monoid`.""" + owner = getattr(op, "__self__", None) + return isinstance(owner, Monoid) and op is owner.plus + + +def _is_monoid_reduce(op: Operation) -> bool: + """True if ``op`` is the ``reduce`` operation of some :class:`Monoid`.""" + owner = getattr(op, "__self__", None) + return isinstance(owner, Monoid) and op is owner.reduce + + class PlusEmpty(ObjectInterpretation): """plus() = 0""" @@ -319,7 +224,7 @@ class PlusDistr(ObjectInterpretation): """x + (y * z) = x * y + x * z""" @implements(Monoid.plus) - def plus(self, monoid, *args): + def plus(self, monoid: Monoid, *args): if any( isinstance(x, Term) and _is_monoid_plus(x.op) @@ -417,24 +322,6 @@ def plus(self, monoid, *args): return fwd() -NormalizePlusIntp = functools.reduce( - coproduct, - typing.cast( - list[Interpretation], - [ - PlusEmpty(), - PlusSingle(), - PlusIdentity(), - PlusAssoc(), - PlusDistr(), - PlusZero(), - PlusConsecutiveDups(), - PlusDups(), - ], - ), -) - - class ReduceNoStreams(ObjectInterpretation): """Implements the identity reduce(R, ∅, body) = 0 @@ -668,18 +555,217 @@ def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): return fwd() -NormalizeReduceIntp = functools.reduce( - coproduct, - typing.cast( - list[Interpretation], - [ - ReduceNoStreams(), - ReduceFusion(), - ReduceSplit(), - ReduceFactorization(), - ReduceDistributeCartesianProduct(), - ], - ), +class MonoidOverCallable(ObjectInterpretation): + """``monoid.reduce(f, streams) = lambda *a: monoid.reduce(f(*a), streams)``.""" + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if isinstance(body, Term) or not isinstance(body, Callable): + return fwd() + return lambda *a, **k: monoid.reduce(body(*a, **k), streams) + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if not args or any( + isinstance(arg, Term) or not isinstance(arg, Callable) for arg in args + ): + return fwd() + return lambda *a, **k: monoid.plus(*(arg(*a, **k) for arg in args)) + + +class MonoidOverMapping(ObjectInterpretation): + """``monoid.reduce({k: v_k}, streams) = {k: monoid.reduce(v_k, streams)}``.""" + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if isinstance(body, Term) or not isinstance(body, Mapping): + return fwd() + return {k: monoid.reduce(v, streams) for (k, v) in body.items()} + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if not args or not isinstance(args[0], Mapping): + return fwd() + + if isinstance(args[0], Interpretation): + keys = args[0].keys() + for b in args[1:]: + if not isinstance(b, Interpretation): + raise TypeError(f"Expected interpretation but got {b}") + if not keys == b.keys(): + raise ValueError( + f"Expected interpretation of {keys} but got {b.keys()}" + ) + return {k: monoid.plus(*(handler(b)(b[k]) for b in args)) for k in keys} + + for b in args[1:]: + if not isinstance(b, Mapping): + raise TypeError(f"Expected mapping but got {b}") + all_values = collections.defaultdict(list) + for d in args: + for k, v in d.items(): + all_values[k].append(v) + return {k: monoid.plus(*vs) for (k, vs) in all_values.items()} + + +def _scalar_args(args): + """True iff ``args`` is non-empty and every arg is a concrete int/float.""" + return ( + bool(args) + and not any(isinstance(x, Term) for x in args) + and all(isinstance(x, int | float) for x in args) + ) + + +class SumPlus(ObjectInterpretation): + """Scalar implementation of :data:`Sum`.""" + + @implements(Sum.plus) + def plus(self, *args): + if not _scalar_args(args): + return fwd() + return sum(args) + + +class MinPlus(ObjectInterpretation): + """Scalar implementation of :data:`Min`.""" + + @implements(Min.plus) + def plus(self, *args): + if not _scalar_args(args): + return fwd() + return min(args) + + +class MaxPlus(ObjectInterpretation): + """Scalar implementation of :data:`Max`.""" + + @implements(Max.plus) + def plus(self, *args): + if not _scalar_args(args): + return fwd() + return max(args) + + +class ProductPlus(ObjectInterpretation): + """Scalar implementation of :data:`Product`.""" + + @implements(Product.plus) + def plus(self, *args): + if not _scalar_args(args): + return fwd() + return functools.reduce(operator.mul, args) + + +class ArgMinPlus(ObjectInterpretation): + """Scalar score implementation of :data:`ArgMin`.""" + + @implements(ArgMin.plus) + def plus(self, *args): + if not args or not all(isinstance(a, tuple) for a in args): + return fwd() + if any(isinstance(a[0], Term) for a in args): + return fwd() + if not all(isinstance(a[0], int | float) for a in args): + return fwd() + return min(args, key=lambda a: a[0]) + + +class ArgMaxPlus(ObjectInterpretation): + """Scalar score implementation of :data:`ArgMax`.""" + + @implements(ArgMax.plus) + def plus(self, *args): + if not args or not all(isinstance(a, tuple) for a in args): + return fwd() + if any(isinstance(a[0], Term) for a in args): + return fwd() + if not all(isinstance(a[0], int | float) for a in args): + return fwd() + return max(args, key=lambda a: a[0]) + + +class CartesianProductPlus(ObjectInterpretation): + """Pure-Python implementation of :data:`CartesianProduct`.""" + + @implements(CartesianProduct.plus) + def plus(self, *args): + if not args: + return fwd() + if any(isinstance(x, Term) for x in args): + return fwd() + if not all(isinstance(x, Iterable) for x in args): + return fwd() + + def to_tuple(x): + return x if isinstance(x, tuple) else (x,) + + return [ + sum((to_tuple(v) for v in vals), ()) for vals in itertools.product(*args) + ] + + +is_scalar = _ExtensiblePredicate({Min, Max, Sum, Product}) + + +class MonoidOverSequence(ObjectInterpretation): + @implements(Monoid.plus) + def plus(self, monoid, *args): + if ( + not is_scalar(monoid) + or not args + or not isinstance(args[0], tuple | list | Generator) + ): + return fwd() + zipped = zip(*args, strict=True) + result = (monoid.plus(*vs) for vs in zipped) + if isinstance(args[0], tuple | list): + return type(args[0])(result) + return result + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if not is_scalar(monoid) or not isinstance(body, tuple | list | Generator): + return fwd() + result = (monoid.reduce(x, streams) for x in body) + if isinstance(body, tuple | list): + return type(body)(result) + return result + + +class _ExtensibleInterpretation(UserDict, Interpretation): + def extend(self, *intps: Interpretation) -> typing.Self: + for intp in intps: + self.data = coproduct(self.data, intp) # type: ignore[assignment] + return self + + +NormalizeIntp = _ExtensibleInterpretation().extend( + MonoidOverSequence(), + MonoidOverMapping(), + MonoidOverCallable(), + ReduceNoStreams(), + ReduceFusion(), + ReduceSplit(), + ReduceFactorization(), + ReduceDistributeCartesianProduct(), + PlusEmpty(), + PlusSingle(), + PlusIdentity(), + PlusAssoc(), + PlusDistr(), + PlusZero(), + PlusConsecutiveDups(), + PlusDups(), + SumPlus(), + MinPlus(), + MaxPlus(), + ProductPlus(), + ArgMinPlus(), + ArgMaxPlus(), + CartesianProductPlus(), ) +"""``NormalizeIntp``applies pure-Term rewrites (associativity, distributivity, +identity elimination, fusion, factorization, etc.). -NormalizeIntp = coproduct(NormalizePlusIntp, NormalizeReduceIntp) +""" diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 8fb12598..5ea04fcb 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -849,6 +849,8 @@ def _(x: collections.abc.Sequence, other) -> bool: @syntactic_eq.register(object) @syntactic_eq.register(str | bytes) def _(x: object, other) -> bool: + if isinstance(other, Term): # Terms often override __eq__ + return False return x == other diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index 9b311b25..f15103e3 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -1,23 +1,60 @@ +import itertools from collections.abc import Callable, Mapping, Sequence +from dataclasses import dataclass from typing import Any, get_args, get_origin +import jax +from hypothesis import given, settings from hypothesis import strategies as st -from effectful.ops.syntax import deffn -from effectful.ops.types import Operation +import effectful.handlers.jax.numpy as _jnp +from effectful.internals.runtime import interpreter +from effectful.ops.monoid import NormalizeIntp +from effectful.ops.semantics import apply, evaluate, handler +from effectful.ops.syntax import _BaseTerm, defdata, deffn, syntactic_eq +from effectful.ops.types import NotHandled, Operation, Term + +_JAX_ARRAY_SHAPE = (2,) + + +def _jax_array_value_strategy() -> st.SearchStrategy[jax.Array]: + return st.lists( + st.integers(min_value=-5, max_value=5), + min_size=_JAX_ARRAY_SHAPE[0], + max_size=_JAX_ARRAY_SHAPE[0], + ).map(lambda xs: jax.numpy.asarray(xs, dtype=jax.numpy.float32)) + + +# Unary jax fns map a scalar to a 1-D array (analogous to ``_UNARY_LIST_FNS`` +# for ints). Uses the effectful-wrapped jnp so named-dim broadcasting works. +_UNARY_JAX_FNS: list[Callable[[jax.Array], jax.Array]] = [ + lambda a: _jnp.stack([a, a + 1]), + lambda a: _jnp.stack([a, -a]), + lambda a: _jnp.stack([a, a + 1, 2 * a]), +] + +_BINARY_JAX_FNS: list[Callable[[jax.Array, jax.Array], jax.Array]] = [ + lambda a, b: a + b, + lambda a, b: a - b, + lambda a, b: a * b, +] def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: """Strategy for the value an *0-arg* Operation should return.""" if annotation is int: - return st.integers() + return st.integers(min_value=-100, max_value=100) if annotation is float: return st.floats(allow_nan=False) if get_origin(annotation) is list and get_args(annotation) == (int,): - return st.lists(st.integers(), max_size=2) + return st.lists(st.integers(min_value=-100, max_value=100), max_size=2) + if annotation is jax.Array: + return _jax_array_value_strategy() + if get_origin(annotation) is list and get_args(annotation) == (jax.Array,): + return st.lists(_jax_array_value_strategy(), max_size=2) raise NotImplementedError( f"No value strategy for return annotation {annotation!r}; " - "supported: int, list[int]" + "supported: int, list[int], jax.Array, list[jax.Array]" ) @@ -46,6 +83,13 @@ def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: lambda x: [0, x, x + 1], ] +_UNARY_JAX_LIST_FNS: list[Callable[[jax.Array], list[jax.Array]]] = [ + lambda _x: [], + lambda x: [x], + lambda x: [x, x + 1], + lambda x: [x, -x], +] + def _strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]: """Pick a strategy producing a callable suitable for binding `op` in an @@ -64,8 +108,18 @@ def _strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]: return st.sampled_from(_BINARY_NUM_FNS) if get_origin(ret) is list and get_args(ret) == (int,) and param_types == (int,): return st.sampled_from(_UNARY_LIST_FNS) + if ret is jax.Array and param_types == (jax.Array,): + return st.sampled_from(_UNARY_JAX_FNS) + if ret is jax.Array and param_types == (jax.Array, jax.Array): + return st.sampled_from(_BINARY_JAX_FNS) + if ( + get_origin(ret) is list + and get_args(ret) == (jax.Array,) + and param_types == (jax.Array,) + ): + return st.sampled_from(_UNARY_JAX_LIST_FNS) raise NotImplementedError( - f"Function-typed free var must return int or list[int]; got {ret!r} for {op}" + f"No callable strategy for free var with return {ret!r}, params {param_types!r}" ) @@ -82,4 +136,220 @@ def random_interpretation( return intp -__all__ = ["random_interpretation"] +def define_vars(*names, typ=int): + if len(names) == 1: + return Operation.define(typ, name=names[0]) + return tuple(Operation.define(typ, name=n) for n in names) + + +def syntactic_eq_alpha(x, y) -> bool: + """Alpha-equivalence-respecting variant of ``syntactic_eq``. + + Walks each expression bottom-up with :func:`evaluate` and renames + every bound variable to a deterministic canonical Operation. The + canonical names are assigned by a counter that increments in + ``evaluate``'s natural traversal order, so two alpha-equivalent + expressions canonicalize to syntactically identical results. + """ + + _op_cache: dict[int, Operation] = {} + + def _canonical_op(idx: int, op: Operation) -> Operation: + """Cached canonical Operation, keyed by encounter index. + + Cached so that two independent canonicalize runs return the same + Operation object for the same index — letting ``syntactic_eq`` + compare canonical forms by Operation identity. + """ + if idx in _op_cache: + return _op_cache[idx] + + op = Operation.define(op, name=f"__cv_{idx}") + _op_cache[idx] = op + return op + + cx = _canonicalize(x, _canonical_op) + cy = _canonicalize(y, _canonical_op) + return syntactic_eq(cx, cy) + + +def _canonicalize(expr, _canonical_op): + counter = itertools.count() + + def _substitute(arg, renaming): + """Apply a bound-variable renaming using ``evaluate`` for traversal.""" + if not renaming: + return arg + with interpreter({apply: _BaseTerm, **renaming}): + return evaluate(arg) + + def _bound_var_order(args, kwargs, bound_set: set[Operation]) -> list[Operation]: + """Return bound variables in deterministic encounter order.""" + seen: list[Operation] = [] + seen_set: set[Operation] = set() + + def _capture(op, *a, **kw): + if op in bound_set and op not in seen_set: + seen.append(op) + seen_set.add(op) + return defdata(op, *a, **kw) + + # ``evaluate`` walks Terms, lists, tuples, mappings, dataclasses, + # etc. for free; the apply handler captures bound vars used as + # ``x()`` anywhere in the body. + with interpreter({apply: _capture}): + evaluate((args, kwargs)) + + # Binders bypass the apply handler. Pick them up with a small structural + # walk that visits dict keys too. + def _walk_bare(obj): + if isinstance(obj, Operation): + if obj in bound_set and obj not in seen_set: + seen.append(obj) + seen_set.add(obj) + elif isinstance(obj, dict): + for k, v in obj.items(): + _walk_bare(k) + _walk_bare(v) + elif isinstance(obj, list | set | frozenset | tuple): + for v in obj: + _walk_bare(v) + + _walk_bare((args, kwargs)) + return seen + + def _apply_canonical(op, *args, **kwargs) -> Term: + bindings = op.__fvs_rule__(*args, **kwargs) + all_bound: set[Operation] = set().union( + *bindings.args, *bindings.kwargs.values() + ) + if not all_bound: + return _BaseTerm(op, *args, **kwargs) + + order = _bound_var_order(args, kwargs, all_bound) + canonical = {var: _canonical_op(next(counter), var) for var in order} + assert all_bound <= set(order) + + new_args = tuple( + _substitute( + arg, {v: canonical[v] for v in bindings.args[i] if v in canonical} + ) + for i, arg in enumerate(args) + ) + new_kwargs = { + k: _substitute( + v, + {var: canonical[var] for var in bindings.kwargs[k] if var in canonical}, + ) + for k, v in kwargs.items() + } + + # avoid the renaming from defdata + return _BaseTerm(op, *new_args, **new_kwargs) + + with interpreter({apply: _apply_canonical}): + return evaluate(expr) + + +@dataclass(frozen=True) +class Backend: + """A value-domain spec used to share monoid tests across int and jax.Array + backends. Provides the concrete value type, the hypothesis strategy for + drawing scalars in property tests, and an equality predicate that works + for that domain. + """ + + name: str + scalar_typ: Any + stream_typ: Any + scalar_strategy: st.SearchStrategy[Any] + eq: Callable[[Any, Any], bool] + + def fresh_op(self, name: str, n_args: int = 1, ret: str = "scalar") -> Operation: + """Build a fresh, unhandled Operation whose parameter and return + annotations are derived from this backend. + + ``ret`` is ``"scalar"`` for a scalar return or ``"stream"`` for a + stream-of-scalar return. The operation has ``n_args`` parameters, + each of type ``scalar_typ``. + """ + scalar = self.scalar_typ + out = self.stream_typ if ret == "stream" else scalar + params = ", ".join(f"_a{i}" for i in range(n_args)) + ns: dict[str, Any] = {"NotHandled": NotHandled} + exec(f"def _fn({params}):\n raise NotHandled\n", ns) + fn = ns["_fn"] + fn.__annotations__ = { + **{f"_a{i}": scalar for i in range(n_args)}, + "return": out, + } + return Operation.define(fn, name=name) + + +def _int_eq(a: Any, b: Any) -> bool: + return not isinstance(a, Term) and not isinstance(b, Term) and a == b + + +def _jax_eq(a: Any, b: Any) -> bool: + def _leaf_eq(x: Any, y: Any) -> bool: + return bool(jax.numpy.all(jax.numpy.isclose(x, y, equal_nan=True))) + + try: + leaves = jax.tree.leaves(jax.tree.map(_leaf_eq, a, b)) + except (ValueError, TypeError): + return False + return all(leaves) + + +def check_rewrite( + lhs, + rhs, + rule, + *, + backend: Backend, + free_vars=[], + max_examples: int = 25, + deadline=None, +) -> None: + with handler(rule): + norm = evaluate(lhs) + assert syntactic_eq_alpha(norm, rhs) + + @given(intp=random_interpretation(free_vars)) + @settings(max_examples=max_examples, deadline=deadline) + def _check_semantics(intp): + with handler(NormalizeIntp), handler(intp): + lhs_val = evaluate(lhs) + rhs_val = evaluate(rhs) + assert backend.eq(lhs_val, rhs_val) + + _check_semantics() + + +INT_BACKEND = Backend( + name="int", + scalar_typ=int, + stream_typ=list[int], + scalar_strategy=st.integers(min_value=-100, max_value=100), + eq=_int_eq, +) + + +JAX_BACKEND = Backend( + name="jax", + scalar_typ=jax.Array, + stream_typ=jax.Array, + scalar_strategy=_jax_array_value_strategy(), + eq=_jax_eq, +) + + +__all__ = [ + "Backend", + "INT_BACKEND", + "JAX_BACKEND", + "random_interpretation", + "define_vars", + "syntactic_eq_alpha", + "check_rewrite", +] diff --git a/tests/test_handlers_jax_monoid.py b/tests/test_handlers_jax_monoid.py new file mode 100644 index 00000000..35d041fe --- /dev/null +++ b/tests/test_handlers_jax_monoid.py @@ -0,0 +1,96 @@ +import jax +import pytest + +import effectful.handlers.jax.numpy as jnp +from effectful.handlers.jax import bind_dims, unbind_dims +from effectful.handlers.jax.monoid import ArrayReduce, LogSumExp +from effectful.handlers.jax.scipy.special import logsumexp +from effectful.ops.monoid import Max, Min, Product, Sum +from tests._monoid_helpers import JAX_BACKEND, Backend, check_rewrite, define_vars + +MONOIDS = [ + pytest.param(Sum, jnp.sum, id="Sum"), + pytest.param(Product, jnp.prod, id="Product"), + pytest.param(Min, jnp.min, id="Min"), + pytest.param(Max, jnp.max, id="Max"), + pytest.param(LogSumExp, logsumexp, id="LogSumExp"), +] + + +@pytest.fixture +def backend() -> Backend: + return JAX_BACKEND + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_array_1(monoid, reductor, backend: Backend): + (x, k) = define_vars("x", "k", typ=jax.Array) + X = define_vars("X", typ=backend.stream_typ) + + lhs = monoid.reduce(x(), {x: X()}) + rhs = reductor(bind_dims(unbind_dims(X(), k), k), axis=0) + + check_rewrite( + lhs=lhs, rhs=rhs, rule=ArrayReduce(), backend=backend, free_vars=[x, X, k] + ) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_array_2(monoid, reductor, backend: Backend): + (x, y, k1, k2) = define_vars("x", "y", "k1", "k2", typ=backend.scalar_typ) + (X, Y) = define_vars("X", "Y", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=2, ret="scalar") + + lhs = monoid.reduce(f(x(), y()), {x: X(), y: Y()}) + rhs = reductor( + bind_dims( + reductor( + bind_dims(f(unbind_dims(X(), k1), unbind_dims(Y(), k2)), k2), + axis=0, + ), + k1, + ), + axis=0, + ) + + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ArrayReduce(), + backend=backend, + free_vars=[x, y, k1, k2, X, Y, f], + ) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_array_3(monoid, reductor, backend: Backend): + """Stream `y` is `g(x())` — depends on the bound element of X. The reducer + must inline ``g`` along the same named dim used to unbind `x`.""" + (x, y, k1, k2) = define_vars("x", "y", "k1", "k2", typ=backend.scalar_typ) + X = define_vars("X", typ=backend.stream_typ) + + f = backend.fresh_op("f", n_args=2, ret="scalar") + g = backend.fresh_op("g", n_args=1, ret="stream") + + lhs = monoid.reduce(f(x(), y()), {x: X(), y: g(x())}) + rhs = reductor( + bind_dims( + reductor( + bind_dims( + f(unbind_dims(X(), k1), unbind_dims(g(unbind_dims(X(), k1)), k2)), + k2, + ), + axis=0, + ), + k1, + ), + axis=0, + ) + + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ArrayReduce(), + backend=backend, + free_vars=[x, y, k1, k2, X, f, g], + ) diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index d881869a..c7ee7567 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -1,29 +1,51 @@ -import functools -import itertools import typing import pytest -from hypothesis import given, settings +from hypothesis import HealthCheck, given, settings from hypothesis import strategies as st -from effectful.internals.runtime import interpreter +import effectful.handlers.jax.monoid # noqa: F401 from effectful.ops.monoid import ( CartesianProduct, Max, Min, Monoid, + MonoidOverMapping, + MonoidOverSequence, NormalizeIntp, + PlusAssoc, + PlusConsecutiveDups, + PlusDistr, + PlusDups, + PlusEmpty, + PlusIdentity, + PlusSingle, + PlusZero, Product, + ReduceDistributeCartesianProduct, + ReduceFactorization, + ReduceFusion, + ReduceNoStreams, + ReduceSplit, Sum, distributes_over, - is_commutative, ) -from effectful.ops.semantics import apply, evaluate, fvsof, handler -from effectful.ops.syntax import _BaseTerm, defdata, syntactic_eq -from effectful.ops.types import NotHandled, Operation -from tests._monoid_helpers import random_interpretation +from effectful.ops.semantics import fvsof, handler +from effectful.ops.types import Operation +from tests._monoid_helpers import ( + INT_BACKEND, + JAX_BACKEND, + Backend, + check_rewrite, + define_vars, + syntactic_eq_alpha, +) + + +@pytest.fixture(params=[INT_BACKEND, JAX_BACKEND], ids=["int", "jax"]) +def backend(request) -> Backend: + return request.param -_INT = st.integers(min_value=-100, max_value=100) ALL_MONOIDS = [ pytest.param(Sum, id="Sum"), @@ -61,247 +83,183 @@ ] -def define_vars(*names, typ=int): - if len(names) == 1: - return Operation.define(typ, name=names[0]) - return tuple(Operation.define(typ, name=n) for n in names) - - -@functools.cache -def _canonical_op(idx: int) -> Operation: - """Globally cached canonical Operation, keyed by encounter index. - - Cached so that two independent canonicalize runs return the same - Operation object for the same index — letting ``syntactic_eq`` - compare canonical forms by Operation identity. - """ - return Operation.define(int, name=f"__cv_{idx}") - - -def syntactic_eq_alpha(x, y) -> bool: - """Alpha-equivalence-respecting variant of ``syntactic_eq``. - - Walks each expression bottom-up with :func:`evaluate` and renames - every bound variable to a deterministic canonical Operation. The - canonical names are assigned by a counter that increments in - ``evaluate``'s natural traversal order, so two alpha-equivalent - expressions canonicalize to syntactically identical results. - """ - return syntactic_eq(_canonicalize(x), _canonicalize(y)) - - -def _canonicalize(expr): - counter = itertools.count() - - def _substitute(arg, renaming): - """Apply a bound-variable renaming using ``evaluate`` for traversal.""" - if not renaming: - return arg - with interpreter({apply: _BaseTerm, **renaming}): - return evaluate(arg) - - def _bound_var_order(args, kwargs, bound_set): - """Return bound variables in deterministic encounter order.""" - seen: list[Operation] = [] - seen_set: set[Operation] = set() - - def _capture(op, *a, **kw): - if op in bound_set and op not in seen_set: - seen.append(op) - seen_set.add(op) - return defdata(op, *a, **kw) - - # ``evaluate`` walks Terms, lists, tuples, mappings, dataclasses, - # etc. for free; the apply handler captures bound vars used as - # ``x()`` anywhere in the body. - with interpreter({apply: _capture}): - evaluate((args, kwargs)) - - # Binders bypass the apply handler. Pick them up with a small structural - # walk that visits dict keys too. - def _walk_bare(obj): - if isinstance(obj, Operation): - if obj in bound_set and obj not in seen_set: - seen.append(obj) - seen_set.add(obj) - elif isinstance(obj, dict): - for k, v in obj.items(): - _walk_bare(k) - _walk_bare(v) - elif isinstance(obj, list | set | frozenset | tuple): - for v in obj: - _walk_bare(v) - - _walk_bare((args, kwargs)) - return seen - - def _apply_canonical(op, *args, **kwargs): - bindings = op.__fvs_rule__(*args, **kwargs) - all_bound: set[Operation] = set().union( - *bindings.args, *bindings.kwargs.values() - ) - if not all_bound: - return _BaseTerm(op, *args, **kwargs) - - order = _bound_var_order(args, kwargs, all_bound) - canonical = {var: _canonical_op(next(counter)) for var in order} - assert all_bound <= set(order) - - new_args = tuple( - _substitute( - arg, {v: canonical[v] for v in bindings.args[i] if v in canonical} - ) - for i, arg in enumerate(args) - ) - new_kwargs = { - k: _substitute( - v, - {var: canonical[var] for var in bindings.kwargs[k] if var in canonical}, - ) - for k, v in kwargs.items() - } - - # avoid the renaming from defdata - return _BaseTerm(op, *new_args, **new_kwargs) - - with interpreter({apply: _apply_canonical}): - return evaluate(expr) - - @pytest.mark.parametrize("monoid", ALL_MONOIDS) -@given(a=_INT, b=_INT, c=_INT) -@settings(max_examples=50, deadline=None) -def test_associativity(monoid, a, b, c): - left = monoid.plus(monoid.plus(a, b), c) - right = monoid.plus(a, monoid.plus(b, c)) - assert left == right +@given(data=st.data()) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) +def test_associativity(monoid, backend, data): + a = data.draw(backend.scalar_strategy) + b = data.draw(backend.scalar_strategy) + c = data.draw(backend.scalar_strategy) + with handler(NormalizeIntp): + left = monoid.plus(monoid.plus(a, b), c) + right = monoid.plus(a, monoid.plus(b, c)) + assert backend.eq(left, right) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -@given(a=_INT) -@settings(max_examples=50, deadline=None) -def test_identity(monoid, a): - assert monoid.plus(monoid.identity, a) == a - assert monoid.plus(a, monoid.identity) == a +@given(data=st.data()) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) +def test_identity(monoid, backend, data): + a = data.draw(backend.scalar_strategy) + with handler(NormalizeIntp): + assert backend.eq(monoid.plus(monoid.identity, a), a) + assert backend.eq(monoid.plus(a, monoid.identity), a) @pytest.mark.parametrize("monoid", COMMUTATIVE) -@given(a=_INT, b=_INT) -@settings(max_examples=50, deadline=None) -def test_commutativity(monoid, a, b): - assert monoid.plus(a, b) == monoid.plus(b, a) +@given(data=st.data()) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) +def test_commutativity(monoid, backend, data): + a = data.draw(backend.scalar_strategy) + b = data.draw(backend.scalar_strategy) + with handler(NormalizeIntp): + assert backend.eq(monoid.plus(a, b), monoid.plus(b, a)) @pytest.mark.parametrize("monoid", IDEMPOTENT) -@given(a=_INT) -@settings(max_examples=50, deadline=None) -def test_idempotence(monoid, a): - assert monoid.plus(a, a) == a +@given(data=st.data()) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) +def test_idempotence(monoid, backend, data): + a = data.draw(backend.scalar_strategy) + with handler(NormalizeIntp): + assert backend.eq(monoid.plus(a, a), a) @pytest.mark.parametrize("monoid", WITH_ZERO) -@given(a=_INT) -@settings(max_examples=50, deadline=None) -def test_zero_absorbs(monoid, a): - assert monoid.plus(monoid.zero, a) == monoid.zero - assert monoid.plus(a, monoid.zero) == monoid.zero - - -def _check_pair(lhs, rhs, *, free_vars=[], max_examples: int = 25) -> None: - """Run structural + semantic checks on a TermPair.""" +@given(data=st.data()) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) +def test_zero_absorbs(monoid, backend, data): + a = data.draw(backend.scalar_strategy) with handler(NormalizeIntp): - norm = evaluate(lhs) - - assert syntactic_eq_alpha(norm, rhs) + assert backend.eq(monoid.plus(monoid.zero, a), monoid.zero) + assert backend.eq(monoid.plus(a, monoid.zero), monoid.zero) - @given(intp=random_interpretation(free_vars)) - @settings(max_examples=max_examples, deadline=None) - def _check_semantics(intp): - with handler(intp): - lhs_val = evaluate(lhs) - rhs_val = evaluate(rhs) - assert lhs_val == rhs_val - _check_semantics() +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_empty(monoid, backend): + check_rewrite( + lhs=monoid.plus(), rhs=monoid.identity, rule=PlusEmpty(), backend=backend + ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_empty(monoid): - _check_pair(lhs=monoid.plus(), rhs=monoid.identity) +def test_plus_single(monoid, backend): + x = define_vars("x", typ=backend.scalar_typ) + check_rewrite( + lhs=monoid.plus(x()), rhs=x(), rule=PlusSingle(), backend=backend, free_vars=[x] + ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_single(monoid): - x = define_vars("x", typ=type(monoid.identity)) - _check_pair(lhs=monoid.plus(x()), rhs=x(), free_vars=[x]) +def test_plus_identity_right(monoid, backend): + x = define_vars("x", typ=backend.scalar_typ) + lhs = monoid.plus(x(), monoid.identity) + rhs = monoid.plus(x()) -@pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_identity_right(monoid): - x = define_vars("x", typ=type(monoid.identity)) - _check_pair(lhs=monoid.plus(x(), monoid.identity), rhs=x(), free_vars=[x]) + check_rewrite(lhs=lhs, rhs=rhs, rule=PlusIdentity(), backend=backend, free_vars=[x]) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_identity_left(monoid): - x = define_vars("x", typ=type(monoid.identity)) - _check_pair(lhs=monoid.plus(monoid.identity, x()), rhs=x(), free_vars=[x]) +def test_plus_identity_left(monoid, backend): + x = define_vars("x", typ=backend.scalar_typ) + + lhs = monoid.plus(monoid.identity, x()) + rhs = monoid.plus(x()) + + check_rewrite(lhs=lhs, rhs=rhs, rule=PlusIdentity(), backend=backend, free_vars=[x]) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_assoc_right(monoid): - x, y, z = define_vars("x", "y", "z", typ=type(monoid.identity)) - _check_pair( +def test_plus_assoc_right(monoid, backend): + x, y, z = define_vars("x", "y", "z", typ=backend.scalar_typ) + check_rewrite( lhs=monoid.plus(x(), monoid.plus(y(), z())), rhs=monoid.plus(x(), y(), z()), + rule=PlusAssoc(), + backend=backend, free_vars=[x, y, z], ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_assoc_left(monoid): - x, y, z = define_vars("x", "y", "z", typ=type(monoid.identity)) - _check_pair( +def test_plus_assoc_left(monoid, backend): + x, y, z = define_vars("x", "y", "z", typ=backend.scalar_typ) + check_rewrite( lhs=monoid.plus(monoid.plus(x(), y()), z()), rhs=monoid.plus(x(), y(), z()), + rule=PlusAssoc(), + backend=backend, free_vars=[x, y, z], ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_sequence(monoid): - a, b, c, d = define_vars("a", "b", "c", "d", typ=type(monoid.identity)) - _check_pair( +def test_plus_sequence(monoid, backend): + a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) + check_rewrite( lhs=monoid.plus((a(), b()), (c(), d())), rhs=(monoid.plus(a(), c()), monoid.plus(b(), d())), + rule=MonoidOverSequence(), + backend=backend, free_vars=[a, b, c, d], ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_mapping(monoid): - a, b, c, d = define_vars("a", "b", "c", "d", typ=type(monoid.identity)) - _check_pair( - lhs=monoid.plus({"x": a(), "y": b()}, {"x": c(), "z": d()}), - rhs={"x": monoid.plus(a(), c()), "y": b(), "z": d()}, +def test_plus_mapping(monoid, backend): + a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) + + lhs = monoid.plus({0: a(), 1: b()}, {0: c(), 2: d()}) + rhs = {0: monoid.plus(a(), c()), 1: monoid.plus(b()), 2: monoid.plus(d())} + + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=MonoidOverMapping(), + backend=backend, free_vars=[a, b, c, d], ) -def test_plus_distributes(): - a, b, c, d = define_vars("a", "b", "c", "d") +def test_plus_distributes(backend): + a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d())) - rhs = Sum.plus( - Product.plus(a(), c()), - Product.plus(a(), d()), - Product.plus(b(), c()), - Product.plus(b(), d()), + rhs = Product.plus( + Sum.plus( + Product.plus(a(), c()), + Product.plus(a(), d()), + Product.plus(b(), c()), + Product.plus(b(), d()), + ) + ) + check_rewrite( + lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] ) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b, c, d]) -def test_plus_distributes_constant(): - a, b, c, d = define_vars("a", "b", "c", "d") +def test_plus_distributes_constant(backend): + a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d()), 5) rhs = Product.plus( 5, @@ -312,11 +270,13 @@ def test_plus_distributes_constant(): Product.plus(b(), d()), ), ) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b, c, d]) + check_rewrite( + lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] + ) -def test_plus_distributes_multiple(): - a, b, c, d = define_vars("a", "b", "c", "d") +def test_plus_distributes_multiple(backend): + a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) lhs = Sum.plus( Min.plus(a(), b()), Min.plus(c(), d()), @@ -337,72 +297,123 @@ def test_plus_distributes_multiple(): Sum.plus(b(), d()), ), ) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b, c, d]) + check_rewrite( + lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] + ) @pytest.mark.parametrize("monoid", IDEMPOTENT) -def test_plus_idempotent_consecutive(monoid): +def test_plus_idempotent_consecutive(monoid, backend): """``a, a, b → a, b`` — only consecutive duplicates collapse.""" - a, b = define_vars("a", "b") + a, b = define_vars("a", "b", typ=backend.scalar_typ) lhs = monoid.plus(a(), a(), b()) - return _check_pair(lhs=lhs, rhs=monoid.plus(a(), b()), free_vars=[a, b]) + return check_rewrite( + lhs=lhs, + rhs=monoid.plus(a(), b()), + rule=PlusConsecutiveDups(), + backend=backend, + free_vars=[a, b], + ) @pytest.mark.parametrize("monoid", IDEMPOTENT) -def test_plus_idempotent_non_consecutive(monoid): +def test_plus_idempotent_non_consecutive(monoid, backend): """``a, b, a`` — Semilattice (Min/Max) collapses via commutative - PlusDups; plain IdempotentMonoid leaves it as-is (consecutive-only).""" - a, b = define_vars("a", "b") + PlusDups.""" + a, b = define_vars("a", "b", typ=backend.scalar_typ) lhs = monoid.plus(a(), b(), a()) - if is_commutative(monoid): - rhs = monoid.plus(a(), b()) - else: - rhs = monoid.plus(a(), b(), a()) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b]) + rhs = monoid.plus(a(), b()) + check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups(), backend=backend, free_vars=[a, b]) -def test_plus_commutative_idempotent_long(): +@pytest.mark.parametrize("monoid", [Min, Max]) +def test_plus_commutative_idempotent_long(monoid, backend): """Long alternation collapses via commutative dedup (Min/Max only).""" - a, b = define_vars("a", "b") - lhs = Min.plus(a(), b(), a(), b(), b(), a(), a()) - _check_pair(lhs=lhs, rhs=Min.plus(a(), b()), free_vars=[a, b]) + a, b = define_vars("a", "b", typ=backend.scalar_typ) + lhs = monoid.plus(a(), b(), a(), b(), b(), a(), a()) + rhs = monoid.plus(a(), b()) + check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups(), backend=backend, free_vars=[a, b]) @pytest.mark.parametrize("monoid", WITH_ZERO) -def test_plus_zero(monoid): - a = define_vars("a") +def test_plus_zero(monoid, backend): + a = define_vars("a", typ=backend.scalar_typ) lhs_right = monoid.plus(a(), monoid.zero) lhs_left = monoid.plus(monoid.zero, a()) - _check_pair(lhs=lhs_right, rhs=monoid.zero, free_vars=[a]) - _check_pair(lhs=lhs_left, rhs=monoid.zero, free_vars=[a]) + rhs = monoid.zero + check_rewrite( + lhs=lhs_right, rhs=rhs, rule=PlusZero(), backend=backend, free_vars=[a] + ) + check_rewrite( + lhs=lhs_left, rhs=rhs, rule=PlusZero(), backend=backend, free_vars=[a] + ) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_partial_1(monoid, backend): + x, y = define_vars("x", "y", typ=backend.scalar_typ) + lhs = monoid.reduce(x(), {x: []}) + rhs = monoid.identity + check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_partial_2(monoid, backend): + x, y = define_vars("x", "y", typ=backend.scalar_typ) + Y = define_vars("Y", typ=backend.stream_typ) + + lhs = monoid.reduce(x(), {y: Y(), x: []}) + rhs = monoid.identity + + check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, Y]) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_body_sequence(monoid): - x = Operation.define(int, name="x") - X = Operation.define(list[int], name="X") +def test_partial_3(monoid, backend): + x, y, a, b = define_vars("x", "y", "a", "b", typ=backend.scalar_typ) + Y = define_vars("Y", typ=backend.stream_typ) + + lhs = monoid.reduce(x(), {y: Y(), x: [a(), b()]}) + rhs = monoid.plus(monoid.reduce(a(), {y: Y()}), monoid.reduce(b(), {y: Y()})) + + check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, a, b, Y]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_partial_4(monoid, backend): + x, y, a, b = define_vars("x", "y", "a", "b", typ=backend.scalar_typ) + f = backend.fresh_op("f", n_args=1, ret="stream") + + lhs = monoid.reduce(x(), {y: f(x()), x: [a(), b()]}) + rhs = monoid.plus(monoid.reduce(a(), {y: f(a())}), monoid.reduce(b(), {y: f(b())})) - @Operation.define - def f(_x: int) -> int: - raise NotHandled + check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, a, b, f]) + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_body_sequence(monoid, backend): + x = Operation.define(backend.scalar_typ, name="x") + X = Operation.define(backend.stream_typ, name="X") + f = backend.fresh_op("f", n_args=1, ret="scalar") g = Operation.define(f, name="g") lhs = monoid.reduce((f(x()), g(x())), {x: X()}) rhs = (monoid.reduce(f(x()), {x: X()}), monoid.reduce(g(x()), {x: X()})) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, f, g]) + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=MonoidOverSequence(), + backend=backend, + free_vars=[X, f, g], + ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_body_sequence_2(monoid): - x, y = define_vars("x", "y") - X, Y = define_vars("X", "Y", typ=list[int]) - - @Operation.define - def f(_x: int) -> int: - raise NotHandled - +def test_reduce_body_sequence_2(monoid, backend): + x, y = define_vars("x", "y", typ=backend.scalar_typ) + X, Y = define_vars("X", "Y", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=1, ret="scalar") g = Operation.define(f, name="g") lhs = monoid.reduce((f(x()), g(y())), {x: X(), y: Y()}) @@ -411,103 +422,115 @@ def f(_x: int) -> int: monoid.reduce(g(y()), {x: X(), y: Y()}), ) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, Y, f, g]) + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=MonoidOverSequence(), + backend=backend, + free_vars=[X, Y, f, g], + ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_body_mapping(monoid): - x = Operation.define(int, name="x") - X = Operation.define(list[int], name="X") - - @Operation.define - def f(_x: int) -> int: - raise NotHandled - +def test_reduce_body_mapping(monoid, backend): + x = Operation.define(backend.scalar_typ, name="x") + X = Operation.define(backend.stream_typ, name="X") + f = backend.fresh_op("f", n_args=1, ret="scalar") g = Operation.define(f, name="g") - lhs = monoid.reduce({"a": f(x()), "b": g(x())}, {x: X()}) + lhs = monoid.reduce({0: f(x()), 1: g(x())}, {x: X()}) rhs = { - "a": monoid.reduce(f(x()), {x: X()}), - "b": monoid.reduce(g(x()), {x: X()}), + 0: monoid.reduce(f(x()), {x: X()}), + 1: monoid.reduce(g(x()), {x: X()}), } - _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, f, g]) + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=MonoidOverMapping(), + backend=backend, + free_vars=[X, f, g], + ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_no_streams(monoid): - a = define_vars("a") +def test_reduce_no_streams(monoid, backend): + a = define_vars("a", typ=backend.scalar_typ) lhs = monoid.reduce(a(), {}) rhs = monoid.identity - _check_pair(lhs=lhs, rhs=rhs, free_vars=[a]) + check_rewrite( + lhs=lhs, rhs=rhs, rule=ReduceNoStreams(), backend=backend, free_vars=[a] + ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_reduce(monoid): - a, b = define_vars("a", "b") - A, B = define_vars("A", "B", typ=list[int]) - - @Operation.define - def f(_x: int, _y: int) -> int: - raise NotHandled +def test_reduce_reduce(monoid, backend): + a, b = define_vars("a", "b", typ=backend.scalar_typ) + A, B = define_vars("A", "B", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=2, ret="scalar") lhs = monoid.reduce(monoid.reduce(f(a(), b()), {a: A()}), {b: B()}) rhs = monoid.reduce(f(a(), b()), {a: A(), b: B()}) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B, f]) + check_rewrite( + lhs=lhs, rhs=rhs, rule=ReduceFusion(), backend=backend, free_vars=[A, B, f] + ) @pytest.mark.parametrize("monoid", COMMUTATIVE) -def test_reduce_plus(monoid): - a, b = define_vars("a", "b") - A, B = define_vars("A", "B", typ=list[int]) +def test_reduce_plus(monoid, backend): + a, b = define_vars("a", "b", typ=backend.scalar_typ) + A, B = define_vars("A", "B", typ=backend.stream_typ) lhs = monoid.reduce(monoid.plus(a(), b()), {a: A(), b: B()}) rhs = monoid.plus( monoid.reduce(a(), {a: A(), b: B()}), monoid.reduce(b(), {a: A(), b: B()}), ) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B]) + check_rewrite( + lhs=lhs, rhs=rhs, rule=ReduceSplit(), backend=backend, free_vars=[A, B] + ) -def test_reduce_independent_1(): - a, b = define_vars("a", "b") - A, B = define_vars("A", "B", typ=list[int]) +def test_reduce_independent_1(backend): + a, b = define_vars("a", "b", typ=backend.scalar_typ) + A, B = define_vars("A", "B", typ=backend.stream_typ) lhs = Sum.reduce(Product.plus(a(), b()), {a: A(), b: B()}) - rhs = Product.plus(Sum.reduce(a(), {a: A()}), Sum.reduce(b(), {b: B()})) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B]) - + rhs = Product.plus( + Sum.reduce(Product.plus(a()), {a: A()}), Sum.reduce(Product.plus(b()), {b: B()}) + ) + check_rewrite( + lhs=lhs, rhs=rhs, rule=ReduceFactorization(), backend=backend, free_vars=[A, B] + ) -def test_reduce_independent_2(): - a, b, c = define_vars("a", "b", "c") - A, B, C = define_vars("A", "B", "C", typ=list[int]) - @Operation.define - def f(_x: int, _y: int) -> int: - raise NotHandled +def test_reduce_independent_2(backend): + a, b, c = define_vars("a", "b", "c", typ=backend.scalar_typ) + A, B, C = define_vars("A", "B", "C", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=2, ret="scalar") lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c())), {a: A(), b: B(), c: C()}) rhs = Product.plus( - Sum.reduce(a(), {a: A()}), + Sum.reduce(Product.plus(a()), {a: A()}), Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), ) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B, C, f]) + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ReduceFactorization(), + backend=backend, + free_vars=[A, B, C, f], + ) -def test_reduce_independent_3_negative(): +def test_reduce_independent_3_negative(backend): """Stream `b` depends on `a` (b: g(a())), so the proposed factorization is unsound — the normalizer must NOT apply it.""" - a, b, c = define_vars("a", "b", "c") - A, C = define_vars("A", "C", typ=list[int]) - - @Operation.define - def f(_x: int, _y: int) -> int: - raise NotHandled + a, b, c = define_vars("a", "b", "c", typ=backend.scalar_typ) + A, C = define_vars("A", "C", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=2, ret="scalar") + g = backend.fresh_op("g", n_args=1, ret="stream") - @Operation.define - def g(_x: int) -> list[int]: - raise NotHandled - - with handler(NormalizeIntp): + with handler(ReduceFactorization()): # ty:ignore[invalid-argument-type] lhs = Sum.reduce( Product.plus(a(), b(), f(b(), c())), {a: A(), b: g(a()), c: C()} ) @@ -519,104 +542,107 @@ def g(_x: int) -> list[int]: assert not syntactic_eq_alpha(lhs, bogus_rhs) -def test_reduce_independent_4(): - a, b, c = define_vars("a", "b", "c") - A, B, C = define_vars("A", "B", "C", typ=list[int]) - - @Operation.define - def f(_x: int, _y: int) -> int: - raise NotHandled +def test_reduce_independent_4(backend): + a, b, c = define_vars("a", "b", "c", typ=backend.scalar_typ) + A, B, C = define_vars("A", "B", "C", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=2, ret="scalar") lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c()), 7), {a: A(), b: B(), c: C()}) rhs = Product.plus( 7, - Sum.reduce(a(), {a: A()}), + Sum.reduce(Product.plus(a()), {a: A()}), Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), ) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B, C, f]) + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ReduceFactorization(), + backend=backend, + free_vars=[A, B, C, f], + ) @pytest.mark.parametrize("outer,inner", MONOID_PAIRS) -def test_reduce_lifted_1(outer, inner): - a, i = define_vars("a", "i") - A, N, A_domain = define_vars("A", "N", "A_domain", typ=list[int]) - - @Operation.define - def f(_: int) -> float: - raise NotHandled +def test_reduce_lifted_1(outer, inner, backend): + a, i = define_vars("a", "i", typ=backend.scalar_typ) + A, N, A_domain = define_vars("A", "N", "A_domain", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=1, ret="scalar") term1 = outer.reduce( inner.reduce(f(a()), {a: A()}), {A: CartesianProduct.reduce(A_domain(), {i: N()})}, ) - term2 = inner.reduce(outer.reduce(f(a()), {a: A_domain()}), {i: N()}) - _check_pair(lhs=term1, rhs=term2, free_vars=[N, A_domain, f]) + term2 = inner.reduce(outer.reduce(inner.plus(f(a())), {a: A_domain()}), {i: N()}) + + check_rewrite( + lhs=term1, + rhs=term2, + rule=ReduceDistributeCartesianProduct(), + backend=backend, + free_vars=[N, A_domain, f], + ) def test_reduce_cartesian_1(): - a, i = define_vars("a", "i") - A = define_vars("A", typ=list[int]) + a, i = define_vars("a", "i", typ=int) + A = define_vars("A", typ=tuple[int]) - term1 = Sum.reduce( - Product.reduce(a(), {a: []}), - {A: CartesianProduct.reduce([], {i: []})}, - ) - term2 = Product.reduce(Sum.reduce(a(), {a: []}), {i: []}) + with handler(NormalizeIntp): + term1 = Sum.reduce( + Product.reduce(a(), {a: []}), + {A: CartesianProduct.reduce([], {i: []})}, + ) + term2 = Product.reduce(Sum.reduce(a(), {a: []}), {i: []}) assert term1 == term2 def test_reduce_cartesian_2(): - a, i = define_vars("a", "i") - A = define_vars("A", typ=list[int]) + a, i = define_vars("a", "i", typ=int) + A = define_vars("A", typ=tuple[int]) - term1 = Sum.reduce( - Product.reduce(a(), {a: A()}), - {A: CartesianProduct.reduce([(0,)], {i: [0]})}, - ) - term2 = Product.reduce(Sum.reduce(a(), {a: [0]}), {i: [0]}) + with handler(NormalizeIntp): + term1 = Sum.reduce( + Product.reduce(a(), {a: A()}), + {A: CartesianProduct.reduce([(0,)], {i: [0]})}, + ) + term2 = Product.reduce(Sum.reduce(a(), {a: [0]}), {i: [0]}) assert term1 == term2 @pytest.mark.parametrize("outer,inner", MONOID_PAIRS) -def test_reduce_lifted_multi_index(outer, inner): - a, i, j = define_vars("a", "i", "j") - A, N, M, A_domain = define_vars("A", "N", "M", "A_domain", typ=list[int]) - - @Operation.define - def f(_: int) -> float: - raise NotHandled +def test_reduce_lifted_multi_index(outer, inner, backend): + a, i, j = define_vars("a", "i", "j", typ=backend.scalar_typ) + A, N, M, A_domain = define_vars("A", "N", "M", "A_domain", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=1, ret="scalar") term1 = outer.reduce( inner.reduce(f(a()), {a: A()}), {A: CartesianProduct.reduce(A_domain(), {i: N(), j: M()})}, ) term2 = inner.reduce( - outer.reduce(f(a()), {a: A_domain()}), + outer.reduce(inner.plus(f(a())), {a: A_domain()}), {i: N(), j: M()}, ) - _check_pair(lhs=term1, rhs=term2, free_vars=[N, M, A_domain, f]) + check_rewrite( + lhs=term1, + rhs=term2, + rule=ReduceDistributeCartesianProduct(), + backend=backend, + free_vars=[N, M, A_domain, f], + ) @pytest.mark.parametrize("outer,inner", MONOID_PAIRS) -def test_reduce_lifted_2(outer, inner): +def test_reduce_lifted_2(outer, inner, backend): """The worked example on page 396 of 'Lifted Variable Elimination: Decoupling the Operators from the Constraint Language'. """ - a, i, s, t = define_vars("a", "i", "s", "t") - A, N, T = define_vars("A", "N", "T", typ=list[int]) - - @Operation.define - def A_domain(_i: int) -> list[int]: - raise NotHandled - - @Operation.define - def f1(_a: int, _s: int) -> float: - raise NotHandled - - @Operation.define - def f2(_t: int, _a: int) -> float: - raise NotHandled + a, i, s, t = define_vars("a", "i", "s", "t", typ=backend.scalar_typ) + A, N, T = define_vars("A", "N", "T", typ=backend.stream_typ) + A_domain = backend.fresh_op("A_domain", n_args=1, ret="stream") + f1 = backend.fresh_op("f1", n_args=2, ret="scalar") + f2 = backend.fresh_op("f2", n_args=2, ret="scalar") term1 = outer.reduce( inner.reduce(inner.plus(f1(a(), s()), f2(t(), a())), {a: A()}), @@ -625,10 +651,18 @@ def f2(_t: int, _a: int) -> float: term2 = outer.reduce( inner.reduce( - outer.reduce(inner.plus(f1(a(), s()), f2(t(), a())), {a: A_domain(i())}), + outer.reduce( + inner.plus(inner.plus(f1(a(), s()), f2(t(), a()))), {a: A_domain(i())} + ), {i: N()}, ), {t: T()}, ) - _check_pair(lhs=term1, rhs=term2, free_vars=[a, i, s, t, A, N, T, A_domain, f1, f2]) + check_rewrite( + lhs=term1, + rhs=term2, + rule=ReduceDistributeCartesianProduct(), + backend=backend, + free_vars=[a, i, s, t, A, N, T, A_domain, f1, f2], + ) From f1491e43c991a3dd11db583f461599d5a2a361c8 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Fri, 22 May 2026 12:51:04 -0400 Subject: [PATCH 05/10] Add `delta` terms for array construction in `handlers.jax.monoid` (#663) * Add monoid module (#653) * add monoid module * clean up * fix doctest * fix * wip * remove incorrect rule * add disjoint set tests and fix bug * lint * drop jax monoid defs * drop incorrect comment * add assert * reduce nondeterminism and add assertions * fix inconsistent stream numbering and missing constant factors * wip * cleanup * wip * wip * fix rule * wip * fix bug * cleanup * lin * wip * fix tests * format * lint * wip * wip * wip * wip * wip * wip * wip * wip * drop runtime typed dict lifting * wip * format * reorganize * stop using string dicts to avoid unification issue * wip * wip * wip * wip * wip * use check_rewrite in jax tests * lint * wip * fix bugs * comment on not implemented cases * format * simplify * lint * add matmul test --- effectful/handlers/jax/_handlers.py | 6 + effectful/handlers/jax/_terms.py | 20 +-- effectful/handlers/jax/monoid.py | 255 +++++++++++++++++++++++++++- effectful/ops/monoid.py | 6 +- tests/test_handlers_jax_monoid.py | 177 ++++++++++++++++++- 5 files changed, 442 insertions(+), 22 deletions(-) diff --git a/effectful/handlers/jax/_handlers.py b/effectful/handlers/jax/_handlers.py index c5d10423..91fba369 100644 --- a/effectful/handlers/jax/_handlers.py +++ b/effectful/handlers/jax/_handlers.py @@ -87,6 +87,12 @@ def _partial_eval(t: Expr[jax.Array]) -> Expr[jax.Array]: if not sized_fvs: return t + # if any dimension is zero sized, the result is empty + if any(size == 0 for size in sized_fvs.values()): + key = tuple(sized_fvs.keys()) + shape = tuple(sized_fvs[k] for k in key) + return jax_getitem(jnp.empty(shape), key) + def _is_eager(t): return not isinstance(t, Term) or t.op in sized_fvs or is_eager_array(t) diff --git a/effectful/handlers/jax/_terms.py b/effectful/handlers/jax/_terms.py index 81206293..c88fe934 100644 --- a/effectful/handlers/jax/_terms.py +++ b/effectful/handlers/jax/_terms.py @@ -8,7 +8,6 @@ import effectful.handlers.jax.numpy as jnp from effectful.handlers.jax._handlers import ( IndexElement, - _partial_eval, _register_jax_op, bind_dims, jax_getitem, @@ -451,28 +450,15 @@ def _bind_dims_array(t: jax.Array, *args: Operation[[], jax.Array]) -> jax.Array >>> bind_dims(t, b, a).shape (3, 2) """ - - def _evaluate(expr): - if isinstance(expr, Term): - (args, kwargs) = jax.tree.map(_evaluate, (expr.args, expr.kwargs)) - return _partial_eval(expr) - if not jax.tree_util.treedef_is_leaf(jax.tree.structure(expr)): - return jax.tree.map(_evaluate, expr) - return expr - if not isinstance(t, Term): return t - result = _evaluate(t) - if not isinstance(result, Term) or not args: - return result - # ensure that the result is a jax_getitem with an array as the first argument - if not (result.op is jax_getitem and isinstance(result.args[0], jax.Array)): + if not (t.op is jax_getitem and isinstance(t.args[0], jax.Array)): raise NotHandled - array = result.args[0] - dims = result.args[1] + array = t.args[0] + dims = t.args[1] assert isinstance(dims, Sequence) # ensure that the order is a subset of the named dimensions diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py index a406cda5..42d7866e 100644 --- a/effectful/handlers/jax/monoid.py +++ b/effectful/handlers/jax/monoid.py @@ -1,4 +1,6 @@ import functools +import typing +from collections.abc import Iterable import jax @@ -12,12 +14,13 @@ Monoid, NormalizeIntp, Product, + Streams, Sum, outer_stream, ) from effectful.ops.semantics import evaluate, fvsof, fwd, handler, typeof from effectful.ops.syntax import ObjectInterpretation, deffn, implements -from effectful.ops.types import Operation +from effectful.ops.types import Interpretation, NotHandled, Operation, Term def cartesian_prod(x, y): @@ -151,8 +154,258 @@ def reduce(self, monoid, body, streams): return fwd() +@Operation.define +def delta(_index: tuple[int, ...], _weight: jax.Array) -> jax.Array: + raise NotHandled + + +py_range = range + + +@Operation.define +def range(*args: int) -> Iterable[jax.Array]: + raise NotHandled + + +def _range_start(term: Term): + assert term.op == range + if len(term.args) < 2: + return 0 + return term.args[0] + + +def _range_stop(term: Term): + assert term.op == range + if len(term.args) < 2: + return term.args[0] + return term.args[1] + + +def _range_step(term: Term): + assert term.op == range + if len(term.args) < 3: + return 1 + return term.args[2] + + +def _is_simple_range(term: Term) -> bool: + if term.op != range: + return False + + start = _range_start(term) + step = _range_step(term) + return ( + not isinstance(start, Term) + and start == 0 + and not isinstance(step, Term) + and step == 1 + ) + + +class ReduceDeltaIndependent(ObjectInterpretation): + """Eliminate a Delta that has independent, dense index arguments. + + reduce(M, streams, delta((), body)) ≡ reduce(M, streams, body) + + reduce(M, streams ∪ {v: range(N)}, delta(idx' ++ (v(),), body)) + ═══════════════════════════════════════════════════════════════════════════ + reduce(M, streams, delta(idx', bind_dims(body[v() := unbind_dims(streams[v], fv)], fv))) + + Not yet supported: + + - **Strided index streams** (``range(0, N, k)`` for ``k != 1``): the + premise ``_is_simple_range`` requires ``start == 0`` and ``step == 1``. + A strided extension would substitute ``v() := unbind_dims(jnp.arange( + start, stop, step), fv)`` and otherwise follow the same shape — the + change is purely in the recognised range form, the bind/unbind cycle + below is unchanged. + - **Non-zero start** (``range(a, b, 1)`` with ``a != 0``): same template + as the strided case; only the recognised range form changes. + - **Non-bare index expressions** (``delta((2*v(),), w)``, + ``delta((f(v()),), w)``, etc.): currently requires the final index + entry to be a bare call ``v()`` of a stream var op. Generalizing to + arbitrary index expressions is a scatter, not a bind: materialize the + index expression and the weight separately over ``v``, then + ``jnp.zeros(N).at[indices].set(values)`` (for Sum; analogous for + other monoids using ``.add``/``.min``/``.max``/...). This is a + different leaf operation from ``bind_dims`` and warrants a sibling + rule rather than an extension of this one. + """ + + @implements(Monoid.reduce) + def _(self, monoid: Monoid, body, streams: Streams): + if not (isinstance(body, Term) and body.op == delta): + return fwd() + + indices, weight = body.args + assert isinstance(indices, tuple) + + if not indices: + return monoid.reduce(weight, streams) + + head_indices, tail_index = indices[:-1], indices[-1] + if not (isinstance(tail_index, Term) and tail_index.op in streams): + return fwd() + + tail_op: Operation = tail_index.op + tail_stream = streams[tail_op] + if not (isinstance(tail_stream, Term) and _is_simple_range(tail_stream)): + return fwd() + + fresh_op = Operation.define(tail_op) + indices = jnp.arange(_range_stop(tail_stream)) + if isinstance(indices, jax.Array) and len(indices) == 0: + return monoid.identity + + fresh_stream = unbind_dims(indices, fresh_op) + subst_intp = typing.cast(Interpretation, {tail_op: deffn(fresh_stream)}) + fresh_body = bind_dims(handler(subst_intp)(evaluate)(weight), fresh_op) + fresh_streams = {k: v for (k, v) in streams.items() if k != tail_op} + return monoid.reduce(delta(head_indices, fresh_body), fresh_streams) + + +class ReduceDependentRangeMask(ObjectInterpretation): + """Eliminate a dependent range by masking. + + reduce(M, streams ∪ {u: range(N), v: range(u())}, body) + ═══════════════════════════════════════════════════════════════════════════ + reduce(M, streams ∪ {u: range(N), v: range(N)}, where(v() < u(), body, M.identity)) + + Currently recognises only the lower-triangular form ``v: range(u())``: + constant start of 0, dependent stop equal to a bare call of another + stream var. + + Not yet supported: + + - **Upper-triangular** (``v: range(u(), N)`` — constant stop, dependent + start): bbox becomes ``range(0, N)`` (or ``range(0, bbox_N)``), guard + becomes ``v() >= u()``. Same shape of rewrite as lower-tri; differs + only in which side of the range carries the stream-var reference and + in the predicate direction. + - **Banded** (``v: range(u() - k, u() + k + 1)`` — two-sided dependent + bounds with constant width): bbox is ``range(0, N + k)`` (or similar + bounded by both endpoints' extents), guard is + ``(v() >= u() - k) & (v() < u() + k + 1)``. Needs both-sides + affine-bound recognition. + - **Strided dependent** (``v: range(0, u(), k)`` for ``k != 1``): bbox + stays ``range(0, N)`` and guard becomes + ``(v() < u()) & (v() % k == 0)`` (or equivalent), or alternatively + embed in a smaller bbox ``range(0, ceil(N/k))`` and remap the index. + - **Affine bounds** (``v: range(a*u() + b, c*u() + d)`` for affine + coefficients): bbox computed from ``ub(c*u() + d)`` over ``u``'s + range; guard is the conjunction of the two affine constraints. This + subsumes the upper/banded/strided cases under one affine recogniser. + - **Multi-stream-var dependent** (``v: range(u() + w())`` referencing + more than one outer stream var): bbox is the affine combination over + both referents' ranges; guard threads through all dependencies. + - **Reverse-order dependent ranges**: e.g. ``v: range(u(), 0, -1)``; + needs to handle negative step and the corresponding reverse + enumeration. + """ + + @implements(Monoid.reduce) + def _(self, monoid: Monoid, body, streams: Streams): + stream_vars = set(streams.keys()) + + # streams of the form k: range(X) + simple_ranges = { + k: v + for (k, v) in streams.items() + if isinstance(v, Term) and _is_simple_range(v) + } + for u, u_stream in simple_ranges.items(): + if fvsof(u_stream) & stream_vars: + continue + + for v, v_stream in simple_ranges.items(): + if ( + isinstance(v_stream, Term) + and isinstance(_range_stop(v_stream), Term) + and _range_stop(v_stream).op == u + ): + fresh_streams = { + a: (u_stream if a == v else b) for (a, b) in streams.items() + } + + # there are other commuting rules for delta that we do not + # currently include + if isinstance(body, Term) and body.op == delta: + fresh_body = delta( + body.args[0], + jnp.where(v() < u(), body.args[1], monoid.identity), # type: ignore[arg-type] + ) + else: + fresh_body = jnp.where(v() < u(), body, monoid.identity) + + return monoid.reduce(fresh_body, fresh_streams) + + return fwd() + + +class ReduceRange(ObjectInterpretation): + """Replace concrete-range stream values with materialized ``jnp.arange``. + + reduce(M, streams ∪ {v: range(a, b, s)}, body) + ≡ reduce(M, streams ∪ {v: jnp.arange(a, b, s)}, body) + + when ``a``, ``b``, ``s`` are concrete and ``body`` is not a delta term. + Delegates the actual reduction to whichever handler picks up the + materialized ``jax.Array`` streams. + """ + + @implements(Monoid.reduce) + def _(self, monoid: Monoid, body, streams: Streams): + if isinstance(body, Term) and body.op == delta: + return fwd() + + new_streams: dict = {} + any_replaced = False + for k, v in streams.items(): + if isinstance(v, Term) and v.op == range: + new_streams[k] = jnp.arange( + _range_start(v), _range_stop(v), _range_step(v) + ) + any_replaced = True + else: + new_streams[k] = v + + if not any_replaced: + return fwd() + return monoid.reduce(body, new_streams) + + +# Cross-cutting delta rules not yet implemented: +# +# - **Delta-commuting** (DC-hoist): for any pure op ``f`` (no Scoped binders +# that intersect a delta's index ops), push delta outward: +# f(args..., delta(idx, body), args...) +# ≡ delta(idx, f(args..., body, args...)) +# This normalizes delta to the outermost position so the reduce rules can +# pattern-match ``isinstance(body, Term) and body.op == delta`` cleanly. +# The soundness condition is mechanical via ``op.__fvs_rule__``: refuse to +# commute when a non-delta arg's scope binds any op in the delta's idx. +# +# - **Delta-merging** (DC-merge): under a pure binary op ``f`` (or +# generalized n-ary), merge multiple deltas when their index tuples are +# subsequence-compatible: +# f(delta(idx_a, v), delta(idx_b, w)) ≡ delta(idx_max, f(v, w)) +# where ``idx_max`` is the longer of ``idx_a``, ``idx_b`` and ``idx_a`` is +# a subsequence of ``idx_b`` (or vice versa). Refuse to fire when neither +# is a subsequence of the other, since that would silently insert an +# outer-product broadcast. +# +# - **Empty-domain detection at the term level**: currently size-0 named +# dims must be resolved by leaf consumers (``bind_dims``, reductors with +# ``initial=monoid.identity``). The empty-domain check is intentionally +# NOT a rule on its own — rewrites stay size-polymorphic and leaf ops +# carry the burden. See the conversation in monoid.py's history for why. + + NormalizeIntp.extend( ArrayReduce(), + ReduceRange(), + ReduceDeltaIndependent(), + ReduceDependentRangeMask(), SumPlusJax(), ProductPlusJax(), MinPlusJax(), diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 70bb5002..c9231510 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -16,7 +16,6 @@ Scoped, deffn, implements, - iter_, syntactic_eq, syntactic_hash, ) @@ -96,8 +95,11 @@ def reduce[A, B, U: Body]( if isinstance(stream_body, Term): continue stream_values_iter = iter(stream_body) - if isinstance(stream_values_iter, Term) and stream_values_iter.op is iter_: + + # if we iterate and get a term instead of a real iterator, skip + if isinstance(stream_values_iter, Term): continue + new_reduces = [] for stream_val in stream_values_iter: with handler({stream_key: deffn(stream_val)}): diff --git a/tests/test_handlers_jax_monoid.py b/tests/test_handlers_jax_monoid.py index 35d041fe..fe888ad4 100644 --- a/tests/test_handlers_jax_monoid.py +++ b/tests/test_handlers_jax_monoid.py @@ -1,11 +1,20 @@ import jax import pytest +from jax import random as random import effectful.handlers.jax.numpy as jnp from effectful.handlers.jax import bind_dims, unbind_dims -from effectful.handlers.jax.monoid import ArrayReduce, LogSumExp +from effectful.handlers.jax.monoid import ( + ArrayReduce, + LogSumExp, + ReduceDeltaIndependent, + ReduceDependentRangeMask, + delta, +) +from effectful.handlers.jax.monoid import range as Range from effectful.handlers.jax.scipy.special import logsumexp -from effectful.ops.monoid import Max, Min, Product, Sum +from effectful.ops.monoid import Max, Min, NormalizeIntp, Product, Sum +from effectful.ops.semantics import handler from tests._monoid_helpers import JAX_BACKEND, Backend, check_rewrite, define_vars MONOIDS = [ @@ -94,3 +103,167 @@ def test_reduce_array_3(monoid, reductor, backend: Backend): backend=backend, free_vars=[x, y, k1, k2, X, f, g], ) + + +# --------------------------------------------------------------------------- +# Delta rules. All tests use the operation form ``delta(idx, body)`` rather +# than the ``Delta`` dataclass; the delta op is the user-facing surface. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_delta_empty(monoid, reductor, backend: Backend): + """An empty-index delta unwraps to its body. + + reduce(M, streams, delta((), body)) ≡ reduce(M, streams, body) + """ + x = define_vars("x", typ=backend.scalar_typ) + X = define_vars("X", typ=backend.stream_typ) + + lhs = monoid.reduce(delta((), x()), {x: X()}) + rhs = monoid.reduce(x(), {x: X()}) + + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ReduceDeltaIndependent(), + backend=backend, + free_vars=[x, X], + ) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_delta_independent_one(monoid, reductor, backend: Backend): + """One R1 step: peel the final preserved index off a delta. + + reduce(M, {y: Y()}, delta((y(),), f(y()))) + ≡ reduce(M, {}, delta((), bind_dims(f(unbind_dims(Y(), k)), k))) + """ + (y, k) = define_vars("y", "k", typ=backend.scalar_typ) + Y = define_vars("Y", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=1, ret="scalar") + + # We use a concrete range here instead of an abstract one, because + # unbind_dims is undefined on empty arrays (and the rewrite produces a + # different rhs in this case) + lhs = monoid.reduce(delta((y(),), f(y())), {y: Range(3)}) + rhs = monoid.reduce(bind_dims(f(unbind_dims(jnp.arange(3), k)), k), {}) + + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ReduceDeltaIndependent(), + backend=backend, + free_vars=[y, k, Y, f], + ) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_delta_independent_preserves_others(monoid, reductor, backend: Backend): + """R1 peels only the final index. Streams not matching the peeled index op + stay untouched, as do earlier entries in the index tuple. + + reduce(M, {x: X(), y: Y()}, delta((x(), y()), f(x(), y()))) + ≡ reduce(M, {x: X()}, delta((x(),), bind_dims(f(x(), unbind_dims(Y(), k)), k))) + """ + (x, y, k) = define_vars("x", "y", "k", typ=backend.scalar_typ) + f = backend.fresh_op("f", n_args=2, ret="scalar") + + lhs = monoid.reduce(delta((x(), y()), f(x(), y())), {x: Range(2), y: Range(3)}) + rhs = monoid.reduce( + bind_dims( + bind_dims( + f(unbind_dims(jnp.arange(2), x), unbind_dims(jnp.arange(3), k)), k + ), + x, + ), + {}, + ) + + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ReduceDeltaIndependent(), + backend=backend, + free_vars=[f], + ) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_dependent_range_mask(monoid, reductor, backend: Backend): + """A dependent range stream gets rewritten to the referent's bbox stream, + with the original constraint folded into the body as a where-guard. + + reduce(M, {u: range(0, N, 1), v: range(0, u(), 1)}, body) + ≡ reduce(M, {u: range(0, N, 1), v: range(0, N, 1)}, where(v() < u(), body, M.identity)) + """ + (u, v) = define_vars("u", "v", typ=backend.scalar_typ) + N = 5 + f = backend.fresh_op("f", n_args=2, ret="scalar") + + body = f(u(), v()) + + lhs = monoid.reduce(body, {u: Range(0, N, 1), v: Range(0, u(), 1)}) + rhs = monoid.reduce( + jnp.where(v() < u(), body, monoid.identity), + {u: Range(0, N, 1), v: Range(0, N, 1)}, + ) + + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ReduceDependentRangeMask(), + backend=backend, + free_vars=[u, v, f], + ) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_dependent_range_mask_delta_body(monoid, reductor, backend: Backend): + """When the body is a delta term, R4 folds the constraint into the delta's + weight while leaving its index tuple untouched. + + reduce(M, {u: range(N), v: range(u())}, delta((u(), v()), w)) + ≡ reduce(M, {u: range(N), v: range(N)}, + delta((u(), v()), where(v() < u(), w, M.identity))) + """ + (u, v) = define_vars("u", "v", typ=backend.scalar_typ) + N = 5 + f = backend.fresh_op("f", n_args=2, ret="scalar") + + weight = f(u(), v()) + idx = (u(), v()) + + lhs = monoid.reduce(delta(idx, weight), {u: Range(0, N, 1), v: Range(0, u(), 1)}) + rhs = monoid.reduce( + delta(idx, jnp.where(v() < u(), weight, monoid.identity)), + {u: Range(0, N, 1), v: Range(0, N, 1)}, + ) + + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ReduceDependentRangeMask(), + backend=backend, + free_vars=[u, v, f], + ) + + +def test_reduce_matmul(): + key = jax.random.PRNGKey(0) + # Define dimensions + B, I, J, K = 2, 3, 4, 5 + + # Create sample matrices + X = random.normal(key, (B, I, J)) + Y = random.normal(key, (B, J, K)) + (b, i, j, k) = define_vars("b", "i", "j", "k", typ=jax.Array) + + with handler(NormalizeIntp): + actual = Sum.reduce( + delta((b(), i(), k()), unbind_dims(X, b, i, j) * unbind_dims(Y, b, j, k)), + {b: Range(B), i: Range(I), j: Range(J), k: Range(K)}, + ) + + expected = jnp.einsum("bij,bjk->bik", X, Y) + assert jnp.allclose(actual, expected) From 0d6ab3822f8f548e4bf0d783ebcaa84f389b00a7 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Wed, 17 Jun 2026 14:05:55 -0400 Subject: [PATCH 06/10] Add weighted streams (#665) * more precise stream type * add tests for weighted rules * add reduction rule for weighted streams and tests * add test to demo expectation * add numpyro monoid module * add quadrature * add tests * wip * refactor tests * wip * test composition of lifting and weighting * drop numpyro changes * drop unused ops * lint * make weighted a Monoid method * fix typing of jax arrays * change weighted typing to take callable * fix test * fix test * resolve type aliases before dispatching * wip * wip * remove typeof_full * wip * wip * wip * format * refactor test harness * drop unused test --- effectful/handlers/jax/_terms.py | 7 + effectful/handlers/jax/monoid.py | 15 +- effectful/ops/monoid.py | 125 ++++++- effectful/ops/semantics.py | 9 +- tests/_monoid_helpers.py | 466 +++++++++++++----------- tests/test_handlers_jax_monoid.py | 184 +++++----- tests/test_ops_monoid.py | 587 +++++++++++++++++------------- 7 files changed, 827 insertions(+), 566 deletions(-) diff --git a/effectful/handlers/jax/_terms.py b/effectful/handlers/jax/_terms.py index c88fe934..05a5390e 100644 --- a/effectful/handlers/jax/_terms.py +++ b/effectful/handlers/jax/_terms.py @@ -14,10 +14,17 @@ unbind_dims, ) from effectful.internals.tensor_utils import _desugar_tensor_index +from effectful.internals.unification import Box, nested_type from effectful.ops.syntax import defdata from effectful.ops.types import Expr, NotHandled, Operation, Term +@nested_type.register(jax.Array) +@nested_type.register(jax._src.core.Tracer) +def _(value): + return Box(jax.Array) + + class _IndexUpdateHelper: """Helper class to implement array-style .at[index].set() updates for effectful arrays.""" diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py index 42d7866e..3f6273be 100644 --- a/effectful/handlers/jax/monoid.py +++ b/effectful/handlers/jax/monoid.py @@ -16,6 +16,7 @@ Product, Streams, Sum, + distributes_over, outer_stream, ) from effectful.ops.semantics import evaluate, fvsof, fwd, handler, typeof @@ -38,6 +39,10 @@ def cartesian_prod(x, y): LogSumExp = Monoid(name="LogSumExp", identity=jnp.asarray(float("-inf"))) +# ``Sum`` in log space is multiplication, which distributes over ``LogSumExp``: +# a + logsumexp(b, c) = logsumexp(a + b, a + c) +distributes_over.register(Sum, LogSumExp) + def _jax_args(args): """True iff ``args`` is non-empty and every arg is a concrete @@ -108,7 +113,15 @@ def plus(self, *args): if not isinstance(a, jax.Array): return fwd() result = a if result is None else cartesian_prod(result, a) - return result if result is not None else CartesianProduct.identity + if result is None: + return CartesianProduct.identity + # CartesianProduct values are streams of rows. ``cartesian_prod`` + # already lifts 1D inputs to 2D, but a single-array call seeds + # ``result = a`` unchanged — promote so the rank invariant holds for + # every array-path return. + if result.ndim == 1: + result = result[:, None] + return result ARRAY_REDUCTORS = { diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index c9231510..76351fa6 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -10,7 +10,14 @@ from typing import Annotated, Any from effectful.internals.disjoint_set import DisjointSet -from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler +from effectful.ops.semantics import ( + coproduct, + evaluate, + fvsof, + fwd, + handler, + typeof, +) from effectful.ops.syntax import ( ObjectInterpretation, Scoped, @@ -19,11 +26,17 @@ syntactic_eq, syntactic_hash, ) -from effectful.ops.types import Expr, Interpretation, NotHandled, Operation, Term +from effectful.ops.types import ( + Expr, + Interpretation, + NotHandled, + Operation, + Term, +) + +type Stream[T] = Iterable[T] -# Note: The streams value type should be something like Iterable[T], but some of -# our target stream types (e.g. jax.Array) are not subtypes of Iterable -type Streams[T] = Mapping[Operation[[], T], Any] +type Streams = Mapping[Operation[[], Any], Stream[Any]] type Body[T] = ( Iterable[T] @@ -34,9 +47,7 @@ ) -def outer_stream( - streams: Streams, -) -> Iterable[tuple[Operation, Expr, dict[Operation, Expr]]]: +def outer_stream(streams: Streams) -> Iterable[tuple[Operation, Stream, Streams]]: """Returns the streams that can be ordered outermost in the loop nest as well as the remaining streams in the nest. @@ -51,13 +62,13 @@ def outer_stream( ) -class Monoid[T]: +class Monoid[W]: """A monoid with ``plus`` and ``reduce`` :class:`Operation` s.""" _name: str - identity: T + identity: W - def __init__(self, identity: T, name: str): + def __init__(self, identity: W, name: str): self._name = name self.identity = identity @@ -111,6 +122,18 @@ def reduce[A, B, U: Body]( return self.plus(*new_reduces) raise NotHandled + @Operation.define + def weighted[T]( + self, stream: Stream[T], weight: Callable[[T], W] | Operation[[T], W] + ) -> Stream[T]: + """A stream paired with a per-element weight. ``var`` is an + :class:`Operation` standing for "an element of ``stream``"; ``weight`` + is an expression that uses ``var`` and evaluates to the weight of that + element. + + """ + raise NotHandled + class MonoidWithZero[T](Monoid[T]): zero: T @@ -175,6 +198,12 @@ def _is_monoid_reduce(op: Operation) -> bool: return isinstance(owner, Monoid) and op is owner.reduce +def _is_monoid_weighted(op: Operation) -> bool: + """True if ``op`` is the ``weighted`` operation of some :class:`Monoid`.""" + owner = getattr(op, "__self__", None) + return isinstance(owner, Monoid) and op is owner.weighted + + class PlusEmpty(ObjectInterpretation): """plus() = 0""" @@ -557,6 +586,78 @@ def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): return fwd() +class ReduceWeightedStream(ObjectInterpretation): + """reduce(M, body, {x: WM.weighted(s, v, w), ...}) = reduce(M, WM.plus(w[v:=x()], body), {x: s, ...}) + + requires distributes_over(WM, M). + + The substitution ``v -> x`` is done by beta-reducing ``deffn(w, v)`` on + ``x()`` — symbolic, no Python dispatch on the weight expression. + """ + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + for k, v in streams.items(): + if isinstance(v, Term) and _is_monoid_weighted(v.op): + v_stream, v_weight = v.args + v_monoid = v.op.__self__ + if not distributes_over(v_monoid, monoid): + continue + w_at_k = v_weight(k()) + weighted_body = v_monoid.plus(w_at_k, body) + new_streams = {**streams, k: v_stream} + return monoid.reduce(weighted_body, new_streams) + return fwd() + + +class ReduceCartesianWeightedStream(ObjectInterpretation): + """``CartesianProduct.reduce`` over a :func:`weighted` body whose + ``weight`` is independent of the plate (product-index) streams:: + + CartesianProduct.reduce(M.weighted(s, w), plates) + = M.weighted( + CartesianProduct.reduce(s, plates), + deffn(M.reduce(w, {e: row()}), row), + ) + + Reuses ``body``'s element binder ``e`` (already typed by construction); + introduces a fresh ``row`` binder typed as ``Iterable[elem_type]``. + + Only fires when ``w`` is independent of the plate vars. + """ + + @Operation.define + @staticmethod + def _iterable_elem[T](iter: Iterable[T]) -> T: + raise NotHandled + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if monoid is not CartesianProduct: + return fwd() + if not (isinstance(body, Term) and _is_monoid_weighted(body.op)): + return fwd() + + s, w = body.args + if not isinstance(s, Term) and len(s) == 0: + return CartesianProduct.reduce([], streams) + + if set(streams.keys()) & fvsof(w): + return fwd() + + elem_typ = typeof(self._iterable_elem(s)) + elem_op = Operation.define(elem_typ, name="elem") + row_op = Operation.define(Iterable[elem_typ], name="row") + + weight_monoid = body.op.__self__ + joint_weight = deffn( + weight_monoid.reduce(w(elem_op()), {elem_op: row_op()}), row_op + ) + joint_stream = CartesianProduct.reduce(s, streams) + + return weight_monoid.weighted(joint_stream, joint_weight) + + class MonoidOverCallable(ObjectInterpretation): """``monoid.reduce(f, streams) = lambda *a: monoid.reduce(f(*a), streams)``.""" @@ -751,6 +852,8 @@ def extend(self, *intps: Interpretation) -> typing.Self: ReduceSplit(), ReduceFactorization(), ReduceDistributeCartesianProduct(), + ReduceWeightedStream(), + ReduceCartesianWeightedStream(), PlusEmpty(), PlusSingle(), PlusIdentity(), diff --git a/effectful/ops/semantics.py b/effectful/ops/semantics.py index e54729bb..1a33d92d 100644 --- a/effectful/ops/semantics.py +++ b/effectful/ops/semantics.py @@ -229,6 +229,13 @@ def _evaluate_list_view(expr, **kwargs): def _simple_type(tp: type) -> type: """Convert a type object into a type that can be dispatched on.""" + + def _resolve_aliases(tp: type) -> type: + tp = typing.get_origin(tp) or tp + if isinstance(tp, typing.TypeAliasType): + return _resolve_aliases(tp.__value__) + return tp + if isinstance(tp, typing.TypeVar): tp = ( tp.__bound__ @@ -246,7 +253,7 @@ def _simple_type(tp: type) -> type: tp = functools.reduce(operator.or_, (type(arg) for arg in args)) if isinstance(tp, types.UnionType): raise TypeError(f"Union types are not supported: {tp}") - return typing.get_origin(tp) or tp + return _resolve_aliases(tp) def typeof[T](term: Expr[T]) -> type[T]: diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index f15103e3..f8089bec 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -1,146 +1,22 @@ +import builtins import itertools -from collections.abc import Callable, Mapping, Sequence -from dataclasses import dataclass -from typing import Any, get_args, get_origin +import typing +from abc import ABC, abstractmethod +from collections.abc import Callable, Mapping +from typing import Any, Literal, overload import jax from hypothesis import given, settings from hypothesis import strategies as st +from hypothesis.strategies import SearchStrategy import effectful.handlers.jax.numpy as _jnp from effectful.internals.runtime import interpreter -from effectful.ops.monoid import NormalizeIntp -from effectful.ops.semantics import apply, evaluate, handler +from effectful.ops.monoid import NormalizeIntp, Stream, _is_monoid_weighted +from effectful.ops.semantics import apply, evaluate, fvsof, handler from effectful.ops.syntax import _BaseTerm, defdata, deffn, syntactic_eq from effectful.ops.types import NotHandled, Operation, Term -_JAX_ARRAY_SHAPE = (2,) - - -def _jax_array_value_strategy() -> st.SearchStrategy[jax.Array]: - return st.lists( - st.integers(min_value=-5, max_value=5), - min_size=_JAX_ARRAY_SHAPE[0], - max_size=_JAX_ARRAY_SHAPE[0], - ).map(lambda xs: jax.numpy.asarray(xs, dtype=jax.numpy.float32)) - - -# Unary jax fns map a scalar to a 1-D array (analogous to ``_UNARY_LIST_FNS`` -# for ints). Uses the effectful-wrapped jnp so named-dim broadcasting works. -_UNARY_JAX_FNS: list[Callable[[jax.Array], jax.Array]] = [ - lambda a: _jnp.stack([a, a + 1]), - lambda a: _jnp.stack([a, -a]), - lambda a: _jnp.stack([a, a + 1, 2 * a]), -] - -_BINARY_JAX_FNS: list[Callable[[jax.Array, jax.Array], jax.Array]] = [ - lambda a, b: a + b, - lambda a, b: a - b, - lambda a, b: a * b, -] - - -def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: - """Strategy for the value an *0-arg* Operation should return.""" - if annotation is int: - return st.integers(min_value=-100, max_value=100) - if annotation is float: - return st.floats(allow_nan=False) - if get_origin(annotation) is list and get_args(annotation) == (int,): - return st.lists(st.integers(min_value=-100, max_value=100), max_size=2) - if annotation is jax.Array: - return _jax_array_value_strategy() - if get_origin(annotation) is list and get_args(annotation) == (jax.Array,): - return st.lists(_jax_array_value_strategy(), max_size=2) - raise NotImplementedError( - f"No value strategy for return annotation {annotation!r}; " - "supported: int, list[int], jax.Array, list[jax.Array]" - ) - - -_UNARY_NUM_FNS: list[Callable[[int], int]] = [ - lambda x: x, - lambda x: x + 1, - lambda x: x - 1, - lambda x: -x, - lambda x: 2 * x, - lambda x: 3 * x + 1, -] - -_BINARY_NUM_FNS: list[Callable[[int, int], int]] = [ - lambda x, y: x + y, - lambda x, y: x - y, - lambda x, y: x * y, - lambda x, y: x + 2 * y, - lambda x, y: 2 * x - y, -] - -_UNARY_LIST_FNS: list[Callable[[int], list[int]]] = [ - lambda _x: [], - lambda x: [x], - lambda x: [x, x + 1], - lambda x: [x, -x], - lambda x: [0, x, x + 1], -] - -_UNARY_JAX_LIST_FNS: list[Callable[[jax.Array], list[jax.Array]]] = [ - lambda _x: [], - lambda x: [x], - lambda x: [x, x + 1], - lambda x: [x, -x], -] - - -def _strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]: - """Pick a strategy producing a callable suitable for binding `op` in an - interpretation. Inspects the operation's signature. - """ - sig = op.__signature__ - params = list(sig.parameters.values()) - ret = sig.return_annotation - param_types = tuple(p.annotation for p in params) - - if not params: - return _value_strategy_for(ret).map(deffn) - if ret in (int, float) and param_types == (int,): - return st.sampled_from(_UNARY_NUM_FNS) - if ret in (int, float) and param_types == (int, int): - return st.sampled_from(_BINARY_NUM_FNS) - if get_origin(ret) is list and get_args(ret) == (int,) and param_types == (int,): - return st.sampled_from(_UNARY_LIST_FNS) - if ret is jax.Array and param_types == (jax.Array,): - return st.sampled_from(_UNARY_JAX_FNS) - if ret is jax.Array and param_types == (jax.Array, jax.Array): - return st.sampled_from(_BINARY_JAX_FNS) - if ( - get_origin(ret) is list - and get_args(ret) == (jax.Array,) - and param_types == (jax.Array,) - ): - return st.sampled_from(_UNARY_JAX_LIST_FNS) - raise NotImplementedError( - f"No callable strategy for free var with return {ret!r}, params {param_types!r}" - ) - - -@st.composite -def random_interpretation( - draw: st.DrawFn, free_vars: Sequence[Operation] -) -> Mapping[Operation, Callable[..., Any]]: - """Draw an Interpretation binding every Operation in `case.free_vars` to - a randomly chosen value/callable. Keys are Operation identities. - """ - intp: dict[Operation, Callable[..., Any]] = {} - for op in free_vars: - intp[op] = draw(_strategy_for_op(op)) - return intp - - -def define_vars(*names, typ=int): - if len(names) == 1: - return Operation.define(typ, name=names[0]) - return tuple(Operation.define(typ, name=n) for n in names) - def syntactic_eq_alpha(x, y) -> bool: """Alpha-equivalence-respecting variant of ``syntactic_eq``. @@ -251,8 +127,7 @@ def _apply_canonical(op, *args, **kwargs) -> Term: return evaluate(expr) -@dataclass(frozen=True) -class Backend: +class Backend(ABC): """A value-domain spec used to share monoid tests across int and jax.Array backends. Provides the concrete value type, the hypothesis strategy for drawing scalars in property tests, and an equality predicate that works @@ -262,10 +137,29 @@ class Backend: name: str scalar_typ: Any stream_typ: Any - scalar_strategy: st.SearchStrategy[Any] - eq: Callable[[Any, Any], bool] - - def fresh_op(self, name: str, n_args: int = 1, ret: str = "scalar") -> Operation: + strategy_for_op: dict[Operation, st.SearchStrategy[Callable[..., Any]]] + + def __init__(self): + self.strategy_for_op = {} + + @abstractmethod + def eq(self, a: Any, b: Any) -> bool: + raise NotImplementedError + + @abstractmethod + def strategy( + self, + arg_types: tuple[type, ...] = (), + ret: Literal["scalar", "stream"] = "scalar", + ) -> SearchStrategy: + raise NotImplementedError + + def _fresh_op( + self, + name: str, + arg_types: tuple[type, ...] = (), + ret: Literal["scalar", "stream"] = "scalar", + ) -> Operation: """Build a fresh, unhandled Operation whose parameter and return annotations are derived from this backend. @@ -275,81 +169,245 @@ def fresh_op(self, name: str, n_args: int = 1, ret: str = "scalar") -> Operation """ scalar = self.scalar_typ out = self.stream_typ if ret == "stream" else scalar - params = ", ".join(f"_a{i}" for i in range(n_args)) + params = ", ".join(f"_a{i}" for i in range(len(arg_types))) ns: dict[str, Any] = {"NotHandled": NotHandled} exec(f"def _fn({params}):\n raise NotHandled\n", ns) fn = ns["_fn"] fn.__annotations__ = { - **{f"_a{i}": scalar for i in range(n_args)}, + **{f"_a{i}": t for i, t in enumerate(arg_types)}, "return": out, } - return Operation.define(fn, name=name) + op = Operation.define(fn, name=name) + self.strategy_for_op[op] = self.strategy(arg_types, ret) + return op + + @overload + def define_vars(self, name: str, /, **kwargs) -> Operation: ... + + @overload + def define_vars( + self, n1: str, n2: str, /, *names: str, **kwargs + ) -> tuple[Operation, ...]: ... + + def define_vars(self, *names: str, **kwargs) -> Operation | tuple[Operation, ...]: # type: ignore[misc] + if len(names) == 1: + return self._fresh_op(names[0], **kwargs) + return tuple(self._fresh_op(n, **kwargs) for n in names) + + def check_rewrite( + self, + lhs, + rhs, + rule, + *, + max_examples: int = 25, + deadline=None, + normalize=NormalizeIntp, + ) -> None: + with handler(rule): + norm = evaluate(lhs) + assert syntactic_eq_alpha(norm, rhs) + + fvs = fvsof(lhs) | fvsof(rhs) + + @st.composite + def random_interpretation( + draw: st.DrawFn, + ) -> Mapping[Operation, Callable[..., Any]]: + """Draw an Interpretation binding every Operation in `free_vars` to + a randomly chosen value/callable. Keys are Operation identities. + """ + intp: dict[Operation, Callable[..., Any]] = {} + for op, strategy in self.strategy_for_op.items(): + if op in fvs: + intp[op] = draw(strategy) + return intp + + @given(intp=random_interpretation()) + @settings( + max_examples=max_examples, deadline=deadline, report_multiple_bugs=False + ) + def _check_semantics(intp): + with handler(normalize), handler(intp): + lhs_val = evaluate(lhs) + rhs_val = evaluate(rhs) + assert self.eq(lhs_val, rhs_val) + _check_semantics() -def _int_eq(a: Any, b: Any) -> bool: - return not isinstance(a, Term) and not isinstance(b, Term) and a == b +def _is_weighted(x: Any) -> bool: + return isinstance(x, Term) and _is_monoid_weighted(x.op) -def _jax_eq(a: Any, b: Any) -> bool: - def _leaf_eq(x: Any, y: Any) -> bool: - return bool(jax.numpy.all(jax.numpy.isclose(x, y, equal_nan=True))) - try: - leaves = jax.tree.leaves(jax.tree.map(_leaf_eq, a, b)) - except (ValueError, TypeError): +def _weight_pairs(x: Any, monoid: Any) -> list[tuple[Any, Any]] | None: + """Return ``(element, weight)`` pairs for a stream. + + A weighted-monoid Term yields each element paired with its weight. A plain + (unweighted) stream yields each element paired with ``monoid.identity`` -- + the no-op weight -- so an unweighted stream compares equal to a weighted one + exactly when every weight reduces to the identity (e.g. ``[()]`` vs a + weighted ``[()]`` whose single empty row reduces to the identity, and, more + generally, whenever both streams are empty). Returns ``None`` for a + non-stream Term, which never compares equal to a weighted stream. + """ + if isinstance(x, Term): + if not _is_monoid_weighted(x.op): + return None + stream, weight = x.args + assert not isinstance(stream, Term) + return [(e, typing.cast(Callable, weight)(e)) for e in stream] + return [(e, monoid.identity) for e in x] + + +def _weighted_stream_eq(a, b, leaf_eq: Callable[[Any, Any], bool]) -> bool: + monoids = {x.op.__self__ for x in (a, b) if _is_weighted(x)} + # distinct weight monoids can never be equal + if len(monoids) != 1: return False - return all(leaves) - - -def check_rewrite( - lhs, - rhs, - rule, - *, - backend: Backend, - free_vars=[], - max_examples: int = 25, - deadline=None, -) -> None: - with handler(rule): - norm = evaluate(lhs) - assert syntactic_eq_alpha(norm, rhs) - - @given(intp=random_interpretation(free_vars)) - @settings(max_examples=max_examples, deadline=deadline) - def _check_semantics(intp): - with handler(NormalizeIntp), handler(intp): - lhs_val = evaluate(lhs) - rhs_val = evaluate(rhs) - assert backend.eq(lhs_val, rhs_val) - - _check_semantics() - - -INT_BACKEND = Backend( - name="int", - scalar_typ=int, - stream_typ=list[int], - scalar_strategy=st.integers(min_value=-100, max_value=100), - eq=_int_eq, -) - - -JAX_BACKEND = Backend( - name="jax", - scalar_typ=jax.Array, - stream_typ=jax.Array, - scalar_strategy=_jax_array_value_strategy(), - eq=_jax_eq, -) - - -__all__ = [ - "Backend", - "INT_BACKEND", - "JAX_BACKEND", - "random_interpretation", - "define_vars", - "syntactic_eq_alpha", - "check_rewrite", -] + monoid = next(iter(monoids)) + + a_pairs = _weight_pairs(a, monoid) + b_pairs = _weight_pairs(b, monoid) + if a_pairs is None or b_pairs is None or len(a_pairs) != len(b_pairs): + return False + for (ea, wa), (eb, wb) in zip(a_pairs, b_pairs): + if not leaf_eq(ea, eb) or not leaf_eq(wa, wb): + return False + return True + + +class IntBackend(Backend): + name = "int" + scalar_typ = int + stream_typ = Stream[int] + + _unary_num_fns: list[Callable[[int], int]] = [ + lambda x: x, + lambda x: x + 1, + lambda x: x - 1, + lambda x: -x, + lambda x: 2 * x, + lambda x: 3 * x + 1, + ] + + _binary_num_fns: list[Callable[[int, int], int]] = [ + lambda x, y: x + y, + lambda x, y: x - y, + lambda x, y: x * y, + lambda x, y: x + 2 * y, + lambda x, y: 2 * x - y, + ] + + _unary_list_fns: list[Callable[[int], list[int]]] = [ + lambda _x: [], + lambda x: [x], + lambda x: [x, x + 1], + lambda x: [x, -x], + lambda x: [0, x, x + 1], + ] + + def strategy( + self, + arg_types: tuple[type, ...] = (), + ret: Literal["scalar", "stream"] = "scalar", + ) -> SearchStrategy: + match arg_types, ret: + case (), "scalar": + return st.integers(min_value=-100, max_value=100).map(deffn) + case (), "stream": + scalars = st.integers(min_value=-100, max_value=100) + return st.lists(scalars, max_size=2).map(deffn) + case (builtins.int,), "scalar": + return st.sampled_from(self._unary_num_fns) + case (builtins.int, builtins.int), "scalar": + return st.sampled_from(self._binary_num_fns) + case (builtins.int,), "stream": + return st.sampled_from(self._unary_list_fns) + raise NotImplementedError( + f"No int strategy for op with return {ret!r} and {arg_types} args" + ) + + def eq(self, a: Any, b: Any) -> bool: + if _is_weighted(a) or _is_weighted(b): + return _weighted_stream_eq(a, b, self.eq) + return not isinstance(a, Term) and not isinstance(b, Term) and a == b + + +class JaxBackend(Backend): + name = "jax" + scalar_typ = jax.Array + stream_typ = jax.Array + + _unary_jax_scalar_fns: list[Callable[[jax.Array], jax.Array]] = [ + lambda a: a, + lambda a: a + 1, + lambda a: a - 1, + lambda a: -a, + lambda a: 2 * a, + ] + + _unary_jax_stream_fns: list[Callable[[jax.Array], Stream[jax.Array]]] = [ + lambda a: _jnp.stack([a, a + 1]), + lambda a: _jnp.stack([a, -a]), + lambda a: _jnp.stack([a, a + 1, 2 * a]), + ] + + _binary_jax_scalar_fns: list[Callable[[jax.Array, jax.Array], jax.Array]] = [ + lambda a, b: a + b, + lambda a, b: a - b, + lambda a, b: a * b, + ] + + def strategy( + self, + arg_types: tuple[type, ...] = (), + ret: Literal["scalar", "stream"] = "scalar", + ) -> st.SearchStrategy[Callable]: + match arg_types, ret: + case (), "scalar": + return ( + st.lists( + st.integers(min_value=-5, max_value=5), + min_size=2, + max_size=2, + ) + .map(lambda xs: jax.numpy.asarray(xs, dtype=jax.numpy.float32)) + .map(deffn) + ) + case (), "stream": + return ( + st.lists( + st.integers(min_value=-5, max_value=5), + min_size=1, + max_size=2, + ) + .map(lambda xs: jax.numpy.asarray(xs, dtype=jax.numpy.float32)) + .map(deffn) + ) + case (jax.Array,), "scalar": + return st.sampled_from(self._unary_jax_scalar_fns) + case (jax.Array, jax.Array), "scalar": + return st.sampled_from(self._binary_jax_scalar_fns) + case (jax.Array,), "stream": + return st.sampled_from(self._unary_jax_stream_fns) + + raise NotImplementedError( + f"No jax strategy for op with return {ret!r} and {arg_types} args" + ) + + def eq(self, a: Any, b: Any) -> bool: + if _is_weighted(a) or _is_weighted(b): + return _weighted_stream_eq(a, b, self.eq) + + def _leaf_eq(x: Any, y: Any) -> bool: + return bool(jax.numpy.all(jax.numpy.isclose(x, y, equal_nan=True))) + + try: + leaves = jax.tree.leaves(jax.tree.map(_leaf_eq, a, b)) + except (ValueError, TypeError): + return False + return all(leaves) + + +__all__ = ["Backend", "IntBackend", "JaxBackend", "syntactic_eq_alpha"] diff --git a/tests/test_handlers_jax_monoid.py b/tests/test_handlers_jax_monoid.py index fe888ad4..18df8401 100644 --- a/tests/test_handlers_jax_monoid.py +++ b/tests/test_handlers_jax_monoid.py @@ -1,3 +1,6 @@ +import functools +import typing + import jax import pytest from jax import random as random @@ -7,15 +10,24 @@ from effectful.handlers.jax.monoid import ( ArrayReduce, LogSumExp, + ProductPlusJax, ReduceDeltaIndependent, ReduceDependentRangeMask, delta, ) from effectful.handlers.jax.monoid import range as Range from effectful.handlers.jax.scipy.special import logsumexp -from effectful.ops.monoid import Max, Min, NormalizeIntp, Product, Sum -from effectful.ops.semantics import handler -from tests._monoid_helpers import JAX_BACKEND, Backend, check_rewrite, define_vars +from effectful.ops.monoid import ( + Max, + Min, + NormalizeIntp, + Product, + ReduceWeightedStream, + Sum, +) +from effectful.ops.semantics import coproduct, handler +from effectful.ops.types import Interpretation +from tests._monoid_helpers import JaxBackend MONOIDS = [ pytest.param(Sum, jnp.sum, id="Sum"), @@ -27,28 +39,27 @@ @pytest.fixture -def backend() -> Backend: - return JAX_BACKEND +def backend() -> JaxBackend: + return JaxBackend() @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_array_1(monoid, reductor, backend: Backend): - (x, k) = define_vars("x", "k", typ=jax.Array) - X = define_vars("X", typ=backend.stream_typ) +def test_reduce_array_1(monoid, reductor, backend: JaxBackend): + (x, k) = backend.define_vars("x", "k", ret="scalar") + X = backend.define_vars("X", ret="stream") lhs = monoid.reduce(x(), {x: X()}) rhs = reductor(bind_dims(unbind_dims(X(), k), k), axis=0) - - check_rewrite( - lhs=lhs, rhs=rhs, rule=ArrayReduce(), backend=backend, free_vars=[x, X, k] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ArrayReduce()) @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_array_2(monoid, reductor, backend: Backend): - (x, y, k1, k2) = define_vars("x", "y", "k1", "k2", typ=backend.scalar_typ) - (X, Y) = define_vars("X", "Y", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=2, ret="scalar") +def test_reduce_array_2(monoid, reductor, backend: JaxBackend): + (x, y, k1, k2) = backend.define_vars("x", "y", "k1", "k2", ret="scalar") + (X, Y) = backend.define_vars("X", "Y", ret="stream") + f = backend.define_vars( + "f", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) lhs = monoid.reduce(f(x(), y()), {x: X(), y: Y()}) rhs = reductor( @@ -61,25 +72,20 @@ def test_reduce_array_2(monoid, reductor, backend: Backend): ), axis=0, ) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ArrayReduce(), - backend=backend, - free_vars=[x, y, k1, k2, X, Y, f], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ArrayReduce()) @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_array_3(monoid, reductor, backend: Backend): +def test_reduce_array_3(monoid, reductor, backend: JaxBackend): """Stream `y` is `g(x())` — depends on the bound element of X. The reducer must inline ``g`` along the same named dim used to unbind `x`.""" - (x, y, k1, k2) = define_vars("x", "y", "k1", "k2", typ=backend.scalar_typ) - X = define_vars("X", typ=backend.stream_typ) + (x, y, k1, k2) = backend.define_vars("x", "y", "k1", "k2", ret="scalar") + X = backend.define_vars("X", ret="stream") - f = backend.fresh_op("f", n_args=2, ret="scalar") - g = backend.fresh_op("g", n_args=1, ret="stream") + f = backend.define_vars( + "f", arg_types=[backend.scalar_typ, backend.scalar_typ], ret="scalar" + ) + g = backend.define_vars("g", arg_types=[backend.scalar_typ], ret="stream") lhs = monoid.reduce(f(x(), y()), {x: X(), y: g(x())}) rhs = reductor( @@ -95,13 +101,37 @@ def test_reduce_array_3(monoid, reductor, backend: Backend): ), axis=0, ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ArrayReduce()) - check_rewrite( + +def test_jax_weighted_reduce(backend: JaxBackend): + """Sum over a single stream with ``Product`` weights lowers to + ``jnp.sum(w(X) * body(X))`` under ``NormalizeIntp`` ∘ ``ArrayReduce``. + + Verifies that the desugaring rule composes cleanly with the JAX lowering + so existing handlers need no changes to support weighted streams. + + """ + (x, k) = backend.define_vars("x", "k", ret="scalar") + X = backend.define_vars("X", ret="stream") + body = backend.define_vars("body", arg_types=[backend.scalar_typ], ret="scalar") + w = backend.define_vars("w", arg_types=[backend.scalar_typ], ret="scalar") + + ws = Product.weighted(X(), w) + lhs = Sum.reduce(body(x()), {x: ws}) + rhs = jnp.sum( + bind_dims(w(unbind_dims(X(), k)) * body(unbind_dims(X(), k)), k), axis=0 + ) + backend.check_rewrite( lhs=lhs, rhs=rhs, - rule=ArrayReduce(), - backend=backend, - free_vars=[x, y, k1, k2, X, f, g], + rule=functools.reduce( + coproduct, + typing.cast( + list[Interpretation], + [ReduceWeightedStream(), ArrayReduce(), ProductPlusJax()], + ), + ), ) @@ -112,62 +142,51 @@ def test_reduce_array_3(monoid, reductor, backend: Backend): @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_delta_empty(monoid, reductor, backend: Backend): +def test_reduce_delta_empty(monoid, reductor, backend: JaxBackend): """An empty-index delta unwraps to its body. reduce(M, streams, delta((), body)) ≡ reduce(M, streams, body) """ - x = define_vars("x", typ=backend.scalar_typ) - X = define_vars("X", typ=backend.stream_typ) + x = backend.define_vars("x", ret="scalar") + X = backend.define_vars("X", ret="stream") lhs = monoid.reduce(delta((), x()), {x: X()}) rhs = monoid.reduce(x(), {x: X()}) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ReduceDeltaIndependent(), - backend=backend, - free_vars=[x, X], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDeltaIndependent()) @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_delta_independent_one(monoid, reductor, backend: Backend): +def test_reduce_delta_independent_one(monoid, reductor, backend: JaxBackend): """One R1 step: peel the final preserved index off a delta. reduce(M, {y: Y()}, delta((y(),), f(y()))) ≡ reduce(M, {}, delta((), bind_dims(f(unbind_dims(Y(), k)), k))) """ - (y, k) = define_vars("y", "k", typ=backend.scalar_typ) - Y = define_vars("Y", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=1, ret="scalar") + (y, k) = backend.define_vars("y", "k", ret="scalar") + f = backend.define_vars("f", arg_types=[backend.scalar_typ], ret="scalar") # We use a concrete range here instead of an abstract one, because # unbind_dims is undefined on empty arrays (and the rewrite produces a # different rhs in this case) lhs = monoid.reduce(delta((y(),), f(y())), {y: Range(3)}) rhs = monoid.reduce(bind_dims(f(unbind_dims(jnp.arange(3), k)), k), {}) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ReduceDeltaIndependent(), - backend=backend, - free_vars=[y, k, Y, f], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDeltaIndependent()) @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_delta_independent_preserves_others(monoid, reductor, backend: Backend): +def test_reduce_delta_independent_preserves_others( + monoid, reductor, backend: JaxBackend +): """R1 peels only the final index. Streams not matching the peeled index op stay untouched, as do earlier entries in the index tuple. reduce(M, {x: X(), y: Y()}, delta((x(), y()), f(x(), y()))) ≡ reduce(M, {x: X()}, delta((x(),), bind_dims(f(x(), unbind_dims(Y(), k)), k))) """ - (x, y, k) = define_vars("x", "y", "k", typ=backend.scalar_typ) - f = backend.fresh_op("f", n_args=2, ret="scalar") + (x, y, k) = backend.define_vars("x", "y", "k", ret="scalar") + f = backend.define_vars( + "f", arg_types=[backend.scalar_typ, backend.scalar_typ], ret="scalar" + ) lhs = monoid.reduce(delta((x(), y()), f(x(), y())), {x: Range(2), y: Range(3)}) rhs = monoid.reduce( @@ -179,27 +198,22 @@ def test_reduce_delta_independent_preserves_others(monoid, reductor, backend: Ba ), {}, ) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ReduceDeltaIndependent(), - backend=backend, - free_vars=[f], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDeltaIndependent()) @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_dependent_range_mask(monoid, reductor, backend: Backend): +def test_reduce_dependent_range_mask(monoid, reductor, backend: JaxBackend): """A dependent range stream gets rewritten to the referent's bbox stream, with the original constraint folded into the body as a where-guard. reduce(M, {u: range(0, N, 1), v: range(0, u(), 1)}, body) ≡ reduce(M, {u: range(0, N, 1), v: range(0, N, 1)}, where(v() < u(), body, M.identity)) """ - (u, v) = define_vars("u", "v", typ=backend.scalar_typ) + (u, v) = backend.define_vars("u", "v", ret="scalar") N = 5 - f = backend.fresh_op("f", n_args=2, ret="scalar") + f = backend.define_vars( + "f", arg_types=[backend.scalar_typ, backend.scalar_typ], ret="scalar" + ) body = f(u(), v()) @@ -208,18 +222,11 @@ def test_reduce_dependent_range_mask(monoid, reductor, backend: Backend): jnp.where(v() < u(), body, monoid.identity), {u: Range(0, N, 1), v: Range(0, N, 1)}, ) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ReduceDependentRangeMask(), - backend=backend, - free_vars=[u, v, f], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDependentRangeMask()) @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_dependent_range_mask_delta_body(monoid, reductor, backend: Backend): +def test_reduce_dependent_range_mask_delta_body(monoid, reductor, backend: JaxBackend): """When the body is a delta term, R4 folds the constraint into the delta's weight while leaving its index tuple untouched. @@ -227,9 +234,11 @@ def test_reduce_dependent_range_mask_delta_body(monoid, reductor, backend: Backe ≡ reduce(M, {u: range(N), v: range(N)}, delta((u(), v()), where(v() < u(), w, M.identity))) """ - (u, v) = define_vars("u", "v", typ=backend.scalar_typ) + (u, v) = backend.define_vars("u", "v", ret="scalar") N = 5 - f = backend.fresh_op("f", n_args=2, ret="scalar") + f = backend.define_vars( + "f", arg_types=[backend.scalar_typ, backend.scalar_typ], ret="scalar" + ) weight = f(u(), v()) idx = (u(), v()) @@ -239,17 +248,10 @@ def test_reduce_dependent_range_mask_delta_body(monoid, reductor, backend: Backe delta(idx, jnp.where(v() < u(), weight, monoid.identity)), {u: Range(0, N, 1), v: Range(0, N, 1)}, ) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ReduceDependentRangeMask(), - backend=backend, - free_vars=[u, v, f], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDependentRangeMask()) -def test_reduce_matmul(): +def test_reduce_matmul(backend: JaxBackend): key = jax.random.PRNGKey(0) # Define dimensions B, I, J, K = 2, 3, 4, 5 @@ -257,7 +259,7 @@ def test_reduce_matmul(): # Create sample matrices X = random.normal(key, (B, I, J)) Y = random.normal(key, (B, J, K)) - (b, i, j, k) = define_vars("b", "i", "j", "k", typ=jax.Array) + (b, i, j, k) = backend.define_vars("b", "i", "j", "k", ret="scalar") with handler(NormalizeIntp): actual = Sum.reduce( diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index c7ee7567..fcd72f06 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -1,10 +1,13 @@ +import math import typing +from collections.abc import Iterable import pytest from hypothesis import HealthCheck, given, settings from hypothesis import strategies as st import effectful.handlers.jax.monoid # noqa: F401 +import effectful.handlers.jax.numpy as jnp from effectful.ops.monoid import ( CartesianProduct, Max, @@ -22,29 +25,25 @@ PlusSingle, PlusZero, Product, + ReduceCartesianWeightedStream, ReduceDistributeCartesianProduct, ReduceFactorization, ReduceFusion, ReduceNoStreams, ReduceSplit, + ReduceWeightedStream, Sum, distributes_over, ) -from effectful.ops.semantics import fvsof, handler -from effectful.ops.types import Operation -from tests._monoid_helpers import ( - INT_BACKEND, - JAX_BACKEND, - Backend, - check_rewrite, - define_vars, - syntactic_eq_alpha, -) +from effectful.ops.semantics import coproduct, evaluate, fvsof, handler +from effectful.ops.syntax import deffn +from effectful.ops.types import NotHandled, Operation, Term +from tests._monoid_helpers import Backend, IntBackend, JaxBackend, syntactic_eq_alpha -@pytest.fixture(params=[INT_BACKEND, JAX_BACKEND], ids=["int", "jax"]) +@pytest.fixture(params=[IntBackend, JaxBackend], ids=["int", "jax"]) def backend(request) -> Backend: - return request.param + return request.param() ALL_MONOIDS = [ @@ -90,10 +89,10 @@ def backend(request) -> Backend: deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture], ) -def test_associativity(monoid, backend, data): - a = data.draw(backend.scalar_strategy) - b = data.draw(backend.scalar_strategy) - c = data.draw(backend.scalar_strategy) +def test_associativity(monoid, backend: Backend, data): + a = data.draw(backend.strategy(ret="scalar"))() + b = data.draw(backend.strategy(ret="scalar"))() + c = data.draw(backend.strategy(ret="scalar"))() with handler(NormalizeIntp): left = monoid.plus(monoid.plus(a, b), c) right = monoid.plus(a, monoid.plus(b, c)) @@ -107,8 +106,8 @@ def test_associativity(monoid, backend, data): deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture], ) -def test_identity(monoid, backend, data): - a = data.draw(backend.scalar_strategy) +def test_identity(monoid, backend: Backend, data): + a = data.draw(backend.strategy(ret="scalar"))() with handler(NormalizeIntp): assert backend.eq(monoid.plus(monoid.identity, a), a) assert backend.eq(monoid.plus(a, monoid.identity), a) @@ -121,9 +120,9 @@ def test_identity(monoid, backend, data): deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture], ) -def test_commutativity(monoid, backend, data): - a = data.draw(backend.scalar_strategy) - b = data.draw(backend.scalar_strategy) +def test_commutativity(monoid, backend: Backend, data): + a = data.draw(backend.strategy(ret="scalar"))() + b = data.draw(backend.strategy(ret="scalar"))() with handler(NormalizeIntp): assert backend.eq(monoid.plus(a, b), monoid.plus(b, a)) @@ -135,8 +134,8 @@ def test_commutativity(monoid, backend, data): deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture], ) -def test_idempotence(monoid, backend, data): - a = data.draw(backend.scalar_strategy) +def test_idempotence(monoid, backend: Backend, data): + a = data.draw(backend.strategy(ret="scalar"))() with handler(NormalizeIntp): assert backend.eq(monoid.plus(a, a), a) @@ -148,102 +147,86 @@ def test_idempotence(monoid, backend, data): deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture], ) -def test_zero_absorbs(monoid, backend, data): - a = data.draw(backend.scalar_strategy) +def test_zero_absorbs(monoid, backend: Backend, data): + a = data.draw(backend.strategy(ret="scalar"))() with handler(NormalizeIntp): assert backend.eq(monoid.plus(monoid.zero, a), monoid.zero) assert backend.eq(monoid.plus(a, monoid.zero), monoid.zero) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_empty(monoid, backend): - check_rewrite( - lhs=monoid.plus(), rhs=monoid.identity, rule=PlusEmpty(), backend=backend - ) +def test_plus_empty(monoid, backend: Backend): + backend.check_rewrite(lhs=monoid.plus(), rhs=monoid.identity, rule=PlusEmpty()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_single(monoid, backend): - x = define_vars("x", typ=backend.scalar_typ) - check_rewrite( - lhs=monoid.plus(x()), rhs=x(), rule=PlusSingle(), backend=backend, free_vars=[x] - ) +def test_plus_single(monoid, backend: Backend): + x = backend.define_vars("x", ret="scalar") + backend.check_rewrite(lhs=monoid.plus(x()), rhs=x(), rule=PlusSingle()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_identity_right(monoid, backend): - x = define_vars("x", typ=backend.scalar_typ) +def test_plus_identity_right(monoid, backend: Backend): + x = backend.define_vars("x", ret="scalar") lhs = monoid.plus(x(), monoid.identity) rhs = monoid.plus(x()) - check_rewrite(lhs=lhs, rhs=rhs, rule=PlusIdentity(), backend=backend, free_vars=[x]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusIdentity()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_identity_left(monoid, backend): - x = define_vars("x", typ=backend.scalar_typ) +def test_plus_identity_left(monoid, backend: Backend): + x = backend.define_vars("x", ret="scalar") lhs = monoid.plus(monoid.identity, x()) rhs = monoid.plus(x()) - check_rewrite(lhs=lhs, rhs=rhs, rule=PlusIdentity(), backend=backend, free_vars=[x]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusIdentity()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_assoc_right(monoid, backend): - x, y, z = define_vars("x", "y", "z", typ=backend.scalar_typ) - check_rewrite( +def test_plus_assoc_right(monoid, backend: Backend): + x, y, z = backend.define_vars("x", "y", "z", ret="scalar") + backend.check_rewrite( lhs=monoid.plus(x(), monoid.plus(y(), z())), rhs=monoid.plus(x(), y(), z()), rule=PlusAssoc(), - backend=backend, - free_vars=[x, y, z], ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_assoc_left(monoid, backend): - x, y, z = define_vars("x", "y", "z", typ=backend.scalar_typ) - check_rewrite( +def test_plus_assoc_left(monoid, backend: Backend): + x, y, z = backend.define_vars("x", "y", "z", ret="scalar") + backend.check_rewrite( lhs=monoid.plus(monoid.plus(x(), y()), z()), rhs=monoid.plus(x(), y(), z()), rule=PlusAssoc(), - backend=backend, - free_vars=[x, y, z], ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_sequence(monoid, backend): - a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) - check_rewrite( +def test_plus_sequence(monoid, backend: Backend): + a, b, c, d = backend.define_vars("a", "b", "c", "d", ret="scalar") + backend.check_rewrite( lhs=monoid.plus((a(), b()), (c(), d())), rhs=(monoid.plus(a(), c()), monoid.plus(b(), d())), rule=MonoidOverSequence(), - backend=backend, - free_vars=[a, b, c, d], ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_mapping(monoid, backend): - a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) +def test_plus_mapping(monoid, backend: Backend): + a, b, c, d = backend.define_vars("a", "b", "c", "d", ret="scalar") lhs = monoid.plus({0: a(), 1: b()}, {0: c(), 2: d()}) rhs = {0: monoid.plus(a(), c()), 1: monoid.plus(b()), 2: monoid.plus(d())} - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=MonoidOverMapping(), - backend=backend, - free_vars=[a, b, c, d], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=MonoidOverMapping()) -def test_plus_distributes(backend): - a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) +def test_plus_distributes(backend: Backend): + a, b, c, d = backend.define_vars("a", "b", "c", "d", ret="scalar") lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d())) rhs = Product.plus( Sum.plus( @@ -253,13 +236,11 @@ def test_plus_distributes(backend): Product.plus(b(), d()), ) ) - check_rewrite( - lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDistr()) -def test_plus_distributes_constant(backend): - a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) +def test_plus_distributes_constant(backend: Backend): + a, b, c, d = backend.define_vars("a", "b", "c", "d", ret="scalar") lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d()), 5) rhs = Product.plus( 5, @@ -270,13 +251,11 @@ def test_plus_distributes_constant(backend): Product.plus(b(), d()), ), ) - check_rewrite( - lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDistr()) -def test_plus_distributes_multiple(backend): - a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) +def test_plus_distributes_multiple(backend: Backend): + a, b, c, d = backend.define_vars("a", "b", "c", "d", ret="scalar") lhs = Sum.plus( Min.plus(a(), b()), Min.plus(c(), d()), @@ -297,238 +276,195 @@ def test_plus_distributes_multiple(backend): Sum.plus(b(), d()), ), ) - check_rewrite( - lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDistr()) @pytest.mark.parametrize("monoid", IDEMPOTENT) -def test_plus_idempotent_consecutive(monoid, backend): +def test_plus_idempotent_consecutive(monoid, backend: Backend): """``a, a, b → a, b`` — only consecutive duplicates collapse.""" - a, b = define_vars("a", "b", typ=backend.scalar_typ) + a, b = backend.define_vars("a", "b", ret="scalar") lhs = monoid.plus(a(), a(), b()) - return check_rewrite( - lhs=lhs, - rhs=monoid.plus(a(), b()), - rule=PlusConsecutiveDups(), - backend=backend, - free_vars=[a, b], + return backend.check_rewrite( + lhs=lhs, rhs=monoid.plus(a(), b()), rule=PlusConsecutiveDups() ) @pytest.mark.parametrize("monoid", IDEMPOTENT) -def test_plus_idempotent_non_consecutive(monoid, backend): +def test_plus_idempotent_non_consecutive(monoid, backend: Backend): """``a, b, a`` — Semilattice (Min/Max) collapses via commutative PlusDups.""" - a, b = define_vars("a", "b", typ=backend.scalar_typ) + a, b = backend.define_vars("a", "b", ret="scalar") lhs = monoid.plus(a(), b(), a()) rhs = monoid.plus(a(), b()) - check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups(), backend=backend, free_vars=[a, b]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups()) @pytest.mark.parametrize("monoid", [Min, Max]) -def test_plus_commutative_idempotent_long(monoid, backend): +def test_plus_commutative_idempotent_long(monoid, backend: Backend): """Long alternation collapses via commutative dedup (Min/Max only).""" - a, b = define_vars("a", "b", typ=backend.scalar_typ) + a, b = backend.define_vars("a", "b", ret="scalar") lhs = monoid.plus(a(), b(), a(), b(), b(), a(), a()) rhs = monoid.plus(a(), b()) - check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups(), backend=backend, free_vars=[a, b]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups()) @pytest.mark.parametrize("monoid", WITH_ZERO) -def test_plus_zero(monoid, backend): - a = define_vars("a", typ=backend.scalar_typ) +def test_plus_zero(monoid, backend: Backend): + a = backend.define_vars("a", ret="scalar") lhs_right = monoid.plus(a(), monoid.zero) lhs_left = monoid.plus(monoid.zero, a()) rhs = monoid.zero - check_rewrite( - lhs=lhs_right, rhs=rhs, rule=PlusZero(), backend=backend, free_vars=[a] - ) - check_rewrite( - lhs=lhs_left, rhs=rhs, rule=PlusZero(), backend=backend, free_vars=[a] - ) + backend.check_rewrite(lhs=lhs_right, rhs=rhs, rule=PlusZero()) + backend.check_rewrite(lhs=lhs_left, rhs=rhs, rule=PlusZero()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_partial_1(monoid, backend): - x, y = define_vars("x", "y", typ=backend.scalar_typ) +def test_partial_1(monoid, backend: Backend): + x = backend.define_vars("x", ret="scalar") lhs = monoid.reduce(x(), {x: []}) rhs = monoid.identity - check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule={}) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_partial_2(monoid, backend): - x, y = define_vars("x", "y", typ=backend.scalar_typ) - Y = define_vars("Y", typ=backend.stream_typ) +def test_partial_2(monoid, backend: Backend): + x, y = backend.define_vars("x", "y", ret="scalar") + Y = backend.define_vars("Y", ret="stream") lhs = monoid.reduce(x(), {y: Y(), x: []}) rhs = monoid.identity - - check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, Y]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule={}) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_partial_3(monoid, backend): - x, y, a, b = define_vars("x", "y", "a", "b", typ=backend.scalar_typ) - Y = define_vars("Y", typ=backend.stream_typ) +def test_partial_3(monoid, backend: Backend): + x, y, a, b = backend.define_vars("x", "y", "a", "b", ret="scalar") + Y = backend.define_vars("Y", ret="stream") lhs = monoid.reduce(x(), {y: Y(), x: [a(), b()]}) rhs = monoid.plus(monoid.reduce(a(), {y: Y()}), monoid.reduce(b(), {y: Y()})) - - check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, a, b, Y]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule={}) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_partial_4(monoid, backend): - x, y, a, b = define_vars("x", "y", "a", "b", typ=backend.scalar_typ) - f = backend.fresh_op("f", n_args=1, ret="stream") +def test_partial_4(monoid, backend: Backend): + x, y, a, b = backend.define_vars("x", "y", "a", "b", ret="scalar") + f = backend.define_vars("f", arg_types=(backend.scalar_typ,), ret="stream") lhs = monoid.reduce(x(), {y: f(x()), x: [a(), b()]}) rhs = monoid.plus(monoid.reduce(a(), {y: f(a())}), monoid.reduce(b(), {y: f(b())})) - - check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, a, b, f]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule={}) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_body_sequence(monoid, backend): - x = Operation.define(backend.scalar_typ, name="x") - X = Operation.define(backend.stream_typ, name="X") - f = backend.fresh_op("f", n_args=1, ret="scalar") - g = Operation.define(f, name="g") +def test_reduce_body_sequence(monoid, backend: Backend): + x = backend.define_vars("x", ret="scalar") + X = backend.define_vars("X", ret="stream") + f, g = backend.define_vars("f", "g", arg_types=(backend.scalar_typ,), ret="scalar") lhs = monoid.reduce((f(x()), g(x())), {x: X()}) rhs = (monoid.reduce(f(x()), {x: X()}), monoid.reduce(g(x()), {x: X()})) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=MonoidOverSequence(), - backend=backend, - free_vars=[X, f, g], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=MonoidOverSequence()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_body_sequence_2(monoid, backend): - x, y = define_vars("x", "y", typ=backend.scalar_typ) - X, Y = define_vars("X", "Y", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=1, ret="scalar") - g = Operation.define(f, name="g") +def test_reduce_body_sequence_2(monoid, backend: Backend): + x, y = backend.define_vars("x", "y", ret="scalar") + X, Y = backend.define_vars("X", "Y", ret="stream") + f, g = backend.define_vars("f", "g", arg_types=(backend.scalar_typ,), ret="scalar") lhs = monoid.reduce((f(x()), g(y())), {x: X(), y: Y()}) rhs = ( monoid.reduce(f(x()), {x: X(), y: Y()}), monoid.reduce(g(y()), {x: X(), y: Y()}), ) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=MonoidOverSequence(), - backend=backend, - free_vars=[X, Y, f, g], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=MonoidOverSequence()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_body_mapping(monoid, backend): - x = Operation.define(backend.scalar_typ, name="x") - X = Operation.define(backend.stream_typ, name="X") - f = backend.fresh_op("f", n_args=1, ret="scalar") - g = Operation.define(f, name="g") +def test_reduce_body_mapping(monoid, backend: Backend): + x = backend.define_vars("x", ret="scalar") + X = backend.define_vars("X", ret="stream") + f, g = backend.define_vars("f", "g", arg_types=(backend.scalar_typ,), ret="scalar") lhs = monoid.reduce({0: f(x()), 1: g(x())}, {x: X()}) rhs = { 0: monoid.reduce(f(x()), {x: X()}), 1: monoid.reduce(g(x()), {x: X()}), } - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=MonoidOverMapping(), - backend=backend, - free_vars=[X, f, g], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=MonoidOverMapping()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_no_streams(monoid, backend): - a = define_vars("a", typ=backend.scalar_typ) +def test_reduce_no_streams(monoid, backend: Backend): + a = backend.define_vars("a", ret="scalar") + lhs = monoid.reduce(a(), {}) rhs = monoid.identity - - check_rewrite( - lhs=lhs, rhs=rhs, rule=ReduceNoStreams(), backend=backend, free_vars=[a] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceNoStreams()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_reduce(monoid, backend): - a, b = define_vars("a", "b", typ=backend.scalar_typ) - A, B = define_vars("A", "B", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=2, ret="scalar") +def test_reduce_reduce(monoid, backend: Backend): + a, b = backend.define_vars("a", "b", ret="scalar") + A, B = backend.define_vars("A", "B", ret="stream") + f = backend.define_vars( + "f", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) lhs = monoid.reduce(monoid.reduce(f(a(), b()), {a: A()}), {b: B()}) rhs = monoid.reduce(f(a(), b()), {a: A(), b: B()}) - - check_rewrite( - lhs=lhs, rhs=rhs, rule=ReduceFusion(), backend=backend, free_vars=[A, B, f] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceFusion()) @pytest.mark.parametrize("monoid", COMMUTATIVE) -def test_reduce_plus(monoid, backend): - a, b = define_vars("a", "b", typ=backend.scalar_typ) - A, B = define_vars("A", "B", typ=backend.stream_typ) +def test_reduce_plus(monoid, backend: Backend): + a, b = backend.define_vars("a", "b", ret="scalar") + A, B = backend.define_vars("A", "B", ret="stream") + lhs = monoid.reduce(monoid.plus(a(), b()), {a: A(), b: B()}) rhs = monoid.plus( monoid.reduce(a(), {a: A(), b: B()}), monoid.reduce(b(), {a: A(), b: B()}), ) - check_rewrite( - lhs=lhs, rhs=rhs, rule=ReduceSplit(), backend=backend, free_vars=[A, B] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceSplit()) -def test_reduce_independent_1(backend): - a, b = define_vars("a", "b", typ=backend.scalar_typ) - A, B = define_vars("A", "B", typ=backend.stream_typ) +def test_reduce_independent_1(backend: Backend): + a, b = backend.define_vars("a", "b", ret="scalar") + A, B = backend.define_vars("A", "B", ret="stream") + lhs = Sum.reduce(Product.plus(a(), b()), {a: A(), b: B()}) rhs = Product.plus( Sum.reduce(Product.plus(a()), {a: A()}), Sum.reduce(Product.plus(b()), {b: B()}) ) - check_rewrite( - lhs=lhs, rhs=rhs, rule=ReduceFactorization(), backend=backend, free_vars=[A, B] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceFactorization()) -def test_reduce_independent_2(backend): - a, b, c = define_vars("a", "b", "c", typ=backend.scalar_typ) - A, B, C = define_vars("A", "B", "C", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=2, ret="scalar") +def test_reduce_independent_2(backend: Backend): + a, b, c = backend.define_vars("a", "b", "c", ret="scalar") + A, B, C = backend.define_vars("A", "B", "C", ret="stream") + f = backend.define_vars( + "f", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c())), {a: A(), b: B(), c: C()}) rhs = Product.plus( Sum.reduce(Product.plus(a()), {a: A()}), Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), ) - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ReduceFactorization(), - backend=backend, - free_vars=[A, B, C, f], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceFactorization()) -def test_reduce_independent_3_negative(backend): +def test_reduce_independent_3_negative(backend: Backend): """Stream `b` depends on `a` (b: g(a())), so the proposed factorization is unsound — the normalizer must NOT apply it.""" - a, b, c = define_vars("a", "b", "c", typ=backend.scalar_typ) - A, C = define_vars("A", "C", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=2, ret="scalar") - g = backend.fresh_op("g", n_args=1, ret="stream") + a, b, c = backend.define_vars("a", "b", "c", ret="scalar") + A, C = backend.define_vars("A", "C", ret="stream") + f = backend.define_vars( + "f", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) + g = backend.define_vars("g", arg_types=(backend.scalar_typ,), ret="stream") with handler(ReduceFactorization()): # ty:ignore[invalid-argument-type] lhs = Sum.reduce( @@ -542,10 +478,12 @@ def test_reduce_independent_3_negative(backend): assert not syntactic_eq_alpha(lhs, bogus_rhs) -def test_reduce_independent_4(backend): - a, b, c = define_vars("a", "b", "c", typ=backend.scalar_typ) - A, B, C = define_vars("A", "B", "C", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=2, ret="scalar") +def test_reduce_independent_4(backend: Backend): + a, b, c = backend.define_vars("a", "b", "c", ret="scalar") + A, B, C = backend.define_vars("A", "B", "C", ret="stream") + f = backend.define_vars( + "f", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c()), 7), {a: A(), b: B(), c: C()}) rhs = Product.plus( @@ -553,39 +491,44 @@ def test_reduce_independent_4(backend): Sum.reduce(Product.plus(a()), {a: A()}), Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), ) - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ReduceFactorization(), - backend=backend, - free_vars=[A, B, C, f], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceFactorization()) + + +def test_reduce_cartesian_3(): + backend = JaxBackend() + i = backend.define_vars("i", ret="scalar") + + with handler(NormalizeIntp): + value = CartesianProduct.reduce(jnp.zeros(2), {i: jnp.arange(3)}) + assert value.shape == (2**3, 3) + + with handler(NormalizeIntp): + value = CartesianProduct.reduce(jnp.zeros(2), {i: jnp.arange(1)}) + assert value.shape == (2**1, 1) + + with handler(NormalizeIntp): + value = CartesianProduct.reduce(jnp.zeros(1), {i: jnp.arange(3)}) + assert value.shape == (1**3, 3) @pytest.mark.parametrize("outer,inner", MONOID_PAIRS) -def test_reduce_lifted_1(outer, inner, backend): - a, i = define_vars("a", "i", typ=backend.scalar_typ) - A, N, A_domain = define_vars("A", "N", "A_domain", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=1, ret="scalar") +def test_reduce_lifted_1(outer, inner, backend: Backend): + a, i = backend.define_vars("a", "i", ret="scalar") + A, N, A_domain = backend.define_vars("A", "N", "A_domain", ret="stream") + f = backend.define_vars("f", arg_types=(backend.scalar_typ,), ret="scalar") - term1 = outer.reduce( + lhs = outer.reduce( inner.reduce(f(a()), {a: A()}), {A: CartesianProduct.reduce(A_domain(), {i: N()})}, ) - term2 = inner.reduce(outer.reduce(inner.plus(f(a())), {a: A_domain()}), {i: N()}) - - check_rewrite( - lhs=term1, - rhs=term2, - rule=ReduceDistributeCartesianProduct(), - backend=backend, - free_vars=[N, A_domain, f], - ) + rhs = inner.reduce(outer.reduce(inner.plus(f(a())), {a: A_domain()}), {i: N()}) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDistributeCartesianProduct()) def test_reduce_cartesian_1(): - a, i = define_vars("a", "i", typ=int) - A = define_vars("A", typ=tuple[int]) + backend = IntBackend() + a, i = backend.define_vars("a", "i", ret="scalar") + A = backend.define_vars("A", ret="stream") with handler(NormalizeIntp): term1 = Sum.reduce( @@ -597,8 +540,9 @@ def test_reduce_cartesian_1(): def test_reduce_cartesian_2(): - a, i = define_vars("a", "i", typ=int) - A = define_vars("A", typ=tuple[int]) + backend = IntBackend() + a, i = backend.define_vars("a", "i", ret="scalar") + A = backend.define_vars("A", ret="stream") with handler(NormalizeIntp): term1 = Sum.reduce( @@ -610,46 +554,41 @@ def test_reduce_cartesian_2(): @pytest.mark.parametrize("outer,inner", MONOID_PAIRS) -def test_reduce_lifted_multi_index(outer, inner, backend): - a, i, j = define_vars("a", "i", "j", typ=backend.scalar_typ) - A, N, M, A_domain = define_vars("A", "N", "M", "A_domain", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=1, ret="scalar") +def test_reduce_lifted_multi_index(outer, inner, backend: Backend): + a, i, j = backend.define_vars("a", "i", "j", ret="scalar") + A, N, M, A_domain = backend.define_vars("A", "N", "M", "A_domain", ret="stream") + f = backend.define_vars("f", arg_types=(backend.scalar_typ,), ret="scalar") - term1 = outer.reduce( + lhs = outer.reduce( inner.reduce(f(a()), {a: A()}), {A: CartesianProduct.reduce(A_domain(), {i: N(), j: M()})}, ) - term2 = inner.reduce( - outer.reduce(inner.plus(f(a())), {a: A_domain()}), - {i: N(), j: M()}, - ) - check_rewrite( - lhs=term1, - rhs=term2, - rule=ReduceDistributeCartesianProduct(), - backend=backend, - free_vars=[N, M, A_domain, f], + rhs = inner.reduce( + outer.reduce(inner.plus(f(a())), {a: A_domain()}), {i: N(), j: M()} ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDistributeCartesianProduct()) @pytest.mark.parametrize("outer,inner", MONOID_PAIRS) -def test_reduce_lifted_2(outer, inner, backend): +def test_reduce_lifted_2(outer, inner, backend: Backend): """The worked example on page 396 of 'Lifted Variable Elimination: Decoupling the Operators from the Constraint Language'. """ - a, i, s, t = define_vars("a", "i", "s", "t", typ=backend.scalar_typ) - A, N, T = define_vars("A", "N", "T", typ=backend.stream_typ) - A_domain = backend.fresh_op("A_domain", n_args=1, ret="stream") - f1 = backend.fresh_op("f1", n_args=2, ret="scalar") - f2 = backend.fresh_op("f2", n_args=2, ret="scalar") + a, i, s, t = backend.define_vars("a", "i", "s", "t", ret="scalar") + A, N, T = backend.define_vars("A", "N", "T", ret="stream") + A_domain = backend.define_vars( + "A_domain", arg_types=(backend.scalar_typ,), ret="stream" + ) + f1, f2 = backend.define_vars( + "f1", "f2", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) - term1 = outer.reduce( + lhs = outer.reduce( inner.reduce(inner.plus(f1(a(), s()), f2(t(), a())), {a: A()}), {A: CartesianProduct.reduce(A_domain(i()), {i: N()}), t: T()}, ) - - term2 = outer.reduce( + rhs = outer.reduce( inner.reduce( outer.reduce( inner.plus(inner.plus(f1(a(), s()), f2(t(), a()))), {a: A_domain(i())} @@ -658,11 +597,143 @@ def test_reduce_lifted_2(outer, inner, backend): ), {t: T()}, ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDistributeCartesianProduct()) + + +# --------------------------------------------------------------------------- +# Weighted streams +# --------------------------------------------------------------------------- + + +def test_reduce_single_weighted_stream(backend: Backend): + """Single weighted stream desugars: + Sum.reduce(body, {a: WS(A, w, Product)}) + = Sum.reduce(Product.plus(w(a), body), {a: A}) + """ + a = backend.define_vars("a", ret="scalar") + A = backend.define_vars("A", ret="stream") + body, w = backend.define_vars( + "body", "w", arg_types=(backend.scalar_typ,), ret="scalar" + ) + + lhs = Sum.reduce(body(a()), {a: Product.weighted(A(), w)}) + rhs = Sum.reduce(Product.plus(w(a()), body(a())), {a: A()}) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceWeightedStream()) + + +def test_reduce_weighted_factorization(backend: Backend): + """Two independent weighted streams under Sum with Product weights factor: + Sum.reduce(f(a)*g(b), {a: Product.weighted(A, a, w_a), b: Product.weighted(B, b, w_b)}) + = (Sum.reduce(w_a(a)*f(a), {a: A})) * (Sum.reduce(w_b(b)*g(b), {b: B})) + + Exercises chaining of ``ReduceWeightedStream`` with ``ReduceFactorization`` + inside ``NormalizeIntp``. + """ + a, b = backend.define_vars("a", "b", ret="scalar") + A, B = backend.define_vars("A", "B", ret="stream") + f, g, w_a, w_b = backend.define_vars( + "f", "g", "w_a", "w_b", arg_types=(backend.scalar_typ,), ret="scalar" + ) + + lhs = Sum.reduce( + Product.plus(f(a()), g(b())), + {a: Product.weighted(A(), w_a), b: Product.weighted(B(), w_b)}, + ) + rhs = Product.plus( + Sum.reduce(Product.plus(w_a(a()), Product.plus(f(a()))), {a: A()}), + Sum.reduce(Product.plus(w_b(b()), Product.plus(g(b()))), {b: B()}), + ) + backend.check_rewrite( + lhs=lhs, rhs=rhs, rule=coproduct(ReduceWeightedStream(), ReduceFactorization()) + ) + + +def test_reduce_cartesian_weighted_stream(backend: Backend): + """``CartesianProduct.reduce`` over a ``WeightedStream`` body whose weight + is independent of the plate var rewrites to a single joint + ``WeightedStream``: + + CartesianProduct.reduce(M.weighted(s, e, w(e)), {p: P}) + = M.weighted(CartesianProduct.reduce(s, {p: P}), row, M.reduce(w(e), {e: row()})) + """ + p, e_var = backend.define_vars("p", "e_var", ret="scalar") + S, P = backend.define_vars("S", "P", ret="stream") + w = backend.define_vars("w", arg_types=(backend.scalar_typ,), ret="scalar") + + lhs = CartesianProduct.reduce(Product.weighted(S(), w), {p: P()}) + row_var = Operation.define(Iterable[backend.scalar_typ], name="row") # type: ignore[name-defined] + rhs = Product.weighted( + CartesianProduct.reduce(S(), {p: P()}), + deffn(Product.reduce(w(e_var()), {e_var: row_var()}), row_var), + ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceCartesianWeightedStream()) + + +def test_lift_weighted_cartesian(backend: Backend): + """Compose ``ReduceCartesianWeightedStream`` + ``ReduceWeightedStream`` + + ``ReduceDistributeCartesianProduct`` on a Sum-of-Product-of-weighted shape: + + Sum.reduce( + Product.reduce(body(a()), {a: A()}), + {A: CartesianProduct.reduce(Product.weighted(S, e, w(e)), {p: P})}, + ) + + The inner ``weighted`` becomes a joint ``weighted`` (rule 1), lifts its + per-element weight into the outer Sum body (rule 2), and the lifted form + matches the inversion pattern (rule 3), yielding:: + + Product.reduce( + Sum.reduce(Product.plus(w(a()), body(a())), {a: S}), + {p: P}, + ) + """ + a, p = backend.define_vars("a", "p", ret="scalar") + A, S, P = backend.define_vars("A", "S", "P", ret="stream") + body, w = backend.define_vars( + "body", "w", arg_types=(backend.scalar_typ,), ret="scalar" + ) - check_rewrite( - lhs=term1, - rhs=term2, - rule=ReduceDistributeCartesianProduct(), - backend=backend, - free_vars=[a, i, s, t, A, N, T, A_domain, f1, f2], + lhs = Sum.reduce( + Product.reduce(body(a()), {a: A()}), + {A: CartesianProduct.reduce(Product.weighted(S(), w), {p: P()})}, + ) + rhs = Product.reduce( + Sum.reduce(Product.plus(w(a()), body(a())), {a: S()}), {p: P()} ) + backend.check_rewrite( + lhs=lhs, + rhs=rhs, + rule=coproduct( + coproduct(ReduceWeightedStream(), ReduceCartesianWeightedStream()), + ReduceDistributeCartesianProduct(), + ), + ) + + +def test_weighted_expectation_demo(): + """Demo: compute E[f(X)] = Σ_x w(x)·f(x) via a weighted reduce. + + X ranges over [1, 2, 3, 4] with weights w(x) = x/10 (a valid distribution + since the weights sum to 1) and f(x) = x*x. Expected value: + 0.1·1 + 0.2·4 + 0.3·9 + 0.4·16 = 10.0 + """ + weights = {1: 0.1, 2: 0.2, 3: 0.3, 4: 0.4} + + def _w(v: int) -> float: + if isinstance(v, Term): + raise NotHandled + return weights[v] + + def _f(v: int) -> float: + if isinstance(v, Term): + raise NotHandled + return float(v * v) + + a = Operation.define(int, name="a") + w = Operation.define(_w, name="w") + f = Operation.define(_f, name="f") + + with handler(NormalizeIntp): + result = evaluate(Sum.reduce(f(a()), {a: Product.weighted([1, 2, 3, 4], w)})) + + assert math.isclose(result, 10.0) From 3b1ac2ade1178fdad347f3fc6bf15161c7782ab4 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Mon, 8 Jun 2026 12:08:22 -0400 Subject: [PATCH 07/10] Replace factorization rules with a push-based rule (#672) * more agressive factorization that hoists shared streams * reduce nesting * comment * replace with simpler push-based rule * format * drop unused disjoint set * remove unused * push multiple streams instead of one at a time --- effectful/internals/disjoint_set.py | 99 ------------- effectful/ops/monoid.py | 205 +++++++++++++++------------ tests/_monoid_helpers.py | 10 ++ tests/test_internals_disjoint_set.py | 124 ---------------- tests/test_ops_monoid.py | 77 +++++++++- 5 files changed, 198 insertions(+), 317 deletions(-) delete mode 100644 effectful/internals/disjoint_set.py delete mode 100644 tests/test_internals_disjoint_set.py diff --git a/effectful/internals/disjoint_set.py b/effectful/internals/disjoint_set.py deleted file mode 100644 index 73b5c5c5..00000000 --- a/effectful/internals/disjoint_set.py +++ /dev/null @@ -1,99 +0,0 @@ -class DisjointSet: - """Disjoint Set Union (Union-Find) data structure. - - Maintains a collection of disjoint sets over the integers 0..n-1, - supporting near-constant-time union and find operations via - path compression and union by rank. - - The amortized time complexity per operation is O(α(n)), where α - is the inverse Ackermann function (effectively constant for any - practical n). - - Example: - >>> dsu = DisjointSet(5) - >>> dsu.union(0, 1) - True - >>> dsu.union(1, 2) - True - >>> dsu.find(0) == dsu.find(2) - True - >>> dsu.find(0) == dsu.find(3) - False - """ - - def __init__(self, n): - """Initialize n singleton sets: {0}, {1}, ..., {n-1}. - - Args: - n: The number of elements. Elements are labeled 0..n-1. - """ - self.parent = list(range(n)) - self.rank = [0] * n - - def _validate(self, x): - if x < 0 or x >= len(self.parent): - raise IndexError(f"Element {x} out of bounds") - - def find(self, x): - """Return the representative (root) of the set containing x. - - Two elements belong to the same set if and only if they have - the same representative. Applies path compression: every node - traversed is re-parented directly to its grandparent, flattening - the tree to speed up future queries. - - Args: - x: The element to look up. - - Returns: - The root element of x's set. - """ - self._validate(x) - while self.parent[x] != x: - self.parent[x] = self.parent[self.parent[x]] # path compression - x = self.parent[x] - return x - - def union(self, *elements): - """Merge the sets containing all given elements into one. - - Accepts any number of elements and unions them all together. - Uses union by rank: shallower trees are attached under the root - of the deeper one, keeping the combined tree shallow. - - Args: - *elements: Two or more elements to merge into a single set. - Calling with 0 or 1 elements is a no-op and returns False. - - Returns: - True if any merging occurred (i.e., at least two of the - elements were in different sets); False if all elements - were already in the same set or fewer than 2 were given. - """ - if len(elements) < 2: - return False - - merged = False - first = elements[0] - - for y in elements[1:]: - if self._union_pair(first, y): - merged = True - - return merged - - def _union_pair(self, x, y): - rx = self.find(x) - ry = self.find(y) - - if rx == ry: - return False - - if self.rank[rx] < self.rank[ry]: - rx, ry = ry, rx - - self.parent[ry] = rx - if self.rank[rx] == self.rank[ry]: - self.rank[rx] += 1 - - return True diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 76351fa6..5f342f25 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -9,7 +9,6 @@ from graphlib import TopologicalSorter from typing import Annotated, Any -from effectful.internals.disjoint_set import DisjointSet from effectful.ops.semantics import ( coproduct, evaluate, @@ -62,6 +61,51 @@ def outer_stream(streams: Streams) -> Iterable[tuple[Operation, Stream, Streams] ) +def inner_stream( + streams: dict[Operation, Expr], +) -> Iterable[tuple[dict[Operation, Expr], Operation, Expr]]: + """Returns the streams that can be ordered innermost in the loop nest as + well as the remaining streams in the nest. + + """ + stream_vars = set(streams.keys()) + + no_dependents = set() + succ = defaultdict(set) + for k, v in streams.items(): + preds = fvsof(v) & stream_vars + if preds: + for pred in preds: + succ[pred].add(k) + else: + no_dependents.add(k) + + topo = TopologicalSorter(succ) + topo.prepare() + return ( + ({k: v for (k, v) in streams.items() if k != op}, op, streams[op]) + for op in set(topo.get_ready()) | no_dependents + ) + + +def inner_streams_first(streams: dict[Operation, Expr]) -> Iterable[Operation]: + """Iterable over streams where dependent streams precede their dependencies.""" + stream_vars = set(streams.keys()) + + no_dependents = set() + succ = defaultdict(set) + for k, v in streams.items(): + preds = fvsof(v) & stream_vars + if preds: + for pred in preds: + succ[pred].add(k) + else: + no_dependents.add(k) + + topo = TopologicalSorter(succ) + return topo.static_order() + + class Monoid[W]: """A monoid with ``plus`` and ``reduce`` :class:`Operation` s.""" @@ -392,110 +436,87 @@ def reduce(self, monoid, body, streams): class ReduceFactorization(ObjectInterpretation): - """ - Implements factorization of independent terms. - For example, when having two independent distributions, - we can rewrite their marginalization as: - ∫p(x)⋅q(y)dxdy => ∫p(x)dx ⋅ ∫q(y)dy - - More specifically, in terms of reduces we are performing: - reduce(R, (S₁ × ... × Sₖ) , A₁ * ... * Aₖ) - => reduce(R, S₁, A₁) * ... * reduce(R, Sₖ, Aₖ) - where free(Aᵢ) ∩ free(Aⱼ) ∩ S = ∅ - and free(Aᵢ) ∩ S ⊆ Sᵢ + """reduce(⊗(F_v ∪ F_rest), {v} ∪ S) = reduce(⊗F_rest ⊗ reduce(⊗F_v, {v}), S) + + where F_v = factors mentioning v, F_rest = the others. Fires only when + v has no dependents among the remaining streams (so it can be innermost) + and F_rest is nonempty (universal variables stay in the outer core). """ @implements(Monoid.reduce) def reduce(self, monoid, body, streams): - if not is_commutative(monoid): - return fwd() - if ( - isinstance(body, Term) + if not ( + is_commutative(monoid) + and isinstance(body, Term) and _is_monoid_plus(body.op) and distributes_over(body.op.__self__, monoid) ): - inner_monoid: Monoid = body.op.__self__ - stream_vars = set(streams.keys()) - factors = [(arg, fvsof(arg)) for arg in body.args] - stream_ids = {v: i for (i, v) in enumerate(stream_vars)} - ds = DisjointSet(len(streams)) - - # streams are in the same partition as their dependencies - for stream_var, stream_id in stream_ids.items(): - stream_body = streams[stream_var] - deps = sorted([stream_ids[v] for v in fvsof(stream_body) & stream_vars]) - ds.union(stream_id, *deps) - - # factors are in the same partition as their dependencies - for _, factor_fvs in factors: - factor_streams = sorted( - [stream_ids[v] for v in (factor_fvs & stream_vars)] - ) - ds.union(*factor_streams) - - placed_streams = set() - new_reduces = [] - for stream_key in streams: - if stream_key in placed_streams: - continue - - partition = ds.find(stream_ids[stream_key]) - partition_streams = { - k: v - for (k, v) in streams.items() - if ds.find(stream_ids[k]) == partition - } - partition_stream_keys = set(partition_streams.keys()) - - partition_factors = [ - t for t in factors if (t[1] & partition_stream_keys) - ] - - assert all( - (t[1] & stream_vars) <= partition_stream_keys - for t in partition_factors - ), "partition contains all streams required by factor" - - partition_term = inner_monoid.plus(*(t[0] for t in partition_factors)) - new_reduces.append((partition_term, partition_streams)) - placed_streams |= partition_stream_keys - - constant_factors = [t for (t, fvs) in factors if not (fvs & stream_vars)] - - if len(new_reduces) > 1: - result = inner_monoid.plus( - *constant_factors, *(monoid.reduce(*args) for args in new_reduces) - ) - return result + return fwd() - return fwd() + inner = body.op.__self__ + stream_keys = set(streams) + factors = [(a, fvsof(a)) for a in body.args] + # candidates: innermost-eligible (no remaining stream depends on v), + # non-universal (some factor doesn't mention v) + support: dict = {} + for v in streams: + if any(v in fvsof(s) for k, s in streams.items() if k is not v): + continue + f_v = frozenset(i for i, (_, fvs) in enumerate(factors) if v in fvs) + if len(f_v) == len(factors): + continue # v is universal: leave it in the outer core + support[v] = f_v + + # eliminate a variable with subset-minimal factor support + # (leaves-first; canonical on hierarchical/laminar supports) + inner_stream = None + inner_factor_ids = None + for v, f_v in support.items(): + if any(u_sup < f_v for u, u_sup in support.items() if u is not v): + continue + inner_stream = v + inner_factor_ids = f_v + break -def inner_stream( - streams: dict[Operation, Expr], -) -> Iterable[tuple[dict[Operation, Expr], Operation, Expr]]: - """Returns the streams that can be ordered innermost in the loop nest as - well as the remaining streams in the nest. + if not inner_stream or not inner_factor_ids: + return fwd() - """ - stream_vars = set(streams.keys()) + inner_factors = [factors[i][0] for i in sorted(inner_factor_ids)] + inner_stream_keys = {inner_stream} + inner_deps = set().union( + *(factors[i][1] for i in f_v), fvsof(streams[v]) & stream_keys + ) - no_dependents = set() - succ = defaultdict(set) - for k, v in streams.items(): - preds = fvsof(v) & stream_vars - if preds: - for pred in preds: - succ[pred].add(k) - else: - no_dependents.add(k) + outer_factors = [a for i, (a, _) in enumerate(factors) if i not in f_v] + outer_stream_keys = stream_keys - inner_stream_keys + outer_factor_deps = set().union( + *(vars for i, (_, vars) in enumerate(factors) if i not in f_v) + ) - topo = TopologicalSorter(succ) - topo.prepare() - return ( - ({k: v for (k, v) in streams.items() if k != op}, op, streams[op]) - for op in set(topo.get_ready()) | no_dependents - ) + # find all streams that are used in the inner factors/streams and are + # not used by the outer factors/streams + # this has to be done iteratively, because moving a stream inward + # reduces the outer dependency set + # ensures that no future factorization application creates a reduce that + # fuses with with the inner reduce + for s in inner_streams_first(streams): + outer_stream_deps = ( + set().union(*(fvsof(streams[k]) for k in outer_stream_keys)) + & stream_keys + ) + outer_deps = outer_factor_deps | outer_stream_deps + if s in inner_deps and s not in outer_deps: + inner_stream_keys |= {s} + inner_deps |= stream_keys & fvsof(streams[s]) + outer_stream_keys -= {s} + + inner_streams = {k: v for (k, v) in streams.items() if k in inner_stream_keys} + inner_red = monoid.reduce(inner.plus(*inner_factors), inner_streams) + + rest_streams = {k: s for k, s in streams.items() if k in outer_stream_keys} + new_body = inner.plus(*outer_factors, inner_red) + return monoid.reduce(new_body, rest_streams) if rest_streams else new_body class ReduceDistributeCartesianProduct(ObjectInterpretation): diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index f8089bec..72787558 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -322,6 +322,11 @@ def strategy( return st.sampled_from(self._unary_num_fns) case (builtins.int, builtins.int), "scalar": return st.sampled_from(self._binary_num_fns) + case (builtins.int, builtins.int, builtins.int), "scalar": + return st.tuples( + st.sampled_from(self._binary_num_fns), + st.sampled_from(self._binary_num_fns), + ).map(lambda fg: lambda a, b, c: fg[0](a, fg[1](b, c))) case (builtins.int,), "stream": return st.sampled_from(self._unary_list_fns) raise NotImplementedError( @@ -389,6 +394,11 @@ def strategy( return st.sampled_from(self._unary_jax_scalar_fns) case (jax.Array, jax.Array), "scalar": return st.sampled_from(self._binary_jax_scalar_fns) + case (jax.Array, jax.Array, jax.Array), "scalar": + return st.tuples( + st.sampled_from(self._binary_jax_scalar_fns), + st.sampled_from(self._binary_jax_scalar_fns), + ).map(lambda fg: lambda a, b, c: fg[0](a, fg[1](b, c))) case (jax.Array,), "stream": return st.sampled_from(self._unary_jax_stream_fns) diff --git a/tests/test_internals_disjoint_set.py b/tests/test_internals_disjoint_set.py deleted file mode 100644 index 808b8d25..00000000 --- a/tests/test_internals_disjoint_set.py +++ /dev/null @@ -1,124 +0,0 @@ -import random - -import pytest - -from effectful.internals.disjoint_set import DisjointSet - - -@pytest.fixture -def dsu(): - return DisjointSet(10) - - -def test_initial_state(dsu): - for i in range(10): - assert dsu.find(i) == i - - -def test_simple_union(dsu): - assert dsu.union(1, 2) is True - assert dsu.find(1) == dsu.find(2) - - -def test_union_idempotent(dsu): - dsu.union(1, 2) - assert dsu.union(1, 2) is False - - -def test_union_chain(dsu): - dsu.union(1, 2) - dsu.union(2, 3) - assert dsu.find(1) == dsu.find(3) - - -def test_union_multiple_elements_all_connected(dsu): - dsu.union(1, 2, 3, 4, 5) - roots = {dsu.find(i) for i in [1, 2, 3, 4, 5]} - assert len(roots) == 1 - - -def test_union_multiple_elements_partial_overlap(dsu): - dsu.union(1, 2) - dsu.union(3, 4) - dsu.union(2, 3, 5) - - roots = {dsu.find(i) for i in [1, 2, 3, 4, 5]} - assert len(roots) == 1 - - -def test_union_multiple_elements_with_existing_connections(dsu): - dsu.union(1, 2) - dsu.union(2, 3) - dsu.union(3, 4, 5, 6) - - roots = {dsu.find(i) for i in [1, 2, 3, 4, 5, 6]} - assert len(roots) == 1 - - -def test_union_single_element(dsu): - assert dsu.union(1) is False - - -def test_union_no_elements(dsu): - assert dsu.union() is False - - -def test_union_self(dsu): - assert dsu.union(3, 3) is False - assert dsu.find(3) == 3 - - -def test_transitivity(dsu): - dsu.union(1, 2) - dsu.union(2, 3) - dsu.union(3, 4) - assert dsu.find(1) == dsu.find(4) - - -def test_disjoint_sets_remain_separate(dsu): - dsu.union(1, 2) - dsu.union(3, 4) - assert dsu.find(1) != dsu.find(3) - - -def test_randomized_unions(): - n = 50 - dsu = DisjointSet(n) - - groups = [{i} for i in range(n)] - - def find_group(x): - for g in groups: - if x in g: - return g - - for _ in range(100): - elems = random.sample(range(n), random.randint(2, 5)) - dsu.union(*elems) - - # merge ground-truth groups - merged = set() - for e in elems: - merged |= find_group(e) - - groups = [g for g in groups if g.isdisjoint(merged)] - groups.append(merged) - - # verify structure matches ground truth - for g in groups: - roots = {dsu.find(x) for x in g} - assert len(roots) == 1 - - -def test_path_compression_effect(): - dsu = DisjointSet(6) - dsu.union(0, 1) - dsu.union(1, 2) - dsu.union(2, 3) - dsu.union(3, 4) - - # Trigger compression - root_before = dsu.find(4) - root_after = dsu.find(4) - - assert root_before == root_after diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index fcd72f06..4d243ca1 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -451,7 +451,10 @@ def test_reduce_independent_2(backend: Backend): lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c())), {a: A(), b: B(), c: C()}) rhs = Product.plus( Sum.reduce(Product.plus(a()), {a: A()}), - Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), + Sum.reduce( + Product.plus(b(), Sum.reduce(Product.plus(f(b(), c())), {c: C()})), + {b: B()}, + ), ) backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceFactorization()) @@ -489,7 +492,77 @@ def test_reduce_independent_4(backend: Backend): rhs = Product.plus( 7, Sum.reduce(Product.plus(a()), {a: A()}), - Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), + Sum.reduce( + Product.plus(b(), Sum.reduce(Product.plus(f(b(), c())), {c: C()})), + {b: B()}, + ), + ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceFactorization()) + + +def test_reduce_chain(backend: Backend): + x, y = backend.define_vars("x", "y", ret="scalar") + X, Y = backend.define_vars("X", "Y", ret="stream") + f, h = backend.define_vars("f", "h", arg_types=(backend.scalar_typ,), ret="scalar") + g = backend.define_vars( + "g", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) + + lhs = Sum.reduce(Product.plus(f(x()), g(x(), y()), h(y())), {x: X(), y: Y()}) + rhs = Sum.reduce( + Product.plus(h(y()), Sum.reduce(Product.plus(f(x()), g(x(), y())), {x: X()})), + {y: Y()}, + ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceFactorization()) + + +@pytest.mark.parametrize("outer,inner", MONOID_PAIRS) +def test_reduce_lift_shared(outer, inner, backend: Backend): + """A stream free in every factor is hoisted into an outer reduce: + Sum.reduce(f(a, c) * g(b, c), {a: A, b: B, c: C}) + = Sum.reduce(Sum.reduce(f(a, c), {a: A}) * Sum.reduce(g(b, c), {b: B}), {c: C}) + """ + a, b, c = backend.define_vars("a", "b", "c", ret="scalar") + A, B, C = backend.define_vars("A", "B", "C", ret="stream") + f, g = backend.define_vars( + "f", "g", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) + + lhs = outer.reduce(inner.plus(f(a(), c()), g(b(), c())), {a: A(), b: B(), c: C()}) + rhs = outer.reduce( + inner.plus( + outer.reduce(inner.plus(f(a(), c())), {a: A()}), + outer.reduce(inner.plus(g(b(), c())), {b: B()}), + ), + {c: C()}, + ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceFactorization()) + + +@pytest.mark.parametrize("outer,inner", MONOID_PAIRS) +def test_reduce_lift_shared_deps(outer, inner, backend: Backend): + """A shared stream is lifted together with its dependencies: both ``c`` + and ``d = h(c)`` appear in every factor, so both are hoisted.""" + a, b, c, d = backend.define_vars("a", "b", "c", "d", ret="scalar") + A, B, C = backend.define_vars("A", "B", "C", ret="stream") + h = backend.define_vars("h", arg_types=(backend.scalar_typ,), ret="stream") + f, g = backend.define_vars( + "f", + "g", + arg_types=(backend.scalar_typ, backend.scalar_typ, backend.scalar_typ), + ret="scalar", + ) + + lhs = outer.reduce( + inner.plus(f(a(), c(), d()), g(b(), c(), d())), + {a: A(), b: B(), c: C(), d: h(c())}, + ) + rhs = outer.reduce( + inner.plus( + outer.reduce(inner.plus(f(a(), c(), d())), {a: A()}), + outer.reduce(inner.plus(g(b(), c(), d())), {b: B()}), + ), + {c: C(), d: h(c())}, ) backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceFactorization()) From 6b262180f9ca56b14ffe3f46615cbca3f9764cbe Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Tue, 9 Jun 2026 17:40:57 -0400 Subject: [PATCH 08/10] Add weighted einsum implementation (#671) * more precise stream type * add tests for weighted rules * add reduction rule for weighted streams and tests * add test to demo expectation * add numpyro monoid module * add quadrature * add tests * wip * refactor tests * wip * test composition of lifting and weighting * drop numpyro changes * drop unused ops * lint * make weighted a Monoid method * fix typing of jax arrays * change weighted typing to take callable * fix test * fix test * resolve type aliases before dispatching * wip * wip * remove typeof_full * wip * wip * wip * format * refactor test harness * fix behavior of delta terms * add baseline einsum * rework einsum to work on shapes instead of concrete tensors * add einsum benchmark * wip * wip * finish sum/product contraction * allow bind_dims to bind nonexistent named dimensions * wip * add custom partial eval for reductions * working benchmarks * fix infinite loop * eliminate identity indexing when possible * wip * handle getitem where dimensions are created * treat any index with bare ops and slice(None) as canonical * simplify range op and add reduction rules * wip * remove old benchmark code * another try at removing identity gathers * refactor * fix test * lint * clean up comment * fix some test failures * drop sketchy bind_dims rule * drop more type-incompatible plus rules * format * fix reduction issue * drop dimension creating behavior from bind_dims * lint * simplify comment * drop partition * fix docstring * handle negative dimension indexing * fix creation of empty tensors * fully restore previous behavior for missing named dims * reduce any arraylike or named tensor * require at least one jax array to reduce * fix typing test * drop typing test * drop einsum parser in favor of opt_einsum * more agressive factorization that hoists shared streams * reduce nesting * comment * replace with simpler push-based rule * format * drop unused disjoint set * remove unused * push multiple streams instead of one at a time * drop contraction ordering handler * fold BindDimsBindDims into default behavior * handle Sum.reduce instead of Monoid.reduce * wip * wip * hacks * extract contraction heuristic * lint * fix test * use a named dimension einsum for contractions * lint * drop custom arange op * wip * simplify by targetting delta rules * wip * fixes * fixes * lint * drop unused * pick up constants but not rest of module * lint --- effectful/handlers/jax/_handlers.py | 114 +++++- effectful/handlers/jax/_terms.py | 71 +++- effectful/handlers/jax/monoid.py | 464 +++++++++++++++-------- effectful/handlers/jax/numpy/__init__.py | 41 +- effectful/handlers/jax/scipy/special.py | 4 +- effectful/ops/monoid.py | 202 +++++----- effectful/ops/syntax.py | 8 + pyproject.toml | 5 +- tests/test_handlers_jax_monoid.py | 431 ++++++++++++++++----- tests/test_ops_monoid.py | 38 +- 10 files changed, 982 insertions(+), 396 deletions(-) diff --git a/effectful/handlers/jax/_handlers.py b/effectful/handlers/jax/_handlers.py index 91fba369..7516adbf 100644 --- a/effectful/handlers/jax/_handlers.py +++ b/effectful/handlers/jax/_handlers.py @@ -1,9 +1,13 @@ import functools +import itertools import typing -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Callable, Iterable, Mapping, Sequence from types import EllipsisType from typing import Annotated +from opt_einsum import get_symbol +from opt_einsum.parser import parse_einsum_input + try: import jax import jax.numpy as jnp @@ -11,7 +15,7 @@ raise ImportError("JAX is required to use effectful.handlers.jax") from effectful.internals.runtime import interpreter -from effectful.ops.semantics import apply, evaluate, fvsof, typeof +from effectful.ops.semantics import apply, evaluate, fvsof, fwd, typeof from effectful.ops.syntax import ( Scoped, _CustomSingleDispatchCallable, @@ -66,9 +70,13 @@ def update_sizes(sizes, op, size): def _getitem_sizeof(x: jax.Array, key: tuple[Expr[IndexElement], ...]): if is_eager_array(x): - for i, k in enumerate(key): + i = 0 + for k in key: if isinstance(k, Term) and len(k.args) == 0 and len(k.kwargs) == 0: update_sizes(sizes, k.op, x.shape[i]) + if k is not None: + i += 1 + return defdata(jax_getitem, x, key) def _apply(op, *args, **kwargs): @@ -89,8 +97,9 @@ def _partial_eval(t: Expr[jax.Array]) -> Expr[jax.Array]: # if any dimension is zero sized, the result is empty if any(size == 0 for size in sized_fvs.values()): - key = tuple(sized_fvs.keys()) - shape = tuple(sized_fvs[k] for k in key) + ops = tuple(sized_fvs.keys()) + key = tuple(k() for k in ops) + shape = tuple(sized_fvs[k] for k in ops) return jax_getitem(jnp.empty(shape), key) def _is_eager(t): @@ -142,7 +151,11 @@ def _jax_op(*args, **kwargs) -> jax.Array: and not isinstance(args[0], Term) and sized_fvs and args[1] - and all(isinstance(k, Term) and k.op in sized_fvs for k in args[1]) + and all( + (isinstance(k, Term) and k.op in sized_fvs) + or (isinstance(k, slice) and k == slice(None)) + for k in args[1] + ) ): raise NotHandled elif sized_fvs and set(sized_fvs.keys()) == fvsof(tm) - {jax_getitem, _jax_op}: @@ -182,6 +195,93 @@ def _jax_op(*args, **kwargs) -> jax.Array: return _jax_op +def _named_dims(term: Expr[jax.Array]) -> tuple[Operation, ...]: + if not (isinstance(term, Term) and term.op == jax_getitem): + return () + index = term.args[1] + assert isinstance(index, Iterable) + return tuple(i.op for i in index if isinstance(i, Term) and not i.args) + + +def _reduce_named(array, axis=None, **kwargs) -> jax.Array: + if axis is None: + return fwd() + + named_dims = _named_dims(array) + if not named_dims: + return fwd() + + bound_arr = bind_dims(array, *named_dims) + + if isinstance(axis, int): + axis = (axis,) + shifted_axis = tuple(a + len(named_dims) if a >= 0 else a for a in axis) + + reduced = fwd(bound_arr, axis=shifted_axis, **kwargs) + return unbind_dims(reduced, *named_dims) + + +def _einsum_named(subscripts, *operands, **kwargs) -> jax.Array: + # only the string-subscripts form is handled; forward the interleaved form + if not isinstance(subscripts, str): + if any(isinstance(x, Term) for x in (subscripts, *operands)): + raise ValueError("Interleaved einsum is not implemented with named tensors") + return jax.numpy.einsum(subscripts, *operands, **kwargs) + + # forward if any operand has a symbolic (Term) shape + if any(isinstance(arr.shape, Term) for arr in operands): + raise NotHandled + + named = [_named_dims(op) for op in operands] + + # normalize: expand ellipses and make the output explicit, using the + # positional shapes (shapes=True avoids materializing the operands) + shapes = [op.shape for op in operands] + in_part, out_part, _ = parse_einsum_input([subscripts, *shapes], shapes=True) + in_specs = in_part.split(",") + assert len(in_specs) == len(operands) + + # fresh symbols for named dims, avoiding every symbol already in use; + # get_symbol gives an effectively unlimited supply (spills into unicode) + used = {c for c in (in_part + out_part) if c not in ",->"} + counter = itertools.count() + + def next_symbol(): + while True: + s = get_symbol(next(counter)) + if s not in used: + used.add(s) + return s + + # assign a letter per unique named dim; shared names reuse the same letter + # so einsum aligns them as batch dims rather than contracting + letter_of, order = {}, [] + for dims in named: + for d in dims: + if d not in letter_of: + letter_of[d] = next_symbol() + order.append(d) + + # bind named dims to leading positional axes and prepend their letters + bound, new_in_specs = [], [] + for op, dims, spec in zip(operands, named, in_specs): + bound.append(bind_dims(op, *dims) if dims else op) + new_in_specs.append("".join(letter_of[d] for d in dims) + spec) + + # add every named dim to the front of the output as passthrough + out_prefix = "".join(letter_of[d] for d in order) + new_subscripts = ",".join(new_in_specs) + "->" + out_prefix + out_part + + result = jax.numpy.einsum(new_subscripts, *bound, **kwargs) + + # unbind: leading axes correspond to `order`, reindex them back to named + reindexed = jax_getitem( + result, + tuple(d() for d in order) + tuple(slice(None) for _ in range(len(out_part))), + ) + return reindexed + + @_register_jax_op def jax_getitem(x: jax.Array, key: tuple[IndexElement, ...]) -> jax.Array: """Operation for indexing an array. Unlike the standard __getitem__ method, @@ -215,6 +315,8 @@ def bind_dims[T, A, B]( >>> bind_dims(t, b, a).shape (3, 2) """ + if isinstance(value, Term) and value.op == bind_dims: + return bind_dims(value.args[0], *(names + tuple(value.args[1:]))) if jax.tree_util.treedef_is_leaf(jax.tree.structure(value)): return __dispatch(typeof(value))(value, *names) return jax.tree.map(lambda v: bind_dims(v, *names), value) diff --git a/effectful/handlers/jax/_terms.py b/effectful/handlers/jax/_terms.py index 05a5390e..53c4c094 100644 --- a/effectful/handlers/jax/_terms.py +++ b/effectful/handlers/jax/_terms.py @@ -467,31 +467,66 @@ def _bind_dims_array(t: jax.Array, *args: Operation[[], jax.Array]) -> jax.Array array = t.args[0] dims = t.args[1] assert isinstance(dims, Sequence) + ndim = len(array.shape) # ensure that the order is a subset of the named dimensions order_set = set(args) if not order_set <= set(a.op for a in dims if isinstance(a, Term)): raise NotHandled - # permute the inner array so that the leading dimensions are in the order - # specified and the trailing dimensions are the remaining named dimensions - # (or slices) - reindex_dims = [ - i - for i, o in enumerate(dims) - if not isinstance(o, Term) or o.op not in order_set - ] - dim_ops = [a.op if isinstance(a, Term) else None for a in dims] - perm = ( - [dim_ops.index(o) for o in args] - + reindex_dims - + list(range(len(dims), len(array.shape))) + def axis_op(ax: int) -> Operation | None: + """The named op of a bare index term at axis ``ax``, else ``None``.""" + if ax < len(dims): + d = dims[ax] + if isinstance(d, Term) and not d.args and not d.kwargs: + return d.op + return None + + # Assign an einsum id to every axis of ``array``. Axes that share a named op + # get the *same* id — a repeated op (e.g. ``arr[i(), i()]``) ties its axes + # together, which einsum reads as a diagonal. Every other axis (slices, ints, + # fancy indices, compound terms, and trailing positional axes) gets a unique + # id, so einsum simply carries it through to be reindexed below. + op_ids: dict[Operation, int] = {} + in_ids: list[int] = [] + next_id = 0 + for ax in range(ndim): + op = axis_op(ax) + if op is not None: + if op not in op_ids: + op_ids[op] = next_id + next_id += 1 + in_ids.append(op_ids[op]) + else: + in_ids.append(next_id) + next_id += 1 + + # Output order: bound args that actually appear (in the requested order, + # deduplicated by the diagonal merge), then every remaining axis in + # first-appearance order. einsum does the permutation and the diagonals; with + # all distinct ids retained in the output it performs no reduction. + present_arg_ids = [op_ids[o] for o in args if o in op_ids] + seen = set(present_arg_ids) + rest_ids: list[int] = [] + for i in in_ids: + if i not in seen: + seen.add(i) + rest_ids.append(i) + + array = jnp.einsum(array, in_ids, present_arg_ids + rest_ids) + + # Re-apply the original index for each carried axis and re-name unbound op + # axes. Trailing positional axes (first appearance beyond ``dims``) are left + # for jax_getitem to carry implicitly. + first_pos: dict[int, int] = {} + for ax, i in enumerate(in_ids): + first_pos.setdefault(i, ax) + + index_expr = (slice(None),) * len(present_arg_ids) + tuple( + dims[first_pos[i]] if first_pos[i] < len(dims) else slice(None) + for i in rest_ids ) - array = jnp.transpose(array, perm) - reindexed = jax_getitem( - array, (slice(None),) * len(args) + tuple(dims[i] for i in reindex_dims) - ) - return reindexed + return jax_getitem(array, index_expr) @unbind_dims.register(jax.Array) # type: ignore diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py index 3f6273be..fbd845d8 100644 --- a/effectful/handlers/jax/monoid.py +++ b/effectful/handlers/jax/monoid.py @@ -1,11 +1,16 @@ import functools +import logging import typing -from collections.abc import Iterable +from typing import Protocol import jax +import jax.core +import opt_einsum +from opt_einsum import get_symbol import effectful.handlers.jax.numpy as jnp -from effectful.handlers.jax import bind_dims, unbind_dims +from effectful.handlers.jax import bind_dims, jax_getitem, unbind_dims +from effectful.handlers.jax._handlers import is_eager_array from effectful.handlers.jax.scipy.special import logsumexp from effectful.ops.monoid import ( CartesianProduct, @@ -16,13 +21,16 @@ Product, Streams, Sum, + _is_monoid_plus, + choose_contraction, distributes_over, - outer_stream, ) from effectful.ops.semantics import evaluate, fvsof, fwd, handler, typeof from effectful.ops.syntax import ObjectInterpretation, deffn, implements from effectful.ops.types import Interpretation, NotHandled, Operation, Term +logger = logging.getLogger(__name__) + def cartesian_prod(x, y): if x.ndim == 1: @@ -46,16 +54,39 @@ def cartesian_prod(x, y): def _jax_args(args): """True iff ``args`` is non-empty and every arg is a concrete - :class:`jax.Array` (no Terms). + :class:`jax.typing.ArrayLike` or named tensor. At least one argument must be + a jax-related type. + """ - typs = (typeof(a) for a in args) return ( bool(args) - and any(issubclass(t, jax.Array) for t in typs) - and all(issubclass(t, jax.typing.ArrayLike) for t in typs) + and all(is_eager_array(a) or isinstance(a, jax.typing.ArrayLike) for a in args) + and any(is_eager_array(a) or isinstance(a, jax.Array) for a in args) ) +class PlusJaxUpcast(ObjectInterpretation): + @implements(Monoid.plus) + def plus(self, monoid, *args): + arg_types = [typeof(a) for a in args] + + def _is_jax(t): + return issubclass(t, jax.Array | jax.core.Tracer) + + # exists array valued and non-array-valued args + if any(_is_jax(t) for t in arg_types) and any( + not _is_jax(t) for t in arg_types + ): + return monoid.plus( + *( + a if _is_jax(t) else jnp.asarray(a) + for (a, t) in zip(args, arg_types, strict=True) + ) + ) + + return fwd() + + class SumPlusJax(ObjectInterpretation): @implements(Sum.plus) def plus(self, *args): @@ -124,125 +155,144 @@ def plus(self, *args): return result -ARRAY_REDUCTORS = { - Sum: jnp.sum, - Product: jnp.prod, - Min: jnp.min, - Max: jnp.max, - LogSumExp: logsumexp, -} +class ReduceArrayGather(ObjectInterpretation): + """M.reduce(body, {k: a} ∪ S) ≡ M.reduce(body[k := a[k']], {k': range(a.shape[0])} ∪ S)""" - -class ArrayReduce(ObjectInterpretation): @implements(Monoid.reduce) def reduce(self, monoid, body, streams): - if monoid not in ARRAY_REDUCTORS or typeof(body) is not jax.Array: + if typeof(body) is not jax.Array: return fwd() - if not streams: - return monoid.identity - reductor = ARRAY_REDUCTORS[monoid] - index = Operation.define(jax.Array) - for stream_key, stream_body, streams_tail in outer_stream(streams): - if not issubclass(typeof(stream_body), jax.Array): - continue + if isinstance(body, Term) and body.op is delta: + return fwd() - if stream_key in fvsof(body): - with handler({stream_key: deffn(unbind_dims(stream_body, index))}): - eval_body = evaluate(body) - eval_streams_tail = evaluate(streams_tail) - assert isinstance(eval_streams_tail, dict) - reduce_tail = ( - monoid.reduce(eval_body, eval_streams_tail) - if len(eval_streams_tail) > 0 - else eval_body - ) - return reductor(bind_dims(reduce_tail, index), axis=0) + body_fvs = fvsof(body) + stream_keys = set(streams) + + body_subst = {} + streams_subst = {} + range_streams = {} + progress = False + for k, v in streams.items(): + if is_eager_array(v) and k in body_fvs and not (fvsof(v) & stream_keys): + kk = Operation.define(k) + body_subst[k] = deffn(unbind_dims(v, kk)) + streams_subst[k] = kk + range_streams[kk] = range(v.shape[0]) + progress = True else: - # TODO: In this case, the stream is unused in the body. The body - # should be multiplied by the length of the stream. The current - # behavior is not efficient. - return fwd() + range_streams[k] = v - return fwd() + if not progress: + return fwd() + subst_body = handler(body_subst)(evaluate)(body) + subst_streams = handler(streams_subst)(evaluate)(range_streams) + return monoid.reduce(subst_body, subst_streams) -@Operation.define -def delta(_index: tuple[int, ...], _weight: jax.Array) -> jax.Array: - raise NotHandled +class Reductor(Protocol): + def __call__( + self, arr: jax.Array, axis: int | tuple[int, ...] | None = None + ) -> jax.Array: ... -py_range = range +ARRAY_REDUCTORS: dict[Monoid, Reductor] = {} +for monoid, func in [ + (Sum, jnp.sum), + (Product, jnp.prod), + (Min, jnp.min), + (Max, jnp.max), +]: + assert isinstance(monoid, Monoid) + assert callable(func) + ARRAY_REDUCTORS[monoid] = functools.partial(func, initial=monoid.identity) -@Operation.define -def range(*args: int) -> Iterable[jax.Array]: - raise NotHandled +ARRAY_REDUCTORS[LogSumExp] = logsumexp -def _range_start(term: Term): - assert term.op == range - if len(term.args) < 2: - return 0 - return term.args[0] +class ReduceArray(ObjectInterpretation): + """Reduce an array body over range streams.""" + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + reductor = ARRAY_REDUCTORS.get(monoid, None) + if reductor is None: + return fwd() + + if typeof(body) is not jax.Array: + return fwd() + + pos_dims = {} + if isinstance(body, Term): + if body.op == delta: + pos_dims = { + d.op + for d in body.args[0] + if isinstance(d, Term) and d.op in streams + } + elif _is_monoid_plus(body.op) and distributes_over( + body.op.__self__, monoid + ): + # delegate to factorization + return fwd() + + body_fvs = fvsof(body) + used = { + k + for k, v in streams.items() + if k in body_fvs and k not in pos_dims and isinstance(v, range) + } + if not used: + return fwd() + + delta_key = tuple(k() for k in streams if k in used) + arr = monoid.reduce(delta(delta_key, body), streams) + reduced_body = reductor(arr, axis=tuple(range(len(used)))) + return reduced_body + + +@Operation.define +def delta(_index: tuple[int, ...], _weight: jax.Array) -> jax.Array: + raise NotHandled def _range_stop(term: Term): - assert term.op == range + assert term.op == jnp.arange + if "stop" in term.kwargs: + return term.kwargs["stop"] if len(term.args) < 2: return term.args[0] return term.args[1] -def _range_step(term: Term): - assert term.op == range - if len(term.args) < 3: - return 1 - return term.args[2] +class DeltaEmpty(ObjectInterpretation): + """delta((), weight) ≡ weight""" + + @implements(delta) + def _(self, index, weight): + if not index: + return weight + return fwd() -def _is_simple_range(term: Term) -> bool: - if term.op != range: - return False +class DeltaFusion(ObjectInterpretation): + """delta(i1, delta(i2, weight)) ≡ delta(i1 ++ i2, weight)""" - start = _range_start(term) - step = _range_step(term) - return ( - not isinstance(start, Term) - and start == 0 - and not isinstance(step, Term) - and step == 1 - ) + @implements(delta) + def _(self, index, weight): + if isinstance(weight, Term) and weight.op == delta: + return delta(index + weight.args[0], weight.args[1]) + return fwd() -class ReduceDeltaIndependent(ObjectInterpretation): +class ReduceDeltaSimpleRange(ObjectInterpretation): """Eliminate a Delta that has independent, dense index arguments. - reduce(M, streams, delta((), body)) ≡ reduce(M, streams, body) - reduce(M, streams ∪ {v: range(N)}, delta(idx' ++ (v(),), body)) + reduce(M, streams ∪ {v: range(N)}, delta((v(),) ++ idx', body)) ═══════════════════════════════════════════════════════════════════════════ - reduce(M, streams, delta(idx', bind_dims(body[v() := unbind_dims(streams[v], fv)], fv))) - - Not yet supported: - - - **Strided index streams** (``range(0, N, k)`` for ``k != 1``): the - premise ``_is_simple_range`` requires ``start == 0`` and ``step == 1``. - A strided extension would substitute ``v() := unbind_dims(jnp.arange( - start, stop, step), fv)`` and otherwise follow the same shape — the - change is purely in the recognised range form, the bind/unbind cycle - below is unchanged. - - **Non-zero start** (``range(a, b, 1)`` with ``a != 0``): same template - as the strided case; only the recognised range form changes. - - **Non-bare index expressions** (``delta((2*v(),), w)``, - ``delta((f(v()),), w)``, etc.): currently requires the final index - entry to be a bare call ``v()`` of a stream var op. Generalizing to - arbitrary index expressions is a scatter, not a bind: materialize the - index expression and the weight separately over ``v``, then - ``jnp.zeros(N).at[indices].set(values)`` (for Sum; analogous for - other monoids using ``.add``/``.min``/``.max``/...). This is a - different leaf operation from ``bind_dims`` and warrants a sibling - rule rather than an extension of this one. + bind_dims(reduce(M, streams, delta(idx', body[v() := unbind_dims(streams[v], fv)])), fv) """ @implements(Monoid.reduce) @@ -250,31 +300,76 @@ def _(self, monoid: Monoid, body, streams: Streams): if not (isinstance(body, Term) and body.op == delta): return fwd() - indices, weight = body.args - assert isinstance(indices, tuple) + index, weight = body.args + assert isinstance(index, tuple) - if not indices: - return monoid.reduce(weight, streams) + if not index: + return fwd() - head_indices, tail_index = indices[:-1], indices[-1] - if not (isinstance(tail_index, Term) and tail_index.op in streams): + head_index, tail_index = index[0], index[1:] + if not (isinstance(head_index, Term) and head_index.op in streams): return fwd() - tail_op: Operation = tail_index.op - tail_stream = streams[tail_op] - if not (isinstance(tail_stream, Term) and _is_simple_range(tail_stream)): + head_op: Operation = head_index.op + head_stream = streams[head_op] + if not ( + isinstance(head_stream, range) + and head_stream.start == 0 + and head_stream.step == 1 + ): return fwd() - fresh_op = Operation.define(tail_op) - indices = jnp.arange(_range_stop(tail_stream)) - if isinstance(indices, jax.Array) and len(indices) == 0: - return monoid.identity + tail_streams = {k: v for (k, v) in streams.items() if k != head_op} + + # peel the head index: substitute it into the weight (slicing direct + # uses, materializing the rest) along a fresh named dim, but bind that + # dim only *after* the surrounding reduce -- see the class docstring. + + fresh_op = Operation.define(head_op) + + def _jax_getitem(arr, index): + inner_index, outer_index = [], [] + progress = False + for i in index: + if isinstance(i, Term) and i.op == head_op: + inner_index.append( + slice(head_stream.start, head_stream.stop, head_stream.step) + ) + outer_index.append(fresh_op()) + progress = True + else: + inner_index.append(slice(None)) + outer_index.append(i) + if progress: + return jax_getitem(jax_getitem(arr, inner_index), outer_index) + return fwd(arr, index) + + slice_subst = typing.cast(Interpretation, {jax_getitem: _jax_getitem}) + sliced_weight = handler(slice_subst)(evaluate)(weight) + sliced_streams = handler(slice_subst)(evaluate)(tail_streams) + + gather_subst = typing.cast( + Interpretation, + { + head_op: deffn( + unbind_dims( + jnp.arange( + head_stream.start, head_stream.stop, head_stream.step + ), + fresh_op, + ) + ) + }, + ) + gathered_weight = handler(gather_subst)(evaluate)(sliced_weight) + gathered_streams = handler(gather_subst)(evaluate)(sliced_streams) - fresh_stream = unbind_dims(indices, fresh_op) - subst_intp = typing.cast(Interpretation, {tail_op: deffn(fresh_stream)}) - fresh_body = bind_dims(handler(subst_intp)(evaluate)(weight), fresh_op) - fresh_streams = {k: v for (k, v) in streams.items() if k != tail_op} - return monoid.reduce(delta(head_indices, fresh_body), fresh_streams) + inner = ( + monoid.reduce(delta(tail_index, gathered_weight), gathered_streams) + if gathered_streams + else gathered_weight + ) + return bind_dims(inner, fresh_op) class ReduceDependentRangeMask(ObjectInterpretation): @@ -324,15 +419,16 @@ def _(self, monoid: Monoid, body, streams: Streams): simple_ranges = { k: v for (k, v) in streams.items() - if isinstance(v, Term) and _is_simple_range(v) + if isinstance(v, range) and v.start == 0 and v.step == 1 } for u, u_stream in simple_ranges.items(): if fvsof(u_stream) & stream_vars: continue - for v, v_stream in simple_ranges.items(): + for v, v_stream in streams.items(): if ( isinstance(v_stream, Term) + and v_stream.op == jnp.arange and isinstance(_range_stop(v_stream), Term) and _range_stop(v_stream).op == u ): @@ -355,38 +451,6 @@ def _(self, monoid: Monoid, body, streams: Streams): return fwd() -class ReduceRange(ObjectInterpretation): - """Replace concrete-range stream values with materialized ``jnp.arange``. - - reduce(M, streams ∪ {v: range(a, b, s)}, body) - ≡ reduce(M, streams ∪ {v: jnp.arange(a, b, s)}, body) - - when ``a``, ``b``, ``s`` are concrete and ``body`` is not a delta term. - Delegates the actual reduction to whichever handler picks up the - materialized ``jax.Array`` streams. - """ - - @implements(Monoid.reduce) - def _(self, monoid: Monoid, body, streams: Streams): - if isinstance(body, Term) and body.op == delta: - return fwd() - - new_streams: dict = {} - any_replaced = False - for k, v in streams.items(): - if isinstance(v, Term) and v.op == range: - new_streams[k] = jnp.arange( - _range_start(v), _range_stop(v), _range_step(v) - ) - any_replaced = True - else: - new_streams[k] = v - - if not any_replaced: - return fwd() - return monoid.reduce(body, new_streams) - - # Cross-cutting delta rules not yet implemented: # # - **Delta-commuting** (DC-hoist): for any pure op ``f`` (no Scoped binders @@ -406,23 +470,115 @@ def _(self, monoid: Monoid, body, streams: Streams): # a subsequence of ``idx_b`` (or vice versa). Refuse to fire when neither # is a subsequence of the other, since that would silently insert an # outer-product broadcast. -# -# - **Empty-domain detection at the term level**: currently size-0 named -# dims must be resolved by leaf consumers (``bind_dims``, reductors with -# ``initial=monoid.identity``). The empty-domain check is intentionally -# NOT a rule on its own — rewrites stay size-polymorphic and leaf ops -# carry the burden. See the conversation in monoid.py's history for why. + + +class ContractLongestArrayStream(ObjectInterpretation): + @implements(choose_contraction) + def _(self, factors, streams): + lengths = { + k: v.shape[0] if isinstance(v, jax.Array) and v.shape else 0 + for (k, v) in streams.items() + } + longest = max(lengths.values()) + return fwd( + factors, {k: v for (k, v) in streams.items() if lengths[k] == longest} + ) + + +class ReduceSumProductContraction(ObjectInterpretation): + """Fast-path a sum-of-products contraction.""" + + @implements(Sum.reduce) + def _(self, body, streams: Streams): + if not ( + isinstance(body, Term) + and _is_monoid_plus(body.op) + and body.op.__self__ is Product + ): + return fwd() + + factors = body.args + if len(factors) != 2 or not all( + issubclass(typeof(f), jax.Array) for f in factors + ): + return fwd() + + (lhs, rhs) = factors + stream_vars = set(streams.keys()) + + # a fully factored reduce only has streams that are used by all factors + shared = fvsof(lhs) & fvsof(rhs) & stream_vars + if shared != stream_vars: + return fwd() + + if not all(isinstance(v, range) for v in streams.values()): + return fwd() + + # create leading reduction dimensions + delta_key = tuple(k() for k in streams) + pos_lhs = Sum.reduce(delta(delta_key, lhs), streams) + pos_rhs = Sum.reduce(delta(delta_key, rhs), streams) + + dims = "".join(get_symbol(i) for i in range(len(streams))) + contraction = jnp.einsum(f"{dims}...,{dims}...->...", pos_lhs, pos_rhs) + return contraction + + +@jax.jit(static_argnums=(0,)) +def einsum(subscripts: str, /, *operands: jax.Array) -> jax.Array: + """Evaluate an einsum expression using monoid reductions.""" + if not operands: + raise ValueError("einsum requires at least one operand") + + in_spec, out_spec, _ = opt_einsum.parser.parse_einsum_input( + [subscripts, *(op.shape for op in operands)], shapes=True + ) + in_specs = in_spec.split(",") + + all_letters = set(out_spec) | {c for s in in_specs for c in s} + ops = {c: Operation.define(jax.Array, name=c) for c in all_letters} + + sizes: dict[str, int] = {} + for spec, op in zip(in_specs, operands, strict=True): + for l, s in zip(spec, op.shape, strict=True): + if l in sizes and sizes[l] != s: + raise ValueError(f"Dimension {l} given sizes {s} and {sizes[l]}") + else: + sizes[l] = s + for c in out_spec: + if c not in sizes: + raise ValueError(f"einsum: output index {c!r} not present in any input") + + arrays = [Operation.define(jax.Array) for _ in operands] + factors = [ + unbind_dims(arr(), *(ops[c] for c in spec)) + for arr, spec in zip(arrays, in_specs, strict=True) + ] + body = Product.plus(*factors) + + out_tuple = tuple(ops[c]() for c in out_spec) + streams = {op: range(sizes[c]) for c, op in ops.items()} + with handler(NormalizeIntp): + norm = deffn(Sum.reduce(delta(out_tuple, body), streams), *arrays) + result = norm(*operands) + assert isinstance(result, jax.Array) + return result NormalizeIntp.extend( - ArrayReduce(), - ReduceRange(), - ReduceDeltaIndependent(), + ReduceArray(), + ReduceSumProductContraction(), + ReduceArrayGather(), + ReduceDeltaSimpleRange(), ReduceDependentRangeMask(), + DeltaEmpty(), + DeltaFusion(), SumPlusJax(), ProductPlusJax(), MinPlusJax(), MaxPlusJax(), LogSumExpPlusJax(), CartesianProductPlusJax(), + ContractLongestArrayStream(), + PlusJaxUpcast(), ) diff --git a/effectful/handlers/jax/numpy/__init__.py b/effectful/handlers/jax/numpy/__init__.py index 990830d2..f2d2affa 100644 --- a/effectful/handlers/jax/numpy/__init__.py +++ b/effectful/handlers/jax/numpy/__init__.py @@ -1,24 +1,43 @@ +from types import NoneType from typing import TYPE_CHECKING import jax.numpy -from .._handlers import _register_jax_op, _register_jax_op_no_partial_eval +from effectful.handlers.jax._handlers import ( + _einsum_named, + _reduce_named, + _register_jax_op, + _register_jax_op_no_partial_eval, +) +from effectful.ops.semantics import handler +from effectful.ops.types import Operation -_no_overload = ["array", "asarray"] +_NO_OVERLOAD = ["array", "asarray"] +_REDUCTION = ["sum", "prod", "min", "max", "any", "all", "mean", "argmax"] for name, op in jax.numpy.__dict__.items(): - if not callable(op): + wrapped_value = None + if type(op) in (float, NoneType): + wrapped_value = op + elif name in _NO_OVERLOAD: + wrapped_value = _register_jax_op_no_partial_eval(op) + elif callable(op): + wrapped_value = _register_jax_op(op) + else: continue - jax_op = ( - _register_jax_op_no_partial_eval(op) - if name in _no_overload - else _register_jax_op(op) - ) - globals()[name] = jax_op + globals()[name] = wrapped_value -pi = jax.numpy.pi +for name in _REDUCTION: + op = globals()[name] + globals()[name] = handler({op: _reduce_named})(op) + +# einsum = effectful.handlers.jax._handlers.einsum +# tensordot = handler({tensordot: _tensordot_named})(tensordot) + + +einsum = Operation.define(_einsum_named) # Tell mypy about our wrapped functions. if TYPE_CHECKING: - from jax.numpy import * # noqa: F403 + from jax.numpy import * # type: ignore[assignment] # noqa: F403 diff --git a/effectful/handlers/jax/scipy/special.py b/effectful/handlers/jax/scipy/special.py index afe1334b..67b99621 100644 --- a/effectful/handlers/jax/scipy/special.py +++ b/effectful/handlers/jax/scipy/special.py @@ -2,9 +2,11 @@ import jax.scipy.special -from effectful.handlers.jax._handlers import _register_jax_op +from effectful.handlers.jax._handlers import _reduce_named, _register_jax_op +from effectful.ops.semantics import handler logsumexp = _register_jax_op(jax.scipy.special.logsumexp) +logsumexp = handler({logsumexp: _reduce_named})(logsumexp) # Tell mypy about our wrapped functions. if TYPE_CHECKING: diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 5f342f25..42180e91 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -4,34 +4,22 @@ import operator import typing from collections import Counter, UserDict, defaultdict -from collections.abc import Callable, Generator, Iterable, Mapping +from collections.abc import Callable, Generator, Iterable, Mapping, Sequence from dataclasses import dataclass from graphlib import TopologicalSorter from typing import Annotated, Any -from effectful.ops.semantics import ( - coproduct, - evaluate, - fvsof, - fwd, - handler, - typeof, -) +from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler, typeof from effectful.ops.syntax import ( ObjectInterpretation, Scoped, + defdata, deffn, implements, syntactic_eq, syntactic_hash, ) -from effectful.ops.types import ( - Expr, - Interpretation, - NotHandled, - Operation, - Term, -) +from effectful.ops.types import Expr, Interpretation, NotHandled, Operation, Term type Stream[T] = Iterable[T] @@ -125,16 +113,21 @@ def __eq__(self, other): def __hash__(self): return hash(id(self)) - # the weak typing allows us to write monoid.plus(monoid.identity, ) - # and monoid.plus(monoid.identity, ) @Operation.define - def plus(self, *args: Any) -> Any: + def plus(self, *args: W) -> W: """Monoid addition. Handlers supply per-monoid and broadcasting - behavior; the default rule only handles empty / Term cases. + behavior; the default rule only handles identity and zero cases (for + monoids that have a zero). + """ - if not args: - return self.identity - raise NotHandled + if hasattr(self, "zero") and any(a is self.zero for a in args): + return self.zero + + nonident_args = [a for a in args if a is not self.identity] + if len(nonident_args) != len(args): + return self.plus(*nonident_args) + + return defdata(self.plus, *nonident_args) # type: ignore[return-value] @Operation.define def reduce[A, B, U: Body]( @@ -146,24 +139,6 @@ def reduce[A, B, U: Body]( broadcasting behavior; the default rule only handles the empty-stream case. """ - for stream_key, stream_body, streams_tail in outer_stream(streams): - if isinstance(stream_body, Term): - continue - stream_values_iter = iter(stream_body) - - # if we iterate and get a term instead of a real iterator, skip - if isinstance(stream_values_iter, Term): - continue - - new_reduces = [] - for stream_val in stream_values_iter: - with handler({stream_key: deffn(stream_val)}): - eval_args = evaluate((body, streams_tail)) - assert isinstance(eval_args, tuple) - new_reduces.append( - self.reduce(*eval_args) if streams_tail else eval_args[0] - ) - return self.plus(*new_reduces) raise NotHandled @Operation.define @@ -268,16 +243,6 @@ def plus(self, _, *args): return fwd() -class PlusIdentity(ObjectInterpretation): - """x₁ + ... + 0 + ... + xₙ = x₁ + ... + xₙ""" - - @implements(Monoid.plus) - def plus(self, monoid, *args): - if any(x is monoid.identity for x in args): - return monoid.plus(*(x for x in args if x is not monoid.identity)) - return fwd() - - class PlusAssoc(ObjectInterpretation): """x + (y + z) = (x + y) + z = x + y + z""" @@ -338,18 +303,6 @@ def plus(self, monoid: Monoid, *args): return fwd() -class PlusZero(ObjectInterpretation): - """x₁ * ... * 0 * ... * xₙ = 0""" - - @implements(Monoid.plus) - def plus(self, monoid, *args): - if not (isinstance(monoid, MonoidWithZero)): - return fwd() - if any(x is monoid.zero for x in args): - return monoid.zero - return fwd() - - class PlusConsecutiveDups(ObjectInterpretation): """x ⊕ x ⊕ y = x ⊕ y""" @@ -397,15 +350,30 @@ def plus(self, monoid, *args): return fwd() -class ReduceNoStreams(ObjectInterpretation): - """Implements the identity - reduce(R, ∅, body) = 0 - """ - +class ReducePartial(ObjectInterpretation): @implements(Monoid.reduce) - def reduce(self, monoid, _, streams): - if len(streams) == 0: + def _(self, monoid, body, streams): + if not streams: return monoid.identity + + for stream_key, stream_body, streams_tail in outer_stream(streams): + if isinstance(stream_body, Term): + continue + stream_values_iter = iter(stream_body) + + # if we iterate and get a term instead of a real iterator, skip + if isinstance(stream_values_iter, Term): + continue + + new_reduces = [] + for stream_val in stream_values_iter: + with handler({stream_key: deffn(stream_val)}): + eval_args = evaluate((body, streams_tail)) + assert isinstance(eval_args, tuple) + new_reduces.append( + monoid.reduce(*eval_args) if streams_tail else eval_args[0] + ) + return monoid.plus(*new_reduces) return fwd() @@ -435,6 +403,30 @@ def reduce(self, monoid, body, streams): return fwd() +@Operation.define +def choose_contraction(factors: Sequence[Any], streams: Streams) -> Operation: + """Used by `ReduceFactorization` to choose a contraction when there is + ambiguity. Takes the factors and streams that are eligible for contraction + (innermost and non-universal). + + The default behavior is to return the first support-minimal stream in the + streams dictionary. + + """ + assert len(streams) > 0 + + factors = [(a, fvsof(a)) for a in factors] + support: dict = { + k: frozenset(i for i, (_, fvs) in enumerate(factors) if k in fvs) + for k in streams + } + for v, f_v in support.items(): + if any(u_sup < f_v for u, u_sup in support.items() if u is not v): + continue + return v + assert False, "expected at least one subset-minimal stream" + + class ReduceFactorization(ObjectInterpretation): """reduce(⊗(F_v ∪ F_rest), {v} ∪ S) = reduce(⊗F_rest ⊗ reduce(⊗F_v, {v}), S) @@ -459,39 +451,40 @@ def reduce(self, monoid, body, streams): # candidates: innermost-eligible (no remaining stream depends on v), # non-universal (some factor doesn't mention v) - support: dict = {} - for v in streams: - if any(v in fvsof(s) for k, s in streams.items() if k is not v): + eligible = {} + for k, v in streams.items(): + if any(k in fvsof(vv) for kk, vv in streams.items() if k is not kk): continue - f_v = frozenset(i for i, (_, fvs) in enumerate(factors) if v in fvs) - if len(f_v) == len(factors): + if len({i for i, (_, fvs) in enumerate(factors) if k in fvs}) == len( + factors + ): continue # v is universal: leave it in the outer core - support[v] = f_v - - # eliminate a variable with subset-minimal factor support - # (leaves-first; canonical on hierarchical/laminar supports) - inner_stream = None - inner_factor_ids = None - for v, f_v in support.items(): - if any(u_sup < f_v for u, u_sup in support.items() if u is not v): - continue - inner_stream = v - inner_factor_ids = f_v - break + eligible[k] = v - if not inner_stream or not inner_factor_ids: + if not eligible: return fwd() + if len(eligible) == 1: + inner_stream = next(iter(eligible)) + else: + inner_stream = choose_contraction(body.args, eligible) + + inner_factor_ids = frozenset( + i for i, (_, fvs) in enumerate(factors) if inner_stream in fvs + ) inner_factors = [factors[i][0] for i in sorted(inner_factor_ids)] inner_stream_keys = {inner_stream} inner_deps = set().union( - *(factors[i][1] for i in f_v), fvsof(streams[v]) & stream_keys + *(factors[i][1] for i in inner_factor_ids), + fvsof(streams[inner_stream]) & stream_keys, ) - outer_factors = [a for i, (a, _) in enumerate(factors) if i not in f_v] + outer_factors = [ + a for i, (a, _) in enumerate(factors) if i not in inner_factor_ids + ] outer_stream_keys = stream_keys - inner_stream_keys outer_factor_deps = set().union( - *(vars for i, (_, vars) in enumerate(factors) if i not in f_v) + *(vars for i, (_, vars) in enumerate(factors) if i not in inner_factor_ids) ) # find all streams that are used in the inner factors/streams and are @@ -857,6 +850,28 @@ def reduce(self, monoid, body, streams): return result +@Operation.define +def as_float(x: int) -> float: + if isinstance(x, Term): + raise NotHandled + return float(x) + + +class PlusCastFloat(ObjectInterpretation): + @implements(Monoid.plus) + def plus(self, monoid, *args): + typs = [typeof(a) for a in args] + if any(issubclass(t, float) for t in typs) and any( + issubclass(t, int) for t in typs + ): + args = [ + as_float(a) if issubclass(t, int) else a + for (a, t) in zip(args, typs, strict=True) + ] + return monoid.plus(*args) + return fwd() + + class _ExtensibleInterpretation(UserDict, Interpretation): def extend(self, *intps: Interpretation) -> typing.Self: for intp in intps: @@ -865,10 +880,10 @@ def extend(self, *intps: Interpretation) -> typing.Self: NormalizeIntp = _ExtensibleInterpretation().extend( + ReducePartial(), MonoidOverSequence(), MonoidOverMapping(), MonoidOverCallable(), - ReduceNoStreams(), ReduceFusion(), ReduceSplit(), ReduceFactorization(), @@ -877,10 +892,8 @@ def extend(self, *intps: Interpretation) -> typing.Self: ReduceCartesianWeightedStream(), PlusEmpty(), PlusSingle(), - PlusIdentity(), PlusAssoc(), PlusDistr(), - PlusZero(), PlusConsecutiveDups(), PlusDups(), SumPlus(), @@ -890,6 +903,7 @@ def extend(self, *intps: Interpretation) -> typing.Self: ArgMinPlus(), ArgMaxPlus(), CartesianProductPlus(), + PlusCastFloat(), ) """``NormalizeIntp``applies pure-Term rewrites (associativity, distributivity, identity elimination, fusion, factorization, etc.). diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 5ea04fcb..2e198bf3 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -854,6 +854,14 @@ def _(x: object, other) -> bool: return x == other +@syntactic_eq.register(int | float) +def _(x: int | float, other) -> bool: + # Terms often override __eq__ + if isinstance(other, Term) or not isinstance(other, int | float): + return False + return x == other + + @_CustomSingleDispatchCallable def syntactic_hash(__dispatch: Callable[[type], Callable[[Any], int]], x) -> int: """Structural hash compatible with :func:`syntactic_eq`. diff --git a/pyproject.toml b/pyproject.toml index 685aaf55..e3ede785 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,10 @@ Source = "https://github.com/BasisResearch/effectful" [project.optional-dependencies] torch = ["torch"] pyro = ["pyro-ppl>=1.9.1"] -jax = ["jax"] +jax = [ + "jax", + "opt_einsum" +] numpyro = [ "numpyro>=0.19", "jax<0.10" diff --git a/tests/test_handlers_jax_monoid.py b/tests/test_handlers_jax_monoid.py index 18df8401..48913d80 100644 --- a/tests/test_handlers_jax_monoid.py +++ b/tests/test_handlers_jax_monoid.py @@ -1,78 +1,108 @@ import functools -import typing import jax import pytest from jax import random as random import effectful.handlers.jax.numpy as jnp -from effectful.handlers.jax import bind_dims, unbind_dims +from effectful.handlers.jax import bind_dims, jax_getitem, unbind_dims from effectful.handlers.jax.monoid import ( - ArrayReduce, - LogSumExp, - ProductPlusJax, - ReduceDeltaIndependent, + ARRAY_REDUCTORS, + DeltaEmpty, + ReduceArray, + ReduceArrayGather, + ReduceDeltaSimpleRange, ReduceDependentRangeMask, + ReduceSumProductContraction, delta, + einsum, ) -from effectful.handlers.jax.monoid import range as Range -from effectful.handlers.jax.scipy.special import logsumexp -from effectful.ops.monoid import ( - Max, - Min, - NormalizeIntp, - Product, - ReduceWeightedStream, - Sum, -) +from effectful.ops.monoid import NormalizeIntp, Product, Sum from effectful.ops.semantics import coproduct, handler -from effectful.ops.types import Interpretation from tests._monoid_helpers import JaxBackend MONOIDS = [ - pytest.param(Sum, jnp.sum, id="Sum"), - pytest.param(Product, jnp.prod, id="Product"), - pytest.param(Min, jnp.min, id="Min"), - pytest.param(Max, jnp.max, id="Max"), - pytest.param(LogSumExp, logsumexp, id="LogSumExp"), + pytest.param(monoid, reductor, id=monoid._name) + for (monoid, reductor) in ARRAY_REDUCTORS.items() ] +@pytest.fixture(scope="module") +def rng_key(): + return random.PRNGKey(0) + + @pytest.fixture def backend() -> JaxBackend: return JaxBackend() +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_array_gather(monoid, reductor, backend: JaxBackend): + (x, k) = backend.define_vars("x", "k", ret="scalar") + X = jnp.arange(3) + + lhs = monoid.reduce(x(), {x: X}) + rhs = monoid.reduce(unbind_dims(X, k), {k: range(X.shape[0])}) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceArrayGather()) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_array_gather_dep(monoid, reductor, backend: JaxBackend): + (x, y) = backend.define_vars("x", "y", ret="scalar") + f = backend.define_vars("f", arg_types=(backend.scalar_typ,), ret="stream") + g = backend.define_vars( + "g", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) + X = jnp.arange(3) + + lhs = monoid.reduce(g(x(), y()), {y: f(x()), x: X}) + rhs = monoid.reduce( + g(unbind_dims(X[:3], x), y()), {y: f(x()), x: range(X.shape[0])} + ) + backend.check_rewrite( + lhs=lhs, rhs=rhs, rule=coproduct(ReduceArrayGather(), ReduceDeltaSimpleRange()) + ) + + @pytest.mark.parametrize("monoid,reductor", MONOIDS) def test_reduce_array_1(monoid, reductor, backend: JaxBackend): (x, k) = backend.define_vars("x", "k", ret="scalar") - X = backend.define_vars("X", ret="stream") + X = jnp.arange(5) - lhs = monoid.reduce(x(), {x: X()}) - rhs = reductor(bind_dims(unbind_dims(X(), k), k), axis=0) - backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ArrayReduce()) + lhs = monoid.reduce(x(), {x: X}) + rhs = reductor(bind_dims(unbind_dims(X, k), k), axis=(0,)) + backend.check_rewrite( + lhs=lhs, + rhs=rhs, + rule=functools.reduce( + coproduct, # type: ignore[arg-type] + [ReduceArrayGather(), ReduceArray(), ReduceDeltaSimpleRange()], + ), + ) @pytest.mark.parametrize("monoid,reductor", MONOIDS) def test_reduce_array_2(monoid, reductor, backend: JaxBackend): (x, y, k1, k2) = backend.define_vars("x", "y", "k1", "k2", ret="scalar") - (X, Y) = backend.define_vars("X", "Y", ret="stream") + X = jnp.arange(5) + Y = jnp.arange(7) f = backend.define_vars( "f", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" ) - lhs = monoid.reduce(f(x(), y()), {x: X(), y: Y()}) + lhs = monoid.reduce(f(x(), y()), {x: X, y: Y}) rhs = reductor( - bind_dims( - reductor( - bind_dims(f(unbind_dims(X(), k1), unbind_dims(Y(), k2)), k2), - axis=0, - ), - k1, + bind_dims(f(unbind_dims(X, k1), unbind_dims(Y, k2)), k1, k2), axis=(0, 1) + ) + backend.check_rewrite( + lhs=lhs, + rhs=rhs, + rule=functools.reduce( + coproduct, # type: ignore[arg-type] + [ReduceArrayGather(), ReduceArray(), ReduceDeltaSimpleRange()], ), - axis=0, ) - backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ArrayReduce()) @pytest.mark.parametrize("monoid,reductor", MONOIDS) @@ -80,58 +110,93 @@ def test_reduce_array_3(monoid, reductor, backend: JaxBackend): """Stream `y` is `g(x())` — depends on the bound element of X. The reducer must inline ``g`` along the same named dim used to unbind `x`.""" (x, y, k1, k2) = backend.define_vars("x", "y", "k1", "k2", ret="scalar") - X = backend.define_vars("X", ret="stream") + X = jnp.arange(5) f = backend.define_vars( "f", arg_types=[backend.scalar_typ, backend.scalar_typ], ret="scalar" ) g = backend.define_vars("g", arg_types=[backend.scalar_typ], ret="stream") - lhs = monoid.reduce(f(x(), y()), {x: X(), y: g(x())}) + lhs = monoid.reduce(f(x(), y()), {x: X, y: g(x())}) rhs = reductor( bind_dims( - reductor( - bind_dims( - f(unbind_dims(X(), k1), unbind_dims(g(unbind_dims(X(), k1)), k2)), - k2, - ), - axis=0, - ), - k1, + monoid.reduce(f(unbind_dims(X, x), y()), {y: g(unbind_dims(X, x))}), x + ), + axis=(0,), + ) + backend.check_rewrite( + lhs=lhs, + rhs=rhs, + rule=functools.reduce( + coproduct, # type: ignore[arg-type] + [ + ReduceArrayGather(), + ReduceArray(), + ReduceDeltaSimpleRange(), + DeltaEmpty(), + ], ), - axis=0, ) - backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ArrayReduce()) -def test_jax_weighted_reduce(backend: JaxBackend): - """Sum over a single stream with ``Product`` weights lowers to - ``jnp.sum(w(X) * body(X))`` under ``NormalizeIntp`` ∘ ``ArrayReduce``. +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_arange_reduce_direct_full(monoid, reductor, backend: JaxBackend): + """A full-range direct index ``A[v()]`` over ``v: arange(N)`` slices the + whole axis (``A[0:N:1]``) and reduces it -- no materialized-arange gather. + """ + (v, k) = backend.define_vars("v", "k", ret="scalar") + A = backend.define_vars("A", ret="stream") - Verifies that the desugaring rule composes cleanly with the JAX lowering - so existing handlers need no changes to support weighted streams. + lhs = monoid.reduce(jax_getitem(A(), [v()]), {v: range(7)}) + rhs = reductor( + bind_dims(jax_getitem(jax_getitem(A(), [slice(0, 7, 1)]), [k()]), k), + axis=(0,), + ) + backend.check_rewrite( + lhs=lhs, rhs=rhs, rule=coproduct(ReduceArray(), ReduceDeltaSimpleRange()) + ) - """ - (x, k) = backend.define_vars("x", "k", ret="scalar") - X = backend.define_vars("X", ret="stream") - body = backend.define_vars("body", arg_types=[backend.scalar_typ], ret="scalar") - w = backend.define_vars("w", arg_types=[backend.scalar_typ], ret="scalar") - ws = Product.weighted(X(), w) - lhs = Sum.reduce(body(x()), {x: ws}) - rhs = jnp.sum( - bind_dims(w(unbind_dims(X(), k)) * body(unbind_dims(X(), k)), k), axis=0 +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_arange_reduce_indirect(monoid, reductor, backend: JaxBackend): + """When the range var is used both as a direct index and as a value + (``A[v()] + v()``), the direct use slices and the indirect use materializes + the range, both aligned on the same fresh dim.""" + (v, k) = backend.define_vars("v", "k", ret="scalar") + A = jnp.arange(10) + + lhs = monoid.reduce(jax_getitem(A, [v()]) + v(), {v: range(5)}) + rhs = reductor( + bind_dims( + jax_getitem(jax_getitem(A, [slice(0, 5, 1)]), [k()]) + + unbind_dims(jnp.arange(5), k), + k, + ), + axis=(0,), ) backend.check_rewrite( - lhs=lhs, - rhs=rhs, - rule=functools.reduce( - coproduct, - typing.cast( - list[Interpretation], - [ReduceWeightedStream(), ArrayReduce(), ProductPlusJax()], - ), + lhs=lhs, rhs=rhs, rule=coproduct(ReduceArray(), ReduceDeltaSimpleRange()) + ) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_arange_reduce_two_streams(monoid, reductor, backend: JaxBackend): + """Two arange streams indexing a 2-D array slice both axes and reduce over + both at once.""" + (u, w, k1, k2) = backend.define_vars("u", "w", "k1", "k2", ret="scalar") + A = jnp.arange(8 * 9).reshape((8, 9)) + + lhs = monoid.reduce(jax_getitem(A, [u(), w()]), {u: range(4), w: range(5)}) + rhs = reductor( + bind_dims( + jax_getitem(jax_getitem(A, [slice(0, 4, 1), slice(0, 5, 1)]), [k1(), k2()]), + k1, + k2, ), + axis=(0, 1), + ) + backend.check_rewrite( + lhs=lhs, rhs=rhs, rule=coproduct(ReduceArray(), ReduceDeltaSimpleRange()) ) @@ -152,15 +217,26 @@ def test_reduce_delta_empty(monoid, reductor, backend: JaxBackend): lhs = monoid.reduce(delta((), x()), {x: X()}) rhs = monoid.reduce(x(), {x: X()}) - backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDeltaIndependent()) + backend.check_rewrite( + lhs=lhs, rhs=rhs, rule=coproduct(ReduceDeltaSimpleRange(), DeltaEmpty()) + ) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_delta_empty_arange(monoid, reductor, backend: JaxBackend): + x = backend.define_vars("x", ret="scalar") + f = backend.define_vars("f", arg_types=[backend.scalar_typ], ret="scalar") + + lhs = monoid.reduce(delta((x(),), f(x())), {x: range(0)}) + rhs = bind_dims(f(unbind_dims(jnp.array([]), x)), x) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDeltaSimpleRange()) @pytest.mark.parametrize("monoid,reductor", MONOIDS) def test_reduce_delta_independent_one(monoid, reductor, backend: JaxBackend): """One R1 step: peel the final preserved index off a delta. - reduce(M, {y: Y()}, delta((y(),), f(y()))) - ≡ reduce(M, {}, delta((), bind_dims(f(unbind_dims(Y(), k)), k))) + reduce(M, {y: Y()}, delta((y(),), f(y()))) ≡ bind_dims(f(unbind_dims(Y(), k)), k) """ (y, k) = backend.define_vars("y", "k", ret="scalar") f = backend.define_vars("f", arg_types=[backend.scalar_typ], ret="scalar") @@ -168,9 +244,9 @@ def test_reduce_delta_independent_one(monoid, reductor, backend: JaxBackend): # We use a concrete range here instead of an abstract one, because # unbind_dims is undefined on empty arrays (and the rewrite produces a # different rhs in this case) - lhs = monoid.reduce(delta((y(),), f(y())), {y: Range(3)}) - rhs = monoid.reduce(bind_dims(f(unbind_dims(jnp.arange(3), k)), k), {}) - backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDeltaIndependent()) + lhs = monoid.reduce(delta((y(),), f(y())), {y: range(3)}) + rhs = bind_dims(f(unbind_dims(jnp.arange(3), k)), k) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDeltaSimpleRange()) @pytest.mark.parametrize("monoid,reductor", MONOIDS) @@ -188,17 +264,34 @@ def test_reduce_delta_independent_preserves_others( "f", arg_types=[backend.scalar_typ, backend.scalar_typ], ret="scalar" ) - lhs = monoid.reduce(delta((x(), y()), f(x(), y())), {x: Range(2), y: Range(3)}) - rhs = monoid.reduce( - bind_dims( - bind_dims( - f(unbind_dims(jnp.arange(2), x), unbind_dims(jnp.arange(3), k)), k - ), - x, + lhs = monoid.reduce(delta((x(), y()), f(x(), y())), {x: range(2), y: range(3)}) + rhs = bind_dims( + bind_dims(f(unbind_dims(jnp.arange(2), x), unbind_dims(jnp.arange(3), k)), k), x + ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDeltaSimpleRange()) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_delta_simple_dep(monoid, reductor, backend: JaxBackend): + (x, y) = backend.define_vars("x", "y", ret="scalar") + X = jnp.arange(3) + + lhs = monoid.reduce( + delta((x(),), unbind_dims(X, x) + y()), + {x: range(3), y: jnp.stack([x(), x() + 1])}, + ) + rhs = bind_dims( + monoid.reduce( + delta((), unbind_dims(X, x) + y()), + { + y: jnp.stack( + [unbind_dims(jnp.arange(3), x), unbind_dims(jnp.arange(3), x) + 1] + ) + }, ), - {}, + x, ) - backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDeltaIndependent()) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDeltaSimpleRange()) @pytest.mark.parametrize("monoid,reductor", MONOIDS) @@ -217,10 +310,9 @@ def test_reduce_dependent_range_mask(monoid, reductor, backend: JaxBackend): body = f(u(), v()) - lhs = monoid.reduce(body, {u: Range(0, N, 1), v: Range(0, u(), 1)}) + lhs = monoid.reduce(body, {u: range(N), v: jnp.arange(u())}) rhs = monoid.reduce( - jnp.where(v() < u(), body, monoid.identity), - {u: Range(0, N, 1), v: Range(0, N, 1)}, + jnp.where(v() < u(), body, monoid.identity), {u: range(N), v: range(N)} ) backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDependentRangeMask()) @@ -243,14 +335,44 @@ def test_reduce_dependent_range_mask_delta_body(monoid, reductor, backend: JaxBa weight = f(u(), v()) idx = (u(), v()) - lhs = monoid.reduce(delta(idx, weight), {u: Range(0, N, 1), v: Range(0, u(), 1)}) + lhs = monoid.reduce(delta(idx, weight), {u: range(N), v: jnp.arange(u())}) rhs = monoid.reduce( delta(idx, jnp.where(v() < u(), weight, monoid.identity)), - {u: Range(0, N, 1), v: Range(0, N, 1)}, + {u: range(N), v: range(N)}, ) backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDependentRangeMask()) +def test_reduce_contraction_single(backend: JaxBackend): + i = backend.define_vars("i", ret="scalar") + (A, B) = backend.define_vars( + "A", "B", arg_types=(backend.scalar_typ,), ret="scalar" + ) + + lhs = Sum.reduce(Product.plus(A(i()), B(i())), {i: range(5)}) + rhs = jnp.einsum( + "a...,a...->...", + Sum.reduce(delta((i(),), A(i())), {i: range(5)}), + Sum.reduce(delta((i(),), B(i())), {i: range(5)}), + ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceSumProductContraction()) + + +def test_reduce_contraction_double(backend: JaxBackend): + i, j = backend.define_vars("i", "j", ret="scalar") + (A, B) = backend.define_vars( + "A", "B", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) + + lhs = Sum.reduce(Product.plus(A(i(), j()), B(i(), j())), {i: range(5), j: range(7)}) + rhs = jnp.einsum( + "ab...,ab...->...", + Sum.reduce(delta((i(), j()), A(i(), j())), {i: range(5), j: range(7)}), + Sum.reduce(delta((i(), j()), B(i(), j())), {i: range(5), j: range(7)}), + ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceSumProductContraction()) + + def test_reduce_matmul(backend: JaxBackend): key = jax.random.PRNGKey(0) # Define dimensions @@ -264,8 +386,135 @@ def test_reduce_matmul(backend: JaxBackend): with handler(NormalizeIntp): actual = Sum.reduce( delta((b(), i(), k()), unbind_dims(X, b, i, j) * unbind_dims(Y, b, j, k)), - {b: Range(B), i: Range(I), j: Range(J), k: Range(K)}, + {b: range(B), i: range(I), j: range(J), k: range(K)}, ) expected = jnp.einsum("bij,bjk->bik", X, Y) assert jnp.allclose(actual, expected) + + +EINSUM_CASES = [ + pytest.param("ij,jk->ik", {"i": 64, "j": 64, "k": 64}, id="matmul"), + pytest.param( + "bij,bjk->bik", + {"b": 16, "i": 32, "j": 32, "k": 32}, + id="batched_matmul", + ), + pytest.param( + "a,abi,bcij,cdij->ij", + {"a": 4, "b": 4, "c": 4, "d": 4, "i": 8, "j": 8}, + id="mixed_rank", + ), + # ───────────────────────── single-operand reshuffles ───────────────────── + # No contraction across operands — these stress the diagonal/transpose/sum + # rewrites rather than any pairwise product ordering. + pytest.param("ij->ji", {"i": 256, "j": 256}, id="transpose"), + pytest.param("ijk->", {"i": 96, "j": 96, "k": 96}, id="full_reduce"), + pytest.param("ijk->k", {"i": 96, "j": 96, "k": 96}, id="partial_reduce"), + # Repeated index *within* one operand — exercises the implicit-diagonal path + # in ReduceDeltaSimpleRange (no explicit jnp.diagonal step). + pytest.param("ii->", {"i": 1024}, id="trace"), + pytest.param("ii->i", {"i": 1024}, id="diagonal"), + pytest.param("bii->b", {"b": 256, "i": 128}, id="batched_trace"), + pytest.param("iij->ij", {"i": 128, "j": 128}, id="diagonal_keep"), + # ───────────────────────── no-shared-index blowups ─────────────────────── + # Output is the full outer product — nothing contracts, so the result tensor + # is as large as the dense intermediate. Pure broadcast cost. + pytest.param("i,j->ij", {"i": 1024, "j": 1024}, id="outer_product"), + pytest.param("ij,kl->ijkl", {"i": 32, "j": 32, "k": 32, "l": 32}, id="outer_4d"), + # Element-wise: every index shared, none contracted. + pytest.param("ij,ij->ij", {"i": 512, "j": 512}, id="hadamard"), + # ───────────────────────── ordering-sensitive products ─────────────────── + # Skewed matrix chain: contracting middle-first (b,d small) is orders of + # magnitude cheaper than the left-to-right order, which materializes a big + # a×c intermediate. The classic "matrix chain order matters" case. + pytest.param( + "ab,bc,cd->ad", {"a": 256, "b": 2, "c": 256, "d": 2}, id="skewed_chain" + ), + pytest.param( + "ab,bc,cd,de->ae", + {"a": 50, "b": 40, "c": 30, "d": 20, "e": 10}, + id="chain_4", + ), + pytest.param( + "ab,bc,cd,de,ef->af", + {"a": 12, "b": 11, "c": 10, "d": 9, "e": 8, "f": 7}, + id="chain_5", + ), + # ───────────────────────── tensor-network shapes ───────────────────────── + # Cyclic / hyperedge contractions with no tree decomposition into matmuls; + # every operand shares indices with two others. + pytest.param("ij,jk,ki->", {"i": 64, "j": 64, "k": 64}, id="trace_of_product"), + pytest.param("ij,jk,ik->", {"i": 48, "j": 48, "k": 48}, id="triangle"), + pytest.param("ijk,jl,kl->il", {"i": 24, "j": 24, "k": 24, "l": 24}, id="hyperedge"), + # Star: many operands share one contracted index, fanning into a large + # outer-product output. + pytest.param( + "ai,bi,ci,di->abcd", + {"a": 8, "b": 8, "c": 8, "d": 8, "i": 32}, + id="star_contraction", + ), + # Bilinear / quadratic form over a batch (attention-score flavored). + pytest.param("bi,ij,bj->b", {"b": 128, "i": 64, "j": 64}, id="bilinear"), + # Batched matrix chain — batch axis rides through three contractions. + pytest.param( + "bij,bjk,bkl->bil", + {"b": 16, "i": 24, "j": 24, "k": 24, "l": 24}, + id="batched_chain", + ), + # Multi-index contraction surface: a whole axis-group (c) contracts at once. + pytest.param( + "abc,cde->abde", + {"a": 12, "b": 12, "c": 12, "d": 12, "e": 12}, + id="tensor_contraction", + ), + # Leading scalar factor plus an element-wise reduce — checks that the + # rank-0 operand threads through without spawning a degenerate axis. + pytest.param(",ij,ij->", {"i": 256, "j": 256}, id="scalar_scaled_reduce"), +] + + +def _make_operands(spec: str, sizes: dict[str, int], key: jax.Array) -> list[jax.Array]: + in_part = spec.split("->")[0] + in_specs = in_part.split(",") + keys = random.split(key, len(in_specs)) + return [ + random.normal(k, tuple(sizes[c] for c in s) if s else ()) + for k, s in zip(keys, in_specs, strict=True) + ] + + +@pytest.mark.parametrize( + "impl", [pytest.param(jnp.einsum, id="jax"), pytest.param(einsum, id="effectful")] +) +@pytest.mark.parametrize("spec,sizes", EINSUM_CASES) +@pytest.mark.benchmark(warmup=True, warmup_iterations=1) +def test_einsum_bench(benchmark, impl, spec, sizes, rng_key): + """Time one ``(spec, impl)`` pair. Group by ``spec`` to compare ``jnp`` + against ``effectful`` for the same subscript pattern (see module docstring). + """ + operands = _make_operands(spec, sizes, rng_key) + + @jax.jit + def f(*operands): + return impl(spec, *operands) + + @benchmark + def _run(): + return f(*operands).block_until_ready() + + +@pytest.mark.parametrize("spec,sizes", EINSUM_CASES) +def test_einsum_matches_jnp(spec: str, sizes, rng_key): + """``einsum`` returns the same result as ``jnp.einsum`` for every spec + in ``EINSUM_EXAMPLES``. + """ + operands = _make_operands(spec, sizes, rng_key) + actual = einsum(spec, *operands) + expected = jnp.einsum(spec, *operands) + assert actual.shape == expected.shape, ( + f"shape mismatch for {spec!r}: got {actual.shape}, expected {expected.shape}" + ) + assert jnp.allclose(actual, expected, atol=1e-4, rtol=1e-4), ( + f"value mismatch for {spec!r}" + ) diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index 4d243ca1..484a8c4e 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -21,15 +21,13 @@ PlusDistr, PlusDups, PlusEmpty, - PlusIdentity, PlusSingle, - PlusZero, Product, ReduceCartesianWeightedStream, ReduceDistributeCartesianProduct, ReduceFactorization, ReduceFusion, - ReduceNoStreams, + ReducePartial, ReduceSplit, ReduceWeightedStream, Sum, @@ -172,7 +170,7 @@ def test_plus_identity_right(monoid, backend: Backend): lhs = monoid.plus(x(), monoid.identity) rhs = monoid.plus(x()) - backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusIdentity()) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule={}) @pytest.mark.parametrize("monoid", ALL_MONOIDS) @@ -182,7 +180,7 @@ def test_plus_identity_left(monoid, backend: Backend): lhs = monoid.plus(monoid.identity, x()) rhs = monoid.plus(x()) - backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusIdentity()) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule={}) @pytest.mark.parametrize("monoid", ALL_MONOIDS) @@ -240,10 +238,10 @@ def test_plus_distributes(backend: Backend): def test_plus_distributes_constant(backend: Backend): - a, b, c, d = backend.define_vars("a", "b", "c", "d", ret="scalar") - lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d()), 5) + a, b, c, d, e = backend.define_vars("a", "b", "c", "d", "e", ret="scalar") + lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d()), e()) rhs = Product.plus( - 5, + e(), Sum.plus( Product.plus(a(), c()), Product.plus(a(), d()), @@ -314,16 +312,16 @@ def test_plus_zero(monoid, backend: Backend): lhs_right = monoid.plus(a(), monoid.zero) lhs_left = monoid.plus(monoid.zero, a()) rhs = monoid.zero - backend.check_rewrite(lhs=lhs_right, rhs=rhs, rule=PlusZero()) - backend.check_rewrite(lhs=lhs_left, rhs=rhs, rule=PlusZero()) + backend.check_rewrite(lhs=lhs_right, rhs=rhs, rule={}) + backend.check_rewrite(lhs=lhs_left, rhs=rhs, rule={}) @pytest.mark.parametrize("monoid", ALL_MONOIDS) def test_partial_1(monoid, backend: Backend): x = backend.define_vars("x", ret="scalar") lhs = monoid.reduce(x(), {x: []}) - rhs = monoid.identity - backend.check_rewrite(lhs=lhs, rhs=rhs, rule={}) + rhs = monoid.plus() + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReducePartial()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) @@ -332,8 +330,8 @@ def test_partial_2(monoid, backend: Backend): Y = backend.define_vars("Y", ret="stream") lhs = monoid.reduce(x(), {y: Y(), x: []}) - rhs = monoid.identity - backend.check_rewrite(lhs=lhs, rhs=rhs, rule={}) + rhs = monoid.plus() + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReducePartial()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) @@ -343,7 +341,7 @@ def test_partial_3(monoid, backend: Backend): lhs = monoid.reduce(x(), {y: Y(), x: [a(), b()]}) rhs = monoid.plus(monoid.reduce(a(), {y: Y()}), monoid.reduce(b(), {y: Y()})) - backend.check_rewrite(lhs=lhs, rhs=rhs, rule={}) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReducePartial()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) @@ -353,7 +351,7 @@ def test_partial_4(monoid, backend: Backend): lhs = monoid.reduce(x(), {y: f(x()), x: [a(), b()]}) rhs = monoid.plus(monoid.reduce(a(), {y: f(a())}), monoid.reduce(b(), {y: f(b())})) - backend.check_rewrite(lhs=lhs, rhs=rhs, rule={}) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReducePartial()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) @@ -401,7 +399,7 @@ def test_reduce_no_streams(monoid, backend: Backend): lhs = monoid.reduce(a(), {}) rhs = monoid.identity - backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceNoStreams()) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReducePartial()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) @@ -482,15 +480,15 @@ def test_reduce_independent_3_negative(backend: Backend): def test_reduce_independent_4(backend: Backend): - a, b, c = backend.define_vars("a", "b", "c", ret="scalar") + a, b, c, d = backend.define_vars("a", "b", "c", "d", ret="scalar") A, B, C = backend.define_vars("A", "B", "C", ret="stream") f = backend.define_vars( "f", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" ) - lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c()), 7), {a: A(), b: B(), c: C()}) + lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c()), d()), {a: A(), b: B(), c: C()}) rhs = Product.plus( - 7, + d(), Sum.reduce(Product.plus(a()), {a: A()}), Sum.reduce( Product.plus(b(), Sum.reduce(Product.plus(f(b(), c())), {c: C()})), From a6617e887b18de3193913b28a624fdd37532cc31 Mon Sep 17 00:00:00 2001 From: eb8680 Date: Thu, 11 Jun 2026 14:35:51 -0400 Subject: [PATCH 09/10] Break ReduceArrayGather rule into 2 steps (#681) --- effectful/handlers/jax/monoid.py | 29 ++++++++++------- effectful/ops/monoid.py | 52 +++++++++++++++++++++++++++++++ tests/test_handlers_jax_monoid.py | 48 +++++++++++++++++++++++++--- tests/test_ops_monoid.py | 28 +++++++++++++++++ 4 files changed, 140 insertions(+), 17 deletions(-) diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py index fbd845d8..96b708e8 100644 --- a/effectful/handlers/jax/monoid.py +++ b/effectful/handlers/jax/monoid.py @@ -156,7 +156,17 @@ def plus(self, *args): class ReduceArrayGather(ObjectInterpretation): - """M.reduce(body, {k: a} ∪ S) ≡ M.reduce(body[k := a[k']], {k': range(a.shape[0])} ∪ S)""" + """Split an array-valued stream into an index range and a length-1 stream: + + M.reduce(body, {k: a} ∪ S) ≡ M.reduce(body, {i: range(a.shape[0]), k: (a[i()],)} ∪ S) + + where ``i`` is fresh and ``a[i()] = unbind_dims(a, i)``. The length-1 stream + ``{k: (a[i()],)}`` is then eliminated by + :class:`~effectful.ops.monoid.EliminateSingletonStreams`, which substitutes + ``k := a[i()]`` into the body and the remaining streams. Together the two + steps perform the gather + ``M.reduce(body[k := a[i()]], {i: range(a.shape[0])} ∪ S)``. + """ @implements(Monoid.reduce) def reduce(self, monoid, body, streams): @@ -169,26 +179,21 @@ def reduce(self, monoid, body, streams): body_fvs = fvsof(body) stream_keys = set(streams) - body_subst = {} - streams_subst = {} - range_streams = {} + new_streams: dict = {} progress = False for k, v in streams.items(): if is_eager_array(v) and k in body_fvs and not (fvsof(v) & stream_keys): - kk = Operation.define(k) - body_subst[k] = deffn(unbind_dims(v, kk)) - streams_subst[k] = kk - range_streams[kk] = range(v.shape[0]) + index = Operation.define(k) + new_streams[index] = range(v.shape[0]) + new_streams[k] = (unbind_dims(v, index),) progress = True else: - range_streams[k] = v + new_streams[k] = v if not progress: return fwd() - subst_body = handler(body_subst)(evaluate)(body) - subst_streams = handler(streams_subst)(evaluate)(range_streams) - return monoid.reduce(subst_body, subst_streams) + return monoid.reduce(body, new_streams) class Reductor(Protocol): diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 42180e91..067f29e5 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -872,6 +872,57 @@ def plus(self, monoid, *args): return fwd() +class EliminateSingletonStreams(ObjectInterpretation): + """Eliminate a length-1 stream by substituting its sole element. + + reduce(M, body, {k: (v,)} ∪ S) = reduce(M, body[k := v], S[k := v]) + + Fires only when the sole element ``v`` is a :class:`Term`, i.e. a *symbolic* + singleton. This is exactly the form ``ReduceArrayGather`` produces (a gather + ``(a[i()],)``) and, more generally, every dependent singleton that + :class:`ReducePartial` cannot peel -- a non-outermost stream whose element + references another stream var. Concrete enumerated streams (``[0]``, + ``range(1)``) and monoid sentinels (``CartesianProduct.identity == [()]``) + have non-``Term`` elements and are left to ``ReducePartial`` / the + per-monoid rules. + + Unlike ``ReducePartial``, this peels the stream wherever it sits in the loop + nest and substitutes symbolically rather than unrolling, leaving a + vectorized index range (e.g. the gather's range) intact instead of + materializing it. + """ + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + # Eliminate *all* symbolic length-1 streams in one pass via a + # simultaneous substitution. Doing them together (rather than one per + # invocation) keeps an interleaving reduction rule -- e.g. + # ``ReduceArray`` consuming a now-live index range -- from firing + # between eliminations, so sibling index ranges stay together and fuse + # into a single reduction. + singletons = { + k: vs[0] + for k, vs in streams.items() + if not isinstance(vs, Term) + and isinstance(vs, collections.abc.Sequence) + and len(vs) == 1 + and isinstance(vs[0], Term) + } + if not singletons: + return fwd() + + subs = {k: deffn(v) for k, v in singletons.items()} + new_body = handler(subs)(evaluate)(body) + new_streams = { + kk: handler(subs)(evaluate)(vv) + for kk, vv in streams.items() + if kk not in singletons + } + # reduce over no streams is a single (empty) assignment, i.e. the body + # itself -- not the monoid identity. + return monoid.reduce(new_body, new_streams) if new_streams else new_body + + class _ExtensibleInterpretation(UserDict, Interpretation): def extend(self, *intps: Interpretation) -> typing.Self: for intp in intps: @@ -881,6 +932,7 @@ def extend(self, *intps: Interpretation) -> typing.Self: NormalizeIntp = _ExtensibleInterpretation().extend( ReducePartial(), + EliminateSingletonStreams(), MonoidOverSequence(), MonoidOverMapping(), MonoidOverCallable(), diff --git a/tests/test_handlers_jax_monoid.py b/tests/test_handlers_jax_monoid.py index 48913d80..e410a3e6 100644 --- a/tests/test_handlers_jax_monoid.py +++ b/tests/test_handlers_jax_monoid.py @@ -17,7 +17,12 @@ delta, einsum, ) -from effectful.ops.monoid import NormalizeIntp, Product, Sum +from effectful.ops.monoid import ( + EliminateSingletonStreams, + NormalizeIntp, + Product, + Sum, +) from effectful.ops.semantics import coproduct, handler from tests._monoid_helpers import JaxBackend @@ -44,6 +49,23 @@ def test_reduce_array_gather(monoid, reductor, backend: JaxBackend): lhs = monoid.reduce(x(), {x: X}) rhs = monoid.reduce(unbind_dims(X, k), {k: range(X.shape[0])}) + backend.check_rewrite( + lhs=lhs, + rhs=rhs, + rule=coproduct(ReduceArrayGather(), EliminateSingletonStreams()), + ) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_array_gather_step1(monoid, reductor, backend: JaxBackend): + """Step 1 alone: an array stream becomes an index range plus a length-1 + stream holding the gathered element. ``ReduceArrayGather`` does not perform + the gather substitution itself -- that is ``EliminateSingletonStreams``.""" + (x, k) = backend.define_vars("x", "k", ret="scalar") + X = jnp.arange(3) + + lhs = monoid.reduce(x(), {x: X}) + rhs = monoid.reduce(x(), {k: range(X.shape[0]), x: (unbind_dims(X, k),)}) backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceArrayGather()) @@ -56,12 +78,17 @@ def test_reduce_array_gather_dep(monoid, reductor, backend: JaxBackend): ) X = jnp.arange(3) + # The dependent stream ``y: f(x())`` gets the *gathered element* X[x] + # substituted for x -- i.e. ``f(X[x])`` -- not the bare index. lhs = monoid.reduce(g(x(), y()), {y: f(x()), x: X}) rhs = monoid.reduce( - g(unbind_dims(X[:3], x), y()), {y: f(x()), x: range(X.shape[0])} + g(unbind_dims(X, x), y()), + {y: f(unbind_dims(X, x)), x: range(X.shape[0])}, ) backend.check_rewrite( - lhs=lhs, rhs=rhs, rule=coproduct(ReduceArrayGather(), ReduceDeltaSimpleRange()) + lhs=lhs, + rhs=rhs, + rule=coproduct(ReduceArrayGather(), EliminateSingletonStreams()), ) @@ -77,7 +104,12 @@ def test_reduce_array_1(monoid, reductor, backend: JaxBackend): rhs=rhs, rule=functools.reduce( coproduct, # type: ignore[arg-type] - [ReduceArrayGather(), ReduceArray(), ReduceDeltaSimpleRange()], + [ + ReduceArrayGather(), + EliminateSingletonStreams(), + ReduceArray(), + ReduceDeltaSimpleRange(), + ], ), ) @@ -100,7 +132,12 @@ def test_reduce_array_2(monoid, reductor, backend: JaxBackend): rhs=rhs, rule=functools.reduce( coproduct, # type: ignore[arg-type] - [ReduceArrayGather(), ReduceArray(), ReduceDeltaSimpleRange()], + [ + ReduceArrayGather(), + EliminateSingletonStreams(), + ReduceArray(), + ReduceDeltaSimpleRange(), + ], ), ) @@ -131,6 +168,7 @@ def test_reduce_array_3(monoid, reductor, backend: JaxBackend): coproduct, # type: ignore[arg-type] [ ReduceArrayGather(), + EliminateSingletonStreams(), ReduceArray(), ReduceDeltaSimpleRange(), DeltaEmpty(), diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index 484a8c4e..9976fd6d 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -10,6 +10,7 @@ import effectful.handlers.jax.numpy as jnp from effectful.ops.monoid import ( CartesianProduct, + EliminateSingletonStreams, Max, Min, Monoid, @@ -354,6 +355,33 @@ def test_partial_4(monoid, backend: Backend): backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReducePartial()) +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_eliminate_singleton_into_sibling(monoid, backend: Backend): + """A length-1 stream substitutes its element into the body *and* into a + sibling stream's definition, then drops out of the nest.""" + x, y, a = backend.define_vars("x", "y", "a", ret="scalar") + f = backend.define_vars("f", arg_types=(backend.scalar_typ,), ret="stream") + g = backend.define_vars( + "g", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) + + lhs = monoid.reduce(g(x(), y()), {x: (a(),), y: f(x())}) + rhs = monoid.reduce(g(a(), y()), {y: f(a())}) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=EliminateSingletonStreams()) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_eliminate_singleton_only_stream(monoid, backend: Backend): + """When the length-1 stream is the only stream, reducing over the now-empty + nest yields the substituted body itself (not the monoid identity).""" + x, a = backend.define_vars("x", "a", ret="scalar") + f = backend.define_vars("f", arg_types=(backend.scalar_typ,), ret="scalar") + + lhs = monoid.reduce(f(x()), {x: (a(),)}) + rhs = f(a()) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=EliminateSingletonStreams()) + + @pytest.mark.parametrize("monoid", ALL_MONOIDS) def test_reduce_body_sequence(monoid, backend: Backend): x = backend.define_vars("x", ret="scalar") From 7043dbadc661eacea02df8fff41a1a4fcb4da3d8 Mon Sep 17 00:00:00 2001 From: eb8680 Date: Thu, 11 Jun 2026 14:36:23 -0400 Subject: [PATCH 10/10] Remove dead code (#683) --- effectful/ops/types.py | 67 ------------------------------------------ 1 file changed, 67 deletions(-) diff --git a/effectful/ops/types.py b/effectful/ops/types.py index 77823b7b..505deffb 100644 --- a/effectful/ops/types.py +++ b/effectful/ops/types.py @@ -42,59 +42,6 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: return self.func(self.dispatch, *args, **kwargs) -class _CustomSingleDispatchMethod[**P, **Q, S, T]: - """Method analog of :class:`_CustomSingleDispatchCallable`. - - The wrapped function has signature ``(self, dispatch, *args, **kwargs)``, - where ``dispatch`` is :meth:`functools.singledispatch.dispatch`. As a - descriptor, it binds ``self`` on attribute access, so callers invoke it - as ``instance.method(*args, **kwargs)``. - """ - - def __init__( - self, - func: Callable[Concatenate[Any, Callable[[type], Callable[Q, S]], P], T], - ): - self.func = func - self._registry = functools.singledispatch(func) - self.__signature__ = inspect.signature( - functools.partial(func, None, None) # type: ignore[arg-type] - ) - functools.update_wrapper(self, func) # type: ignore[arg-type] - - @property - def dispatch(self): - return self._registry.dispatch - - @property - def register(self): - return self._registry.register - - def __get__(self, instance, owner=None): - if instance is None: - return self - return _BoundCustomSingleDispatchMethod(self, instance) - - -class _BoundCustomSingleDispatchMethod: - __slots__ = ("_method", "_instance") - - def __init__(self, method: _CustomSingleDispatchMethod, instance: Any): - self._method = method - self._instance = instance - - @property - def dispatch(self): - return self._method.dispatch - - @property - def register(self): - return self._method.register - - def __call__(self, *args, **kwargs): - return self._method.func(self._instance, self._method.dispatch, *args, **kwargs) - - class _ClassMethodOpDescriptor(classmethod): def __init__(self, define, *args, **kwargs): super().__init__(*args, **kwargs) @@ -412,20 +359,6 @@ def func(*args, **kwargs): op.register = default._registry.register # type: ignore[attr-defined] return op - @define.register(_CustomSingleDispatchMethod) - @classmethod - def _define_customsingledispatchmethod( - cls, default: _CustomSingleDispatchMethod, **kwargs - ): - @functools.wraps(default.func) - def _wrapper(obj, *args, **kwargs): - return default.__get__(obj)(*args, **kwargs) - - op = cls.define(_wrapper, **kwargs) - op.register = default.register # type: ignore[attr-defined] - op.dispatch = default.dispatch # type: ignore[attr-defined] - return op - @typing.final def __default_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> "Expr[V]": """The default rule is used when the operation is not handled.