Skip to content

Commit 485a3bc

Browse files
authored
Store SummaryStats labels separately from data (#81)
* Refactor to store labels separate from data * Set var_names to nothing by default * Update tests * Increment patch number * Run downgrade CI on minimum Julia version DimensionalData has a fix for latest Julia release, but not on its earliest supported version * Update Project.toml
1 parent 77a3a91 commit 485a3bc

5 files changed

Lines changed: 74 additions & 47 deletions

File tree

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131
downgrade:
3232
- false
3333
include:
34-
- version: '1'
34+
- version: 'min'
3535
os: ubuntu-latest
3636
arch: x64
3737
downgrade: true

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "PosteriorStats"
22
uuid = "7f36be82-ad55-44ba-a5c0-b8b5480d7aa5"
33
authors = ["Seth Axen <seth@sethaxen.com> and contributors"]
4-
version = "0.4.2"
4+
version = "0.4.3-DEV"
55

66
[deps]
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"

src/summarize.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,14 +117,16 @@ Base.@constprop :aggressive function summarize(
117117
stats_funs_and_names...;
118118
kind::Union{Symbol,Val}=:all,
119119
name::String="SummaryStats",
120-
var_names=axes(data, 3),
120+
var_names=nothing,
121121
kwargs...,
122122
)
123-
length(var_names) == size(data, 3) || throw(
124-
DimensionMismatch(
125-
"length $(length(var_names)) of `var_names` does not match number of parameters $(size(data, 3)) in `data`.",
126-
),
127-
)
123+
var_names === nothing ||
124+
length(var_names) == size(data, 3) ||
125+
throw(
126+
DimensionMismatch(
127+
"length $(length(var_names)) of `var_names` does not match number of parameters $(size(data, 3)) in `data`.",
128+
),
129+
)
128130
if isempty(stats_funs_and_names)
129131
return _summarize(data, default_summary_stats(kind; kwargs...), name, var_names)
130132
else

src/summarystats.jl

Lines changed: 48 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -31,60 +31,64 @@ this will be used for the row labels or will be replaced with the `labels` if pr
3131
display. If not provided, then the column `$(LABEL_COLUMN_NAME)` in `data` will be
3232
used if it exists. Otherwise, the parameter names will be numeric indices.
3333
"""
34-
struct SummaryStats{D,N<:AbstractString}
34+
struct SummaryStats{D,L<:Union{Nothing,AbstractVector},N<:AbstractString}
3535
data::D
36+
labels::L
3637
name::N
37-
function SummaryStats(data, name::N) where {N<:AbstractString}
38-
_coltable = Tables.columntable(data)
39-
# define default parameter names if not present, and set as first column
40-
if !haskey(_coltable, LABEL_COLUMN_NAME)
41-
data_cols = _coltable
42-
labels = Base.OneTo(Tables.rowcount(data))
43-
else
44-
data_colnames = filter(k -> k !== LABEL_COLUMN_NAME, keys(_coltable))
45-
data_cols = NamedTuple{data_colnames}(_coltable)
46-
labels = _coltable[LABEL_COLUMN_NAME]
47-
end
48-
coltable = merge((; LABEL_COLUMN_NAME => labels), data_cols)
49-
return new{typeof(coltable),N}(coltable, name)
50-
end
5138
end
5239

5340
function SummaryStats(
5441
data; labels::Union{AbstractVector,Nothing}=nothing, name::AbstractString="SummaryStats"
5542
)
43+
_coltable = Tables.columntable(data)
5644
if labels !== nothing
5745
length(labels) == Tables.rowcount(data) || throw(
5846
DimensionMismatch(
5947
"length $(length(labels)) of `labels` does not match number of rows $(Tables.rowcount(data)) in `data`.",
6048
),
6149
)
62-
data_with_varnames = merge(
63-
Tables.columntable(data), (; LABEL_COLUMN_NAME => labels)
50+
end
51+
if haskey(_coltable, LABEL_COLUMN_NAME)
52+
labels === nothing || throw(
53+
ArgumentError(
54+
"Either `labels` or a column named `$(LABEL_COLUMN_NAME)` may be provided, but not both.",
55+
),
6456
)
65-
else
66-
data_with_varnames = data
57+
data_colnames = filter(k -> k !== LABEL_COLUMN_NAME, keys(_coltable))
58+
data_cols = NamedTuple{data_colnames}(_coltable)
59+
_labels = _coltable[LABEL_COLUMN_NAME]
60+
return SummaryStats(data_cols, _labels, name)
6761
end
68-
return SummaryStats(data_with_varnames, name)
62+
return SummaryStats(_coltable, labels, name)
6963
end
7064

7165
# forward key interfaces from its parent
7266
Base.parent(stats::SummaryStats) = getfield(stats, :data)
73-
Base.keys(stats::SummaryStats) = map(Symbol, Tables.columnnames(stats))
67+
Base.keys(stats::SummaryStats) = Tables.columnnames(stats)
7468
Base.haskey(stats::SummaryStats, nm::Symbol) = nm keys(stats)
75-
Base.length(stats::SummaryStats) = length(parent(stats))
69+
Base.length(stats::SummaryStats) = length(parent(stats)) + 1
7670
Base.getindex(stats::SummaryStats, i::Union{Int,Symbol}) = Tables.getcolumn(stats, i)
77-
Base.iterate(stats::SummaryStats, rest...) = iterate(parent(stats), rest...)
71+
Base.iterate(stats::SummaryStats) = (_labels(stats), 2)
72+
function Base.iterate(stats::SummaryStats, i::Int)
73+
state = iterate(parent(stats), i - 1)
74+
state === nothing && return nothing
75+
return (state[1], state[2] + 1)
76+
end
7877
function Base.merge(stats::SummaryStats, other_stats::SummaryStats...)
7978
isempty(other_stats) && return stats
8079
stats_all = (stats, other_stats...)
8180
stats_last = last(stats_all)
82-
return SummaryStats(merge(map(parent, stats_all)...), stats_last.name)
81+
return SummaryStats(
82+
merge(map(parent, stats_all)...),
83+
getfield(stats_last, :labels),
84+
getfield(stats_last, :name),
85+
)
8386
end
8487
for f in (:(==), :isequal)
8588
@eval begin
8689
function Base.$(f)(stats::SummaryStats, other_stats::SummaryStats)
87-
return $(f)(parent(stats), parent(other_stats))
90+
return $(f)(_labels(stats), _labels(other_stats)) &&
91+
$(f)(parent(stats), parent(other_stats))
8892
end
8993
end
9094
end
@@ -99,30 +103,42 @@ function Base.show(io::IO, mime::MIME"text/html", stats::SummaryStats; kwargs...
99103
end
100104

101105
function _show(io::IO, mime::MIME, stats::SummaryStats; kwargs...)
102-
nt = parent(stats)
103-
data = nt[keys(nt)[2:end]]
106+
data = parent(stats)
104107
rhat_formatter = _prettytables_rhat_formatter(data)
105108
extra_formatters = rhat_formatter === nothing ? () : (rhat_formatter,)
106109
return _show_prettytable(
107110
io,
108111
mime,
109112
data;
110113
title=stats.name,
111-
row_labels=Tables.getcolumn(stats, LABEL_COLUMN_NAME),
114+
row_labels=_labels(stats),
112115
extra_formatters,
113116
kwargs...,
114117
)
115118
end
116119

117120
#### Tables interface as column table
118121

122+
_labels(s::SummaryStats) = getfield(s, :labels)
123+
_labels(s::SummaryStats{<:Any,Nothing}) = eachindex(values(parent(s))...)
124+
119125
Tables.istable(::Type{<:SummaryStats}) = true
120126
Tables.columnaccess(::Type{<:SummaryStats}) = true
121127
Tables.columns(s::SummaryStats) = s
122-
Tables.columnnames(s::SummaryStats) = Tables.columnnames(parent(s))
123-
Tables.getcolumn(stats::SummaryStats, i::Int) = Tables.getcolumn(parent(stats), i)
124-
Tables.getcolumn(stats::SummaryStats, nm::Symbol) = Tables.getcolumn(parent(stats), nm)
125-
Tables.schema(s::SummaryStats) = Tables.schema(parent(s))
128+
Tables.columnnames(s::SummaryStats) = (LABEL_COLUMN_NAME, Tables.columnnames(parent(s))...)
129+
function Tables.getcolumn(stats::SummaryStats, i::Int)
130+
i == 1 && return _labels(stats)
131+
return Tables.getcolumn(parent(stats), i - 1)
132+
end
133+
function Tables.getcolumn(stats::SummaryStats, nm::Symbol)
134+
nm === LABEL_COLUMN_NAME && return _labels(stats)
135+
return Tables.getcolumn(parent(stats), nm)
136+
end
137+
function Tables.schema(s::SummaryStats)
138+
labels = _labels(s)
139+
sch = Tables.schema(parent(s))
140+
return Tables.Schema((LABEL_COLUMN_NAME, sch.names...), (eltype(labels), sch.types...))
141+
end
126142

127143
IteratorInterfaceExtensions.isiterable(::SummaryStats) = true
128144
function IteratorInterfaceExtensions.getiterator(s::SummaryStats)

test/summarystats.jl

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,35 +13,44 @@ using Test
1313

1414
@testset "constructors" begin
1515
stats1 = SummaryStats(data)
16-
@test stats1.data == data_with_default_labels
16+
@test stats1.data == data
17+
@test isnothing(stats1.labels)
1718
@test stats1.name == "SummaryStats"
1819

1920
stats2 = SummaryStats(data; name="Stats")
20-
@test stats2.data == data_with_default_labels
21+
@test stats2.data == data
22+
@test isnothing(stats2.labels)
2123
@test stats2.name == "Stats"
2224

2325
stats3 = SummaryStats(data; labels)
24-
@test stats3.data == data_with_labels
26+
@test stats3.data == data
27+
@test stats3.labels == labels
2528
@test stats3.name == "SummaryStats"
2629

2730
stats4 = SummaryStats(data; labels, name="Stats")
28-
@test stats4.data == data_with_labels
31+
@test stats4.data == data
32+
@test stats4.labels == labels
2933
@test stats4.name == "Stats"
3034

3135
stats5 = SummaryStats(data_with_labels)
32-
@test stats5.data == data_with_labels
36+
@test stats5.data == data
37+
@test stats5.labels == labels
3338

3439
stats6 = SummaryStats(merge(data, (; label=labels)))
35-
@test stats6.data == data_with_labels
40+
@test stats6.data == data
41+
@test stats6.labels == labels
3642
@test stats6.name == "SummaryStats"
43+
44+
@test_throws ArgumentError SummaryStats(data_with_labels; labels)
45+
@test_throws DimensionMismatch SummaryStats(data; labels=labels[1:(end - 1)])
3746
end
3847

3948
@inferred SummaryStats(data; name="Stats")
4049
stats_with_names(data, name, labels) = SummaryStats(data; name, labels)
4150
stats = @inferred stats_with_names(data, "Stats", labels)
4251

4352
@testset "basic interfaces" begin
44-
@test parent(stats) == data_with_labels
53+
@test parent(stats) == data
4554
@test stats.name == "Stats"
4655
@test SummaryStats(data; name="MoreStats").name == "MoreStats"
4756
@test keys(stats) == (:label, keys(data)...)

0 commit comments

Comments
 (0)