Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/arviz_stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
SamplingWrapper,
loo_kfold,
)

from arviz_stats.psense import psense, psense_summary
from arviz_stats.metrics import bayesian_r2, kl_divergence, metrics, residual_r2, wasserstein
from arviz_stats.sampling_diagnostics import bfmi, ess, mcse, rhat, rhat_nested, diagnose
Expand Down
4 changes: 2 additions & 2 deletions src/arviz_stats/base/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def get_bins(self, ary, axis=-1, bins="arviz"):

# pylint: disable=redefined-builtin, too-many-return-statements
# noqa: PLR0911
def histogram(self, ary, bins=None, range=None, weights=None, axis=-1, density=None):
def histogram(self, ary, bins=None, range=None, weights=None, axis=-1, density=True):
"""Compute histogram over provided axis.

Parameters
Expand All @@ -429,7 +429,7 @@ def histogram(self, ary, bins=None, range=None, weights=None, axis=-1, density=N
range : (float, float), optional
weights : array-like, optional
axis : int, sequence of int or None, default -1
density : bool, optional
density : bool, default True

Returns
-------
Expand Down
2 changes: 1 addition & 1 deletion src/arviz_stats/base/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def _get_bins(self, values, bins="arviz"):
return bins

# pylint: disable=redefined-builtin
def _histogram(self, ary, bins=None, range=None, weights=None, density=None):
def _histogram(self, ary, bins=None, range=None, weights=None, density=True):
if bins is None:
bins = self._get_bins(ary)
return np.histogram(ary, bins=bins, range=range, weights=weights, density=density)
Expand Down
2 changes: 1 addition & 1 deletion src/arviz_stats/base/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def get_bins(self, da, dim=None, bins="arviz"):
)

# pylint: disable=redefined-builtin
def histogram(self, da, dim=None, bins=None, range=None, weights=None, density=None):
def histogram(self, da, dim=None, bins=None, range=None, weights=None, density=True):
"""Compute histogram on DataArray input."""
dims = validate_dims(dim)
edges_dim = "edges_dim" if da.name is None else f"edges_dim_{da.name}"
Expand Down
49 changes: 37 additions & 12 deletions src/arviz_stats/loo/loo.py
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep in mind that moment matching is a two-step algorithm. First compute standard PSIS-LOO-CV, then improve the estimates for observations with high Pareto-k. So loo_moment_match() expects a fully computed PSIS-LOO-CV result (with pointwise=True so it has per-observation Pareto k values) as its loo_orig argument so it knows which observations need fixing.

Since we're backend-agnostic and can't auto-derive the required functions from a model object the way R can with stanfit, loo() needs to handle this wiring explicitly. I think the cleanest way is to compute the base PSIS-LOO-CV via a recursive loo() call with pointwise=True, then pass that result directly to loo_moment_match() and return.

So we should probably have something like the following at the entry point of this function:

    if moment_match:
        # early validation goes here

        loo_orig = loo(
            data,
            pointwise=True,
            var_name=var_name,
            reff=reff,
            log_weights=log_weights,
            pareto_k=pareto_k,
            log_jacobian=log_jacobian,
        )

        return loo_moment_match(
            data,
            loo_orig,
            log_prob_upars_fn=log_prob_upars_fn,
            log_lik_i_upars_fn=log_lik_i_upars_fn,
            upars=upars,
            var_name=var_name,
            reff=reff,
            pointwise=pointwise,
            **moment_match_kwargs,
        )

    loo_inputs = _prepare_loo_inputs(data, var_name)
    pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise

Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import warnings

from .loo_moment_match import loo_moment_match
from arviz_base import rcParams
from xarray_einstats.stats import logsumexp

Expand All @@ -28,6 +29,11 @@ def loo(
pareto_k=None,
log_jacobian=None,
mixture=False,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notice that we allow the user to pass True here for mixture importance sampling, which uses mixture estimators for high dimensional Bayesian settings where standard PSIS-LOO-CV struggles (see https://mc-stan.org/loo/articles/loo2-mixis.html). When this argument is True, the user is expected to pass a mixture posterior, which is fundamentally different than the posterior that loo_moment_match() expects and operates on.

Implicitly adaptive moment matching and mixture importance sampling are different correction strategies and shouldn't be combined. Right now if a user passes both flags, the mixture branch returns early and moment_match=True gets silently ignored with no error or warning.

We should add something like the following early in the loo() function to avoid this:

if moment_match and mixture:
    raise ValueError(
        "`moment_match` and `mixture` cannot both be True. "
        "These are different correction strategies for PSIS-LOO-CV."
    )

moment_match=False,
log_prob_upars_fn=None,
log_lik_i_upars_fn=None,
upars=None,
**moment_match_kwargs,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of these inputs are new to the loo() function and should have proper documentation in the docstring parameters section. We should be following the same verbiage as what is already included in the loo_moment_match() function for consistency.

For the moment_match_kwargs argument, we should be explicit about what these are and probably add a reference to the loo_moment_match() function itself at the end.

):
r"""Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV).

Expand Down Expand Up @@ -159,6 +165,8 @@ def loo(
"""
loo_inputs = _prepare_loo_inputs(data, var_name)
pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise
if moment_match:
pointwise = True
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is correct that moment matching needs pointwise=True internally so the base PSIS-LOO-CV result includes per-observation Pareto k values to iterate over. But loo_moment_match() also accepts its own pointwise parameter which controls whether pointwise fields like elpd_i, pareto_k, and influence_pareto_k are included in the final result or stripped out before returning. By overwriting the variable here, the user's original choice is lost, so, for example,

loo(data, pointwise=False, moment_match=True)

would still return pointwise data.

I think we should remove these two lines since this should be handled with our original loo computations.


if reff is None:
reff = _get_r_eff(data, loo_inputs.n_samples)
Expand Down Expand Up @@ -218,18 +226,35 @@ def loo(
log_weights=mix_log_weights,
)

return _compute_loo_results(
log_likelihood=loo_inputs.log_likelihood,
var_name=loo_inputs.var_name,
pointwise=pointwise,
sample_dims=loo_inputs.sample_dims,
n_samples=loo_inputs.n_samples,
n_data_points=loo_inputs.n_data_points,
log_weights=log_weights,
pareto_k=pareto_k,
approx_posterior=False,
log_jacobian=jacobian_da,
)
loo_res = _compute_loo_results(
log_likelihood=loo_inputs.log_likelihood,
var_name=loo_inputs.var_name,
pointwise=pointwise,
sample_dims=loo_inputs.sample_dims,
n_samples=loo_inputs.n_samples,
n_data_points=loo_inputs.n_data_points,
log_weights=log_weights,
pareto_k=pareto_k,
approx_posterior=False,
log_jacobian=jacobian_da,
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indentation here is not correct, but should be fixed if we follow the other suggestions anyway.

if moment_match:
if log_prob_upars_fn is None or log_lik_i_upars_fn is None:
raise ValueError(
"moment_match=True requires log_prob_upars_fn and log_lik_i_upars_fn"
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the checks for the required callback functions. However, this runs after _compute_loo_results already finished. So, a user who passes moment_match=True but forgets a callback function such as log_prob_upars_fn will wait through the entire PSIS-LOO-CV computation before hitting this error. If we move all validation to the top of the function body, before any computation starts, it'll fail fast on bad input.

This could have pretty major implications for a user with a particularly large log-likelihood function as the PSIS-LOO-CV could take a while.


loo_res = loo_moment_match (
data=data,
loo_orig=loo_res,
log_prob_upars_fn=log_prob_upars_fn,
log_lik_i_upars_fn=log_lik_i_upars_fn,
upars=upars,
var_name=var_name,
reff=reff,
**moment_match_kwargs,
)
return loo_res


def loo_i(
Expand Down
2 changes: 1 addition & 1 deletion src/arviz_stats/numba/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def quantile(self, ary, quantile, axis=-1, method="linear", skipna=False, weight
return result
return np.moveaxis(result, 0, -1)

def _histogram(self, ary, bins=None, range=None, weights=None, density=None): # pylint: disable=redefined-builtin
def _histogram(self, ary, bins=None, range=None, weights=None, density=True): # pylint: disable=redefined-builtin
"""Compute the histogram of the data."""
if bins is None:
bins = self._get_bins(ary)
Expand Down
2 changes: 1 addition & 1 deletion tests/base/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def test_get_bins_axis(self, array_stats, rng, axis):

def test_histogram_basic(self, array_stats, rng):
ary = rng.normal(size=(1000,))
counts, edges = array_stats.histogram(ary)
counts, edges = array_stats.histogram(ary, density=False)
assert len(counts) == len(edges) - 1
assert counts.sum() == 1000

Expand Down
2 changes: 1 addition & 1 deletion tests/base/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def test_get_bins_coverage(self, core, rng):

def test_histogram_counts(self, core, rng):
x = rng.normal(size=1000)
counts, edges = core._histogram(x)
counts, edges = core._histogram(x, density=False)
assert counts.sum() == 1000
assert len(counts) == len(edges) - 1

Expand Down
4 changes: 2 additions & 2 deletions tests/base/test_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,13 +755,13 @@ def test_kde_circular_boundary_wrapping(self, density, rng):
class TestHistogramEdgeCases:
def test_histogram_empty_bins(self, density, rng):
x = rng.normal(size=100)
counts, _edges = density._histogram(x, bins=100, range=(-10, -5))
counts, _edges = density._histogram(x, bins=100, range=(-10, -5), density=False)
assert len(counts) == 100
assert np.sum(counts) == 0

def test_histogram_single_bin(self, density, rng):
x = rng.normal(size=100)
counts, edges = density._histogram(x, bins=1)
counts, edges = density._histogram(x, bins=1, density=False)
assert len(counts) == 1
assert len(edges) == 2
assert np.sum(counts) == 100
Expand Down
21 changes: 21 additions & 0 deletions tests/test_minimal.py
Copy link
Copy Markdown
Member

@jordandeklerk jordandeklerk Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two main problems here for these tests. Where the tests live, and what they test.

These tests should not be in tests/test_minimal.py. That file is for tests of the array-level API. The moment-match integration is a change to the loo() function, so the tests belong in tests/loo/test_loo.py alongside the existing PSIS-LOO tests.

There's also tests/loo/test_loo_mm.py which already has the roaches_r_example fixture, callback functions, R reference values, and everything needed to write meaningful tests for moment matching. The new tests should reuse that existing infrastructure.

As a nudge toward the right direction for testing this new functionality, we want something like the following in spirit:

from .test_loo_mm import load_roaches_r_example
from arviz_stats import loo, loo_moment_match

@pytest.fixture(scope="module")
def roaches_r_example():
    return load_roaches_r_example()


# Does the one-step path produce the same answer as the two-step path?
def test_loo_moment_match_matches_two_step(roaches_r_example):
    data = roaches_r_example["data_tree"]
    mm_kwargs = dict(
        log_prob_upars_fn=roaches_r_example["log_prob_fn"],
        log_lik_i_upars_fn=roaches_r_example["log_lik_i_fn"],
        upars=roaches_r_example["upars"],
        var_name="log_lik",
        cov=True,
    )

    loo_orig = loo(data, pointwise=True, var_name="log_lik")
    result_two_step = loo_moment_match(data, loo_orig, pointwise=True, **mm_kwargs)
    result_one_step = loo(data, moment_match=True, pointwise=True, **mm_kwargs)

    assert_almost_equal(result_one_step.elpd, result_two_step.elpd, decimal=10)
    assert_almost_equal(result_one_step.se, result_two_step.se, decimal=10)
    assert_almost_equal(result_one_step.p, result_two_step.p, decimal=10)
    assert_array_equal(result_one_step.pareto_k.values, result_two_step.pareto_k.values)

Also, as a general principle for testing, we should not just be testing that a result exists. This doesn't really tell us anything about the nature of the functionality and if it is robust or not. For some examples about what to be looking for in testing general PSIS-LOO-CV functions, see https://github.com/arviz-devs/arviz-stats/blob/main/tests/loo/test_loo.py.

Original file line number Diff line number Diff line change
Expand Up @@ -387,3 +387,24 @@ def test_hdi_invalid_prob(data_c0d1):
def test_thin_invalid_factor(data_c0d1):
with pytest.raises(ValueError, match="must be greater than 1"):
array_stats.thin(data_c0d1, factor=0, chain_axis=0, draw_axis=1)


def test_loo_moment_match_flag_runs():
from arviz_stats import loo
from arviz_base import load_arviz_data
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imports should not be nested inside the test functions. These should live at the top of the module.


data = load_arviz_data("centered_eight")

result = loo(data, moment_match=False)

assert result is not None
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This calls loo(data, moment_match=False) which is identical to loo(data), so it doesn't actually test anything new. This should be removed.


def test_loo_moment_match_requires_functions():
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the match parameter to verify the error message, e.g.,

pytest.raises(ValueError, match="log_prob_upars_fn"))

from arviz_stats import loo
from arviz_base import load_arviz_data
import pytest
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already imported at the top of the module. Remove this and move the other imports to the top of the module as well.


data = load_arviz_data("centered_eight")

with pytest.raises(ValueError):
loo(data, moment_match=True)
Loading