Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 88 additions & 11 deletions pymc_extras/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def custom_transform(x):
from __future__ import annotations

import copy
import inspect
import warnings

from collections.abc import Callable, Sequence
Expand Down Expand Up @@ -147,7 +148,47 @@ def _remove_leading_xs(args: list[str | int]) -> list[str | int]:
return args


def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVariable:
def _get_dist_param_names(dist_class) -> list[str]:
"""Return ordered parameter names from a PyMC distribution's dist() signature."""
sig = inspect.signature(dist_class.dist)
return [p for p in sig.parameters if p not in ("size", "kwargs")]


def _get_consumed_dims(distribution: str, parameters: dict) -> set[str]:
"""Return dims consumed by a distribution's rightmost core axes but not re-emitted in the output."""
dist_class = getattr(pm, distribution, None)
if dist_class is None:
return set()
rv_op = getattr(dist_class, "rv_op", None)
ndims_params = getattr(rv_op, "ndims_params", None)
ndim_supp = getattr(rv_op, "ndim_supp", 0) or 0
if not ndims_params:
return set()

try:
param_names = _get_dist_param_names(dist_class)
except (ValueError, TypeError):
return set()

consumed: set[str] = set()
for i, pname in enumerate(param_names):
if i >= len(ndims_params):
break
n_consumed = ndims_params[i] - ndim_supp
if n_consumed <= 0:
continue
val = parameters.get(pname)
if isinstance(val, Prior) and val.dims:
consumed.update(val.dims[-n_consumed:])
return consumed


def handle_dims(
x: pt.TensorLike,
dims: Dims,
desired_dims: Dims,
consumed_dims: frozenset[str] = frozenset(),
) -> pt.TensorVariable:
"""Take a tensor of dims `dims` and align it to `desired_dims`.

Doesn't check for validity of the dims
Expand All @@ -160,6 +201,10 @@ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVa
The current dimensions of the tensor.
desired_dims : Dims
The desired dimensions of the tensor.
consumed_dims : frozenset[str], optional
Rightmost dims of the parameter consumed by the distribution and not
re-emitted in the output. Excluded from batch alignment but preserved
as trailing axes for the distribution op.

Returns
-------
Expand Down Expand Up @@ -192,20 +237,28 @@ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVa
dims = dims if isinstance(dims, tuple) else (dims,)
desired_dims = desired_dims if isinstance(desired_dims, tuple) else (desired_dims,)

if difference := set(dims).difference(desired_dims):
batch_dims = tuple(d for d in dims if d not in consumed_dims)

if difference := set(batch_dims).difference(desired_dims):
raise UnsupportedShapeError(
f"Dims {dims} of data are not a subset of the desired dims {desired_dims}. "
f"{difference} is missing from the desired dims."
)

aligned_dims = np.array(dims)[:, None] == np.array(desired_dims)
if not batch_dims:
return x

n_core_dims = len(dims) - len(batch_dims)

aligned_dims = np.array(batch_dims)[:, None] == np.array(desired_dims)

missing_dims = aligned_dims.sum(axis=0) == 0
new_idx = aligned_dims.argmax(axis=0)

args = ["x" if missing else idx for (idx, missing) in zip(new_idx, missing_dims, strict=False)]
args = _remove_leading_xs(args)
return x.dimshuffle(*args)
core_dim_indices = list(range(len(batch_dims), len(batch_dims) + n_core_dims))
return x.dimshuffle(*args, *core_dim_indices)


DimHandler = Callable[[pt.TensorLike, Dims], pt.TensorLike]
Expand Down Expand Up @@ -244,8 +297,10 @@ def create_dim_handler(desired_dims: Dims) -> DimHandler:

"""

def func(x: pt.TensorLike, dims: Dims) -> pt.TensorVariable:
return handle_dims(x, dims, desired_dims)
def func(
x: pt.TensorLike, dims: Dims, consumed_dims: frozenset[str] = frozenset()
) -> pt.TensorVariable:
return handle_dims(x, dims, desired_dims, consumed_dims)

return func

Expand Down Expand Up @@ -551,6 +606,10 @@ class Prior:
created, by default None or no transform. The transformation must
be registered with `register_tensor_transform` function or
be available in either `pytensor.tensor` or `pymc.math`.
core_dims : Dims, optional
The rightmost support dimensions of the distribution output. Inferred
automatically from ``rv_op`` metadata; only set explicitly for custom
distributions.

Examples
--------
Expand Down Expand Up @@ -665,6 +724,17 @@ def __init__(
core_dims = tuple(core_dims)
self.core_dims = core_dims

# Auto-compute core_dims when not explicitly set. See the class
# docstring for the two cases handled.
if not self.core_dims and self.dims:
dist_class = getattr(pm, self.distribution, None)
rv_op = getattr(dist_class, "rv_op", None) if dist_class else None
ndim_supp = getattr(rv_op, "ndim_supp", 0) or 0
if ndim_supp > 0:
self.core_dims = tuple(self.dims[-ndim_supp:])
else:
self.core_dims = tuple(_get_consumed_dims(self.distribution, self.parameters))

self._checks()

@property
Expand Down Expand Up @@ -797,9 +867,12 @@ def _param_dims_work(self) -> None:
if (other_dims := getattr(value, "dims", None)) is not None:
other_dims_set.update(other_dims)

if not other_dims_set.issubset(self.dims):
consumed = _get_consumed_dims(self.distribution, self.parameters)
remaining = other_dims_set - consumed

if not remaining.issubset(self.dims):
raise UnsupportedShapeError(
f"Parameter dims {other_dims} are not a subset of the prior dims {self.dims}"
f"Parameter dims {remaining} are not a subset of the prior dims {self.dims}"
)

def __str__(self) -> str:
Expand Down Expand Up @@ -827,7 +900,11 @@ def _create_parameter(self, param, value, name, xdist: bool = False):
if xdist:
return value.create_variable(child_name, xdist=True)
else:
return self.dim_handler(value.create_variable(child_name), value.dims or ())
consumed = _get_consumed_dims(self.distribution, self.parameters)
child_consumed = frozenset(d for d in (value.dims or ()) if d in consumed)
return self.dim_handler(
value.create_variable(child_name), value.dims or (), child_consumed
)

def _create_centered_variable(self, name: str, xdist: bool = False):
parameters = {
Expand All @@ -836,7 +913,7 @@ def _create_centered_variable(self, name: str, xdist: bool = False):
}
if xdist:
pymc_distribution = _get_pymc_dim_distribution(self.distribution)
core_dims_kwargs = {"core_dims": self.core_dims}
core_dims_kwargs = {"core_dims": self.core_dims if self.core_dims else None}
else:
pymc_distribution = self.pymc_distribution
core_dims_kwargs = {}
Expand Down Expand Up @@ -870,7 +947,7 @@ def handle_variable(var_name: str):
}
if xdist:
pymc_distribution = _get_pymc_dim_distribution(self.distribution)
core_dims_kwargs = {"core_dims": self.core_dims}
core_dims_kwargs = {"core_dims": self.core_dims if self.core_dims else None}
else:
pymc_distribution = self.pymc_distribution
core_dims_kwargs = {}
Expand Down
112 changes: 112 additions & 0 deletions tests/test_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,3 +1402,115 @@ def test_core_dims(self):
ip = m.initial_point()
with pytensor.config.change_flags(mode="FAST_COMPILE"):
assert m.compile_logp()(ip) == ref_m.compile_logp()(ip)


class TestConsumedDims:
"""Tests for distributions that 'consume' dims in their parameters.

Some distributions take vector/matrix parameters but produce lower-rank
output. For example, Categorical(p)->() consumes the 'p' dimension entirely.
These consumed dims must not be required to appear in the output dims.

See: https://github.com/pymc-devs/pymc-extras/issues/499
"""

# --- Categorical: (p)->(), ndim_supp=0, ndims_params=[1] ---

def test_categorical_construction(self):
"""Prior("Categorical", p=Prior("Dirichlet", ..., dims="probs")) should not raise.

Categorical: (p)->(), ndim_supp=0, ndims_params=[1]
p's dim is fully consumed — it must not be required to appear in output dims.
"""
p = pr.Prior("Dirichlet", a=[1, 1, 1], dims="probs")
cat = pr.Prior("Categorical", p=p, dims="trial")
assert cat.dims == ("trial",)

def test_categorical_create_variable(self):
"""Non-xdist create_variable produces correct shape.

Categorical: (p)->(), ndim_supp=0, ndims_params=[1]
The consumed 'probs' dim disappears; output shape is (trial,).
"""
p = pr.Prior("Dirichlet", a=[1, 1, 1], dims="probs")
cat = pr.Prior("Categorical", p=p, dims="trial")
with pm.Model(coords={"probs": range(3), "trial": range(10)}) as m:
cat.create_variable("y")
assert fast_eval(m["y"]).shape == (10,)

@pytest.mark.filterwarnings("ignore:The `pymc.dims` module is experimental")
def test_categorical_create_variable_xdist(self):
"""xdist=True create_variable produces correct named dims.

Categorical: (p)->(), ndim_supp=0, ndims_params=[1]
core_dims is auto-computed as ('probs',); output dims are ('trial',).
"""
p = pr.Prior("Dirichlet", a=DataArray([1, 1, 1], dims="probs"), dims="probs")
cat = pr.Prior("Categorical", p=p, dims="trial")
with pm.Model(coords={"probs": range(3), "trial": range(10)}) as m:
cat.create_variable("y", xdist=True)
assert m.named_vars_to_dims["y"] == ("trial",)

def test_categorical_batch_broadcast(self):
"""Non-xdist: batch dim 'geo' passes through, core dim 'probs' is consumed.

Categorical: (p)->(), ndim_supp=0, ndims_params=[1]
p has dims ('geo', 'probs'): 'probs' is consumed, 'geo' broadcasts into output.
"""
p = pr.Prior("Dirichlet", a=np.ones((2, 3)), dims=("geo", "probs"))
cat = pr.Prior("Categorical", p=p, dims=("geo", "trial"))
with pm.Model(coords={"geo": range(2), "probs": range(3), "trial": range(10)}) as m:
cat.create_variable("y")
assert fast_eval(m["y"]).shape == (2, 10)

@pytest.mark.filterwarnings("ignore:The `pymc.dims` module is experimental")
def test_categorical_batch_broadcast_xdist(self):
"""xdist=True: batch dim 'geo' passes through, core dim 'probs' is consumed.

Categorical: (p)->(), ndim_supp=0, ndims_params=[1]
p has dims ('geo', 'probs'): 'probs' is consumed, 'geo' broadcasts into output.
"""
a = DataArray(np.ones((2, 3)), dims=("geo", "probs"))
p = pr.Prior("Dirichlet", a=a, dims=("geo", "probs"))
cat = pr.Prior("Categorical", p=p, dims=("geo", "trial"))
with pm.Model(coords={"geo": range(2), "probs": range(3), "trial": range(10)}) as m:
cat.create_variable("y", xdist=True)
assert m.named_vars_to_dims["y"] == ("geo", "trial")

# --- Interpolated: (x),(x),(x)->(), ndim_supp=0, ndims_params=[1,1,1] ---

def test_interpolated_construction(self):
"""Interpolated with x_points as a Prior with dims should not raise.

Interpolated: (x),(x),(x)->(), ndim_supp=0, ndims_params=[1,1,1]
All three parameter dims are consumed — none appear in the scalar output.
"""
x_pts = pr.Prior("Normal", mu=0, sigma=1, dims="knots")
interp = pr.Prior(
"Interpolated",
x_points=x_pts,
pdf_points=np.ones(5) / 5,
dims="obs",
)
assert interp.dims == ("obs",)

# --- MvNormal: (n),(n,n)->(n), ndim_supp=1, ndims_params=[1,2] ---

def test_mvnormal_cov_extra_dim(self):
"""MvNormal cov with 2 named dims: the 2nd dim is consumed, 1st re-emitted.

MvNormal: (n),(n,n)->(n), ndim_supp=1, ndims_params=[1,2]
cov has 2 core dims but output only has 1 — so cov's 2nd dim is "extra consumed".
'component' is re-emitted in the output; 'component_' is consumed.
"""
cov = pr.Prior("Wishart", nu=5, V=np.eye(3), dims=("component", "component_"))
mv = pr.Prior("MvNormal", mu=np.zeros(3), cov=cov, dims="component")
assert mv.dims == ("component",)

# --- Guard: non-consumed incompatible dims still raise ---

def test_incompatible_dims_still_raise(self):
"""A truly incompatible (non-consumed) dim still raises UnsupportedShapeError."""
inner = pr.Prior("Normal", dims="other")
with pytest.raises(pr.UnsupportedShapeError):
pr.Prior("Normal", mu=inner, dims="channel")
Loading