Skip to content

Commit b4ee3c1

Browse files
committed
Prior: Help users transition from non-scalar parameters to DataArray
1 parent c1fdd52 commit b4ee3c1

2 files changed

Lines changed: 64 additions & 2 deletions

File tree

pymc_extras/prior.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def custom_transform(x):
8585

8686
import copy
8787
import typing
88+
import warnings
8889

8990
from collections.abc import Callable
9091
from functools import partial
@@ -493,6 +494,39 @@ def create_variable(self, name: str) -> "TensorVariable":
493494
).prior
494495

495496

497+
def _param_value_with_dims(param, value, dims: Dims | None):
498+
"""Infer parameter dims positionally.
499+
500+
This is a transition helper to guide users into defining DataArray parameters explicitly.
501+
"""
502+
if hasattr(value, "dims"):
503+
return value
504+
505+
if isinstance(value, list | tuple | Number):
506+
value = np.asarray(value)
507+
508+
if value.ndim > 0:
509+
if dims is None:
510+
raise ValueError(
511+
f"Cannot infer dims of array-like parameter {param}. Use DataArray with explicit dims"
512+
)
513+
else:
514+
parameter_dims = dims[::-1][: value.ndim]
515+
warnings.warn(
516+
f"Implicit conversion of array-like parameter {param} to DataArray with dims {parameter_dims}. "
517+
"Use DataArray with explicit dims to avoid this warning",
518+
stacklevel=2,
519+
)
520+
if isinstance(value, Variable):
521+
from pytensor.xtensor import as_xtensor
522+
523+
value = as_xtensor(value, dims=parameter_dims)
524+
else:
525+
value = DataArray(value, dims=parameter_dims)
526+
527+
return value
528+
529+
496530
class Prior:
497531
"""A class to represent a prior distribution.
498532
@@ -777,7 +811,10 @@ def __repr__(self) -> str:
777811

778812
def _create_parameter(self, param, value, name, xdist: bool = False):
779813
if not hasattr(value, "create_variable"):
780-
return value
814+
if xdist:
815+
return _param_value_with_dims(param, value, dims=self.dims)
816+
else:
817+
return value
781818

782819
child_name = f"{name}_{param}"
783820
if xdist:
@@ -803,7 +840,10 @@ def _create_non_centered_variable(
803840
def handle_variable(var_name: str):
804841
parameter = self.parameters[var_name]
805842
if not hasattr(parameter, "create_variable"):
806-
return parameter
843+
if xdist:
844+
return _param_value_with_dims(var_name, parameter, dims=self.dims)
845+
else:
846+
return parameter
807847

808848
if xdist:
809849
return parameter.create_variable(f"{name}_{var_name}", xdist=True)

tests/test_prior.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,3 +1327,25 @@ def test_dims(self) -> None:
13271327
# This is always invalid
13281328
with pytest.raises(UnsupportedShapeError):
13291329
p.dims = ()
1330+
1331+
@pytest.mark.parametrize(
1332+
"mu",
1333+
(
1334+
[10, 20],
1335+
np.array([10, 20]),
1336+
pt.as_tensor([10, 20]),
1337+
),
1338+
)
1339+
def test_implicit_conversion_with_dims(self, mu):
1340+
# When xdist=True, list/numpy parameters should be converted to DataArray if dims are present
1341+
p_wo_dims = Prior("Normal", mu=mu, dims=None)
1342+
p_with_dims = Prior("Normal", mu=mu, dims=("test_dim",))
1343+
1344+
coords = {"test_dim": ["a", "b"]}
1345+
with pm.Model(coords=coords):
1346+
with pytest.warns(UserWarning, match="Implicit conversion"):
1347+
res = p_with_dims.create_variable("v", xdist=True)
1348+
assert res.dims == ("test_dim",)
1349+
1350+
with pytest.raises(ValueError, match="Cannot infer dims"):
1351+
p_wo_dims.create_variable("v", xdist=True)

0 commit comments

Comments
 (0)