@@ -266,6 +266,7 @@ def __init__(
266266 measurement_error : bool = False ,
267267 mode : str | None = None ,
268268 name : str | None = None ,
269+ data_name : str = "data" ,
269270 ):
270271 self ._fit_coords : dict [str , Sequence [str ]] | None = None
271272 self ._fit_dims : dict [str , Sequence [str ]] | None = None
@@ -280,6 +281,7 @@ def __init__(
280281 self .k_states = k_states
281282 self .k_posdef = k_posdef
282283 self .name = name
284+ self .data_name = data_name
283285 self .measurement_error = measurement_error
284286 self .mode = mode
285287
@@ -311,11 +313,11 @@ def __init__(
311313 console = Console ()
312314 console .print (self .requirement_table )
313315
314- def graph_name (self , base : str ) -> str :
316+ def prefixed_name (self , base_name : str ) -> str :
315317 if not self .name :
316- return base
318+ return base_name
317319 prefix = f"{ self .name } _"
318- return base if base .startswith (prefix ) else f"{ self .name } _{ base } "
320+ return base_name if base_name .startswith (prefix ) else f"{ self .name } _{ base_name } "
319321
320322 def _populate_properties (self ) -> None :
321323 self ._set_parameters ()
@@ -626,7 +628,7 @@ def add_default_priors(self) -> None:
626628 raise NotImplementedError ("The add_default_priors property has not been implemented!" )
627629
628630 def make_and_register_variable (
629- self , name , shape : int | tuple [int , ...] | None = None , dtype = floatX
631+ self , base_name , shape : int | tuple [int , ...] | None = None , dtype = floatX
630632 ) -> pt .TensorVariable :
631633 """
632634 Helper function to create a pytensor symbolic variable and register it in the _name_to_variable dictionary
@@ -655,27 +657,27 @@ def make_and_register_variable(
655657 An error is raised if the provided name has already been registered, or if the name is not present in the
656658 ``param_names`` property.
657659 """
658- if name not in self .param_names :
660+ if base_name not in self .param_names :
659661 raise ValueError (
660- f"{ name } is not a model parameter. All placeholder variables should correspond to model "
662+ f"{ base_name } is not a model parameter. All placeholder variables should correspond to model "
661663 f"parameters."
662664 )
663665
664- gname = self .graph_name ( name )
666+ name = self .prefixed_name ( base_name )
665667
666- if gname in self ._tensor_variable_info :
668+ if name in self ._tensor_variable_info :
667669 raise ValueError (
668- f"{ gname } is already a registered placeholder variable with shape "
669- f"{ self ._tensor_variable_info [gname ].type .shape } "
670+ f"{ name } is already a registered placeholder variable with shape "
671+ f"{ self ._tensor_variable_info [name ].type .shape } "
670672 )
671673
672- placeholder = pt .tensor (gname , shape = shape , dtype = dtype )
673- tensor_var = SymbolicVariable (name = gname , symbolic_variable = placeholder )
674+ placeholder = pt .tensor (name , shape = shape , dtype = dtype )
675+ tensor_var = SymbolicVariable (name = name , symbolic_variable = placeholder )
674676 self ._tensor_variable_info = self ._tensor_variable_info .add (tensor_var )
675677 return placeholder
676678
677679 def make_and_register_data (
678- self , name : str , shape : int | tuple [int ], dtype : str = floatX
680+ self , base_name : str , shape : int | tuple [int ], dtype : str = floatX
679681 ) -> Variable :
680682 r"""
681683 Helper function to create a pytensor symbolic variable and register it in the _name_to_data dictionary
@@ -697,22 +699,22 @@ def make_and_register_data(
697699 An error is raised if the provided name has already been registered, or if the name is not present in the
698700 ``data_names`` property.
699701 """
700- if name not in self .data_names :
702+ if base_name not in self .data_names :
701703 raise ValueError (
702- f"{ name } is not a model data- variable. All placeholder variables should correspond to model "
703- f"data- variables."
704+ f"{ base_name } is not a model data variable. All placeholder variables should correspond to model "
705+ f"data variables."
704706 )
705707
706- gname = self .graph_name ( name )
708+ name = self .prefixed_name ( base_name )
707709
708- if gname in self ._tensor_data_info :
710+ if name in self ._tensor_data_info :
709711 raise ValueError (
710- f"{ gname } is already a registered placeholder variable with shape "
711- f"{ self ._tensor_data_info [gname ].type .shape } "
712+ f"{ name } is already a registered placeholder variable with shape "
713+ f"{ self ._tensor_data_info [name ].type .shape } "
712714 )
713715
714- placeholder = pt .tensor (gname , shape = shape , dtype = dtype )
715- tensor_data = SymbolicData (name = gname , symbolic_data = placeholder )
716+ placeholder = pt .tensor (name , shape = shape , dtype = dtype )
717+ tensor_data = SymbolicData (name = name , symbolic_data = placeholder )
716718 self ._tensor_data_info = self ._tensor_data_info .add (tensor_data )
717719 return placeholder
718720
@@ -816,12 +818,12 @@ def _save_exogenous_data_info(self):
816818 """
817819 pymc_mod = modelcontext (None )
818820 for data_name in self .data_names :
819- gname = self .graph_name (data_name )
820- data = pymc_mod [gname ]
821+ name = self .prefixed_name (data_name )
822+ data = pymc_mod [name ]
821823 self ._fit_exog_data [data_name ] = {
822- "name" : gname ,
824+ "name" : name ,
823825 "value" : data .get_value (),
824- "dims" : pymc_mod .named_vars_to_dims .get (gname , None ),
826+ "dims" : pymc_mod .named_vars_to_dims .get (name , None ),
825827 }
826828
827829 def _insert_random_variables (self ):
@@ -860,8 +862,8 @@ def _insert_random_variables(self):
860862 found_params = []
861863 with pymc_model :
862864 for param_name in self .param_names :
863- gname = self .graph_name (param_name )
864- param = getattr (pymc_model , gname , None )
865+ name = self .prefixed_name (param_name )
866+ param = getattr (pymc_model , name , None )
865867 if param is not None :
866868 found_params .append (param_name )
867869
@@ -898,8 +900,8 @@ def _insert_data_variables(self):
898900 found_data = []
899901 with pymc_model :
900902 for data_name in data_names :
901- gname = self .graph_name (data_name )
902- data = getattr (pymc_model , gname , None )
903+ name = self .prefixed_name (data_name )
904+ data = getattr (pymc_model , name , None )
903905 if data is not None :
904906 found_data .append (data_name )
905907
@@ -1065,7 +1067,7 @@ def build_statespace_graph(
10651067 obs_coords = obs_coords ,
10661068 register_data = register_data ,
10671069 missing_fill_value = missing_fill_value ,
1068- data_name = self .graph_name ( "data" ),
1070+ data_name = self .prefixed_name ( self . data_name ),
10691071 )
10701072
10711073 filter_outputs = self .kalman_filter .build_graph (
@@ -1164,16 +1166,16 @@ def _build_dummy_graph(self) -> None:
11641166 A list of pm.Flat variables representing all parameters estimated by the model.
11651167 """
11661168
1167- def infer_variable_shape (name ):
1168- gname = self .graph_name ( name )
1169- shape = self ._name_to_variable [gname ].type .shape
1169+ def infer_variable_shape (base_name ):
1170+ name = self .prefixed_name ( base_name )
1171+ shape = self ._name_to_variable [name ].type .shape
11701172 if not any (dim is None for dim in shape ):
11711173 return shape
11721174
1173- dim_names = self ._fit_dims .get (gname , None )
1175+ dim_names = self ._fit_dims .get (name , None )
11741176 if dim_names is None :
11751177 raise ValueError (
1176- f"Could not infer shape for { name } , because it was not given coords during model"
1178+ f"Could not infer shape for { base_name } , because it was not given coords during model"
11771179 f"fitting"
11781180 )
11791181
@@ -1185,11 +1187,11 @@ def infer_variable_shape(name):
11851187 ]
11861188 )
11871189
1188- for name in self .param_names :
1190+ for base_name in self .param_names :
11891191 pm .Flat (
1190- self .graph_name ( name ),
1191- shape = infer_variable_shape (name ),
1192- dims = self ._fit_dims .get (self .graph_name ( name ), None ),
1192+ self .prefixed_name ( base_name ),
1193+ shape = infer_variable_shape (base_name ),
1194+ dims = self ._fit_dims .get (self .prefixed_name ( base_name ), None ),
11931195 )
11941196
11951197 def _kalman_filter_outputs_from_dummy_graph (
@@ -1229,14 +1231,14 @@ def _kalman_filter_outputs_from_dummy_graph(
12291231 self ._insert_random_variables ()
12301232
12311233 for name in self .data_names :
1232- if self .graph_name (name ) not in pm_mod :
1234+ if self .prefixed_name (name ) not in pm_mod :
12331235 pm .Data (** self ._fit_exog_data [name ])
12341236
12351237 self ._insert_data_variables ()
12361238
12371239 for name in self .data_names :
12381240 if name in scenario .keys ():
1239- pm .set_data ({self .graph_name (name ): scenario [name ]})
1241+ pm .set_data ({self .prefixed_name (name ): scenario [name ]})
12401242
12411243 x0 , P0 , c , d , T , Z , R , H , Q = self .unpack_statespace ()
12421244
@@ -1251,7 +1253,7 @@ def _kalman_filter_outputs_from_dummy_graph(
12511253 obs_coords = obs_coords ,
12521254 data_dims = data_dims ,
12531255 register_data = True ,
1254- data_name = self .graph_name ( "data" ),
1256+ data_name = self .prefixed_name ( self . data_name ),
12551257 )
12561258
12571259 filter_outputs = self .kalman_filter .build_graph (
@@ -1808,7 +1810,7 @@ def sample_statespace_matrices(
18081810 self ._insert_random_variables ()
18091811
18101812 for name in self .data_names :
1811- pm .Data (name = self .graph_name (name ), ** self .data_info [name ])
1813+ pm .Data (name = self .prefixed_name (name ), ** self .data_info [name ])
18121814
18131815 self ._insert_data_variables ()
18141816 matrices = self .unpack_statespace ()
@@ -1874,7 +1876,7 @@ def sample_filter_outputs(
18741876 n_obs = self .ssm .k_endog ,
18751877 obs_coords = obs_coords ,
18761878 register_data = True ,
1877- data_name = self .graph_name ( "data" ),
1879+ data_name = self .prefixed_name ( self . data_name ),
18781880 )
18791881
18801882 filter_outputs = self .kalman_filter .build_graph (
@@ -2307,13 +2309,16 @@ def _build_forecast_model(
23072309
23082310 sub_dict = {
23092311 data_var : pt .as_tensor_variable (
2310- data_var .get_value (), name = self .graph_name ( "data" )
2312+ data_var .get_value (), name = self .prefixed_name ( self . data_name )
23112313 )
23122314 for data_var in forecast_model .data_vars
23132315 }
23142316
23152317 missing_data_vars = np .setdiff1d (
2316- ar1 = [* [self .graph_name (name ) for name in self .data_names ], self .graph_name ("data" )],
2318+ ar1 = [
2319+ * [self .prefixed_name (name ) for name in self .data_names ],
2320+ self .prefixed_name (self .data_name ),
2321+ ],
23172322 ar2 = [k .name for k , _ in sub_dict .items ()],
23182323 )
23192324 if missing_data_vars .size > 0 :
@@ -2492,9 +2497,11 @@ def forecast(
24922497 with forecast_model :
24932498 if scenario is not None :
24942499 dummy_obs_data = np .zeros ((len (forecast_index ), self .k_endog ))
2495- scoped_scenario = {self .graph_name (name ): value for name , value in scenario .items ()}
2500+ scoped_scenario = {
2501+ self .prefixed_name (name ): value for name , value in scenario .items ()
2502+ }
24962503 pm .set_data (
2497- scoped_scenario | {self .graph_name ( "data" ): dummy_obs_data },
2504+ scoped_scenario | {self .prefixed_name ( self . data_name ): dummy_obs_data },
24982505 coords = {"data_time" : np .arange (len (forecast_index ))},
24992506 )
25002507
0 commit comments