@@ -138,6 +138,10 @@ class PyMCStateSpace:
138138 Regardless of whether a mode is specified, it can always be overwritten via the ``compile_kwargs`` argument
139139 to all sampling methods.
140140
141+ name : str, optional
142+ Prefix used to namespace internal graph variable and data names so multiple state space models can coexist
143+ in the same PyMC model without naming collisions. If ``None``, the default naming behavior is used.
144+
141145 Notes
142146 -----
143147 Based on the statsmodels statespace implementation https://github.com/statsmodels/statsmodels/blob/main/statsmodels/tsa/statespace/representation.py,
@@ -261,6 +265,7 @@ def __init__(
261265 verbose : bool = True ,
262266 measurement_error : bool = False ,
263267 mode : str | None = None ,
268+ name : str | None = None ,
264269 ):
265270 self ._fit_coords : dict [str , Sequence [str ]] | None = None
266271 self ._fit_dims : dict [str , Sequence [str ]] | None = None
@@ -274,6 +279,7 @@ def __init__(
274279 self .k_endog = k_endog
275280 self .k_states = k_states
276281 self .k_posdef = k_posdef
282+ self .name = name
277283 self .measurement_error = measurement_error
278284 self .mode = mode
279285
@@ -305,6 +311,12 @@ def __init__(
305311 console = Console ()
306312 console .print (self .requirement_table )
307313
314+ def graph_name (self , base : str ) -> str :
315+ if not self .name :
316+ return base
317+ prefix = f"{ self .name } _"
318+ return base if base .startswith (prefix ) else f"{ self .name } _{ base } "
319+
308320 def _populate_properties (self ) -> None :
309321 self ._set_parameters ()
310322 self ._set_states ()
@@ -649,14 +661,16 @@ def make_and_register_variable(
649661 f"parameters."
650662 )
651663
652- if name in self ._tensor_variable_info :
664+ gname = self .graph_name (name )
665+
666+ if gname in self ._tensor_variable_info :
653667 raise ValueError (
654- f"{ name } is already a registered placeholder variable with shape "
655- f"{ self ._tensor_variable_info [name ].type .shape } "
668+ f"{ gname } is already a registered placeholder variable with shape "
669+ f"{ self ._tensor_variable_info [gname ].type .shape } "
656670 )
657671
658- placeholder = pt .tensor (name , shape = shape , dtype = dtype )
659- tensor_var = SymbolicVariable (name = name , symbolic_variable = placeholder )
672+ placeholder = pt .tensor (gname , shape = shape , dtype = dtype )
673+ tensor_var = SymbolicVariable (name = gname , symbolic_variable = placeholder )
660674 self ._tensor_variable_info = self ._tensor_variable_info .add (tensor_var )
661675 return placeholder
662676
@@ -685,18 +699,20 @@ def make_and_register_data(
685699 """
686700 if name not in self .data_names :
687701 raise ValueError (
688- f"{ name } is not a model parameter . All placeholder variables should correspond to model "
689- f"parameters ."
702+ f"{ name } is not a model data-variable . All placeholder variables should correspond to model "
703+ f"data-variables ."
690704 )
691705
692- if name in self ._tensor_data_info :
706+ gname = self .graph_name (name )
707+
708+ if gname in self ._tensor_data_info :
693709 raise ValueError (
694- f"{ name } is already a registered placeholder variable with shape "
695- f"{ self ._tensor_data_info [name ].type .shape } "
710+ f"{ gname } is already a registered placeholder variable with shape "
711+ f"{ self ._tensor_data_info [gname ].type .shape } "
696712 )
697713
698- placeholder = pt .tensor (name , shape = shape , dtype = dtype )
699- tensor_data = SymbolicData (name = name , symbolic_data = placeholder )
714+ placeholder = pt .tensor (gname , shape = shape , dtype = dtype )
715+ tensor_data = SymbolicData (name = gname , symbolic_data = placeholder )
700716 self ._tensor_data_info = self ._tensor_data_info .add (tensor_data )
701717 return placeholder
702718
@@ -800,11 +816,12 @@ def _save_exogenous_data_info(self):
800816 """
801817 pymc_mod = modelcontext (None )
802818 for data_name in self .data_names :
803- data = pymc_mod [data_name ]
819+ gname = self .graph_name (data_name )
820+ data = pymc_mod [gname ]
804821 self ._fit_exog_data [data_name ] = {
805- "name" : data_name ,
822+ "name" : gname ,
806823 "value" : data .get_value (),
807- "dims" : pymc_mod .named_vars_to_dims .get (data_name , None ),
824+ "dims" : pymc_mod .named_vars_to_dims .get (gname , None ),
808825 }
809826
810827 def _insert_random_variables (self ):
@@ -843,9 +860,10 @@ def _insert_random_variables(self):
843860 found_params = []
844861 with pymc_model :
845862 for param_name in self .param_names :
846- param = getattr (pymc_model , param_name , None )
863+ gname = self .graph_name (param_name )
864+ param = getattr (pymc_model , gname , None )
847865 if param is not None :
848- found_params .append (param . name )
866+ found_params .append (param_name )
849867
850868 missing_params = list (set (self .param_names ) - set (found_params ))
851869 if len (missing_params ) > 0 :
@@ -880,9 +898,10 @@ def _insert_data_variables(self):
880898 found_data = []
881899 with pymc_model :
882900 for data_name in data_names :
883- data = getattr (pymc_model , data_name , None )
901+ gname = self .graph_name (data_name )
902+ data = getattr (pymc_model , gname , None )
884903 if data is not None :
885- found_data .append (data . name )
904+ found_data .append (data_name )
886905
887906 missing_data = list (set (data_names ) - set (found_data ))
888907 if len (missing_data ) > 0 :
@@ -1046,6 +1065,7 @@ def build_statespace_graph(
10461065 obs_coords = obs_coords ,
10471066 register_data = register_data ,
10481067 missing_fill_value = missing_fill_value ,
1068+ data_name = self .graph_name ("data" ),
10491069 )
10501070
10511071 filter_outputs = self .kalman_filter .build_graph (
@@ -1145,11 +1165,12 @@ def _build_dummy_graph(self) -> None:
11451165 """
11461166
11471167 def infer_variable_shape (name ):
1148- shape = self ._name_to_variable [name ].type .shape
1168+ gname = self .graph_name (name )
1169+ shape = self ._name_to_variable [gname ].type .shape
11491170 if not any (dim is None for dim in shape ):
11501171 return shape
11511172
1152- dim_names = self ._fit_dims .get (name , None )
1173+ dim_names = self ._fit_dims .get (gname , None )
11531174 if dim_names is None :
11541175 raise ValueError (
11551176 f"Could not infer shape for { name } , because it was not given coords during model"
@@ -1166,9 +1187,9 @@ def infer_variable_shape(name):
11661187
11671188 for name in self .param_names :
11681189 pm .Flat (
1169- name ,
1190+ self . graph_name ( name ) ,
11701191 shape = infer_variable_shape (name ),
1171- dims = self ._fit_dims .get (name , None ),
1192+ dims = self ._fit_dims .get (self . graph_name ( name ) , None ),
11721193 )
11731194
11741195 def _kalman_filter_outputs_from_dummy_graph (
@@ -1208,14 +1229,14 @@ def _kalman_filter_outputs_from_dummy_graph(
12081229 self ._insert_random_variables ()
12091230
12101231 for name in self .data_names :
1211- if name not in pm_mod :
1232+ if self . graph_name ( name ) not in pm_mod :
12121233 pm .Data (** self ._fit_exog_data [name ])
12131234
12141235 self ._insert_data_variables ()
12151236
12161237 for name in self .data_names :
12171238 if name in scenario .keys ():
1218- pm .set_data ({name : scenario [name ]})
1239+ pm .set_data ({self . graph_name ( name ) : scenario [name ]})
12191240
12201241 x0 , P0 , c , d , T , Z , R , H , Q = self .unpack_statespace ()
12211242
@@ -1230,6 +1251,7 @@ def _kalman_filter_outputs_from_dummy_graph(
12301251 obs_coords = obs_coords ,
12311252 data_dims = data_dims ,
12321253 register_data = True ,
1254+ data_name = self .graph_name ("data" ),
12331255 )
12341256
12351257 filter_outputs = self .kalman_filter .build_graph (
@@ -1786,7 +1808,7 @@ def sample_statespace_matrices(
17861808 self ._insert_random_variables ()
17871809
17881810 for name in self .data_names :
1789- pm .Data (** self .data_info [name ])
1811+ pm .Data (name = self . graph_name ( name ), ** self .data_info [name ])
17901812
17911813 self ._insert_data_variables ()
17921814 matrices = self .unpack_statespace ()
@@ -1852,6 +1874,7 @@ def sample_filter_outputs(
18521874 n_obs = self .ssm .k_endog ,
18531875 obs_coords = obs_coords ,
18541876 register_data = True ,
1877+ data_name = self .graph_name ("data" ),
18551878 )
18561879
18571880 filter_outputs = self .kalman_filter .build_graph (
@@ -2283,12 +2306,15 @@ def _build_forecast_model(
22832306 mu , cov = grouped_outputs [group_idx ]
22842307
22852308 sub_dict = {
2286- data_var : pt .as_tensor_variable (data_var .get_value (), name = "data" )
2309+ data_var : pt .as_tensor_variable (
2310+ data_var .get_value (), name = self .graph_name ("data" )
2311+ )
22872312 for data_var in forecast_model .data_vars
22882313 }
22892314
22902315 missing_data_vars = np .setdiff1d (
2291- ar1 = [* self .data_names , "data" ], ar2 = [k .name for k , _ in sub_dict .items ()]
2316+ ar1 = [* [self .graph_name (name ) for name in self .data_names ], self .graph_name ("data" )],
2317+ ar2 = [k .name for k , _ in sub_dict .items ()],
22922318 )
22932319 if missing_data_vars .size > 0 :
22942320 raise ValueError (f"{ missing_data_vars } data used for fitting not found!" )
@@ -2466,8 +2492,9 @@ def forecast(
24662492 with forecast_model :
24672493 if scenario is not None :
24682494 dummy_obs_data = np .zeros ((len (forecast_index ), self .k_endog ))
2495+ scoped_scenario = {self .graph_name (name ): value for name , value in scenario .items ()}
24692496 pm .set_data (
2470- scenario | {"data" : dummy_obs_data },
2497+ scoped_scenario | {self . graph_name ( "data" ) : dummy_obs_data },
24712498 coords = {"data_time" : np .arange (len (forecast_index ))},
24722499 )
24732500
0 commit comments