@@ -1016,6 +1016,71 @@ function to_bytes(x)
10161016 end
10171017end
10181018
1019+ function to_llvmtype (@nospecialize (mlirty:: MLIR.IR.Type ))
1020+ # Void (no predicate in the C API; compare by value)
1021+ if mlirty == MLIR. IR. Type (MLIR. API. mlirLLVMVoidTypeGet (MLIR. IR. context (mlirty)))
1022+ return LLVM. VoidType ()
1023+ end
1024+
1025+ # Integers
1026+ MLIR. IR. isinteger (mlirty) && return LLVM. IntType (MLIR. IR. bitwidth (mlirty))
1027+
1028+ # Float types
1029+ MLIR. IR. isf16 (mlirty) && return LLVM. HalfType ()
1030+ MLIR. IR. isbf16 (mlirty) && return LLVM. BFloatType ()
1031+ MLIR. IR. isf32 (mlirty) && return LLVM. FloatType ()
1032+ MLIR. IR. isf64 (mlirty) && return LLVM. DoubleType ()
1033+ MLIR. IR. istf32 (mlirty) && return LLVM. FloatType ()
1034+
1035+ # Pointer
1036+ if MLIR. API. mlirTypeIsALLVMPointerType (mlirty)
1037+ addrspace = MLIR. API. mlirLLVMPointerTypeGetAddressSpace (mlirty)
1038+ return LLVM. PointerType (addrspace)
1039+ end
1040+
1041+ # Array
1042+ if MLIR. API. mlirTypeIsALLVMArrayType (mlirty)
1043+ elem = MLIR. IR. Type (MLIR. API. mlirLLVMArrayTypeGetElementType (mlirty))
1044+ n = MLIR. API. mlirLLVMArrayTypeGetNumElements (mlirty)
1045+ return LLVM. ArrayType (to_llvmtype (elem), n)
1046+ end
1047+
1048+ # Struct
1049+ if MLIR. API. mlirTypeIsALLVMStructType (mlirty)
1050+ packed = MLIR. API. mlirLLVMStructTypeIsPacked (mlirty)
1051+ nfields = MLIR. API. mlirLLVMStructTypeGetNumElementTypes (mlirty)
1052+ elems = LLVM. LLVMType[
1053+ to_llvmtype (MLIR. IR. Type (MLIR. API. mlirLLVMStructTypeGetElementType (mlirty, i)))
1054+ for i in 0 : (nfields - 1 )
1055+ ]
1056+ if ! MLIR. API. mlirLLVMStructTypeIsLiteral (mlirty)
1057+ nameref = MLIR. API. mlirLLVMStructTypeGetIdentifier (mlirty)
1058+ sname = unsafe_string (nameref. data, nameref. length)
1059+ st = LLVM. StructType (sname)
1060+ isempty (elems) || LLVM. elements! (st, elems, packed)
1061+ return st
1062+ end
1063+ return LLVM. StructType (elems; packed)
1064+ end
1065+
1066+ # Function
1067+ if MLIR. API. mlirTypeIsALLVMFunctionType (mlirty)
1068+ retty = to_llvmtype (
1069+ MLIR. IR. Type (MLIR. API. mlirLLVMFunctionTypeGetReturnType (mlirty))
1070+ )
1071+ ninputs = MLIR. API. mlirLLVMFunctionTypeGetNumInputs (mlirty)
1072+ params = LLVM. LLVMType[
1073+ to_llvmtype (MLIR. IR. Type (MLIR. API. mlirLLVMFunctionTypeGetInput (mlirty, i))) for
1074+ i in 0 : (ninputs - 1 )
1075+ ]
1076+ return LLVM. FunctionType (
1077+ retty, params; vararg= MLIR. API. mlirLLVMFunctionTypeIsVarArg (mlirty)
1078+ )
1079+ end
1080+
1081+ return error (" cannot convert type to llvm " * string (mlirty))
1082+ end
1083+
10191084function Reactant. make_tracer (
10201085 seen, @nospecialize (prev:: CuTracedArray ), @nospecialize (path), mode; kwargs...
10211086)
@@ -1215,7 +1280,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
12151280 )
12161281
12171282 trueidx = 1
1218- allocs = Union{Tuple{MLIR. IR. Value,MLIR. IR. Type,Type },Nothing}[]
1283+ allocs = Union{Tuple{MLIR. IR. Value,MLIR. IR. Type,LLVM . LLVMType },Nothing}[]
12191284
12201285 llvmptr = MLIR. IR. Type (MLIR. API. mlirLLVMPointerTypeGet (ctx, 0 ))
12211286 i8 = MLIR. IR. Type (UInt8)
@@ -1235,7 +1300,8 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
12351300 )
12361301 trueidx += 1
12371302 jltyp = Core. Typeof (a)
1238- if Enzyme. Compiler. inline_roots_type (jltyp) != 0
1303+ lltyp = to_llvmtype (argty)
1304+ if Enzyme. Compiler. inline_roots_type (llvmtyp) != 0
12391305 trueidx += 1
12401306 end
12411307 c1 = MLIR. IR. result (
@@ -1250,7 +1316,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
12501316 ),
12511317 1 ,
12521318 )
1253- push! (allocs, (alloc, argty, jltyp ))
1319+ push! (allocs, (alloc, argty, lltyp ))
12541320
12551321 if has_cast_float_type
12561322 # The argument `a` has BFloat16 fields but the GPU function was
@@ -1384,17 +1450,17 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
13841450 if arg === nothing
13851451 continue
13861452 end
1387- alloc, argty, jltyp = arg
1453+ alloc, argty, llvmtyp = arg
13881454 argres = MLIR. IR. result (MLIR. Dialects. llvm. load (alloc; res= argty), 1 )
13891455 push! (wrapargs, argres)
1390- if Enzyme. Compiler. inline_roots_type (jltyp ) != 0
1456+ if Enzyme. Compiler. inline_roots_type (llvmtyp ) != 0
13911457 c1 = MLIR. IR. result (
13921458 MLIR. Dialects. llvm. mlir_constant (;
13931459 res= MLIR. IR. Type (Int64), value= MLIR. IR. Attribute (1 )
13941460 ),
13951461 1 ,
13961462 )
1397- roots_count = Enzyme. Compiler. inline_roots_type (jltyp )
1463+ roots_count = Enzyme. Compiler. inline_roots_type (llvmtyp )
13981464 jlvaluet = MLIR. IR. Type (MLIR. API. mlirLLVMPointerTypeGet (ctx, 10 ))
13991465 njlvaluet = MLIR. IR. Type (
14001466 MLIR. API. mlirLLVMArrayTypeGet (jlvaluet, roots_count)
0 commit comments