Skip to content

Commit 7579acb

Browse files
committed
fix: clip extreme propensity scores in plot_ate and plot_balance_ecdf
1 parent deb8774 commit 7579acb

2 files changed

Lines changed: 111 additions & 2 deletions

File tree

causalpy/experiments/inverse_propensity_weighting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ def plot_weights(bins, top0, top1, ax, color="population"):
578578
bar.set_edgecolor("black")
579579

580580
def make_hists(idata, i, axs, method=method):
581-
p_i = az.extract(idata)["p"][:, i].values
581+
p_i = self._prepare_ps(az.extract(idata)["p"][:, i].values)
582582
if method == "raw":
583583
weight0 = 1 / (1 - p_i[self.t.flatten() == 0])
584584
weight1 = 1 / (p_i[self.t.flatten() == 1])
@@ -749,7 +749,7 @@ def plot_balance_ecdf(
749749
if weighting_scheme is None:
750750
weighting_scheme = self.weighting_scheme
751751

752-
ps = az.extract(idata)["p"].mean(dim="sample").values
752+
ps = self._prepare_ps(az.extract(idata)["p"].mean(dim="sample").values)
753753
X = pd.DataFrame(self.X, columns=self.labels)
754754
X["ps"] = ps
755755
t = self.t.flatten()
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright 2022 - 2026 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Tests for IPW plotting with extreme propensity scores.
16+
17+
Regression tests for issue #645: plot_ate() and plot_balance_ecdf() crash
18+
with ValueError when propensity scores include 0.0 or 1.0 due to
19+
unguarded division.
20+
"""
21+
22+
import matplotlib.pyplot as plt
23+
import numpy as np
24+
import pytest
25+
26+
import causalpy as cp
27+
28+
sample_kwargs = {
29+
"tune": 50,
30+
"draws": 100,
31+
"chains": 2,
32+
"cores": 2,
33+
"random_seed": 42,
34+
}
35+
36+
37+
@pytest.fixture(scope="module")
38+
def ipw_result(mock_pymc_sample):
39+
"""Create a fitted IPW result for testing."""
40+
df = cp.load_data("nhefs")
41+
return cp.InversePropensityWeighting(
42+
df,
43+
formula="trt ~ 1 + age + race",
44+
outcome_variable="outcome",
45+
weighting_scheme="robust",
46+
model=cp.pymc_models.PropensityScore(sample_kwargs=sample_kwargs),
47+
)
48+
49+
50+
@pytest.fixture
51+
def extreme_idata(ipw_result):
52+
"""Create idata with some propensity scores at 0.0 and 1.0."""
53+
import copy
54+
55+
idata = copy.deepcopy(ipw_result.idata)
56+
idata.posterior["p"][:, :, :5] = 0.0
57+
idata.posterior["p"][:, :, 5:10] = 1.0
58+
return idata
59+
60+
61+
class TestPlotAteExtremeScores:
62+
"""plot_ate must not crash when propensity scores hit 0 or 1."""
63+
64+
@pytest.mark.parametrize("method", ["raw", "robust", "overlap"])
65+
def test_plot_ate_no_crash(self, ipw_result, extreme_idata, method):
66+
"""Verify plot_ate renders without error for each weighting scheme."""
67+
fig, axs = ipw_result.plot_ate(
68+
idata=extreme_idata, method=method, prop_draws=1, ate_draws=5
69+
)
70+
assert isinstance(fig, plt.Figure)
71+
plt.close(fig)
72+
73+
74+
class TestPlotBalanceEcdfExtremeScores:
75+
"""plot_balance_ecdf must not crash when propensity scores hit 0 or 1."""
76+
77+
@pytest.mark.parametrize("scheme", ["raw", "robust", "overlap"])
78+
def test_plot_balance_ecdf_no_crash(self, ipw_result, extreme_idata, scheme):
79+
"""Verify plot_balance_ecdf renders without error for each weighting scheme."""
80+
fig, axs = ipw_result.plot_balance_ecdf(
81+
"age", idata=extreme_idata, weighting_scheme=scheme
82+
)
83+
assert isinstance(fig, plt.Figure)
84+
plt.close(fig)
85+
86+
87+
class TestPreparePs:
88+
"""Unit tests for _prepare_ps clipping behavior."""
89+
90+
def test_clips_zeros(self, ipw_result):
91+
"""Scores at 0.0 are clipped to eps."""
92+
ps = np.array([0.0, 0.5, 1.0])
93+
clipped = ipw_result._prepare_ps(ps)
94+
assert clipped[0] > 0.0
95+
assert clipped[2] < 1.0
96+
assert clipped[1] == 0.5
97+
98+
def test_warns_on_extreme(self, ipw_result):
99+
"""A warning is emitted when extreme scores are detected."""
100+
ps = np.array([0.0, 0.5, 1.0])
101+
with pytest.warns(UserWarning, match="Extreme propensity scores"):
102+
ipw_result._prepare_ps(ps)
103+
104+
def test_no_warn_on_safe(self, ipw_result):
105+
"""No warning when all scores are within bounds."""
106+
ps = np.array([0.3, 0.5, 0.7])
107+
# Should not warn
108+
clipped = ipw_result._prepare_ps(ps)
109+
np.testing.assert_array_equal(ps, clipped)

0 commit comments

Comments
 (0)