Skip to content

Commit d68e837

Browse files
authored
Merge branch 'main' into dependabot/github_actions/actions-0c2b297cc4
2 parents 587242a + 8fe2c06 commit d68e837

2 files changed

Lines changed: 9 additions & 3 deletions

File tree

causalpy/experiments/panel_regression.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ def get_plot_data_bayesian(self, **kwargs: Any) -> pd.DataFrame:
557557
"""
558558
# Get posterior predictions
559559
if isinstance(self.model, PyMCModel):
560-
mu = self.model.idata.posterior["mu"] # type: ignore[attr-defined]
560+
mu = self.model.idata.posterior["mu"] # type: ignore[union-attr]
561561
pred_mean = mu.mean(dim=["chain", "draw"]).values.flatten()
562562
pred_lower = mu.quantile(0.025, dim=["chain", "draw"]).values.flatten()
563563
pred_upper = mu.quantile(0.975, dim=["chain", "draw"]).values.flatten()
@@ -673,14 +673,17 @@ def plot_unit_effects(
673673

674674
if isinstance(self.model, PyMCModel):
675675
# Bayesian: get posterior means
676-
beta = self.model.idata.posterior["beta"] # type: ignore[attr-defined]
676+
beta = self.model.idata.posterior["beta"] # type: ignore[union-attr]
677677
unit_fe_indices = [self.labels.index(name) for name in unit_fe_names]
678678

679679
# Get mean and std for each unit FE
680680
fe_means = []
681681
for idx in unit_fe_indices:
682682
fe_means.append(
683-
float(beta.sel(coeffs=self.labels[idx]).mean(dim=["chain", "draw"]))
683+
beta.sel(coeffs=self.labels[idx])
684+
.mean(dim=["chain", "draw"])
685+
.squeeze("treated_units", drop=True)
686+
.item()
684687
)
685688

686689
ax.hist(

causalpy/tests/test_panel_regression.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,9 @@ def test_panel_regression_plot_coefficients(mock_pymc_sample, small_panel_data):
290290

291291

292292
@pytest.mark.integration
293+
@pytest.mark.filterwarnings(
294+
"error:Conversion of an array with ndim > 0 to a scalar:DeprecationWarning"
295+
)
293296
def test_panel_regression_plot_unit_effects(mock_pymc_sample, small_panel_data):
294297
"""Test plot_unit_effects method."""
295298
result = cp.PanelRegression(

0 commit comments

Comments
 (0)