Skip to content
90 changes: 61 additions & 29 deletions pymc_extras/statespace/core/statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ class PyMCStateSpace:
Regardless of whether a mode is specified, it can always be overwritten via the ``compile_kwargs`` argument
to all sampling methods.

name : str, optional
Prefix used to namespace internal graph variable and data names so multiple state space models can coexist
in the same PyMC model without naming collisions. If ``None``, the default naming behavior is used.

Notes
-----
Based on the statsmodels statespace implementation https://github.com/statsmodels/statsmodels/blob/main/statsmodels/tsa/statespace/representation.py,
Expand Down Expand Up @@ -261,6 +265,8 @@ def __init__(
verbose: bool = True,
measurement_error: bool = False,
mode: str | None = None,
name: str | None = None,
data_name: str = "data",
):
self._fit_coords: dict[str, Sequence[str]] | None = None
self._fit_dims: dict[str, Sequence[str]] | None = None
Expand All @@ -274,6 +280,8 @@ def __init__(
self.k_endog = k_endog
self.k_states = k_states
self.k_posdef = k_posdef
self.name = name
self.data_name = data_name
self.measurement_error = measurement_error
self.mode = mode

Expand Down Expand Up @@ -305,6 +313,12 @@ def __init__(
console = Console()
console.print(self.requirement_table)

def prefixed_name(self, base_name: str) -> str:
if not self.name:
return base_name
prefix = f"{self.name}_"
return base_name if base_name.startswith(prefix) else f"{self.name}_{base_name}"

def _populate_properties(self) -> None:
self._set_parameters()
self._set_states()
Expand Down Expand Up @@ -614,7 +628,7 @@ def add_default_priors(self) -> None:
raise NotImplementedError("The add_default_priors property has not been implemented!")

def make_and_register_variable(
self, name, shape: int | tuple[int, ...] | None = None, dtype=floatX
self, base_name, shape: int | tuple[int, ...] | None = None, dtype=floatX
) -> pt.TensorVariable:
"""
Helper function to create a pytensor symbolic variable and register it in the _name_to_variable dictionary
Expand Down Expand Up @@ -643,12 +657,14 @@ def make_and_register_variable(
An error is raised if the provided name has already been registered, or if the name is not present in the
``param_names`` property.
"""
if name not in self.param_names:
if base_name not in self.param_names:
raise ValueError(
f"{name} is not a model parameter. All placeholder variables should correspond to model "
f"{base_name} is not a model parameter. All placeholder variables should correspond to model "
f"parameters."
)

name = self.prefixed_name(base_name)

if name in self._tensor_variable_info:
raise ValueError(
f"{name} is already a registered placeholder variable with shape "
Expand All @@ -661,7 +677,7 @@ def make_and_register_variable(
return placeholder

def make_and_register_data(
self, name: str, shape: int | tuple[int], dtype: str = floatX
self, base_name: str, shape: int | tuple[int], dtype: str = floatX
) -> Variable:
r"""
Helper function to create a pytensor symbolic variable and register it in the _name_to_data dictionary
Expand All @@ -683,12 +699,14 @@ def make_and_register_data(
An error is raised if the provided name has already been registered, or if the name is not present in the
``data_names`` property.
"""
if name not in self.data_names:
if base_name not in self.data_names:
raise ValueError(
f"{name} is not a model parameter. All placeholder variables should correspond to model "
f"parameters."
f"{base_name} is not a model data variable. All placeholder variables should correspond to model "
f"data variables."
)

name = self.prefixed_name(base_name)

if name in self._tensor_data_info:
raise ValueError(
f"{name} is already a registered placeholder variable with shape "
Expand Down Expand Up @@ -800,11 +818,12 @@ def _save_exogenous_data_info(self):
"""
pymc_mod = modelcontext(None)
for data_name in self.data_names:
data = pymc_mod[data_name]
name = self.prefixed_name(data_name)
data = pymc_mod[name]
self._fit_exog_data[data_name] = {
"name": data_name,
"name": name,
"value": data.get_value(),
"dims": pymc_mod.named_vars_to_dims.get(data_name, None),
"dims": pymc_mod.named_vars_to_dims.get(name, None),
}

def _insert_random_variables(self):
Expand Down Expand Up @@ -843,9 +862,9 @@ def _insert_random_variables(self):
found_params = []
with pymc_model:
for param_name in self.param_names:
param = getattr(pymc_model, param_name, None)
if param is not None:
found_params.append(param.name)
name = self.prefixed_name(param_name)
if name in pymc_model:
found_params.append(param_name)

missing_params = list(set(self.param_names) - set(found_params))
if len(missing_params) > 0:
Expand Down Expand Up @@ -880,9 +899,9 @@ def _insert_data_variables(self):
found_data = []
with pymc_model:
for data_name in data_names:
data = getattr(pymc_model, data_name, None)
if data is not None:
found_data.append(data.name)
name = self.prefixed_name(data_name)
if name in pymc_model:
found_data.append(data_name)

missing_data = list(set(data_names) - set(found_data))
if len(missing_data) > 0:
Expand Down Expand Up @@ -1046,6 +1065,7 @@ def build_statespace_graph(
obs_coords=obs_coords,
register_data=register_data,
missing_fill_value=missing_fill_value,
data_name=self.prefixed_name(self.data_name),
)

filter_outputs = self.kalman_filter.build_graph(
Expand All @@ -1070,7 +1090,7 @@ def build_statespace_graph(
obs_dims = obs_dims if all([dim in pm_mod.coords.keys() for dim in obs_dims]) else None

SequenceMvNormal(
"obs",
self.prefixed_name("obs"),
mus=observed_states,
covs=observed_covariances,
logp=logp,
Expand Down Expand Up @@ -1144,15 +1164,16 @@ def _build_dummy_graph(self) -> None:
A list of pm.Flat variables representing all parameters estimated by the model.
"""

def infer_variable_shape(name):
def infer_variable_shape(base_name):
name = self.prefixed_name(base_name)
shape = self._name_to_variable[name].type.shape
if not any(dim is None for dim in shape):
return shape

dim_names = self._fit_dims.get(name, None)
if dim_names is None:
raise ValueError(
f"Could not infer shape for {name}, because it was not given coords during model"
f"Could not infer shape for {base_name}, because it was not given coords during model"
f"fitting"
)

Expand All @@ -1164,11 +1185,11 @@ def infer_variable_shape(name):
]
)

for name in self.param_names:
for base_name in self.param_names:
pm.Flat(
name,
shape=infer_variable_shape(name),
dims=self._fit_dims.get(name, None),
self.prefixed_name(base_name),
shape=infer_variable_shape(base_name),
dims=self._fit_dims.get(self.prefixed_name(base_name), None),
)

def _kalman_filter_outputs_from_dummy_graph(
Expand Down Expand Up @@ -1208,14 +1229,14 @@ def _kalman_filter_outputs_from_dummy_graph(
self._insert_random_variables()

for name in self.data_names:
if name not in pm_mod:
if self.prefixed_name(name) not in pm_mod:
pm.Data(**self._fit_exog_data[name])

self._insert_data_variables()

for name in self.data_names:
if name in scenario.keys():
pm.set_data({name: scenario[name]})
pm.set_data({self.prefixed_name(name): scenario[name]})

x0, P0, c, d, T, Z, R, H, Q = self.unpack_statespace()

Expand All @@ -1230,6 +1251,7 @@ def _kalman_filter_outputs_from_dummy_graph(
obs_coords=obs_coords,
data_dims=data_dims,
register_data=True,
data_name=self.prefixed_name(self.data_name),
)

filter_outputs = self.kalman_filter.build_graph(
Expand Down Expand Up @@ -1786,7 +1808,7 @@ def sample_statespace_matrices(
self._insert_random_variables()

for name in self.data_names:
pm.Data(**self.data_info[name])
pm.Data(name=self.prefixed_name(name), **self.data_info[name])

self._insert_data_variables()
matrices = self.unpack_statespace()
Expand Down Expand Up @@ -1852,6 +1874,7 @@ def sample_filter_outputs(
n_obs=self.ssm.k_endog,
obs_coords=obs_coords,
register_data=True,
data_name=self.prefixed_name(self.data_name),
)

filter_outputs = self.kalman_filter.build_graph(
Expand Down Expand Up @@ -2283,12 +2306,18 @@ def _build_forecast_model(
mu, cov = grouped_outputs[group_idx]

sub_dict = {
data_var: pt.as_tensor_variable(data_var.get_value(), name="data")
data_var: pt.as_tensor_variable(
data_var.get_value(), name=self.prefixed_name(self.data_name)
)
for data_var in forecast_model.data_vars
}

missing_data_vars = np.setdiff1d(
ar1=[*self.data_names, "data"], ar2=[k.name for k, _ in sub_dict.items()]
ar1=[
*[self.prefixed_name(name) for name in self.data_names],
self.prefixed_name(self.data_name),
],
ar2=[k.name for k, _ in sub_dict.items()],
)
if missing_data_vars.size > 0:
raise ValueError(f"{missing_data_vars} data used for fitting not found!")
Expand Down Expand Up @@ -2466,8 +2495,11 @@ def forecast(
with forecast_model:
if scenario is not None:
dummy_obs_data = np.zeros((len(forecast_index), self.k_endog))
scoped_scenario = {
self.prefixed_name(name): value for name, value in scenario.items()
}
pm.set_data(
scenario | {"data": dummy_obs_data},
scoped_scenario | {self.prefixed_name(self.data_name): dummy_obs_data},
coords={"data_time": np.arange(len(forecast_index))},
)

Expand Down
57 changes: 50 additions & 7 deletions pymc_extras/statespace/models/structural/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
import xarray as xr

from pytensor import Mode, Variable, config
from pytensor import Mode, Variable, config, graph_replace
from pytensor import tensor as pt

from pymc_extras.statespace.core.properties import (
Expand Down Expand Up @@ -117,7 +117,7 @@ class StructuralTimeSeries(PyMCStateSpace):
def __init__(
self,
ssm: PytensorRepresentation,
name: str,
name: str | None,
coords_info: CoordInfo,
param_info: ParameterInfo,
data_info: DataInfo,
Expand All @@ -140,8 +140,10 @@ def __init__(
----------
ssm : PytensorRepresentation
The state space representation containing system matrices.
name : str
Name of the model. If None, defaults to "StructuralTimeSeries".
name : str, optional
Prefix applied to all internal graph variable and data names, allowing multiple
state space models to coexist in the same PyMC model without naming collisions.
If ``None`` (default), no prefix is applied and variable names are unchanged.
coords_info : CoordInfo
Coordinate specifications for model dimensions.
param_info : ParameterInfo
Expand All @@ -167,7 +169,6 @@ def __init__(
mode : str | Mode | None, default None
PyTensor compilation mode.
"""
self._name = name or "StructuralTimeSeries"
self.measurement_error = measurement_error

k_states, k_posdef, k_endog = ssm.k_states, ssm.k_posdef, ssm.k_endog
Expand All @@ -184,8 +185,14 @@ def __init__(
verbose=verbose,
measurement_error=measurement_error,
mode=mode,
name=name,
)

if name is not None:
ssm, tensor_variable_info, tensor_data_info = self._prefix_placeholder_variables(
ssm, tensor_variable_info, tensor_data_info
)

self._tensor_variable_info = tensor_variable_info
self._tensor_data_info = tensor_data_info
self._component_info = component_info.copy()
Expand Down Expand Up @@ -234,6 +241,37 @@ def strip(names):
self._shock_names = strip(self._shock_info.names)
self._param_names = strip(self._param_info.names)

def _prefix_placeholder_variables(self, ssm, tensor_variable_info, tensor_data_info):
"""Replace placeholder variables in SSM matrices with prefixed-name copies."""
replacements = {}
new_variables = []
for sv in tensor_variable_info:
old_var = sv.symbolic_variable
new_var = old_var.type(name=self.prefixed_name(sv.name))
replacements[old_var] = new_var
new_variables.append(SymbolicVariable(name=new_var.name, symbolic_variable=new_var))

new_data_entries = []
for sd in tensor_data_info:
old_var = sd.symbolic_data
new_var = old_var.type(name=self.prefixed_name(sd.name))
replacements[old_var] = new_var
new_data_entries.append(SymbolicData(name=new_var.name, symbolic_data=new_var))

if not replacements:
return ssm, tensor_variable_info, tensor_data_info

matrices = [getattr(ssm, name) for name in LONG_MATRIX_NAMES]
replaced_matrices = graph_replace(matrices, replace=replacements, strict=False)

for mat_name, new_mat in zip(LONG_MATRIX_NAMES, replaced_matrices):
setattr(ssm, mat_name, new_mat)

new_variable_info = SymbolicVariableInfo(symbolic_variables=tuple(new_variables))
new_data_info = SymbolicDataInfo(symbolic_data=tuple(new_data_entries))

return ssm, new_variable_info, new_data_info

def _init_ssm(self, ssm: PytensorRepresentation, k_posdef: int) -> None:
"""Initialize state space model representation."""
self.ssm = ssm.copy()
Expand Down Expand Up @@ -988,8 +1026,13 @@ def build(

Parameters
----------
name: str, optional
Name of the exogenous data being modeled. Default is "data"
name : str, optional
Prefix applied to all internal graph variable and data names, allowing multiple
structural models to coexist in the same PyMC model without naming collisions.
If ``None`` (default), no prefix is applied. When a name is provided, prior
variables must be named with the prefix, e.g. ``pm.Normal("m1_initial_trend", ...)``
for ``name="m1"``. Use ``model.prefixed_name(p)`` for each ``p`` in
``model.param_names`` to get the expected names.

filter_type : str, optional
The type of Kalman filter to use. Valid options are "standard", "univariate", "single", "cholesky", and
Expand Down
Loading
Loading