Skip to content

Fix moment_match integration in loo#344

Open
aatifjunaid wants to merge 6 commits intoarviz-devs:mainfrom
aatifjunaid:aatif-branch
Open

Fix moment_match integration in loo#344
aatifjunaid wants to merge 6 commits intoarviz-devs:mainfrom
aatifjunaid:aatif-branch

Conversation

@aatifjunaid
Copy link
Copy Markdown
Contributor

  • Ensures loo_moment_match is only called when moment_match=True
  • Adds validation for required functions
  • Adds minimal tests for flag behavior

All tests passing locally.

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Mar 29, 2026

Codecov Report

❌ Patch coverage is 77.77778% with 2 lines in your changes missing coverage. Please review.
✅ Project coverage is 84.64%. Comparing base (5d8858a) to head (9f52247).
⚠️ Report is 26 commits behind head on main.

Files with missing lines Patch % Lines
src/arviz_stats/loo/loo.py 83.33% 1 Missing ⚠️
src/arviz_stats/numba/array.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #344      +/-   ##
==========================================
+ Coverage   84.48%   84.64%   +0.15%     
==========================================
  Files          42       43       +1     
  Lines        5794     6088     +294     
==========================================
+ Hits         4895     5153     +258     
- Misses        899      935      +36     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown
Member

@jordandeklerk jordandeklerk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking this on! You've made some good progress here and you've got the key pieces in place for the most part. That's the hardest part.

I'm taking a first pass at this and might come back with more changes after this first round. I will need to dive a little deeper and ensure everything is in the right place. The following items need to be addressed first though:

  • There are some unrelated histogram() function changes on this branch that have nothing to do with the scope of this PR and the issue attached to it. All of these changes need to be reverted. I'm not sure where they came from.

  • One of the main things to address is the overall structure of how the moment match path fits into loo(). Right now the moment-match logic runs at the end of the function after the full PSIS-LOO-CV computation, which means validation happens too late, the pointwise flag gets permanently overridden, and the non-moment-match code path gets modified unnecessarily, among other things.

  • The tests also need attention. They're currently in tests/test_minimal.py but should be in tests/loo/test_loo.py probably, and the existing roaches_r_example fixture from tests/loo/test_loo_mm.py should be reused to add end-to-end tests that prove the wiring is correct.

See the comments below for specifics on each point.

Remember to review the contributing tutorial here https://python.arviz.org/en/latest/contributing/pr_tutorial.html. You'll need to address linting issues and test CI failures as well. But we can iron those out with the new set of changes that need to be implemented.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep in mind that moment matching is a two-step algorithm. First compute standard PSIS-LOO-CV, then improve the estimates for observations with high Pareto-k. So loo_moment_match() expects a fully computed PSIS-LOO-CV result (with pointwise=True so it has per-observation Pareto k values) as its loo_orig argument so it knows which observations need fixing.

Since we're backend-agnostic and can't auto-derive the required functions from a model object the way R can with stanfit, loo() needs to handle this wiring explicitly. I think the cleanest way is to compute the base PSIS-LOO-CV via a recursive loo() call with pointwise=True, then pass that result directly to loo_moment_match() and return.

So we should probably have something like the following at the entry point of this function:

    if moment_match:
        # early validation goes here

        loo_orig = loo(
            data,
            pointwise=True,
            var_name=var_name,
            reff=reff,
            log_weights=log_weights,
            pareto_k=pareto_k,
            log_jacobian=log_jacobian,
        )

        return loo_moment_match(
            data,
            loo_orig,
            log_prob_upars_fn=log_prob_upars_fn,
            log_lik_i_upars_fn=log_lik_i_upars_fn,
            upars=upars,
            var_name=var_name,
            reff=reff,
            pointwise=pointwise,
            **moment_match_kwargs,
        )

    loo_inputs = _prepare_loo_inputs(data, var_name)
    pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise

log_weights=None,
pareto_k=None,
log_jacobian=None,
mixture=False,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notice that we allow the user to pass True here for mixture importance sampling, which uses mixture estimators for high dimensional Bayesian settings where standard PSIS-LOO-CV struggles (see https://mc-stan.org/loo/articles/loo2-mixis.html). When this argument is True, the user is expected to pass a mixture posterior, which is fundamentally different than the posterior that loo_moment_match() expects and operates on.

Implicitly adaptive moment matching and mixture importance sampling are different correction strategies and shouldn't be combined. Right now if a user passes both flags, the mixture branch returns early and moment_match=True gets silently ignored with no error or warning.

We should add something like the following early in the loo() function to avoid this:

if moment_match and mixture:
    raise ValueError(
        "`moment_match` and `mixture` cannot both be True. "
        "These are different correction strategies for PSIS-LOO-CV."
    )

log_prob_upars_fn=None,
log_lik_i_upars_fn=None,
upars=None,
**moment_match_kwargs,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of these inputs are new to the loo() function and should have proper documentation in the docstring parameters section. We should be following the same verbiage as what is already included in the loo_moment_match() function for consistency.

For the moment_match_kwargs argument, we should be explicit about what these are and probably add a reference to the loo_moment_match() function itself at the end.

loo_inputs = _prepare_loo_inputs(data, var_name)
pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise
if moment_match:
pointwise = True
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is correct that moment matching needs pointwise=True internally so the base PSIS-LOO-CV result includes per-observation Pareto k values to iterate over. But loo_moment_match() also accepts its own pointwise parameter which controls whether pointwise fields like elpd_i, pareto_k, and influence_pareto_k are included in the final result or stripped out before returning. By overwriting the variable here, the user's original choice is lost, so, for example,

loo(data, pointwise=False, moment_match=True)

would still return pointwise data.

I think we should remove these two lines since this should be handled with our original loo computations.

pareto_k=pareto_k,
approx_posterior=False,
log_jacobian=jacobian_da,
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indentation here is not correct, but should be fixed if we follow the other suggestions anyway.

Comment thread tests/test_minimal.py
Copy link
Copy Markdown
Member

@jordandeklerk jordandeklerk Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two main problems here for these tests. Where the tests live, and what they test.

These tests should not be in tests/test_minimal.py. That file is for tests of the array-level API. The moment-match integration is a change to the loo() function, so the tests belong in tests/loo/test_loo.py alongside the existing PSIS-LOO tests.

There's also tests/loo/test_loo_mm.py which already has the roaches_r_example fixture, callback functions, R reference values, and everything needed to write meaningful tests for moment matching. The new tests should reuse that existing infrastructure.

As a nudge toward the right direction for testing this new functionality, we want something like the following in spirit:

from .test_loo_mm import load_roaches_r_example
from arviz_stats import loo, loo_moment_match

@pytest.fixture(scope="module")
def roaches_r_example():
    return load_roaches_r_example()


# Does the one-step path produce the same answer as the two-step path?
def test_loo_moment_match_matches_two_step(roaches_r_example):
    data = roaches_r_example["data_tree"]
    mm_kwargs = dict(
        log_prob_upars_fn=roaches_r_example["log_prob_fn"],
        log_lik_i_upars_fn=roaches_r_example["log_lik_i_fn"],
        upars=roaches_r_example["upars"],
        var_name="log_lik",
        cov=True,
    )

    loo_orig = loo(data, pointwise=True, var_name="log_lik")
    result_two_step = loo_moment_match(data, loo_orig, pointwise=True, **mm_kwargs)
    result_one_step = loo(data, moment_match=True, pointwise=True, **mm_kwargs)

    assert_almost_equal(result_one_step.elpd, result_two_step.elpd, decimal=10)
    assert_almost_equal(result_one_step.se, result_two_step.se, decimal=10)
    assert_almost_equal(result_one_step.p, result_two_step.p, decimal=10)
    assert_array_equal(result_one_step.pareto_k.values, result_two_step.pareto_k.values)

Also, as a general principle for testing, we should not just be testing that a result exists. This doesn't really tell us anything about the nature of the functionality and if it is robust or not. For some examples about what to be looking for in testing general PSIS-LOO-CV functions, see https://github.com/arviz-devs/arviz-stats/blob/main/tests/loo/test_loo.py.

Comment thread tests/test_minimal.py

def test_loo_moment_match_flag_runs():
from arviz_stats import loo
from arviz_base import load_arviz_data
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imports should not be nested inside the test functions. These should live at the top of the module.

Comment thread tests/test_minimal.py

result = loo(data, moment_match=False)

assert result is not None
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This calls loo(data, moment_match=False) which is identical to loo(data), so it doesn't actually test anything new. This should be removed.

Comment thread tests/test_minimal.py

assert result is not None

def test_loo_moment_match_requires_functions():
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the match parameter to verify the error message, e.g.,

pytest.raises(ValueError, match="log_prob_upars_fn"))

Comment thread tests/test_minimal.py
def test_loo_moment_match_requires_functions():
from arviz_stats import loo
from arviz_base import load_arviz_data
import pytest
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already imported at the top of the module. Remove this and move the other imports to the top of the module as well.

@aatifjunaid
Copy link
Copy Markdown
Contributor Author

Hi, thanks a lot for the detailed review and guidance — this is really helpful.

Apologies for the confusion on this PR.
I’m a bit tied up with a deadline right now, but I’ll come back to this shortly, clean up the branch, and rework the implementation following your suggestions (especially the two-step structure and test placement).

Really appreciate your patience — I’ll update this soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants