Skip to content

Commit 1bf3670

Browse files
committed
Add name parameter to PyMCStateSpace and update related methods for unique graph naming
Create a test to ensure multiple state space models can coexist with distinct names.
1 parent 7319012 commit 1bf3670

3 files changed

Lines changed: 80 additions & 34 deletions

File tree

pymc_extras/statespace/core/statespace.py

Lines changed: 56 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ class PyMCStateSpace:
138138
Regardless of whether a mode is specified, it can always be overwritten via the ``compile_kwargs`` argument
139139
to all sampling methods.
140140
141+
name : str, optional
142+
Prefix used to namespace internal graph variable and data names so multiple state space models can coexist
143+
in the same PyMC model without naming collisions. If ``None``, the default naming behavior is used.
144+
141145
Notes
142146
-----
143147
Based on the statsmodels statespace implementation https://github.com/statsmodels/statsmodels/blob/main/statsmodels/tsa/statespace/representation.py,
@@ -261,6 +265,7 @@ def __init__(
261265
verbose: bool = True,
262266
measurement_error: bool = False,
263267
mode: str | None = None,
268+
name: str | None = None,
264269
):
265270
self._fit_coords: dict[str, Sequence[str]] | None = None
266271
self._fit_dims: dict[str, Sequence[str]] | None = None
@@ -274,6 +279,7 @@ def __init__(
274279
self.k_endog = k_endog
275280
self.k_states = k_states
276281
self.k_posdef = k_posdef
282+
self.name = name
277283
self.measurement_error = measurement_error
278284
self.mode = mode
279285

@@ -305,6 +311,12 @@ def __init__(
305311
console = Console()
306312
console.print(self.requirement_table)
307313

314+
def graph_name(self, base: str) -> str:
315+
if not self.name:
316+
return base
317+
prefix = f"{self.name}_"
318+
return base if base.startswith(prefix) else f"{self.name}_{base}"
319+
308320
def _populate_properties(self) -> None:
309321
self._set_parameters()
310322
self._set_states()
@@ -649,14 +661,16 @@ def make_and_register_variable(
649661
f"parameters."
650662
)
651663

652-
if name in self._tensor_variable_info:
664+
gname = self.graph_name(name)
665+
666+
if gname in self._tensor_variable_info:
653667
raise ValueError(
654-
f"{name} is already a registered placeholder variable with shape "
655-
f"{self._tensor_variable_info[name].type.shape}"
668+
f"{gname} is already a registered placeholder variable with shape "
669+
f"{self._tensor_variable_info[gname].type.shape}"
656670
)
657671

658-
placeholder = pt.tensor(name, shape=shape, dtype=dtype)
659-
tensor_var = SymbolicVariable(name=name, symbolic_variable=placeholder)
672+
placeholder = pt.tensor(gname, shape=shape, dtype=dtype)
673+
tensor_var = SymbolicVariable(name=gname, symbolic_variable=placeholder)
660674
self._tensor_variable_info = self._tensor_variable_info.add(tensor_var)
661675
return placeholder
662676

@@ -685,18 +699,20 @@ def make_and_register_data(
685699
"""
686700
if name not in self.data_names:
687701
raise ValueError(
688-
f"{name} is not a model parameter. All placeholder variables should correspond to model "
689-
f"parameters."
702+
f"{name} is not a model data-variable. All placeholder variables should correspond to model "
703+
f"data-variables."
690704
)
691705

692-
if name in self._tensor_data_info:
706+
gname = self.graph_name(name)
707+
708+
if gname in self._tensor_data_info:
693709
raise ValueError(
694-
f"{name} is already a registered placeholder variable with shape "
695-
f"{self._tensor_data_info[name].type.shape}"
710+
f"{gname} is already a registered placeholder variable with shape "
711+
f"{self._tensor_data_info[gname].type.shape}"
696712
)
697713

698-
placeholder = pt.tensor(name, shape=shape, dtype=dtype)
699-
tensor_data = SymbolicData(name=name, symbolic_data=placeholder)
714+
placeholder = pt.tensor(gname, shape=shape, dtype=dtype)
715+
tensor_data = SymbolicData(name=gname, symbolic_data=placeholder)
700716
self._tensor_data_info = self._tensor_data_info.add(tensor_data)
701717
return placeholder
702718

@@ -800,11 +816,12 @@ def _save_exogenous_data_info(self):
800816
"""
801817
pymc_mod = modelcontext(None)
802818
for data_name in self.data_names:
803-
data = pymc_mod[data_name]
819+
gname = self.graph_name(data_name)
820+
data = pymc_mod[gname]
804821
self._fit_exog_data[data_name] = {
805-
"name": data_name,
822+
"name": gname,
806823
"value": data.get_value(),
807-
"dims": pymc_mod.named_vars_to_dims.get(data_name, None),
824+
"dims": pymc_mod.named_vars_to_dims.get(gname, None),
808825
}
809826

810827
def _insert_random_variables(self):
@@ -843,9 +860,10 @@ def _insert_random_variables(self):
843860
found_params = []
844861
with pymc_model:
845862
for param_name in self.param_names:
846-
param = getattr(pymc_model, param_name, None)
863+
gname = self.graph_name(param_name)
864+
param = getattr(pymc_model, gname, None)
847865
if param is not None:
848-
found_params.append(param.name)
866+
found_params.append(param_name)
849867

850868
missing_params = list(set(self.param_names) - set(found_params))
851869
if len(missing_params) > 0:
@@ -880,9 +898,10 @@ def _insert_data_variables(self):
880898
found_data = []
881899
with pymc_model:
882900
for data_name in data_names:
883-
data = getattr(pymc_model, data_name, None)
901+
gname = self.graph_name(data_name)
902+
data = getattr(pymc_model, gname, None)
884903
if data is not None:
885-
found_data.append(data.name)
904+
found_data.append(data_name)
886905

887906
missing_data = list(set(data_names) - set(found_data))
888907
if len(missing_data) > 0:
@@ -1046,6 +1065,7 @@ def build_statespace_graph(
10461065
obs_coords=obs_coords,
10471066
register_data=register_data,
10481067
missing_fill_value=missing_fill_value,
1068+
data_name=self.graph_name("data"),
10491069
)
10501070

10511071
filter_outputs = self.kalman_filter.build_graph(
@@ -1145,11 +1165,12 @@ def _build_dummy_graph(self) -> None:
11451165
"""
11461166

11471167
def infer_variable_shape(name):
1148-
shape = self._name_to_variable[name].type.shape
1168+
gname = self.graph_name(name)
1169+
shape = self._name_to_variable[gname].type.shape
11491170
if not any(dim is None for dim in shape):
11501171
return shape
11511172

1152-
dim_names = self._fit_dims.get(name, None)
1173+
dim_names = self._fit_dims.get(gname, None)
11531174
if dim_names is None:
11541175
raise ValueError(
11551176
f"Could not infer shape for {name}, because it was not given coords during model"
@@ -1166,9 +1187,9 @@ def infer_variable_shape(name):
11661187

11671188
for name in self.param_names:
11681189
pm.Flat(
1169-
name,
1190+
self.graph_name(name),
11701191
shape=infer_variable_shape(name),
1171-
dims=self._fit_dims.get(name, None),
1192+
dims=self._fit_dims.get(self.graph_name(name), None),
11721193
)
11731194

11741195
def _kalman_filter_outputs_from_dummy_graph(
@@ -1208,14 +1229,14 @@ def _kalman_filter_outputs_from_dummy_graph(
12081229
self._insert_random_variables()
12091230

12101231
for name in self.data_names:
1211-
if name not in pm_mod:
1232+
if self.graph_name(name) not in pm_mod:
12121233
pm.Data(**self._fit_exog_data[name])
12131234

12141235
self._insert_data_variables()
12151236

12161237
for name in self.data_names:
12171238
if name in scenario.keys():
1218-
pm.set_data({name: scenario[name]})
1239+
pm.set_data({self.graph_name(name): scenario[name]})
12191240

12201241
x0, P0, c, d, T, Z, R, H, Q = self.unpack_statespace()
12211242

@@ -1230,6 +1251,7 @@ def _kalman_filter_outputs_from_dummy_graph(
12301251
obs_coords=obs_coords,
12311252
data_dims=data_dims,
12321253
register_data=True,
1254+
data_name=self.graph_name("data"),
12331255
)
12341256

12351257
filter_outputs = self.kalman_filter.build_graph(
@@ -1786,7 +1808,7 @@ def sample_statespace_matrices(
17861808
self._insert_random_variables()
17871809

17881810
for name in self.data_names:
1789-
pm.Data(**self.data_info[name])
1811+
pm.Data(name=self.graph_name(name), **self.data_info[name])
17901812

17911813
self._insert_data_variables()
17921814
matrices = self.unpack_statespace()
@@ -1852,6 +1874,7 @@ def sample_filter_outputs(
18521874
n_obs=self.ssm.k_endog,
18531875
obs_coords=obs_coords,
18541876
register_data=True,
1877+
data_name=self.graph_name("data"),
18551878
)
18561879

18571880
filter_outputs = self.kalman_filter.build_graph(
@@ -2283,12 +2306,15 @@ def _build_forecast_model(
22832306
mu, cov = grouped_outputs[group_idx]
22842307

22852308
sub_dict = {
2286-
data_var: pt.as_tensor_variable(data_var.get_value(), name="data")
2309+
data_var: pt.as_tensor_variable(
2310+
data_var.get_value(), name=self.graph_name("data")
2311+
)
22872312
for data_var in forecast_model.data_vars
22882313
}
22892314

22902315
missing_data_vars = np.setdiff1d(
2291-
ar1=[*self.data_names, "data"], ar2=[k.name for k, _ in sub_dict.items()]
2316+
ar1=[*[self.graph_name(name) for name in self.data_names], self.graph_name("data")],
2317+
ar2=[k.name for k, _ in sub_dict.items()],
22922318
)
22932319
if missing_data_vars.size > 0:
22942320
raise ValueError(f"{missing_data_vars} data used for fitting not found!")
@@ -2466,8 +2492,9 @@ def forecast(
24662492
with forecast_model:
24672493
if scenario is not None:
24682494
dummy_obs_data = np.zeros((len(forecast_index), self.k_endog))
2495+
scoped_scenario = {self.graph_name(name): value for name, value in scenario.items()}
24692496
pm.set_data(
2470-
scenario | {"data": dummy_obs_data},
2497+
scoped_scenario | {self.graph_name("data"): dummy_obs_data},
24712498
coords={"data_time": np.arange(len(forecast_index))},
24722499
)
24732500

pymc_extras/statespace/utils/data_tools.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def preprocess_pandas_data(data, n_obs, obs_coords=None, check_column_names=Fals
121121
return preprocess_numpy_data(data.values, n_obs, obs_coords)
122122

123123

124-
def add_data_to_active_model(values, index, data_dims=None):
124+
def add_data_to_active_model(values, index, data_dims=None, data_name="data"):
125125
pymc_mod = modelcontext(None)
126126
if data_dims is None:
127127
data_dims = [TIME_DIM, OBS_STATE_DIM]
@@ -146,7 +146,7 @@ def add_data_to_active_model(values, index, data_dims=None):
146146
else:
147147
data_shape = (None, values.shape[-1])
148148

149-
data = pm.Data("data", values, dims=data_dims, shape=data_shape)
149+
data = pm.Data(data_name, values, dims=data_dims, shape=data_shape)
150150

151151
return data
152152

@@ -178,7 +178,13 @@ def mask_missing_values_in_data(values, missing_fill_value=None):
178178

179179

180180
def register_data_with_pymc(
181-
data, n_obs, obs_coords, register_data=True, missing_fill_value=None, data_dims=None
181+
data,
182+
n_obs,
183+
obs_coords,
184+
register_data=True,
185+
missing_fill_value=None,
186+
data_dims=None,
187+
data_name="data",
182188
):
183189
if isinstance(data, pt.TensorVariable | TensorSharedVariable):
184190
values, index = preprocess_tensor_data(data, n_obs, obs_coords)
@@ -192,7 +198,7 @@ def register_data_with_pymc(
192198
data, nan_mask = mask_missing_values_in_data(values, missing_fill_value)
193199

194200
if register_data:
195-
data = add_data_to_active_model(data, index, data_dims)
201+
data = add_data_to_active_model(data, index, data_dims, data_name=data_name)
196202
else:
197-
data = pytensor.shared(data, name="data")
203+
data = pytensor.shared(data, name=data_name)
198204
return data, nan_mask

tests/statespace/test_namespace.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import pymc as pm
2+
3+
from pymc_extras.statespace.core.statespace import PyMCStateSpace
4+
5+
6+
def test_two_statespace_models_can_coexist_with_names(monkeypatch):
7+
monkeypatch.setattr(PyMCStateSpace, "make_symbolic_graph", lambda self: None)
8+
9+
with pm.Model():
10+
ssm_a = PyMCStateSpace(k_endog=1, k_states=1, k_posdef=1, name="a")
11+
ssm_b = PyMCStateSpace(k_endog=1, k_states=1, k_posdef=1, name="b")
12+
13+
assert ssm_a.graph_name("data") != ssm_b.graph_name("data")

0 commit comments

Comments
 (0)