Skip to content

Commit 41880d3

Browse files
committed
Add hierarchical regression unit tests
1 parent dca1534 commit 41880d3

1 file changed

Lines changed: 217 additions & 0 deletions

File tree

causalpy/tests/test_pymc_models.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from typing import Any, ClassVar
15+
1416
import arviz as az
1517
import numpy as np
1618
import pandas as pd
@@ -20,7 +22,10 @@
2022
from pymc_extras.prior import Prior
2123

2224
import causalpy as cp
25+
26+
from causalpy.data.simulate_data import generate_hlr_data
2327
from causalpy.pymc_models import (
28+
HierarchicalLinearRegression,
2429
LinearRegression,
2530
PyMCModel,
2631
SoftmaxWeightedSumFitter,
@@ -761,6 +766,218 @@ def prior_test_data():
761766
return X, y, coords
762767

763768

769+
class TestHierarchicalLinearRegression:
770+
data: ClassVar[dict[str, Any]]
771+
X: ClassVar[xr.DataArray]
772+
Z: ClassVar[xr.DataArray]
773+
y: ClassVar[xr.DataArray]
774+
group_idx: ClassVar[np.ndarray]
775+
coords: ClassVar[dict[str, Any]]
776+
sample_kwargs: ClassVar[dict[str, Any]]
777+
priors: ClassVar[dict[str, Prior]]
778+
model_true: ClassVar[dict[str, np.ndarray]]
779+
model: ClassVar[HierarchicalLinearRegression]
780+
idata: ClassVar[az.InferenceData]
781+
782+
@staticmethod
783+
def _expect_value_error(fn, expected_message: str) -> None:
784+
with pytest.raises(ValueError, match=expected_message):
785+
fn()
786+
787+
@staticmethod
788+
def _prepare_data(
789+
XY: pd.DataFrame,
790+
Z_: pd.DataFrame,
791+
params: dict[str, np.ndarray],
792+
*,
793+
seed: int = 42,
794+
sample_kwargs: dict[str, object] | None = None,
795+
) -> dict[str, object]:
796+
"""Map HLR data into HierarchicalLinearRegression inputs."""
797+
obs_ind = XY.index.to_numpy()
798+
group_idx = XY["group_idx"].to_numpy()
799+
n_groups = int(XY["group_idx"].nunique())
800+
801+
_X = XY[["1", "x1"]].to_numpy()
802+
_Z = Z_.to_numpy()
803+
_y = XY["y"].to_numpy()
804+
805+
X = xr.DataArray(
806+
_X,
807+
dims=["obs_ind", "coeffs"],
808+
coords={"obs_ind": obs_ind, "coeffs": ["1", "x1"]},
809+
)
810+
Z = xr.DataArray(
811+
_Z,
812+
dims=["obs_ind", "random_coeffs"],
813+
coords={"obs_ind": obs_ind, "random_coeffs": list(Z_.columns)},
814+
)
815+
y = xr.DataArray(
816+
_y[:, None],
817+
dims=["obs_ind", "treated_units"],
818+
coords={"obs_ind": obs_ind, "treated_units": ["unit_0"]},
819+
)
820+
821+
coords = {
822+
"obs_ind": obs_ind,
823+
"coeffs": ["1", "x1"],
824+
"random_coeffs": list(Z_.columns),
825+
"treated_units": ["unit_0"],
826+
"groups": np.arange(n_groups),
827+
}
828+
829+
resolved_sample_kwargs: dict[str, Any] = {
830+
"draws": 500,
831+
"tune": 500,
832+
"chains": 2,
833+
"target_accept": 0.90,
834+
"progressbar": False,
835+
"random_seed": seed,
836+
}
837+
if sample_kwargs is not None:
838+
resolved_sample_kwargs.update(sample_kwargs)
839+
840+
return {
841+
"model_kwargs": {
842+
"X": X,
843+
"Z": Z,
844+
"y": y,
845+
"group_idx": group_idx,
846+
"coords": coords,
847+
},
848+
"sample_kwargs": sample_kwargs,
849+
"priors": {
850+
"beta_fixed": Prior(
851+
"Normal", mu=0, sigma=10, dims=["treated_units", "coeffs"]
852+
),
853+
"sigma_fixed": Prior("HalfNormal", sigma=1, dims=["treated_units"]),
854+
"sigma_random": Prior(
855+
"HalfNormal", sigma=1, dims=["treated_units", "random_coeffs"]
856+
),
857+
"beta_group": Prior(
858+
"Normal",
859+
mu=0,
860+
sigma=1,
861+
dims=["groups", "treated_units", "random_coeffs"],
862+
),
863+
},
864+
"model_true": params,
865+
}
866+
867+
@classmethod
868+
def _setup_data(
869+
cls,
870+
*,
871+
seed: int = 42,
872+
n_groups: int = 12,
873+
n_obs_per_group: int = 40,
874+
beta_fixed_true: tuple[float, float] = (2.0, 1.4),
875+
sigma_random_true: tuple[float, float] = (0.7, 0.5),
876+
sigma_noise: float = 0.8,
877+
sample_kwargs: dict[str, Any] | None = None,
878+
) -> dict[str, Any]:
879+
XY, Z, params = generate_hlr_data(
880+
seed=seed,
881+
n_groups=n_groups,
882+
n_obs_per_group=n_obs_per_group,
883+
beta_fixed_true=beta_fixed_true,
884+
sigma_random_true=sigma_random_true,
885+
sigma_noise=sigma_noise,
886+
)
887+
return cls._prepare_data(
888+
XY,
889+
Z,
890+
params,
891+
seed=seed,
892+
sample_kwargs=sample_kwargs,
893+
)
894+
895+
@classmethod
896+
def setup_class(cls) -> None:
897+
cls.data = cls._setup_data(
898+
sample_kwargs={
899+
"draws": 120,
900+
"tune": 120,
901+
"chains": 2,
902+
"cores": 1,
903+
}
904+
)
905+
cls.X = cls.data["model_kwargs"]["X"]
906+
cls.Z = cls.data["model_kwargs"]["Z"]
907+
cls.y = cls.data["model_kwargs"]["y"]
908+
cls.group_idx = cls.data["model_kwargs"]["group_idx"]
909+
cls.coords = cls.data["model_kwargs"]["coords"]
910+
cls.sample_kwargs = cls.data["sample_kwargs"]
911+
cls.priors = cls.data["priors"]
912+
cls.model_true = cls.data["model_true"]
913+
914+
cls.model = HierarchicalLinearRegression(
915+
sample_kwargs=cls.sample_kwargs, priors=cls.priors
916+
)
917+
cls.model.build_model(
918+
X=cls.X,
919+
Z=cls.Z,
920+
y=cls.y,
921+
group_idx=cls.group_idx,
922+
coords=cls.coords,
923+
)
924+
with cls.model:
925+
cls.idata = pm.sample(**cls.sample_kwargs)
926+
927+
def test_centered_requires_sigma_random_prior(self) -> None:
928+
priors = {key: self.model.priors[key] for key in ["beta_fixed", "sigma_fixed"]}
929+
self._expect_value_error(
930+
lambda: HierarchicalLinearRegression(priors=priors).build_model(
931+
X=self.X,
932+
Z=self.Z,
933+
y=self.y,
934+
group_idx=self.group_idx,
935+
coords=self.coords,
936+
non_centered=False,
937+
),
938+
"Missing required priors for centered parameterization: sigma_random.",
939+
)
940+
941+
def test_noncentered_requires_beta_group_prior(self) -> None:
942+
priors = {
943+
key: self.model.priors[key]
944+
for key in ["beta_fixed", "sigma_fixed", "sigma_random"]
945+
}
946+
self._expect_value_error(
947+
lambda: HierarchicalLinearRegression(priors=priors).build_model(
948+
X=self.X,
949+
Z=self.Z,
950+
y=self.y,
951+
group_idx=self.group_idx,
952+
coords=self.coords,
953+
non_centered=True,
954+
),
955+
"Missing required priors for non-centered parameterization: beta_group.",
956+
)
957+
958+
def test_fixed_effect_parameter_recovery(self) -> None:
959+
beta_fixed_mean = (
960+
self.idata.posterior["beta_fixed"].mean(dim=("chain", "draw")).to_numpy()[0]
961+
)
962+
np.testing.assert_allclose(
963+
beta_fixed_mean,
964+
self.model_true["beta_fixed_true"],
965+
atol=0.50,
966+
)
967+
968+
def test_random_scale_parameter_recovery(self) -> None:
969+
sigma_random_mean = (
970+
self.idata.posterior["sigma_random"]
971+
.mean(dim=("chain", "draw"))
972+
.to_numpy()[0]
973+
)
974+
np.testing.assert_allclose(
975+
sigma_random_mean,
976+
self.model_true["sigma_random_true"],
977+
atol=0.50,
978+
)
979+
980+
764981
class TestPriorIntegration:
765982
"""
766983
Test suite for Prior class integration with PyMC models.

0 commit comments

Comments
 (0)