Skip to content

Commit 24825f7

Browse files
authored
refactor: Reimplement pointwise log-like for JointOrderStatistics (#92)
* refactor: Reimlement pointwise log-like for JointOrderStatistics * chore: Update changelog for release * chore: Increment patch number
1 parent a33ad74 commit 24825f7

3 files changed

Lines changed: 44 additions & 23 deletions

File tree

CHANGELOG.md

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,18 @@
1010

1111
### Maintenance
1212

13-
- Update DimensionalData compat to v0.30 for tests ([#91](https://github.com/arviz-devs/PosteriorStats.jl/pull/91))
14-
1513
### Documentation
1614

15+
## v0.4.8 (2026-03-07)
16+
17+
### Features
18+
19+
- refactor: Refactor pointwise log-likelihoods for `JointOrderStatistics` to be more efficient ([#92](https://github.com/arviz-devs/PosteriorStats.jl/pull/92))
20+
21+
### Maintenance
22+
23+
- Update DimensionalData compat to v0.30 for tests ([#91](https://github.com/arviz-devs/PosteriorStats.jl/pull/91))
24+
1725
## v0.4.7 (2026-02-06)
1826

1927
### Features

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "PosteriorStats"
22
uuid = "7f36be82-ad55-44ba-a5c0-b8b5480d7aa5"
3-
version = "0.4.7"
3+
version = "0.4.8"
44
authors = ["Seth Axen <seth@sethaxen.com> and contributors"]
55

66
[deps]

src/pointwise_loglikelihoods.jl

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -203,38 +203,51 @@ end
203203

204204
if isdefined(Distributions, :JointOrderStatistics)
205205
function pointwise_conditional_loglikelihoods!!(
206-
log_like::AbstractVector{<:Real},
206+
log_like::AbstractVector{T},
207207
y::AbstractVector{<:Real},
208208
dist::Distributions.JointOrderStatistics,
209-
)
209+
) where {T<:Real}
210210
(; n, ranks) = dist
211+
m = length(y)
211212

212-
if length(ranks) == 1
213+
if m == 1
213214
log_like[begin] = Distributions.loglikelihood(dist, y)
214215
return log_like
215216
end
216217

218+
y_ext = Iterators.flatten((y, last(y)))
219+
ranks_ext = Iterators.flatten((ranks, n + 1))
220+
217221
udist = dist.dist
218-
r_ext = Iterators.flatten((0, ranks, n + 1))
219-
r_iter = Iterators.zip(r_ext, ranks, Iterators.drop(r_ext, 2))
220-
y_ext = Iterators.flatten((minimum(udist), y, maximum(udist)))
221-
y_iter = Iterators.zip(y_ext, y, Iterators.drop(y_ext, 2))
222-
223-
for (i, (r_minus, r_cur, r_plus), (y_minus, y_cur, y_plus)) in
224-
zip(eachindex(log_like), r_iter, y_iter)
225-
udist_trunc = if r_minus == 0
226-
Distributions.truncated(udist; upper=y_plus)
227-
elseif r_plus == n + 1
228-
Distributions.truncated(udist; lower=y_minus)
222+
yi = first(y)
223+
ri = si = first(ranks)
224+
loggi = SpecialFunctions.loggamma(T(si))
225+
logdi = Distributions.logcdf(udist, yi)
226+
for (i, (yi_plus, ri_plus)) in enumerate(Iterators.drop(zip(y_ext, ranks_ext), 1))
227+
si_plus = ri_plus - ri
228+
si_gap = si + si_plus
229+
logdi_plus = if i == m
230+
Distributions.logccdf(udist, yi_plus)
229231
else
230-
Distributions.truncated(udist; lower=y_minus, upper=y_plus)
232+
Distributions.logdiffcdf(udist, yi_plus, yi)
231233
end
232-
n_gap = r_plus - r_minus - 1
233-
r_in_gap = r_cur - r_minus
234-
dist_ostat = Distributions.OrderStatistic(udist_trunc, n_gap, r_in_gap)
235-
log_like[i] = Distributions.loglikelihood(dist_ostat, y_cur)
236-
end
234+
logdi_gap = LogExpFunctions.logaddexp(logdi, logdi_plus)
237235

236+
loggi_plus = SpecialFunctions.loggamma(T(si_plus))
237+
loggi_gap = SpecialFunctions.loggamma(T(si_gap))
238+
log_beta = loggi + loggi_plus - loggi_gap
239+
240+
logpi = Distributions.logpdf(udist, yi)
241+
242+
# likelihood is basically a change-of-variables times a ratio of Dirichlets,
243+
# where all terms cancel except for the ones that change depending on whether
244+
# ranks[i] is observed or not.
245+
log_like[i] =
246+
logpi + (si - 1) * logdi + (si_plus - 1) * logdi_plus -
247+
(si_gap - 1) * logdi_gap - log_beta
248+
249+
(yi, ri, si, logdi, loggi) = (yi_plus, ri_plus, si_plus, logdi_plus, loggi_plus)
250+
end
238251
return log_like
239252
end
240253
end

0 commit comments

Comments
 (0)