Fix moment_match integration in loo#344
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #344 +/- ##
==========================================
+ Coverage 84.48% 84.64% +0.15%
==========================================
Files 42 43 +1
Lines 5794 6088 +294
==========================================
+ Hits 4895 5153 +258
- Misses 899 935 +36 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
jordandeklerk
left a comment
There was a problem hiding this comment.
Thanks for taking this on! You've made some good progress here and you've got the key pieces in place for the most part. That's the hardest part.
I'm taking a first pass at this and might come back with more changes after this first round. I will need to dive a little deeper and ensure everything is in the right place. The following items need to be addressed first though:
-
There are some unrelated
histogram()function changes on this branch that have nothing to do with the scope of this PR and the issue attached to it. All of these changes need to be reverted. I'm not sure where they came from. -
One of the main things to address is the overall structure of how the moment match path fits into
loo(). Right now the moment-match logic runs at the end of the function after the full PSIS-LOO-CV computation, which means validation happens too late, the pointwise flag gets permanently overridden, and the non-moment-match code path gets modified unnecessarily, among other things. -
The tests also need attention. They're currently in
tests/test_minimal.pybut should be intests/loo/test_loo.pyprobably, and the existingroaches_r_examplefixture fromtests/loo/test_loo_mm.pyshould be reused to add end-to-end tests that prove the wiring is correct.
See the comments below for specifics on each point.
Remember to review the contributing tutorial here https://python.arviz.org/en/latest/contributing/pr_tutorial.html. You'll need to address linting issues and test CI failures as well. But we can iron those out with the new set of changes that need to be implemented.
There was a problem hiding this comment.
Keep in mind that moment matching is a two-step algorithm. First compute standard PSIS-LOO-CV, then improve the estimates for observations with high Pareto-k. So loo_moment_match() expects a fully computed PSIS-LOO-CV result (with pointwise=True so it has per-observation Pareto k values) as its loo_orig argument so it knows which observations need fixing.
Since we're backend-agnostic and can't auto-derive the required functions from a model object the way R can with stanfit, loo() needs to handle this wiring explicitly. I think the cleanest way is to compute the base PSIS-LOO-CV via a recursive loo() call with pointwise=True, then pass that result directly to loo_moment_match() and return.
So we should probably have something like the following at the entry point of this function:
if moment_match:
# early validation goes here
loo_orig = loo(
data,
pointwise=True,
var_name=var_name,
reff=reff,
log_weights=log_weights,
pareto_k=pareto_k,
log_jacobian=log_jacobian,
)
return loo_moment_match(
data,
loo_orig,
log_prob_upars_fn=log_prob_upars_fn,
log_lik_i_upars_fn=log_lik_i_upars_fn,
upars=upars,
var_name=var_name,
reff=reff,
pointwise=pointwise,
**moment_match_kwargs,
)
loo_inputs = _prepare_loo_inputs(data, var_name)
pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise| log_weights=None, | ||
| pareto_k=None, | ||
| log_jacobian=None, | ||
| mixture=False, |
There was a problem hiding this comment.
Notice that we allow the user to pass True here for mixture importance sampling, which uses mixture estimators for high dimensional Bayesian settings where standard PSIS-LOO-CV struggles (see https://mc-stan.org/loo/articles/loo2-mixis.html). When this argument is True, the user is expected to pass a mixture posterior, which is fundamentally different than the posterior that loo_moment_match() expects and operates on.
Implicitly adaptive moment matching and mixture importance sampling are different correction strategies and shouldn't be combined. Right now if a user passes both flags, the mixture branch returns early and moment_match=True gets silently ignored with no error or warning.
We should add something like the following early in the loo() function to avoid this:
if moment_match and mixture:
raise ValueError(
"`moment_match` and `mixture` cannot both be True. "
"These are different correction strategies for PSIS-LOO-CV."
)| log_prob_upars_fn=None, | ||
| log_lik_i_upars_fn=None, | ||
| upars=None, | ||
| **moment_match_kwargs, |
There was a problem hiding this comment.
All of these inputs are new to the loo() function and should have proper documentation in the docstring parameters section. We should be following the same verbiage as what is already included in the loo_moment_match() function for consistency.
For the moment_match_kwargs argument, we should be explicit about what these are and probably add a reference to the loo_moment_match() function itself at the end.
| loo_inputs = _prepare_loo_inputs(data, var_name) | ||
| pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise | ||
| if moment_match: | ||
| pointwise = True |
There was a problem hiding this comment.
This is correct that moment matching needs pointwise=True internally so the base PSIS-LOO-CV result includes per-observation Pareto k values to iterate over. But loo_moment_match() also accepts its own pointwise parameter which controls whether pointwise fields like elpd_i, pareto_k, and influence_pareto_k are included in the final result or stripped out before returning. By overwriting the variable here, the user's original choice is lost, so, for example,
loo(data, pointwise=False, moment_match=True)would still return pointwise data.
I think we should remove these two lines since this should be handled with our original loo computations.
| pareto_k=pareto_k, | ||
| approx_posterior=False, | ||
| log_jacobian=jacobian_da, | ||
| ) |
There was a problem hiding this comment.
Indentation here is not correct, but should be fixed if we follow the other suggestions anyway.
There was a problem hiding this comment.
There are two main problems here for these tests. Where the tests live, and what they test.
These tests should not be in tests/test_minimal.py. That file is for tests of the array-level API. The moment-match integration is a change to the loo() function, so the tests belong in tests/loo/test_loo.py alongside the existing PSIS-LOO tests.
There's also tests/loo/test_loo_mm.py which already has the roaches_r_example fixture, callback functions, R reference values, and everything needed to write meaningful tests for moment matching. The new tests should reuse that existing infrastructure.
As a nudge toward the right direction for testing this new functionality, we want something like the following in spirit:
from .test_loo_mm import load_roaches_r_example
from arviz_stats import loo, loo_moment_match
@pytest.fixture(scope="module")
def roaches_r_example():
return load_roaches_r_example()
# Does the one-step path produce the same answer as the two-step path?
def test_loo_moment_match_matches_two_step(roaches_r_example):
data = roaches_r_example["data_tree"]
mm_kwargs = dict(
log_prob_upars_fn=roaches_r_example["log_prob_fn"],
log_lik_i_upars_fn=roaches_r_example["log_lik_i_fn"],
upars=roaches_r_example["upars"],
var_name="log_lik",
cov=True,
)
loo_orig = loo(data, pointwise=True, var_name="log_lik")
result_two_step = loo_moment_match(data, loo_orig, pointwise=True, **mm_kwargs)
result_one_step = loo(data, moment_match=True, pointwise=True, **mm_kwargs)
assert_almost_equal(result_one_step.elpd, result_two_step.elpd, decimal=10)
assert_almost_equal(result_one_step.se, result_two_step.se, decimal=10)
assert_almost_equal(result_one_step.p, result_two_step.p, decimal=10)
assert_array_equal(result_one_step.pareto_k.values, result_two_step.pareto_k.values)Also, as a general principle for testing, we should not just be testing that a result exists. This doesn't really tell us anything about the nature of the functionality and if it is robust or not. For some examples about what to be looking for in testing general PSIS-LOO-CV functions, see https://github.com/arviz-devs/arviz-stats/blob/main/tests/loo/test_loo.py.
|
|
||
| def test_loo_moment_match_flag_runs(): | ||
| from arviz_stats import loo | ||
| from arviz_base import load_arviz_data |
There was a problem hiding this comment.
Imports should not be nested inside the test functions. These should live at the top of the module.
|
|
||
| result = loo(data, moment_match=False) | ||
|
|
||
| assert result is not None |
There was a problem hiding this comment.
This calls loo(data, moment_match=False) which is identical to loo(data), so it doesn't actually test anything new. This should be removed.
|
|
||
| assert result is not None | ||
|
|
||
| def test_loo_moment_match_requires_functions(): |
There was a problem hiding this comment.
Use the match parameter to verify the error message, e.g.,
pytest.raises(ValueError, match="log_prob_upars_fn"))| def test_loo_moment_match_requires_functions(): | ||
| from arviz_stats import loo | ||
| from arviz_base import load_arviz_data | ||
| import pytest |
There was a problem hiding this comment.
This is already imported at the top of the module. Remove this and move the other imports to the top of the module as well.
|
Hi, thanks a lot for the detailed review and guidance — this is really helpful. Apologies for the confusion on this PR. Really appreciate your patience — I’ll update this soon. |
All tests passing locally.