Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions ext/ReactantAbstractFFTsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,25 @@ for op in (:rfft, :fft, :ifft)
invperm(perm),
)
end

# Out-of-place plan
plan_name = Symbol("Reactant", uppercase(string(op)), "Plan")
plan_f = Symbol("plan_", op)
@eval struct $(plan_name){T} <: AbstractFFTs.Plan{T} end
@eval AbstractFFTs.$(plan_f)(::Reactant.TracedRArray{T}) where {T} = $(plan_name){T}()
@eval Base.:*(::$(plan_name){T}, x::Reactant.TracedRArray{T}) where {T} =
AbstractFFTs.$(op)(x)

# In-place plan
if op !== :rfft
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whats up with rfft?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rfft returns an array of different shape than the input one, it can't be done in-place. There's no plan_rfft!

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah got it!

plan_name! = Symbol("Reactant", uppercase(string(op)), "InPlacePlan")
plan_f! = Symbol("plan_", op, "!")
@eval struct $(plan_name!){T} <: AbstractFFTs.Plan{T} end
@eval AbstractFFTs.$(plan_f!)(::Reactant.TracedRArray{T}) where {T} =
$(plan_name!){T}()
@eval Base.:*(::$(plan_name!){T}, x::Reactant.TracedRArray{T}) where {T} =
copyto!(x, AbstractFFTs.$(op)(x))
end
end

for op in (:irfft,)
Expand Down
42 changes: 42 additions & 0 deletions test/integration/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,45 @@ end
@test @jit(rfft(x_ra, (1, 2, 3))) ≈ rfft(x, (1, 2, 3))
end
end

@testset "Planned FFTs" begin
@testset "Out-of-place [$(fft), size $(size)]" for size in ((16,), (16, 16)),
(plan, fft) in (
(FFTW.plan_fft, FFTW.fft),
(FFTW.plan_ifft, FFTW.ifft),
(FFTW.plan_rfft, FFTW.rfft),
)

x = randn(fft === FFTW.rfft ? Float32 : ComplexF32, size)
x_r = Reactant.to_rarray(x)
# We make a copy of the original array to make sure the operation does
# not modify the input.
copied_x_r = copy(x_r)

planned_fft(x) = plan(x) * x
compiled_planned_fft = @compile planned_fft(x_r)
# Make sure the result is correct
@test compiled_planned_fft(x_r) ≈ fft(x)
# Make sure the operation is not in-place
@test x_r == copied_x_r
end

@testset "In-place [$(fft!), size $(size)]" for size in ((16,), (16, 16)),
(plan!, fft!) in ((FFTW.plan_fft!, FFTW.fft!), (FFTW.plan_ifft!, FFTW.ifft!))

x = randn(ComplexF32, size)
x_r = Reactant.to_rarray(x)
# We make a copy of the original array to make sure the operation
# modifies the input.
copied_x_r = copy(x_r)

planned_fft!(x) = plan!(x) * x
compiled_planned_fft! = @compile planned_fft!(x_r)
planned_y_r = compiled_planned_fft!(x_r)
# Make sure the result is correct
@test planned_y_r ≈ fft!(x)
# Make sure the operation is in-place
@test planned_y_r ≈ x_r
@test x_r ≉ copied_x_r
end
end
Loading