forked from EnzymeAD/Reactant.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathReactantAbstractFFTsExt.jl
More file actions
94 lines (80 loc) · 3.15 KB
/
ReactantAbstractFFTsExt.jl
File metadata and controls
94 lines (80 loc) · 3.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
module ReactantAbstractFFTsExt
using AbstractFFTs: AbstractFFTs
using Reactant: Reactant, MLIR, Ops, AnyTracedRArray, TracedRArray, TracedUtils
using Reactant.Ops: @opcall
function __permutation_to_move_dims_to_end(dims, N::Integer)
perm = [i for i in 1:N if i ∉ Set(dims)]
append!(perm, reverse(dims))
return perm
end
__is_valid_stablehlo_fft_dims(dim::Integer, N::Integer) = dim == N
function __is_valid_stablehlo_fft_dims(dims, N::Integer)
return collect(dims) == collect(N:-1:(N - length(dims) + 1))
end
for op in (:rfft, :fft, :ifft)
@eval function AbstractFFTs.$(op)(x::AnyTracedRArray, dims)
@assert maximum(dims) <= ndims(x) "Invalid dimensions for fft: $(dims)"
fft_lengths = Int64[size(x, dim) for dim in reverse(dims)]
if __is_valid_stablehlo_fft_dims(dims, ndims(x))
return @opcall fft(
TracedUtils.materialize_traced_array(x);
type=$(uppercase(string(op))),
length=fft_lengths,
)
end
perm = __permutation_to_move_dims_to_end(dims, ndims(x))
return permutedims(
@opcall(
fft(
TracedUtils.materialize_traced_array(permutedims(x, perm));
type=$(uppercase(string(op))),
length=fft_lengths,
)
),
invperm(perm),
)
end
# Out-of-place plan
plan_name = Symbol("Reactant", uppercase(string(op)), "Plan")
plan_f = Symbol("plan_", op)
@eval struct $(plan_name){T} <: AbstractFFTs.Plan{T} end
@eval AbstractFFTs.$(plan_f)(::Reactant.TracedRArray{T}) where {T} = $(plan_name){T}()
@eval Base.:*(::$(plan_name){T}, x::Reactant.TracedRArray{T}) where {T} =
AbstractFFTs.$(op)(x)
# In-place plan
if op !== :rfft
plan_name! = Symbol("Reactant", uppercase(string(op)), "InPlacePlan")
plan_f! = Symbol("plan_", op, "!")
@eval struct $(plan_name!){T} <: AbstractFFTs.Plan{T} end
@eval AbstractFFTs.$(plan_f!)(::Reactant.TracedRArray{T}) where {T} =
$(plan_name!){T}()
@eval Base.:*(::$(plan_name!){T}, x::Reactant.TracedRArray{T}) where {T} =
copyto!(x, AbstractFFTs.$(op)(x))
end
end
for op in (:irfft,)
mode = uppercase(string(op))
@eval function AbstractFFTs.$(op)(x::AnyTracedRArray, d::Integer, dims)
@assert maximum(dims) <= ndims(x) "Invalid dimensions for irfft: $(dims)"
fft_lengths = vcat(Int64[size(x, dim) for dim in reverse(dims[2:end])], d)
if __is_valid_stablehlo_fft_dims(dims, ndims(x))
return @opcall fft(
TracedUtils.materialize_traced_array(x);
type=$(uppercase(string(op))),
length=fft_lengths,
)
end
perm = __permutation_to_move_dims_to_end(dims, ndims(x))
return permutedims(
@opcall(
fft(
TracedUtils.materialize_traced_array(permutedims(x, perm));
type=$(uppercase(string(op))),
length=fft_lengths,
)
),
invperm(perm),
)
end
end
end