Skip to content

Commit cbea84d

Browse files
committed
Add name parameter to PyMCStateSpace and update related methods for unique graph naming
Create a test to ensure multiple state space models can coexist with distinct names.
1 parent 7319012 commit cbea84d

3 files changed

Lines changed: 76 additions & 34 deletions

File tree

pymc_extras/statespace/core/statespace.py

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ def __init__(
261261
verbose: bool = True,
262262
measurement_error: bool = False,
263263
mode: str | None = None,
264+
name: str | None = None,
264265
):
265266
self._fit_coords: dict[str, Sequence[str]] | None = None
266267
self._fit_dims: dict[str, Sequence[str]] | None = None
@@ -274,6 +275,7 @@ def __init__(
274275
self.k_endog = k_endog
275276
self.k_states = k_states
276277
self.k_posdef = k_posdef
278+
self.name = name
277279
self.measurement_error = measurement_error
278280
self.mode = mode
279281

@@ -305,6 +307,12 @@ def __init__(
305307
console = Console()
306308
console.print(self.requirement_table)
307309

310+
def graph_name(self, base: str) -> str:
311+
if not self.name:
312+
return base
313+
prefix = f"{self.name}_"
314+
return base if base.startswith(prefix) else f"{self.name}_{base}"
315+
308316
def _populate_properties(self) -> None:
309317
self._set_parameters()
310318
self._set_states()
@@ -649,14 +657,16 @@ def make_and_register_variable(
649657
f"parameters."
650658
)
651659

652-
if name in self._tensor_variable_info:
660+
gname = self.graph_name(name)
661+
662+
if gname in self._tensor_variable_info:
653663
raise ValueError(
654-
f"{name} is already a registered placeholder variable with shape "
655-
f"{self._tensor_variable_info[name].type.shape}"
664+
f"{gname} is already a registered placeholder variable with shape "
665+
f"{self._tensor_variable_info[gname].type.shape}"
656666
)
657667

658-
placeholder = pt.tensor(name, shape=shape, dtype=dtype)
659-
tensor_var = SymbolicVariable(name=name, symbolic_variable=placeholder)
668+
placeholder = pt.tensor(gname, shape=shape, dtype=dtype)
669+
tensor_var = SymbolicVariable(name=gname, symbolic_variable=placeholder)
660670
self._tensor_variable_info = self._tensor_variable_info.add(tensor_var)
661671
return placeholder
662672

@@ -685,18 +695,20 @@ def make_and_register_data(
685695
"""
686696
if name not in self.data_names:
687697
raise ValueError(
688-
f"{name} is not a model parameter. All placeholder variables should correspond to model "
689-
f"parameters."
698+
f"{name} is not a model data-variable. All placeholder variables should correspond to model "
699+
f"data-variables."
690700
)
691701

692-
if name in self._tensor_data_info:
702+
gname = self.graph_name(name)
703+
704+
if gname in self._tensor_data_info:
693705
raise ValueError(
694-
f"{name} is already a registered placeholder variable with shape "
695-
f"{self._tensor_data_info[name].type.shape}"
706+
f"{gname} is already a registered placeholder variable with shape "
707+
f"{self._tensor_data_info[gname].type.shape}"
696708
)
697709

698-
placeholder = pt.tensor(name, shape=shape, dtype=dtype)
699-
tensor_data = SymbolicData(name=name, symbolic_data=placeholder)
710+
placeholder = pt.tensor(gname, shape=shape, dtype=dtype)
711+
tensor_data = SymbolicData(name=gname, symbolic_data=placeholder)
700712
self._tensor_data_info = self._tensor_data_info.add(tensor_data)
701713
return placeholder
702714

@@ -800,11 +812,12 @@ def _save_exogenous_data_info(self):
800812
"""
801813
pymc_mod = modelcontext(None)
802814
for data_name in self.data_names:
803-
data = pymc_mod[data_name]
815+
gname = self.graph_name(data_name)
816+
data = pymc_mod[gname]
804817
self._fit_exog_data[data_name] = {
805-
"name": data_name,
818+
"name": gname,
806819
"value": data.get_value(),
807-
"dims": pymc_mod.named_vars_to_dims.get(data_name, None),
820+
"dims": pymc_mod.named_vars_to_dims.get(gname, None),
808821
}
809822

810823
def _insert_random_variables(self):
@@ -843,9 +856,10 @@ def _insert_random_variables(self):
843856
found_params = []
844857
with pymc_model:
845858
for param_name in self.param_names:
846-
param = getattr(pymc_model, param_name, None)
859+
gname = self.graph_name(param_name)
860+
param = getattr(pymc_model, gname, None)
847861
if param is not None:
848-
found_params.append(param.name)
862+
found_params.append(param_name)
849863

850864
missing_params = list(set(self.param_names) - set(found_params))
851865
if len(missing_params) > 0:
@@ -880,9 +894,10 @@ def _insert_data_variables(self):
880894
found_data = []
881895
with pymc_model:
882896
for data_name in data_names:
883-
data = getattr(pymc_model, data_name, None)
897+
gname = self.graph_name(data_name)
898+
data = getattr(pymc_model, gname, None)
884899
if data is not None:
885-
found_data.append(data.name)
900+
found_data.append(data_name)
886901

887902
missing_data = list(set(data_names) - set(found_data))
888903
if len(missing_data) > 0:
@@ -1046,6 +1061,7 @@ def build_statespace_graph(
10461061
obs_coords=obs_coords,
10471062
register_data=register_data,
10481063
missing_fill_value=missing_fill_value,
1064+
data_name=self.graph_name("data"),
10491065
)
10501066

10511067
filter_outputs = self.kalman_filter.build_graph(
@@ -1145,11 +1161,12 @@ def _build_dummy_graph(self) -> None:
11451161
"""
11461162

11471163
def infer_variable_shape(name):
1148-
shape = self._name_to_variable[name].type.shape
1164+
gname = self.graph_name(name)
1165+
shape = self._name_to_variable[gname].type.shape
11491166
if not any(dim is None for dim in shape):
11501167
return shape
11511168

1152-
dim_names = self._fit_dims.get(name, None)
1169+
dim_names = self._fit_dims.get(gname, None)
11531170
if dim_names is None:
11541171
raise ValueError(
11551172
f"Could not infer shape for {name}, because it was not given coords during model"
@@ -1166,9 +1183,9 @@ def infer_variable_shape(name):
11661183

11671184
for name in self.param_names:
11681185
pm.Flat(
1169-
name,
1186+
self.graph_name(name),
11701187
shape=infer_variable_shape(name),
1171-
dims=self._fit_dims.get(name, None),
1188+
dims=self._fit_dims.get(self.graph_name(name), None),
11721189
)
11731190

11741191
def _kalman_filter_outputs_from_dummy_graph(
@@ -1208,14 +1225,14 @@ def _kalman_filter_outputs_from_dummy_graph(
12081225
self._insert_random_variables()
12091226

12101227
for name in self.data_names:
1211-
if name not in pm_mod:
1228+
if self.graph_name(name) not in pm_mod:
12121229
pm.Data(**self._fit_exog_data[name])
12131230

12141231
self._insert_data_variables()
12151232

12161233
for name in self.data_names:
12171234
if name in scenario.keys():
1218-
pm.set_data({name: scenario[name]})
1235+
pm.set_data({self.graph_name(name): scenario[name]})
12191236

12201237
x0, P0, c, d, T, Z, R, H, Q = self.unpack_statespace()
12211238

@@ -1230,6 +1247,7 @@ def _kalman_filter_outputs_from_dummy_graph(
12301247
obs_coords=obs_coords,
12311248
data_dims=data_dims,
12321249
register_data=True,
1250+
data_name=self.graph_name("data"),
12331251
)
12341252

12351253
filter_outputs = self.kalman_filter.build_graph(
@@ -1786,7 +1804,7 @@ def sample_statespace_matrices(
17861804
self._insert_random_variables()
17871805

17881806
for name in self.data_names:
1789-
pm.Data(**self.data_info[name])
1807+
pm.Data(name=self.graph_name(name), **self.data_info[name])
17901808

17911809
self._insert_data_variables()
17921810
matrices = self.unpack_statespace()
@@ -1852,6 +1870,7 @@ def sample_filter_outputs(
18521870
n_obs=self.ssm.k_endog,
18531871
obs_coords=obs_coords,
18541872
register_data=True,
1873+
data_name=self.graph_name("data"),
18551874
)
18561875

18571876
filter_outputs = self.kalman_filter.build_graph(
@@ -2283,12 +2302,15 @@ def _build_forecast_model(
22832302
mu, cov = grouped_outputs[group_idx]
22842303

22852304
sub_dict = {
2286-
data_var: pt.as_tensor_variable(data_var.get_value(), name="data")
2305+
data_var: pt.as_tensor_variable(
2306+
data_var.get_value(), name=self.graph_name("data")
2307+
)
22872308
for data_var in forecast_model.data_vars
22882309
}
22892310

22902311
missing_data_vars = np.setdiff1d(
2291-
ar1=[*self.data_names, "data"], ar2=[k.name for k, _ in sub_dict.items()]
2312+
ar1=[*[self.graph_name(name) for name in self.data_names], self.graph_name("data")],
2313+
ar2=[k.name for k, _ in sub_dict.items()],
22922314
)
22932315
if missing_data_vars.size > 0:
22942316
raise ValueError(f"{missing_data_vars} data used for fitting not found!")
@@ -2466,8 +2488,9 @@ def forecast(
24662488
with forecast_model:
24672489
if scenario is not None:
24682490
dummy_obs_data = np.zeros((len(forecast_index), self.k_endog))
2491+
scoped_scenario = {self.graph_name(name): value for name, value in scenario.items()}
24692492
pm.set_data(
2470-
scenario | {"data": dummy_obs_data},
2493+
scoped_scenario | {self.graph_name("data"): dummy_obs_data},
24712494
coords={"data_time": np.arange(len(forecast_index))},
24722495
)
24732496

pymc_extras/statespace/utils/data_tools.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def preprocess_pandas_data(data, n_obs, obs_coords=None, check_column_names=Fals
121121
return preprocess_numpy_data(data.values, n_obs, obs_coords)
122122

123123

124-
def add_data_to_active_model(values, index, data_dims=None):
124+
def add_data_to_active_model(values, index, data_dims=None, data_name="data"):
125125
pymc_mod = modelcontext(None)
126126
if data_dims is None:
127127
data_dims = [TIME_DIM, OBS_STATE_DIM]
@@ -146,7 +146,7 @@ def add_data_to_active_model(values, index, data_dims=None):
146146
else:
147147
data_shape = (None, values.shape[-1])
148148

149-
data = pm.Data("data", values, dims=data_dims, shape=data_shape)
149+
data = pm.Data(data_name, values, dims=data_dims, shape=data_shape)
150150

151151
return data
152152

@@ -178,7 +178,13 @@ def mask_missing_values_in_data(values, missing_fill_value=None):
178178

179179

180180
def register_data_with_pymc(
181-
data, n_obs, obs_coords, register_data=True, missing_fill_value=None, data_dims=None
181+
data,
182+
n_obs,
183+
obs_coords,
184+
register_data=True,
185+
missing_fill_value=None,
186+
data_dims=None,
187+
data_name="data",
182188
):
183189
if isinstance(data, pt.TensorVariable | TensorSharedVariable):
184190
values, index = preprocess_tensor_data(data, n_obs, obs_coords)
@@ -192,7 +198,7 @@ def register_data_with_pymc(
192198
data, nan_mask = mask_missing_values_in_data(values, missing_fill_value)
193199

194200
if register_data:
195-
data = add_data_to_active_model(data, index, data_dims)
201+
data = add_data_to_active_model(data, index, data_dims, data_name=data_name)
196202
else:
197-
data = pytensor.shared(data, name="data")
203+
data = pytensor.shared(data, name=data_name)
198204
return data, nan_mask

tests/statespace/test_namespace.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import pymc as pm
2+
3+
from pymc_extras.statespace.core.statespace import PyMCStateSpace
4+
5+
6+
def test_two_statespace_models_can_coexist_with_names(monkeypatch):
7+
monkeypatch.setattr(PyMCStateSpace, "make_symbolic_graph", lambda self: None)
8+
9+
with pm.Model():
10+
ssm_a = PyMCStateSpace(k_endog=1, k_states=1, k_posdef=1, name="a")
11+
ssm_b = PyMCStateSpace(k_endog=1, k_states=1, k_posdef=1, name="b")
12+
13+
assert ssm_a.graph_name("data") != ssm_b.graph_name("data")

0 commit comments

Comments
 (0)