@@ -85,6 +85,7 @@ def custom_transform(x):
8585
8686import copy
8787import typing
88+ import warnings
8889
8990from collections .abc import Callable
9091from functools import partial
@@ -493,6 +494,39 @@ def create_variable(self, name: str) -> "TensorVariable":
493494 ).prior
494495
495496
497+ def _param_value_with_dims (param , value , dims : Dims | None ):
498+ """Infer parameter dims positionally.
499+
500+ This is a transition helper to guide users into defining DataArray parameters explicitly.
501+ """
502+ if hasattr (value , "dims" ):
503+ return value
504+
505+ if isinstance (value , list | tuple | Number ):
506+ value = np .asarray (value )
507+
508+ if value .ndim > 0 :
509+ if dims is None :
510+ raise ValueError (
511+ f"Cannot infer dims of array-like parameter { param } . Use DataArray with explicit dims"
512+ )
513+ else :
514+ parameter_dims = dims [::- 1 ][: value .ndim ]
515+ warnings .warn (
516+ f"Implicit conversion of array-like parameter { param } to DataArray with dims { parameter_dims } . "
517+ "Use DataArray with explicit dims to avoid this warning" ,
518+ stacklevel = 2 ,
519+ )
520+ if isinstance (value , Variable ):
521+ from pytensor .xtensor import as_xtensor
522+
523+ value = as_xtensor (value , dims = parameter_dims )
524+ else :
525+ value = DataArray (value , dims = parameter_dims )
526+
527+ return value
528+
529+
496530class Prior :
497531 """A class to represent a prior distribution.
498532
@@ -777,7 +811,10 @@ def __repr__(self) -> str:
777811
778812 def _create_parameter (self , param , value , name , xdist : bool = False ):
779813 if not hasattr (value , "create_variable" ):
780- return value
814+ if xdist :
815+ return _param_value_with_dims (param , value , dims = self .dims )
816+ else :
817+ return value
781818
782819 child_name = f"{ name } _{ param } "
783820 if xdist :
@@ -803,7 +840,10 @@ def _create_non_centered_variable(
803840 def handle_variable (var_name : str ):
804841 parameter = self .parameters [var_name ]
805842 if not hasattr (parameter , "create_variable" ):
806- return parameter
843+ if xdist :
844+ return _param_value_with_dims (var_name , parameter , dims = self .dims )
845+ else :
846+ return parameter
807847
808848 if xdist :
809849 return parameter .create_variable (f"{ name } _{ var_name } " , xdist = True )
0 commit comments