Skip to content

Commit c5d087b

Browse files
wsmosesPangorawgithub-actions[bot]claudeSouthEndMusic
authored
Bf16cu (#2694)
* force fusion * BFloat16 cuda kernel * add BFloat16s dep * only run bf16 on 1.12 * kernel cast 🪄 * depend on Enzyme-jax kernel-cast pass * remove xla options * Update src/Compiler.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * kernelcast * Copy preserved argument if needed (#2722) * Make copyto! force a new buffer * Update src/Compiler.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * setpahts! * yo * Update src/TracedRArray.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * add aliasing test * pjrt * Update test/core/aliasing.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Move :new_buffer marking to Base.materialize! instead of _copyto! The :new_buffer path should only be set on the broadcast path (a .= b), not on all copyto! calls that happen to receive a Broadcasted argument. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * cleanup * Add buffer aliasing test for struct field reassignment with both new buffers Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * less eager * Update src/Compiler.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * use count * fmt * good args * fixup --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> * Make the CUDA.jl not loaded error more specific (#2743) * Update LIBTPU_VERSION to 0.0.39.dev20260328 * Update Project.toml * Update WORKSPACE * Update ENZYMEXLA_COMMIT to a new hash * Update ENZYMEXLA_COMMIT to new hash * Regenerate MLIR Bindings (#2750) Co-authored-by: enzyme-ci-bot[bot] <78882869+enzyme-ci-bot[bot]@users.noreply.github.com> * Bump version and update Reactant_jll dependency * Regenerate MLIR Bindings (#2756) Co-authored-by: enzyme-ci-bot[bot] <78882869+enzyme-ci-bot[bot]@users.noreply.github.com> * Multifloat options (#2757) * Multifloat options * fix * fix * fix * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update Project.toml * `div` for ConcreteRNumber (#2674) * div for ConcreteRNumber * Update test/core/math_ops.jl Co-authored-by: Paul Berg <naydex.mc+github@gmail.com> * Update src/ConcreteRArray.jl Co-authored-by: Paul Berg <naydex.mc+github@gmail.com> * Update src/ConcreteRArray.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/ConcreteRArray.jl Co-authored-by: Paul Berg <naydex.mc+github@gmail.com> * Update src/ConcreteRArray.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/ConcreteRArray.jl * remove methods * Revert "remove methods" This reverts commit ab20e23. * meth * fix ambiguities --------- Co-authored-by: Paul Berg <naydex.mc+github@gmail.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Paul Berg <9824244+Pangoraw@users.noreply.github.com> * Multifloat fix pass parsing (#2759) * Change shardy_passes default value to post_sdy_propagation This is currently what the docs say * Bump version from 0.2.246 to 0.2.247 * Update ENZYMEXLA_COMMIT hash in WORKSPACE * Update ENZYMEXLA_COMMIT to new hash (#2760) * Regenerate MLIR Bindings (#2761) Co-authored-by: enzyme-ci-bot[bot] <78882869+enzyme-ci-bot[bot]@users.noreply.github.com> * Update LIBTPU_VERSION to 0.0.39.dev20260401 * fix: qa test (#2763) * feat: better StructArray & StaticArray support (#2546) * Draft to figure out better StructArray support * Simplify and generalize structarray type conversion * Start adding StaticArray support * Add StaticArray support and tweak elem_apply_while_loop to select correct container type * Revert tracing.jl * Remove info debug * Remove get_ith * Add _copyto! * format * Fix broken test and add new tests * format * add StaticArrays * Add LinearAlgebra * Remove unused function * Reuse the known destination for while loop if possible * Update ext/ReactantStructArraysExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Proposed improved support for SArrays * fix dumb mistake * Add additional changes for StaticArrays * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Cleanup * Update * Update * Format * Fix for code review * Add comments * So dumb * Correct comment in overloaded_mul function Fix comment typo in overloaded_mul function. * Update to remove anonymous functions * Update * Add a complex test * Update --------- Co-authored-by: Billy Moses <wmoses@google.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Avik Pal <avikpal@mit.edu> * Update WORKSPACE * untested * turn it down * Attach data-layout to MLIR module * claude is cooking * dl * Add BFloat16 extension * compat * make a hard dep --------- Co-authored-by: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Co-authored-by: Paul Berg <naydex.mc+github@gmail.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> Co-authored-by: Bart de Koning <74617371+SouthEndMusic@users.noreply.github.com> Co-authored-by: enzyme-ci-bot[bot] <78882869+enzyme-ci-bot[bot]@users.noreply.github.com> Co-authored-by: Maximilian Gelbrecht <maximilian.gelbrecht@posteo.de> Co-authored-by: Avik Pal <avikpal@mit.edu> Co-authored-by: dkytezab <danielkytezable@gmail.com> Co-authored-by: Paul Tiede <ptiede91@gmail.com>
1 parent de9b977 commit c5d087b

14 files changed

Lines changed: 299 additions & 80 deletions

File tree

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ projects = ["docs", "test"]
88

99
[deps]
1010
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
11+
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
1112
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
1213
Crayons = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
1314
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
@@ -104,6 +105,7 @@ ReactantZygoteExt = "Zygote"
104105
AbstractFFTs = "1.5"
105106
Adapt = "4.4"
106107
ArrayInterface = "7.17.1"
108+
BFloat16s = "0.6.1"
107109
CEnum = "0.5"
108110
CUDA = "5.9"
109111
Crayons = "4.1.1"

deps/ReactantExtra/API.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,15 +1249,13 @@ REACTANT_ABI uint8_t FutureIsReady(FutureType *Future) {
12491249

12501250
REACTANT_ABI void FutureAwait(FutureType *Future) { Future->Await(); }
12511251

1252-
xla::CompileOptions
1253-
GenerateCompileOptions(int64_t device_id, const int64_t *mesh_ids,
1254-
int64_t num_mesh_ids, const char *xla_gpu_cuda_data_dir,
1255-
bool use_shardy_partitioner, int64_t num_replicas,
1256-
int64_t num_partitions, bool use_spmd_partitioning,
1257-
bool kernel_cache_enabled, const char *kernel_cache_path,
1258-
bool autotune_cache_enabled,
1259-
const char *autotune_cache_path, int process_id,
1260-
bool xla_enable_enzyme_comms_opt) {
1252+
xla::CompileOptions GenerateCompileOptions(
1253+
int64_t device_id, const int64_t *mesh_ids, int64_t num_mesh_ids,
1254+
const char *xla_gpu_cuda_data_dir, bool use_shardy_partitioner,
1255+
int64_t num_replicas, int64_t num_partitions, bool use_spmd_partitioning,
1256+
bool kernel_cache_enabled, const char *kernel_cache_path,
1257+
bool autotune_cache_enabled, const char *autotune_cache_path,
1258+
int process_id, bool xla_enable_enzyme_comms_opt) {
12611259
xla::CompileOptions options;
12621260
auto debug_options = options.executable_build_options.mutable_debug_options();
12631261

@@ -1905,14 +1903,15 @@ ifrt_compile(ifrt::Client *client, MlirModule cmod, int64_t device_id,
19051903
bool use_spmd_partitioning, bool kernel_cache_enabled,
19061904
const char *kernel_cache_path, bool autotune_cache_enabled,
19071905
const char *autotune_cache_path, int process_id,
1908-
bool xla_enable_enzyme_comms_opt) {
1906+
bool xla_enable_enzyme_comms_opt) {
19091907
return ifrt_compile_internal(
19101908
client, cmod,
19111909
GenerateCompileOptions(
19121910
device_id, mesh_ids, num_mesh_ids, xla_gpu_cuda_data_dir,
19131911
use_shardy_partitioner, num_replicas, num_partitions,
19141912
use_spmd_partitioning, kernel_cache_enabled, kernel_cache_path,
1915-
autotune_cache_enabled, autotune_cache_path, process_id, xla_enable_enzyme_comms_opt));
1913+
autotune_cache_enabled, autotune_cache_path, process_id,
1914+
xla_enable_enzyme_comms_opt));
19161915
}
19171916

19181917
REACTANT_ABI HeldIfrtLoadedExecutable *

ext/ReactantCUDAExt.jl

Lines changed: 239 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module ReactantCUDAExt
22

3+
using BFloat16s: BFloat16
34
using Reactant:
45
Reactant,
56
TracedRArray,
@@ -959,6 +960,7 @@ function compile(job)
959960
throw(GPUCompiler.InvalidIRError(job, errors))
960961
end
961962
# LLVM.strip_debuginfo!(mod)
963+
dl = string(LLVM.datalayout(mod))
962964
modstr = string(mod)
963965
# This is a bit weird since we're taking a module from julia's llvm into reactant's llvm version
964966
# it is probably safer to reparse a string using the right llvm module api, so we will do that.
@@ -967,7 +969,17 @@ function compile(job)
967969
)
968970
@assert mmod != C_NULL
969971

970-
linkRes = MLIR.API.LinkInModule(MLIR.IR.current_module(), mmod, entryname)
972+
cur_module = MLIR.IR.current_module()
973+
linkRes = MLIR.API.LinkInModule(cur_module, mmod, entryname)
974+
975+
dl_attr_name = "llvm.data_layout"
976+
prevdlattr = MLIR.IR.getattr(MLIR.IR.Operation(cur_module), dl_attr_name)
977+
if !isnothing(prevdlattr)
978+
prevdl = String(prevdlattr)
979+
@assert prevdl == dl "data layout mismatch, tried compiling cuda kernels for different target machines?"
980+
else
981+
MLIR.IR.setattr!(MLIR.IR.Operation(cur_module), dl_attr_name, MLIR.IR.Attribute(dl))
982+
end
971983

972984
String(Reactant.TracedUtils.get_attribute_by_name(linkRes, "sym_name"))
973985
end
@@ -1135,6 +1147,13 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
11351147
for (i, prev) in enumerate(Any[func.f, args...])
11361148
Reactant.make_tracer(seen, prev, (kernelargsym, i), Reactant.NoStopTracedTrack)
11371149
end
1150+
bfloat16_compile_type = Reactant.Compiler.BFLOAT16_COMPILE_TYPE[]
1151+
has_cast_float_type =
1152+
bfloat16_compile_type !== BFloat16 && any(values(seen)) do arg
1153+
(arg isa TracedRArray || arg isa TracedRNumber) &&
1154+
Reactant.unwrapped_eltype(typeof(arg)) === BFloat16
1155+
end
1156+
11381157
wrapper_tys = MLIR.IR.Type[]
11391158
for arg in values(seen)
11401159
if !(arg isa TracedRArray || arg isa TracedRNumber)
@@ -1162,6 +1181,16 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
11621181
end
11631182
wrapbody = MLIR.IR.Block(wrapper_tys, [MLIR.IR.Location() for _ in wrapper_tys])
11641183
push!(MLIR.IR.region(wrapfunc, 1), wrapbody)
1184+
if has_cast_float_type
1185+
MLIR.IR.setattr!(
1186+
wrapfunc, "enzymexla.float_type", MLIR.IR.Attribute(MLIR.IR.Type(BFloat16))
1187+
)
1188+
MLIR.IR.setattr!(
1189+
wrapfunc,
1190+
"enzymexla.src_float_type",
1191+
MLIR.IR.Attribute(MLIR.IR.Type(bfloat16_compile_type)),
1192+
)
1193+
end
11651194
for i in 1:length(wrapper_tys)
11661195
MLIR.API.ReactantFuncSetArgAttr(
11671196
wrapfunc,
@@ -1223,15 +1252,65 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
12231252
)
12241253
push!(allocs, (alloc, argty, jltyp))
12251254

1226-
sz = abi_sizeof(a)
1227-
array_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGet(MLIR.IR.Type(Int8), sz))
1228-
cdata = MLIR.IR.result(
1229-
MLIR.Dialects.llvm.mlir_constant(;
1230-
res=array_ty, value=MLIR.IR.DenseElementsAttribute(to_bytes(a))
1231-
),
1232-
1,
1233-
)
1234-
MLIR.Dialects.llvm.store(cdata, alloc)
1255+
if has_cast_float_type
1256+
# The argument `a` has BFloat16 fields but the GPU function was
1257+
# compiled with a substitute type (e.g. Float32). We need to:
1258+
# 1. Create an alloca with the bf16 layout
1259+
# 2. Store the raw bf16 bytes into it
1260+
# 3. Load, walk the struct fields, extend bf16→f32, store into alloc
1261+
compile_float_ty = MLIR.IR.Type(bfloat16_compile_type)
1262+
bf16_float_ty = MLIR.IR.Type(BFloat16)
1263+
bf16_ty = _replace_float_in_llvm_type(
1264+
argty, compile_float_ty, bf16_float_ty
1265+
)
1266+
1267+
bf16_c1 = MLIR.IR.result(
1268+
MLIR.Dialects.llvm.mlir_constant(;
1269+
res=MLIR.IR.Type(Int64), value=MLIR.IR.Attribute(1)
1270+
),
1271+
1,
1272+
)
1273+
bf16_alloc = MLIR.IR.result(
1274+
MLIR.Dialects.llvm.alloca(
1275+
bf16_c1; elem_type=MLIR.IR.Attribute(bf16_ty), res=llvmptr
1276+
),
1277+
1,
1278+
)
1279+
1280+
sz = abi_sizeof(a)
1281+
val = to_bytes(a)
1282+
array_ty = MLIR.IR.Type(
1283+
MLIR.API.mlirLLVMArrayTypeGet(MLIR.IR.Type(Int8), sz)
1284+
)
1285+
cdata = MLIR.IR.result(
1286+
MLIR.Dialects.llvm.mlir_constant(;
1287+
res=array_ty, value=MLIR.IR.DenseElementsAttribute(val)
1288+
),
1289+
1,
1290+
)
1291+
MLIR.Dialects.llvm.store(cdata, bf16_alloc)
1292+
1293+
bf16_val = MLIR.IR.result(
1294+
MLIR.Dialects.llvm.load(bf16_alloc; res=bf16_ty), 1
1295+
)
1296+
converted_val = _convert_bf16_value(
1297+
bf16_val, bf16_ty, argty, bf16_float_ty, compile_float_ty
1298+
)
1299+
MLIR.Dialects.llvm.store(converted_val, alloc)
1300+
else
1301+
sz = abi_sizeof(a)
1302+
val = to_bytes(a)
1303+
array_ty = MLIR.IR.Type(
1304+
MLIR.API.mlirLLVMArrayTypeGet(MLIR.IR.Type(Int8), sz)
1305+
)
1306+
cdata = MLIR.IR.result(
1307+
MLIR.Dialects.llvm.mlir_constant(;
1308+
res=array_ty, value=MLIR.IR.DenseElementsAttribute(val)
1309+
),
1310+
1,
1311+
)
1312+
MLIR.Dialects.llvm.store(cdata, alloc)
1313+
end
12351314
end
12361315
end
12371316
LLVM.deactivate(ctx)
@@ -1275,7 +1354,14 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
12751354
# we need to now compute the offset in bytes of the path
12761355
julia_arg = allargs[p[2]]
12771356

1278-
offset = get_field_offset(typeof(julia_arg), p[3:end])
1357+
offset = if has_cast_float_type
1358+
get_field_offset(
1359+
_bfloat16_to_ft_type(typeof(julia_arg), bfloat16_compile_type),
1360+
p[3:end],
1361+
)
1362+
else
1363+
get_field_offset(typeof(julia_arg), p[3:end])
1364+
end
12791365
MLIR.IR.with_block(wrapbody) do
12801366
ptr = MLIR.IR.result(
12811367
MLIR.Dialects.llvm.getelementptr(
@@ -1353,6 +1439,9 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
13531439
"enzymexla.kernel_call", @__FILE__, @__LINE__
13541440
),
13551441
)
1442+
if has_cast_float_type
1443+
MLIR.IR.setattr!(call, "cast_float_type", MLIR.IR.UnitAttribute())
1444+
end
13561445

13571446
argidx = 1
13581447
for arg in values(seen)
@@ -1364,13 +1453,151 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
13641453
end
13651454
end
13661455

1456+
function _bfloat16_to_ft_type(@nospecialize(T), @nospecialize(FT))
1457+
T === BFloat16 && return FT
1458+
T isa DataType || return T
1459+
isempty(T.parameters) && return T
1460+
new_params = Any[_bfloat16_to_ft_type(p, FT) for p in T.parameters]
1461+
all(p1 === p2 for (p1, p2) in zip(T.parameters, new_params)) && return T
1462+
return T.name.wrapper{new_params...}
1463+
end
1464+
1465+
function _substitute_bfloat16_tt(@nospecialize(tt::Type{<:Tuple}), @nospecialize(FT))
1466+
new_params = Any[_bfloat16_to_ft_type(T, FT) for T in tt.parameters]
1467+
return Tuple{new_params...}
1468+
end
1469+
1470+
"""
1471+
_replace_float_in_llvm_type(ty, src_float_ty, tgt_float_ty)
1472+
1473+
Recursively walk an LLVM type and replace `src_float_ty` with `tgt_float_ty`.
1474+
Handles struct types and array types.
1475+
"""
1476+
function _replace_float_in_llvm_type(
1477+
ty::MLIR.IR.Type, src_float_ty::MLIR.IR.Type, tgt_float_ty::MLIR.IR.Type
1478+
)
1479+
ty == src_float_ty && return tgt_float_ty
1480+
if MLIR.API.mlirTypeIsALLVMStructType(ty)
1481+
n = MLIR.API.mlirLLVMStructTypeGetNumElementTypes(ty)
1482+
field_types = MLIR.IR.Type[
1483+
_replace_float_in_llvm_type(
1484+
MLIR.IR.Type(MLIR.API.mlirLLVMStructTypeGetElementType(ty, i - 1)),
1485+
src_float_ty,
1486+
tgt_float_ty,
1487+
) for i in 1:n
1488+
]
1489+
if all(
1490+
field_types[i] ==
1491+
MLIR.IR.Type(MLIR.API.mlirLLVMStructTypeGetElementType(ty, i - 1)) for i in 1:n
1492+
)
1493+
return ty
1494+
end
1495+
ctx = MLIR.IR.current_context()
1496+
is_packed = MLIR.API.mlirLLVMStructTypeIsPacked(ty)
1497+
return MLIR.IR.Type(
1498+
MLIR.API.mlirLLVMStructTypeLiteralGet(ctx, n, field_types, is_packed)
1499+
)
1500+
elseif MLIR.API.mlirTypeIsALLVMArrayType(ty)
1501+
elem_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGetElementType(ty))
1502+
new_elem_ty = _replace_float_in_llvm_type(elem_ty, src_float_ty, tgt_float_ty)
1503+
if new_elem_ty == elem_ty
1504+
return ty
1505+
end
1506+
num_elems = MLIR.API.mlirLLVMArrayTypeGetNumElements(ty)
1507+
return MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGet(new_elem_ty, num_elems))
1508+
end
1509+
return ty
1510+
end
1511+
1512+
"""
1513+
_convert_bf16_value(src_val, src_ty, tgt_ty, src_float_ty, tgt_float_ty)
1514+
1515+
Recursively walk an LLVM value, converting float fields from `src_float_ty` to
1516+
`tgt_float_ty` using arith.extf. Returns a new value of type `tgt_ty`.
1517+
"""
1518+
function _convert_bf16_value(
1519+
src_val::MLIR.IR.Value,
1520+
src_ty::MLIR.IR.Type,
1521+
tgt_ty::MLIR.IR.Type,
1522+
src_float_ty::MLIR.IR.Type,
1523+
tgt_float_ty::MLIR.IR.Type,
1524+
)
1525+
src_ty == tgt_ty && return src_val
1526+
if src_ty == src_float_ty
1527+
src_width = MLIR.API.mlirFloatTypeGetWidth(src_float_ty)
1528+
tgt_width = MLIR.API.mlirFloatTypeGetWidth(tgt_float_ty)
1529+
if tgt_width > src_width
1530+
return MLIR.IR.result(MLIR.Dialects.llvm.fpext(src_val; res=tgt_float_ty), 1)
1531+
elseif tgt_width < src_width
1532+
return MLIR.IR.result(MLIR.Dialects.llvm.fptrunc(src_val; res=tgt_float_ty), 1)
1533+
else
1534+
return MLIR.IR.result(MLIR.Dialects.llvm.fptrunc(src_val; res=tgt_float_ty), 1)
1535+
end
1536+
end
1537+
if MLIR.API.mlirTypeIsALLVMStructType(src_ty)
1538+
n = MLIR.API.mlirLLVMStructTypeGetNumElementTypes(src_ty)
1539+
tgt_val = MLIR.IR.result(MLIR.Dialects.llvm.mlir_undef(; res=tgt_ty), 1)
1540+
for i in 0:(n - 1)
1541+
field_src_ty = MLIR.IR.Type(
1542+
MLIR.API.mlirLLVMStructTypeGetElementType(src_ty, i)
1543+
)
1544+
field_tgt_ty = MLIR.IR.Type(
1545+
MLIR.API.mlirLLVMStructTypeGetElementType(tgt_ty, i)
1546+
)
1547+
field_val = MLIR.IR.result(
1548+
MLIR.Dialects.llvm.extractvalue(
1549+
src_val; res=field_src_ty, position=MLIR.IR.Attribute(Int64[i])
1550+
),
1551+
1,
1552+
)
1553+
converted = _convert_bf16_value(
1554+
field_val, field_src_ty, field_tgt_ty, src_float_ty, tgt_float_ty
1555+
)
1556+
tgt_val = MLIR.IR.result(
1557+
MLIR.Dialects.llvm.insertvalue(
1558+
tgt_val, converted; res=tgt_ty, position=MLIR.IR.Attribute(Int64[i])
1559+
),
1560+
1,
1561+
)
1562+
end
1563+
return tgt_val
1564+
elseif MLIR.API.mlirTypeIsALLVMArrayType(src_ty)
1565+
num_elems = MLIR.API.mlirLLVMArrayTypeGetNumElements(src_ty)
1566+
elem_src_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGetElementType(src_ty))
1567+
elem_tgt_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGetElementType(tgt_ty))
1568+
tgt_val = MLIR.IR.result(MLIR.Dialects.llvm.mlir_undef(; res=tgt_ty), 1)
1569+
for i in 0:(num_elems - 1)
1570+
elem_val = MLIR.IR.result(
1571+
MLIR.Dialects.llvm.extractvalue(
1572+
src_val; res=elem_src_ty, position=MLIR.IR.Attribute(Int64[i])
1573+
),
1574+
1,
1575+
)
1576+
converted = _convert_bf16_value(
1577+
elem_val, elem_src_ty, elem_tgt_ty, src_float_ty, tgt_float_ty
1578+
)
1579+
tgt_val = MLIR.IR.result(
1580+
MLIR.Dialects.llvm.insertvalue(
1581+
tgt_val, converted; res=tgt_ty, position=MLIR.IR.Attribute(Int64[i])
1582+
),
1583+
1,
1584+
)
1585+
end
1586+
return tgt_val
1587+
end
1588+
return src_val
1589+
end
1590+
13671591
Reactant.@reactant_overlay @noinline function CUDA.cufunction(
13681592
f::F, tt::TT=Tuple{}; kwargs...
13691593
) where {F,TT}
13701594
res = Base.@lock CUDA.cufunction_lock begin
13711595
# compile the function
13721596
cache = llvm_compiler_cache(MLIR.IR.current_module())
1373-
source = CUDA.methodinstance(F, tt)
1597+
effective_tt = _substitute_bfloat16_tt(
1598+
tt, Reactant.Compiler.BFLOAT16_COMPILE_TYPE[]
1599+
)
1600+
source = CUDA.methodinstance(F, effective_tt)
13741601
# cuda = CUDA.active_state()
13751602
device = nothing # cuda.device
13761603
# config = CUDA.compiler_config(device; kwargs...)::CUDA.CUDACompilerConfig

src/Compiler.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1118,6 +1118,7 @@ function cubinFeatures()
11181118
return "+ptx$ptx"
11191119
end
11201120

1121+
const BFLOAT16_COMPILE_TYPE = Ref{DataType}(Float32)
11211122
const DEBUG_KERNEL = Ref{Bool}(false)
11221123
const DUMP_LLVMIR = Ref{Bool}(false)
11231124
const DUMP_FAILED_LOCKSTEP = Ref{Bool}(false)
@@ -1343,7 +1344,7 @@ function compile_mlir!(
13431344
# Raise enabled but use default passes
13441345
# TODO(#2240) remove redundant libdevice raise after fixing phase ordering
13451346
result =
1346-
"canonicalize,llvm-to-memref-access,canonicalize,convert-llvm-to-cf,canonicalize,enzyme-lift-cf-to-scf,canonicalize,func.func(canonicalize-loops),canonicalize-scf-for,canonicalize,libdevice-funcs-raise,canonicalize,affine-cfg,canonicalize,func.func(canonicalize-loops),canonicalize,llvm-to-affine-access,canonicalize,delinearize-indexing,canonicalize,simplify-affine-exprs,affine-cfg,canonicalize,func.func(affine-loop-invariant-code-motion),canonicalize,sort-memory,raise-affine-to-stablehlo{strip_llvm_debuginfo=$(compile_options.strip_llvm_debuginfo) prefer_while_raising=false dump_failed_lockstep=$(DUMP_FAILED_LOCKSTEP[])},canonicalize,arith-raise{stablehlo=true}," *
1347+
"canonicalize,llvm-to-memref-access,canonicalize,convert-llvm-to-cf,canonicalize,enzyme-lift-cf-to-scf,canonicalize,func.func(canonicalize-loops),canonicalize-scf-for,canonicalize,libdevice-funcs-raise,canonicalize,affine-cfg,canonicalize,func.func(canonicalize-loops),canonicalize,llvm-to-affine-access,canonicalize,delinearize-indexing,canonicalize,simplify-affine-exprs,affine-cfg,canonicalize,func.func(affine-loop-invariant-code-motion),canonicalize,sort-memory,func.func(kernelcast),raise-affine-to-stablehlo{strip_llvm_debuginfo=$(compile_options.strip_llvm_debuginfo) prefer_while_raising=false dump_failed_lockstep=$(DUMP_FAILED_LOCKSTEP[])},canonicalize,arith-raise{stablehlo=true}," *
13471348
opt_passes2
13481349

13491350
if DUS_TO_CONCAT[]

0 commit comments

Comments
 (0)