Skip to content

Commit 27a28ea

Browse files
committed
Address PR review: rename name helper, add configurable data_name, and move namespace test
1 parent 1bf3670 commit 27a28ea

3 files changed

Lines changed: 66 additions & 62 deletions

File tree

pymc_extras/statespace/core/statespace.py

Lines changed: 56 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ def __init__(
266266
measurement_error: bool = False,
267267
mode: str | None = None,
268268
name: str | None = None,
269+
data_name: str = "data",
269270
):
270271
self._fit_coords: dict[str, Sequence[str]] | None = None
271272
self._fit_dims: dict[str, Sequence[str]] | None = None
@@ -280,6 +281,7 @@ def __init__(
280281
self.k_states = k_states
281282
self.k_posdef = k_posdef
282283
self.name = name
284+
self.data_name = data_name
283285
self.measurement_error = measurement_error
284286
self.mode = mode
285287

@@ -311,11 +313,11 @@ def __init__(
311313
console = Console()
312314
console.print(self.requirement_table)
313315

314-
def graph_name(self, base: str) -> str:
316+
def prefixed_name(self, base_name: str) -> str:
315317
if not self.name:
316-
return base
318+
return base_name
317319
prefix = f"{self.name}_"
318-
return base if base.startswith(prefix) else f"{self.name}_{base}"
320+
return base_name if base_name.startswith(prefix) else f"{self.name}_{base_name}"
319321

320322
def _populate_properties(self) -> None:
321323
self._set_parameters()
@@ -626,7 +628,7 @@ def add_default_priors(self) -> None:
626628
raise NotImplementedError("The add_default_priors property has not been implemented!")
627629

628630
def make_and_register_variable(
629-
self, name, shape: int | tuple[int, ...] | None = None, dtype=floatX
631+
self, base_name, shape: int | tuple[int, ...] | None = None, dtype=floatX
630632
) -> pt.TensorVariable:
631633
"""
632634
Helper function to create a pytensor symbolic variable and register it in the _name_to_variable dictionary
@@ -655,27 +657,27 @@ def make_and_register_variable(
655657
An error is raised if the provided name has already been registered, or if the name is not present in the
656658
``param_names`` property.
657659
"""
658-
if name not in self.param_names:
660+
if base_name not in self.param_names:
659661
raise ValueError(
660-
f"{name} is not a model parameter. All placeholder variables should correspond to model "
662+
f"{base_name} is not a model parameter. All placeholder variables should correspond to model "
661663
f"parameters."
662664
)
663665

664-
gname = self.graph_name(name)
666+
name = self.prefixed_name(base_name)
665667

666-
if gname in self._tensor_variable_info:
668+
if name in self._tensor_variable_info:
667669
raise ValueError(
668-
f"{gname} is already a registered placeholder variable with shape "
669-
f"{self._tensor_variable_info[gname].type.shape}"
670+
f"{name} is already a registered placeholder variable with shape "
671+
f"{self._tensor_variable_info[name].type.shape}"
670672
)
671673

672-
placeholder = pt.tensor(gname, shape=shape, dtype=dtype)
673-
tensor_var = SymbolicVariable(name=gname, symbolic_variable=placeholder)
674+
placeholder = pt.tensor(name, shape=shape, dtype=dtype)
675+
tensor_var = SymbolicVariable(name=name, symbolic_variable=placeholder)
674676
self._tensor_variable_info = self._tensor_variable_info.add(tensor_var)
675677
return placeholder
676678

677679
def make_and_register_data(
678-
self, name: str, shape: int | tuple[int], dtype: str = floatX
680+
self, base_name: str, shape: int | tuple[int], dtype: str = floatX
679681
) -> Variable:
680682
r"""
681683
Helper function to create a pytensor symbolic variable and register it in the _name_to_data dictionary
@@ -697,22 +699,22 @@ def make_and_register_data(
697699
An error is raised if the provided name has already been registered, or if the name is not present in the
698700
``data_names`` property.
699701
"""
700-
if name not in self.data_names:
702+
if base_name not in self.data_names:
701703
raise ValueError(
702-
f"{name} is not a model data-variable. All placeholder variables should correspond to model "
703-
f"data-variables."
704+
f"{base_name} is not a model data variable. All placeholder variables should correspond to model "
705+
f"data variables."
704706
)
705707

706-
gname = self.graph_name(name)
708+
name = self.prefixed_name(base_name)
707709

708-
if gname in self._tensor_data_info:
710+
if name in self._tensor_data_info:
709711
raise ValueError(
710-
f"{gname} is already a registered placeholder variable with shape "
711-
f"{self._tensor_data_info[gname].type.shape}"
712+
f"{name} is already a registered placeholder variable with shape "
713+
f"{self._tensor_data_info[name].type.shape}"
712714
)
713715

714-
placeholder = pt.tensor(gname, shape=shape, dtype=dtype)
715-
tensor_data = SymbolicData(name=gname, symbolic_data=placeholder)
716+
placeholder = pt.tensor(name, shape=shape, dtype=dtype)
717+
tensor_data = SymbolicData(name=name, symbolic_data=placeholder)
716718
self._tensor_data_info = self._tensor_data_info.add(tensor_data)
717719
return placeholder
718720

@@ -816,12 +818,12 @@ def _save_exogenous_data_info(self):
816818
"""
817819
pymc_mod = modelcontext(None)
818820
for data_name in self.data_names:
819-
gname = self.graph_name(data_name)
820-
data = pymc_mod[gname]
821+
name = self.prefixed_name(data_name)
822+
data = pymc_mod[name]
821823
self._fit_exog_data[data_name] = {
822-
"name": gname,
824+
"name": name,
823825
"value": data.get_value(),
824-
"dims": pymc_mod.named_vars_to_dims.get(gname, None),
826+
"dims": pymc_mod.named_vars_to_dims.get(name, None),
825827
}
826828

827829
def _insert_random_variables(self):
@@ -860,8 +862,8 @@ def _insert_random_variables(self):
860862
found_params = []
861863
with pymc_model:
862864
for param_name in self.param_names:
863-
gname = self.graph_name(param_name)
864-
param = getattr(pymc_model, gname, None)
865+
name = self.prefixed_name(param_name)
866+
param = getattr(pymc_model, name, None)
865867
if param is not None:
866868
found_params.append(param_name)
867869

@@ -898,8 +900,8 @@ def _insert_data_variables(self):
898900
found_data = []
899901
with pymc_model:
900902
for data_name in data_names:
901-
gname = self.graph_name(data_name)
902-
data = getattr(pymc_model, gname, None)
903+
name = self.prefixed_name(data_name)
904+
data = getattr(pymc_model, name, None)
903905
if data is not None:
904906
found_data.append(data_name)
905907

@@ -1065,7 +1067,7 @@ def build_statespace_graph(
10651067
obs_coords=obs_coords,
10661068
register_data=register_data,
10671069
missing_fill_value=missing_fill_value,
1068-
data_name=self.graph_name("data"),
1070+
data_name=self.prefixed_name(self.data_name),
10691071
)
10701072

10711073
filter_outputs = self.kalman_filter.build_graph(
@@ -1164,16 +1166,16 @@ def _build_dummy_graph(self) -> None:
11641166
A list of pm.Flat variables representing all parameters estimated by the model.
11651167
"""
11661168

1167-
def infer_variable_shape(name):
1168-
gname = self.graph_name(name)
1169-
shape = self._name_to_variable[gname].type.shape
1169+
def infer_variable_shape(base_name):
1170+
name = self.prefixed_name(base_name)
1171+
shape = self._name_to_variable[name].type.shape
11701172
if not any(dim is None for dim in shape):
11711173
return shape
11721174

1173-
dim_names = self._fit_dims.get(gname, None)
1175+
dim_names = self._fit_dims.get(name, None)
11741176
if dim_names is None:
11751177
raise ValueError(
1176-
f"Could not infer shape for {name}, because it was not given coords during model"
1178+
f"Could not infer shape for {base_name}, because it was not given coords during model"
11771179
f"fitting"
11781180
)
11791181

@@ -1185,11 +1187,11 @@ def infer_variable_shape(name):
11851187
]
11861188
)
11871189

1188-
for name in self.param_names:
1190+
for base_name in self.param_names:
11891191
pm.Flat(
1190-
self.graph_name(name),
1191-
shape=infer_variable_shape(name),
1192-
dims=self._fit_dims.get(self.graph_name(name), None),
1192+
self.prefixed_name(base_name),
1193+
shape=infer_variable_shape(base_name),
1194+
dims=self._fit_dims.get(self.prefixed_name(base_name), None),
11931195
)
11941196

11951197
def _kalman_filter_outputs_from_dummy_graph(
@@ -1229,14 +1231,14 @@ def _kalman_filter_outputs_from_dummy_graph(
12291231
self._insert_random_variables()
12301232

12311233
for name in self.data_names:
1232-
if self.graph_name(name) not in pm_mod:
1234+
if self.prefixed_name(name) not in pm_mod:
12331235
pm.Data(**self._fit_exog_data[name])
12341236

12351237
self._insert_data_variables()
12361238

12371239
for name in self.data_names:
12381240
if name in scenario.keys():
1239-
pm.set_data({self.graph_name(name): scenario[name]})
1241+
pm.set_data({self.prefixed_name(name): scenario[name]})
12401242

12411243
x0, P0, c, d, T, Z, R, H, Q = self.unpack_statespace()
12421244

@@ -1251,7 +1253,7 @@ def _kalman_filter_outputs_from_dummy_graph(
12511253
obs_coords=obs_coords,
12521254
data_dims=data_dims,
12531255
register_data=True,
1254-
data_name=self.graph_name("data"),
1256+
data_name=self.prefixed_name(self.data_name),
12551257
)
12561258

12571259
filter_outputs = self.kalman_filter.build_graph(
@@ -1808,7 +1810,7 @@ def sample_statespace_matrices(
18081810
self._insert_random_variables()
18091811

18101812
for name in self.data_names:
1811-
pm.Data(name=self.graph_name(name), **self.data_info[name])
1813+
pm.Data(name=self.prefixed_name(name), **self.data_info[name])
18121814

18131815
self._insert_data_variables()
18141816
matrices = self.unpack_statespace()
@@ -1874,7 +1876,7 @@ def sample_filter_outputs(
18741876
n_obs=self.ssm.k_endog,
18751877
obs_coords=obs_coords,
18761878
register_data=True,
1877-
data_name=self.graph_name("data"),
1879+
data_name=self.prefixed_name(self.data_name),
18781880
)
18791881

18801882
filter_outputs = self.kalman_filter.build_graph(
@@ -2307,13 +2309,16 @@ def _build_forecast_model(
23072309

23082310
sub_dict = {
23092311
data_var: pt.as_tensor_variable(
2310-
data_var.get_value(), name=self.graph_name("data")
2312+
data_var.get_value(), name=self.prefixed_name(self.data_name)
23112313
)
23122314
for data_var in forecast_model.data_vars
23132315
}
23142316

23152317
missing_data_vars = np.setdiff1d(
2316-
ar1=[*[self.graph_name(name) for name in self.data_names], self.graph_name("data")],
2318+
ar1=[
2319+
*[self.prefixed_name(name) for name in self.data_names],
2320+
self.prefixed_name(self.data_name),
2321+
],
23172322
ar2=[k.name for k, _ in sub_dict.items()],
23182323
)
23192324
if missing_data_vars.size > 0:
@@ -2492,9 +2497,11 @@ def forecast(
24922497
with forecast_model:
24932498
if scenario is not None:
24942499
dummy_obs_data = np.zeros((len(forecast_index), self.k_endog))
2495-
scoped_scenario = {self.graph_name(name): value for name, value in scenario.items()}
2500+
scoped_scenario = {
2501+
self.prefixed_name(name): value for name, value in scenario.items()
2502+
}
24962503
pm.set_data(
2497-
scoped_scenario | {self.graph_name("data"): dummy_obs_data},
2504+
scoped_scenario | {self.prefixed_name(self.data_name): dummy_obs_data},
24982505
coords={"data_time": np.arange(len(forecast_index))},
24992506
)
25002507

tests/statespace/core/test_statespace.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,16 @@ def test_base_class_raises():
440440
)
441441

442442

443+
def test_two_statespace_models_can_coexist_with_names(monkeypatch):
444+
monkeypatch.setattr(PyMCStateSpace, "make_symbolic_graph", lambda self: None)
445+
446+
with pm.Model():
447+
ssm_a = PyMCStateSpace(k_endog=1, k_states=1, k_posdef=1, name="a")
448+
ssm_b = PyMCStateSpace(k_endog=1, k_states=1, k_posdef=1, name="b")
449+
450+
assert ssm_a.prefixed_name("data") != ssm_b.prefixed_name("data")
451+
452+
443453
def test_update_raises_if_missing_variables(ss_mod):
444454
with pm.Model() as mod:
445455
rho = pm.Normal("rho")

tests/statespace/test_namespace.py

Lines changed: 0 additions & 13 deletions
This file was deleted.

0 commit comments

Comments
 (0)