Skip to content

Commit e950db5

Browse files
committed
Fix structural model placeholder aliasing via graph_replace
StructuralTimeSeries.__init__ now uses pytensor.graph_replace to create fresh, prefixed placeholder variables in all SSM matrices when a model name is provided. This replaces the previous metadata-only rename approach, which left the actual graph placeholders unchanged and caused silent aliasing when the same Component instance was reused across multiple named StructuralTimeSeries models. Key changes: - Add _prefix_placeholder_variables() method that builds a replacement mapping from old to new (prefixed) placeholder variables, applies graph_replace across all SSM matrices, and rebuilds SymbolicVariableInfo and SymbolicDataInfo with aligned names. - Add _validate_symbolic_info() diagnostic helper. - Pass name=name through to PyMCStateSpace.__init__. - Add 8 focused tests covering the aliasing bug, name alignment, graph correctness, and the validation helper.
1 parent 43c8d29 commit e950db5

2 files changed

Lines changed: 252 additions & 2 deletions

File tree

pymc_extras/statespace/models/structural/core.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88
import xarray as xr
99

10-
from pytensor import Mode, Variable, config
10+
from pytensor import Mode, Variable, config, graph_replace
1111
from pytensor import tensor as pt
1212

1313
from pymc_extras.statespace.core.properties import (
@@ -117,7 +117,7 @@ class StructuralTimeSeries(PyMCStateSpace):
117117
def __init__(
118118
self,
119119
ssm: PytensorRepresentation,
120-
name: str,
120+
name: str | None,
121121
coords_info: CoordInfo,
122122
param_info: ParameterInfo,
123123
data_info: DataInfo,
@@ -184,8 +184,14 @@ def __init__(
184184
verbose=verbose,
185185
measurement_error=measurement_error,
186186
mode=mode,
187+
name=name,
187188
)
188189

190+
if name is not None:
191+
ssm, tensor_variable_info, tensor_data_info = self._prefix_placeholder_variables(
192+
ssm, tensor_variable_info, tensor_data_info
193+
)
194+
189195
self._tensor_variable_info = tensor_variable_info
190196
self._tensor_data_info = tensor_data_info
191197
self._component_info = component_info.copy()
@@ -234,6 +240,78 @@ def strip(names):
234240
self._shock_names = strip(self._shock_info.names)
235241
self._param_names = strip(self._param_info.names)
236242

243+
def _prefix_placeholder_variables(self, ssm, tensor_variable_info, tensor_data_info):
244+
"""Replace placeholder variables in SSM matrices with prefixed-name copies.
245+
246+
Creates new placeholder variables whose names are prefixed with the model name,
247+
applies graph_replace to swap them into all SSM matrices, and rebuilds the
248+
symbolic info objects to reference the new placeholders.
249+
"""
250+
replacements = {}
251+
new_variables = []
252+
for sv in tensor_variable_info:
253+
old_var = sv.symbolic_variable
254+
new_var = old_var.type(name=self.prefixed_name(sv.name))
255+
replacements[old_var] = new_var
256+
new_variables.append(SymbolicVariable(name=new_var.name, symbolic_variable=new_var))
257+
258+
new_data_entries = []
259+
for sd in tensor_data_info:
260+
old_var = sd.symbolic_data
261+
new_var = old_var.type(name=self.prefixed_name(sd.name))
262+
replacements[old_var] = new_var
263+
new_data_entries.append(SymbolicData(name=new_var.name, symbolic_data=new_var))
264+
265+
if not replacements:
266+
return ssm, tensor_variable_info, tensor_data_info
267+
268+
matrices = [getattr(ssm, name) for name in LONG_MATRIX_NAMES]
269+
replaced_matrices = graph_replace(matrices, replace=replacements, strict=False)
270+
271+
new_ssm = ssm.copy()
272+
for mat_name, new_mat in zip(LONG_MATRIX_NAMES, replaced_matrices):
273+
setattr(new_ssm, mat_name, new_mat)
274+
275+
new_variable_info = SymbolicVariableInfo(symbolic_variables=tuple(new_variables))
276+
new_data_info = SymbolicDataInfo(symbolic_data=tuple(new_data_entries))
277+
278+
return new_ssm, new_variable_info, new_data_info
279+
280+
def _validate_symbolic_info(self):
281+
"""Validate that symbolic info metadata names match actual Variable names.
282+
283+
Note: P0 is excluded from the name-match check because
284+
PytensorRepresentation.__setitem__ mutates the tensor name to the
285+
matrix name ("initial_state_cov") after registration.
286+
"""
287+
for sv in self._tensor_variable_info:
288+
# P0 is assigned as an entire matrix via ssm["initial_state_cov"] = P0,
289+
# which mutates the tensor's .name to "initial_state_cov".
290+
is_p0 = sv.name == "P0" or (self.name and sv.name == self.prefixed_name("P0"))
291+
if not is_p0 and sv.name != sv.symbolic_variable.name:
292+
raise ValueError(
293+
f"Variable name mismatch: metadata={sv.name}, "
294+
f"variable={sv.symbolic_variable.name}"
295+
)
296+
if self.name and not sv.name.startswith(f"{self.name}_"):
297+
raise ValueError(f"Variable {sv.name} missing expected prefix {self.name}_")
298+
299+
for sd in self._tensor_data_info:
300+
if sd.name != sd.symbolic_data.name:
301+
raise ValueError(
302+
f"Data name mismatch: metadata={sd.name}, " f"data={sd.symbolic_data.name}"
303+
)
304+
if self.name and not sd.name.startswith(f"{self.name}_"):
305+
raise ValueError(f"Data {sd.name} missing expected prefix {self.name}_")
306+
307+
var_ids = [id(sv.symbolic_variable) for sv in self._tensor_variable_info]
308+
if len(var_ids) != len(set(var_ids)):
309+
raise ValueError("Duplicate Variable objects in tensor_variable_info")
310+
311+
data_ids = [id(sd.symbolic_data) for sd in self._tensor_data_info]
312+
if len(data_ids) != len(set(data_ids)):
313+
raise ValueError("Duplicate Variable objects in tensor_data_info")
314+
237315
def _init_ssm(self, ssm: PytensorRepresentation, k_posdef: int) -> None:
238316
"""Initialize state space model representation."""
239317
self.ssm = ssm.copy()

tests/statespace/models/structural/test_core.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,42 @@
1818
RTOL = 0 if floatX.endswith("64") else 1e-6
1919

2020

21+
def _build_named_structural_model(name: str):
22+
return (
23+
st.LevelTrend(order=1, innovations_order=1)
24+
+ st.Regression(name="reg", state_names=["x"])
25+
+ st.MeasurementError(name="obs")
26+
).build(name=name, verbose=False)
27+
28+
29+
def test_structural_name_propagates_to_base_and_scopes_p0():
30+
ss_mod = _build_named_structural_model(name="m1")
31+
32+
assert ss_mod.name == "m1"
33+
assert "P0" in ss_mod.param_names
34+
assert ss_mod.prefixed_name("P0") in ss_mod._name_to_variable
35+
assert "P0" not in ss_mod._name_to_variable
36+
37+
38+
def test_named_structural_models_do_not_collide_in_placeholder_registries():
39+
with pm.Model():
40+
m1 = _build_named_structural_model(name="m1")
41+
m2 = _build_named_structural_model(name="m2")
42+
43+
var_keys_1 = set(m1._name_to_variable)
44+
var_keys_2 = set(m2._name_to_variable)
45+
data_keys_1 = set(m1._name_to_data)
46+
data_keys_2 = set(m2._name_to_data)
47+
48+
assert var_keys_1.isdisjoint(var_keys_2)
49+
assert data_keys_1.isdisjoint(data_keys_2)
50+
51+
assert var_keys_1 == {m1.prefixed_name(name) for name in m1.param_names}
52+
assert var_keys_2 == {m2.prefixed_name(name) for name in m2.param_names}
53+
assert data_keys_1 == {m1.prefixed_name(name) for name in m1.data_names}
54+
assert data_keys_2 == {m2.prefixed_name(name) for name in m2.data_names}
55+
56+
2157
def test_add_components():
2258
ll = st.LevelTrend(order=2)
2359
se = st.TimeSeasonality(name="seasonal", season_length=12)
@@ -195,3 +231,139 @@ def test_sequence_type_component_arguments(arg_type):
195231

196232
assert ss_mod.k_endog == len(state_names)
197233
assert sorted(ss_mod.observed_states) == sorted(list(state_names))
234+
235+
236+
class TestGraphReplacePlaceholderNamespacing:
237+
"""Tests for the graph_replace-based placeholder namespacing in StructuralTimeSeries."""
238+
239+
def test_same_component_reused_in_two_named_models_no_aliasing(self):
240+
"""A single Component used in two named models creates independent placeholders."""
241+
trend = st.LevelTrend(order=1, innovations_order=1)
242+
243+
m1 = trend.build(name="m1", verbose=False)
244+
m2 = trend.build(name="m2", verbose=False)
245+
246+
# All m1 placeholders should be prefixed with "m1_"
247+
for sv in m1._tensor_variable_info:
248+
assert sv.name.startswith("m1_"), f"Expected m1_ prefix, got {sv.name}"
249+
# P0 tensor name is mutated by PytensorRepresentation.__setitem__
250+
if not sv.name.endswith("_P0"):
251+
assert sv.name == sv.symbolic_variable.name
252+
253+
# All m2 placeholders should be prefixed with "m2_"
254+
for sv in m2._tensor_variable_info:
255+
assert sv.name.startswith("m2_"), f"Expected m2_ prefix, got {sv.name}"
256+
if not sv.name.endswith("_P0"):
257+
assert sv.name == sv.symbolic_variable.name
258+
259+
# No overlap in placeholder Variable objects between models
260+
m1_var_ids = {id(sv.symbolic_variable) for sv in m1._tensor_variable_info}
261+
m2_var_ids = {id(sv.symbolic_variable) for sv in m2._tensor_variable_info}
262+
assert m1_var_ids.isdisjoint(m2_var_ids)
263+
264+
def test_reused_component_with_data_placeholders(self):
265+
"""Regression (data placeholders) also get independent prefixed copies."""
266+
comp = st.LevelTrend(order=1, innovations_order=1) + st.Regression(
267+
name="reg", state_names=["x"]
268+
)
269+
270+
m1 = comp.build(name="m1", verbose=False)
271+
m2 = comp.build(name="m2", verbose=False)
272+
273+
# Variable placeholders
274+
m1_var_ids = {id(sv.symbolic_variable) for sv in m1._tensor_variable_info}
275+
m2_var_ids = {id(sv.symbolic_variable) for sv in m2._tensor_variable_info}
276+
assert m1_var_ids.isdisjoint(m2_var_ids)
277+
278+
# Data placeholders
279+
m1_data_ids = {id(sd.symbolic_data) for sd in m1._tensor_data_info}
280+
m2_data_ids = {id(sd.symbolic_data) for sd in m2._tensor_data_info}
281+
assert m1_data_ids.isdisjoint(m2_data_ids)
282+
283+
for sd in m1._tensor_data_info:
284+
assert sd.name.startswith("m1_")
285+
assert sd.name == sd.symbolic_data.name
286+
for sd in m2._tensor_data_info:
287+
assert sd.name.startswith("m2_")
288+
assert sd.name == sd.symbolic_data.name
289+
290+
def test_symbolic_info_name_matches_variable_name(self):
291+
"""After prefixing, metadata names must match actual Variable.name."""
292+
mod = (st.LevelTrend(order=1, innovations_order=1) + st.MeasurementError(name="obs")).build(
293+
name="test_model", verbose=False
294+
)
295+
296+
for sv in mod._tensor_variable_info:
297+
# P0 tensor name is mutated by PytensorRepresentation.__setitem__
298+
if not sv.name.endswith("_P0"):
299+
assert (
300+
sv.name == sv.symbolic_variable.name
301+
), f"Mismatch: metadata={sv.name}, variable={sv.symbolic_variable.name}"
302+
303+
for sd in mod._tensor_data_info:
304+
assert (
305+
sd.name == sd.symbolic_data.name
306+
), f"Mismatch: metadata={sd.name}, data={sd.symbolic_data.name}"
307+
308+
# Validate via the dedicated helper
309+
mod._validate_symbolic_info()
310+
311+
def test_unnamed_model_preserves_original_placeholders(self):
312+
"""When name is None, placeholders should be unchanged from the component."""
313+
trend = st.LevelTrend(order=1, innovations_order=1)
314+
mod = trend.build(name=None, verbose=False)
315+
316+
for sv in mod._tensor_variable_info:
317+
# P0 tensor name is mutated by PytensorRepresentation.__setitem__
318+
if sv.name != "P0":
319+
assert sv.name == sv.symbolic_variable.name
320+
321+
def test_prefixed_placeholders_are_in_ssm_graph(self):
322+
"""Old unprefixed placeholders must not appear in the SSM matrices of
323+
a named model; new prefixed ones must."""
324+
from pytensor.graph.traversal import explicit_graph_inputs
325+
326+
from pymc_extras.statespace.utils.constants import LONG_MATRIX_NAMES
327+
328+
trend = st.LevelTrend(order=1, innovations_order=1)
329+
mod = trend.build(name="ns", verbose=False)
330+
331+
# Collect all explicit graph inputs across all SSM matrices
332+
all_matrices = [getattr(mod.ssm, name) for name in LONG_MATRIX_NAMES]
333+
graph_inputs = set(explicit_graph_inputs(all_matrices))
334+
graph_input_names = {v.name for v in graph_inputs if hasattr(v, "name") and v.name}
335+
336+
# Every non-P0 registered variable should appear in the graph as a prefixed input
337+
# (P0 is excluded because __setitem__ renames its tensor to "initial_state_cov")
338+
expected_names = {
339+
sv.name for sv in mod._tensor_variable_info if not sv.name.endswith("_P0")
340+
}
341+
# graph inputs should be a superset of registered variable names
342+
assert (
343+
expected_names <= graph_input_names
344+
), f"Missing from graph: {expected_names - graph_input_names}"
345+
346+
# Original unprefixed names (pre-prefix) should NOT appear
347+
original_names = {"initial_level_trend", "sigma_level_trend"}
348+
assert original_names.isdisjoint(
349+
graph_input_names
350+
), f"Old unprefixed names still in graph: {original_names & graph_input_names}"
351+
352+
def test_validate_symbolic_info_catches_mismatch(self):
353+
"""_validate_symbolic_info should raise on name/variable mismatch."""
354+
from pymc_extras.statespace.core.properties import SymbolicVariable, SymbolicVariableInfo
355+
356+
mod = st.LevelTrend(order=1, innovations_order=1).build(name="v", verbose=False)
357+
358+
# Corrupt a metadata name to trigger validation error
359+
corrupted = SymbolicVariableInfo(
360+
symbolic_variables=tuple(
361+
SymbolicVariable(name="WRONG_NAME", symbolic_variable=sv.symbolic_variable)
362+
if i == 0
363+
else sv
364+
for i, sv in enumerate(mod._tensor_variable_info)
365+
)
366+
)
367+
mod._tensor_variable_info = corrupted
368+
with pytest.raises(ValueError, match="Variable name mismatch"):
369+
mod._validate_symbolic_info()

0 commit comments

Comments
 (0)