@@ -20,7 +20,8 @@ using FFTW
2020# in columns. In Julia column-major, reshape (F1F2, F0) puts stride-F0 elements in rows.
2121# We use right-multiply X @ W instead of W @ X to process rows instead of columns.
2222#
23- # Input/output layout: (D, BS, N2D) where D=2 for real/imag interleaving.
23+ # Input/output memory layout: (D, BS, N2D) where D=2 for real/imag interleaving.
24+ # Internally, BS is permuted to trailing position for batched matmul convention.
2425function fft_kernel (
2526 x_packed_in:: ct.TileArray{Float32, 3} , # Input (D, BS, N2D) - natural Julia complex layout
2627 y_packed_out:: ct.TileArray{Float32, 3} , # Output (D, BS, N2D)
@@ -55,96 +56,94 @@ function fft_kernel(
5556 bid = ct. bid (1 )
5657
5758 # --- Load Input Data ---
58- # Input is (D, BS, N2D) where D=2 for real/imag. Load and reshape to (2, BS, N).
59- X_ri = reshape (ct. load (x_packed_in; index= (1 , bid, 1 ), shape= (D, BS, N2D)), (2 , BS, N))
59+ # Input is (D, BS, N2D) where D=2 for real/imag. Load and permute BS to trailing.
60+ X_ri_mem = reshape (ct. load (x_packed_in; index= (1 , bid, 1 ), shape= (D, BS, N2D)), (2 , BS, N))
61+ X_ri = permutedims (X_ri_mem, (1 , 3 , 2 )) # (2, N, BS) — trailing batch
6062
6163 # Split real and imaginary parts (extract from first dimension)
62- X_r = reshape (ct. extract (X_ri, (1 , 1 , 1 ), (1 , BS, N )), (BS, F1F2, F0))
63- X_i = reshape (ct. extract (X_ri, (2 , 1 , 1 ), (1 , BS, N )), (BS, F1F2, F0))
64+ X_r = reshape (ct. extract (X_ri, (1 , 1 , 1 ), (1 , N, BS )), (F1F2, F0, BS ))
65+ X_i = reshape (ct. extract (X_ri, (2 , 1 , 1 ), (1 , N, BS )), (F1F2, F0, BS ))
6466
6567 # --- Load DFT Matrices ---
66- # W0 (F0 x F0) - for right-multiply X @ W0
68+ # W0 (F0 x F0) - for right-multiply X @ W0, batch dim trailing
6769 W0_ri = reshape (ct. load (W0; index= (1 , 1 , 1 ), shape= (F0, F0, 2 )), (F0, F0, 2 ))
68- W0_r = ct. broadcast_to (reshape (ct. extract (W0_ri, (1 , 1 , 1 ), (F0, F0, 1 )), (1 , F0, F0 )), (BS , F0, F0 ))
69- W0_i = ct. broadcast_to (reshape (ct. extract (W0_ri, (1 , 1 , 2 ), (F0, F0, 1 )), (1 , F0, F0 )), (BS , F0, F0 ))
70+ W0_r = ct. broadcast_to (reshape (ct. extract (W0_ri, (1 , 1 , 1 ), (F0, F0, 1 )), (F0 , F0, 1 )), (F0 , F0, BS ))
71+ W0_i = ct. broadcast_to (reshape (ct. extract (W0_ri, (1 , 1 , 2 ), (F0, F0, 1 )), (F0 , F0, 1 )), (F0 , F0, BS ))
7072
7173 # W1 (F1 x F1)
7274 W1_ri = reshape (ct. load (W1; index= (1 , 1 , 1 ), shape= (F1, F1, 2 )), (F1, F1, 2 ))
73- W1_r = ct. broadcast_to (reshape (ct. extract (W1_ri, (1 , 1 , 1 ), (F1, F1, 1 )), (1 , F1, F1 )), (BS , F1, F1 ))
74- W1_i = ct. broadcast_to (reshape (ct. extract (W1_ri, (1 , 1 , 2 ), (F1, F1, 1 )), (1 , F1, F1 )), (BS , F1, F1 ))
75+ W1_r = ct. broadcast_to (reshape (ct. extract (W1_ri, (1 , 1 , 1 ), (F1, F1, 1 )), (F1 , F1, 1 )), (F1 , F1, BS ))
76+ W1_i = ct. broadcast_to (reshape (ct. extract (W1_ri, (1 , 1 , 2 ), (F1, F1, 1 )), (F1 , F1, 1 )), (F1 , F1, BS ))
7577
7678 # W2 (F2 x F2)
7779 W2_ri = reshape (ct. load (W2; index= (1 , 1 , 1 ), shape= (F2, F2, 2 )), (F2, F2, 2 ))
78- W2_r = ct. broadcast_to (reshape (ct. extract (W2_ri, (1 , 1 , 1 ), (F2, F2, 1 )), (1 , F2, F2 )), (BS , F2, F2 ))
79- W2_i = ct. broadcast_to (reshape (ct. extract (W2_ri, (1 , 1 , 2 ), (F2, F2, 1 )), (1 , F2, F2 )), (BS , F2, F2 ))
80+ W2_r = ct. broadcast_to (reshape (ct. extract (W2_ri, (1 , 1 , 1 ), (F2, F2, 1 )), (F2 , F2, 1 )), (F2 , F2, BS ))
81+ W2_i = ct. broadcast_to (reshape (ct. extract (W2_ri, (1 , 1 , 2 ), (F2, F2, 1 )), (F2 , F2, 1 )), (F2 , F2, BS ))
8082
8183 # --- Load Twiddle Factors ---
8284 # T0 (F1F2, F0) - note swapped from Python's (F0, F1F2)
8385 T0_ri = reshape (ct. load (T0; index= (1 , 1 , 1 ), shape= (F1F2, F0, 2 )), (F1F2, F0, 2 ))
84- T0_r = reshape (ct. extract (T0_ri, (1 , 1 , 1 ), (F1F2, F0, 1 )), (1 , N ))
85- T0_i = reshape (ct. extract (T0_ri, (1 , 1 , 2 ), (F1F2, F0, 1 )), (1 , N ))
86+ T0_r = reshape (ct. extract (T0_ri, (1 , 1 , 1 ), (F1F2, F0, 1 )), (N, 1 ))
87+ T0_i = reshape (ct. extract (T0_ri, (1 , 1 , 2 ), (F1F2, F0, 1 )), (N, 1 ))
8688
8789 # T1 (F0F2, F1) - note swapped from Python's (F1, F2)
8890 T1_ri = reshape (ct. load (T1; index= (1 , 1 , 1 ), shape= (F0F2, F1, 2 )), (F0F2, F1, 2 ))
89- T1_r = reshape (ct. extract (T1_ri, (1 , 1 , 1 ), (F0F2, F1, 1 )), (1 , F0F2 * F1))
90- T1_i = reshape (ct. extract (T1_ri, (1 , 1 , 2 ), (F0F2, F1, 1 )), (1 , F0F2 * F1))
91+ T1_r = reshape (ct. extract (T1_ri, (1 , 1 , 1 ), (F0F2, F1, 1 )), (F0F2 * F1, 1 ))
92+ T1_i = reshape (ct. extract (T1_ri, (1 , 1 , 2 ), (F0F2, F1, 1 )), (F0F2 * F1, 1 ))
9193
9294 # --- Stage 0: F0-point DFT ---
93- # X is (BS, F1F2, F0), W0 is (BS , F0, F0)
95+ # X is (F1F2, F0, BS ), W0 is (F0 , F0, BS) — trailing batch
9496 # Right-multiply: X @ W0 processes each row (F1F2 rows, each with F0 elements)
95- # Each row has elements at stride F1F2 in the original array - exactly what we need!
96- X_r_ = X_r * W0_r - X_i * W0_i # (BS, F1F2, F0) @ (BS, F0, F0) → (BS, F1F2, F0)
97+ X_r_ = X_r * W0_r - X_i * W0_i # (F1F2, F0, BS) @ (F0, F0, BS) → (F1F2, F0, BS)
9798 X_i_ = X_r * W0_i + X_i * W0_r
9899
99100 # --- Twiddle & Permute 0 ---
100- # Reshape to (BS, N ) for element-wise twiddle multiply
101- X_r_flat = reshape (X_r_, (BS, N ))
102- X_i_flat = reshape (X_i_, (BS, N ))
101+ # Reshape to (N, BS ) for element-wise twiddle multiply
102+ X_r_flat = reshape (X_r_, (N, BS ))
103+ X_i_flat = reshape (X_i_, (N, BS ))
103104 X_r2 = T0_r .* X_r_flat .- T0_i .* X_i_flat
104105 X_i2 = T0_i .* X_r_flat .+ T0_r .* X_i_flat
105106
106107 # Reshape and permute for stage 1
107- # Current logical layout after reshape (BS, F1F2, F0): data at (bs, f1*F2+f2, f0)
108- # Reshape to (BS, F2, F1, F0) then permute to (BS, F0F2, F1) for stage 1
109- X_r3 = reshape (X_r2, (BS, F2, F1, F0))
110- X_i3 = reshape (X_i2, (BS, F2, F1, F0))
111- X_r4 = permutedims (X_r3, (1 , 2 , 4 , 3 )) # (BS, F2, F0, F1)
112- X_i4 = permutedims (X_i3, (1 , 2 , 4 , 3 ))
113- X_r5 = reshape (X_r4, (BS, F0F2, F1))
114- X_i5 = reshape (X_i4, (BS, F0F2, F1))
108+ # Reshape to (F2, F1, F0, BS) then permute to (F0F2, F1, BS) for stage 1
109+ X_r3 = reshape (X_r2, (F2, F1, F0, BS))
110+ X_i3 = reshape (X_i2, (F2, F1, F0, BS))
111+ X_r4 = permutedims (X_r3, (1 , 3 , 2 , 4 )) # (F2, F0, F1, BS)
112+ X_i4 = permutedims (X_i3, (1 , 3 , 2 , 4 ))
113+ X_r5 = reshape (X_r4, (F0F2, F1, BS))
114+ X_i5 = reshape (X_i4, (F0F2, F1, BS))
115115
116116 # --- Stage 1: F1-point DFT ---
117- # X is (BS, F0F2, F1), W1 is (BS , F1, F1 )
117+ # X is (F0F2, F1, BS ), W1 is (F1 , F1, BS )
118118 X_r6 = X_r5 * W1_r - X_i5 * W1_i
119119 X_i6 = X_r5 * W1_i + X_i5 * W1_r
120120
121121 # --- Twiddle & Permute 1 ---
122- X_r_flat2 = reshape (X_r6, (BS, N ))
123- X_i_flat2 = reshape (X_i6, (BS, N ))
122+ X_r_flat2 = reshape (X_r6, (N, BS ))
123+ X_i_flat2 = reshape (X_i6, (N, BS ))
124124 X_r7 = T1_r .* X_r_flat2 .- T1_i .* X_i_flat2
125125 X_i7 = T1_i .* X_r_flat2 .+ T1_r .* X_i_flat2
126126
127127 # Reshape and permute for stage 2
128- X_r8 = reshape (X_r7, (BS, F2, F0, F1))
129- X_i8 = reshape (X_i7, (BS, F2, F0, F1))
130- X_r9 = permutedims (X_r8, (1 , 3 , 4 , 2 )) # (BS, F0, F1, F2)
131- X_i9 = permutedims (X_i8, (1 , 3 , 4 , 2 ))
132- X_r10 = reshape (X_r9, (BS, F0F1, F2))
133- X_i10 = reshape (X_i9, (BS, F0F1, F2))
128+ X_r8 = reshape (X_r7, (F2, F0, F1, BS ))
129+ X_i8 = reshape (X_i7, (F2, F0, F1, BS ))
130+ X_r9 = permutedims (X_r8, (2 , 3 , 1 , 4 )) # (F0, F1, F2, BS )
131+ X_i9 = permutedims (X_i8, (2 , 3 , 1 , 4 ))
132+ X_r10 = reshape (X_r9, (F0F1, F2, BS ))
133+ X_i10 = reshape (X_i9, (F0F1, F2, BS ))
134134
135135 # --- Stage 2: F2-point DFT ---
136- # X is (BS, F0F1, F2), W2 is (BS , F2, F2 )
136+ # X is (F0F1, F2, BS ), W2 is (F2 , F2, BS )
137137 X_r11 = X_r10 * W2_r - X_i10 * W2_i
138138 X_i11 = X_r10 * W2_i + X_i10 * W2_r
139139
140140 # --- Final Output ---
141- # After stage 2, data is in (BS, F0F1, F2) layout
142- # Reshape to (BS, F0, F1, F2) - output is already in frequency order
143- X_r_final = reshape (X_r11, (1 , BS, N))
144- X_i_final = reshape (X_i11, (1 , BS, N))
141+ X_r_final = reshape (X_r11, (1 , N, BS))
142+ X_i_final = reshape (X_i11, (1 , N, BS))
145143
146144 # --- Concatenate and Store ---
147- Y_ri = reshape (ct. cat ((X_r_final, X_i_final), 1 ), (D, BS, N2D))
145+ # Permute BS back to middle for memory layout (D, BS, N2D)
146+ Y_ri = permutedims (reshape (ct. cat ((X_r_final, X_i_final), 1 ), (D, N2D, BS)), (1 , 3 , 2 ))
148147 ct. store (y_packed_out; index= (1 , bid, 1 ), tile= Y_ri)
149148
150149 return
0 commit comments