From 76c6762114e2a149dc18e343c28928e843c81f74 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Sun, 19 Apr 2026 20:33:21 -0400 Subject: [PATCH 1/4] Support consumed dims in Prior for distributions like Categorical Distributions such as Categorical(p)->() consume the parameter dimension entirely without re-emitting it in the output. Prior previously raised UnsupportedShapeError for these cases. - Add _get_dist_param_names and _get_consumed_dims helpers - Fix _param_dims_work to exclude consumed dims from subset check - Fix handle_dims to append core dim indices and handle all-consumed case - Auto-compute core_dims from rv_op metadata when not explicitly set - Fix xdist path to pass core_dims=None (not ()) when absent - Add TestConsumedDims with 8 tests covering Categorical, Interpolated, MvNormal cov, and the guard for non-consumed incompatible dims Closes #499 --- pymc_extras/prior.py | 125 +++++++++++++++++++++++++++++++++++++++---- tests/test_prior.py | 112 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 226 insertions(+), 11 deletions(-) diff --git a/pymc_extras/prior.py b/pymc_extras/prior.py index c30101e1d..6f39f8160 100644 --- a/pymc_extras/prior.py +++ b/pymc_extras/prior.py @@ -84,6 +84,7 @@ def custom_transform(x): from __future__ import annotations import copy +import inspect import warnings from collections.abc import Callable, Sequence @@ -147,7 +148,62 @@ def _remove_leading_xs(args: list[str | int]) -> list[str | int]: return args -def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVariable: +def _get_dist_param_names(dist_class) -> list[str]: + """Return ordered parameter names from a PyMC distribution's dist() signature.""" + sig = inspect.signature(dist_class.dist) + return [p for p in sig.parameters if p not in ("size", "kwargs")] + + +def _get_consumed_dims(distribution: str, parameters: dict) -> set[str]: + """Return the set of dims consumed as core dims but not re-emitted in the output. + + Uses rv_op metadata (ndims_params, ndim_supp) to determine which dims of + Prior-valued parameters are "consumed" by the distribution (i.e., they are + core dims of the parameter that do not appear in the output). + + The number of purely consumed dims for parameter i is: + max(0, ndims_params[i] - ndim_supp) + + Examples + -------- + - Categorical(p)->(): ndims_params=[1], ndim_supp=0 -> p's dim is consumed + - MvNormal(n),(n,n)->(n): ndims_params=[1,2], ndim_supp=1 + -> mu: 1-1=0 (not consumed), cov: 2-1=1 (last dim consumed) + - Dirichlet(a)->(a): ndims_params=[1], ndim_supp=1 -> 1-1=0 (not consumed) + """ + dist_class = getattr(pm, distribution, None) + if dist_class is None: + return set() + rv_op = getattr(dist_class, "rv_op", None) + ndims_params = getattr(rv_op, "ndims_params", None) + ndim_supp = getattr(rv_op, "ndim_supp", 0) or 0 + if not ndims_params: + return set() + + try: + param_names = _get_dist_param_names(dist_class) + except (ValueError, TypeError): + return set() + + consumed: set[str] = set() + for i, pname in enumerate(param_names): + if i >= len(ndims_params): + break + n_consumed = ndims_params[i] - ndim_supp + if n_consumed <= 0: + continue + val = parameters.get(pname) + if isinstance(val, Prior) and val.dims: + consumed.update(val.dims[-n_consumed:]) + return consumed + + +def handle_dims( + x: pt.TensorLike, + dims: Dims, + desired_dims: Dims, + consumed_dims: frozenset[str] = frozenset(), +) -> pt.TensorVariable: """Take a tensor of dims `dims` and align it to `desired_dims`. Doesn't check for validity of the dims @@ -160,6 +216,11 @@ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVa The current dimensions of the tensor. desired_dims : Dims The desired dimensions of the tensor. + consumed_dims : frozenset[str], optional + Dimensions that are core dims of the parameter and will be consumed + internally by the distribution op. These are excluded from the batch-dim + alignment check and dimshuffle, but preserved as trailing axes so the + distribution op can consume them. Defaults to an empty frozenset. Returns ------- @@ -192,20 +253,28 @@ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVa dims = dims if isinstance(dims, tuple) else (dims,) desired_dims = desired_dims if isinstance(desired_dims, tuple) else (desired_dims,) - if difference := set(dims).difference(desired_dims): + batch_dims = tuple(d for d in dims if d not in consumed_dims) + + if difference := set(batch_dims).difference(desired_dims): raise UnsupportedShapeError( f"Dims {dims} of data are not a subset of the desired dims {desired_dims}. " f"{difference} is missing from the desired dims." ) - aligned_dims = np.array(dims)[:, None] == np.array(desired_dims) + if not batch_dims: + return x + + n_core_dims = len(dims) - len(batch_dims) + + aligned_dims = np.array(batch_dims)[:, None] == np.array(desired_dims) missing_dims = aligned_dims.sum(axis=0) == 0 new_idx = aligned_dims.argmax(axis=0) args = ["x" if missing else idx for (idx, missing) in zip(new_idx, missing_dims, strict=False)] args = _remove_leading_xs(args) - return x.dimshuffle(*args) + core_dim_indices = list(range(len(batch_dims), len(batch_dims) + n_core_dims)) + return x.dimshuffle(*args, *core_dim_indices) DimHandler = Callable[[pt.TensorLike, Dims], pt.TensorLike] @@ -244,8 +313,10 @@ def create_dim_handler(desired_dims: Dims) -> DimHandler: """ - def func(x: pt.TensorLike, dims: Dims) -> pt.TensorVariable: - return handle_dims(x, dims, desired_dims) + def func( + x: pt.TensorLike, dims: Dims, consumed_dims: frozenset[str] = frozenset() + ) -> pt.TensorVariable: + return handle_dims(x, dims, desired_dims, consumed_dims) return func @@ -551,6 +622,20 @@ class Prior: created, by default None or no transform. The transformation must be registered with `register_tensor_transform` function or be available in either `pytensor.tensor` or `pymc.math`. + core_dims : Dims, optional + The core (support) dimensions of the distribution. When not provided, + core_dims is inferred automatically from the distribution's ``rv_op`` + metadata: + + - Distributions whose output has support dims (``ndim_supp > 0``, e.g. + ``Dirichlet``) use the last ``ndim_supp`` dims of ``dims``. + - Distributions that consume input dims without re-emitting them + (``ndim_supp == 0``, e.g. ``Categorical``) use the dims identified as + consumed by :func:`_get_consumed_dims`. + + Only needs to be set explicitly when the automatic inference is + insufficient or when using the ``xdist=True`` path with a custom + distribution. Examples -------- @@ -665,6 +750,17 @@ def __init__( core_dims = tuple(core_dims) self.core_dims = core_dims + # Auto-compute core_dims when not explicitly set. See the class + # docstring for the two cases handled. + if not self.core_dims and self.dims: + dist_class = getattr(pm, self.distribution, None) + rv_op = getattr(dist_class, "rv_op", None) if dist_class else None + ndim_supp = getattr(rv_op, "ndim_supp", 0) or 0 + if ndim_supp > 0: + self.core_dims = tuple(self.dims[-ndim_supp:]) + else: + self.core_dims = tuple(_get_consumed_dims(self.distribution, self.parameters)) + self._checks() @property @@ -797,9 +893,12 @@ def _param_dims_work(self) -> None: if (other_dims := getattr(value, "dims", None)) is not None: other_dims_set.update(other_dims) - if not other_dims_set.issubset(self.dims): + consumed = _get_consumed_dims(self.distribution, self.parameters) + remaining = other_dims_set - consumed + + if not remaining.issubset(self.dims): raise UnsupportedShapeError( - f"Parameter dims {other_dims} are not a subset of the prior dims {self.dims}" + f"Parameter dims {remaining} are not a subset of the prior dims {self.dims}" ) def __str__(self) -> str: @@ -827,7 +926,11 @@ def _create_parameter(self, param, value, name, xdist: bool = False): if xdist: return value.create_variable(child_name, xdist=True) else: - return self.dim_handler(value.create_variable(child_name), value.dims or ()) + consumed = _get_consumed_dims(self.distribution, self.parameters) + child_consumed = frozenset(d for d in (value.dims or ()) if d in consumed) + return self.dim_handler( + value.create_variable(child_name), value.dims or (), child_consumed + ) def _create_centered_variable(self, name: str, xdist: bool = False): parameters = { @@ -836,7 +939,7 @@ def _create_centered_variable(self, name: str, xdist: bool = False): } if xdist: pymc_distribution = _get_pymc_dim_distribution(self.distribution) - core_dims_kwargs = {"core_dims": self.core_dims} + core_dims_kwargs = {"core_dims": self.core_dims if self.core_dims else None} else: pymc_distribution = self.pymc_distribution core_dims_kwargs = {} @@ -870,7 +973,7 @@ def handle_variable(var_name: str): } if xdist: pymc_distribution = _get_pymc_dim_distribution(self.distribution) - core_dims_kwargs = {"core_dims": self.core_dims} + core_dims_kwargs = {"core_dims": self.core_dims if self.core_dims else None} else: pymc_distribution = self.pymc_distribution core_dims_kwargs = {} diff --git a/tests/test_prior.py b/tests/test_prior.py index 9f5c151f3..06123c726 100644 --- a/tests/test_prior.py +++ b/tests/test_prior.py @@ -1402,3 +1402,115 @@ def test_core_dims(self): ip = m.initial_point() with pytensor.config.change_flags(mode="FAST_COMPILE"): assert m.compile_logp()(ip) == ref_m.compile_logp()(ip) + + +class TestConsumedDims: + """Tests for distributions that 'consume' dims in their parameters. + + Some distributions take vector/matrix parameters but produce lower-rank + output. For example, Categorical(p)->() consumes the 'p' dimension entirely. + These consumed dims must not be required to appear in the output dims. + + See: https://github.com/pymc-devs/pymc-extras/issues/499 + """ + + # --- Categorical: (p)->(), ndim_supp=0, ndims_params=[1] --- + + def test_categorical_construction(self): + """Prior("Categorical", p=Prior("Dirichlet", ..., dims="probs")) should not raise. + + Categorical: (p)->(), ndim_supp=0, ndims_params=[1] + p's dim is fully consumed — it must not be required to appear in output dims. + """ + p = pr.Prior("Dirichlet", a=[1, 1, 1], dims="probs") + cat = pr.Prior("Categorical", p=p, dims="trial") + assert cat.dims == ("trial",) + + def test_categorical_create_variable(self): + """Non-xdist create_variable produces correct shape. + + Categorical: (p)->(), ndim_supp=0, ndims_params=[1] + The consumed 'probs' dim disappears; output shape is (trial,). + """ + p = pr.Prior("Dirichlet", a=[1, 1, 1], dims="probs") + cat = pr.Prior("Categorical", p=p, dims="trial") + with pm.Model(coords={"probs": range(3), "trial": range(10)}) as m: + cat.create_variable("y") + assert fast_eval(m["y"]).shape == (10,) + + @pytest.mark.filterwarnings("ignore:The `pymc.dims` module is experimental") + def test_categorical_create_variable_xdist(self): + """xdist=True create_variable produces correct named dims. + + Categorical: (p)->(), ndim_supp=0, ndims_params=[1] + core_dims is auto-computed as ('probs',); output dims are ('trial',). + """ + p = pr.Prior("Dirichlet", a=DataArray([1, 1, 1], dims="probs"), dims="probs") + cat = pr.Prior("Categorical", p=p, dims="trial") + with pm.Model(coords={"probs": range(3), "trial": range(10)}) as m: + cat.create_variable("y", xdist=True) + assert m.named_vars_to_dims["y"] == ("trial",) + + def test_categorical_batch_broadcast(self): + """Non-xdist: batch dim 'geo' passes through, core dim 'probs' is consumed. + + Categorical: (p)->(), ndim_supp=0, ndims_params=[1] + p has dims ('geo', 'probs'): 'probs' is consumed, 'geo' broadcasts into output. + """ + p = pr.Prior("Dirichlet", a=np.ones((2, 3)), dims=("geo", "probs")) + cat = pr.Prior("Categorical", p=p, dims=("geo", "trial")) + with pm.Model(coords={"geo": range(2), "probs": range(3), "trial": range(10)}) as m: + cat.create_variable("y") + assert fast_eval(m["y"]).shape == (2, 10) + + @pytest.mark.filterwarnings("ignore:The `pymc.dims` module is experimental") + def test_categorical_batch_broadcast_xdist(self): + """xdist=True: batch dim 'geo' passes through, core dim 'probs' is consumed. + + Categorical: (p)->(), ndim_supp=0, ndims_params=[1] + p has dims ('geo', 'probs'): 'probs' is consumed, 'geo' broadcasts into output. + """ + a = DataArray(np.ones((2, 3)), dims=("geo", "probs")) + p = pr.Prior("Dirichlet", a=a, dims=("geo", "probs")) + cat = pr.Prior("Categorical", p=p, dims=("geo", "trial")) + with pm.Model(coords={"geo": range(2), "probs": range(3), "trial": range(10)}) as m: + cat.create_variable("y", xdist=True) + assert m.named_vars_to_dims["y"] == ("geo", "trial") + + # --- Interpolated: (x),(x),(x)->(), ndim_supp=0, ndims_params=[1,1,1] --- + + def test_interpolated_construction(self): + """Interpolated with x_points as a Prior with dims should not raise. + + Interpolated: (x),(x),(x)->(), ndim_supp=0, ndims_params=[1,1,1] + All three parameter dims are consumed — none appear in the scalar output. + """ + x_pts = pr.Prior("Normal", mu=0, sigma=1, dims="knots") + interp = pr.Prior( + "Interpolated", + x_points=x_pts, + pdf_points=np.ones(5) / 5, + dims="obs", + ) + assert interp.dims == ("obs",) + + # --- MvNormal: (n),(n,n)->(n), ndim_supp=1, ndims_params=[1,2] --- + + def test_mvnormal_cov_extra_dim(self): + """MvNormal cov with 2 named dims: the 2nd dim is consumed, 1st re-emitted. + + MvNormal: (n),(n,n)->(n), ndim_supp=1, ndims_params=[1,2] + cov has 2 core dims but output only has 1 — so cov's 2nd dim is "extra consumed". + 'component' is re-emitted in the output; 'component_' is consumed. + """ + cov = pr.Prior("Wishart", nu=5, V=np.eye(3), dims=("component", "component_")) + mv = pr.Prior("MvNormal", mu=np.zeros(3), cov=cov, dims="component") + assert mv.dims == ("component",) + + # --- Guard: non-consumed incompatible dims still raise --- + + def test_incompatible_dims_still_raise(self): + """A truly incompatible (non-consumed) dim still raises UnsupportedShapeError.""" + inner = pr.Prior("Normal", dims="other") + with pytest.raises(pr.UnsupportedShapeError): + pr.Prior("Normal", mu=inner, dims="channel") From 02fe2d085d351e55aa82a27c2b42f0821836e5dc Mon Sep 17 00:00:00 2001 From: Will Dean Date: Sun, 19 Apr 2026 20:38:30 -0400 Subject: [PATCH 2/4] Document rightmost-axes convention for consumed and core dims --- pymc_extras/prior.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/pymc_extras/prior.py b/pymc_extras/prior.py index 6f39f8160..fa2f35901 100644 --- a/pymc_extras/prior.py +++ b/pymc_extras/prior.py @@ -161,14 +161,19 @@ def _get_consumed_dims(distribution: str, parameters: dict) -> set[str]: Prior-valued parameters are "consumed" by the distribution (i.e., they are core dims of the parameter that do not appear in the output). + Following the NumPy generalized ufunc (gufunc) convention that PyMC inherits, + core dims are always the **rightmost** axes of a parameter tensor. Consumed + dims are therefore always taken from the trailing (rightmost) end of a + parameter's ``dims`` tuple — never from the left. + The number of purely consumed dims for parameter i is: max(0, ndims_params[i] - ndim_supp) Examples -------- - - Categorical(p)->(): ndims_params=[1], ndim_supp=0 -> p's dim is consumed - - MvNormal(n),(n,n)->(n): ndims_params=[1,2], ndim_supp=1 - -> mu: 1-1=0 (not consumed), cov: 2-1=1 (last dim consumed) + - Categorical(p)->(): ndims_params=[1], ndim_supp=0 -> p's rightmost dim is consumed + - MvNormal(mu,cov)->(n): ndims_params=[1,2], ndim_supp=1 + -> mu: 1-1=0 (not consumed), cov: 2-1=1 (rightmost dim consumed) - Dirichlet(a)->(a): ndims_params=[1], ndim_supp=1 -> 1-1=0 (not consumed) """ dist_class = getattr(pm, distribution, None) @@ -218,9 +223,11 @@ def handle_dims( The desired dimensions of the tensor. consumed_dims : frozenset[str], optional Dimensions that are core dims of the parameter and will be consumed - internally by the distribution op. These are excluded from the batch-dim - alignment check and dimshuffle, but preserved as trailing axes so the - distribution op can consume them. Defaults to an empty frozenset. + internally by the distribution op. Following the NumPy gufunc convention + inherited by PyMC, consumed dims are always the **rightmost** axes of the + parameter tensor — never the leftmost. These are excluded from the + batch-dim alignment check and dimshuffle, but preserved as trailing axes + so the distribution op can consume them. Defaults to an empty frozenset. Returns ------- @@ -623,9 +630,10 @@ class Prior: be registered with `register_tensor_transform` function or be available in either `pytensor.tensor` or `pymc.math`. core_dims : Dims, optional - The core (support) dimensions of the distribution. When not provided, - core_dims is inferred automatically from the distribution's ``rv_op`` - metadata: + The core (support) dimensions of the distribution. Following the NumPy + gufunc convention inherited by PyMC, core dims are always the + **rightmost** axes of the output tensor. When not provided, core_dims is + inferred automatically from the distribution's ``rv_op`` metadata: - Distributions whose output has support dims (``ndim_supp > 0``, e.g. ``Dirichlet``) use the last ``ndim_supp`` dims of ``dims``. From 8dcd301fbab316d0f6d5e82b08c172b39818c78d Mon Sep 17 00:00:00 2001 From: Will Dean Date: Sun, 19 Apr 2026 20:40:33 -0400 Subject: [PATCH 3/4] Trim core_dims docstring to be brief --- pymc_extras/prior.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/pymc_extras/prior.py b/pymc_extras/prior.py index fa2f35901..0b6b4614e 100644 --- a/pymc_extras/prior.py +++ b/pymc_extras/prior.py @@ -630,20 +630,9 @@ class Prior: be registered with `register_tensor_transform` function or be available in either `pytensor.tensor` or `pymc.math`. core_dims : Dims, optional - The core (support) dimensions of the distribution. Following the NumPy - gufunc convention inherited by PyMC, core dims are always the - **rightmost** axes of the output tensor. When not provided, core_dims is - inferred automatically from the distribution's ``rv_op`` metadata: - - - Distributions whose output has support dims (``ndim_supp > 0``, e.g. - ``Dirichlet``) use the last ``ndim_supp`` dims of ``dims``. - - Distributions that consume input dims without re-emitting them - (``ndim_supp == 0``, e.g. ``Categorical``) use the dims identified as - consumed by :func:`_get_consumed_dims`. - - Only needs to be set explicitly when the automatic inference is - insufficient or when using the ``xdist=True`` path with a custom - distribution. + The rightmost support dimensions of the distribution output. Inferred + automatically from ``rv_op`` metadata; only set explicitly for custom + distributions. Examples -------- From 4f83d614fb320a00c2e79d6e8efaaa40700838c1 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Sun, 19 Apr 2026 20:43:53 -0400 Subject: [PATCH 4/4] Trim all new docstrings to be brief --- pymc_extras/prior.py | 31 ++++--------------------------- 1 file changed, 4 insertions(+), 27 deletions(-) diff --git a/pymc_extras/prior.py b/pymc_extras/prior.py index 0b6b4614e..e2a3bc0a9 100644 --- a/pymc_extras/prior.py +++ b/pymc_extras/prior.py @@ -155,27 +155,7 @@ def _get_dist_param_names(dist_class) -> list[str]: def _get_consumed_dims(distribution: str, parameters: dict) -> set[str]: - """Return the set of dims consumed as core dims but not re-emitted in the output. - - Uses rv_op metadata (ndims_params, ndim_supp) to determine which dims of - Prior-valued parameters are "consumed" by the distribution (i.e., they are - core dims of the parameter that do not appear in the output). - - Following the NumPy generalized ufunc (gufunc) convention that PyMC inherits, - core dims are always the **rightmost** axes of a parameter tensor. Consumed - dims are therefore always taken from the trailing (rightmost) end of a - parameter's ``dims`` tuple — never from the left. - - The number of purely consumed dims for parameter i is: - max(0, ndims_params[i] - ndim_supp) - - Examples - -------- - - Categorical(p)->(): ndims_params=[1], ndim_supp=0 -> p's rightmost dim is consumed - - MvNormal(mu,cov)->(n): ndims_params=[1,2], ndim_supp=1 - -> mu: 1-1=0 (not consumed), cov: 2-1=1 (rightmost dim consumed) - - Dirichlet(a)->(a): ndims_params=[1], ndim_supp=1 -> 1-1=0 (not consumed) - """ + """Return dims consumed by a distribution's rightmost core axes but not re-emitted in the output.""" dist_class = getattr(pm, distribution, None) if dist_class is None: return set() @@ -222,12 +202,9 @@ def handle_dims( desired_dims : Dims The desired dimensions of the tensor. consumed_dims : frozenset[str], optional - Dimensions that are core dims of the parameter and will be consumed - internally by the distribution op. Following the NumPy gufunc convention - inherited by PyMC, consumed dims are always the **rightmost** axes of the - parameter tensor — never the leftmost. These are excluded from the - batch-dim alignment check and dimshuffle, but preserved as trailing axes - so the distribution op can consume them. Defaults to an empty frozenset. + Rightmost dims of the parameter consumed by the distribution and not + re-emitted in the output. Excluded from batch alignment but preserved + as trailing axes for the distribution op. Returns -------