Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "PSIS"
uuid = "ce719bf2-d5d0-4fb9-925d-10a81b42ad04"
version = "0.9.8"
version = "0.9.9"
authors = ["Seth Axen <seth.axen@gmail.com> and contributors"]

[deps]
Expand Down
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
PSIS = "ce719bf2-d5d0-4fb9-925d-10a81b42ad04"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Expand All @@ -13,5 +14,6 @@ Distributions = "0.25.81"
Documenter = "1"
DocumenterCitations = "1.2.1"
DocumenterInterLinks = "1"
LogExpFunctions = "0.3.3"
MCMCDiagnosticTools = "0.3.2"
Plots = "1.10.1"
21 changes: 11 additions & 10 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ Result of Pareto-smoothed importance sampling (PSIS) using [`psis`](@ref).
- `nparams`: number of parameters in `log_weights`
- `ndraws`: number of draws in `log_weights`
- `nchains`: number of chains in `log_weights`
- `reff`: the ratio of the effective sample size of the unsmoothed importance ratios and
the actual sample size.
- `reff`: the ratio of the effective sample size of the inverse of the unsmoothed
importance ratios and the actual sample size.
- `ess`: estimated effective sample size of estimate of mean using smoothed importance
samples (see [`ess_is`](@ref))
- `tail_length`: length of the upper tail of `log_weights` that was smoothed
Expand Down Expand Up @@ -186,10 +186,11 @@ While `psis` computes smoothed log weights out-of-place, `psis!` smooths them in
- `log_ratios`: an array of logarithms of importance ratios, with size
`(draws, [chains, [parameters...]])`, where `chains>1` would be used when chains are
generated using Markov chain Monte Carlo.
- `reff::Union{Real,AbstractArray}`: the ratio(s) of effective sample size of
`log_ratios` and the actual sample size `reff = ess/(draws * chains)`, used to account
for autocorrelation, e.g. due to Markov chain Monte Carlo. If an array, it must have the
size `(parameters...,)` to match `log_ratios`.
- `reff::Union{Real,AbstractArray}`: the ratio(s) of effective sample size of the inverse
of the unsmoothed importance ratios `1 ./ exp.(log_ratios)` and the actual sample size
`reff = ess/(draws * chains)`, used to account for autocorrelation, e.g. due to Markov
chain Monte Carlo. If an array, it must have the size `(parameters...,)` to match
`log_ratios`.

# Keywords

Expand Down Expand Up @@ -236,9 +237,9 @@ If the draws were generated using MCMC, we can compute the relative efficiency u
[`MCMCDiagnosticTools.ess`](@extref).

```jldoctest psis
julia> using MCMCDiagnosticTools
julia> using LogExpFunctions, MCMCDiagnosticTools

julia> reff = ess(log_ratios; kind=:basic, split_chains=1, relative=true);
julia> reff = ess(softmax(-log_ratios; dims=(1, 2)); kind=:basic, split_chains=1, relative=true);

julia> result = psis(log_ratios, reff)
┌ Warning: 9 parameters had Pareto shape values 0.7 < k ≤ 1. Resulting importance sampling estimates are likely to be unstable.
Expand All @@ -248,8 +249,8 @@ julia> result = psis(log_ratios, reff)
PSISResult with 1000 draws, 1 chains, and 30 parameters
Pareto shape (k) diagnostic values:
Count Min. ESS
(-Inf, 0.5] good 9 (30.0%) 806
(0.5, 0.7] okay 11 (36.7%) 842
(-Inf, 0.5] good 10 (33.3%) 835
(0.5, 0.7] okay 10 (33.3%) 849
(0.7, 1] bad 9 (30.0%) ——
(1, Inf) very bad 1 (3.3%) ——
```
Expand Down
3 changes: 2 additions & 1 deletion src/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ Given normalized weights ``w_{1:n}``, the ESS is estimated using the L2-norm of
\\mathrm{ESS}(w_{1:n}) = \\frac{r_{\\mathrm{eff}}}{\\sum_{i=1}^n w_i^2}
```
where ``r_{\\mathrm{eff}}`` is the relative efficiency of the `log_weights`.
where ``r_{\\mathrm{eff}}`` is the relative efficiency of the inverse of the unsmoothed
importance ratios (see [`psis`](@ref)).
ess_is(result::PSISResult; bad_shape_nan=true)
Expand Down
Loading