Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 41 additions & 21 deletions src/r2_score.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,36 @@
"""
r2_score(y_true::AbstractVector, y_pred::AbstractArray) -> (; r2, r2_std)
r2_score(y_true::AbstractVector, y_pred::AbstractArray; kwargs...) -> (; r2, r2_std)

``R²`` for linear Bayesian regression models.[Gelman2019](@citep)

The ``R²``, or coefficient of determination, is defined as the proportion of variance in the
data that is explained by the model. For each draw, it is computed as the variance of the
predicted values divided by the variance of the predicted values plus the variance of the
residuals.

The distribution of the ``R²`` scores can then be summarized using a point estimate and a
credible interval (CI).

# Arguments

- `y_true`: Observed data of length `noutputs`
- `y_pred`: Predicted data with size `(ndraws[, nchains], noutputs)`

# Keywords

- `summary::Bool=true`: Whether to return a summary or an array of ``R²`` scores. The
summary is a named tuple with the point estimate `:r2` and the credible interval
`:<ci_fun>`.
- `point_estimate=Statistics.mean`: The function used to compute the point estimate of the
``R²`` scores if `summary` is `true`. Supported options are:
+ [`Statistics.mean`](@extref) (default)
+ [`Statistics.median`](@extref)
+ [`StatsBase.mode`](@extref)
- `ci_fun=eti`: The function used to compute the credible interval if `summary` is
`true`. Supported options are [`eti`](@ref) and [`hdi`](@ref).
- `ci_prob=$(DEFAULT_CI_PROB)`: The probability mass to be contained in the credible
interval.

# Examples

```jldoctest
Expand All @@ -19,34 +42,31 @@ julia> y_true = idata.observed_data.y;

julia> y_pred = PermutedDimsArray(idata.posterior_predictive.y, (:draw, :chain, :y_dim_0));

julia> r2_score(y_true, y_pred) |> pairs
pairs(::NamedTuple) with 2 entries:
:r2 => 0.683197
:r2_std => 0.0368838
julia> r2_score(y_true, y_pred)
(r2 = 0.683196996216511, eti = 0.6082075654135802 .. 0.7462891653797559)
```

# References

- [Gelman2019](@cite) Gelman et al, The Am. Stat., 73(3) (2019)
"""
function r2_score(y_true, y_pred)
r_squared = r2_samples(y_true, y_pred)
return NamedTuple{(:r2, :r2_std)}(StatsBase.mean_and_std(r_squared; corrected=false))
function r2_score(
y_true,
y_pred;
summary=true,
point_estimate=Statistics.mean,
ci_fun=eti,
ci_prob=DEFAULT_CI_PROB,
)
r_squared = _r2_samples(y_true, y_pred)
summary || return r_squared
r2 = point_estimate(r_squared)
ci = ci_fun(r_squared; prob=ci_prob)
ci_name = Symbol(_fname(ci_fun))
return (; r2, ci_name => ci)
end

"""
r2_samples(y_true::AbstractVector, y_pred::AbstractArray) -> AbstractVector

``R²`` samples for Bayesian regression models. Only valid for linear models.

See also [`r2_score`](@ref).

# Arguments

- `y_true`: Observed data of length `noutputs`
- `y_pred`: Predicted data with size `(ndraws[, nchains], noutputs)`
"""
function r2_samples(y_true::AbstractVector, y_pred::AbstractArray)
function _r2_samples(y_true::AbstractVector, y_pred::AbstractArray)
@assert ndims(y_pred) ∈ (2, 3)
corrected = false
dims = ndims(y_pred)
Expand Down
23 changes: 17 additions & 6 deletions test/r2_score.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using PosteriorStats
using Statistics
using Test

@testset "r2_score/r2_sample" begin
@testset "r2_score" begin
@testset "basic" begin
n = 100
@testset for T in (Float32, Float64),
Expand All @@ -17,11 +17,22 @@ using Test
x_reshape = length(sz) == 1 ? x' : reshape(x, 1, 1, :)
y_pred = slope .* x_reshape .+ intercept .+ randn(T, sz..., n) .* σ

r2_val = @inferred r2_score(y, y_pred)
@test r2_val isa NamedTuple{(:r2, :r2_std),NTuple{2,T}}
r2_draws = @inferred PosteriorStats.r2_samples(y, y_pred)
@test r2_val.r2 ≈ mean(r2_draws)
@test r2_val.r2_std ≈ std(r2_draws; corrected=false)
r2_val = @inferred r2_score(y, y_pred; ci_prob=PosteriorStats.DEFAULT_CI_PROB)
@test r2_val isa @NamedTuple{r2::T, eti::ClosedInterval{T}}
r2_draws = @inferred PosteriorStats._r2_samples(y, y_pred)
@test r2_val.r2 == mean(r2_draws)
@test r2_val.eti == eti(r2_draws; prob=PosteriorStats.DEFAULT_CI_PROB)
@test r2_val == r2_score(y, y_pred)

r2_val2 = r2_score(
y, y_pred; point_estimate=median, ci_fun=hdi, ci_prob=T(0.95)
)
@test r2_val2 isa @NamedTuple{r2::T, hdi::ClosedInterval{T}}
@test r2_val2.r2 == median(r2_draws)
@test r2_val2.hdi == hdi(r2_draws; prob=T(0.95))

r2_draws2 = PosteriorStats.r2_score(y, y_pred; summary=false)
@test r2_draws2 == r2_draws

# check rough consistency with GLM
res = lm(@formula(y ~ 1 + x), (; x=Float64.(x), y=Float64.(y)))
Expand Down
Loading