-
-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathr2_score.jl
More file actions
84 lines (66 loc) · 2.86 KB
/
r2_score.jl
File metadata and controls
84 lines (66 loc) · 2.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
r2_score(y_true::AbstractVector, y_pred::AbstractArray; kwargs...) -> (; r2, r2_std)
``R²`` for linear Bayesian regression models.[Gelman2019](@citep)
The ``R²``, or coefficient of determination, is defined as the proportion of variance in the
data that is explained by the model. For each draw, it is computed as the variance of the
predicted values divided by the variance of the predicted values plus the variance of the
residuals.
The distribution of the ``R²`` scores can then be summarized using a point estimate and a
credible interval (CI).
# Arguments
- `y_true`: Observed data of length `noutputs`
- `y_pred`: Predicted data with size `(ndraws[, nchains], noutputs)`
# Keywords
- `summary::Bool=true`: Whether to return a summary or an array of ``R²`` scores. The
summary is a named tuple with the point estimate `:r2` and the credible interval
`:<ci_fun>`.
- `point_estimate=Statistics.mean`: The function used to compute the point estimate of the
``R²`` scores if `summary` is `true`. Supported options are:
+ [`Statistics.mean`](@extref) (default)
+ [`Statistics.median`](@extref)
+ [`StatsBase.mode`](@extref)
- `ci_fun=eti`: The function used to compute the credible interval if `summary` is
`true`. Supported options are [`eti`](@ref) and [`hdi`](@ref).
- `ci_prob=$(DEFAULT_CI_PROB)`: The probability mass to be contained in the credible
interval.
# Examples
```jldoctest
julia> using ArviZExampleData
julia> idata = load_example_data("regression1d");
julia> y_true = idata.observed_data.y;
julia> y_pred = PermutedDimsArray(idata.posterior_predictive.y, (:draw, :chain, :y_dim_0));
julia> r2_score(y_true, y_pred)
(r2 = 0.683196996216511, eti = 0.6230680117869596 .. 0.7384123771046265)
```
# References
- [Gelman2019](@cite) Gelman et al, The Am. Stat., 73(3) (2019)
"""
function r2_score(
y_true,
y_pred;
summary=true,
point_estimate=Statistics.mean,
ci_fun=eti,
ci_prob=DEFAULT_CI_PROB,
)
r_squared = _r2_samples(y_true, y_pred)
summary || return r_squared
r2 = point_estimate(r_squared)
ci = ci_fun(r_squared; prob=ci_prob)
ci_name = Symbol(_fname(ci_fun))
return (; r2, ci_name => ci)
end
function _r2_samples(y_true::AbstractVector, y_pred::AbstractArray)
@assert ndims(y_pred) ∈ (2, 3)
corrected = false
dims = ndims(y_pred)
var_y_est = dropdims(Statistics.var(y_pred; corrected, dims); dims)
y_true_reshape = reshape(y_true, ntuple(one, ndims(y_pred) - 1)..., :)
var_residual = dropdims(Statistics.var(y_pred .- y_true_reshape; corrected, dims); dims)
# allocate storage for type-stability
T = typeof(first(var_y_est) / first(var_residual))
sample_axes = ntuple(Base.Fix1(axes, y_pred), ndims(y_pred) - 1)
r_squared = similar(y_pred, T, sample_axes)
r_squared .= var_y_est ./ (var_y_est .+ var_residual)
return r_squared
end