Skip to content

Add support for FFT plans#1931

Merged
giordano merged 3 commits intoEnzymeAD:mainfrom
giordano:mg/fft-plan
Dec 6, 2025
Merged

Add support for FFT plans#1931
giordano merged 3 commits intoEnzymeAD:mainfrom
giordano:mg/fft-plan

Conversation

@giordano
Copy link
Copy Markdown
Member

@giordano giordano commented Dec 4, 2025

Should fix #1926, but for some reason I don't quite understand, the in-place plans aren't in-place (the added tests should fail), even though the code in #1926 (comment) works as expected. I now get

julia> using Reactant, FFTW

julia> x_host = rand(ComplexF32, 16, 16); x_r = Reactant.to_rarray(x_host);

julia> planned_fft!(x) = plan_fft!(x) * x
planned_fft! (generic function with 1 method)

julia> @code_hlo planned_fft!(x_r)
module @reactant_planned... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  func.func @main(%arg0: tensor<16x16xcomplex<f32>> {enzymexla.memory_effects = []}) -> tensor<16x16xcomplex<f32>> attributes {enzymexla.memory_effects = []} {
    %0 = stablehlo.fft %arg0, type =  FFT, length = [16, 16] : (tensor<16x16xcomplex<f32>>) -> tensor<16x16xcomplex<f32>>
    return %0 : tensor<16x16xcomplex<f32>>
  }
}

Note that contrary to #1926 (comment), the input is not aliasing the output. @avik-pal do you have any clue of what's going on here?

Edit: I simply missed a ! 🤦

@giordano giordano added the enhancement New feature or request label Dec 4, 2025
Comment thread ext/ReactantAbstractFFTsExt.jl Outdated
@giordano giordano force-pushed the mg/fft-plan branch 2 times, most recently from 26824de to 9cb5e4f Compare December 4, 2025 17:11
@giordano giordano requested review from avik-pal and wsmoses December 5, 2025 10:42
Comment thread ext/ReactantAbstractFFTsExt.jl Outdated
giordano and others added 3 commits December 6, 2025 01:24
@giordano
Copy link
Copy Markdown
Member Author

giordano commented Dec 6, 2025

I'd hope XLA is able to reuse plans applied to different arrays of the same shape. Since all the planning is completely opaque to the frontend we can't do much about it, but reusing plans is quite a crucial optimisation (and the whole point plans exist)

@wsmoses
Copy link
Copy Markdown
Member

wsmoses commented Dec 6, 2025

I'd assume so as well, but also since we already separate compile from run, during compile even if we make a different plan for all the fft instructions, they'd still all be planned during runtime ? It already does this with matmul by trying all possible matmul plans I know

@giordano
Copy link
Copy Markdown
Member Author

giordano commented Dec 6, 2025

The requisites for a plan are the shape and the eltype of the target array, and XLA should known them all at compile-time, right? So I'd speculate it's possible it does all the planning during compilation. But still, planning can be quite expensive, not repeating it unnecessarily would be welcome in large applications, whether it's during compile- or run-time.

@wsmoses
Copy link
Copy Markdown
Member

wsmoses commented Dec 6, 2025

yeah Id assume it has a cache, but we can benchmark, check, and fix, as needed as a follow up

AbstractFFTs.$(op)(x)

# In-place plan
if op !== :rfft
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whats up with rfft?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rfft returns an array of different shape than the input one, it can't be done in-place. There's no plan_rfft!

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah got it!

@giordano giordano merged commit d0cace0 into EnzymeAD:main Dec 6, 2025
176 of 181 checks passed
@giordano giordano deleted the mg/fft-plan branch December 6, 2025 19:42
@giordano giordano added the FFT FFT operations label Feb 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request FFT FFT operations

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support FFT plans

3 participants