Skip to content

Commit 79d3219

Browse files
authored
refactor: improve readability of Enzyme.jl with enums (#2842)
1 parent 7378556 commit 79d3219

2 files changed

Lines changed: 38 additions & 27 deletions

File tree

src/Enzyme.jl

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
const enzyme_out = 0
2-
const enzyme_dup = 1
3-
const enzyme_const = 2
4-
const enzyme_dupnoneed = 3
5-
const enzyme_outnoneed = 4
6-
const enzyme_constnoneed = 5
1+
@enumx EnzymeActivity begin
2+
OUT = 0
3+
DUPLICATED = 1
4+
CONST = 2
5+
DUPLICATED_NO_NEED = 3
6+
OUT_NO_NEED = 4
7+
CONST_NO_NEED = 5
8+
end
79

810
struct StackedBatchDuplicated{T,N,M,V<:AbstractArray{T,N},W<:AbstractArray{T,M}} <:
911
Annotation{V}
@@ -201,17 +203,17 @@ end
201203
end
202204

203205
@inline function act_from_type(::Type{<:Active}, reverse, needs_primal)
204-
return needs_primal ? enzyme_out : enzyme_outnoneed
206+
return needs_primal ? EnzymeActivity.OUT : EnzymeActivity.OUT_NO_NEED
205207
end
206208
@inline function act_from_type(::Type{<:Const}, reverse, needs_primal)
207-
return needs_primal ? enzyme_const : enzyme_constnoneed
209+
return needs_primal ? EnzymeActivity.CONST : EnzymeActivity.CONST_NO_NEED
208210
end
209211

210212
@inline function act_from_type(::Type{<:Duplicated}, reverse, needs_primal)
211213
if reverse
212-
return needs_primal ? enzyme_out : enzyme_outnoneed
214+
return needs_primal ? EnzymeActivity.OUT : EnzymeActivity.OUT_NO_NEED
213215
else
214-
return needs_primal ? enzyme_dup : enzyme_dupnoneed
216+
return needs_primal ? EnzymeActivity.DUPLICATED : EnzymeActivity.DUPLICATED_NO_NEED
215217
end
216218
end
217219
@inline function act_from_type(
@@ -221,7 +223,7 @@ end
221223
end
222224

223225
@inline function act_from_type(::Type{<:DuplicatedNoNeed}, reverse, needs_primal)
224-
return reverse ? enzyme_out : enzyme_dupnoneed
226+
return reverse ? EnzymeActivity.OUT : EnzymeActivity.DUPLICATED_NO_NEED
225227
end
226228
@inline function act_from_type(
227229
::Type{<:Union{BatchDuplicatedNoNeed,StackedBatchDuplicatedNoNeed}},
@@ -289,7 +291,9 @@ function set_act!(inp, path, reverse, tostore; emptypath=false, width=1)
289291
end
290292

291293
function act_attr(val)
292-
return MLIR.IR.Attribute(MLIR.API.enzymeActivityAttrGet(MLIR.IR.current_context(), val))
294+
return MLIR.IR.Attribute(
295+
MLIR.API.enzymeActivityAttrGet(MLIR.IR.current_context(), Int32(val))
296+
)
293297
end
294298

295299
function infer_activity(
@@ -342,7 +346,7 @@ function overload_autodiff(
342346
(; result, linear_args, in_tys, linear_results) = mlir_fn_res
343347
fnwrap = mlir_fn_res.fnwrapped
344348

345-
activity = Int32[]
349+
activity = EnzymeActivity.T[]
346350
ad_inputs = MLIR.IR.Value[]
347351

348352
reverse_seeds = Dict{Tuple,MLIR.IR.Value}()
@@ -353,7 +357,7 @@ function overload_autodiff(
353357
push!(activity, act_from_type(arg, reverse))
354358
push_acts!(ad_inputs, arg, path[3:end], reverse)
355359

356-
if CMode <: ReverseMode && act_from_type(arg, false) == enzyme_dup
360+
if CMode <: ReverseMode && act_from_type(arg, false) == EnzymeActivity.DUPLICATED
357361
x = if width == 1
358362
arg.dval
359363
elseif arg.dval isa AbstractArray
@@ -370,7 +374,7 @@ function overload_autodiff(
370374
end
371375

372376
outtys = MLIR.IR.Type[]
373-
ret_activity = Int32[]
377+
ret_activity = EnzymeActivity.T[]
374378

375379
for a in linear_results
376380
if TracedUtils.has_idx(a, resprefix)
@@ -395,7 +399,7 @@ function overload_autodiff(
395399

396400
act = act_from_type(A, reverse, EnzymeCore.needs_primal(CMode))
397401
cst = nothing
398-
if act == enzyme_out || act == enzyme_outnoneed
402+
if act == EnzymeActivity.OUT || act == EnzymeActivity.OUT_NO_NEED
399403
if width == 1
400404
cst = @opcall fill(one(unwrapped_eltype(a)), size(a))
401405
else
@@ -407,26 +411,28 @@ function overload_autodiff(
407411
if CMode <: ReverseMode && TracedUtils.has_idx(a, argprefix)
408412
idx, path = TracedUtils.get_argidx(a, argprefix)
409413
arg = idx == 1 && fnwrap ? f : args[idx - fnwrap]
410-
if act_from_type(arg, false) == enzyme_dup
414+
if act_from_type(arg, false) == EnzymeActivity.DUPLICATED
411415
seed = reverse_seeds[path]
412-
if cst == nothing
413-
if act == enzyme_const
414-
act = enzyme_out
415-
elseif act == enzyme_constnoneed
416-
act = enzyme_outnoneed
416+
if cst === nothing
417+
if act == EnzymeActivity.CONST
418+
act = EnzymeActivity.OUT
419+
elseif act == EnzymeActivity.CONST_NO_NEED
420+
act = EnzymeActivity.OUT_NO_NEED
417421
else
418422
@assert false
419423
end
420424
cst = seed
421425
else
422-
@assert act == enzyme_out || act == enzyme_outnoneed
426+
@assert (
427+
act == EnzymeActivity.OUT || act == EnzymeActivity.OUT_NO_NEED
428+
)
423429
cst = MLIR.IR.result(MLIR.Dialects.stablehlo.add(cst, seed), 1)
424430
end
425431
end
426432
end
427433

428434
push!(ret_activity, act)
429-
if cst != nothing
435+
if cst !== nothing
430436
push!(ad_inputs, cst)
431437
end
432438
else
@@ -437,7 +443,7 @@ function overload_autodiff(
437443
act = act_from_type(arg, reverse, true)
438444
push!(ret_activity, act)
439445

440-
if act == enzyme_out || act == enzyme_outnoneed
446+
if act == EnzymeActivity.OUT || act == EnzymeActivity.OUT_NO_NEED
441447
seed = reverse_seeds[path]
442448
push!(ad_inputs, seed)
443449
end
@@ -453,7 +459,11 @@ function overload_autodiff(
453459
end
454460

455461
for (i, act) in enumerate(activity)
456-
if act == enzyme_out || act == enzyme_dup || act == enzyme_dupnoneed
462+
if (
463+
act == EnzymeActivity.OUT ||
464+
act == EnzymeActivity.DUPLICATED ||
465+
act == EnzymeActivity.DUPLICATED_NO_NEED
466+
)
457467
push!(outtys, TracedUtils.batch_ty(width, in_tys[i]))
458468
end
459469
end
@@ -528,7 +538,7 @@ function overload_autodiff(
528538
idx, path = TracedUtils.get_argidx(a, argprefix)
529539

530540
arg = idx == 1 && fnwrap ? f : args[idx - fnwrap]
531-
act_from_type(arg, reverse) != enzyme_out && continue
541+
act_from_type(arg, reverse) != EnzymeActivity.OUT && continue
532542

533543
if idx == 1 && fnwrap && arg isa Active
534544
@assert false

test/core/qa.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ end
9494
Reactant.PrecisionConfig,
9595
Reactant.InterpolationType,
9696
ReactantMPIExt.Ops,
97+
Reactant.EnzymeActivity,
9798
)
9899

99100
test_explicit_imports(

0 commit comments

Comments
 (0)