|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
| 14 | +from typing import Any, ClassVar |
| 15 | + |
14 | 16 | import arviz as az |
15 | 17 | import numpy as np |
16 | 18 | import pandas as pd |
|
20 | 22 | from pymc_extras.prior import Prior |
21 | 23 |
|
22 | 24 | import causalpy as cp |
| 25 | + |
| 26 | +from causalpy.data.simulate_data import generate_hlr_data |
23 | 27 | from causalpy.pymc_models import ( |
| 28 | + HierarchicalLinearRegression, |
24 | 29 | LinearRegression, |
25 | 30 | PyMCModel, |
26 | 31 | SoftmaxWeightedSumFitter, |
@@ -761,6 +766,218 @@ def prior_test_data(): |
761 | 766 | return X, y, coords |
762 | 767 |
|
763 | 768 |
|
| 769 | +class TestHierarchicalLinearRegression: |
| 770 | + data: ClassVar[dict[str, Any]] |
| 771 | + X: ClassVar[xr.DataArray] |
| 772 | + Z: ClassVar[xr.DataArray] |
| 773 | + y: ClassVar[xr.DataArray] |
| 774 | + group_idx: ClassVar[np.ndarray] |
| 775 | + coords: ClassVar[dict[str, Any]] |
| 776 | + sample_kwargs: ClassVar[dict[str, Any]] |
| 777 | + priors: ClassVar[dict[str, Prior]] |
| 778 | + model_true: ClassVar[dict[str, np.ndarray]] |
| 779 | + model: ClassVar[HierarchicalLinearRegression] |
| 780 | + idata: ClassVar[az.InferenceData] |
| 781 | + |
| 782 | + @staticmethod |
| 783 | + def _expect_value_error(fn, expected_message: str) -> None: |
| 784 | + with pytest.raises(ValueError, match=expected_message): |
| 785 | + fn() |
| 786 | + |
| 787 | + @staticmethod |
| 788 | + def _prepare_data( |
| 789 | + XY: pd.DataFrame, |
| 790 | + Z_: pd.DataFrame, |
| 791 | + params: dict[str, np.ndarray], |
| 792 | + *, |
| 793 | + seed: int = 42, |
| 794 | + sample_kwargs: dict[str, object] | None = None, |
| 795 | + ) -> dict[str, object]: |
| 796 | + """Map HLR data into HierarchicalLinearRegression inputs.""" |
| 797 | + obs_ind = XY.index.to_numpy() |
| 798 | + group_idx = XY["group_idx"].to_numpy() |
| 799 | + n_groups = int(XY["group_idx"].nunique()) |
| 800 | + |
| 801 | + _X = XY[["1", "x1"]].to_numpy() |
| 802 | + _Z = Z_.to_numpy() |
| 803 | + _y = XY["y"].to_numpy() |
| 804 | + |
| 805 | + X = xr.DataArray( |
| 806 | + _X, |
| 807 | + dims=["obs_ind", "coeffs"], |
| 808 | + coords={"obs_ind": obs_ind, "coeffs": ["1", "x1"]}, |
| 809 | + ) |
| 810 | + Z = xr.DataArray( |
| 811 | + _Z, |
| 812 | + dims=["obs_ind", "random_coeffs"], |
| 813 | + coords={"obs_ind": obs_ind, "random_coeffs": list(Z_.columns)}, |
| 814 | + ) |
| 815 | + y = xr.DataArray( |
| 816 | + _y[:, None], |
| 817 | + dims=["obs_ind", "treated_units"], |
| 818 | + coords={"obs_ind": obs_ind, "treated_units": ["unit_0"]}, |
| 819 | + ) |
| 820 | + |
| 821 | + coords = { |
| 822 | + "obs_ind": obs_ind, |
| 823 | + "coeffs": ["1", "x1"], |
| 824 | + "random_coeffs": list(Z_.columns), |
| 825 | + "treated_units": ["unit_0"], |
| 826 | + "groups": np.arange(n_groups), |
| 827 | + } |
| 828 | + |
| 829 | + resolved_sample_kwargs: dict[str, Any] = { |
| 830 | + "draws": 500, |
| 831 | + "tune": 500, |
| 832 | + "chains": 2, |
| 833 | + "target_accept": 0.90, |
| 834 | + "progressbar": False, |
| 835 | + "random_seed": seed, |
| 836 | + } |
| 837 | + if sample_kwargs is not None: |
| 838 | + resolved_sample_kwargs.update(sample_kwargs) |
| 839 | + |
| 840 | + return { |
| 841 | + "model_kwargs": { |
| 842 | + "X": X, |
| 843 | + "Z": Z, |
| 844 | + "y": y, |
| 845 | + "group_idx": group_idx, |
| 846 | + "coords": coords, |
| 847 | + }, |
| 848 | + "sample_kwargs": sample_kwargs, |
| 849 | + "priors": { |
| 850 | + "beta_fixed": Prior( |
| 851 | + "Normal", mu=0, sigma=10, dims=["treated_units", "coeffs"] |
| 852 | + ), |
| 853 | + "sigma_fixed": Prior("HalfNormal", sigma=1, dims=["treated_units"]), |
| 854 | + "sigma_random": Prior( |
| 855 | + "HalfNormal", sigma=1, dims=["treated_units", "random_coeffs"] |
| 856 | + ), |
| 857 | + "beta_group": Prior( |
| 858 | + "Normal", |
| 859 | + mu=0, |
| 860 | + sigma=1, |
| 861 | + dims=["groups", "treated_units", "random_coeffs"], |
| 862 | + ), |
| 863 | + }, |
| 864 | + "model_true": params, |
| 865 | + } |
| 866 | + |
| 867 | + @classmethod |
| 868 | + def _setup_data( |
| 869 | + cls, |
| 870 | + *, |
| 871 | + seed: int = 42, |
| 872 | + n_groups: int = 12, |
| 873 | + n_obs_per_group: int = 40, |
| 874 | + beta_fixed_true: tuple[float, float] = (2.0, 1.4), |
| 875 | + sigma_random_true: tuple[float, float] = (0.7, 0.5), |
| 876 | + sigma_noise: float = 0.8, |
| 877 | + sample_kwargs: dict[str, Any] | None = None, |
| 878 | + ) -> dict[str, Any]: |
| 879 | + XY, Z, params = generate_hlr_data( |
| 880 | + seed=seed, |
| 881 | + n_groups=n_groups, |
| 882 | + n_obs_per_group=n_obs_per_group, |
| 883 | + beta_fixed_true=beta_fixed_true, |
| 884 | + sigma_random_true=sigma_random_true, |
| 885 | + sigma_noise=sigma_noise, |
| 886 | + ) |
| 887 | + return cls._prepare_data( |
| 888 | + XY, |
| 889 | + Z, |
| 890 | + params, |
| 891 | + seed=seed, |
| 892 | + sample_kwargs=sample_kwargs, |
| 893 | + ) |
| 894 | + |
| 895 | + @classmethod |
| 896 | + def setup_class(cls) -> None: |
| 897 | + cls.data = cls._setup_data( |
| 898 | + sample_kwargs={ |
| 899 | + "draws": 120, |
| 900 | + "tune": 120, |
| 901 | + "chains": 2, |
| 902 | + "cores": 1, |
| 903 | + } |
| 904 | + ) |
| 905 | + cls.X = cls.data["model_kwargs"]["X"] |
| 906 | + cls.Z = cls.data["model_kwargs"]["Z"] |
| 907 | + cls.y = cls.data["model_kwargs"]["y"] |
| 908 | + cls.group_idx = cls.data["model_kwargs"]["group_idx"] |
| 909 | + cls.coords = cls.data["model_kwargs"]["coords"] |
| 910 | + cls.sample_kwargs = cls.data["sample_kwargs"] |
| 911 | + cls.priors = cls.data["priors"] |
| 912 | + cls.model_true = cls.data["model_true"] |
| 913 | + |
| 914 | + cls.model = HierarchicalLinearRegression( |
| 915 | + sample_kwargs=cls.sample_kwargs, priors=cls.priors |
| 916 | + ) |
| 917 | + cls.model.build_model( |
| 918 | + X=cls.X, |
| 919 | + Z=cls.Z, |
| 920 | + y=cls.y, |
| 921 | + group_idx=cls.group_idx, |
| 922 | + coords=cls.coords, |
| 923 | + ) |
| 924 | + with cls.model: |
| 925 | + cls.idata = pm.sample(**cls.sample_kwargs) |
| 926 | + |
| 927 | + def test_centered_requires_sigma_random_prior(self) -> None: |
| 928 | + priors = {key: self.model.priors[key] for key in ["beta_fixed", "sigma_fixed"]} |
| 929 | + self._expect_value_error( |
| 930 | + lambda: HierarchicalLinearRegression(priors=priors).build_model( |
| 931 | + X=self.X, |
| 932 | + Z=self.Z, |
| 933 | + y=self.y, |
| 934 | + group_idx=self.group_idx, |
| 935 | + coords=self.coords, |
| 936 | + non_centered=False, |
| 937 | + ), |
| 938 | + "Missing required priors for centered parameterization: sigma_random.", |
| 939 | + ) |
| 940 | + |
| 941 | + def test_noncentered_requires_beta_group_prior(self) -> None: |
| 942 | + priors = { |
| 943 | + key: self.model.priors[key] |
| 944 | + for key in ["beta_fixed", "sigma_fixed", "sigma_random"] |
| 945 | + } |
| 946 | + self._expect_value_error( |
| 947 | + lambda: HierarchicalLinearRegression(priors=priors).build_model( |
| 948 | + X=self.X, |
| 949 | + Z=self.Z, |
| 950 | + y=self.y, |
| 951 | + group_idx=self.group_idx, |
| 952 | + coords=self.coords, |
| 953 | + non_centered=True, |
| 954 | + ), |
| 955 | + "Missing required priors for non-centered parameterization: beta_group.", |
| 956 | + ) |
| 957 | + |
| 958 | + def test_fixed_effect_parameter_recovery(self) -> None: |
| 959 | + beta_fixed_mean = ( |
| 960 | + self.idata.posterior["beta_fixed"].mean(dim=("chain", "draw")).to_numpy()[0] |
| 961 | + ) |
| 962 | + np.testing.assert_allclose( |
| 963 | + beta_fixed_mean, |
| 964 | + self.model_true["beta_fixed_true"], |
| 965 | + atol=0.50, |
| 966 | + ) |
| 967 | + |
| 968 | + def test_random_scale_parameter_recovery(self) -> None: |
| 969 | + sigma_random_mean = ( |
| 970 | + self.idata.posterior["sigma_random"] |
| 971 | + .mean(dim=("chain", "draw")) |
| 972 | + .to_numpy()[0] |
| 973 | + ) |
| 974 | + np.testing.assert_allclose( |
| 975 | + sigma_random_mean, |
| 976 | + self.model_true["sigma_random_true"], |
| 977 | + atol=0.50, |
| 978 | + ) |
| 979 | + |
| 980 | + |
764 | 981 | class TestPriorIntegration: |
765 | 982 | """ |
766 | 983 | Test suite for Prior class integration with PyMC models. |
|
0 commit comments