From 378dc40eb3426c219cf3c4d5eb50b4575f738acb Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Tue, 7 Jan 2025 13:12:50 +0100 Subject: [PATCH 01/14] WIP: switching interpreter in `set_reactant_abi` when encountering `f === overload_autodiff`. This doesn't work for some reason, the function within overload autodiff uses the original interpreter (?) --- src/Interpreter.jl | 14 ++++++++++++-- test/autodiff.jl | 12 ++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 15698aee50..74dfe0e05e 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -63,6 +63,16 @@ function set_reactant_abi( if f === Reactant.call_with_reactant arginfo2 = ArgInfo(fargs isa Nothing ? nothing : fargs[2:end], argtypes[2:end]) return abstract_call(interp, arginfo2::ArgInfo, si, sv, max_methods) + elseif interp.defer_within_autodiff && f === overload_autodiff + interp′ = Enzyme.Compiler.Interpreter.EnzymeInterpreter(interp; defer_within_autodiff=false) + return Base.@invoke abstract_call_known( + interp′::Enzyme.Compiler.Interpreter.EnzymeInterpreter, + f, + arginfo, + si, + sv, + max_methods, + ) end return Base.@invoke abstract_call_known( @@ -87,7 +97,7 @@ end false, #=reverse_rules=# false, #=inactive_rules=# false, #=broadcast_rewrite=# - false, #=within_autodiff_rewrite=# + true, #=defer_within_autodiff=# set_reactant_abi, ) end @@ -105,7 +115,7 @@ else false, #=reverse_rules=# false, #=inactive_rules=# false, #=broadcast_rewrite=# - false, #=within_autodiff_rewrite=# + true, #=defer_within_autodiff=# set_reactant_abi, ) end diff --git a/test/autodiff.jl b/test/autodiff.jl index b52891c161..a127bba279 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -126,6 +126,18 @@ end @test res1[1] ≈ ores1[1] end +function error_not_within_autodiff() + !Enzyme.within_autodiff() && error("Not within autodiff") + return +end + +fwd_within_autodiff(Mode, RT) = Enzyme.autodiff(Mode, error_not_within_autodiff, RT) + +@testset "within_autodiff" begin + @test_thows ErrorException @jit error_not_within_autodiff() + @test isnothing(@jit fwd_within_autodiff(Forward, Const)) +end + function gw(z) return Enzyme.gradient(Forward, sum, z; chunk=Val(1)) end From 5812016bca25eeb554ff372b061d4f63bc7f98b2 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Tue, 7 Jan 2025 13:17:42 +0100 Subject: [PATCH 02/14] other approach: pass within_autodiff from overload_autodiff --> make_mlir_fn. In order to pass this information from make_mlir_fn to call_with_reactant_generator, I introduced a new function `call_with_reactant_within_autodiff` which allows detection by looking at `self`. --- src/Interpreter.jl | 8 ++++---- src/TracedUtils.jl | 15 +++++++++++++-- src/utils.jl | 9 ++++++++- test/autodiff.jl | 9 ++++++--- 4 files changed, 31 insertions(+), 10 deletions(-) diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 74dfe0e05e..d073411eda 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -88,7 +88,7 @@ end @static if Enzyme.GPUCompiler.HAS_INTEGRATED_CACHE struct ReactantCacheToken end - function ReactantInterpreter(; world::UInt=Base.get_world_counter()) + function ReactantInterpreter(; world::UInt=Base.get_world_counter(), within_autodiff=false) return Enzyme.Compiler.Interpreter.EnzymeInterpreter( ReactantCacheToken(), REACTANT_METHOD_TABLE, @@ -97,7 +97,7 @@ end false, #=reverse_rules=# false, #=inactive_rules=# false, #=broadcast_rewrite=# - true, #=defer_within_autodiff=# + !within_autodiff, #=defer_within_autodiff=# set_reactant_abi, ) end @@ -105,7 +105,7 @@ else const REACTANT_CACHE = Enzyme.GPUCompiler.CodeCache() function ReactantInterpreter(; - world::UInt=Base.get_world_counter(), code_cache=REACTANT_CACHE + world::UInt=Base.get_world_counter(), code_cache=REACTANT_CACHE, within_autodiff=false ) return Enzyme.Compiler.Interpreter.EnzymeInterpreter( REACTANT_CACHE, @@ -115,7 +115,7 @@ else false, #=reverse_rules=# false, #=inactive_rules=# false, #=broadcast_rewrite=# - true, #=defer_within_autodiff=# + !within_autodiff, #=defer_within_autodiff=# set_reactant_abi, ) end diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 98b83ea006..c8ea10a91e 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -282,6 +282,7 @@ function make_mlir_fn( args_in_result::Symbol=:all, construct_function_without_args::Bool=false, do_transpose=true, + within_autodiff=false, input_shardings=nothing, # This is not meant to be used by the user. output_shardings=nothing, # This is not meant to be used by the user. runtime=nothing, @@ -341,9 +342,19 @@ function make_mlir_fn( process_linear_args!(linear_args, fnbody, do_transpose, optimize_then_pad, inv_map) if isempty(kwargs) - Reactant.call_with_reactant(f, traced_args...) + if within_autodiff + Reactant.call_with_reactant_within_autodiff(f, traced_args...) + else + Reactant.call_with_reactant(f, traced_args...) + end else - Reactant.call_with_reactant(Core.kwcall, kwargs, f, traced_args...) + if within_autodiff + Reactant.call_with_reactant_within_autodiff( + Core.kwcall, kwargs, f, traced_args... + ) + else + Reactant.call_with_reactant(Core.kwcall, kwargs, f, traced_args...) + end end finally MLIR.IR.deactivate!(fnbody) diff --git a/src/utils.jl b/src/utils.jl index d4af7bb72d..772606efda 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -19,6 +19,7 @@ function apply(f::F, args...; kwargs...) where {F} end function call_with_reactant end +function call_with_reactant_within_autodiff end function maybe_argextype(@nospecialize(x), src) return try @@ -647,7 +648,9 @@ function call_with_reactant_generator( )) end - interp = ReactantInterpreter(; world) + interp = ReactantInterpreter(; + world, within_autodiff=self == typeof(Reactant.call_with_reactant_within_autodiff) + ) min_world = Ref{UInt}(typemin(UInt)) max_world = Ref{UInt}(typemax(UInt)) @@ -891,6 +894,10 @@ end $(Expr(:meta, :generated_only)) return $(Expr(:meta, :generated, call_with_reactant_generator)) end +@eval function call_with_reactant_within_autodiff($REDUB_ARGUMENTS_NAME...) + $(Expr(:meta, :generated_only)) + return $(Expr(:meta, :generated, call_with_reactant_generator)) +end @static if isdefined(Core, :BFloat16) nmantissa(::Type{Core.BFloat16}) = 7 diff --git a/test/autodiff.jl b/test/autodiff.jl index a127bba279..cae5dc15a7 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -128,14 +128,17 @@ end function error_not_within_autodiff() !Enzyme.within_autodiff() && error("Not within autodiff") - return + return nothing end fwd_within_autodiff(Mode, RT) = Enzyme.autodiff(Mode, error_not_within_autodiff, RT) @testset "within_autodiff" begin - @test_thows ErrorException @jit error_not_within_autodiff() - @test isnothing(@jit fwd_within_autodiff(Forward, Const)) + @test_throws ErrorException error_not_within_autodiff() + @test fwd_within_autodiff(Forward, Const) == () + + @test_throws ErrorException @jit error_not_within_autodiff() + @test (@jit fwd_within_autodiff(Forward, Const)) == () end function gw(z) From 0f8431811f837e87027eff57d80c4a6101004375 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Tue, 7 Jan 2025 13:33:10 +0100 Subject: [PATCH 03/14] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/Interpreter.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/Interpreter.jl b/src/Interpreter.jl index d073411eda..b08027be34 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -64,7 +64,9 @@ function set_reactant_abi( arginfo2 = ArgInfo(fargs isa Nothing ? nothing : fargs[2:end], argtypes[2:end]) return abstract_call(interp, arginfo2::ArgInfo, si, sv, max_methods) elseif interp.defer_within_autodiff && f === overload_autodiff - interp′ = Enzyme.Compiler.Interpreter.EnzymeInterpreter(interp; defer_within_autodiff=false) + interp′ = Enzyme.Compiler.Interpreter.EnzymeInterpreter( + interp; defer_within_autodiff=false + ) return Base.@invoke abstract_call_known( interp′::Enzyme.Compiler.Interpreter.EnzymeInterpreter, f, @@ -88,7 +90,9 @@ end @static if Enzyme.GPUCompiler.HAS_INTEGRATED_CACHE struct ReactantCacheToken end - function ReactantInterpreter(; world::UInt=Base.get_world_counter(), within_autodiff=false) + function ReactantInterpreter(; + world::UInt=Base.get_world_counter(), within_autodiff=false + ) return Enzyme.Compiler.Interpreter.EnzymeInterpreter( ReactantCacheToken(), REACTANT_METHOD_TABLE, @@ -105,7 +109,9 @@ else const REACTANT_CACHE = Enzyme.GPUCompiler.CodeCache() function ReactantInterpreter(; - world::UInt=Base.get_world_counter(), code_cache=REACTANT_CACHE, within_autodiff=false + world::UInt=Base.get_world_counter(), + code_cache=REACTANT_CACHE, + within_autodiff=false, ) return Enzyme.Compiler.Interpreter.EnzymeInterpreter( REACTANT_CACHE, From c78599d4682cd72c92807a9691f9bab55829722f Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Wed, 8 Jan 2025 13:14:08 +0100 Subject: [PATCH 04/14] `!defer_within_autodiff` -> `within_autodiff_rewrite` --- src/Interpreter.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Interpreter.jl b/src/Interpreter.jl index b08027be34..0ba521e1f5 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -63,9 +63,9 @@ function set_reactant_abi( if f === Reactant.call_with_reactant arginfo2 = ArgInfo(fargs isa Nothing ? nothing : fargs[2:end], argtypes[2:end]) return abstract_call(interp, arginfo2::ArgInfo, si, sv, max_methods) - elseif interp.defer_within_autodiff && f === overload_autodiff + elseif !(interp.within_autodiff_rewrite) && f === overload_autodiff interp′ = Enzyme.Compiler.Interpreter.EnzymeInterpreter( - interp; defer_within_autodiff=false + interp; within_autodiff_rewrite=true ) return Base.@invoke abstract_call_known( interp′::Enzyme.Compiler.Interpreter.EnzymeInterpreter, @@ -101,7 +101,7 @@ end false, #=reverse_rules=# false, #=inactive_rules=# false, #=broadcast_rewrite=# - !within_autodiff, #=defer_within_autodiff=# + within_autodiff, #=within_autodiff_rewrite=# set_reactant_abi, ) end @@ -121,7 +121,7 @@ else false, #=reverse_rules=# false, #=inactive_rules=# false, #=broadcast_rewrite=# - !within_autodiff, #=defer_within_autodiff=# + within_autodiff, #=within_autodiff_rewrite=# set_reactant_abi, ) end From a37921ff9471f067bba33d6e0632a200928142ed Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 1 Sep 2025 19:41:40 -0400 Subject: [PATCH 05/14] fix: set within_autodiff --- src/Enzyme.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index adc802db84..baa2e0cff7 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -314,6 +314,7 @@ function overload_autodiff( argprefix, resprefix, resargprefix, + within_autodiff=true, ) (; result, linear_args, in_tys, linear_results) = mlir_fn_res fnwrap = mlir_fn_res.fnwrapped From 4defdfc196e41853f18a39bc9948455791bddb6c Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:41:02 -0600 Subject: [PATCH 06/14] Don't create new interpreter work with global ref instead This has the unfortunate downside of introducing a try finally block around `overload_autodiff`. --- src/Enzyme.jl | 364 +++++++++++++++++++++++---------------------- src/Interpreter.jl | 12 -- src/Overlay.jl | 4 + 3 files changed, 192 insertions(+), 188 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index d8d6c089ab..197aaac0c7 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -295,228 +295,240 @@ function act_attr(val) return MLIR.IR.Attribute(val) end +const WITHIN_AUTODIFF = Ref(false) + +function overload_within_autodiff() + return WITHIN_AUTODIFF[] +end + function overload_autodiff( ::CMode, f::FA, ::Type{A}, args::Vararg{Annotation,Nargs} ) where {CMode<:Mode,FA<:Annotation,A<:Annotation,Nargs} - reverse = CMode <: ReverseMode + WITHIN_AUTODIFF[] = true + try + reverse = CMode <: ReverseMode - width = Enzyme.same_or_one(1, args...) - if width == 0 - throw(ErrorException("Cannot differentiate with a batch size of 0")) - end + width = Enzyme.same_or_one(1, args...) + if width == 0 + throw(ErrorException("Cannot differentiate with a batch size of 0")) + end - primf = f.val - primargs = ((v.val for v in args)...,) - - argprefix::Symbol = gensym("autodiffarg") - resprefix::Symbol = gensym("autodiffresult") - resargprefix::Symbol = gensym("autodiffresarg") - - mlir_fn_res = TracedUtils.make_mlir_fn( - primf, - primargs, - (), - string(f) * "_autodiff", - false; - argprefix, - resprefix, - resargprefix, - within_autodiff=true, - ) - (; result, linear_args, in_tys, linear_results) = mlir_fn_res - fnwrap = mlir_fn_res.fnwrapped - func2 = mlir_fn_res.f - - activity = Int32[] - ad_inputs = MLIR.IR.Value[] - - for a in linear_args - idx, path = TracedUtils.get_argidx(a, argprefix) - arg = idx == 1 && fnwrap ? f : args[idx - fnwrap] - push!(activity, act_from_type(arg, reverse)) - push_acts!(ad_inputs, arg, path[3:end], reverse) - end + primf = f.val + primargs = ((v.val for v in args)...,) + + argprefix::Symbol = gensym("autodiffarg") + resprefix::Symbol = gensym("autodiffresult") + resargprefix::Symbol = gensym("autodiffresarg") + + mlir_fn_res = TracedUtils.make_mlir_fn( + primf, + primargs, + (), + string(f) * "_autodiff", + false; + argprefix, + resprefix, + resargprefix, + within_autodiff=true, + ) + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f - outtys = MLIR.IR.Type[] - ret_activity = Int32[] + activity = Int32[] + ad_inputs = MLIR.IR.Value[] - for a in linear_results - if TracedUtils.has_idx(a, resprefix) - if EnzymeCore.needs_primal(CMode) - push!( - outtys, - TracedUtils.transpose_ty(MLIR.IR.type(TracedUtils.get_mlir_data(a))), - ) - end + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) + arg = idx == 1 && fnwrap ? f : args[idx - fnwrap] + push!(activity, act_from_type(arg, reverse)) + push_acts!(ad_inputs, arg, path[3:end], reverse) + end - if CMode <: ForwardMode && !(A <: Const) - push!( - outtys, - TracedUtils.batch_ty( - width, - TracedUtils.transpose_ty( - MLIR.IR.type(TracedUtils.get_mlir_data(a)) - ), - ), - ) - end + outtys = MLIR.IR.Type[] + ret_activity = Int32[] - act = act_from_type(A, reverse, EnzymeCore.needs_primal(CMode)) - push!(ret_activity, act) - if act == enzyme_out || act == enzyme_outnoneed - if width == 1 - cst = @opcall fill(one(unwrapped_eltype(a)), size(a)) - else - cst = @opcall fill(one(unwrapped_eltype(a)), (size(a)..., width)) + for a in linear_results + if TracedUtils.has_idx(a, resprefix) + if EnzymeCore.needs_primal(CMode) + push!( + outtys, + TracedUtils.transpose_ty(MLIR.IR.type(TracedUtils.get_mlir_data(a))), + ) end - push!(ad_inputs, cst.mlir_data) - end - else - if TracedUtils.has_idx(a, argprefix) - idx, path = TracedUtils.get_argidx(a, argprefix) - arg = idx == 1 && fnwrap ? f : args[idx - fnwrap] - act = act_from_type(arg, reverse, true) - push!(ret_activity, act) + if CMode <: ForwardMode && !(A <: Const) + push!( + outtys, + TracedUtils.batch_ty( + width, + TracedUtils.transpose_ty( + MLIR.IR.type(TracedUtils.get_mlir_data(a)) + ), + ), + ) + end + act = act_from_type(A, reverse, EnzymeCore.needs_primal(CMode)) + push!(ret_activity, act) if act == enzyme_out || act == enzyme_outnoneed if width == 1 - TracedUtils.push_val!(ad_inputs, arg.dval, path[3:end]) - elseif arg.dval isa AbstractArray - TracedUtils.push_val!(ad_inputs, arg.dval, path[3:end]) + cst = @opcall fill(one(unwrapped_eltype(a)), size(a)) else - TracedUtils.push_val!( - ad_inputs, call_with_reactant(stack, arg.dval), path[3:end] - ) + cst = @opcall fill(one(unwrapped_eltype(a)), (size(a)..., width)) end + push!(ad_inputs, cst.mlir_data) end else - act = act_from_type(Const, reverse, true) - push!(ret_activity, act) - end + if TracedUtils.has_idx(a, argprefix) + idx, path = TracedUtils.get_argidx(a, argprefix) + arg = idx == 1 && fnwrap ? f : args[idx - fnwrap] + + act = act_from_type(arg, reverse, true) + push!(ret_activity, act) + + if act == enzyme_out || act == enzyme_outnoneed + if width == 1 + TracedUtils.push_val!(ad_inputs, arg.dval, path[3:end]) + elseif arg.dval isa AbstractArray + TracedUtils.push_val!(ad_inputs, arg.dval, path[3:end]) + else + TracedUtils.push_val!( + ad_inputs, call_with_reactant(stack, arg.dval), path[3:end] + ) + end + end + else + act = act_from_type(Const, reverse, true) + push!(ret_activity, act) + end - push!( - outtys, TracedUtils.transpose_ty(MLIR.IR.type(TracedUtils.get_mlir_data(a))) - ) + push!( + outtys, TracedUtils.transpose_ty(MLIR.IR.type(TracedUtils.get_mlir_data(a))) + ) + end end - end - for (i, act) in enumerate(activity) - if act == enzyme_out || act == enzyme_dup || act == enzyme_dupnoneed - push!(outtys, TracedUtils.batch_ty(width, in_tys[i])) + for (i, act) in enumerate(activity) + if act == enzyme_out || act == enzyme_dup || act == enzyme_dupnoneed + push!(outtys, TracedUtils.batch_ty(width, in_tys[i])) + end end - end - fname = TracedUtils.get_attribute_by_name(func2, "sym_name") - fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) - res = (reverse ? MLIR.Dialects.enzyme.autodiff : MLIR.Dialects.enzyme.fwddiff)( - [TracedUtils.transpose_val(v) for v in ad_inputs]; - outputs=outtys, - fn=fname, - width, - strong_zero=EnzymeCore.strong_zero(CMode), - activity=MLIR.IR.Attribute([act_attr(a) for a in activity]), - ret_activity=MLIR.IR.Attribute([act_attr(a) for a in ret_activity]), - ) + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + res = (reverse ? MLIR.Dialects.enzyme.autodiff : MLIR.Dialects.enzyme.fwddiff)( + [TracedUtils.transpose_val(v) for v in ad_inputs]; + outputs=outtys, + fn=fname, + width, + strong_zero=EnzymeCore.strong_zero(CMode), + activity=MLIR.IR.Attribute([act_attr(a) for a in activity]), + ret_activity=MLIR.IR.Attribute([act_attr(a) for a in ret_activity]), + ) - residx = 1 + residx = 1 - dresult = if CMode <: ForwardMode && !(A <: Const) - if width == 1 - deepcopy(result) - else - ntuple(Val(width)) do i - Base.@_inline_meta + dresult = if CMode <: ForwardMode && !(A <: Const) + if width == 1 deepcopy(result) + else + ntuple(Val(width)) do i + Base.@_inline_meta + deepcopy(result) + end end + else + nothing end - else - nothing - end - for a in linear_results - if TracedUtils.has_idx(a, resprefix) - if EnzymeCore.needs_primal(CMode) - path = TracedUtils.get_idx(a, resprefix) - tval = TracedUtils.transpose_val(MLIR.IR.result(res, residx)) - TracedUtils.set!(result, path[2:end], tval) - residx += 1 - end - if CMode <: ForwardMode && !(A <: Const) - path = TracedUtils.get_idx(a, resprefix) - tval = TracedUtils.transpose_val(MLIR.IR.result(res, residx)) - if width == 1 - TracedUtils.set!(dresult, path[2:end], tval) - else - ttval = TracedRArray(tval) - for (i, sl) in enumerate(eachslice(ttval; dims=ndims(ttval))) - TracedUtils.set!( - dresult[i], - path[2:end], - @allowscalar(TracedUtils.get_mlir_data(sl)) - ) + for a in linear_results + if TracedUtils.has_idx(a, resprefix) + if EnzymeCore.needs_primal(CMode) + path = TracedUtils.get_idx(a, resprefix) + tval = TracedUtils.transpose_val(MLIR.IR.result(res, residx)) + TracedUtils.set!(result, path[2:end], tval) + residx += 1 + end + if CMode <: ForwardMode && !(A <: Const) + path = TracedUtils.get_idx(a, resprefix) + tval = TracedUtils.transpose_val(MLIR.IR.result(res, residx)) + if width == 1 + TracedUtils.set!(dresult, path[2:end], tval) + else + ttval = TracedRArray(tval) + for (i, sl) in enumerate(eachslice(ttval; dims=ndims(ttval))) + TracedUtils.set!( + dresult[i], + path[2:end], + @allowscalar(TracedUtils.get_mlir_data(sl)) + ) + end end + residx += 1 end + elseif TracedUtils.has_idx(a, argprefix) + idx, path = TracedUtils.get_argidx(a, argprefix) + arg = idx == 1 && fnwrap ? f : args[idx - fnwrap] + TracedUtils.set!( + arg.val, path[3:end], TracedUtils.transpose_val(MLIR.IR.result(res, residx)) + ) + residx += 1 + else + TracedUtils.set!(a, (), TracedUtils.transpose_val(MLIR.IR.result(res, residx))) residx += 1 end - elseif TracedUtils.has_idx(a, argprefix) - idx, path = TracedUtils.get_argidx(a, argprefix) - arg = idx == 1 && fnwrap ? f : args[idx - fnwrap] - TracedUtils.set!( - arg.val, path[3:end], TracedUtils.transpose_val(MLIR.IR.result(res, residx)) - ) - residx += 1 - else - TracedUtils.set!(a, (), TracedUtils.transpose_val(MLIR.IR.result(res, residx))) - residx += 1 end - end - restup = Any[(a isa Active) ? copy(a) : nothing for a in args] - for a in linear_args - idx, path = TracedUtils.get_argidx(a, argprefix) + restup = Any[(a isa Active) ? copy(a) : nothing for a in args] + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) - arg = idx == 1 && fnwrap ? f : args[idx - fnwrap] - act_from_type(arg, reverse) != enzyme_out && continue + arg = idx == 1 && fnwrap ? f : args[idx - fnwrap] + act_from_type(arg, reverse) != enzyme_out && continue - if idx == 1 && fnwrap && arg isa Active - @assert false - end + if idx == 1 && fnwrap && arg isa Active + @assert false + end - set_act!( - arg, - path[3:end], - reverse, - TracedUtils.transpose_val(MLIR.IR.result(res, residx)); - width, - emptypath=arg isa Active, - ) - residx += 1 - end + set_act!( + arg, + path[3:end], + reverse, + TracedUtils.transpose_val(MLIR.IR.result(res, residx)); + width, + emptypath=arg isa Active, + ) + residx += 1 + end - func2.operation = MLIR.API.MlirOperation(C_NULL) + func2.operation = MLIR.API.MlirOperation(C_NULL) - if reverse - resv = if EnzymeCore.needs_primal(CMode) - result - else - nothing - end - return ((restup...,), resv) - else - if EnzymeCore.needs_primal(CMode) - if CMode <: ForwardMode && !(A <: Const) - return (dresult, result) + if reverse + resv = if EnzymeCore.needs_primal(CMode) + result else - return (result,) + nothing end + return ((restup...,), resv) else - if CMode <: ForwardMode && !(A <: Const) - return (dresult,) + if EnzymeCore.needs_primal(CMode) + if CMode <: ForwardMode && !(A <: Const) + return (dresult, result) + else + return (result,) + end else - return () + if CMode <: ForwardMode && !(A <: Const) + return (dresult,) + else + return () + end end end + + finally + WITHIN_AUTODIFF[] = false end end diff --git a/src/Interpreter.jl b/src/Interpreter.jl index a78485dceb..a3fc17fcbc 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -74,18 +74,6 @@ function set_reactant_abi( if f === call_with_reactant arginfo2 = ArgInfo(fargs isa Nothing ? nothing : fargs[2:end], argtypes[2:end]) return abstract_call(interp, arginfo2::ArgInfo, si, sv, max_methods) - elseif !(interp.within_autodiff_rewrite) && f === overload_autodiff - interp′ = Enzyme.Compiler.Interpreter.EnzymeInterpreter( - interp; within_autodiff_rewrite=true - ) - return Base.@invoke abstract_call_known( - interp′::Enzyme.Compiler.Interpreter.EnzymeInterpreter, - f, - arginfo, - si, - sv, - max_methods, - ) end return Base.@invoke abstract_call_known( diff --git a/src/Overlay.jl b/src/Overlay.jl index e837966cb7..63b691b608 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -9,6 +9,10 @@ end # Enzyme.jl overlays +@reactant_overlay @noinline function Enzyme.within_autodiff() + return overload_within_autodiff() +end + @reactant_overlay @noinline function Enzyme.autodiff_deferred( rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs} ) where {FA<:Annotation,A<:Annotation,Nargs} From 1b52e39be0458d6564d019112fbe4aae467771af Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:44:43 -0600 Subject: [PATCH 07/14] refactor put try-finally outside of overload_autodiff call --- src/Enzyme.jl | 363 ++++++++++++++++++++++++------------------------- src/Overlay.jl | 12 +- 2 files changed, 186 insertions(+), 189 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 197aaac0c7..52e0a79612 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -295,240 +295,229 @@ function act_attr(val) return MLIR.IR.Attribute(val) end -const WITHIN_AUTODIFF = Ref(false) - -function overload_within_autodiff() - return WITHIN_AUTODIFF[] -end function overload_autodiff( ::CMode, f::FA, ::Type{A}, args::Vararg{Annotation,Nargs} ) where {CMode<:Mode,FA<:Annotation,A<:Annotation,Nargs} - WITHIN_AUTODIFF[] = true - try - reverse = CMode <: ReverseMode - - width = Enzyme.same_or_one(1, args...) - if width == 0 - throw(ErrorException("Cannot differentiate with a batch size of 0")) - end - - primf = f.val - primargs = ((v.val for v in args)...,) - - argprefix::Symbol = gensym("autodiffarg") - resprefix::Symbol = gensym("autodiffresult") - resargprefix::Symbol = gensym("autodiffresarg") - - mlir_fn_res = TracedUtils.make_mlir_fn( - primf, - primargs, - (), - string(f) * "_autodiff", - false; - argprefix, - resprefix, - resargprefix, - within_autodiff=true, - ) - (; result, linear_args, in_tys, linear_results) = mlir_fn_res - fnwrap = mlir_fn_res.fnwrapped - func2 = mlir_fn_res.f + reverse = CMode <: ReverseMode - activity = Int32[] - ad_inputs = MLIR.IR.Value[] + width = Enzyme.same_or_one(1, args...) + if width == 0 + throw(ErrorException("Cannot differentiate with a batch size of 0")) + end - for a in linear_args - idx, path = TracedUtils.get_argidx(a, argprefix) - arg = idx == 1 && fnwrap ? f : args[idx - fnwrap] - push!(activity, act_from_type(arg, reverse)) - push_acts!(ad_inputs, arg, path[3:end], reverse) - end + primf = f.val + primargs = ((v.val for v in args)...,) + + argprefix::Symbol = gensym("autodiffarg") + resprefix::Symbol = gensym("autodiffresult") + resargprefix::Symbol = gensym("autodiffresarg") + + mlir_fn_res = TracedUtils.make_mlir_fn( + primf, + primargs, + (), + string(f) * "_autodiff", + false; + argprefix, + resprefix, + resargprefix, + within_autodiff=true, + ) + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + activity = Int32[] + ad_inputs = MLIR.IR.Value[] + + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) + arg = idx == 1 && fnwrap ? f : args[idx - fnwrap] + push!(activity, act_from_type(arg, reverse)) + push_acts!(ad_inputs, arg, path[3:end], reverse) + end - outtys = MLIR.IR.Type[] - ret_activity = Int32[] + outtys = MLIR.IR.Type[] + ret_activity = Int32[] - for a in linear_results - if TracedUtils.has_idx(a, resprefix) - if EnzymeCore.needs_primal(CMode) - push!( - outtys, - TracedUtils.transpose_ty(MLIR.IR.type(TracedUtils.get_mlir_data(a))), - ) - end + for a in linear_results + if TracedUtils.has_idx(a, resprefix) + if EnzymeCore.needs_primal(CMode) + push!( + outtys, + TracedUtils.transpose_ty(MLIR.IR.type(TracedUtils.get_mlir_data(a))), + ) + end - if CMode <: ForwardMode && !(A <: Const) - push!( - outtys, - TracedUtils.batch_ty( - width, - TracedUtils.transpose_ty( - MLIR.IR.type(TracedUtils.get_mlir_data(a)) - ), + if CMode <: ForwardMode && !(A <: Const) + push!( + outtys, + TracedUtils.batch_ty( + width, + TracedUtils.transpose_ty( + MLIR.IR.type(TracedUtils.get_mlir_data(a)) ), - ) + ), + ) + end + + act = act_from_type(A, reverse, EnzymeCore.needs_primal(CMode)) + push!(ret_activity, act) + if act == enzyme_out || act == enzyme_outnoneed + if width == 1 + cst = @opcall fill(one(unwrapped_eltype(a)), size(a)) + else + cst = @opcall fill(one(unwrapped_eltype(a)), (size(a)..., width)) end + push!(ad_inputs, cst.mlir_data) + end + else + if TracedUtils.has_idx(a, argprefix) + idx, path = TracedUtils.get_argidx(a, argprefix) + arg = idx == 1 && fnwrap ? f : args[idx - fnwrap] - act = act_from_type(A, reverse, EnzymeCore.needs_primal(CMode)) + act = act_from_type(arg, reverse, true) push!(ret_activity, act) + if act == enzyme_out || act == enzyme_outnoneed if width == 1 - cst = @opcall fill(one(unwrapped_eltype(a)), size(a)) + TracedUtils.push_val!(ad_inputs, arg.dval, path[3:end]) + elseif arg.dval isa AbstractArray + TracedUtils.push_val!(ad_inputs, arg.dval, path[3:end]) else - cst = @opcall fill(one(unwrapped_eltype(a)), (size(a)..., width)) + TracedUtils.push_val!( + ad_inputs, call_with_reactant(stack, arg.dval), path[3:end] + ) end - push!(ad_inputs, cst.mlir_data) end else - if TracedUtils.has_idx(a, argprefix) - idx, path = TracedUtils.get_argidx(a, argprefix) - arg = idx == 1 && fnwrap ? f : args[idx - fnwrap] - - act = act_from_type(arg, reverse, true) - push!(ret_activity, act) - - if act == enzyme_out || act == enzyme_outnoneed - if width == 1 - TracedUtils.push_val!(ad_inputs, arg.dval, path[3:end]) - elseif arg.dval isa AbstractArray - TracedUtils.push_val!(ad_inputs, arg.dval, path[3:end]) - else - TracedUtils.push_val!( - ad_inputs, call_with_reactant(stack, arg.dval), path[3:end] - ) - end - end - else - act = act_from_type(Const, reverse, true) - push!(ret_activity, act) - end - - push!( - outtys, TracedUtils.transpose_ty(MLIR.IR.type(TracedUtils.get_mlir_data(a))) - ) + act = act_from_type(Const, reverse, true) + push!(ret_activity, act) end + + push!( + outtys, TracedUtils.transpose_ty(MLIR.IR.type(TracedUtils.get_mlir_data(a))) + ) end + end - for (i, act) in enumerate(activity) - if act == enzyme_out || act == enzyme_dup || act == enzyme_dupnoneed - push!(outtys, TracedUtils.batch_ty(width, in_tys[i])) - end + for (i, act) in enumerate(activity) + if act == enzyme_out || act == enzyme_dup || act == enzyme_dupnoneed + push!(outtys, TracedUtils.batch_ty(width, in_tys[i])) end + end - fname = TracedUtils.get_attribute_by_name(func2, "sym_name") - fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) - res = (reverse ? MLIR.Dialects.enzyme.autodiff : MLIR.Dialects.enzyme.fwddiff)( - [TracedUtils.transpose_val(v) for v in ad_inputs]; - outputs=outtys, - fn=fname, - width, - strong_zero=EnzymeCore.strong_zero(CMode), - activity=MLIR.IR.Attribute([act_attr(a) for a in activity]), - ret_activity=MLIR.IR.Attribute([act_attr(a) for a in ret_activity]), - ) + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + res = (reverse ? MLIR.Dialects.enzyme.autodiff : MLIR.Dialects.enzyme.fwddiff)( + [TracedUtils.transpose_val(v) for v in ad_inputs]; + outputs=outtys, + fn=fname, + width, + strong_zero=EnzymeCore.strong_zero(CMode), + activity=MLIR.IR.Attribute([act_attr(a) for a in activity]), + ret_activity=MLIR.IR.Attribute([act_attr(a) for a in ret_activity]), + ) - residx = 1 + residx = 1 - dresult = if CMode <: ForwardMode && !(A <: Const) - if width == 1 + dresult = if CMode <: ForwardMode && !(A <: Const) + if width == 1 + deepcopy(result) + else + ntuple(Val(width)) do i + Base.@_inline_meta deepcopy(result) - else - ntuple(Val(width)) do i - Base.@_inline_meta - deepcopy(result) - end end - else - nothing end + else + nothing + end - for a in linear_results - if TracedUtils.has_idx(a, resprefix) - if EnzymeCore.needs_primal(CMode) - path = TracedUtils.get_idx(a, resprefix) - tval = TracedUtils.transpose_val(MLIR.IR.result(res, residx)) - TracedUtils.set!(result, path[2:end], tval) - residx += 1 - end - if CMode <: ForwardMode && !(A <: Const) - path = TracedUtils.get_idx(a, resprefix) - tval = TracedUtils.transpose_val(MLIR.IR.result(res, residx)) - if width == 1 - TracedUtils.set!(dresult, path[2:end], tval) - else - ttval = TracedRArray(tval) - for (i, sl) in enumerate(eachslice(ttval; dims=ndims(ttval))) - TracedUtils.set!( - dresult[i], - path[2:end], - @allowscalar(TracedUtils.get_mlir_data(sl)) - ) - end + for a in linear_results + if TracedUtils.has_idx(a, resprefix) + if EnzymeCore.needs_primal(CMode) + path = TracedUtils.get_idx(a, resprefix) + tval = TracedUtils.transpose_val(MLIR.IR.result(res, residx)) + TracedUtils.set!(result, path[2:end], tval) + residx += 1 + end + if CMode <: ForwardMode && !(A <: Const) + path = TracedUtils.get_idx(a, resprefix) + tval = TracedUtils.transpose_val(MLIR.IR.result(res, residx)) + if width == 1 + TracedUtils.set!(dresult, path[2:end], tval) + else + ttval = TracedRArray(tval) + for (i, sl) in enumerate(eachslice(ttval; dims=ndims(ttval))) + TracedUtils.set!( + dresult[i], + path[2:end], + @allowscalar(TracedUtils.get_mlir_data(sl)) + ) end - residx += 1 end - elseif TracedUtils.has_idx(a, argprefix) - idx, path = TracedUtils.get_argidx(a, argprefix) - arg = idx == 1 && fnwrap ? f : args[idx - fnwrap] - TracedUtils.set!( - arg.val, path[3:end], TracedUtils.transpose_val(MLIR.IR.result(res, residx)) - ) - residx += 1 - else - TracedUtils.set!(a, (), TracedUtils.transpose_val(MLIR.IR.result(res, residx))) residx += 1 end - end - - restup = Any[(a isa Active) ? copy(a) : nothing for a in args] - for a in linear_args + elseif TracedUtils.has_idx(a, argprefix) idx, path = TracedUtils.get_argidx(a, argprefix) - arg = idx == 1 && fnwrap ? f : args[idx - fnwrap] - act_from_type(arg, reverse) != enzyme_out && continue - - if idx == 1 && fnwrap && arg isa Active - @assert false - end - - set_act!( - arg, - path[3:end], - reverse, - TracedUtils.transpose_val(MLIR.IR.result(res, residx)); - width, - emptypath=arg isa Active, + TracedUtils.set!( + arg.val, path[3:end], TracedUtils.transpose_val(MLIR.IR.result(res, residx)) ) residx += 1 + else + TracedUtils.set!(a, (), TracedUtils.transpose_val(MLIR.IR.result(res, residx))) + residx += 1 end + end + + restup = Any[(a isa Active) ? copy(a) : nothing for a in args] + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) + + arg = idx == 1 && fnwrap ? f : args[idx - fnwrap] + act_from_type(arg, reverse) != enzyme_out && continue + + if idx == 1 && fnwrap && arg isa Active + @assert false + end + + set_act!( + arg, + path[3:end], + reverse, + TracedUtils.transpose_val(MLIR.IR.result(res, residx)); + width, + emptypath=arg isa Active, + ) + residx += 1 + end - func2.operation = MLIR.API.MlirOperation(C_NULL) + func2.operation = MLIR.API.MlirOperation(C_NULL) - if reverse - resv = if EnzymeCore.needs_primal(CMode) - result + if reverse + resv = if EnzymeCore.needs_primal(CMode) + result + else + nothing + end + return ((restup...,), resv) + else + if EnzymeCore.needs_primal(CMode) + if CMode <: ForwardMode && !(A <: Const) + return (dresult, result) else - nothing + return (result,) end - return ((restup...,), resv) else - if EnzymeCore.needs_primal(CMode) - if CMode <: ForwardMode && !(A <: Const) - return (dresult, result) - else - return (result,) - end + if CMode <: ForwardMode && !(A <: Const) + return (dresult,) else - if CMode <: ForwardMode && !(A <: Const) - return (dresult,) - else - return () - end + return () end end - - finally - WITHIN_AUTODIFF[] = false end end diff --git a/src/Overlay.jl b/src/Overlay.jl index 63b691b608..399fdeffa7 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -9,14 +9,22 @@ end # Enzyme.jl overlays +const WITHIN_AUTODIFF = Ref(false) + @reactant_overlay @noinline function Enzyme.within_autodiff() - return overload_within_autodiff() + return WITHIN_AUTODIFF[] end @reactant_overlay @noinline function Enzyme.autodiff_deferred( rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs} ) where {FA<:Annotation,A<:Annotation,Nargs} - return overload_autodiff(rmode, f, rt, args...) + original_within_autodiff = WITHIN_AUTODIFF[] + try + WITHIN_AUTODIFF[] = true + return overload_autodiff(rmode, f, rt, args...) + finally + WITHIN_AUTODIFF[] = original_within_autodiff + end end @reactant_overlay @noinline function Enzyme.autodiff( From fe3cd15167434a48ba1cd52e1dda22f2502ccbea Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:54:03 -0600 Subject: [PATCH 08/14] remove call_with_reactant_within_autodiff --- src/TracedUtils.jl | 14 ++------------ src/utils.jl | 8 +------- 2 files changed, 3 insertions(+), 19 deletions(-) diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 4a6b841b37..37d80acfa8 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -343,19 +343,9 @@ function make_mlir_fn( process_linear_args!(linear_args, fnbody, do_transpose, optimize_then_pad, inv_map) if isempty(kwargs) - if within_autodiff - Reactant.call_with_reactant_within_autodiff(f, traced_args...) - else - Reactant.call_with_reactant(f, traced_args...) - end + Reactant.call_with_reactant(f, traced_args...) else - if within_autodiff - Reactant.call_with_reactant_within_autodiff( - Core.kwcall, kwargs, f, traced_args... - ) - else - Reactant.call_with_reactant(Core.kwcall, kwargs, f, traced_args...) - end + Reactant.call_with_reactant(Core.kwcall, kwargs, f, traced_args...) end finally MLIR.IR.deactivate!(fnbody) diff --git a/src/utils.jl b/src/utils.jl index 8bbdcd7e18..8deee83983 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -678,9 +678,7 @@ function call_with_reactant_generator( )) end - interp = ReactantInterpreter(; - world, within_autodiff=self == typeof(Reactant.call_with_reactant_within_autodiff) - ) + interp = ReactantInterpreter(; world) min_world = Ref{UInt}(typemin(UInt)) max_world = Ref{UInt}(typemax(UInt)) @@ -948,10 +946,6 @@ end $(Expr(:meta, :generated_only)) return $(Expr(:meta, :generated, call_with_reactant_generator)) end -@eval function call_with_reactant_within_autodiff($REDUB_ARGUMENTS_NAME...) - $(Expr(:meta, :generated_only)) - return $(Expr(:meta, :generated, call_with_reactant_generator)) -end @static if isdefined(Core, :BFloat16) nmantissa(::Type{Core.BFloat16}) = 7 From 025288b957a0ddda5f070e863bef0af7557ad465 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:56:13 -0600 Subject: [PATCH 09/14] format --- src/Enzyme.jl | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 52e0a79612..d8794d2cef 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -96,7 +96,7 @@ end @register_make_zero_inplace(Enzyme.remake_zero!) function Enzyme.make_zero( - ::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false) + ::Type{RT}, seen::IdDict, prev::RT, (::Val{copy_if_inactive})=Val(false) )::RT where {copy_if_inactive,RT<:Union{RArray,RNumber}} if haskey(seen, prev) return seen[prev] @@ -142,7 +142,7 @@ function EnzymeRules.augmented_primal( ) where {RT} primargs = ntuple(Val(length(args))) do i Base.@_inline_meta - args[i].val + return args[i].val end primal = if EnzymeCore.needs_primal(config) @@ -162,7 +162,7 @@ function EnzymeRules.augmented_primal( else ntuple(Val(EnzymeRules.width(config))) do i Base.@_inline_meta - ConcretePJRTArray( + return ConcretePJRTArray( zeros(T.val, primargs...); client=XLA.client(uval.val), device=XLA.device(uval.val), @@ -192,7 +192,7 @@ function EnzymeRules.reverse( ) where {RT,N} ntuple(Val(N + 2)) do i Base.@_inline_meta - nothing + return nothing end end @@ -295,7 +295,6 @@ function act_attr(val) return MLIR.IR.Attribute(val) end - function overload_autodiff( ::CMode, f::FA, ::Type{A}, args::Vararg{Annotation,Nargs} ) where {CMode<:Mode,FA<:Annotation,A<:Annotation,Nargs} @@ -428,7 +427,7 @@ function overload_autodiff( else ntuple(Val(width)) do i Base.@_inline_meta - deepcopy(result) + return deepcopy(result) end end else From bb72f5273cc4f1b7b5f4ae75d8581a3783fc9714 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Tue, 18 Nov 2025 10:19:45 -0600 Subject: [PATCH 10/14] cleanup and fix --- src/Enzyme.jl | 1 - src/Interpreter.jl | 23 +++++++++++------------ src/Overlay.jl | 8 +++++++- src/utils.jl | 1 - 4 files changed, 18 insertions(+), 15 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index d8794d2cef..b3a4305a13 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -321,7 +321,6 @@ function overload_autodiff( argprefix, resprefix, resargprefix, - within_autodiff=true, ) (; result, linear_args, in_tys, linear_results) = mlir_fn_res fnwrap = mlir_fn_res.fnwrapped diff --git a/src/Interpreter.jl b/src/Interpreter.jl index a3fc17fcbc..9c1d7c1c9e 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -90,17 +90,17 @@ end struct ReactantCacheToken end function ReactantInterpreter(; - world::UInt=Base.get_world_counter(), within_autodiff=false + world::UInt=Base.get_world_counter() ) return Enzyme.Compiler.Interpreter.EnzymeInterpreter( ReactantCacheToken(), REACTANT_METHOD_TABLE, world, - false, #=forward_rules=# - false, #=reverse_rules=# - false, #=inactive_rules=# - false, #=broadcast_rewrite=# - within_autodiff, #=within_autodiff_rewrite=# + false, #=forward_rules=# + false, #=reverse_rules=# + false, #=inactive_rules=# + false, #=broadcast_rewrite=# + false, #=within_autodiff_rewrite=# set_reactant_abi, ) end @@ -110,17 +110,16 @@ else function ReactantInterpreter(; world::UInt=Base.get_world_counter(), code_cache=REACTANT_CACHE, - within_autodiff=false, ) return Enzyme.Compiler.Interpreter.EnzymeInterpreter( REACTANT_CACHE, REACTANT_METHOD_TABLE, world, - false, #=forward_rules=# - false, #=reverse_rules=# - false, #=inactive_rules=# - false, #=broadcast_rewrite=# - within_autodiff, #=within_autodiff_rewrite=# + false, #=forward_rules=# + false, #=reverse_rules=# + false, #=inactive_rules=# + false, #=broadcast_rewrite=# + false, #=within_autodiff_rewrite=# set_reactant_abi, ) end diff --git a/src/Overlay.jl b/src/Overlay.jl index 399fdeffa7..b928ae4892 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -30,7 +30,13 @@ end @reactant_overlay @noinline function Enzyme.autodiff( rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs} ) where {FA<:Annotation,A<:Annotation,Nargs} - return overload_autodiff(rmode, f, rt, args...) + original_within_autodiff = WITHIN_AUTODIFF[] + try + WITHIN_AUTODIFF[] = true + return overload_autodiff(rmode, f, rt, args...) + finally + WITHIN_AUTODIFF[] = original_within_autodiff + end end @reactant_overlay function EnzymeCore.ignore_derivatives(args...) diff --git a/src/utils.jl b/src/utils.jl index 8deee83983..8688603ff1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -19,7 +19,6 @@ function apply(f::F, args...; kwargs...) where {F} end function call_with_reactant end -function call_with_reactant_within_autodiff end function maybe_argextype(@nospecialize(x), src) return try From a43fbb25550d9adfc20a443be720bfed22f7298e Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Tue, 18 Nov 2025 10:35:36 -0600 Subject: [PATCH 11/14] format --- src/Interpreter.jl | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 9c1d7c1c9e..673ceb8fa3 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -89,38 +89,35 @@ end @static if Enzyme.GPUCompiler.HAS_INTEGRATED_CACHE struct ReactantCacheToken end - function ReactantInterpreter(; - world::UInt=Base.get_world_counter() - ) + function ReactantInterpreter(; world::UInt=Base.get_world_counter()) return Enzyme.Compiler.Interpreter.EnzymeInterpreter( ReactantCacheToken(), REACTANT_METHOD_TABLE, world, - false, #=forward_rules=# - false, #=reverse_rules=# - false, #=inactive_rules=# - false, #=broadcast_rewrite=# - false, #=within_autodiff_rewrite=# - set_reactant_abi, + false, + #=forward_rules=#false, + #=reverse_rules=#false, + #=inactive_rules=#false, + #=broadcast_rewrite=#false, + #=within_autodiff_rewrite=#set_reactant_abi, ) end else const REACTANT_CACHE = Enzyme.GPUCompiler.CodeCache() function ReactantInterpreter(; - world::UInt=Base.get_world_counter(), - code_cache=REACTANT_CACHE, + world::UInt=Base.get_world_counter(), code_cache=REACTANT_CACHE ) return Enzyme.Compiler.Interpreter.EnzymeInterpreter( REACTANT_CACHE, REACTANT_METHOD_TABLE, world, - false, #=forward_rules=# - false, #=reverse_rules=# - false, #=inactive_rules=# - false, #=broadcast_rewrite=# - false, #=within_autodiff_rewrite=# - set_reactant_abi, + false, + #=forward_rules=#false, + #=reverse_rules=#false, + #=inactive_rules=#false, + #=broadcast_rewrite=#false, + #=within_autodiff_rewrite=#set_reactant_abi, ) end end From 64d3611134bdfe1beb8acb17004ca79eb8d7090d Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Tue, 18 Nov 2025 11:47:24 -0600 Subject: [PATCH 12/14] format with v1 --- src/Interpreter.jl | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 673ceb8fa3..d24a622b8e 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -95,11 +95,11 @@ end REACTANT_METHOD_TABLE, world, false, - #=forward_rules=#false, - #=reverse_rules=#false, - #=inactive_rules=#false, - #=broadcast_rewrite=#false, - #=within_autodiff_rewrite=#set_reactant_abi, + false, #=forward_rules=# + false, #=reverse_rules=# + false, #=inactive_rules=# + false, #=broadcast_rewrite=# + set_reactant_abi, #=within_autodiff_rewrite=# ) end else @@ -113,11 +113,11 @@ else REACTANT_METHOD_TABLE, world, false, - #=forward_rules=#false, - #=reverse_rules=#false, - #=inactive_rules=#false, - #=broadcast_rewrite=#false, - #=within_autodiff_rewrite=#set_reactant_abi, + false, #=forward_rules=# + false, #=reverse_rules=# + false, #=inactive_rules=# + false, #=broadcast_rewrite=# + set_reactant_abi, #=within_autodiff_rewrite=# ) end end From 4dbb406cc2052d8c4daf1608b569bb42aaf720e9 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Fri, 5 Dec 2025 23:51:57 -0600 Subject: [PATCH 13/14] revert wrong formatting changes to Enzyme.jl Co-authored-by: avik Co-authored-by: Avik Pal --- src/Enzyme.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 4131b80c34..e263e4d716 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -96,7 +96,7 @@ end @register_make_zero_inplace(Enzyme.remake_zero!) function Enzyme.make_zero( - ::Type{RT}, seen::IdDict, prev::RT, (::Val{copy_if_inactive})=Val(false) + ::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false) )::RT where {copy_if_inactive,RT<:Union{RArray,RNumber}} if haskey(seen, prev) return seen[prev] @@ -142,7 +142,7 @@ function EnzymeRules.augmented_primal( ) where {RT} primargs = ntuple(Val(length(args))) do i Base.@_inline_meta - return args[i].val + args[i].val end primal = if EnzymeCore.needs_primal(config) @@ -162,7 +162,7 @@ function EnzymeRules.augmented_primal( else ntuple(Val(EnzymeRules.width(config))) do i Base.@_inline_meta - return ConcretePJRTArray( + ConcretePJRTArray( zeros(T.val, primargs...); client=XLA.client(uval.val), device=XLA.device(uval.val), @@ -192,7 +192,7 @@ function EnzymeRules.reverse( ) where {RT,N} ntuple(Val(N + 2)) do i Base.@_inline_meta - return nothing + nothing end end @@ -426,7 +426,7 @@ function overload_autodiff( else ntuple(Val(width)) do i Base.@_inline_meta - return deepcopy(result) + deepcopy(result) end end else From 626e9090e57f57e9236fa06c08ce5b1751fd7f0a Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Fri, 5 Dec 2025 23:50:40 -0600 Subject: [PATCH 14/14] fix argument comments Co-authored-by: Avik Pal --- src/Interpreter.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Interpreter.jl b/src/Interpreter.jl index d24a622b8e..6f23ef582b 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -94,12 +94,12 @@ end ReactantCacheToken(), REACTANT_METHOD_TABLE, world, - false, false, #=forward_rules=# false, #=reverse_rules=# false, #=inactive_rules=# false, #=broadcast_rewrite=# - set_reactant_abi, #=within_autodiff_rewrite=# + false, #=within_autodiff_rewrite=# + set_reactant_abi, #=handler=# ) end else @@ -112,12 +112,12 @@ else REACTANT_CACHE, REACTANT_METHOD_TABLE, world, - false, false, #=forward_rules=# false, #=reverse_rules=# false, #=inactive_rules=# false, #=broadcast_rewrite=# - set_reactant_abi, #=within_autodiff_rewrite=# + false, #=within_autodiff_rewrite=# + set_reactant_abi, #=handler=# ) end end