-
-
Notifications
You must be signed in to change notification settings - Fork 23
Fix moment_match integration in loo #344
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b9a1290
f91d03b
dbfa0e3
a241ee6
31edb22
9f52247
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
|
|
||
| import warnings | ||
|
|
||
| from .loo_moment_match import loo_moment_match | ||
| from arviz_base import rcParams | ||
| from xarray_einstats.stats import logsumexp | ||
|
|
||
|
|
@@ -28,6 +29,11 @@ def loo( | |
| pareto_k=None, | ||
| log_jacobian=None, | ||
| mixture=False, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Notice that we allow the user to pass Implicitly adaptive moment matching and mixture importance sampling are different correction strategies and shouldn't be combined. Right now if a user passes both flags, the mixture branch returns early and We should add something like the following early in the if moment_match and mixture:
raise ValueError(
"`moment_match` and `mixture` cannot both be True. "
"These are different correction strategies for PSIS-LOO-CV."
) |
||
| moment_match=False, | ||
| log_prob_upars_fn=None, | ||
| log_lik_i_upars_fn=None, | ||
| upars=None, | ||
| **moment_match_kwargs, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All of these inputs are new to the For the |
||
| ): | ||
| r"""Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV). | ||
|
|
||
|
|
@@ -159,6 +165,8 @@ def loo( | |
| """ | ||
| loo_inputs = _prepare_loo_inputs(data, var_name) | ||
| pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise | ||
| if moment_match: | ||
| pointwise = True | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is correct that moment matching needs loo(data, pointwise=False, moment_match=True)would still return pointwise data. I think we should remove these two lines since this should be handled with our original |
||
|
|
||
| if reff is None: | ||
| reff = _get_r_eff(data, loo_inputs.n_samples) | ||
|
|
@@ -218,18 +226,35 @@ def loo( | |
| log_weights=mix_log_weights, | ||
| ) | ||
|
|
||
| return _compute_loo_results( | ||
| log_likelihood=loo_inputs.log_likelihood, | ||
| var_name=loo_inputs.var_name, | ||
| pointwise=pointwise, | ||
| sample_dims=loo_inputs.sample_dims, | ||
| n_samples=loo_inputs.n_samples, | ||
| n_data_points=loo_inputs.n_data_points, | ||
| log_weights=log_weights, | ||
| pareto_k=pareto_k, | ||
| approx_posterior=False, | ||
| log_jacobian=jacobian_da, | ||
| ) | ||
| loo_res = _compute_loo_results( | ||
| log_likelihood=loo_inputs.log_likelihood, | ||
| var_name=loo_inputs.var_name, | ||
| pointwise=pointwise, | ||
| sample_dims=loo_inputs.sample_dims, | ||
| n_samples=loo_inputs.n_samples, | ||
| n_data_points=loo_inputs.n_data_points, | ||
| log_weights=log_weights, | ||
| pareto_k=pareto_k, | ||
| approx_posterior=False, | ||
| log_jacobian=jacobian_da, | ||
| ) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indentation here is not correct, but should be fixed if we follow the other suggestions anyway. |
||
| if moment_match: | ||
| if log_prob_upars_fn is None or log_lik_i_upars_fn is None: | ||
| raise ValueError( | ||
| "moment_match=True requires log_prob_upars_fn and log_lik_i_upars_fn" | ||
| ) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like the checks for the required callback functions. However, this runs after This could have pretty major implications for a user with a particularly large log-likelihood function as the PSIS-LOO-CV could take a while. |
||
|
|
||
| loo_res = loo_moment_match ( | ||
| data=data, | ||
| loo_orig=loo_res, | ||
| log_prob_upars_fn=log_prob_upars_fn, | ||
| log_lik_i_upars_fn=log_lik_i_upars_fn, | ||
| upars=upars, | ||
| var_name=var_name, | ||
| reff=reff, | ||
| **moment_match_kwargs, | ||
| ) | ||
| return loo_res | ||
|
|
||
|
|
||
| def loo_i( | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are two main problems here for these tests. Where the tests live, and what they test. These tests should not be in There's also As a nudge toward the right direction for testing this new functionality, we want something like the following in spirit: from .test_loo_mm import load_roaches_r_example
from arviz_stats import loo, loo_moment_match
@pytest.fixture(scope="module")
def roaches_r_example():
return load_roaches_r_example()
# Does the one-step path produce the same answer as the two-step path?
def test_loo_moment_match_matches_two_step(roaches_r_example):
data = roaches_r_example["data_tree"]
mm_kwargs = dict(
log_prob_upars_fn=roaches_r_example["log_prob_fn"],
log_lik_i_upars_fn=roaches_r_example["log_lik_i_fn"],
upars=roaches_r_example["upars"],
var_name="log_lik",
cov=True,
)
loo_orig = loo(data, pointwise=True, var_name="log_lik")
result_two_step = loo_moment_match(data, loo_orig, pointwise=True, **mm_kwargs)
result_one_step = loo(data, moment_match=True, pointwise=True, **mm_kwargs)
assert_almost_equal(result_one_step.elpd, result_two_step.elpd, decimal=10)
assert_almost_equal(result_one_step.se, result_two_step.se, decimal=10)
assert_almost_equal(result_one_step.p, result_two_step.p, decimal=10)
assert_array_equal(result_one_step.pareto_k.values, result_two_step.pareto_k.values)Also, as a general principle for testing, we should not just be testing that a result exists. This doesn't really tell us anything about the nature of the functionality and if it is robust or not. For some examples about what to be looking for in testing general PSIS-LOO-CV functions, see https://github.com/arviz-devs/arviz-stats/blob/main/tests/loo/test_loo.py. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -387,3 +387,24 @@ def test_hdi_invalid_prob(data_c0d1): | |
| def test_thin_invalid_factor(data_c0d1): | ||
| with pytest.raises(ValueError, match="must be greater than 1"): | ||
| array_stats.thin(data_c0d1, factor=0, chain_axis=0, draw_axis=1) | ||
|
|
||
|
|
||
| def test_loo_moment_match_flag_runs(): | ||
| from arviz_stats import loo | ||
| from arviz_base import load_arviz_data | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Imports should not be nested inside the test functions. These should live at the top of the module. |
||
|
|
||
| data = load_arviz_data("centered_eight") | ||
|
|
||
| result = loo(data, moment_match=False) | ||
|
|
||
| assert result is not None | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This calls |
||
|
|
||
| def test_loo_moment_match_requires_functions(): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use the match parameter to verify the error message, e.g., pytest.raises(ValueError, match="log_prob_upars_fn")) |
||
| from arviz_stats import loo | ||
| from arviz_base import load_arviz_data | ||
| import pytest | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is already imported at the top of the module. Remove this and move the other imports to the top of the module as well. |
||
|
|
||
| data = load_arviz_data("centered_eight") | ||
|
|
||
| with pytest.raises(ValueError): | ||
| loo(data, moment_match=True) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keep in mind that moment matching is a two-step algorithm. First compute standard PSIS-LOO-CV, then improve the estimates for observations with high Pareto-k. So
loo_moment_match()expects a fully computed PSIS-LOO-CV result (withpointwise=Trueso it has per-observation Pareto k values) as itsloo_origargument so it knows which observations need fixing.Since we're backend-agnostic and can't auto-derive the required functions from a model object the way R can with stanfit,
loo()needs to handle this wiring explicitly. I think the cleanest way is to compute the base PSIS-LOO-CV via a recursiveloo()call withpointwise=True, then pass that result directly toloo_moment_match()and return.So we should probably have something like the following at the entry point of this function: