Skip to content

Commit 84f8fa8

Browse files
maleadtclaude
andauthored
Use Julia-native for loops (#174)
The new IRStructurizer engine handles Julia's for-in-range iterator protocol, so counting while loops are no longer needed. Convert all `k = Int32(1); while k <= n; ...; k += Int32(1); end` patterns to `for k in Int32(1):n; ...; end` across examples and tests. Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 93bc239 commit 84f8fa8

File tree

13 files changed

+141
-187
lines changed

13 files changed

+141
-187
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,5 @@ CUDA_Tile_jll = "13.1"
3030
CompilerCaching = "0.2"
3131
EnumX = "1.0"
3232
GPUArrays = "11"
33-
IRStructurizer = "0.4"
33+
IRStructurizer = "0.5"
3434
julia = "1.11"

README.md

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -511,30 +511,6 @@ standard Julia silently produce truncated or wrapped results instead:
511511
Assertions may be added in the future for testing purposes.
512512

513513

514-
## Limitations
515-
516-
### `for` loops
517-
518-
The compiler recognizes simple while-loop patterns but not Julia's iterator-based `for` loops. Write such loops as:
519-
520-
```julia
521-
# Do this:
522-
i = 1
523-
while i <= n
524-
# ...
525-
i += 1
526-
end
527-
528-
# Not this:
529-
for i in 1:n
530-
# ...
531-
end
532-
```
533-
534-
Also make sure `i`, `n`, and the increment all have the same type.
535-
536-
537-
538514
## Host-level operations
539515

540516
cuTile.jl also provides a limited set of host-level APIs to use cuTile without

examples/batchmatmul.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ function batch_matmul_kernel(A::ct.TileArray{T,3}, B::ct.TileArray{T,3}, C::ct.T
2626
acc = zeros(Float32, tm, tn)
2727

2828
# K reduction loop
29-
k = Int32(1)
30-
while k <= num_k
29+
for k in Int32(1):num_k
3130
# Load 3D tiles: (tm, tk, 1) and (tk, tn, 1)
3231
a = ct.load(A; index=(bid_m, k, pid_batch), shape=(tm, tk, 1),
3332
padding_mode=ct.PaddingMode.Zero)
@@ -45,7 +44,6 @@ function batch_matmul_kernel(A::ct.TileArray{T,3}, B::ct.TileArray{T,3}, C::ct.T
4544
end
4645

4746
acc = muladd(a_2d, b_2d, acc)
48-
k += Int32(1)
4947
end
5048

5149
# Convert to output type, reshape to 3D, and store

examples/fmha.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,7 @@ function fmha_kernel(Q::ct.TileArray{T, 4}, K::ct.TileArray{T, 4},
7373
end
7474

7575
# Loop over K, V blocks
76-
j = Int32(0)
77-
while j < Tc
76+
for j in Int32(0):Tc-Int32(1)
7877
# QK product
7978
# K is (D_k, SeqLen_KV, KVH, Batch)
8079
# Load (TILE_N, TILE_D, 1, 1) with order=(2,1,3,4) to transpose D and N
@@ -123,8 +122,6 @@ function fmha_kernel(Q::ct.TileArray{T, 4}, K::ct.TileArray{T, 4},
123122
# (TILE_D, TILE_N) @ (TILE_N, TILE_M) = (TILE_D, TILE_M)
124123
acc = muladd(v, p, acc)
125124
m_i = m_ij
126-
127-
j += Int32(1)
128125
end
129126

130127
# Final normalization and store

examples/layernorm.jl

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -33,40 +33,34 @@ function layer_norm_fwd(X::ct.TileArray{Float32, 2}, W::ct.TileArray{Float32, 1}
3333

3434
# Compute mean
3535
mean = zeros(Float32, (TILE_N, 1))
36-
j = Int32(1)
37-
while j <= num_tiles
36+
for j in Int32(1):num_tiles
3837
tx = ct.load(X; index=(j, bid_m), shape=(TILE_N, 1), padding_mode=ct.PaddingMode.Zero)
3938
mean = mean .+ tx
40-
j += Int32(1)
4139
end
4240
mean = sum(mean; dims=1) / N
4341
ct.store(Mean; index=bid_m, tile=mean)
4442

4543
# Compute variance
4644
var = zeros(Float32, (TILE_N, 1))
47-
j = Int32(1)
48-
while j <= num_tiles
45+
for j in Int32(1):num_tiles
4946
tx = ct.load(X; index=(j, bid_m), shape=(TILE_N, 1), padding_mode=ct.PaddingMode.Zero)
5047
# Mask for valid elements
5148
mask = reshape(((j - Int32(1)) * Int32(TILE_N) .+ ct.arange(TILE_N)) .<= N, (TILE_N, 1))
5249
centered_tx = ifelse.(mask, tx .- mean, 0.0f0)
5350
var = var .+ (centered_tx .^ 2.0f0)
54-
j += Int32(1)
5551
end
5652
var = sum(var; dims=1) / N
5753
rstd = 1.0f0 ./ sqrt.(var .+ eps)
5854
ct.store(Rstd; index=bid_m, tile=rstd)
5955

6056
# Normalize and apply affine transformation
61-
j = Int32(1)
62-
while j <= num_tiles
57+
for j in Int32(1):num_tiles
6358
tx = ct.load(X; index=(j, bid_m), shape=(TILE_N, 1), padding_mode=ct.PaddingMode.Zero)
6459
tw = reshape(ct.load(W; index=j, shape=(TILE_N,), padding_mode=ct.PaddingMode.Zero), (TILE_N, 1))
6560
tb = reshape(ct.load(B; index=j, shape=(TILE_N,), padding_mode=ct.PaddingMode.Zero), (TILE_N, 1))
6661
ty = (tx .- mean) .* rstd
6762
ty = ty .* tw .+ tb
6863
ct.store(Y; index=(j, bid_m), tile=ty)
69-
j += Int32(1)
7064
end
7165

7266
return
@@ -136,23 +130,19 @@ function layer_norm_bwd_dx(DX::ct.TileArray{Float32, 2}, DY::ct.TileArray{Float3
136130
# First pass: compute c1 and c2 reduction terms
137131
c1 = zeros(Float32, (TILE_N, 1))
138132
c2 = zeros(Float32, (TILE_N, 1))
139-
j = Int32(1)
140-
while j <= num_tiles
133+
for j in Int32(1):num_tiles
141134
_, xhat, wdy = bwd_helper(X, W, DY, bid_m, j, mean, rstd, TILE_N, N)
142135
c1 = c1 .+ (xhat .* wdy)
143136
c2 = c2 .+ wdy
144-
j += Int32(1)
145137
end
146138
c1 = sum(c1; dims=1) / N
147139
c2 = sum(c2; dims=1) / N
148140

149141
# Second pass: compute dX
150-
j = Int32(1)
151-
while j <= num_tiles
142+
for j in Int32(1):num_tiles
152143
_, xhat, wdy = bwd_helper(X, W, DY, bid_m, j, mean, rstd, TILE_N, N)
153144
tdx = (wdy .- (xhat .* c1 .+ c2)) .* rstd
154145
ct.store(DX; index=(j, bid_m), tile=tdx)
155-
j += Int32(1)
156146
end
157147

158148
return
@@ -195,19 +185,16 @@ function layer_norm_bwd_dx_partial_dwdb(DX::ct.TileArray{Float32, 2}, DY::ct.Til
195185
# First pass: compute c1 and c2 reduction terms
196186
c1 = zeros(Float32, (TILE_N, 1))
197187
c2 = zeros(Float32, (TILE_N, 1))
198-
j = Int32(1)
199-
while j <= num_tiles
188+
for j in Int32(1):num_tiles
200189
_, xhat, wdy = bwd_helper(X, W, DY, bid_m, j, mean, rstd, TILE_N, N)
201190
c1 = c1 .+ (xhat .* wdy)
202191
c2 = c2 .+ wdy
203-
j += Int32(1)
204192
end
205193
c1 = sum(c1; dims=1) / N
206194
c2 = sum(c2; dims=1) / N
207195

208196
# Second pass: compute dX and partial dW/dB
209-
j = Int32(1)
210-
while j <= num_tiles
197+
for j in Int32(1):num_tiles
211198
tdy, xhat, wdy = bwd_helper(X, W, DY, bid_m, j, mean, rstd, TILE_N, N)
212199
tdx = (wdy .- (xhat .* c1 .+ c2)) .* rstd
213200
ct.store(DX; index=(j, bid_m), tile=tdx)
@@ -230,8 +217,6 @@ function layer_norm_bwd_dx_partial_dwdb(DX::ct.TileArray{Float32, 2}, DY::ct.Til
230217
# Release spinlock
231218
ct.atomic_xchg(Locks, group_bid_m, 0;
232219
memory_order=ct.MemoryOrder.Release)
233-
234-
j += Int32(1)
235220
end
236221

237222
return
@@ -258,11 +243,9 @@ function layer_norm_bwd_dwdb(DW::ct.TileArray{Float32, 2}, DB::ct.TileArray{Floa
258243

259244
dw = zeros(Float32, (TILE_N, TILE_M))
260245
db = zeros(Float32, (TILE_N, TILE_M))
261-
i = Int32(1)
262-
while i <= num_tiles
246+
for i in Int32(1):num_tiles
263247
dw = dw .+ ct.load(DW; index=(bid_n, i), shape=(TILE_N, TILE_M), padding_mode=ct.PaddingMode.Zero)
264248
db = db .+ ct.load(DB; index=(bid_n, i), shape=(TILE_N, TILE_M), padding_mode=ct.PaddingMode.Zero)
265-
i += Int32(1)
266249
end
267250
sum_dw = sum(dw; dims=2)
268251
sum_db = sum(db; dims=2)

examples/matmul.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,7 @@ function matmul_kernel(A::ct.TileArray{T,2}, B::ct.TileArray{T,2}, C::ct.TileArr
4242
acc = zeros(Float32, tm, tn)
4343

4444
# K reduction loop - accumulate partial products
45-
# NOTE: Uses while-loop pattern. Native `for k in 0:n` syntax generates complex
46-
# iterator protocol IR that doesn't map cleanly to ForOp. Use while-loops for now.
47-
k = Int32(1)
48-
while k <= num_k
45+
for k in Int32(1):num_k
4946
# Load and convert to TF32 for tensor cores (Float32 only)
5047
# padding_mode=Zero ensures out-of-bounds reads return zero (for non-aligned dimensions)
5148
a = ct.load(A; index=(bid_m, k), shape=(tm, tk), padding_mode=ct.PaddingMode.Zero)
@@ -55,7 +52,6 @@ function matmul_kernel(A::ct.TileArray{T,2}, B::ct.TileArray{T,2}, C::ct.TileArr
5552
b = convert(ct.Tile{ct.TFloat32}, b)
5653
end
5754
acc = muladd(a, b, acc)
58-
k += Int32(1)
5955
end
6056

6157
# Convert accumulator to output type and store

examples/moe.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,7 @@ function fused_moe_kernel(A::ct.TileArray{T, 2}, B::ct.TileArray{T, 3},
6868
acc = zeros(Float32, TILE_N, TILE_M)
6969
num_k = cld(K, Int32(TILE_K))
7070

71-
k = Int32(1)
72-
while k <= num_k
71+
for k in Int32(1):num_k
7372
# 1-indexed row indices into A's K dimension
7473
a_k_indices = (k - Int32(1)) * Int32(TILE_K) .+ ct.arange(TILE_K)
7574

@@ -87,7 +86,6 @@ function fused_moe_kernel(A::ct.TileArray{T, 2}, B::ct.TileArray{T, 3},
8786

8887
# acc(N,M) += b(N,K) @ a(K,M)
8988
acc = muladd(b, a, acc)
90-
k += Int32(1)
9189
end
9290

9391
if mul_routed_weight

src/language/overlays.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,29 @@ macro overlay(ex)
88
end
99

1010

11+
#=============================================================================
12+
StepRange Construction
13+
=============================================================================#
14+
15+
# GPU-safe replacement for Base.steprange_last to enable `for i in start:step:stop`.
16+
# The original pulls in ArgumentError, @noinline overflow_case, and checked_srem_int.
17+
# This overlay uses unsigned arithmetic (bitcast → unsigned rem → bitcast) which
18+
# produces identical results and maps cleanly to Tile IR (signless integers make
19+
# signed↔unsigned bitcasts no-ops).
20+
@overlay function Base.steprange_last(start::T, step::T, stop::T) where {T<:Base.BitInteger}
21+
stop == start && return stop
22+
if step > zero(step)
23+
stop < start && return start - oneunit(step) # empty range
24+
remain = signed(unsigned(stop - start) % unsigned(step))
25+
return stop - remain
26+
else
27+
stop > start && return start + oneunit(step) # empty range
28+
remain = signed(unsigned(start - stop) % unsigned(-step))
29+
return stop + remain
30+
end
31+
end
32+
33+
1134
#=============================================================================
1235
Broadcasting
1336
=============================================================================#

src/mapreduce.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,7 @@ _atomic_op(_, ::Type) = nothing
2727
start_d = idx_d
2828
end
2929

30-
@nwhileloops($N,
31-
d -> (idx_d <= n_d),
32-
d -> (idx_d = start_d),
33-
d -> (idx_d = idx_d + reduce_stride[d]),
30+
@nloops($N, idx, d -> (start_d:reduce_stride[d]:n_d),
3431
begin
3532
tile = load(src, (@ntuple $N d -> idx_d), tile_size; padding_mode=pad_mode)
3633
acc = op.(acc, f.(tile))

src/utils.jl

Lines changed: 1 addition & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Base.Cartesian: @nexprs, @ntuple, inlineanonymous
1+
using Base.Cartesian: @nexprs, @ntuple, @nloops
22

33
#=============================================================================
44
Grid and tile sizing helpers (used by broadcast and mapreduce)
@@ -71,61 +71,3 @@ function _compute_tile_sizes(input_size::NTuple{N,Int}, dim_order; budget::Int=4
7171
return NTuple{N,Int}(ts)
7272
end
7373

74-
#=============================================================================
75-
@nwhileloops — while-loop variant of Base.Cartesian.@nloops
76-
=============================================================================#
77-
78-
"""
79-
@nwhileloops N condexpr [preexpr [postexpr]] body
80-
81-
Generate N nested `while` loops, analogous to `Base.Cartesian.@nloops` but
82-
using `while` instead of `for`. This is needed because the cuTile compiler
83-
only recognizes while-loop patterns for structured control flow.
84-
85-
`condexpr` and the optional `preexpr`/`postexpr` are `d->` anonymous functions
86-
specialized per dimension with Cartesian `_d` suffix naming. If you want just
87-
a post-expression, supply `nothing` for the pre-expression.
88-
89-
# Example
90-
```julia
91-
@nwhileloops 2 d->(idx_d <= n_d) d->(idx_d = start_d) d->(idx_d += stride[d]) begin
92-
# innermost body
93-
end
94-
```
95-
generates:
96-
```julia
97-
idx_2 = start_2
98-
while idx_2 <= n_2
99-
idx_1 = start_1
100-
while idx_1 <= n_1
101-
# innermost body
102-
idx_1 += stride[1]
103-
end
104-
idx_2 += stride[2]
105-
end
106-
```
107-
"""
108-
macro nwhileloops(N, condexpr, args...)
109-
_nwhileloops(N, condexpr, args...)
110-
end
111-
112-
function _nwhileloops(N::Int, condexpr::Expr, args::Expr...)
113-
if !(1 <= length(args) <= 3)
114-
throw(ArgumentError("expected 1 to 3 trailing arguments (body, or pre+body, or pre+post+body), got $(length(args))"))
115-
end
116-
body = args[end]
117-
ex = Expr(:escape, body)
118-
for d in 1:N
119-
cond = esc(inlineanonymous(condexpr, d))
120-
preexpr = length(args) > 1 ? esc(inlineanonymous(args[1], d)) : nothing
121-
postexpr = length(args) > 2 ? esc(inlineanonymous(args[2], d)) : nothing
122-
ex = quote
123-
$preexpr
124-
while $cond
125-
$ex
126-
$postexpr
127-
end
128-
end
129-
end
130-
ex
131-
end

0 commit comments

Comments
 (0)