Skip to content

Commit caf0027

Browse files
AntonOrestenmaleadtclaude
authored
matmul: Switch to trailing batch dims and allow mat-vec, vec-mat (#132)
Co-authored-by: Tim Besard <tim.besard@gmail.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 9d6ba70 commit caf0027

5 files changed

Lines changed: 543 additions & 68 deletions

File tree

examples/fft.jl

Lines changed: 44 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ using FFTW
2020
# in columns. In Julia column-major, reshape (F1F2, F0) puts stride-F0 elements in rows.
2121
# We use right-multiply X @ W instead of W @ X to process rows instead of columns.
2222
#
23-
# Input/output layout: (D, BS, N2D) where D=2 for real/imag interleaving.
23+
# Input/output memory layout: (D, BS, N2D) where D=2 for real/imag interleaving.
24+
# Internally, BS is permuted to trailing position for batched matmul convention.
2425
function fft_kernel(
2526
x_packed_in::ct.TileArray{Float32, 3}, # Input (D, BS, N2D) - natural Julia complex layout
2627
y_packed_out::ct.TileArray{Float32, 3}, # Output (D, BS, N2D)
@@ -55,96 +56,94 @@ function fft_kernel(
5556
bid = ct.bid(1)
5657

5758
# --- Load Input Data ---
58-
# Input is (D, BS, N2D) where D=2 for real/imag. Load and reshape to (2, BS, N).
59-
X_ri = reshape(ct.load(x_packed_in; index=(1, bid, 1), shape=(D, BS, N2D)), (2, BS, N))
59+
# Input is (D, BS, N2D) where D=2 for real/imag. Load and permute BS to trailing.
60+
X_ri_mem = reshape(ct.load(x_packed_in; index=(1, bid, 1), shape=(D, BS, N2D)), (2, BS, N))
61+
X_ri = permutedims(X_ri_mem, (1, 3, 2)) # (2, N, BS) — trailing batch
6062

6163
# Split real and imaginary parts (extract from first dimension)
62-
X_r = reshape(ct.extract(X_ri, (1, 1, 1), (1, BS, N)), (BS, F1F2, F0))
63-
X_i = reshape(ct.extract(X_ri, (2, 1, 1), (1, BS, N)), (BS, F1F2, F0))
64+
X_r = reshape(ct.extract(X_ri, (1, 1, 1), (1, N, BS)), (F1F2, F0, BS))
65+
X_i = reshape(ct.extract(X_ri, (2, 1, 1), (1, N, BS)), (F1F2, F0, BS))
6466

6567
# --- Load DFT Matrices ---
66-
# W0 (F0 x F0) - for right-multiply X @ W0
68+
# W0 (F0 x F0) - for right-multiply X @ W0, batch dim trailing
6769
W0_ri = reshape(ct.load(W0; index=(1, 1, 1), shape=(F0, F0, 2)), (F0, F0, 2))
68-
W0_r = ct.broadcast_to(reshape(ct.extract(W0_ri, (1, 1, 1), (F0, F0, 1)), (1, F0, F0)), (BS, F0, F0))
69-
W0_i = ct.broadcast_to(reshape(ct.extract(W0_ri, (1, 1, 2), (F0, F0, 1)), (1, F0, F0)), (BS, F0, F0))
70+
W0_r = ct.broadcast_to(reshape(ct.extract(W0_ri, (1, 1, 1), (F0, F0, 1)), (F0, F0, 1)), (F0, F0, BS))
71+
W0_i = ct.broadcast_to(reshape(ct.extract(W0_ri, (1, 1, 2), (F0, F0, 1)), (F0, F0, 1)), (F0, F0, BS))
7072

7173
# W1 (F1 x F1)
7274
W1_ri = reshape(ct.load(W1; index=(1, 1, 1), shape=(F1, F1, 2)), (F1, F1, 2))
73-
W1_r = ct.broadcast_to(reshape(ct.extract(W1_ri, (1, 1, 1), (F1, F1, 1)), (1, F1, F1)), (BS, F1, F1))
74-
W1_i = ct.broadcast_to(reshape(ct.extract(W1_ri, (1, 1, 2), (F1, F1, 1)), (1, F1, F1)), (BS, F1, F1))
75+
W1_r = ct.broadcast_to(reshape(ct.extract(W1_ri, (1, 1, 1), (F1, F1, 1)), (F1, F1, 1)), (F1, F1, BS))
76+
W1_i = ct.broadcast_to(reshape(ct.extract(W1_ri, (1, 1, 2), (F1, F1, 1)), (F1, F1, 1)), (F1, F1, BS))
7577

7678
# W2 (F2 x F2)
7779
W2_ri = reshape(ct.load(W2; index=(1, 1, 1), shape=(F2, F2, 2)), (F2, F2, 2))
78-
W2_r = ct.broadcast_to(reshape(ct.extract(W2_ri, (1, 1, 1), (F2, F2, 1)), (1, F2, F2)), (BS, F2, F2))
79-
W2_i = ct.broadcast_to(reshape(ct.extract(W2_ri, (1, 1, 2), (F2, F2, 1)), (1, F2, F2)), (BS, F2, F2))
80+
W2_r = ct.broadcast_to(reshape(ct.extract(W2_ri, (1, 1, 1), (F2, F2, 1)), (F2, F2, 1)), (F2, F2, BS))
81+
W2_i = ct.broadcast_to(reshape(ct.extract(W2_ri, (1, 1, 2), (F2, F2, 1)), (F2, F2, 1)), (F2, F2, BS))
8082

8183
# --- Load Twiddle Factors ---
8284
# T0 (F1F2, F0) - note swapped from Python's (F0, F1F2)
8385
T0_ri = reshape(ct.load(T0; index=(1, 1, 1), shape=(F1F2, F0, 2)), (F1F2, F0, 2))
84-
T0_r = reshape(ct.extract(T0_ri, (1, 1, 1), (F1F2, F0, 1)), (1, N))
85-
T0_i = reshape(ct.extract(T0_ri, (1, 1, 2), (F1F2, F0, 1)), (1, N))
86+
T0_r = reshape(ct.extract(T0_ri, (1, 1, 1), (F1F2, F0, 1)), (N, 1))
87+
T0_i = reshape(ct.extract(T0_ri, (1, 1, 2), (F1F2, F0, 1)), (N, 1))
8688

8789
# T1 (F0F2, F1) - note swapped from Python's (F1, F2)
8890
T1_ri = reshape(ct.load(T1; index=(1, 1, 1), shape=(F0F2, F1, 2)), (F0F2, F1, 2))
89-
T1_r = reshape(ct.extract(T1_ri, (1, 1, 1), (F0F2, F1, 1)), (1, F0F2 * F1))
90-
T1_i = reshape(ct.extract(T1_ri, (1, 1, 2), (F0F2, F1, 1)), (1, F0F2 * F1))
91+
T1_r = reshape(ct.extract(T1_ri, (1, 1, 1), (F0F2, F1, 1)), (F0F2 * F1, 1))
92+
T1_i = reshape(ct.extract(T1_ri, (1, 1, 2), (F0F2, F1, 1)), (F0F2 * F1, 1))
9193

9294
# --- Stage 0: F0-point DFT ---
93-
# X is (BS, F1F2, F0), W0 is (BS, F0, F0)
95+
# X is (F1F2, F0, BS), W0 is (F0, F0, BS) — trailing batch
9496
# Right-multiply: X @ W0 processes each row (F1F2 rows, each with F0 elements)
95-
# Each row has elements at stride F1F2 in the original array - exactly what we need!
96-
X_r_ = X_r * W0_r - X_i * W0_i # (BS, F1F2, F0) @ (BS, F0, F0) → (BS, F1F2, F0)
97+
X_r_ = X_r * W0_r - X_i * W0_i # (F1F2, F0, BS) @ (F0, F0, BS) → (F1F2, F0, BS)
9798
X_i_ = X_r * W0_i + X_i * W0_r
9899

99100
# --- Twiddle & Permute 0 ---
100-
# Reshape to (BS, N) for element-wise twiddle multiply
101-
X_r_flat = reshape(X_r_, (BS, N))
102-
X_i_flat = reshape(X_i_, (BS, N))
101+
# Reshape to (N, BS) for element-wise twiddle multiply
102+
X_r_flat = reshape(X_r_, (N, BS))
103+
X_i_flat = reshape(X_i_, (N, BS))
103104
X_r2 = T0_r .* X_r_flat .- T0_i .* X_i_flat
104105
X_i2 = T0_i .* X_r_flat .+ T0_r .* X_i_flat
105106

106107
# Reshape and permute for stage 1
107-
# Current logical layout after reshape (BS, F1F2, F0): data at (bs, f1*F2+f2, f0)
108-
# Reshape to (BS, F2, F1, F0) then permute to (BS, F0F2, F1) for stage 1
109-
X_r3 = reshape(X_r2, (BS, F2, F1, F0))
110-
X_i3 = reshape(X_i2, (BS, F2, F1, F0))
111-
X_r4 = permutedims(X_r3, (1, 2, 4, 3)) # (BS, F2, F0, F1)
112-
X_i4 = permutedims(X_i3, (1, 2, 4, 3))
113-
X_r5 = reshape(X_r4, (BS, F0F2, F1))
114-
X_i5 = reshape(X_i4, (BS, F0F2, F1))
108+
# Reshape to (F2, F1, F0, BS) then permute to (F0F2, F1, BS) for stage 1
109+
X_r3 = reshape(X_r2, (F2, F1, F0, BS))
110+
X_i3 = reshape(X_i2, (F2, F1, F0, BS))
111+
X_r4 = permutedims(X_r3, (1, 3, 2, 4)) # (F2, F0, F1, BS)
112+
X_i4 = permutedims(X_i3, (1, 3, 2, 4))
113+
X_r5 = reshape(X_r4, (F0F2, F1, BS))
114+
X_i5 = reshape(X_i4, (F0F2, F1, BS))
115115

116116
# --- Stage 1: F1-point DFT ---
117-
# X is (BS, F0F2, F1), W1 is (BS, F1, F1)
117+
# X is (F0F2, F1, BS), W1 is (F1, F1, BS)
118118
X_r6 = X_r5 * W1_r - X_i5 * W1_i
119119
X_i6 = X_r5 * W1_i + X_i5 * W1_r
120120

121121
# --- Twiddle & Permute 1 ---
122-
X_r_flat2 = reshape(X_r6, (BS, N))
123-
X_i_flat2 = reshape(X_i6, (BS, N))
122+
X_r_flat2 = reshape(X_r6, (N, BS))
123+
X_i_flat2 = reshape(X_i6, (N, BS))
124124
X_r7 = T1_r .* X_r_flat2 .- T1_i .* X_i_flat2
125125
X_i7 = T1_i .* X_r_flat2 .+ T1_r .* X_i_flat2
126126

127127
# Reshape and permute for stage 2
128-
X_r8 = reshape(X_r7, (BS, F2, F0, F1))
129-
X_i8 = reshape(X_i7, (BS, F2, F0, F1))
130-
X_r9 = permutedims(X_r8, (1, 3, 4, 2)) # (BS, F0, F1, F2)
131-
X_i9 = permutedims(X_i8, (1, 3, 4, 2))
132-
X_r10 = reshape(X_r9, (BS, F0F1, F2))
133-
X_i10 = reshape(X_i9, (BS, F0F1, F2))
128+
X_r8 = reshape(X_r7, (F2, F0, F1, BS))
129+
X_i8 = reshape(X_i7, (F2, F0, F1, BS))
130+
X_r9 = permutedims(X_r8, (2, 3, 1, 4)) # (F0, F1, F2, BS)
131+
X_i9 = permutedims(X_i8, (2, 3, 1, 4))
132+
X_r10 = reshape(X_r9, (F0F1, F2, BS))
133+
X_i10 = reshape(X_i9, (F0F1, F2, BS))
134134

135135
# --- Stage 2: F2-point DFT ---
136-
# X is (BS, F0F1, F2), W2 is (BS, F2, F2)
136+
# X is (F0F1, F2, BS), W2 is (F2, F2, BS)
137137
X_r11 = X_r10 * W2_r - X_i10 * W2_i
138138
X_i11 = X_r10 * W2_i + X_i10 * W2_r
139139

140140
# --- Final Output ---
141-
# After stage 2, data is in (BS, F0F1, F2) layout
142-
# Reshape to (BS, F0, F1, F2) - output is already in frequency order
143-
X_r_final = reshape(X_r11, (1, BS, N))
144-
X_i_final = reshape(X_i11, (1, BS, N))
141+
X_r_final = reshape(X_r11, (1, N, BS))
142+
X_i_final = reshape(X_i11, (1, N, BS))
145143

146144
# --- Concatenate and Store ---
147-
Y_ri = reshape(ct.cat((X_r_final, X_i_final), 1), (D, BS, N2D))
145+
# Permute BS back to middle for memory layout (D, BS, N2D)
146+
Y_ri = permutedims(reshape(ct.cat((X_r_final, X_i_final), 1), (D, N2D, BS)), (1, 3, 2))
148147
ct.store(y_packed_out; index=(1, bid, 1), tile=Y_ri)
149148

150149
return

src/language/operations.jl

Lines changed: 127 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -922,32 +922,145 @@ end
922922
=============================================================================#
923923

924924
# Matrix multiply-accumulate: muladd(a, b, acc) = a * b + acc
925-
@inline Base.muladd(a::Tile{T1, SA}, b::Tile{T2, SB}, acc::Tile{T3, SC}) where {T1, T2, T3, SA, SB, SC} =
925+
# Handles 1D promotion, type promotion, and batched dims (≥3D).
926+
# Note: SA, SB, SC type parameters required to avoid ambiguity with scalar methods during codegen
927+
@inline function Base.muladd(a::Tile{T1, SA}, b::Tile{T2, SB}, acc::Tile{T3, SC}) where {T1, T2, T3, SA, SB, SC}
928+
_muladd(a, b, acc, Val(ndims(a)), Val(ndims(b)))
929+
end
930+
931+
# 2D × 2D: direct MmaFOp with type promotion
932+
@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{2}, ::Val{2})
926933
Intrinsics.mma(a, b, acc)
934+
end
935+
936+
# Vec-mat (1D × 2D): reshape (M,) → (M, 1), MmaFOp, acc is already (M, N)
937+
@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{1}, ::Val{2})
938+
a2d = reshape(a, (size(a, 1), 1))
939+
_muladd(a2d, b, acc, Val(2), Val(2))
940+
end
941+
942+
# Mat-vec (2D × 1D): reshape b (K,) → (K, 1), acc (M,) → (M, 1), MmaFOp, squeeze back
943+
@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{2}, ::Val{1})
944+
M, K = size(a, 1), size(b, 1)
945+
b2d = reshape(b, (K, 1))
946+
acc2d = reshape(acc, (M, 1))
947+
result = _muladd(a, b2d, acc2d, Val(2), Val(2))
948+
reshape(result, (M,))
949+
end
950+
951+
# Vec-vec (1D × 1D): not supported
952+
@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{1})
953+
return :(throw(ArgumentError("Vector-vector multiply-accumulate is not supported.")))
954+
end
955+
956+
# Batched mat-vec / vec-mat (≥3D × 1D or 1D × ≥3D): not supported, unsqueeze manually
957+
@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{NB}) where {NB}
958+
NB >= 3 || return :(throw(ArgumentError("unreachable")))
959+
return :(throw(ArgumentError("Batched vec-mat is not supported. Reshape the 1D operand to 2D first.")))
960+
end
961+
@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{NA}, ::Val{1}) where {NA}
962+
NA >= 3 || return :(throw(ArgumentError("unreachable")))
963+
return :(throw(ArgumentError("Batched mat-vec is not supported. Reshape the 1D operand to 2D first.")))
964+
end
965+
966+
# Batched matmul (≥3D × ≥3D): trailing batch dims with broadcast
967+
# Julia convention: first two dims are matrix (M,K)/(K,N), trailing dims are batch.
968+
# MmaFOp expects exactly 3D tiles (B, M, K), so we:
969+
# 1. Broadcast batch dims to a common shape
970+
# 2. Permute trailing batch → leading
971+
# 3. Flatten multiple batch dims into one for MmaFOp
972+
# 4. Unflatten + permute back after
973+
@generated function _muladd(a::Tile{T1, SA}, b::Tile{T2, SB}, acc::Tile{T3, SC},
974+
::Val{NA}, ::Val{NB}) where {T1, T2, T3, SA, SB, SC, NA, NB}
975+
sa = Tuple(SA.parameters)
976+
sb = Tuple(SB.parameters)
977+
978+
# Matrix dims are first two; batch dims are trailing
979+
M = sa[1]; K = sa[2]; N = sb[2]
980+
a_batch = sa[3:end]
981+
b_batch = sb[3:end]
982+
983+
# Broadcast batch dims (pad shorter with trailing 1s, then broadcast)
984+
n_batch = max(length(a_batch), length(b_batch))
985+
a_batch_padded = (a_batch..., ntuple(Returns(1), n_batch - length(a_batch))...)
986+
b_batch_padded = (b_batch..., ntuple(Returns(1), n_batch - length(b_batch))...)
987+
batch_shape = map(max, a_batch_padded, b_batch_padded)
988+
B_flat = prod(batch_shape)
989+
990+
quote
991+
# Reshape + broadcast to align batch dims (still trailing)
992+
a_bc = broadcast_to(reshape(a, $((M, K, a_batch_padded...))), $((M, K, batch_shape...)))
993+
b_bc = broadcast_to(reshape(b, $((K, N, b_batch_padded...))), $((K, N, batch_shape...)))
994+
acc_bc = broadcast_to(acc, $((M, N, batch_shape...)))
995+
# Flatten batch dims to one (still trailing), then permute to leading
996+
a_3d = permutedims(reshape(a_bc, $((M, K, B_flat))), (3, 1, 2))
997+
b_3d = permutedims(reshape(b_bc, $((K, N, B_flat))), (3, 1, 2))
998+
acc_3d = permutedims(reshape(acc_bc, $((M, N, B_flat))), (3, 1, 2))
999+
# MmaFOp
1000+
result_3d = Intrinsics.mma(a_3d, b_3d, acc_3d)
1001+
# Permute back to trailing, unflatten batch dims
1002+
reshape(permutedims(result_3d, (2, 3, 1)), $((M, N, batch_shape...)))
1003+
end
1004+
end
9271005

928-
# Matrix multiplication (A * B like Julia arrays)
1006+
# Matrix multiplication: A * B = muladd(A, B, zeros)
9291007
# Note: SA, SB type parameters required to avoid ambiguity with scalar*tile methods during codegen
9301008
@inline function Base.:(*)(a::Tile{T1, SA}, b::Tile{T2, SB}) where {T1, T2, SA, SB}
931-
_matmul(a, b, Val(ndims(a)))
1009+
_matmul(a, b, Val(ndims(a)), Val(ndims(b)))
9321010
end
9331011

934-
# 2D matmul: (M, K) × (K, N) → (M, N)
935-
@inline function _matmul(a::Tile{T1}, b::Tile, ::Val{2}) where {T1}
936-
M = size(a, 1)
937-
N = size(b, 2)
938-
acc = zeros(T1, (M, N))
1012+
# 2D × 2D → (M, N)
1013+
@inline function _matmul(a::Tile{T1}, b::Tile, ::Val{2}, ::Val{2}) where {T1}
1014+
acc = zeros(T1, (size(a, 1), size(b, 2)))
9391015
muladd(a, b, acc)
9401016
end
9411017

942-
# 3D batched matmul: (B, M, K) × (B, K, N) → (B, M, N)
943-
@inline function _matmul(a::Tile{T1}, b::Tile, ::Val{3}) where {T1}
944-
B = max(size(a, 1), size(b, 1)) # Broadcast batch dimension
945-
M = size(a, 2)
946-
N = size(b, 3)
947-
acc = zeros(T1, (B, M, N))
1018+
# Vec-mat (1D × 2D) → (M, N)
1019+
@inline function _matmul(a::Tile{T1}, b::Tile, ::Val{1}, ::Val{2}) where {T1}
1020+
acc = zeros(T1, (size(a, 1), size(b, 2)))
9481021
muladd(a, b, acc)
9491022
end
9501023

1024+
# Mat-vec (2D × 1D) → (M,)
1025+
@inline function _matmul(a::Tile{T1}, b::Tile, ::Val{2}, ::Val{1}) where {T1}
1026+
acc = zeros(T1, (size(a, 1),))
1027+
muladd(a, b, acc)
1028+
end
1029+
1030+
# Vec-vec (1D × 1D): not supported
1031+
@generated function _matmul(::Tile, ::Tile, ::Val{1}, ::Val{1})
1032+
return :(throw(ArgumentError("Vector-vector multiplication is not supported. Use dot(a, b) for inner products, or reshape explicitly.")))
1033+
end
1034+
1035+
# Batched (≥3D × ≥3D) → (M, N, batch...)
1036+
@generated function _matmul(a::Tile{T1, SA}, b::Tile{T2, SB},
1037+
::Val{NA}, ::Val{NB}) where {T1, T2, SA, SB, NA, NB}
1038+
sa = Tuple(SA.parameters)
1039+
sb = Tuple(SB.parameters)
1040+
a_batch = sa[3:end]
1041+
b_batch = sb[3:end]
1042+
n_batch = max(length(a_batch), length(b_batch))
1043+
a_batch_padded = (a_batch..., ntuple(_ -> 1, n_batch - length(a_batch))...)
1044+
b_batch_padded = (b_batch..., ntuple(_ -> 1, n_batch - length(b_batch))...)
1045+
batch_shape = map(max, a_batch_padded, b_batch_padded)
1046+
M = sa[1]; N = sb[2]
1047+
out_shape = (M, N, batch_shape...)
1048+
quote
1049+
acc = zeros(T1, $out_shape)
1050+
muladd(a, b, acc)
1051+
end
1052+
end
1053+
1054+
# Batched × 1D: not supported — unsqueeze the 1D operand manually
1055+
@generated function _matmul(::Tile, ::Tile, ::Val{NA}, ::Val{1}) where {NA}
1056+
NA >= 3 || return :(throw(ArgumentError("unreachable")))
1057+
return :(throw(ArgumentError("Batched mat-vec is not supported. Reshape the 1D operand to 2D first.")))
1058+
end
1059+
@generated function _matmul(::Tile, ::Tile, ::Val{1}, ::Val{NB}) where {NB}
1060+
NB >= 3 || return :(throw(ArgumentError("unreachable")))
1061+
return :(throw(ArgumentError("Batched vec-mat is not supported. Reshape the 1D operand to 2D first.")))
1062+
end
1063+
9511064
#=============================================================================
9521065
Selection
9531066
=============================================================================#

0 commit comments

Comments
 (0)