From ff04859888b0cfa48e6c99b6bb2ee02c79ee1e31 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 16 Sep 2025 17:54:29 +0200 Subject: [PATCH 1/5] Generalize r2_score --- src/r2_score.jl | 41 ++++++++++++++++++++--------------------- test/r2_score.jl | 19 +++++++++++++------ 2 files changed, 33 insertions(+), 27 deletions(-) diff --git a/src/r2_score.jl b/src/r2_score.jl index 571daf9..0972c54 100644 --- a/src/r2_score.jl +++ b/src/r2_score.jl @@ -1,5 +1,5 @@ """ - r2_score(y_true::AbstractVector, y_pred::AbstractArray) -> (; r2, r2_std) + r2_score(y_true::AbstractVector, y_pred::AbstractArray; kwargs...) -> (; r2, r2_std) ``R²`` for linear Bayesian regression models.[Gelman2019](@citep) @@ -8,6 +8,15 @@ - `y_true`: Observed data of length `noutputs` - `y_pred`: Predicted data with size `(ndraws[, nchains], noutputs)` +# Keywords + + - `summary::Bool=true`: Whether to return the mean and CI of the ``R²`` scores or the raw + samples. + - `ci_fun=eti`: The function used to compute the credible interval if `summary` is `true`. + Supported options are [`eti`](@ref) and [`hdi`](@ref). + - `ci_prob=$(DEFAULT_CI_PROB)`: The probability mass to be contained in the credible + interval. + # Examples ```jldoctest @@ -19,34 +28,24 @@ julia> y_true = idata.observed_data.y; julia> y_pred = PermutedDimsArray(idata.posterior_predictive.y, (:draw, :chain, :y_dim_0)); -julia> r2_score(y_true, y_pred) |> pairs -pairs(::NamedTuple) with 2 entries: - :r2 => 0.683197 - :r2_std => 0.0368838 +julia> r2_score(y_true, y_pred) +(r2 = 0.683196996216511, eti = 0.6082075654135802 .. 0.7462891653797559) ``` # References - [Gelman2019](@cite) Gelman et al, The Am. Stat., 73(3) (2019) """ -function r2_score(y_true, y_pred) - r_squared = r2_samples(y_true, y_pred) - return NamedTuple{(:r2, :r2_std)}(StatsBase.mean_and_std(r_squared; corrected=false)) +function r2_score(y_true, y_pred; summary=true, ci_fun=eti, ci_prob=DEFAULT_CI_PROB) + r_squared = _r2_samples(y_true, y_pred) + summary || return r_squared + r2 = Statistics.mean(r_squared) + ci = ci_fun(r_squared; prob=ci_prob) + ci_name = Symbol(_fname(ci_fun)) + return (; r2, ci_name => ci) end -""" - r2_samples(y_true::AbstractVector, y_pred::AbstractArray) -> AbstractVector - -``R²`` samples for Bayesian regression models. Only valid for linear models. - -See also [`r2_score`](@ref). - -# Arguments - - - `y_true`: Observed data of length `noutputs` - - `y_pred`: Predicted data with size `(ndraws[, nchains], noutputs)` -""" -function r2_samples(y_true::AbstractVector, y_pred::AbstractArray) +function _r2_samples(y_true::AbstractVector, y_pred::AbstractArray) @assert ndims(y_pred) ∈ (2, 3) corrected = false dims = ndims(y_pred) diff --git a/test/r2_score.jl b/test/r2_score.jl index f04794d..00b4676 100644 --- a/test/r2_score.jl +++ b/test/r2_score.jl @@ -3,7 +3,7 @@ using PosteriorStats using Statistics using Test -@testset "r2_score/r2_sample" begin +@testset "r2_score" begin @testset "basic" begin n = 100 @testset for T in (Float32, Float64), @@ -17,11 +17,18 @@ using Test x_reshape = length(sz) == 1 ? x' : reshape(x, 1, 1, :) y_pred = slope .* x_reshape .+ intercept .+ randn(T, sz..., n) .* σ - r2_val = @inferred r2_score(y, y_pred) - @test r2_val isa NamedTuple{(:r2, :r2_std),NTuple{2,T}} - r2_draws = @inferred PosteriorStats.r2_samples(y, y_pred) - @test r2_val.r2 ≈ mean(r2_draws) - @test r2_val.r2_std ≈ std(r2_draws; corrected=false) + r2_val = @inferred r2_score(y, y_pred; ci_prob=T(0.89)) + @test r2_val isa @NamedTuple{r2::T, eti::ClosedInterval{T}} + r2_draws = @inferred PosteriorStats._r2_samples(y, y_pred) + @test r2_val.r2 == mean(r2_draws) + @test r2_val.eti == eti(r2_draws; prob=T(0.89)) + + r2_val2 = r2_score(y, y_pred; ci_fun=hdi, ci_prob=T(0.95)) + @test r2_val2 isa @NamedTuple{r2::T, hdi::ClosedInterval{T}} + @test r2_val2.hdi == hdi(r2_draws; prob=0.95) + + r2_draws2 = PosteriorStats.r2_score(y, y_pred; summary=false) + @test r2_draws2 == r2_draws # check rough consistency with GLM res = lm(@formula(y ~ 1 + x), (; x=Float64.(x), y=Float64.(y))) From 5c92abc78020b36535e2fd7f2499dfd1dbbf3a31 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 16 Sep 2025 17:54:41 +0200 Subject: [PATCH 2/5] Update r2_score docstring --- src/r2_score.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/r2_score.jl b/src/r2_score.jl index 0972c54..ddbe92c 100644 --- a/src/r2_score.jl +++ b/src/r2_score.jl @@ -3,6 +3,14 @@ ``R²`` for linear Bayesian regression models.[Gelman2019](@citep) +The ``R²``, or coefficient of determination, is defined as the proportion of variance in the +data that is explained by the model. For each draw, it is computed as the variance of the +predicted values divided by the variance of the predicted values plus the variance of the +residuals. + +The distribution of the ``R²`` scores can then be summarized using the mean and a credible +interval (CI). + # Arguments - `y_true`: Observed data of length `noutputs` From 660f590a2e23f0ca22477acae84e39676c50826c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 16 Sep 2025 18:03:35 +0200 Subject: [PATCH 3/5] Support other point estimates --- src/r2_score.jl | 29 +++++++++++++++++++++-------- test/r2_score.jl | 7 +++++-- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/src/r2_score.jl b/src/r2_score.jl index ddbe92c..69800fb 100644 --- a/src/r2_score.jl +++ b/src/r2_score.jl @@ -8,8 +8,8 @@ data that is explained by the model. For each draw, it is computed as the varian predicted values divided by the variance of the predicted values plus the variance of the residuals. -The distribution of the ``R²`` scores can then be summarized using the mean and a credible -interval (CI). +The distribution of the ``R²`` scores can then be summarized using a point estimate and a +credible interval (CI). # Arguments @@ -18,10 +18,16 @@ interval (CI). # Keywords - - `summary::Bool=true`: Whether to return the mean and CI of the ``R²`` scores or the raw - samples. - - `ci_fun=eti`: The function used to compute the credible interval if `summary` is `true`. - Supported options are [`eti`](@ref) and [`hdi`](@ref). + - `summary::Bool=true`: Whether to return a summary or an array of ``R²`` scores. The + summary is a named tuple with the point estimate `:r2` and the credible interval + `:`. + - `point_estimate=Statistics.mean`: The function used to compute the point estimate of the + ``R²`` scores if `summary` is `true`. Supported options are: + + [`Statistics.mean`](@extref) (default) + + [`Statistics.median`](@extref) + + [`StatsBase.mode`](@extref) + - `ci_fun=eti`: The function used to compute the credible interval if `summary` is + `true`. Supported options are [`eti`](@ref) and [`hdi`](@ref). - `ci_prob=$(DEFAULT_CI_PROB)`: The probability mass to be contained in the credible interval. @@ -44,10 +50,17 @@ julia> r2_score(y_true, y_pred) - [Gelman2019](@cite) Gelman et al, The Am. Stat., 73(3) (2019) """ -function r2_score(y_true, y_pred; summary=true, ci_fun=eti, ci_prob=DEFAULT_CI_PROB) +function r2_score( + y_true, + y_pred; + summary=true, + point_estimate=Statistics.mean, + ci_fun=eti, + ci_prob=DEFAULT_CI_PROB, +) r_squared = _r2_samples(y_true, y_pred) summary || return r_squared - r2 = Statistics.mean(r_squared) + r2 = point_estimate(r_squared) ci = ci_fun(r_squared; prob=ci_prob) ci_name = Symbol(_fname(ci_fun)) return (; r2, ci_name => ci) diff --git a/test/r2_score.jl b/test/r2_score.jl index 00b4676..138c009 100644 --- a/test/r2_score.jl +++ b/test/r2_score.jl @@ -23,9 +23,12 @@ using Test @test r2_val.r2 == mean(r2_draws) @test r2_val.eti == eti(r2_draws; prob=T(0.89)) - r2_val2 = r2_score(y, y_pred; ci_fun=hdi, ci_prob=T(0.95)) + r2_val2 = r2_score( + y, y_pred; point_estimate=median, ci_fun=hdi, ci_prob=T(0.95) + ) @test r2_val2 isa @NamedTuple{r2::T, hdi::ClosedInterval{T}} - @test r2_val2.hdi == hdi(r2_draws; prob=0.95) + @test r2_val2.r2 == median(r2_draws) + @test r2_val2.hdi == hdi(r2_draws; prob=T(0.95)) r2_draws2 = PosteriorStats.r2_score(y, y_pred; summary=false) @test r2_draws2 == r2_draws From 013d04acad809e337531b86f8c67ddde2a4db132 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 22 Sep 2025 14:12:16 +0200 Subject: [PATCH 4/5] Test default value --- test/r2_score.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/r2_score.jl b/test/r2_score.jl index 138c009..716cbb1 100644 --- a/test/r2_score.jl +++ b/test/r2_score.jl @@ -22,6 +22,7 @@ using Test r2_draws = @inferred PosteriorStats._r2_samples(y, y_pred) @test r2_val.r2 == mean(r2_draws) @test r2_val.eti == eti(r2_draws; prob=T(0.89)) + @test r2_val == r2_score(y, y_pred) r2_val2 = r2_score( y, y_pred; point_estimate=median, ci_fun=hdi, ci_prob=T(0.95) From 99c32ad9177114edb0c90077cb1698f98b502796 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 22 Sep 2025 20:36:36 +0200 Subject: [PATCH 5/5] Fix tests --- test/r2_score.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/r2_score.jl b/test/r2_score.jl index 716cbb1..f9d069e 100644 --- a/test/r2_score.jl +++ b/test/r2_score.jl @@ -17,11 +17,11 @@ using Test x_reshape = length(sz) == 1 ? x' : reshape(x, 1, 1, :) y_pred = slope .* x_reshape .+ intercept .+ randn(T, sz..., n) .* σ - r2_val = @inferred r2_score(y, y_pred; ci_prob=T(0.89)) + r2_val = @inferred r2_score(y, y_pred; ci_prob=PosteriorStats.DEFAULT_CI_PROB) @test r2_val isa @NamedTuple{r2::T, eti::ClosedInterval{T}} r2_draws = @inferred PosteriorStats._r2_samples(y, y_pred) @test r2_val.r2 == mean(r2_draws) - @test r2_val.eti == eti(r2_draws; prob=T(0.89)) + @test r2_val.eti == eti(r2_draws; prob=PosteriorStats.DEFAULT_CI_PROB) @test r2_val == r2_score(y, y_pred) r2_val2 = r2_score(