Skip to content

Commit a799f6a

Browse files
committed
cuda: Use LLVM type for inline roots detection
basically with tracedrnumber the julia type has ptrs when it shouldn't
1 parent 8610a4f commit a799f6a

1 file changed

Lines changed: 72 additions & 6 deletions

File tree

ext/ReactantCUDAExt.jl

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,71 @@ function to_bytes(x)
10161016
end
10171017
end
10181018

1019+
function to_llvmtype(@nospecialize(mlirty::MLIR.IR.Type))
1020+
# Void (no predicate in the C API; compare by value)
1021+
if mlirty == MLIR.IR.Type(MLIR.API.mlirLLVMVoidTypeGet(MLIR.IR.context(mlirty)))
1022+
return LLVM.VoidType()
1023+
end
1024+
1025+
# Integers
1026+
MLIR.IR.isinteger(mlirty) && return LLVM.IntType(MLIR.IR.bitwidth(mlirty))
1027+
1028+
# Float types
1029+
MLIR.IR.isf16(mlirty) && return LLVM.HalfType()
1030+
MLIR.IR.isbf16(mlirty) && return LLVM.BFloatType()
1031+
MLIR.IR.isf32(mlirty) && return LLVM.FloatType()
1032+
MLIR.IR.isf64(mlirty) && return LLVM.DoubleType()
1033+
MLIR.IR.istf32(mlirty) && return LLVM.FloatType()
1034+
1035+
# Pointer
1036+
if MLIR.API.mlirTypeIsALLVMPointerType(mlirty)
1037+
addrspace = MLIR.API.mlirLLVMPointerTypeGetAddressSpace(mlirty)
1038+
return LLVM.PointerType(addrspace)
1039+
end
1040+
1041+
# Array
1042+
if MLIR.API.mlirTypeIsALLVMArrayType(mlirty)
1043+
elem = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGetElementType(mlirty))
1044+
n = MLIR.API.mlirLLVMArrayTypeGetNumElements(mlirty)
1045+
return LLVM.ArrayType(to_llvmtype(elem), n)
1046+
end
1047+
1048+
# Struct
1049+
if MLIR.API.mlirTypeIsALLVMStructType(mlirty)
1050+
packed = MLIR.API.mlirLLVMStructTypeIsPacked(mlirty)
1051+
nfields = MLIR.API.mlirLLVMStructTypeGetNumElementTypes(mlirty)
1052+
elems = LLVM.LLVMType[
1053+
to_llvmtype(MLIR.IR.Type(MLIR.API.mlirLLVMStructTypeGetElementType(mlirty, i)))
1054+
for i in 0:(nfields - 1)
1055+
]
1056+
if !MLIR.API.mlirLLVMStructTypeIsLiteral(mlirty)
1057+
nameref = MLIR.API.mlirLLVMStructTypeGetIdentifier(mlirty)
1058+
sname = unsafe_string(nameref.data, nameref.length)
1059+
st = LLVM.StructType(sname)
1060+
isempty(elems) || LLVM.elements!(st, elems, packed)
1061+
return st
1062+
end
1063+
return LLVM.StructType(elems; packed)
1064+
end
1065+
1066+
# Function
1067+
if MLIR.API.mlirTypeIsALLVMFunctionType(mlirty)
1068+
retty = to_llvmtype(
1069+
MLIR.IR.Type(MLIR.API.mlirLLVMFunctionTypeGetReturnType(mlirty))
1070+
)
1071+
ninputs = MLIR.API.mlirLLVMFunctionTypeGetNumInputs(mlirty)
1072+
params = LLVM.LLVMType[
1073+
to_llvmtype(MLIR.IR.Type(MLIR.API.mlirLLVMFunctionTypeGetInput(mlirty, i))) for
1074+
i in 0:(ninputs - 1)
1075+
]
1076+
return LLVM.FunctionType(
1077+
retty, params; vararg=MLIR.API.mlirLLVMFunctionTypeIsVarArg(mlirty)
1078+
)
1079+
end
1080+
1081+
return error("cannot convert type to llvm " * string(mlirty))
1082+
end
1083+
10191084
function Reactant.make_tracer(
10201085
seen, @nospecialize(prev::CuTracedArray), @nospecialize(path), mode; kwargs...
10211086
)
@@ -1215,7 +1280,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
12151280
)
12161281

12171282
trueidx = 1
1218-
allocs = Union{Tuple{MLIR.IR.Value,MLIR.IR.Type,Type},Nothing}[]
1283+
allocs = Union{Tuple{MLIR.IR.Value,MLIR.IR.Type,LLVM.LLVMType},Nothing}[]
12191284

12201285
llvmptr = MLIR.IR.Type(MLIR.API.mlirLLVMPointerTypeGet(ctx, 0))
12211286
i8 = MLIR.IR.Type(UInt8)
@@ -1235,7 +1300,8 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
12351300
)
12361301
trueidx += 1
12371302
jltyp = Core.Typeof(a)
1238-
if Enzyme.Compiler.inline_roots_type(jltyp) != 0
1303+
lltyp = to_llvmtype(argty)
1304+
if Enzyme.Compiler.inline_roots_type(llvmtyp) != 0
12391305
trueidx += 1
12401306
end
12411307
c1 = MLIR.IR.result(
@@ -1250,7 +1316,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
12501316
),
12511317
1,
12521318
)
1253-
push!(allocs, (alloc, argty, jltyp))
1319+
push!(allocs, (alloc, argty, lltyp))
12541320

12551321
if has_cast_float_type
12561322
# The argument `a` has BFloat16 fields but the GPU function was
@@ -1384,17 +1450,17 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
13841450
if arg === nothing
13851451
continue
13861452
end
1387-
alloc, argty, jltyp = arg
1453+
alloc, argty, llvmtyp = arg
13881454
argres = MLIR.IR.result(MLIR.Dialects.llvm.load(alloc; res=argty), 1)
13891455
push!(wrapargs, argres)
1390-
if Enzyme.Compiler.inline_roots_type(jltyp) != 0
1456+
if Enzyme.Compiler.inline_roots_type(llvmtyp) != 0
13911457
c1 = MLIR.IR.result(
13921458
MLIR.Dialects.llvm.mlir_constant(;
13931459
res=MLIR.IR.Type(Int64), value=MLIR.IR.Attribute(1)
13941460
),
13951461
1,
13961462
)
1397-
roots_count = Enzyme.Compiler.inline_roots_type(jltyp)
1463+
roots_count = Enzyme.Compiler.inline_roots_type(llvmtyp)
13981464
jlvaluet = MLIR.IR.Type(MLIR.API.mlirLLVMPointerTypeGet(ctx, 10))
13991465
njlvaluet = MLIR.IR.Type(
14001466
MLIR.API.mlirLLVMArrayTypeGet(jlvaluet, roots_count)

0 commit comments

Comments
 (0)