Skip to content

Commit 6ee79a1

Browse files
committed
only in 1.12
1 parent 847515d commit 6ee79a1

1 file changed

Lines changed: 13 additions & 4 deletions

File tree

ext/ReactantCUDAExt.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,6 +1178,12 @@ function mlir_extract_roots_from_value!(
11781178
end
11791179
end
11801180

1181+
# On 1.12+, there was a change to the calling convention where
1182+
# an additional argument would be added for the roots, this will
1183+
# return the number of roots in the corresponding convention, or
1184+
# 0 if it does not apply https://github.com/JuliaLang/julia/pull/55767/files#diff-62cfb2606c6a323a7f26a3eddfa0bf2b819fa33e094561fee09daeb328e3a1e7
1185+
const HasInlineRootsABI = VERSION v"1.12"
1186+
11811187
Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
11821188
args...;
11831189
convert=Val(true),
@@ -1280,7 +1286,10 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
12801286
)
12811287

12821288
trueidx = 1
1283-
allocs = Union{Tuple{MLIR.IR.Value,MLIR.IR.Type,LLVM.LLVMType},Nothing}[]
1289+
allocs = Union{
1290+
Tuple{MLIR.IR.Value,MLIR.IR.Type,HasInlineRootsABI ? LLVM.LLVMType : Nothing},
1291+
Nothing,
1292+
}[]
12841293

12851294
llvmptr = MLIR.IR.Type(MLIR.API.mlirLLVMPointerTypeGet(ctx, 0))
12861295
i8 = MLIR.IR.Type(UInt8)
@@ -1300,8 +1309,8 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
13001309
)
13011310
trueidx += 1
13021311
jltyp = Core.Typeof(a)
1303-
lltyp = to_llvmtype(argty)
1304-
if Enzyme.Compiler.inline_roots_type(lltyp) != 0
1312+
lltyp = HasInlineRootsABI ? to_llvmtype(argty) : nothing
1313+
if HasInlineRootsABI && Enzyme.Compiler.inline_roots_type(lltyp) != 0
13051314
trueidx += 1
13061315
end
13071316
c1 = MLIR.IR.result(
@@ -1453,7 +1462,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
14531462
alloc, argty, llvmtyp = arg
14541463
argres = MLIR.IR.result(MLIR.Dialects.llvm.load(alloc; res=argty), 1)
14551464
push!(wrapargs, argres)
1456-
if Enzyme.Compiler.inline_roots_type(llvmtyp) != 0
1465+
if HasInlineRootsABI && Enzyme.Compiler.inline_roots_type(llvmtyp) != 0
14571466
c1 = MLIR.IR.result(
14581467
MLIR.Dialects.llvm.mlir_constant(;
14591468
res=MLIR.IR.Type(Int64), value=MLIR.IR.Attribute(1)

0 commit comments

Comments
 (0)