Skip to content

Commit 77a3a91

Browse files
authored
Reorganize summarize/SummaryStats source (#80)
* Copy summarize.jl * Remove now-duplicated code * Add include for new file * Split tests to match source organization * Remove unused dependency * Increment patch number
1 parent 0feb4b8 commit 77a3a91

7 files changed

Lines changed: 447 additions & 457 deletions

File tree

Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "PosteriorStats"
22
uuid = "7f36be82-ad55-44ba-a5c0-b8b5480d7aa5"
33
authors = ["Seth Axen <seth@sethaxen.com> and contributors"]
4-
version = "0.4.1"
4+
version = "0.4.2"
55

66
[deps]
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
@@ -27,7 +27,6 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2727
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2828
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2929
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
30-
TableOperations = "ab02a1b2-a7df-11e8-156e-fb1833f50b87"
3130
TableTraits = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c"
3231
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
3332

@@ -60,7 +59,6 @@ Setfield = "1"
6059
SpecialFunctions = "1.2, 2"
6160
Statistics = "1"
6261
StatsBase = "0.33.17, 0.34"
63-
TableOperations = "1"
6462
TableTraits = "1"
6563
Tables = "1.9"
6664
julia = "1.10"

src/PosteriorStats.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ include("model_weights.jl")
6464
include("compare.jl")
6565
include("loo_pit.jl")
6666
include("r2_score.jl")
67+
include("summarystats.jl")
6768
include("summarize.jl")
6869

6970
end # module

src/summarize.jl

Lines changed: 0 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
const LABEL_COLUMN_NAME = :label
2-
31
const _DEFAULT_SUMMARY_STATS_KIND_DOCSTRING = """
42
- `kind::Symbol`: The named collection of summary statistics to be computed:
53
+ `:all`: Everything in `:stats` and `:diagnostics`
@@ -17,137 +15,6 @@ const _DEFAULT_SUMMARY_STATS_CI_DOCSTRING = """
1715
interval `<ci>`.
1816
"""
1917

20-
"""
21-
struct SummaryStats
22-
23-
A container for a column table of values computed by [`summarize`](@ref).
24-
25-
This object implements the Tables and TableTraits column table interfaces. It has a custom
26-
`show` method.
27-
28-
!!! note
29-
`SummaryStats` behaves like an `OrderedDict` of columns, where the columns can be
30-
accessed using either `Symbol`s or a 1-based integer index. However, this interface
31-
is not part of the public API and may change in the future. We recommend using it
32-
only interactively.
33-
34-
# Constructor
35-
36-
SummaryStats(data; name="SummaryStats"[, labels])
37-
38-
Construct a `SummaryStats` from tabular `data`.
39-
40-
`data` must implement the Tables interface. If it contains a column `$(LABEL_COLUMN_NAME)`,
41-
this will be used for the row labels or will be replaced with the `labels` if provided.
42-
43-
# Keywords
44-
45-
- `name::AbstractString`: The name of the collection of summary statistics, used as the
46-
table title in display.
47-
- `labels::AbstractVector`: The names of the parameters in `data`, used as row labels in
48-
display. If not provided, then the column `$(LABEL_COLUMN_NAME)` in `data` will be
49-
used if it exists. Otherwise, the parameter names will be numeric indices.
50-
"""
51-
struct SummaryStats{D,N<:AbstractString}
52-
data::D
53-
name::N
54-
function SummaryStats(data, name::N) where {N<:AbstractString}
55-
_coltable = Tables.columntable(data)
56-
# define default parameter names if not present, and set as first column
57-
if !haskey(_coltable, LABEL_COLUMN_NAME)
58-
data_cols = _coltable
59-
labels = Base.OneTo(Tables.rowcount(data))
60-
else
61-
data_colnames = filter(k -> k !== LABEL_COLUMN_NAME, keys(_coltable))
62-
data_cols = NamedTuple{data_colnames}(_coltable)
63-
labels = _coltable[LABEL_COLUMN_NAME]
64-
end
65-
coltable = merge((; LABEL_COLUMN_NAME => labels), data_cols)
66-
return new{typeof(coltable),N}(coltable, name)
67-
end
68-
end
69-
70-
function SummaryStats(
71-
data; labels::Union{AbstractVector,Nothing}=nothing, name::AbstractString="SummaryStats"
72-
)
73-
if labels !== nothing
74-
length(labels) == Tables.rowcount(data) || throw(
75-
DimensionMismatch(
76-
"length $(length(labels)) of `labels` does not match number of rows $(Tables.rowcount(data)) in `data`.",
77-
),
78-
)
79-
data_with_varnames = merge(
80-
Tables.columntable(data), (; LABEL_COLUMN_NAME => labels)
81-
)
82-
else
83-
data_with_varnames = data
84-
end
85-
return SummaryStats(data_with_varnames, name)
86-
end
87-
88-
# forward key interfaces from its parent
89-
Base.parent(stats::SummaryStats) = getfield(stats, :data)
90-
Base.keys(stats::SummaryStats) = map(Symbol, Tables.columnnames(stats))
91-
Base.haskey(stats::SummaryStats, nm::Symbol) = nm keys(stats)
92-
Base.length(stats::SummaryStats) = length(parent(stats))
93-
Base.getindex(stats::SummaryStats, i::Union{Int,Symbol}) = Tables.getcolumn(stats, i)
94-
Base.iterate(stats::SummaryStats, rest...) = iterate(parent(stats), rest...)
95-
function Base.merge(stats::SummaryStats, other_stats::SummaryStats...)
96-
isempty(other_stats) && return stats
97-
stats_all = (stats, other_stats...)
98-
stats_last = last(stats_all)
99-
return SummaryStats(merge(map(parent, stats_all)...), stats_last.name)
100-
end
101-
for f in (:(==), :isequal)
102-
@eval begin
103-
function Base.$(f)(stats::SummaryStats, other_stats::SummaryStats)
104-
return $(f)(parent(stats), parent(other_stats))
105-
end
106-
end
107-
end
108-
109-
#### custom tabular show methods
110-
111-
function Base.show(io::IO, mime::MIME"text/plain", stats::SummaryStats; kwargs...)
112-
return _show(io, mime, stats; kwargs...)
113-
end
114-
function Base.show(io::IO, mime::MIME"text/html", stats::SummaryStats; kwargs...)
115-
return _show(io, mime, stats; kwargs...)
116-
end
117-
118-
function _show(io::IO, mime::MIME, stats::SummaryStats; kwargs...)
119-
nt = parent(stats)
120-
data = nt[keys(nt)[2:end]]
121-
rhat_formatter = _prettytables_rhat_formatter(data)
122-
extra_formatters = rhat_formatter === nothing ? () : (rhat_formatter,)
123-
return _show_prettytable(
124-
io,
125-
mime,
126-
data;
127-
title=stats.name,
128-
row_labels=Tables.getcolumn(stats, LABEL_COLUMN_NAME),
129-
extra_formatters,
130-
kwargs...,
131-
)
132-
end
133-
134-
#### Tables interface as column table
135-
136-
Tables.istable(::Type{<:SummaryStats}) = true
137-
Tables.columnaccess(::Type{<:SummaryStats}) = true
138-
Tables.columns(s::SummaryStats) = s
139-
Tables.columnnames(s::SummaryStats) = Tables.columnnames(parent(s))
140-
Tables.getcolumn(stats::SummaryStats, i::Int) = Tables.getcolumn(parent(stats), i)
141-
Tables.getcolumn(stats::SummaryStats, nm::Symbol) = Tables.getcolumn(parent(stats), nm)
142-
Tables.schema(s::SummaryStats) = Tables.schema(parent(s))
143-
144-
IteratorInterfaceExtensions.isiterable(::SummaryStats) = true
145-
function IteratorInterfaceExtensions.getiterator(s::SummaryStats)
146-
return Tables.datavaluerows(Tables.columntable(s))
147-
end
148-
149-
TableTraits.isiterabletable(::SummaryStats) = true
150-
15118
"""
15219
summarize(data; kind=:all,kwargs...) -> SummaryStats
15320
summarize(data, stats_funs...; kwargs...) -> SummaryStats

src/summarystats.jl

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
const LABEL_COLUMN_NAME = :label
2+
3+
"""
4+
struct SummaryStats
5+
6+
A container for a column table of values computed by [`summarize`](@ref).
7+
8+
This object implements the Tables and TableTraits column table interfaces. It has a custom
9+
`show` method.
10+
11+
!!! note
12+
`SummaryStats` behaves like an `OrderedDict` of columns, where the columns can be
13+
accessed using either `Symbol`s or a 1-based integer index. However, this interface
14+
is not part of the public API and may change in the future. We recommend using it
15+
only interactively.
16+
17+
# Constructor
18+
19+
SummaryStats(data; name="SummaryStats"[, labels])
20+
21+
Construct a `SummaryStats` from tabular `data`.
22+
23+
`data` must implement the Tables interface. If it contains a column `$(LABEL_COLUMN_NAME)`,
24+
this will be used for the row labels or will be replaced with the `labels` if provided.
25+
26+
# Keywords
27+
28+
- `name::AbstractString`: The name of the collection of summary statistics, used as the
29+
table title in display.
30+
- `labels::AbstractVector`: The names of the parameters in `data`, used as row labels in
31+
display. If not provided, then the column `$(LABEL_COLUMN_NAME)` in `data` will be
32+
used if it exists. Otherwise, the parameter names will be numeric indices.
33+
"""
34+
struct SummaryStats{D,N<:AbstractString}
35+
data::D
36+
name::N
37+
function SummaryStats(data, name::N) where {N<:AbstractString}
38+
_coltable = Tables.columntable(data)
39+
# define default parameter names if not present, and set as first column
40+
if !haskey(_coltable, LABEL_COLUMN_NAME)
41+
data_cols = _coltable
42+
labels = Base.OneTo(Tables.rowcount(data))
43+
else
44+
data_colnames = filter(k -> k !== LABEL_COLUMN_NAME, keys(_coltable))
45+
data_cols = NamedTuple{data_colnames}(_coltable)
46+
labels = _coltable[LABEL_COLUMN_NAME]
47+
end
48+
coltable = merge((; LABEL_COLUMN_NAME => labels), data_cols)
49+
return new{typeof(coltable),N}(coltable, name)
50+
end
51+
end
52+
53+
function SummaryStats(
54+
data; labels::Union{AbstractVector,Nothing}=nothing, name::AbstractString="SummaryStats"
55+
)
56+
if labels !== nothing
57+
length(labels) == Tables.rowcount(data) || throw(
58+
DimensionMismatch(
59+
"length $(length(labels)) of `labels` does not match number of rows $(Tables.rowcount(data)) in `data`.",
60+
),
61+
)
62+
data_with_varnames = merge(
63+
Tables.columntable(data), (; LABEL_COLUMN_NAME => labels)
64+
)
65+
else
66+
data_with_varnames = data
67+
end
68+
return SummaryStats(data_with_varnames, name)
69+
end
70+
71+
# forward key interfaces from its parent
72+
Base.parent(stats::SummaryStats) = getfield(stats, :data)
73+
Base.keys(stats::SummaryStats) = map(Symbol, Tables.columnnames(stats))
74+
Base.haskey(stats::SummaryStats, nm::Symbol) = nm keys(stats)
75+
Base.length(stats::SummaryStats) = length(parent(stats))
76+
Base.getindex(stats::SummaryStats, i::Union{Int,Symbol}) = Tables.getcolumn(stats, i)
77+
Base.iterate(stats::SummaryStats, rest...) = iterate(parent(stats), rest...)
78+
function Base.merge(stats::SummaryStats, other_stats::SummaryStats...)
79+
isempty(other_stats) && return stats
80+
stats_all = (stats, other_stats...)
81+
stats_last = last(stats_all)
82+
return SummaryStats(merge(map(parent, stats_all)...), stats_last.name)
83+
end
84+
for f in (:(==), :isequal)
85+
@eval begin
86+
function Base.$(f)(stats::SummaryStats, other_stats::SummaryStats)
87+
return $(f)(parent(stats), parent(other_stats))
88+
end
89+
end
90+
end
91+
92+
#### custom tabular show methods
93+
94+
function Base.show(io::IO, mime::MIME"text/plain", stats::SummaryStats; kwargs...)
95+
return _show(io, mime, stats; kwargs...)
96+
end
97+
function Base.show(io::IO, mime::MIME"text/html", stats::SummaryStats; kwargs...)
98+
return _show(io, mime, stats; kwargs...)
99+
end
100+
101+
function _show(io::IO, mime::MIME, stats::SummaryStats; kwargs...)
102+
nt = parent(stats)
103+
data = nt[keys(nt)[2:end]]
104+
rhat_formatter = _prettytables_rhat_formatter(data)
105+
extra_formatters = rhat_formatter === nothing ? () : (rhat_formatter,)
106+
return _show_prettytable(
107+
io,
108+
mime,
109+
data;
110+
title=stats.name,
111+
row_labels=Tables.getcolumn(stats, LABEL_COLUMN_NAME),
112+
extra_formatters,
113+
kwargs...,
114+
)
115+
end
116+
117+
#### Tables interface as column table
118+
119+
Tables.istable(::Type{<:SummaryStats}) = true
120+
Tables.columnaccess(::Type{<:SummaryStats}) = true
121+
Tables.columns(s::SummaryStats) = s
122+
Tables.columnnames(s::SummaryStats) = Tables.columnnames(parent(s))
123+
Tables.getcolumn(stats::SummaryStats, i::Int) = Tables.getcolumn(parent(stats), i)
124+
Tables.getcolumn(stats::SummaryStats, nm::Symbol) = Tables.getcolumn(parent(stats), nm)
125+
Tables.schema(s::SummaryStats) = Tables.schema(parent(s))
126+
127+
IteratorInterfaceExtensions.isiterable(::SummaryStats) = true
128+
function IteratorInterfaceExtensions.getiterator(s::SummaryStats)
129+
return Tables.datavaluerows(Tables.columntable(s))
130+
end
131+
132+
TableTraits.isiterabletable(::SummaryStats) = true

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@ Random.seed!(97)
1818
include("model_weights.jl")
1919
include("compare.jl")
2020
include("r2_score.jl")
21+
include("summarystats.jl")
2122
include("summarize.jl")
2223
end

0 commit comments

Comments
 (0)