Add defer_within_autodiff to EnzymeInterpreter#2254
Conversation
| (; fargs, argtypes) = arginfo | ||
|
|
||
| if f === Enzyme.within_autodiff | ||
| if !(interp.defer_within_autodiff) && f === Enzyme.within_autodiff |
There was a problem hiding this comment.
Why is this necessary? This fundamentally breaks this functionality?
There was a problem hiding this comment.
This is for a Reactant bug: EnzymeAD/Reactant.jl#442 (comment)
Reason being that Reactant uses EnzymeInterpreter as well, while not necessarily doing autodiff.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2254 +/- ##
==========================================
- Coverage 74.93% 74.92% -0.01%
==========================================
Files 56 56
Lines 17434 17436 +2
==========================================
Hits 13064 13064
- Misses 4370 4372 +2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
@vchuravy can you give this a review before merge |
|
Seems fine. |
3e37885 to
9f54068
Compare
|
bump on this |
…_autodiff` to no return true during Reactant compilation. When this flag is true, `interp.handler` is responsible for handling within_autodiff, or to toggle defer_within_autodiff to false somewhere down the call chain.
9f54068 to
f1e15c9
Compare
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl
index 77f0027..260c90b 100644
--- a/src/compiler/interpreter.jl
+++ b/src/compiler/interpreter.jl
@@ -173,7 +173,7 @@ function EnzymeInterpreter(
reverse_rules::Bool,
inactive_rules::Bool,
broadcast_rewrite::Bool = true,
- within_autodiff_rewrite::Bool = true,
+ within_autodiff_rewrite::Bool = true,
handler = nothing
)
@assert world <= Base.get_world_counter()
@@ -250,23 +250,27 @@ EnzymeInterpreter(
handler = nothing
) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, inactive_rules, broadcast_rewrite, within_autodiff_rewrite, handler)
-function EnzymeInterpreter(interp::EnzymeInterpreter;
- cache_or_token = (@static if HAS_INTEGRATED_CACHE
- interp.token
- else
- interp.code_cache
- end),
- mt = interp.method_table,
- local_cache = interp.local_cache,
- world = interp.world,
- inf_params = interp.inf_params,
- opt_params = interp.opt_params,
- forward_rules = interp.forward_rules,
- reverse_rules = interp.reverse_rules,
- inactive_rules = interp.inactive_rules,
- broadcast_rewrite = interp.broadcast_rewrite,
- within_autodiff_rewrite = interp.within_autodiff_rewrite,
- handler = interp.handler)
+function EnzymeInterpreter(
+ interp::EnzymeInterpreter;
+ cache_or_token = (
+ @static if HAS_INTEGRATED_CACHE
+ interp.token
+ else
+ interp.code_cache
+ end
+ ),
+ mt = interp.method_table,
+ local_cache = interp.local_cache,
+ world = interp.world,
+ inf_params = interp.inf_params,
+ opt_params = interp.opt_params,
+ forward_rules = interp.forward_rules,
+ reverse_rules = interp.reverse_rules,
+ inactive_rules = interp.inactive_rules,
+ broadcast_rewrite = interp.broadcast_rewrite,
+ within_autodiff_rewrite = interp.within_autodiff_rewrite,
+ handler = interp.handler
+ )
return EnzymeInterpreter(
cache_or_token,
mt, |
Together with Reactant pr: EnzymeAD/Reactant.jl#490