diff --git a/pymc_extras/statespace/core/statespace.py b/pymc_extras/statespace/core/statespace.py index 0f887291d..4131af79c 100644 --- a/pymc_extras/statespace/core/statespace.py +++ b/pymc_extras/statespace/core/statespace.py @@ -138,6 +138,10 @@ class PyMCStateSpace: Regardless of whether a mode is specified, it can always be overwritten via the ``compile_kwargs`` argument to all sampling methods. + name : str, optional + Prefix used to namespace internal graph variable and data names so multiple state space models can coexist + in the same PyMC model without naming collisions. If ``None``, the default naming behavior is used. + Notes ----- Based on the statsmodels statespace implementation https://github.com/statsmodels/statsmodels/blob/main/statsmodels/tsa/statespace/representation.py, @@ -261,6 +265,8 @@ def __init__( verbose: bool = True, measurement_error: bool = False, mode: str | None = None, + name: str | None = None, + data_name: str = "data", ): self._fit_coords: dict[str, Sequence[str]] | None = None self._fit_dims: dict[str, Sequence[str]] | None = None @@ -274,6 +280,8 @@ def __init__( self.k_endog = k_endog self.k_states = k_states self.k_posdef = k_posdef + self.name = name + self.data_name = data_name self.measurement_error = measurement_error self.mode = mode @@ -305,6 +313,12 @@ def __init__( console = Console() console.print(self.requirement_table) + def prefixed_name(self, base_name: str) -> str: + if not self.name: + return base_name + prefix = f"{self.name}_" + return base_name if base_name.startswith(prefix) else f"{self.name}_{base_name}" + def _populate_properties(self) -> None: self._set_parameters() self._set_states() @@ -614,7 +628,7 @@ def add_default_priors(self) -> None: raise NotImplementedError("The add_default_priors property has not been implemented!") def make_and_register_variable( - self, name, shape: int | tuple[int, ...] | None = None, dtype=floatX + self, base_name, shape: int | tuple[int, ...] | None = None, dtype=floatX ) -> pt.TensorVariable: """ Helper function to create a pytensor symbolic variable and register it in the _name_to_variable dictionary @@ -643,12 +657,14 @@ def make_and_register_variable( An error is raised if the provided name has already been registered, or if the name is not present in the ``param_names`` property. """ - if name not in self.param_names: + if base_name not in self.param_names: raise ValueError( - f"{name} is not a model parameter. All placeholder variables should correspond to model " + f"{base_name} is not a model parameter. All placeholder variables should correspond to model " f"parameters." ) + name = self.prefixed_name(base_name) + if name in self._tensor_variable_info: raise ValueError( f"{name} is already a registered placeholder variable with shape " @@ -661,7 +677,7 @@ def make_and_register_variable( return placeholder def make_and_register_data( - self, name: str, shape: int | tuple[int], dtype: str = floatX + self, base_name: str, shape: int | tuple[int], dtype: str = floatX ) -> Variable: r""" Helper function to create a pytensor symbolic variable and register it in the _name_to_data dictionary @@ -683,12 +699,14 @@ def make_and_register_data( An error is raised if the provided name has already been registered, or if the name is not present in the ``data_names`` property. """ - if name not in self.data_names: + if base_name not in self.data_names: raise ValueError( - f"{name} is not a model parameter. All placeholder variables should correspond to model " - f"parameters." + f"{base_name} is not a model data variable. All placeholder variables should correspond to model " + f"data variables." ) + name = self.prefixed_name(base_name) + if name in self._tensor_data_info: raise ValueError( f"{name} is already a registered placeholder variable with shape " @@ -800,11 +818,12 @@ def _save_exogenous_data_info(self): """ pymc_mod = modelcontext(None) for data_name in self.data_names: - data = pymc_mod[data_name] + name = self.prefixed_name(data_name) + data = pymc_mod[name] self._fit_exog_data[data_name] = { - "name": data_name, + "name": name, "value": data.get_value(), - "dims": pymc_mod.named_vars_to_dims.get(data_name, None), + "dims": pymc_mod.named_vars_to_dims.get(name, None), } def _insert_random_variables(self): @@ -843,9 +862,9 @@ def _insert_random_variables(self): found_params = [] with pymc_model: for param_name in self.param_names: - param = getattr(pymc_model, param_name, None) - if param is not None: - found_params.append(param.name) + name = self.prefixed_name(param_name) + if name in pymc_model: + found_params.append(param_name) missing_params = list(set(self.param_names) - set(found_params)) if len(missing_params) > 0: @@ -880,9 +899,9 @@ def _insert_data_variables(self): found_data = [] with pymc_model: for data_name in data_names: - data = getattr(pymc_model, data_name, None) - if data is not None: - found_data.append(data.name) + name = self.prefixed_name(data_name) + if name in pymc_model: + found_data.append(data_name) missing_data = list(set(data_names) - set(found_data)) if len(missing_data) > 0: @@ -1046,6 +1065,7 @@ def build_statespace_graph( obs_coords=obs_coords, register_data=register_data, missing_fill_value=missing_fill_value, + data_name=self.prefixed_name(self.data_name), ) filter_outputs = self.kalman_filter.build_graph( @@ -1070,7 +1090,7 @@ def build_statespace_graph( obs_dims = obs_dims if all([dim in pm_mod.coords.keys() for dim in obs_dims]) else None SequenceMvNormal( - "obs", + self.prefixed_name("obs"), mus=observed_states, covs=observed_covariances, logp=logp, @@ -1144,7 +1164,8 @@ def _build_dummy_graph(self) -> None: A list of pm.Flat variables representing all parameters estimated by the model. """ - def infer_variable_shape(name): + def infer_variable_shape(base_name): + name = self.prefixed_name(base_name) shape = self._name_to_variable[name].type.shape if not any(dim is None for dim in shape): return shape @@ -1152,7 +1173,7 @@ def infer_variable_shape(name): dim_names = self._fit_dims.get(name, None) if dim_names is None: raise ValueError( - f"Could not infer shape for {name}, because it was not given coords during model" + f"Could not infer shape for {base_name}, because it was not given coords during model" f"fitting" ) @@ -1164,11 +1185,11 @@ def infer_variable_shape(name): ] ) - for name in self.param_names: + for base_name in self.param_names: pm.Flat( - name, - shape=infer_variable_shape(name), - dims=self._fit_dims.get(name, None), + self.prefixed_name(base_name), + shape=infer_variable_shape(base_name), + dims=self._fit_dims.get(self.prefixed_name(base_name), None), ) def _kalman_filter_outputs_from_dummy_graph( @@ -1208,14 +1229,14 @@ def _kalman_filter_outputs_from_dummy_graph( self._insert_random_variables() for name in self.data_names: - if name not in pm_mod: + if self.prefixed_name(name) not in pm_mod: pm.Data(**self._fit_exog_data[name]) self._insert_data_variables() for name in self.data_names: if name in scenario.keys(): - pm.set_data({name: scenario[name]}) + pm.set_data({self.prefixed_name(name): scenario[name]}) x0, P0, c, d, T, Z, R, H, Q = self.unpack_statespace() @@ -1230,6 +1251,7 @@ def _kalman_filter_outputs_from_dummy_graph( obs_coords=obs_coords, data_dims=data_dims, register_data=True, + data_name=self.prefixed_name(self.data_name), ) filter_outputs = self.kalman_filter.build_graph( @@ -1786,7 +1808,7 @@ def sample_statespace_matrices( self._insert_random_variables() for name in self.data_names: - pm.Data(**self.data_info[name]) + pm.Data(name=self.prefixed_name(name), **self.data_info[name]) self._insert_data_variables() matrices = self.unpack_statespace() @@ -1852,6 +1874,7 @@ def sample_filter_outputs( n_obs=self.ssm.k_endog, obs_coords=obs_coords, register_data=True, + data_name=self.prefixed_name(self.data_name), ) filter_outputs = self.kalman_filter.build_graph( @@ -2283,12 +2306,18 @@ def _build_forecast_model( mu, cov = grouped_outputs[group_idx] sub_dict = { - data_var: pt.as_tensor_variable(data_var.get_value(), name="data") + data_var: pt.as_tensor_variable( + data_var.get_value(), name=self.prefixed_name(self.data_name) + ) for data_var in forecast_model.data_vars } missing_data_vars = np.setdiff1d( - ar1=[*self.data_names, "data"], ar2=[k.name for k, _ in sub_dict.items()] + ar1=[ + *[self.prefixed_name(name) for name in self.data_names], + self.prefixed_name(self.data_name), + ], + ar2=[k.name for k, _ in sub_dict.items()], ) if missing_data_vars.size > 0: raise ValueError(f"{missing_data_vars} data used for fitting not found!") @@ -2466,8 +2495,11 @@ def forecast( with forecast_model: if scenario is not None: dummy_obs_data = np.zeros((len(forecast_index), self.k_endog)) + scoped_scenario = { + self.prefixed_name(name): value for name, value in scenario.items() + } pm.set_data( - scenario | {"data": dummy_obs_data}, + scoped_scenario | {self.prefixed_name(self.data_name): dummy_obs_data}, coords={"data_time": np.arange(len(forecast_index))}, ) diff --git a/pymc_extras/statespace/models/structural/core.py b/pymc_extras/statespace/models/structural/core.py index dda5269c0..a393fd463 100644 --- a/pymc_extras/statespace/models/structural/core.py +++ b/pymc_extras/statespace/models/structural/core.py @@ -7,7 +7,7 @@ import numpy as np import xarray as xr -from pytensor import Mode, Variable, config +from pytensor import Mode, Variable, config, graph_replace from pytensor import tensor as pt from pymc_extras.statespace.core.properties import ( @@ -117,7 +117,7 @@ class StructuralTimeSeries(PyMCStateSpace): def __init__( self, ssm: PytensorRepresentation, - name: str, + name: str | None, coords_info: CoordInfo, param_info: ParameterInfo, data_info: DataInfo, @@ -140,8 +140,10 @@ def __init__( ---------- ssm : PytensorRepresentation The state space representation containing system matrices. - name : str - Name of the model. If None, defaults to "StructuralTimeSeries". + name : str, optional + Prefix applied to all internal graph variable and data names, allowing multiple + state space models to coexist in the same PyMC model without naming collisions. + If ``None`` (default), no prefix is applied and variable names are unchanged. coords_info : CoordInfo Coordinate specifications for model dimensions. param_info : ParameterInfo @@ -167,7 +169,6 @@ def __init__( mode : str | Mode | None, default None PyTensor compilation mode. """ - self._name = name or "StructuralTimeSeries" self.measurement_error = measurement_error k_states, k_posdef, k_endog = ssm.k_states, ssm.k_posdef, ssm.k_endog @@ -184,8 +185,14 @@ def __init__( verbose=verbose, measurement_error=measurement_error, mode=mode, + name=name, ) + if name is not None: + ssm, tensor_variable_info, tensor_data_info = self._prefix_placeholder_variables( + ssm, tensor_variable_info, tensor_data_info + ) + self._tensor_variable_info = tensor_variable_info self._tensor_data_info = tensor_data_info self._component_info = component_info.copy() @@ -234,6 +241,37 @@ def strip(names): self._shock_names = strip(self._shock_info.names) self._param_names = strip(self._param_info.names) + def _prefix_placeholder_variables(self, ssm, tensor_variable_info, tensor_data_info): + """Replace placeholder variables in SSM matrices with prefixed-name copies.""" + replacements = {} + new_variables = [] + for sv in tensor_variable_info: + old_var = sv.symbolic_variable + new_var = old_var.type(name=self.prefixed_name(sv.name)) + replacements[old_var] = new_var + new_variables.append(SymbolicVariable(name=new_var.name, symbolic_variable=new_var)) + + new_data_entries = [] + for sd in tensor_data_info: + old_var = sd.symbolic_data + new_var = old_var.type(name=self.prefixed_name(sd.name)) + replacements[old_var] = new_var + new_data_entries.append(SymbolicData(name=new_var.name, symbolic_data=new_var)) + + if not replacements: + return ssm, tensor_variable_info, tensor_data_info + + matrices = [getattr(ssm, name) for name in LONG_MATRIX_NAMES] + replaced_matrices = graph_replace(matrices, replace=replacements, strict=False) + + for mat_name, new_mat in zip(LONG_MATRIX_NAMES, replaced_matrices): + setattr(ssm, mat_name, new_mat) + + new_variable_info = SymbolicVariableInfo(symbolic_variables=tuple(new_variables)) + new_data_info = SymbolicDataInfo(symbolic_data=tuple(new_data_entries)) + + return ssm, new_variable_info, new_data_info + def _init_ssm(self, ssm: PytensorRepresentation, k_posdef: int) -> None: """Initialize state space model representation.""" self.ssm = ssm.copy() @@ -988,8 +1026,13 @@ def build( Parameters ---------- - name: str, optional - Name of the exogenous data being modeled. Default is "data" + name : str, optional + Prefix applied to all internal graph variable and data names, allowing multiple + structural models to coexist in the same PyMC model without naming collisions. + If ``None`` (default), no prefix is applied. When a name is provided, prior + variables must be named with the prefix, e.g. ``pm.Normal("m1_initial_trend", ...)`` + for ``name="m1"``. Use ``model.prefixed_name(p)`` for each ``p`` in + ``model.param_names`` to get the expected names. filter_type : str, optional The type of Kalman filter to use. Valid options are "standard", "univariate", "single", "cholesky", and diff --git a/pymc_extras/statespace/utils/data_tools.py b/pymc_extras/statespace/utils/data_tools.py index cbc5d517c..204021529 100644 --- a/pymc_extras/statespace/utils/data_tools.py +++ b/pymc_extras/statespace/utils/data_tools.py @@ -121,7 +121,7 @@ def preprocess_pandas_data(data, n_obs, obs_coords=None, check_column_names=Fals return preprocess_numpy_data(data.values, n_obs, obs_coords) -def add_data_to_active_model(values, index, data_dims=None): +def add_data_to_active_model(values, index, data_dims=None, data_name="data"): pymc_mod = modelcontext(None) if data_dims is None: data_dims = [TIME_DIM, OBS_STATE_DIM] @@ -146,7 +146,7 @@ def add_data_to_active_model(values, index, data_dims=None): else: data_shape = (None, values.shape[-1]) - data = pm.Data("data", values, dims=data_dims, shape=data_shape) + data = pm.Data(data_name, values, dims=data_dims, shape=data_shape) return data @@ -178,7 +178,13 @@ def mask_missing_values_in_data(values, missing_fill_value=None): def register_data_with_pymc( - data, n_obs, obs_coords, register_data=True, missing_fill_value=None, data_dims=None + data, + n_obs, + obs_coords, + register_data=True, + missing_fill_value=None, + data_dims=None, + data_name="data", ): if isinstance(data, pt.TensorVariable | TensorSharedVariable): values, index = preprocess_tensor_data(data, n_obs, obs_coords) @@ -192,7 +198,7 @@ def register_data_with_pymc( data, nan_mask = mask_missing_values_in_data(values, missing_fill_value) if register_data: - data = add_data_to_active_model(data, index, data_dims) + data = add_data_to_active_model(data, index, data_dims, data_name=data_name) else: - data = pytensor.shared(data, name="data") + data = pytensor.shared(data, name=data_name) return data, nan_mask diff --git a/tests/statespace/core/test_statespace.py b/tests/statespace/core/test_statespace.py index c8e343609..012a0ac6a 100644 --- a/tests/statespace/core/test_statespace.py +++ b/tests/statespace/core/test_statespace.py @@ -440,6 +440,35 @@ def test_base_class_raises(): ) +@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.") +def test_two_named_statespace_models_coexist_end_to_end(mock_pymc_sample): + data = pd.DataFrame( + np.random.default_rng(42).normal(size=(100, 1)).astype(floatX), columns=["y"] + ) + + ss_a = st.LevelTrend(name="trend", order=1, innovations_order=1).build(name="a", verbose=False) + ss_b = st.LevelTrend(name="trend", order=1, innovations_order=1).build(name="b", verbose=False) + + with pm.Model(coords={**ss_a.coords, **ss_b.coords}) as m: + pm.Normal("a_initial_trend", dims=["state_trend"]) + a_P0_sigma = pm.Exponential("a_P0_sigma", 1) + pm.Deterministic("a_P0", pt.eye(ss_a.k_states) * a_P0_sigma, dims=["state", "state_aux"]) + pm.Exponential("a_sigma_trend", 1, dims=["shock_trend"]) + ss_a.build_statespace_graph(data) + + pm.Normal("b_initial_trend", dims=["state_trend"]) + b_P0_sigma = pm.Exponential("b_P0_sigma", 1) + pm.Deterministic("b_P0", pt.eye(ss_b.k_states) * b_P0_sigma, dims=["state", "state_aux"]) + pm.Exponential("b_sigma_trend", 1, dims=["shock_trend"]) + ss_b.build_statespace_graph(data) + + idata = pm.sample(draws=10, tune=0, chains=1) + + assert "a_obs" in m.named_vars + assert "b_obs" in m.named_vars + assert "posterior" in idata.groups() + + def test_update_raises_if_missing_variables(ss_mod): with pm.Model() as mod: rho = pm.Normal("rho") diff --git a/tests/statespace/models/structural/test_core.py b/tests/statespace/models/structural/test_core.py index cdca5d4fa..44fd30036 100644 --- a/tests/statespace/models/structural/test_core.py +++ b/tests/statespace/models/structural/test_core.py @@ -18,6 +18,42 @@ RTOL = 0 if floatX.endswith("64") else 1e-6 +def _build_named_structural_model(name: str): + return ( + st.LevelTrend(order=1, innovations_order=1) + + st.Regression(name="reg", state_names=["x"]) + + st.MeasurementError(name="obs") + ).build(name=name, verbose=False) + + +def test_structural_name_propagates_to_base_and_scopes_p0(): + ss_mod = _build_named_structural_model(name="m1") + + assert ss_mod.name == "m1" + assert "P0" in ss_mod.param_names + assert ss_mod.prefixed_name("P0") in ss_mod._name_to_variable + assert "P0" not in ss_mod._name_to_variable + + +def test_named_structural_models_do_not_collide_in_placeholder_registries(): + with pm.Model(): + m1 = _build_named_structural_model(name="m1") + m2 = _build_named_structural_model(name="m2") + + var_keys_1 = set(m1._name_to_variable) + var_keys_2 = set(m2._name_to_variable) + data_keys_1 = set(m1._name_to_data) + data_keys_2 = set(m2._name_to_data) + + assert var_keys_1.isdisjoint(var_keys_2) + assert data_keys_1.isdisjoint(data_keys_2) + + assert var_keys_1 == {m1.prefixed_name(name) for name in m1.param_names} + assert var_keys_2 == {m2.prefixed_name(name) for name in m2.param_names} + assert data_keys_1 == {m1.prefixed_name(name) for name in m1.data_names} + assert data_keys_2 == {m2.prefixed_name(name) for name in m2.data_names} + + def test_add_components(): ll = st.LevelTrend(order=2) se = st.TimeSeasonality(name="seasonal", season_length=12) @@ -195,3 +231,116 @@ def test_sequence_type_component_arguments(arg_type): assert ss_mod.k_endog == len(state_names) assert sorted(ss_mod.observed_states) == sorted(list(state_names)) + + +class TestGraphReplacePlaceholderNamespacing: + @staticmethod + def _variable_names_match(mod): + """Check that all variable metadata names carry the expected model prefix.""" + for sv in mod._tensor_variable_info: + if mod.name and not sv.name.startswith(f"{mod.name}_"): + raise ValueError(f"Variable {sv.name} missing expected prefix {mod.name}_") + + for sd in mod._tensor_data_info: + if sd.name != sd.symbolic_data.name: + raise ValueError( + f"Data name mismatch: metadata={sd.name}, data={sd.symbolic_data.name}" + ) + if mod.name and not sd.name.startswith(f"{mod.name}_"): + raise ValueError(f"Data {sd.name} missing expected prefix {mod.name}_") + + var_ids = [id(sv.symbolic_variable) for sv in mod._tensor_variable_info] + if len(var_ids) != len(set(var_ids)): + raise ValueError("Duplicate Variable objects in tensor_variable_info") + + data_ids = [id(sd.symbolic_data) for sd in mod._tensor_data_info] + if len(data_ids) != len(set(data_ids)): + raise ValueError("Duplicate Variable objects in tensor_data_info") + + def test_same_component_reused_in_two_named_models_no_aliasing(self): + trend = st.LevelTrend(order=1, innovations_order=1) + m1 = trend.build(name="m1", verbose=False) + m2 = trend.build(name="m2", verbose=False) + + for sv in m1._tensor_variable_info: + assert sv.name.startswith("m1_") + + for sv in m2._tensor_variable_info: + assert sv.name.startswith("m2_") + + m1_var_ids = {id(sv.symbolic_variable) for sv in m1._tensor_variable_info} + m2_var_ids = {id(sv.symbolic_variable) for sv in m2._tensor_variable_info} + assert m1_var_ids.isdisjoint(m2_var_ids) + + def test_reused_component_with_data_placeholders(self): + comp = st.LevelTrend(order=1, innovations_order=1) + st.Regression( + name="reg", state_names=["x"] + ) + m1 = comp.build(name="m1", verbose=False) + m2 = comp.build(name="m2", verbose=False) + + m1_var_ids = {id(sv.symbolic_variable) for sv in m1._tensor_variable_info} + m2_var_ids = {id(sv.symbolic_variable) for sv in m2._tensor_variable_info} + assert m1_var_ids.isdisjoint(m2_var_ids) + + m1_data_ids = {id(sd.symbolic_data) for sd in m1._tensor_data_info} + m2_data_ids = {id(sd.symbolic_data) for sd in m2._tensor_data_info} + assert m1_data_ids.isdisjoint(m2_data_ids) + + for sd in m1._tensor_data_info: + assert sd.name.startswith("m1_") + assert sd.name == sd.symbolic_data.name + for sd in m2._tensor_data_info: + assert sd.name.startswith("m2_") + assert sd.name == sd.symbolic_data.name + + def test_metadata_names_match_variable_names(self): + mod = (st.LevelTrend(order=1, innovations_order=1) + st.MeasurementError(name="obs")).build( + name="test_model", verbose=False + ) + + for sd in mod._tensor_data_info: + assert sd.name == sd.symbolic_data.name + + self._variable_names_match(mod) + + def test_unnamed_model_has_no_prefix_on_variable_names(self): + mod = st.LevelTrend(order=1, innovations_order=1).build(name=None, verbose=False) + + registered_names = {sv.name for sv in mod._tensor_variable_info} + assert registered_names == set(mod.param_names) + + def test_prefixed_placeholders_are_in_ssm_graph(self): + from pytensor.graph.traversal import explicit_graph_inputs + + from pymc_extras.statespace.utils.constants import LONG_MATRIX_NAMES + + mod = st.LevelTrend(order=1, innovations_order=1).build(name="ns", verbose=False) + + all_matrices = [getattr(mod.ssm, name) for name in LONG_MATRIX_NAMES] + graph_input_names = { + v.name for v in explicit_graph_inputs(all_matrices) if hasattr(v, "name") and v.name + } + + p0_name = mod.prefixed_name("P0") + expected_names = {sv.name for sv in mod._tensor_variable_info if sv.name != p0_name} + assert expected_names <= graph_input_names + + unprefixed_names = {"initial_level_trend", "sigma_level_trend"} + assert unprefixed_names.isdisjoint(graph_input_names) + + def test_corrupted_metadata_name_raises(self): + from pymc_extras.statespace.core.properties import SymbolicVariable, SymbolicVariableInfo + + mod = st.LevelTrend(order=1, innovations_order=1).build(name="v", verbose=False) + + mod._tensor_variable_info = SymbolicVariableInfo( + symbolic_variables=tuple( + SymbolicVariable(name="WRONG_NAME", symbolic_variable=sv.symbolic_variable) + if i == 0 + else sv + for i, sv in enumerate(mod._tensor_variable_info) + ) + ) + with pytest.raises(ValueError, match="missing expected prefix"): + self._variable_names_match(mod) diff --git a/tests/statespace/utils/test_coord_assignment.py b/tests/statespace/utils/test_coord_assignment.py index 17a1a5e91..89b6d4ce9 100644 --- a/tests/statespace/utils/test_coord_assignment.py +++ b/tests/statespace/utils/test_coord_assignment.py @@ -69,7 +69,7 @@ def _generate_timeseries(freq): @pytest.fixture() def create_model(load_dataset): - ss_mod = structural.LevelTrend(order=2).build("data", verbose=False) + ss_mod = structural.LevelTrend(order=2).build(verbose=False) def _create_model(f): data = load_dataset(f) @@ -122,7 +122,7 @@ def make_model(index): a = pd.DataFrame(index=index, columns=["A", "B", "C", "D"], data=np.arange(n * 4).reshape(n, 4)) mod = LevelTrend(order=2, innovations_order=[0, 1]) - ss_mod = mod.build(name="a", verbose=False) + ss_mod = mod.build(verbose=False) initial_trend_dims, sigma_trend_dims, P0_dims = ss_mod.param_dims.values() coords = ss_mod.coords