@@ -557,7 +557,7 @@ def get_plot_data_bayesian(self, **kwargs: Any) -> pd.DataFrame:
557557 """
558558 # Get posterior predictions
559559 if isinstance (self .model , PyMCModel ):
560- mu = self .model .idata .posterior ["mu" ] # type: ignore[attr-defined ]
560+ mu = self .model .idata .posterior ["mu" ] # type: ignore[union-attr ]
561561 pred_mean = mu .mean (dim = ["chain" , "draw" ]).values .flatten ()
562562 pred_lower = mu .quantile (0.025 , dim = ["chain" , "draw" ]).values .flatten ()
563563 pred_upper = mu .quantile (0.975 , dim = ["chain" , "draw" ]).values .flatten ()
@@ -673,14 +673,17 @@ def plot_unit_effects(
673673
674674 if isinstance (self .model , PyMCModel ):
675675 # Bayesian: get posterior means
676- beta = self .model .idata .posterior ["beta" ] # type: ignore[attr-defined ]
676+ beta = self .model .idata .posterior ["beta" ] # type: ignore[union-attr ]
677677 unit_fe_indices = [self .labels .index (name ) for name in unit_fe_names ]
678678
679679 # Get mean and std for each unit FE
680680 fe_means = []
681681 for idx in unit_fe_indices :
682682 fe_means .append (
683- float (beta .sel (coeffs = self .labels [idx ]).mean (dim = ["chain" , "draw" ]))
683+ beta .sel (coeffs = self .labels [idx ])
684+ .mean (dim = ["chain" , "draw" ])
685+ .squeeze ("treated_units" , drop = True )
686+ .item ()
684687 )
685688
686689 ax .hist (
0 commit comments