1- const enzyme_out = 0
2- const enzyme_dup = 1
3- const enzyme_const = 2
4- const enzyme_dupnoneed = 3
5- const enzyme_outnoneed = 4
6- const enzyme_constnoneed = 5
1+ @enumx EnzymeActivity begin
2+ OUT = 0
3+ DUPLICATED = 1
4+ CONST = 2
5+ DUPLICATED_NO_NEED = 3
6+ OUT_NO_NEED = 4
7+ CONST_NO_NEED = 5
8+ end
79
810struct StackedBatchDuplicated{T,N,M,V<: AbstractArray{T,N} ,W<: AbstractArray{T,M} } < :
911 Annotation{V}
@@ -201,17 +203,17 @@ end
201203end
202204
203205@inline function act_from_type (:: Type{<:Active} , reverse, needs_primal)
204- return needs_primal ? enzyme_out : enzyme_outnoneed
206+ return needs_primal ? EnzymeActivity . OUT : EnzymeActivity . OUT_NO_NEED
205207end
206208@inline function act_from_type (:: Type{<:Const} , reverse, needs_primal)
207- return needs_primal ? enzyme_const : enzyme_constnoneed
209+ return needs_primal ? EnzymeActivity . CONST : EnzymeActivity . CONST_NO_NEED
208210end
209211
210212@inline function act_from_type (:: Type{<:Duplicated} , reverse, needs_primal)
211213 if reverse
212- return needs_primal ? enzyme_out : enzyme_outnoneed
214+ return needs_primal ? EnzymeActivity . OUT : EnzymeActivity . OUT_NO_NEED
213215 else
214- return needs_primal ? enzyme_dup : enzyme_dupnoneed
216+ return needs_primal ? EnzymeActivity . DUPLICATED : EnzymeActivity . DUPLICATED_NO_NEED
215217 end
216218end
217219@inline function act_from_type (
221223end
222224
223225@inline function act_from_type (:: Type{<:DuplicatedNoNeed} , reverse, needs_primal)
224- return reverse ? enzyme_out : enzyme_dupnoneed
226+ return reverse ? EnzymeActivity . OUT : EnzymeActivity . DUPLICATED_NO_NEED
225227end
226228@inline function act_from_type (
227229 :: Type{<:Union{BatchDuplicatedNoNeed,StackedBatchDuplicatedNoNeed}} ,
@@ -289,7 +291,9 @@ function set_act!(inp, path, reverse, tostore; emptypath=false, width=1)
289291end
290292
291293function act_attr (val)
292- return MLIR. IR. Attribute (MLIR. API. enzymeActivityAttrGet (MLIR. IR. current_context (), val))
294+ return MLIR. IR. Attribute (
295+ MLIR. API. enzymeActivityAttrGet (MLIR. IR. current_context (), Int32 (val))
296+ )
293297end
294298
295299function infer_activity (
@@ -342,7 +346,7 @@ function overload_autodiff(
342346 (; result, linear_args, in_tys, linear_results) = mlir_fn_res
343347 fnwrap = mlir_fn_res. fnwrapped
344348
345- activity = Int32 []
349+ activity = EnzymeActivity . T []
346350 ad_inputs = MLIR. IR. Value[]
347351
348352 reverse_seeds = Dict {Tuple,MLIR.IR.Value} ()
@@ -353,7 +357,7 @@ function overload_autodiff(
353357 push! (activity, act_from_type (arg, reverse))
354358 push_acts! (ad_inputs, arg, path[3 : end ], reverse)
355359
356- if CMode <: ReverseMode && act_from_type (arg, false ) == enzyme_dup
360+ if CMode <: ReverseMode && act_from_type (arg, false ) == EnzymeActivity . DUPLICATED
357361 x = if width == 1
358362 arg. dval
359363 elseif arg. dval isa AbstractArray
@@ -370,7 +374,7 @@ function overload_autodiff(
370374 end
371375
372376 outtys = MLIR. IR. Type[]
373- ret_activity = Int32 []
377+ ret_activity = EnzymeActivity . T []
374378
375379 for a in linear_results
376380 if TracedUtils. has_idx (a, resprefix)
@@ -395,7 +399,7 @@ function overload_autodiff(
395399
396400 act = act_from_type (A, reverse, EnzymeCore. needs_primal (CMode))
397401 cst = nothing
398- if act == enzyme_out || act == enzyme_outnoneed
402+ if act == EnzymeActivity . OUT || act == EnzymeActivity . OUT_NO_NEED
399403 if width == 1
400404 cst = @opcall fill (one (unwrapped_eltype (a)), size (a))
401405 else
@@ -407,26 +411,28 @@ function overload_autodiff(
407411 if CMode <: ReverseMode && TracedUtils. has_idx (a, argprefix)
408412 idx, path = TracedUtils. get_argidx (a, argprefix)
409413 arg = idx == 1 && fnwrap ? f : args[idx - fnwrap]
410- if act_from_type (arg, false ) == enzyme_dup
414+ if act_from_type (arg, false ) == EnzymeActivity . DUPLICATED
411415 seed = reverse_seeds[path]
412- if cst == nothing
413- if act == enzyme_const
414- act = enzyme_out
415- elseif act == enzyme_constnoneed
416- act = enzyme_outnoneed
416+ if cst === nothing
417+ if act == EnzymeActivity . CONST
418+ act = EnzymeActivity . OUT
419+ elseif act == EnzymeActivity . CONST_NO_NEED
420+ act = EnzymeActivity . OUT_NO_NEED
417421 else
418422 @assert false
419423 end
420424 cst = seed
421425 else
422- @assert act == enzyme_out || act == enzyme_outnoneed
426+ @assert (
427+ act == EnzymeActivity. OUT || act == EnzymeActivity. OUT_NO_NEED
428+ )
423429 cst = MLIR. IR. result (MLIR. Dialects. stablehlo. add (cst, seed), 1 )
424430 end
425431 end
426432 end
427433
428434 push! (ret_activity, act)
429- if cst != nothing
435+ if cst != = nothing
430436 push! (ad_inputs, cst)
431437 end
432438 else
@@ -437,7 +443,7 @@ function overload_autodiff(
437443 act = act_from_type (arg, reverse, true )
438444 push! (ret_activity, act)
439445
440- if act == enzyme_out || act == enzyme_outnoneed
446+ if act == EnzymeActivity . OUT || act == EnzymeActivity . OUT_NO_NEED
441447 seed = reverse_seeds[path]
442448 push! (ad_inputs, seed)
443449 end
@@ -453,7 +459,11 @@ function overload_autodiff(
453459 end
454460
455461 for (i, act) in enumerate (activity)
456- if act == enzyme_out || act == enzyme_dup || act == enzyme_dupnoneed
462+ if (
463+ act == EnzymeActivity. OUT ||
464+ act == EnzymeActivity. DUPLICATED ||
465+ act == EnzymeActivity. DUPLICATED_NO_NEED
466+ )
457467 push! (outtys, TracedUtils. batch_ty (width, in_tys[i]))
458468 end
459469 end
@@ -528,7 +538,7 @@ function overload_autodiff(
528538 idx, path = TracedUtils. get_argidx (a, argprefix)
529539
530540 arg = idx == 1 && fnwrap ? f : args[idx - fnwrap]
531- act_from_type (arg, reverse) != enzyme_out && continue
541+ act_from_type (arg, reverse) != EnzymeActivity . OUT && continue
532542
533543 if idx == 1 && fnwrap && arg isa Active
534544 @assert false
0 commit comments