Skip to content

Commit 83dc55b

Browse files
committed
fix: update to use enums
1 parent 42b5f2c commit 83dc55b

2 files changed

Lines changed: 77 additions & 133 deletions

File tree

ext/ReactantMPIExt/Ops.jl

Lines changed: 40 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -184,36 +184,20 @@ end
184184
return recvbuf
185185
end
186186

187-
@enum MPIOpEnum begin
188-
MPI_OP_NULL_ENUM = 0
189-
MPI_BAND_ENUM = 1
190-
MPI_BOR_ENUM = 2
191-
MPI_BXOR_ENUM = 3
192-
MPI_LAND_ENUM = 4
193-
MPI_LOR_ENUM = 5
194-
MPI_LXOR_ENUM = 6
195-
MPI_MAX_ENUM = 7
196-
MPI_MIN_ENUM = 8
197-
MPI_PROD_ENUM = 9
198-
MPI_REPLACE_ENUM = 10
199-
MPI_SUM_ENUM = 11
200-
MPI_NO_OP_ENUM = 12
201-
end
202-
203187
const MPI_OP_MAP = Dict(
204-
MPI.OP_NULL.val => MPI_OP_NULL_ENUM,
205-
MPI.BAND.val => MPI_BAND_ENUM,
206-
MPI.BOR.val => MPI_BOR_ENUM,
207-
MPI.BXOR.val => MPI_BXOR_ENUM,
208-
MPI.LAND.val => MPI_LAND_ENUM,
209-
MPI.LOR.val => MPI_LOR_ENUM,
210-
MPI.LXOR.val => MPI_LXOR_ENUM,
211-
MPI.MAX.val => MPI_MAX_ENUM,
212-
MPI.MIN.val => MPI_MIN_ENUM,
213-
MPI.PROD.val => MPI_PROD_ENUM,
214-
MPI.REPLACE.val => MPI_REPLACE_ENUM,
215-
MPI.SUM.val => MPI_SUM_ENUM,
216-
MPI.NO_OP.val => MPI_NO_OP_ENUM,
188+
MPI.OP_NULL.val => MLIR.API.ENZYMEXLA_MPI_OP_NULL,
189+
MPI.BAND.val => MLIR.API.ENZYMEXLA_MPI_BAND,
190+
MPI.BOR.val => MLIR.API.ENZYMEXLA_MPI_BOR,
191+
MPI.BXOR.val => MLIR.API.ENZYMEXLA_MPI_BXOR,
192+
MPI.LAND.val => MLIR.API.ENZYMEXLA_MPI_LAND,
193+
MPI.LOR.val => MLIR.API.ENZYMEXLA_MPI_LOR,
194+
MPI.LXOR.val => MLIR.API.ENZYMEXLA_MPI_LXOR,
195+
MPI.MAX.val => MLIR.API.ENZYMEXLA_MPI_MAX,
196+
MPI.MIN.val => MLIR.API.ENZYMEXLA_MPI_MIN,
197+
MPI.PROD.val => MLIR.API.ENZYMEXLA_MPI_PROD,
198+
MPI.REPLACE.val => MLIR.API.ENZYMEXLA_MPI_REPLACE,
199+
MPI.SUM.val => MLIR.API.ENZYMEXLA_MPI_SUM,
200+
MPI.NO_OP.val => MLIR.API.ENZYMEXLA_MPI_NO_OP,
217201
)
218202

219203
function get_mpi_op_enum(op)
@@ -222,64 +206,34 @@ function get_mpi_op_enum(op)
222206
end
223207
end
224208

225-
@enum MPIDataTypeEnum begin
226-
MPI_DATATYPE_NULL_ENUM = 0
227-
MPI_INT8_T_ENUM = 1
228-
MPI_UINT8_T_ENUM = 2
229-
MPI_INT16_T_ENUM = 3
230-
MPI_UINT16_T_ENUM = 4
231-
MPI_INT32_T_ENUM = 5
232-
MPI_UINT32_T_ENUM = 6
233-
MPI_INT64_T_ENUM = 7
234-
MPI_UINT64_T_ENUM = 8
235-
MPI_BYTE_ENUM = 9
236-
MPI_SHORT_ENUM = 10
237-
MPI_UNSIGNED_SHORT_ENUM = 11
238-
MPI_INT_ENUM = 12
239-
MPI_UNSIGNED_ENUM = 13
240-
MPI_LONG_ENUM = 14
241-
MPI_UNSIGNED_LONG_ENUM = 15
242-
MPI_LONG_LONG_INT_ENUM = 16
243-
MPI_UNSIGNED_LONG_LONG_ENUM = 17
244-
MPI_CHAR_ENUM = 18
245-
MPI_SIGNED_CHAR_ENUM = 19
246-
MPI_UNSIGNED_CHAR_ENUM = 20
247-
MPI_WCHAR_ENUM = 21
248-
MPI_FLOAT_ENUM = 22
249-
MPI_DOUBLE_ENUM = 23
250-
MPI_C_FLOAT_COMPLEX_ENUM = 24
251-
MPI_C_DOUBLE_COMPLEX_ENUM = 25
252-
MPI_C_BOOL_ENUM = 26
253-
end
254-
255209
const MPI_DATATYPE_MAP = Dict(
256-
MPI.DATATYPE_NULL.val => MPI_DATATYPE_NULL_ENUM,
257-
MPI.INT8_T.val => MPI_INT8_T_ENUM,
258-
MPI.UINT8_T.val => MPI_UINT8_T_ENUM,
259-
MPI.INT16_T.val => MPI_INT16_T_ENUM,
260-
MPI.UINT16_T.val => MPI_UINT16_T_ENUM,
261-
MPI.INT32_T.val => MPI_INT32_T_ENUM,
262-
MPI.UINT32_T.val => MPI_UINT32_T_ENUM,
263-
MPI.INT64_T.val => MPI_INT64_T_ENUM,
264-
MPI.UINT64_T.val => MPI_UINT64_T_ENUM,
265-
MPI.BYTE.val => MPI_BYTE_ENUM,
266-
MPI.SHORT.val => MPI_SHORT_ENUM,
267-
MPI.UNSIGNED_SHORT.val => MPI_UNSIGNED_SHORT_ENUM,
268-
MPI.INT.val => MPI_INT_ENUM,
269-
MPI.UNSIGNED.val => MPI_UNSIGNED_ENUM,
270-
MPI.LONG.val => MPI_LONG_ENUM,
271-
MPI.UNSIGNED_LONG.val => MPI_UNSIGNED_LONG_ENUM,
272-
MPI.LONG_LONG_INT.val => MPI_LONG_LONG_INT_ENUM,
273-
MPI.UNSIGNED_LONG_LONG.val => MPI_UNSIGNED_LONG_LONG_ENUM,
274-
MPI.CHAR.val => MPI_CHAR_ENUM,
275-
MPI.SIGNED_CHAR.val => MPI_SIGNED_CHAR_ENUM,
276-
MPI.UNSIGNED_CHAR.val => MPI_UNSIGNED_CHAR_ENUM,
277-
MPI.WCHAR.val => MPI_WCHAR_ENUM,
278-
MPI.FLOAT.val => MPI_FLOAT_ENUM,
279-
MPI.DOUBLE.val => MPI_DOUBLE_ENUM,
280-
MPI.C_FLOAT_COMPLEX.val => MPI_C_FLOAT_COMPLEX_ENUM,
281-
MPI.C_DOUBLE_COMPLEX.val => MPI_C_DOUBLE_COMPLEX_ENUM,
282-
MPI.C_BOOL.val => MPI_C_BOOL_ENUM,
210+
MPI.DATATYPE_NULL.val => MLIR.API.ENZYMEXLA_MPI_DATATYPE_NULL,
211+
MPI.INT8_T.val => MLIR.API.ENZYMEXLA_MPI_INT8_T,
212+
MPI.UINT8_T.val => MLIR.API.ENZYMEXLA_MPI_UINT8_T,
213+
MPI.INT16_T.val => MLIR.API.ENZYMEXLA_MPI_INT16_T,
214+
MPI.UINT16_T.val => MLIR.API.ENZYMEXLA_MPI_UINT16_T,
215+
MPI.INT32_T.val => MLIR.API.ENZYMEXLA_MPI_INT32_T,
216+
MPI.UINT32_T.val => MLIR.API.ENZYMEXLA_MPI_UINT32_T,
217+
MPI.INT64_T.val => MLIR.API.ENZYMEXLA_MPI_INT64_T,
218+
MPI.UINT64_T.val => MLIR.API.ENZYMEXLA_MPI_UINT64_T,
219+
MPI.BYTE.val => MLIR.API.ENZYMEXLA_MPI_BYTE,
220+
MPI.SHORT.val => MLIR.API.ENZYMEXLA_MPI_SHORT,
221+
MPI.UNSIGNED_SHORT.val => MLIR.API.ENZYMEXLA_MPI_UNSIGNED_SHORT,
222+
MPI.INT.val => MLIR.API.ENZYMEXLA_MPI_INT,
223+
MPI.UNSIGNED.val => MLIR.API.ENZYMEXLA_MPI_UNSIGNED,
224+
MPI.LONG.val => MLIR.API.ENZYMEXLA_MPI_LONG,
225+
MPI.UNSIGNED_LONG.val => MLIR.API.ENZYMEXLA_MPI_UNSIGNED_LONG,
226+
MPI.LONG_LONG_INT.val => MLIR.API.ENZYMEXLA_MPI_LONG_LONG_INT,
227+
MPI.UNSIGNED_LONG_LONG.val => MLIR.API.ENZYMEXLA_MPI_UNSIGNED_LONG_LONG,
228+
MPI.CHAR.val => MLIR.API.ENZYMEXLA_MPI_CHAR,
229+
MPI.SIGNED_CHAR.val => MLIR.API.ENZYMEXLA_MPI_SIGNED_CHAR,
230+
MPI.UNSIGNED_CHAR.val => MLIR.API.ENZYMEXLA_MPI_UNSIGNED_CHAR,
231+
MPI.WCHAR.val => MLIR.API.ENZYMEXLA_MPI_WCHAR,
232+
MPI.FLOAT.val => MLIR.API.ENZYMEXLA_MPI_FLOAT,
233+
MPI.DOUBLE.val => MLIR.API.ENZYMEXLA_MPI_DOUBLE,
234+
MPI.C_FLOAT_COMPLEX.val => MLIR.API.ENZYMEXLA_MPI_C_FLOAT_COMPLEX,
235+
MPI.C_DOUBLE_COMPLEX.val => MLIR.API.ENZYMEXLA_MPI_C_DOUBLE_COMPLEX,
236+
MPI.C_BOOL.val => MLIR.API.ENZYMEXLA_MPI_C_BOOL,
283237
)
284238

285239
function get_mpi_datatype_enum(datatype)

src/Ops.jl

Lines changed: 37 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,24 @@ using ..Reactant:
1515
using ReactantCore: ReactantCore
1616
using GPUArraysCore: GPUArraysCore
1717

18+
const GELU_APPROXIMATION_MAP = Dict(
19+
"NONE" => MLIR.API.ENZYMEXLA_GELU_APPROXIMATION_NONE,
20+
"TANH" => MLIR.API.ENZYMEXLA_GELU_APPROXIMATION_TANH,
21+
"SIGMOID" => MLIR.API.ENZYMEXLA_GELU_APPROXIMATION_SIGMOID,
22+
)
23+
24+
const LAPACK_TRANSPOSE_MAP = Dict(
25+
'N' => MLIR.API.ENZYMEXLA_LAPACK_TRANSPOSE_NONE,
26+
'T' => MLIR.API.ENZYMEXLA_LAPACK_TRANSPOSE_TRANSPOSE,
27+
'C' => MLIR.API.ENZYMEXLA_LAPACK_TRANSPOSE_CONJUGATE_TRANSPOSE,
28+
)
29+
30+
const LAPACK_UPLO_MAP = Dict(
31+
'U' => MLIR.API.ENZYMEXLA_LAPACK_UPLO_UPPER,
32+
'L' => MLIR.API.ENZYMEXLA_LAPACK_UPLO_LOWER,
33+
'F' => MLIR.API.ENZYMEXLA_LAPACK_UPLO_FULL,
34+
)
35+
1836
function _function_macro_error()
1937
throw(ArgumentError("`caller_function` is not available in this context"))
2038
end
@@ -3877,13 +3895,13 @@ end
38773895
info_size = batch_sizes
38783896

38793897
if algorithm == "DEFAULT"
3880-
algint = 0
3898+
alg = MLIR.API.ENZYMEXLA_SVD_ALGORITHM_NONE
38813899
elseif algorithm == "QRIteration"
3882-
algint = 1
3900+
alg = MLIR.API.ENZYMEXLA_SVD_ALGORITHM_QRITERATION
38833901
elseif algorithm == "DivideAndConquer"
3884-
algint = 2
3902+
alg = MLIR.API.ENZYMEXLA_SVD_ALGORITHM_DIVIDEANDCONQUER
38853903
elseif algorithm == "Jacobi"
3886-
algint = 3
3904+
alg = MLIR.API.ENZYMEXLA_SVD_ALGORITHM_JACOBI
38873905
else
38883906
error("Unsupported SVD algorithm: $algorithm")
38893907
end
@@ -3895,7 +3913,7 @@ end
38953913
Vt=mlir_type(TracedRArray{T,N}, Vt_size),
38963914
info=mlir_type(TracedRArray{iT,N - 2}, info_size),
38973915
full=full,
3898-
algorithm=MLIR.API.enzymexlaSVDAlgorithmAttrGet(MLIR.IR.current_context(), algint),
3916+
algorithm=MLIR.API.enzymexlaSVDAlgorithmAttrGet(MLIR.IR.current_context(), alg),
38993917
location,
39003918
)
39013919

@@ -4103,21 +4121,15 @@ end
41034121
approximation::String;
41044122
location=mlir_stacktrace("ml.gelu", @__FILE__, @__LINE__),
41054123
)
4106-
approx = if approximation == "NONE"
4107-
0
4108-
elseif approximation == "TANH"
4109-
1
4110-
elseif approximation == "SIGMOID"
4111-
2
4112-
else
4113-
error("Invalid gelu approximation: $approximation")
4114-
end
4115-
approx = MLIR.API.enzymexlaGeluApproximationAttrGet(
4116-
MLIR.IR.current_context(), Int32(approx)
4117-
)
4118-
41194124
res = MLIR.IR.result(
4120-
enzymexla.ml_gelu(x.mlir_data; gelu_approximation=approx, location), 1
4125+
enzymexla.ml_gelu(
4126+
x.mlir_data;
4127+
gelu_approximation=MLIR.API.enzymexlaGeluApproximationAttrGet(
4128+
MLIR.IR.current_context(), GELU_APPROXIMATION_MAP[approximation]
4129+
),
4130+
location,
4131+
),
4132+
1,
41214133
)
41224134

41234135
if x isa TracedRArray
@@ -4210,41 +4222,19 @@ end
42104222
location=mlir_stacktrace("syrk", @__FILE__, @__LINE__),
42114223
) where {T,N}
42124224
ctx = MLIR.IR.current_context()
4213-
uplo_attr = MLIR.API.enzymexlaLapackUploAttrGet(
4214-
ctx,
4215-
if uplo == 'U'
4216-
Int32(1)
4217-
elseif uplo == 'L'
4218-
Int32(0)
4219-
else
4220-
Int32(2)
4221-
end,
4222-
)
4223-
transpose_attr = MLIR.API.enzymexlaLapackTransposeAttrGet(
4224-
ctx,
4225-
if transpose_a == 'N'
4226-
Int32(0)
4227-
elseif transpose_a == 'T'
4228-
Int32(1)
4229-
elseif transpose_a == 'C'
4230-
Int32(2)
4231-
else
4232-
error("Unknown transpose mode: $transpose_a")
4233-
end,
4234-
)
4235-
4236-
alpha_ = constant(alpha; location)
4237-
beta_ = constant(beta; location)
4225+
uplo_attr = MLIR.API.enzymexlaLapackUploAttrGet(ctx, LAPACK_UPLO_MAP[uplo])
42384226

42394227
res = MLIR.IR.result(
42404228
enzymexla.blas_syrk(
42414229
A.mlir_data,
42424230
C.mlir_data,
4243-
alpha_.mlir_data,
4244-
beta_.mlir_data;
4231+
constant(alpha; location).mlir_data,
4232+
constant(beta; location).mlir_data;
42454233
uplo=uplo_attr,
42464234
output_uplo=uplo_attr,
4247-
transpose=transpose_attr,
4235+
transpose=MLIR.API.enzymexlaLapackTransposeAttrGet(
4236+
ctx, LAPACK_TRANSPOSE_MAP[transpose_a]
4237+
),
42484238
output=mlir_type(TracedRArray{T,N}, size(C)),
42494239
location,
42504240
),

0 commit comments

Comments
 (0)