|
22 | 22 | from patsy import dmatrices |
23 | 23 | from sklearn.linear_model import LinearRegression as sk_lin_reg |
24 | 24 |
|
| 25 | +import arviz as az |
| 26 | + |
| 27 | +from causalpy.constants import HDI_PROB |
25 | 28 | from causalpy.custom_exceptions import DataException |
26 | 29 | from causalpy.pymc_models import InstrumentalVariableRegression |
| 30 | +from causalpy.utils import round_num |
27 | 31 |
|
28 | 32 | from .base import BaseExperiment |
29 | 33 | from causalpy.reporting import EffectSummary |
@@ -268,13 +272,48 @@ def plot(self, *args, **kwargs) -> None: # type: ignore[override] |
268 | 272 | """ |
269 | 273 | raise NotImplementedError("Plot method not implemented.") |
270 | 274 |
|
271 | | - def summary(self, round_to: int | None = None) -> None: |
| 275 | + def summary(self, round_to: int | None = 2) -> None: |
272 | 276 | """Print summary of main results and model coefficients. |
273 | 277 |
|
274 | 278 | :param round_to: |
275 | 279 | Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers |
276 | 280 | """ |
277 | | - raise NotImplementedError("Summary method not implemented.") |
| 281 | + round_to = round_to or 2 |
| 282 | + print(f"{self.expt_type:=^80}") |
| 283 | + print(f"Formula: {self.formula}") |
| 284 | + print(f"Instruments formula: {self.instruments_formula}") |
| 285 | + |
| 286 | + print("\nNaive OLS coefficients:") |
| 287 | + for name, val in self.ols_beta_params.items(): |
| 288 | + print(f" {name: <20} {round_num(val, round_to)}") |
| 289 | + |
| 290 | + print("\n2SLS coefficients:") |
| 291 | + print(" First stage:") |
| 292 | + for name, val in zip( |
| 293 | + self.labels_instruments, self.ols_beta_first_params, strict=False |
| 294 | + ): |
| 295 | + print(f" {name: <20} {round_num(val, round_to)}") |
| 296 | + print(" Second stage:") |
| 297 | + for name, val in zip(self.labels, self.ols_beta_second_params, strict=False): |
| 298 | + print(f" {name: <20} {round_num(val, round_to)}") |
| 299 | + |
| 300 | + print("\nBayesian coefficients:") |
| 301 | + posterior = self.idata.posterior |
| 302 | + for var, dim, labels, stage in [ |
| 303 | + ("beta_t", "instruments", self.labels_instruments, "Instrument stage"), |
| 304 | + ("beta_z", "covariates", self.labels, "Outcome stage"), |
| 305 | + ]: |
| 306 | + print(f" {stage}:") |
| 307 | + coeffs = az.extract(posterior, var_names=var) |
| 308 | + for name in labels: |
| 309 | + samples = coeffs.sel({dim: name}) |
| 310 | + lo = samples.quantile((1 - HDI_PROB) / 2).item() |
| 311 | + hi = samples.quantile(1 - (1 - HDI_PROB) / 2).item() |
| 312 | + print( |
| 313 | + f" {name: <20} {round_num(samples.mean().item(), round_to)}, " |
| 314 | + f"{HDI_PROB * 100:.0f}% HDI [{round_num(lo, round_to)}, " |
| 315 | + f"{round_num(hi, round_to)}]" |
| 316 | + ) |
278 | 317 |
|
279 | 318 | def effect_summary( |
280 | 319 | self, |
|
0 commit comments