Skip to content

Properly set within_autodiff (#442)#490

Merged
jumerckx merged 18 commits intomainfrom
jm/deferred_within_autodiff
Dec 14, 2025
Merged

Properly set within_autodiff (#442)#490
jumerckx merged 18 commits intomainfrom
jm/deferred_within_autodiff

Conversation

@jumerckx
Copy link
Copy Markdown
Collaborator

@jumerckx jumerckx commented Jan 7, 2025

fixes #442
needs Enzyme.jl: EnzymeAD/Enzyme.jl#2254

I had to introduce a new function call_with_reactant_within_autodiff to smuggle the within_autodiff in the call_with_reactant_generator through the self argument.
I also tried doing things through set_reactant_abi but that didn't seem to suffice (first commit).
Perhaps the extra code in set_reactant_abi isn't strictly necessary now so I can try removing it again if wanted.

@avik-pal avik-pal force-pushed the jm/deferred_within_autodiff branch 2 times, most recently from dd14b85 to 8731c7f Compare September 1, 2025 23:13
@avik-pal avik-pal requested a review from wsmoses September 1, 2025 23:41
Comment thread src/TracedUtils.jl Outdated

if isempty(kwargs)
Reactant.call_with_reactant(f, traced_args...)
if within_autodiff
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like a cleaner way to do this, is not to have a second interpreter. But instead we can create a new global ref set to false, and overlay within_autodiff to lookup that var, and during autodiff set that to true

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with this though, but if we were to do it in this form, I would probably change call_with_reactant to take a config type var, which stores the current state of whether in autodiff or not (and also we can extend to other things down the line as well)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I got rid of the second interpreter like you described

@avik-pal avik-pal force-pushed the jm/deferred_within_autodiff branch from a3115b5 to 2a4d0f2 Compare September 2, 2025 00:18
@codecov
Copy link
Copy Markdown

codecov Bot commented Sep 2, 2025

Codecov Report

❌ Patch coverage is 78.57143% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 42.56%. Comparing base (71b744e) to head (2a4d0f2).

Files with missing lines Patch % Lines
src/utils.jl 50.00% 2 Missing ⚠️
src/TracedUtils.jl 83.33% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #490      +/-   ##
==========================================
+ Coverage   42.55%   42.56%   +0.01%     
==========================================
  Files         123      123              
  Lines       21816    21826      +10     
==========================================
+ Hits         9283     9290       +7     
- Misses      12533    12536       +3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@avik-pal
Copy link
Copy Markdown
Collaborator

julia> using Enzyme

julia> function error_not_within_autodiff()
           !Enzyme.within_autodiff() && error("Not within autodiff")
           return nothing
       end
error_not_within_autodiff (generic function with 1 method)

julia> fwd_within_autodiff(Mode, RT) = Enzyme.autodiff(Mode, error_not_within_autodiff, RT)
fwd_within_autodiff (generic function with 1 method)

julia> error_not_within_autodiff()
ERROR: Not within autodiff
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] error_not_within_autodiff()
   @ Main ./REPL[5]:2
 [3] top-level scope
   @ REPL[7]:1
 [4] top-level scope
   @ none:1

julia> fwd_within_autodiff(Forward, Const)
()

julia> error_not_within_autodiff()

julia> Enzyme.within_autodiff()
false

I am extremely confused why is the 2nd call not throw an error here. Only happens if I call fwd_within_autodiff in between. cc @wsmoses this is in isolation from Reactant

@wsmoses
Copy link
Copy Markdown
Member

wsmoses commented Sep 13, 2025

@vchuravy er wat

jumerckx and others added 5 commits November 17, 2025 14:17
… === overload_autodiff`.

This doesn't work for some reason, the function within overload autodiff uses the original interpreter (?)
…mlir_fn. In order to pass this information from make_mlir_fn to call_with_reactant_generator, I introduced a new function `call_with_reactant_within_autodiff` which allows detection by looking at `self`.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@jumerckx jumerckx force-pushed the jm/deferred_within_autodiff branch 2 times, most recently from 15eaa8f to 159837e Compare November 18, 2025 16:19
work with global ref instead

This has the unfortunate downside of introducing a try finally block 
around `overload_autodiff`.
put try-finally outside of overload_autodiff call
@jumerckx jumerckx force-pushed the jm/deferred_within_autodiff branch 2 times, most recently from 14864f3 to bb72f52 Compare November 18, 2025 16:21
Comment thread src/Interpreter.jl Outdated
Comment on lines +98 to +102
#=forward_rules=#false,
#=reverse_rules=#false,
#=inactive_rules=#false,
#=broadcast_rewrite=#false,
#=within_autodiff_rewrite=#set_reactant_abi,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
#=forward_rules=#false,
#=reverse_rules=#false,
#=inactive_rules=#false,
#=broadcast_rewrite=#false,
#=within_autodiff_rewrite=#set_reactant_abi,
false, #=forward_rules=#
false, #=reverse_rules=#
false, #=inactive_rules=#
false, #=broadcast_rewrite=#
set_reactant_abi, #=within_autodiff_rewrite=#

Comment thread src/Interpreter.jl Outdated
Comment on lines +116 to +120
#=forward_rules=#false,
#=reverse_rules=#false,
#=inactive_rules=#false,
#=broadcast_rewrite=#false,
#=within_autodiff_rewrite=#set_reactant_abi,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
#=forward_rules=#false,
#=reverse_rules=#false,
#=inactive_rules=#false,
#=broadcast_rewrite=#false,
#=within_autodiff_rewrite=#set_reactant_abi,
false, #=forward_rules=#
false, #=reverse_rules=#
false, #=inactive_rules=#
false, #=broadcast_rewrite=#
set_reactant_abi, #=within_autodiff_rewrite=#

@jumerckx jumerckx requested a review from wsmoses November 20, 2025 16:53
Comment thread src/Interpreter.jl Outdated
false, #=broadcast_rewrite=#
false, #=within_autodiff_rewrite=#
set_reactant_abi,
set_reactant_abi, #=within_autodiff_rewrite=#
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment seems off

Comment thread src/Enzyme.jl Outdated
Comment on lines +99 to +429
@@ -142,7 +142,7 @@ function EnzymeRules.augmented_primal(
) where {RT}
primargs = ntuple(Val(length(args))) do i
Base.@_inline_meta
args[i].val
return args[i].val
end

primal = if EnzymeCore.needs_primal(config)
@@ -162,7 +162,7 @@ function EnzymeRules.augmented_primal(
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
ConcretePJRTArray(
return ConcretePJRTArray(
zeros(T.val, primargs...);
client=XLA.client(uval.val),
device=XLA.device(uval.val),
@@ -192,7 +192,7 @@ function EnzymeRules.reverse(
) where {RT,N}
ntuple(Val(N + 2)) do i
Base.@_inline_meta
nothing
return nothing
end
end

@@ -426,7 +426,7 @@ function overload_autodiff(
else
ntuple(Val(width)) do i
Base.@_inline_meta
deepcopy(result)
return deepcopy(result)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert these changes. possible originated from juliaformatter v2?

@jumerckx jumerckx force-pushed the jm/deferred_within_autodiff branch from e0a228f to 106375a Compare December 6, 2025 05:52
jumerckx and others added 2 commits December 5, 2025 23:53
@jumerckx jumerckx force-pushed the jm/deferred_within_autodiff branch from 106375a to 232b01f Compare December 6, 2025 05:53
Co-authored-by: Avik Pal <avikpal@mit.edu>
@jumerckx jumerckx force-pushed the jm/deferred_within_autodiff branch from 232b01f to 626e909 Compare December 6, 2025 05:55
@avik-pal
Copy link
Copy Markdown
Collaborator

@jumerckx is this good to go?

@jumerckx
Copy link
Copy Markdown
Collaborator Author

Yes! Will merge after CI has finished

@jumerckx jumerckx merged commit 4f8b904 into main Dec 14, 2025
70 checks passed
@jumerckx jumerckx deleted the jm/deferred_within_autodiff branch December 14, 2025 18:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Enzyme.within_autodiff returns true inside compile

3 participants