Skip to content

Commit cf3d052

Browse files
feat: use new enum based API (#2766)
* Regenerate MLIR Bindings * fix: update to use enums * fix: workaround raising issue * fix: more * test: syrk tests are busted --------- Co-authored-by: enzyme-ci-bot[bot] <78882869+enzyme-ci-bot[bot]@users.noreply.github.com>
1 parent 5e641c4 commit cf3d052

12 files changed

Lines changed: 677 additions & 244 deletions

File tree

ext/ReactantLogExpFunctionsExt.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,16 @@ end
9696
function LogExpFunctions.softmax!(
9797
r::AnyTracedRArray{<:Real}, x::AnyTracedRArray{<:Real}=r; dims=:
9898
)
99-
return LogExpFunctions._softmax!(r, x, dims)
99+
dims isa Colon && (dims = 1:ndims(x))
100+
res = @opcall softmax(x; dims=vec(collect(Int64, dims)))
101+
copyto!(r, res)
102+
return r
100103
end
101104

102105
function LogExpFunctions.softmax(x::AnyTracedRArray{<:Real}; dims=:)
103-
return LogExpFunctions._softmax!(similar(x, float(eltype(x))), x, dims)
106+
dims isa Colon && (dims = 1:ndims(x))
107+
res = @opcall softmax(x; dims=vec(collect(Int64, dims)))
108+
return res
104109
end
105110

106111
for (T1, T2) in [(TracedRNumber, Number), (Number, TracedRNumber)]

ext/ReactantMPIExt/Ops.jl

Lines changed: 40 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -184,36 +184,20 @@ end
184184
return recvbuf
185185
end
186186

187-
@enum MPIOpEnum begin
188-
MPI_OP_NULL_ENUM = 0
189-
MPI_BAND_ENUM = 1
190-
MPI_BOR_ENUM = 2
191-
MPI_BXOR_ENUM = 3
192-
MPI_LAND_ENUM = 4
193-
MPI_LOR_ENUM = 5
194-
MPI_LXOR_ENUM = 6
195-
MPI_MAX_ENUM = 7
196-
MPI_MIN_ENUM = 8
197-
MPI_PROD_ENUM = 9
198-
MPI_REPLACE_ENUM = 10
199-
MPI_SUM_ENUM = 11
200-
MPI_NO_OP_ENUM = 12
201-
end
202-
203187
const MPI_OP_MAP = Dict(
204-
MPI.OP_NULL.val => MPI_OP_NULL_ENUM,
205-
MPI.BAND.val => MPI_BAND_ENUM,
206-
MPI.BOR.val => MPI_BOR_ENUM,
207-
MPI.BXOR.val => MPI_BXOR_ENUM,
208-
MPI.LAND.val => MPI_LAND_ENUM,
209-
MPI.LOR.val => MPI_LOR_ENUM,
210-
MPI.LXOR.val => MPI_LXOR_ENUM,
211-
MPI.MAX.val => MPI_MAX_ENUM,
212-
MPI.MIN.val => MPI_MIN_ENUM,
213-
MPI.PROD.val => MPI_PROD_ENUM,
214-
MPI.REPLACE.val => MPI_REPLACE_ENUM,
215-
MPI.SUM.val => MPI_SUM_ENUM,
216-
MPI.NO_OP.val => MPI_NO_OP_ENUM,
188+
MPI.OP_NULL.val => MLIR.API.ENZYMEXLA_MPI_OP_NULL,
189+
MPI.BAND.val => MLIR.API.ENZYMEXLA_MPI_BAND,
190+
MPI.BOR.val => MLIR.API.ENZYMEXLA_MPI_BOR,
191+
MPI.BXOR.val => MLIR.API.ENZYMEXLA_MPI_BXOR,
192+
MPI.LAND.val => MLIR.API.ENZYMEXLA_MPI_LAND,
193+
MPI.LOR.val => MLIR.API.ENZYMEXLA_MPI_LOR,
194+
MPI.LXOR.val => MLIR.API.ENZYMEXLA_MPI_LXOR,
195+
MPI.MAX.val => MLIR.API.ENZYMEXLA_MPI_MAX,
196+
MPI.MIN.val => MLIR.API.ENZYMEXLA_MPI_MIN,
197+
MPI.PROD.val => MLIR.API.ENZYMEXLA_MPI_PROD,
198+
MPI.REPLACE.val => MLIR.API.ENZYMEXLA_MPI_REPLACE,
199+
MPI.SUM.val => MLIR.API.ENZYMEXLA_MPI_SUM,
200+
MPI.NO_OP.val => MLIR.API.ENZYMEXLA_MPI_NO_OP,
217201
)
218202

219203
function get_mpi_op_enum(op)
@@ -222,64 +206,34 @@ function get_mpi_op_enum(op)
222206
end
223207
end
224208

225-
@enum MPIDataTypeEnum begin
226-
MPI_DATATYPE_NULL_ENUM = 0
227-
MPI_INT8_T_ENUM = 1
228-
MPI_UINT8_T_ENUM = 2
229-
MPI_INT16_T_ENUM = 3
230-
MPI_UINT16_T_ENUM = 4
231-
MPI_INT32_T_ENUM = 5
232-
MPI_UINT32_T_ENUM = 6
233-
MPI_INT64_T_ENUM = 7
234-
MPI_UINT64_T_ENUM = 8
235-
MPI_BYTE_ENUM = 9
236-
MPI_SHORT_ENUM = 10
237-
MPI_UNSIGNED_SHORT_ENUM = 11
238-
MPI_INT_ENUM = 12
239-
MPI_UNSIGNED_ENUM = 13
240-
MPI_LONG_ENUM = 14
241-
MPI_UNSIGNED_LONG_ENUM = 15
242-
MPI_LONG_LONG_INT_ENUM = 16
243-
MPI_UNSIGNED_LONG_LONG_ENUM = 17
244-
MPI_CHAR_ENUM = 18
245-
MPI_SIGNED_CHAR_ENUM = 19
246-
MPI_UNSIGNED_CHAR_ENUM = 20
247-
MPI_WCHAR_ENUM = 21
248-
MPI_FLOAT_ENUM = 22
249-
MPI_DOUBLE_ENUM = 23
250-
MPI_C_FLOAT_COMPLEX_ENUM = 24
251-
MPI_C_DOUBLE_COMPLEX_ENUM = 25
252-
MPI_C_BOOL_ENUM = 26
253-
end
254-
255209
const MPI_DATATYPE_MAP = Dict(
256-
MPI.DATATYPE_NULL.val => MPI_DATATYPE_NULL_ENUM,
257-
MPI.INT8_T.val => MPI_INT8_T_ENUM,
258-
MPI.UINT8_T.val => MPI_UINT8_T_ENUM,
259-
MPI.INT16_T.val => MPI_INT16_T_ENUM,
260-
MPI.UINT16_T.val => MPI_UINT16_T_ENUM,
261-
MPI.INT32_T.val => MPI_INT32_T_ENUM,
262-
MPI.UINT32_T.val => MPI_UINT32_T_ENUM,
263-
MPI.INT64_T.val => MPI_INT64_T_ENUM,
264-
MPI.UINT64_T.val => MPI_UINT64_T_ENUM,
265-
MPI.BYTE.val => MPI_BYTE_ENUM,
266-
MPI.SHORT.val => MPI_SHORT_ENUM,
267-
MPI.UNSIGNED_SHORT.val => MPI_UNSIGNED_SHORT_ENUM,
268-
MPI.INT.val => MPI_INT_ENUM,
269-
MPI.UNSIGNED.val => MPI_UNSIGNED_ENUM,
270-
MPI.LONG.val => MPI_LONG_ENUM,
271-
MPI.UNSIGNED_LONG.val => MPI_UNSIGNED_LONG_ENUM,
272-
MPI.LONG_LONG_INT.val => MPI_LONG_LONG_INT_ENUM,
273-
MPI.UNSIGNED_LONG_LONG.val => MPI_UNSIGNED_LONG_LONG_ENUM,
274-
MPI.CHAR.val => MPI_CHAR_ENUM,
275-
MPI.SIGNED_CHAR.val => MPI_SIGNED_CHAR_ENUM,
276-
MPI.UNSIGNED_CHAR.val => MPI_UNSIGNED_CHAR_ENUM,
277-
MPI.WCHAR.val => MPI_WCHAR_ENUM,
278-
MPI.FLOAT.val => MPI_FLOAT_ENUM,
279-
MPI.DOUBLE.val => MPI_DOUBLE_ENUM,
280-
MPI.C_FLOAT_COMPLEX.val => MPI_C_FLOAT_COMPLEX_ENUM,
281-
MPI.C_DOUBLE_COMPLEX.val => MPI_C_DOUBLE_COMPLEX_ENUM,
282-
MPI.C_BOOL.val => MPI_C_BOOL_ENUM,
210+
MPI.DATATYPE_NULL.val => MLIR.API.ENZYMEXLA_MPI_DATATYPE_NULL,
211+
MPI.INT8_T.val => MLIR.API.ENZYMEXLA_MPI_INT8_T,
212+
MPI.UINT8_T.val => MLIR.API.ENZYMEXLA_MPI_UINT8_T,
213+
MPI.INT16_T.val => MLIR.API.ENZYMEXLA_MPI_INT16_T,
214+
MPI.UINT16_T.val => MLIR.API.ENZYMEXLA_MPI_UINT16_T,
215+
MPI.INT32_T.val => MLIR.API.ENZYMEXLA_MPI_INT32_T,
216+
MPI.UINT32_T.val => MLIR.API.ENZYMEXLA_MPI_UINT32_T,
217+
MPI.INT64_T.val => MLIR.API.ENZYMEXLA_MPI_INT64_T,
218+
MPI.UINT64_T.val => MLIR.API.ENZYMEXLA_MPI_UINT64_T,
219+
MPI.BYTE.val => MLIR.API.ENZYMEXLA_MPI_BYTE,
220+
MPI.SHORT.val => MLIR.API.ENZYMEXLA_MPI_SHORT,
221+
MPI.UNSIGNED_SHORT.val => MLIR.API.ENZYMEXLA_MPI_UNSIGNED_SHORT,
222+
MPI.INT.val => MLIR.API.ENZYMEXLA_MPI_INT,
223+
MPI.UNSIGNED.val => MLIR.API.ENZYMEXLA_MPI_UNSIGNED,
224+
MPI.LONG.val => MLIR.API.ENZYMEXLA_MPI_LONG,
225+
MPI.UNSIGNED_LONG.val => MLIR.API.ENZYMEXLA_MPI_UNSIGNED_LONG,
226+
MPI.LONG_LONG_INT.val => MLIR.API.ENZYMEXLA_MPI_LONG_LONG_INT,
227+
MPI.UNSIGNED_LONG_LONG.val => MLIR.API.ENZYMEXLA_MPI_UNSIGNED_LONG_LONG,
228+
MPI.CHAR.val => MLIR.API.ENZYMEXLA_MPI_CHAR,
229+
MPI.SIGNED_CHAR.val => MLIR.API.ENZYMEXLA_MPI_SIGNED_CHAR,
230+
MPI.UNSIGNED_CHAR.val => MLIR.API.ENZYMEXLA_MPI_UNSIGNED_CHAR,
231+
MPI.WCHAR.val => MLIR.API.ENZYMEXLA_MPI_WCHAR,
232+
MPI.FLOAT.val => MLIR.API.ENZYMEXLA_MPI_FLOAT,
233+
MPI.DOUBLE.val => MLIR.API.ENZYMEXLA_MPI_DOUBLE,
234+
MPI.C_FLOAT_COMPLEX.val => MLIR.API.ENZYMEXLA_MPI_C_FLOAT_COMPLEX,
235+
MPI.C_DOUBLE_COMPLEX.val => MLIR.API.ENZYMEXLA_MPI_C_DOUBLE_COMPLEX,
236+
MPI.C_BOOL.val => MLIR.API.ENZYMEXLA_MPI_C_BOOL,
283237
)
284238

285239
function get_mpi_datatype_enum(datatype)

ext/ReactantNNlibExt/Implementations.jl

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,31 +27,15 @@ end
2727

2828
function NNlib.softmax!(out::AnyTracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N}
2929
x = T.(materialize_traced_array(x))
30-
max_ = maximum(x; dims)
31-
diff = exp.(x .- max_)
32-
# TOOD: re-enable conditional once https://github.com/EnzymeAD/Reactant.jl/issues/1581
33-
# fixed
34-
# @trace if all(isfinite, max_)
35-
@. out = diff
36-
# else
37-
# @. out = ifelse(isinf(max_), ifelse(isinf(x), T(1), T(0)), diff)
38-
# end
39-
out ./= sum(out; dims)
30+
res = @opcall softmax(x; dims=vec(collect(Int64, dims)))
31+
copyto!(out, res)
4032
return out
4133
end
4234

4335
function NNlib.logsoftmax!(out::AnyTracedRArray{T}, x::AbstractArray; dims=1) where {T}
4436
x = T.(materialize_traced_array(x))
45-
max_ = maximum(x; dims)
46-
diff = x .- max_
47-
# TOOD: re-enable conditional once https://github.com/EnzymeAD/Reactant.jl/issues/1581
48-
# fixed
49-
# @trace if all(isfinite, max_)
50-
@. out = diff
51-
# else
52-
# @. out = ifelse(isinf(max_), ifelse(isinf(x), T(0), -T(Inf)), diff)
53-
# end
54-
out .-= log.(sum(exp, out; dims))
37+
res = @opcall logsoftmax(x; dims=vec(collect(Int64, dims)))
38+
copyto!(out, res)
5539
return out
5640
end
5741

src/Ops.jl

Lines changed: 67 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,31 @@ using ..Reactant:
1515
using ReactantCore: ReactantCore
1616
using GPUArraysCore: GPUArraysCore
1717

18+
const GELU_APPROXIMATION_MAP = Dict(
19+
"NONE" => MLIR.API.ENZYMEXLA_GELU_APPROXIMATION_NONE,
20+
"TANH" => MLIR.API.ENZYMEXLA_GELU_APPROXIMATION_TANH,
21+
"SIGMOID" => MLIR.API.ENZYMEXLA_GELU_APPROXIMATION_SIGMOID,
22+
)
23+
24+
const LAPACK_TRANSPOSE_MAP = Dict(
25+
'N' => MLIR.API.ENZYMEXLA_LAPACK_TRANSPOSE_NONE,
26+
'T' => MLIR.API.ENZYMEXLA_LAPACK_TRANSPOSE_TRANSPOSE,
27+
'C' => MLIR.API.ENZYMEXLA_LAPACK_TRANSPOSE_CONJUGATE_TRANSPOSE,
28+
)
29+
30+
const LAPACK_UPLO_MAP = Dict(
31+
'U' => MLIR.API.ENZYMEXLA_LAPACK_UPLO_UPPER,
32+
'L' => MLIR.API.ENZYMEXLA_LAPACK_UPLO_LOWER,
33+
'F' => MLIR.API.ENZYMEXLA_LAPACK_UPLO_FULL,
34+
)
35+
36+
const SVD_ALGORITHM_MAP = Dict(
37+
"DEFAULT" => MLIR.API.ENZYMEXLA_SVD_ALGORITHM_NONE,
38+
"QRIteration" => MLIR.API.ENZYMEXLA_SVD_ALGORITHM_QRITERATION,
39+
"Jacobi" => MLIR.API.ENZYMEXLA_SVD_ALGORITHM_JACOBI,
40+
"DivideAndConquer" => MLIR.API.ENZYMEXLA_SVD_ALGORITHM_DIVIDEANDCONQUER,
41+
)
42+
1843
function _function_macro_error()
1944
throw(ArgumentError("`caller_function` is not available in this context"))
2045
end
@@ -3876,26 +3901,16 @@ end
38763901
Vt_size = (batch_sizes..., full ? n : r, n)
38773902
info_size = batch_sizes
38783903

3879-
if algorithm == "DEFAULT"
3880-
algint = 0
3881-
elseif algorithm == "QRIteration"
3882-
algint = 1
3883-
elseif algorithm == "DivideAndConquer"
3884-
algint = 2
3885-
elseif algorithm == "Jacobi"
3886-
algint = 3
3887-
else
3888-
error("Unsupported SVD algorithm: $algorithm")
3889-
end
3890-
38913904
svd_op = enzymexla.linalg_svd(
38923905
x.mlir_data;
38933906
U=mlir_type(TracedRArray{T,N}, U_size),
38943907
S=mlir_type(TracedRArray{Base.real(T),N - 1}, S_size),
38953908
Vt=mlir_type(TracedRArray{T,N}, Vt_size),
38963909
info=mlir_type(TracedRArray{iT,N - 2}, info_size),
38973910
full=full,
3898-
algorithm=MLIR.API.enzymexlaSVDAlgorithmAttrGet(MLIR.IR.current_context(), algint),
3911+
algorithm=MLIR.API.enzymexlaSVDAlgorithmAttrGet(
3912+
MLIR.IR.current_context(), SVD_ALGORITHM_MAP[algorithm]
3913+
),
38993914
location,
39003915
)
39013916

@@ -4103,21 +4118,15 @@ end
41034118
approximation::String;
41044119
location=mlir_stacktrace("ml.gelu", @__FILE__, @__LINE__),
41054120
)
4106-
approx = if approximation == "NONE"
4107-
0
4108-
elseif approximation == "TANH"
4109-
1
4110-
elseif approximation == "SIGMOID"
4111-
2
4112-
else
4113-
error("Invalid gelu approximation: $approximation")
4114-
end
4115-
approx = MLIR.API.enzymexlaGeluApproximationAttrGet(
4116-
MLIR.IR.current_context(), Int32(approx)
4117-
)
4118-
41194121
res = MLIR.IR.result(
4120-
enzymexla.ml_gelu(x.mlir_data; gelu_approximation=approx, location), 1
4122+
enzymexla.ml_gelu(
4123+
x.mlir_data;
4124+
gelu_approximation=MLIR.API.enzymexlaGeluApproximationAttrGet(
4125+
MLIR.IR.current_context(), GELU_APPROXIMATION_MAP[approximation]
4126+
),
4127+
location,
4128+
),
4129+
1,
41214130
)
41224131

41234132
if x isa TracedRArray
@@ -4210,41 +4219,19 @@ end
42104219
location=mlir_stacktrace("syrk", @__FILE__, @__LINE__),
42114220
) where {T,N}
42124221
ctx = MLIR.IR.current_context()
4213-
uplo_attr = MLIR.API.enzymexlaLapackUploAttrGet(
4214-
ctx,
4215-
if uplo == 'U'
4216-
Int32(1)
4217-
elseif uplo == 'L'
4218-
Int32(0)
4219-
else
4220-
Int32(2)
4221-
end,
4222-
)
4223-
transpose_attr = MLIR.API.enzymexlaLapackTransposeAttrGet(
4224-
ctx,
4225-
if transpose_a == 'N'
4226-
Int32(0)
4227-
elseif transpose_a == 'T'
4228-
Int32(1)
4229-
elseif transpose_a == 'C'
4230-
Int32(2)
4231-
else
4232-
error("Unknown transpose mode: $transpose_a")
4233-
end,
4234-
)
4235-
4236-
alpha_ = constant(alpha; location)
4237-
beta_ = constant(beta; location)
4222+
uplo_attr = MLIR.API.enzymexlaLapackUploAttrGet(ctx, LAPACK_UPLO_MAP[uplo])
42384223

42394224
res = MLIR.IR.result(
42404225
enzymexla.blas_syrk(
42414226
A.mlir_data,
42424227
C.mlir_data,
4243-
alpha_.mlir_data,
4244-
beta_.mlir_data;
4228+
constant(alpha; location).mlir_data,
4229+
constant(beta; location).mlir_data;
42454230
uplo=uplo_attr,
42464231
output_uplo=uplo_attr,
4247-
transpose=transpose_attr,
4232+
transpose=MLIR.API.enzymexlaLapackTransposeAttrGet(
4233+
ctx, LAPACK_TRANSPOSE_MAP[transpose_a]
4234+
),
42484235
output=mlir_type(TracedRArray{T,N}, size(C)),
42494236
location,
42504237
),
@@ -4510,4 +4497,29 @@ function julia_callback(
45104497
return Tuple(results)
45114498
end
45124499

4500+
@noinline function softmax(
4501+
x::TracedRArray{T,N};
4502+
dims::Vector{Int64},
4503+
location=mlir_stacktrace("softmax", @__FILE__, @__LINE__),
4504+
) where {T,N}
4505+
max_val = Reactant.call_with_reactant(Core.kwcall, (; dims,), Base.maximum, x)
4506+
exp_diff = exponential(x .- max_val; location)
4507+
denom = Reactant.call_with_reactant(Core.kwcall, (; dims,), Base.sum, exp_diff)
4508+
return exp_diff ./ denom
4509+
end
4510+
4511+
@noinline function logsoftmax(
4512+
x::TracedRArray{T,N};
4513+
dims::Vector{Int64},
4514+
location=mlir_stacktrace("logsoftmax", @__FILE__, @__LINE__),
4515+
) where {T,N}
4516+
max_val = Reactant.call_with_reactant(Core.kwcall, (; dims,), Base.maximum, x)
4517+
diff = x .- max_val
4518+
exp_diff = exponential(diff; location)
4519+
reduced_exp_diff = Reactant.call_with_reactant(
4520+
Core.kwcall, (; dims,), Base.sum, exp_diff
4521+
)
4522+
return diff .- log(reduced_exp_diff; location)
4523+
end
4524+
45134525
end # module Ops

0 commit comments

Comments
 (0)