diff --git a/effectful/handlers/numpyro.py b/effectful/handlers/numpyro.py index f0369d37..329cf471 100644 --- a/effectful/handlers/numpyro.py +++ b/effectful/handlers/numpyro.py @@ -226,10 +226,12 @@ def _pos_base_dist(self) -> dist.Distribution: @functools.cached_property def _is_eager(self) -> bool: - return all( - (not isinstance(x, Term) or is_eager_array(x)) - for x in (*self.args, *self.kwargs.values()) - ) + def _arg_is_eager(x): + if isinstance(x, _DistributionTerm): + return x._is_eager + return not isinstance(x, Term) or is_eager_array(x) + + return all(_arg_is_eager(x) for x in (*self.args, *self.kwargs.values())) @property def op(self): @@ -357,7 +359,7 @@ def to_event(self, reinterpreted_batch_ndims=None) -> dist.Distribution: raise NotHandled @defop - def expand(self, batch_shape) -> jax.Array: + def expand(self, batch_shape) -> dist.Distribution: if not self._is_eager: raise NotHandled @@ -396,6 +398,33 @@ def __str__(self): expand = _DistributionTerm.expand +@defdata.register(dist.Distribution) +class _DistributionMethodTerm(_DistributionTerm): + """Term for distribution-method ops returning the abstract ``dist.Distribution`` + (``expand``, ``to_event``). Catches the ``defdata`` fallthrough that would + otherwise hit ``_CallableTerm``. See #666.""" + + def __init__(self, ty, op, *args, **kwargs): + receiver = args[0] if args else None + constr = ( + receiver._constr + if isinstance(receiver, _DistributionTerm) + else dist.Distribution + ) + super().__init__(constr, op, *args, **kwargs) + + @functools.cached_property + def _pos_base_dist(self) -> dist.Distribution: + # Delegate to NumPyro's method of the same name on the materialised receiver. + receiver = self._args[0] + base = ( + receiver._pos_base_dist + if isinstance(receiver, _DistributionTerm) + else receiver + ) + return getattr(base, self._op.__name__)(*self._args[1:], **self._kwargs) + + @defop def Cauchy(loc=0.0, scale=1.0, **kwargs) -> dist.Cauchy: raise NotHandled diff --git a/tests/test_handlers_numpyro.py b/tests/test_handlers_numpyro.py index 270def25..befd5cbe 100644 --- a/tests/test_handlers_numpyro.py +++ b/tests/test_handlers_numpyro.py @@ -480,7 +480,7 @@ def add_case(raw_dist, raw_params, batch_shape, xfail=None): ("concentration0", f"exp(rand({batch_shape + indep_shape}))"), ), batch_shape, - xfail="to_event not implemented", + xfail="to_event composed with expand_by on indexed dims not implemented", ) # Dirichlet.to_event @@ -494,7 +494,7 @@ def add_case(raw_dist, raw_params, batch_shape, xfail=None): ), ), batch_shape, - xfail="to_event not implemented", + xfail="to_event composed with expand_by on indexed dims not implemented", ) # TransformedDistribution.to_event @@ -513,7 +513,7 @@ def add_case(raw_dist, raw_params, batch_shape, xfail=None): ("high", f"2. + rand({batch_shape + indep_shape})"), ), batch_shape, - xfail="to_event not implemented", + xfail="TransformedDistribution not implemented", ) @@ -929,3 +929,94 @@ def test_distribution_typeof(): typeof(dist.Normal(jax_getitem(jnp.array([0, 1, 2]), [defop(jax.Array)()]))) is numpyro.distributions.continuous.Normal ) + + +def test_distribution_method_chain_on_non_eager_term(): + """Regression test for #666 (narrow). + + ``Normal(mu_term, 1.0).expand([J]).to_event(1)`` must not raise + ``AttributeError`` mid-chain. Previously ``_DistributionTerm.expand`` was + ``@defop``-annotated to return ``jax.Array``, routing ``.expand([J])``'s + result through ``_ArrayTerm`` (no ``.to_event``). The fix annotates + ``expand`` to return ``dist.Distribution`` and registers a fallback + ``_DistributionMethodTerm`` for ``defdata`` dispatch on the abstract base, + so the chain stays in the distribution-term surface. + """ + mu = defop(jax.Array, name="mu") + + expanded = dist.Normal(mu(), 1.0).expand([3]) + assert isinstance(expanded, numpyro.distributions.Distribution) + + chained = expanded.to_event(1) + assert isinstance(chained, numpyro.distributions.Distribution) + + +def test_expand_to_event_shape_laws(): + """Equational laws for ``.expand`` and ``.to_event`` on a distribution term + whose free-variable arg has been bound by an effectful handler. + + These hold for any NumPyro distribution and should survive any future + refactor of how deferred method ops are encoded: + + d.expand(s).batch_shape == tuple(s) + d.expand(s).event_shape == d.event_shape + d.to_event(k).event_shape == d.batch_shape[-k:] + d.event_shape + d.to_event(k).batch_shape == d.batch_shape[:-k] + """ + import jax.numpy as jnp + + from effectful.ops.semantics import handler + + mu = defop(jax.Array, name="mu") + + with handler({mu: lambda: jnp.array(0.0)}): + d = dist.Normal(mu(), 1.0) + assert d.batch_shape == () + assert d.event_shape == () + + expanded = d.expand([3, 4]) + assert expanded.batch_shape == (3, 4) + assert expanded.event_shape == () + + indep = expanded.to_event(1) + assert indep.batch_shape == (3,) + assert indep.event_shape == (4,) + + chained = d.expand([3]).to_event(1) + assert chained.batch_shape == () + assert chained.event_shape == (3,) + assert not chained.support.is_discrete + + +def test_expand_to_event_chain_end_to_end_mcmc(): + """End-to-end regression: the literal #666 idiom — ``Normal(mu_term, 1.0) + .expand([J]).to_event(1)`` with ``mu_term`` bound by an effectful handler — + must trace, build a potential, and run MCMC to completion. + + Before the fix this raised ``AttributeError: '_ArrayTerm' object has no + attribute 'to_event'`` at chain construction. After the fix, the chain + constructs a ``_DistributionMethodTerm`` whose materialised + ``_pos_base_dist`` resolves to a real ``dist.Independent`` wrapping the + handler-bound receiver, so NumPyro's downstream property/sample/log_prob + accesses all resolve. + """ + import jax.numpy as jnp + import jax.random as jr + + from effectful.ops.semantics import handler + + mu = defop(jax.Array, name="mu") + + def model(): + numpyro.sample("theta", dist.Normal(mu(), 1.0).expand([3]).to_event(1)) + + with handler({mu: lambda: jnp.array(0.0)}): + mcmc = numpyro.infer.MCMC( + numpyro.infer.NUTS(model), + num_warmup=20, + num_samples=20, + progress_bar=False, + ) + mcmc.run(jr.PRNGKey(0)) + + assert mcmc.get_samples()["theta"].shape == (20, 3)