Conversation
dd14b85 to
8731c7f
Compare
|
|
||
| if isempty(kwargs) | ||
| Reactant.call_with_reactant(f, traced_args...) | ||
| if within_autodiff |
There was a problem hiding this comment.
I feel like a cleaner way to do this, is not to have a second interpreter. But instead we can create a new global ref set to false, and overlay within_autodiff to lookup that var, and during autodiff set that to true
There was a problem hiding this comment.
I'm okay with this though, but if we were to do it in this form, I would probably change call_with_reactant to take a config type var, which stores the current state of whether in autodiff or not (and also we can extend to other things down the line as well)
There was a problem hiding this comment.
Ok, I got rid of the second interpreter like you described
a3115b5 to
2a4d0f2
Compare
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #490 +/- ##
==========================================
+ Coverage 42.55% 42.56% +0.01%
==========================================
Files 123 123
Lines 21816 21826 +10
==========================================
+ Hits 9283 9290 +7
- Misses 12533 12536 +3 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
julia> using Enzyme
julia> function error_not_within_autodiff()
!Enzyme.within_autodiff() && error("Not within autodiff")
return nothing
end
error_not_within_autodiff (generic function with 1 method)
julia> fwd_within_autodiff(Mode, RT) = Enzyme.autodiff(Mode, error_not_within_autodiff, RT)
fwd_within_autodiff (generic function with 1 method)
julia> error_not_within_autodiff()
ERROR: Not within autodiff
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] error_not_within_autodiff()
@ Main ./REPL[5]:2
[3] top-level scope
@ REPL[7]:1
[4] top-level scope
@ none:1
julia> fwd_within_autodiff(Forward, Const)
()
julia> error_not_within_autodiff()
julia> Enzyme.within_autodiff()
falseI am extremely confused why is the 2nd call not throw an error here. Only happens if I call fwd_within_autodiff in between. cc @wsmoses this is in isolation from Reactant |
|
@vchuravy er wat |
… === overload_autodiff`. This doesn't work for some reason, the function within overload autodiff uses the original interpreter (?)
…mlir_fn. In order to pass this information from make_mlir_fn to call_with_reactant_generator, I introduced a new function `call_with_reactant_within_autodiff` which allows detection by looking at `self`.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
15eaa8f to
159837e
Compare
work with global ref instead This has the unfortunate downside of introducing a try finally block around `overload_autodiff`.
14864f3 to
bb72f52
Compare
| #=forward_rules=#false, | ||
| #=reverse_rules=#false, | ||
| #=inactive_rules=#false, | ||
| #=broadcast_rewrite=#false, | ||
| #=within_autodiff_rewrite=#set_reactant_abi, |
There was a problem hiding this comment.
[JuliaFormatter] reported by reviewdog 🐶
| #=forward_rules=#false, | |
| #=reverse_rules=#false, | |
| #=inactive_rules=#false, | |
| #=broadcast_rewrite=#false, | |
| #=within_autodiff_rewrite=#set_reactant_abi, | |
| false, #=forward_rules=# | |
| false, #=reverse_rules=# | |
| false, #=inactive_rules=# | |
| false, #=broadcast_rewrite=# | |
| set_reactant_abi, #=within_autodiff_rewrite=# |
| #=forward_rules=#false, | ||
| #=reverse_rules=#false, | ||
| #=inactive_rules=#false, | ||
| #=broadcast_rewrite=#false, | ||
| #=within_autodiff_rewrite=#set_reactant_abi, |
There was a problem hiding this comment.
[JuliaFormatter] reported by reviewdog 🐶
| #=forward_rules=#false, | |
| #=reverse_rules=#false, | |
| #=inactive_rules=#false, | |
| #=broadcast_rewrite=#false, | |
| #=within_autodiff_rewrite=#set_reactant_abi, | |
| false, #=forward_rules=# | |
| false, #=reverse_rules=# | |
| false, #=inactive_rules=# | |
| false, #=broadcast_rewrite=# | |
| set_reactant_abi, #=within_autodiff_rewrite=# |
| false, #=broadcast_rewrite=# | ||
| false, #=within_autodiff_rewrite=# | ||
| set_reactant_abi, | ||
| set_reactant_abi, #=within_autodiff_rewrite=# |
| @@ -142,7 +142,7 @@ function EnzymeRules.augmented_primal( | |||
| ) where {RT} | |||
| primargs = ntuple(Val(length(args))) do i | |||
| Base.@_inline_meta | |||
| args[i].val | |||
| return args[i].val | |||
| end | |||
|
|
|||
| primal = if EnzymeCore.needs_primal(config) | |||
| @@ -162,7 +162,7 @@ function EnzymeRules.augmented_primal( | |||
| else | |||
| ntuple(Val(EnzymeRules.width(config))) do i | |||
| Base.@_inline_meta | |||
| ConcretePJRTArray( | |||
| return ConcretePJRTArray( | |||
| zeros(T.val, primargs...); | |||
| client=XLA.client(uval.val), | |||
| device=XLA.device(uval.val), | |||
| @@ -192,7 +192,7 @@ function EnzymeRules.reverse( | |||
| ) where {RT,N} | |||
| ntuple(Val(N + 2)) do i | |||
| Base.@_inline_meta | |||
| nothing | |||
| return nothing | |||
| end | |||
| end | |||
|
|
|||
| @@ -426,7 +426,7 @@ function overload_autodiff( | |||
| else | |||
| ntuple(Val(width)) do i | |||
| Base.@_inline_meta | |||
| deepcopy(result) | |||
| return deepcopy(result) | |||
There was a problem hiding this comment.
revert these changes. possible originated from juliaformatter v2?
e0a228f to
106375a
Compare
Co-authored-by: avik Co-authored-by: Avik Pal <avikpal@mit.edu>
106375a to
232b01f
Compare
Co-authored-by: Avik Pal <avikpal@mit.edu>
232b01f to
626e909
Compare
|
@jumerckx is this good to go? |
|
Yes! Will merge after CI has finished |
fixes #442
needs Enzyme.jl: EnzymeAD/Enzyme.jl#2254
I had to introduce a new function
call_with_reactant_within_autodiffto smuggle thewithin_autodiffin thecall_with_reactant_generatorthrough theselfargument.I also tried doing things through
set_reactant_abibut that didn't seem to suffice (first commit).Perhaps the extra code in
set_reactant_abiisn't strictly necessary now so I can try removing it again if wanted.