Skip to content

Commit 0b22a39

Browse files
committed
feat: implement summary method for InstrumentalVariable
Fixes #360
1 parent 711970c commit 0b22a39

2 files changed

Lines changed: 42 additions & 2 deletions

File tree

causalpy/experiments/instrumental_variable.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,12 @@
2222
from patsy import dmatrices
2323
from sklearn.linear_model import LinearRegression as sk_lin_reg
2424

25+
import arviz as az
26+
27+
from causalpy.constants import HDI_PROB
2528
from causalpy.custom_exceptions import DataException
2629
from causalpy.pymc_models import InstrumentalVariableRegression
30+
from causalpy.utils import round_num
2731

2832
from .base import BaseExperiment
2933
from causalpy.reporting import EffectSummary
@@ -268,13 +272,48 @@ def plot(self, *args, **kwargs) -> None: # type: ignore[override]
268272
"""
269273
raise NotImplementedError("Plot method not implemented.")
270274

271-
def summary(self, round_to: int | None = None) -> None:
275+
def summary(self, round_to: int | None = 2) -> None:
272276
"""Print summary of main results and model coefficients.
273277
274278
:param round_to:
275279
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers
276280
"""
277-
raise NotImplementedError("Summary method not implemented.")
281+
round_to = round_to or 2
282+
print(f"{self.expt_type:=^80}")
283+
print(f"Formula: {self.formula}")
284+
print(f"Instruments formula: {self.instruments_formula}")
285+
286+
print("\nNaive OLS coefficients:")
287+
for name, val in self.ols_beta_params.items():
288+
print(f" {name: <20} {round_num(val, round_to)}")
289+
290+
print("\n2SLS coefficients:")
291+
print(" First stage:")
292+
for name, val in zip(
293+
self.labels_instruments, self.ols_beta_first_params, strict=False
294+
):
295+
print(f" {name: <20} {round_num(val, round_to)}")
296+
print(" Second stage:")
297+
for name, val in zip(self.labels, self.ols_beta_second_params, strict=False):
298+
print(f" {name: <20} {round_num(val, round_to)}")
299+
300+
print("\nBayesian coefficients:")
301+
posterior = self.idata.posterior
302+
for var, dim, labels, stage in [
303+
("beta_t", "instruments", self.labels_instruments, "Instrument stage"),
304+
("beta_z", "covariates", self.labels, "Outcome stage"),
305+
]:
306+
print(f" {stage}:")
307+
coeffs = az.extract(posterior, var_names=var)
308+
for name in labels:
309+
samples = coeffs.sel({dim: name})
310+
lo = samples.quantile((1 - HDI_PROB) / 2).item()
311+
hi = samples.quantile(1 - (1 - HDI_PROB) / 2).item()
312+
print(
313+
f" {name: <20} {round_num(samples.mean().item(), round_to)}, "
314+
f"{HDI_PROB * 100:.0f}% HDI [{round_num(lo, round_to)}, "
315+
f"{round_num(hi, round_to)}]"
316+
)
278317

279318
def effect_summary(
280319
self,

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,7 @@ def test_iv_reg(mock_pymc_sample):
637637
assert isinstance(result, cp.InstrumentalVariable)
638638
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
639639
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
640+
result.summary()
640641
with pytest.raises(NotImplementedError):
641642
result.get_plot_data()
642643

0 commit comments

Comments
 (0)