Skip to content

Commit 71c8367

Browse files
authored
fix create_result on Enum (#2835)
* add `create_result` method for enums * remove legacy `testset` wrapper * test
1 parent d5433eb commit 71c8367

2 files changed

Lines changed: 96 additions & 56 deletions

File tree

src/Compiler.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,37 @@ Base.@nospecializeinfer function create_result(
277277
return result_cache[tocopy]
278278
end
279279

280+
Base.@nospecializeinfer function create_result(
281+
@nospecialize(tocopy::Enum),
282+
@nospecialize(path::Tuple),
283+
result_stores,
284+
path_to_shard_info,
285+
to_unreshard_results,
286+
_unresharded_code::Vector{Expr},
287+
_unresharded_arrays_cache,
288+
used_shardinfo,
289+
result_cache,
290+
var_idx,
291+
resultgen_code,
292+
)
293+
if !haskey(result_cache, tocopy)
294+
sym = Symbol("result", var_idx[])
295+
var_idx[] += 1
296+
297+
result = Meta.quot(tocopy)
298+
299+
push!(
300+
resultgen_code,
301+
quote
302+
$sym = $result
303+
end,
304+
)
305+
result_cache[tocopy] = sym
306+
end
307+
308+
return result_cache[tocopy]
309+
end
310+
280311
function create_result(
281312
tocopy::ConcretePJRTNumber{T,D},
282313
@nospecialize(path::Tuple),

test/core/compile.jl

Lines changed: 65 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -13,76 +13,85 @@ intout_caller(vis) = @noinline intout(vis)
1313
@test_throws MethodError @compile intout_caller(visr)
1414
end
1515

16-
@testset "compile" begin
17-
@testset "create_result" begin
18-
@testset "NamedTuple" begin
19-
x = (; a=Reactant.TestUtils.construct_test_array(Float64, 4, 3))
20-
x2 = Reactant.to_rarray(x)
21-
22-
res = @jit sum(x2)
23-
@test res isa NamedTuple
24-
@test res.a isa ConcreteRNumber{Float64}
25-
@test isapprox(res.a, sum(x.a))
26-
end
27-
28-
@testset "Array" begin
29-
x = [1 2; 3 4; 5 6]
30-
f = Reactant.compile(() -> x, ())
31-
@test f() x
32-
end
16+
@testset "create_result" begin
17+
@testset "NamedTuple" begin
18+
x = (; a=Reactant.TestUtils.construct_test_array(Float64, 4, 3))
19+
x2 = Reactant.to_rarray(x)
20+
21+
res = @jit sum(x2)
22+
@test res isa NamedTuple
23+
@test res.a isa ConcreteRNumber{Float64}
24+
@test isapprox(res.a, sum(x.a))
3325
end
3426

35-
@testset "world-age" begin
36-
a = ones(2, 10)
37-
b = ones(10, 2)
38-
a_ra = Reactant.to_rarray(a)
39-
b_ra = Reactant.to_rarray(b)
27+
@testset "Array" begin
28+
x = [1 2; 3 4; 5 6]
29+
f = Reactant.compile(() -> x, ())
30+
@test f() x
31+
end
4032

41-
fworld(x, y) = @jit(x * y)
33+
@testset "Enum" begin
34+
@enum MyEnum begin
35+
MyEnumA = 1
36+
MyEnumB = 2
37+
end
4238

43-
@test fworld(a_ra, b_ra) ones(2, 2) * 10
39+
x = MyEnumA
40+
f = @compile identity(x)
41+
@test f(x) == x
4442
end
43+
end
4544

46-
@testset "type casting & optimized out returns" begin
47-
a = ones(2, 10)
48-
a_ra = Reactant.to_rarray(a)
45+
@testset "world-age" begin
46+
a = ones(2, 10)
47+
b = ones(10, 2)
48+
a_ra = Reactant.to_rarray(a)
49+
b_ra = Reactant.to_rarray(b)
4950

50-
ftype1(x) = Float64.(x)
51-
ftype2(x) = Float32.(x)
51+
fworld(x, y) = @jit(x * y)
5252

53-
y1 = @jit ftype1(a_ra)
54-
y2 = @jit ftype2(a_ra)
53+
@test fworld(a_ra, b_ra) ones(2, 2) * 10
54+
end
5555

56-
@test y1 isa Reactant.ConcreteRArray{Float64,2}
57-
@test y2 isa Reactant.ConcreteRArray{Float32,2}
56+
@testset "type casting & optimized out returns" begin
57+
a = ones(2, 10)
58+
a_ra = Reactant.to_rarray(a)
5859

59-
@test y1 Float64.(a)
60-
@test y2 Float32.(a)
61-
end
60+
ftype1(x) = Float64.(x)
61+
ftype2(x) = Float32.(x)
6262

63-
@testset "no variable name collisions in compile macros (#237)" begin
64-
f(x) = x
65-
g(x) = f(x)
66-
x = Reactant.TestUtils.construct_test_array(Float64, 2, 2)
67-
y = Reactant.to_rarray(x)
68-
@test (@jit g(y); true)
69-
end
63+
y1 = @jit ftype1(a_ra)
64+
y2 = @jit ftype2(a_ra)
7065

71-
# disabled due to long test time (core tests go from 2m to 7m just with this test)
72-
# @testset "resource exhaustation bug (#190)" begin
73-
# x = rand(2, 2)
74-
# y = Reactant.to_rarray(x)
75-
# @test try
76-
# for _ in 1:10_000
77-
# f = @compile sum(y)
78-
# end
79-
# true
80-
# catch e
81-
# false
82-
# end
83-
# end
66+
@test y1 isa Reactant.ConcreteRArray{Float64,2}
67+
@test y2 isa Reactant.ConcreteRArray{Float32,2}
68+
69+
@test y1 Float64.(a)
70+
@test y2 Float32.(a)
8471
end
8572

73+
@testset "no variable name collisions in compile macros (#237)" begin
74+
f(x) = x
75+
g(x) = f(x)
76+
x = Reactant.TestUtils.construct_test_array(Float64, 2, 2)
77+
y = Reactant.to_rarray(x)
78+
@test (@jit g(y); true)
79+
end
80+
81+
# disabled due to long test time (core tests go from 2m to 7m just with this test)
82+
# @testset "resource exhaustation bug (#190)" begin
83+
# x = rand(2, 2)
84+
# y = Reactant.to_rarray(x)
85+
# @test try
86+
# for _ in 1:10_000
87+
# f = @compile sum(y)
88+
# end
89+
# true
90+
# catch e
91+
# false
92+
# end
93+
# end
94+
8695
@testset "Module export" begin
8796
f(x) = sin.(cos.(x))
8897
x_ra = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float64, 3))

0 commit comments

Comments
 (0)