Skip to content

Commit d595fcc

Browse files
authored
Merge pull request #87 from JuliaGPU/tb/const
Pass constants as scalars, infer as constants.
2 parents 8627ec3 + 5b49cab commit d595fcc

20 files changed

Lines changed: 354 additions & 240 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,6 @@ DLFP8TypesExt = "DLFP8Types"
2525
BFloat16s = "0.6"
2626
CUDA_Compiler_jll = "0.4"
2727
CUDA_Tile_jll = "13.1"
28-
CompilerCaching = "0.1"
28+
CompilerCaching = "0.1.2"
2929
IRStructurizer = "0.1"
3030
julia = "1.11"

README.md

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ using CUDA
3434
import cuTile as ct
3535

3636
# Define kernel
37-
function vadd(a, b, c, tile_size::ct.Constant{Int})
37+
function vadd(a, b, c, tile_size::Int)
3838
pid = ct.bid(1)
39-
tile_a = ct.load(a, pid, (tile_size[],))
40-
tile_b = ct.load(b, pid, (tile_size[],))
39+
tile_a = ct.load(a, pid, (tile_size,))
40+
tile_b = ct.load(b, pid, (tile_size,))
4141
ct.store(c, pid, tile_a + tile_b)
4242
return
4343
end
@@ -297,28 +297,32 @@ permutedims(tile, (3, 1, 2))
297297

298298
This applies to `bid`, `num_blocks`, `permutedims`, `reshape`, dimension arguments, etc.
299299

300-
### `Val`-like constants
300+
### Compile-time constants
301301

302-
CuTile.jl uses `ct.Constant{T}` to encode compile-time constant values in the type domain, similar to how `Val` works. An explicit `[]` is needed to extract the value at runtime:
302+
Python annotates constant parameters in the kernel signature and passes plain values at launch.
303+
Julia is the reverse: kernel signatures use plain types, and constants are wrapped at launch:
303304

304305
```python
305306
# Python
306307
@ct.kernel
307-
def kernel(a, b, tile_size):
308+
def kernel(a, b, tile_size: ct.Constant[int]):
308309
tile = ct.load(a, index=(0,), shape=(tile_size,))
309310

310311
ct.launch(stream, grid, kernel, (a, b, 16))
311312
```
312313

313314
```julia
314315
# Julia
315-
function kernel(a, b, tile_size::ct.Constant{Int})
316-
tile = ct.load(a, 1, (tile_size[],))
316+
function kernel(a, b, tile_size::Int)
317+
tile = ct.load(a, 1, (tile_size,))
317318
end
318319

319320
ct.launch(kernel, grid, a, b, ct.Constant(16))
320321
```
321322

323+
`ct.Constant` arguments generate no kernel parameter; the value is embedded directly in
324+
the compiled code. Different constant values produce different kernel specializations.
325+
322326
### Broadcasting and Math Functions
323327

324328
Python's operators and math functions work directly on tiles with automatic broadcasting.

examples/batchmatmul.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,32 +12,31 @@ import cuTile as ct
1212
# A: (M, K, Batch), B: (K, N, Batch), C: (M, N, Batch)
1313
# Grid: (M_tiles, N_tiles, Batch)
1414
function batch_matmul_kernel(A::ct.TileArray{T,3}, B::ct.TileArray{T,3}, C::ct.TileArray{T,3},
15-
tm::ct.Constant{Int}, tn::ct.Constant{Int},
16-
tk::ct.Constant{Int}) where {T}
15+
tm::Int, tn::Int, tk::Int) where {T}
1716
# Grid dimensions (1-indexed)
1817
bid_m = ct.bid(1) # M tile index
1918
bid_n = ct.bid(2) # N tile index
2019
pid_batch = ct.bid(3) # Batch index
2120

2221
# Number of K tiles to iterate over
2322
K = size(A, 2)
24-
num_k = cld(K, Int32(tk[]))
23+
num_k = cld(K, Int32(tk))
2524

2625
# Initialize accumulator with Float32 for precision
27-
acc = ct.full((tm[], tn[]), zero(Float32), Float32)
26+
acc = ct.full((tm, tn), zero(Float32), Float32)
2827

2928
# K reduction loop
3029
k = Int32(1)
3130
while k <= num_k
3231
# Load 3D tiles: (tm, tk, 1) and (tk, tn, 1)
33-
a = ct.load(A, (bid_m, k, pid_batch), (tm[], tk[], 1);
32+
a = ct.load(A, (bid_m, k, pid_batch), (tm, tk, 1);
3433
padding_mode=ct.PaddingMode.Zero)
35-
b = ct.load(B, (k, bid_n, pid_batch), (tk[], tn[], 1);
34+
b = ct.load(B, (k, bid_n, pid_batch), (tk, tn, 1);
3635
padding_mode=ct.PaddingMode.Zero)
3736

3837
# Reshape 3D tiles to 2D for mma
39-
a_2d = reshape(a, (tm[], tk[]))
40-
b_2d = reshape(b, (tk[], tn[]))
38+
a_2d = reshape(a, (tm, tk))
39+
b_2d = reshape(b, (tk, tn))
4140

4241
# Convert to TF32 for tensor cores (Float32 inputs only)
4342
if T === Float32
@@ -51,7 +50,7 @@ function batch_matmul_kernel(A::ct.TileArray{T,3}, B::ct.TileArray{T,3}, C::ct.T
5150

5251
# Convert to output type, reshape to 3D, and store
5352
result = convert(ct.Tile{T}, acc)
54-
result_3d = reshape(result, (tm[], tn[], 1))
53+
result_3d = reshape(result, (tm, tn, 1))
5554
ct.store(C, (bid_m, bid_n, pid_batch), result_3d)
5655

5756
return nothing

examples/fft.jl

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,28 +29,28 @@ function fft_kernel(
2929
W2::ct.TileArray{Float32, 3}, # W2 (F2, F2, 2)
3030
T0::ct.TileArray{Float32, 3}, # T0 (F1F2, F0, 2) twiddle factors
3131
T1::ct.TileArray{Float32, 3}, # T1 (F0F2, F1, 2) twiddle factors
32-
n_const::ct.Constant{Int},
33-
f0_const::ct.Constant{Int},
34-
f1_const::ct.Constant{Int},
35-
f2_const::ct.Constant{Int},
36-
f0f1_const::ct.Constant{Int},
37-
f1f2_const::ct.Constant{Int},
38-
f0f2_const::ct.Constant{Int},
39-
bs_const::ct.Constant{Int},
40-
d_const::ct.Constant{Int},
41-
n2d_const::ct.Constant{Int}
32+
n_const::Int,
33+
f0_const::Int,
34+
f1_const::Int,
35+
f2_const::Int,
36+
f0f1_const::Int,
37+
f1f2_const::Int,
38+
f0f2_const::Int,
39+
bs_const::Int,
40+
d_const::Int,
41+
n2d_const::Int
4242
)
4343
# Extract constant values
44-
N = n_const[]
45-
F0 = f0_const[]
46-
F1 = f1_const[]
47-
F2 = f2_const[]
48-
F0F1 = f0f1_const[]
49-
F1F2 = f1f2_const[]
50-
F0F2 = f0f2_const[]
51-
BS = bs_const[]
52-
D = d_const[]
53-
N2D = n2d_const[]
44+
N = n_const
45+
F0 = f0_const
46+
F1 = f1_const
47+
F2 = f2_const
48+
F0F1 = f0f1_const
49+
F1F2 = f1f2_const
50+
F0F2 = f0f2_const
51+
BS = bs_const
52+
D = d_const
53+
N2D = n2d_const
5454

5555
bid = ct.bid(1)
5656

examples/layernorm.jl

Lines changed: 34 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
using CUDA
66
import cuTile as ct
77

8-
const ConstInt = ct.Constant{Int}
9-
108
#=============================================================================
119
LayerNorm Forward Kernel
1210
@@ -25,43 +23,43 @@ const ConstInt = ct.Constant{Int}
2523
function layer_norm_fwd(X::ct.TileArray{Float32, 2}, W::ct.TileArray{Float32, 1},
2624
B::ct.TileArray{Float32, 1}, Y::ct.TileArray{Float32, 2},
2725
Mean::ct.TileArray{Float32, 1}, Rstd::ct.TileArray{Float32, 1},
28-
eps::ct.Constant{Float32}, TILE_N::ConstInt)
26+
eps::Float32, TILE_N::Int)
2927
bid_m = ct.bid(1)
30-
num_tiles = ct.num_tiles(X, 2, (1, TILE_N[]))
28+
num_tiles = ct.num_tiles(X, 2, (1, TILE_N))
3129
N = size(X, 2)
3230

3331
# Compute mean
34-
mean = ct.full((1, TILE_N[]), 0.0f0, Float32)
32+
mean = ct.full((1, TILE_N), 0.0f0, Float32)
3533
j = Int32(1)
3634
while j <= num_tiles
37-
tx = ct.load(X, (bid_m, j), (1, TILE_N[]); padding_mode=ct.PaddingMode.Zero)
35+
tx = ct.load(X, (bid_m, j), (1, TILE_N); padding_mode=ct.PaddingMode.Zero)
3836
mean = mean .+ tx
3937
j += Int32(1)
4038
end
4139
mean = sum(mean; dims=2) / N
4240
ct.store(Mean, bid_m, mean)
4341

4442
# Compute variance
45-
var = ct.full((1, TILE_N[]), 0.0f0, Float32)
43+
var = ct.full((1, TILE_N), 0.0f0, Float32)
4644
j = Int32(1)
4745
while j <= num_tiles
48-
tx = ct.load(X, (bid_m, j), (1, TILE_N[]); padding_mode=ct.PaddingMode.Zero)
46+
tx = ct.load(X, (bid_m, j), (1, TILE_N); padding_mode=ct.PaddingMode.Zero)
4947
# Mask for valid elements
50-
mask = ct.broadcast_to(((j - Int32(1)) * Int32(TILE_N[]) .+ ct.arange((TILE_N[],), Int32)) .<= N, (1, TILE_N[]))
48+
mask = ct.broadcast_to(((j - Int32(1)) * Int32(TILE_N) .+ ct.arange((TILE_N,), Int32)) .<= N, (1, TILE_N))
5149
centered_tx = ifelse.(mask, tx .- mean, 0.0f0)
5250
var = var .+ (centered_tx .^ 2.0f0)
5351
j += Int32(1)
5452
end
5553
var = sum(var; dims=2) / N
56-
rstd = 1.0f0 ./ sqrt.(var .+ eps[])
54+
rstd = 1.0f0 ./ sqrt.(var .+ eps)
5755
ct.store(Rstd, bid_m, rstd)
5856

5957
# Normalize and apply affine transformation
6058
j = Int32(1)
6159
while j <= num_tiles
62-
tx = ct.load(X, (bid_m, j), (1, TILE_N[]); padding_mode=ct.PaddingMode.Zero)
63-
tw = reshape(ct.load(W, j, (TILE_N[],); padding_mode=ct.PaddingMode.Zero), (1, TILE_N[]))
64-
tb = reshape(ct.load(B, j, (TILE_N[],); padding_mode=ct.PaddingMode.Zero), (1, TILE_N[]))
60+
tx = ct.load(X, (bid_m, j), (1, TILE_N); padding_mode=ct.PaddingMode.Zero)
61+
tw = reshape(ct.load(W, j, (TILE_N,); padding_mode=ct.PaddingMode.Zero), (1, TILE_N))
62+
tb = reshape(ct.load(B, j, (TILE_N,); padding_mode=ct.PaddingMode.Zero), (1, TILE_N))
6563
ty = (tx .- mean) .* rstd
6664
ty = ty .* tw .+ tb
6765
ct.store(Y, (bid_m, j), ty)
@@ -123,21 +121,21 @@ Args:
123121
function layer_norm_bwd_dx(DX::ct.TileArray{Float32, 2}, DY::ct.TileArray{Float32, 2},
124122
X::ct.TileArray{Float32, 2}, W::ct.TileArray{Float32, 1},
125123
Mean::ct.TileArray{Float32, 1}, Rstd::ct.TileArray{Float32, 1},
126-
TILE_N::ConstInt)
124+
TILE_N::Int)
127125
bid_m = ct.bid(1)
128-
num_tiles = ct.num_tiles(X, 2, (1, TILE_N[]))
126+
num_tiles = ct.num_tiles(X, 2, (1, TILE_N))
129127
N = size(X, 2)
130128

131129
# Load mean and rstd for this row
132130
mean = ct.load(Mean, bid_m, (1,); padding_mode=ct.PaddingMode.Zero)
133131
rstd = ct.load(Rstd, bid_m, (1,); padding_mode=ct.PaddingMode.Zero)
134132

135133
# First pass: compute c1 and c2 reduction terms
136-
c1 = ct.full((1, TILE_N[]), 0.0f0, Float32)
137-
c2 = ct.full((1, TILE_N[]), 0.0f0, Float32)
134+
c1 = ct.full((1, TILE_N), 0.0f0, Float32)
135+
c2 = ct.full((1, TILE_N), 0.0f0, Float32)
138136
j = Int32(1)
139137
while j <= num_tiles
140-
_, xhat, wdy = bwd_helper(X, W, DY, bid_m, j, mean, rstd, TILE_N[], N)
138+
_, xhat, wdy = bwd_helper(X, W, DY, bid_m, j, mean, rstd, TILE_N, N)
141139
c1 = c1 .+ (xhat .* wdy)
142140
c2 = c2 .+ wdy
143141
j += Int32(1)
@@ -148,7 +146,7 @@ function layer_norm_bwd_dx(DX::ct.TileArray{Float32, 2}, DY::ct.TileArray{Float3
148146
# Second pass: compute dX
149147
j = Int32(1)
150148
while j <= num_tiles
151-
_, xhat, wdy = bwd_helper(X, W, DY, bid_m, j, mean, rstd, TILE_N[], N)
149+
_, xhat, wdy = bwd_helper(X, W, DY, bid_m, j, mean, rstd, TILE_N, N)
152150
tdx = (wdy .- (xhat .* c1 .+ c2)) .* rstd
153151
ct.store(DX, (bid_m, j), tdx)
154152
j += Int32(1)
@@ -181,22 +179,22 @@ function layer_norm_bwd_dx_partial_dwdb(DX::ct.TileArray{Float32, 2}, DY::ct.Til
181179
X::ct.TileArray{Float32, 2}, W::ct.TileArray{Float32, 1},
182180
Mean::ct.TileArray{Float32, 1}, Rstd::ct.TileArray{Float32, 1},
183181
Locks::ct.TileArray{Int, 1},
184-
GROUP_SIZE_M::ConstInt, TILE_N::ConstInt)
182+
GROUP_SIZE_M::Int, TILE_N::Int)
185183
bid_m = ct.bid(1)
186-
num_tiles = ct.num_tiles(X, 2, (1, TILE_N[]))
184+
num_tiles = ct.num_tiles(X, 2, (1, TILE_N))
187185
N = size(X, 2)
188-
group_bid_m = ((bid_m - Int32(1)) % Int32(GROUP_SIZE_M[])) + Int32(1)
186+
group_bid_m = ((bid_m - Int32(1)) % Int32(GROUP_SIZE_M)) + Int32(1)
189187

190188
# Load mean and rstd for this row
191189
mean = ct.load(Mean, bid_m, (1,); padding_mode=ct.PaddingMode.Zero)
192190
rstd = ct.load(Rstd, bid_m, (1,); padding_mode=ct.PaddingMode.Zero)
193191

194192
# First pass: compute c1 and c2 reduction terms
195-
c1 = ct.full((1, TILE_N[]), 0.0f0, Float32)
196-
c2 = ct.full((1, TILE_N[]), 0.0f0, Float32)
193+
c1 = ct.full((1, TILE_N), 0.0f0, Float32)
194+
c2 = ct.full((1, TILE_N), 0.0f0, Float32)
197195
j = Int32(1)
198196
while j <= num_tiles
199-
_, xhat, wdy = bwd_helper(X, W, DY, bid_m, j, mean, rstd, TILE_N[], N)
197+
_, xhat, wdy = bwd_helper(X, W, DY, bid_m, j, mean, rstd, TILE_N, N)
200198
c1 = c1 .+ (xhat .* wdy)
201199
c2 = c2 .+ wdy
202200
j += Int32(1)
@@ -207,12 +205,12 @@ function layer_norm_bwd_dx_partial_dwdb(DX::ct.TileArray{Float32, 2}, DY::ct.Til
207205
# Second pass: compute dX and partial dW/dB
208206
j = Int32(1)
209207
while j <= num_tiles
210-
tdy, xhat, wdy = bwd_helper(X, W, DY, bid_m, j, mean, rstd, TILE_N[], N)
208+
tdy, xhat, wdy = bwd_helper(X, W, DY, bid_m, j, mean, rstd, TILE_N, N)
211209
tdx = (wdy .- (xhat .* c1 .+ c2)) .* rstd
212210
ct.store(DX, (bid_m, j), tdx)
213211

214-
partial_dw = reshape(tdy .* xhat, (TILE_N[], 1))
215-
partial_db = reshape(tdy, (TILE_N[], 1))
212+
partial_dw = reshape(tdy .* xhat, (TILE_N, 1))
213+
partial_db = reshape(tdy, (TILE_N, 1))
216214

217215
# Acquire spinlock
218216
while ct.atomic_cas(Locks, group_bid_m, 0, 1;
@@ -221,8 +219,8 @@ function layer_norm_bwd_dx_partial_dwdb(DX::ct.TileArray{Float32, 2}, DY::ct.Til
221219
end
222220

223221
# Critical section: accumulate partial gradients
224-
partial_dw = partial_dw .+ ct.load(DW, (j, group_bid_m), (TILE_N[], 1); padding_mode=ct.PaddingMode.Zero)
225-
partial_db = partial_db .+ ct.load(DB, (j, group_bid_m), (TILE_N[], 1); padding_mode=ct.PaddingMode.Zero)
222+
partial_dw = partial_dw .+ ct.load(DW, (j, group_bid_m), (TILE_N, 1); padding_mode=ct.PaddingMode.Zero)
223+
partial_db = partial_db .+ ct.load(DB, (j, group_bid_m), (TILE_N, 1); padding_mode=ct.PaddingMode.Zero)
226224
ct.store(DW, (j, group_bid_m), partial_dw)
227225
ct.store(DB, (j, group_bid_m), partial_db)
228226

@@ -251,16 +249,16 @@ Args:
251249
"""
252250
function layer_norm_bwd_dwdb(DW::ct.TileArray{Float32, 2}, DB::ct.TileArray{Float32, 2},
253251
FINAL_DW::ct.TileArray{Float32, 1}, FINAL_DB::ct.TileArray{Float32, 1},
254-
TILE_M::ConstInt, TILE_N::ConstInt)
252+
TILE_M::Int, TILE_N::Int)
255253
bid_n = ct.bid(1)
256-
num_tiles = ct.num_tiles(DW, 2, (TILE_N[], TILE_M[]))
254+
num_tiles = ct.num_tiles(DW, 2, (TILE_N, TILE_M))
257255

258-
dw = ct.zeros((TILE_N[], TILE_M[]), Float32)
259-
db = ct.zeros((TILE_N[], TILE_M[]), Float32)
256+
dw = ct.zeros((TILE_N, TILE_M), Float32)
257+
db = ct.zeros((TILE_N, TILE_M), Float32)
260258
i = Int32(1)
261259
while i <= num_tiles
262-
dw = dw .+ ct.load(DW, (bid_n, i), (TILE_N[], TILE_M[]); padding_mode=ct.PaddingMode.Zero)
263-
db = db .+ ct.load(DB, (bid_n, i), (TILE_N[], TILE_M[]); padding_mode=ct.PaddingMode.Zero)
260+
dw = dw .+ ct.load(DW, (bid_n, i), (TILE_N, TILE_M); padding_mode=ct.PaddingMode.Zero)
261+
db = db .+ ct.load(DB, (bid_n, i), (TILE_N, TILE_M); padding_mode=ct.PaddingMode.Zero)
264262
i += Int32(1)
265263
end
266264
sum_dw = sum(dw; dims=2)

examples/matmul.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,22 @@ end
2323
# Matrix multiplication kernel with K reduction loop and 2D swizzle
2424
# C = A @ B where A is (M, K), B is (K, N), C is (M, N)
2525
function matmul_kernel(A::ct.TileArray{T,2}, B::ct.TileArray{T,2}, C::ct.TileArray{T,2},
26-
tm::ct.Constant{Int}, tn::ct.Constant{Int}, tk::ct.Constant{Int}) where {T}
26+
tm::Int, tn::Int, tk::Int) where {T}
2727
# Use 1D grid with swizzle for better cache locality
2828
bid = ct.bid(1)
2929
M = size(A, 1)
3030
N = size(B, 2)
3131
# swizzle_2d expects 0-indexed bid, returns 0-indexed tile coords
32-
bid_m_0, bid_n_0 = swizzle_2d(M, N, tm[], tn[], 8, bid - Int32(1))
32+
bid_m_0, bid_n_0 = swizzle_2d(M, N, tm, tn, 8, bid - Int32(1))
3333
# Convert to 1-indexed tile coordinates
3434
bid_m = bid_m_0 + Int32(1)
3535
bid_n = bid_n_0 + Int32(1)
3636

3737
# Number of K tiles to iterate over
38-
num_k = ct.num_tiles(A, 2, (tm[], tk[]))
38+
num_k = ct.num_tiles(A, 2, (tm, tk))
3939

4040
# Initialize accumulator with Float32 for precision
41-
acc = ct.full((tm[], tn[]), zero(Float32), Float32)
41+
acc = ct.full((tm, tn), zero(Float32), Float32)
4242

4343
# K reduction loop - accumulate partial products
4444
# NOTE: Uses while-loop pattern. Native `for k in 0:n` syntax generates complex
@@ -47,8 +47,8 @@ function matmul_kernel(A::ct.TileArray{T,2}, B::ct.TileArray{T,2}, C::ct.TileArr
4747
while k <= num_k
4848
# Load and convert to TF32 for tensor cores (Float32 only)
4949
# padding_mode=Zero ensures out-of-bounds reads return zero (for non-aligned dimensions)
50-
a = ct.load(A, (bid_m, k), (tm[], tk[]); padding_mode=ct.PaddingMode.Zero)
51-
b = ct.load(B, (k, bid_n), (tk[], tn[]); padding_mode=ct.PaddingMode.Zero)
50+
a = ct.load(A, (bid_m, k), (tm, tk); padding_mode=ct.PaddingMode.Zero)
51+
b = ct.load(B, (k, bid_n), (tk, tn); padding_mode=ct.PaddingMode.Zero)
5252
if T === Float32
5353
a = convert(ct.Tile{ct.TFloat32}, a)
5454
b = convert(ct.Tile{ct.TFloat32}, b)

examples/transpose.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ import cuTile as ct
88
# Transpose kernel with TileArray and constant tile sizes
99
# TileArray carries size/stride metadata, Constant parameters are ghost types
1010
function transpose_kernel(x::ct.TileArray{T,2}, y::ct.TileArray{T,2},
11-
tm::ct.Constant{Int}, tn::ct.Constant{Int}) where {T}
11+
tm::Int, tn::Int) where {T}
1212
bidx = ct.bid(1)
1313
bidy = ct.bid(2)
14-
input_tile = ct.load(x, (bidx, bidy), (tm[], tn[]))
14+
input_tile = ct.load(x, (bidx, bidy), (tm, tn))
1515
transposed_tile = transpose(input_tile)
1616
ct.store(y, (bidy, bidx), transposed_tile)
1717
return

0 commit comments

Comments
 (0)