-
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathfrom_namedtuple.jl
More file actions
78 lines (68 loc) · 3.39 KB
/
from_namedtuple.jl
File metadata and controls
78 lines (68 loc) · 3.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
using InferenceObjects, Test
@testset "from_namedtuple" begin
nchains, ndraws = 4, 10
sizes = (x=(), y=(2,), z=(3, 5))
dims = (y=[:yx], z=[:zx, :zy])
coords = (yx=["y1", "y2"], zx=1:3, zy=1:5)
nts = [
"NamedTuple" => map(sz -> randn(ndraws, nchains, sz...), sizes),
"Vector{Vector{NamedTuple}}" =>
[[map(Base.splat(randn), sizes) for _ in 1:ndraws] for _ in 1:nchains],
]
@testset "posterior::$(type)" for (type, nt) in nts
VERSION ≥ v"1.9" && @inferred from_namedtuple(nt; dims, coords, library="MyLib")
idata1 = from_namedtuple(nt; dims, coords, library="MyLib")
idata2 = convert_to_inference_data(nt; dims, coords, library="MyLib")
test_idata_approx_equal(idata1, idata2)
end
@testset "$(group)" for group in [
:posterior_predictive, :sample_stats, :predictions, :log_likelihood
]
library = "MyLib"
@testset "::$(type)" for (type, nt) in nts
idata1 = from_namedtuple(nt; group => nt, dims, coords, library)
test_idata_group_correct(idata1, group, keys(sizes); library, dims, coords)
idata2 = from_namedtuple(nt; group => (:x,), dims, coords, library)
test_idata_group_correct(idata2, :posterior, (:y, :z); library, dims, coords)
test_idata_group_correct(idata2, group, (:x,); library, dims, coords)
end
end
@testset "$(group)" for group in [:prior_predictive, :sample_stats_prior]
library = "MyLib"
@testset "::$(type)" for (type, nt) in nts
idata1 = from_namedtuple(; prior=nt, group => nt, dims, coords, library)
test_idata_group_correct(idata1, :prior, keys(sizes); library, dims, coords)
test_idata_group_correct(idata1, group, keys(sizes); library, dims, coords)
idata2 = from_namedtuple(; prior=nt, group => (:x,), dims, coords, library)
test_idata_group_correct(idata2, :prior, (:y, :z); library, dims, coords)
test_idata_group_correct(idata2, group, (:x,); library, dims, coords)
end
end
@testset "$(group)" for group in
[:observed_data, :constant_data, :predictions_constant_data]
_, nt = nts[1]
library = "MyLib"
dims = (; w=[:wx])
coords = (; wx=1:2)
idata1 = from_namedtuple(nt; group => (w=[1.0, 2.0], v=2.5), dims, coords, library)
test_idata_group_correct(idata1, :posterior, keys(sizes); library, dims, coords)
test_idata_group_correct(
idata1, group, (:w, :v); library, dims, coords, default_dims=()
)
# ensure that dims are matched to named tuple keys
# https://github.com/arviz-devs/ArviZ.jl/issues/96
idata2 = from_namedtuple(nt; group => (w=[1.0, 2.0], v=2.5), dims, coords, library)
test_idata_group_correct(idata2, :posterior, keys(sizes); library, dims, coords)
test_idata_group_correct(
idata2, group, (:w, :v); library, dims, coords, default_dims=()
)
end
@testset "convert_to_inference_data with non-posterior `group`" begin
data = (x=3, y=randn(2))
idata = convert_to_inference_data(data; group=:observed_data)
@test issetequal(keys(idata), (:observed_data,))
@test issetequal(keys(idata.observed_data), (:x, :y))
@test idata.observed_data.x == fill(data.x)
@test idata.observed_data.y == data.y
end
end