@@ -33,40 +33,34 @@ function layer_norm_fwd(X::ct.TileArray{Float32, 2}, W::ct.TileArray{Float32, 1}
3333
3434 # Compute mean
3535 mean = zeros (Float32, (TILE_N, 1 ))
36- j = Int32 (1 )
37- while j <= num_tiles
36+ for j in Int32 (1 ): num_tiles
3837 tx = ct. load (X; index= (j, bid_m), shape= (TILE_N, 1 ), padding_mode= ct. PaddingMode. Zero)
3938 mean = mean .+ tx
40- j += Int32 (1 )
4139 end
4240 mean = sum (mean; dims= 1 ) / N
4341 ct. store (Mean; index= bid_m, tile= mean)
4442
4543 # Compute variance
4644 var = zeros (Float32, (TILE_N, 1 ))
47- j = Int32 (1 )
48- while j <= num_tiles
45+ for j in Int32 (1 ): num_tiles
4946 tx = ct. load (X; index= (j, bid_m), shape= (TILE_N, 1 ), padding_mode= ct. PaddingMode. Zero)
5047 # Mask for valid elements
5148 mask = reshape (((j - Int32 (1 )) * Int32 (TILE_N) .+ ct. arange (TILE_N)) .<= N, (TILE_N, 1 ))
5249 centered_tx = ifelse .(mask, tx .- mean, 0.0f0 )
5350 var = var .+ (centered_tx .^ 2.0f0 )
54- j += Int32 (1 )
5551 end
5652 var = sum (var; dims= 1 ) / N
5753 rstd = 1.0f0 ./ sqrt .(var .+ eps)
5854 ct. store (Rstd; index= bid_m, tile= rstd)
5955
6056 # Normalize and apply affine transformation
61- j = Int32 (1 )
62- while j <= num_tiles
57+ for j in Int32 (1 ): num_tiles
6358 tx = ct. load (X; index= (j, bid_m), shape= (TILE_N, 1 ), padding_mode= ct. PaddingMode. Zero)
6459 tw = reshape (ct. load (W; index= j, shape= (TILE_N,), padding_mode= ct. PaddingMode. Zero), (TILE_N, 1 ))
6560 tb = reshape (ct. load (B; index= j, shape= (TILE_N,), padding_mode= ct. PaddingMode. Zero), (TILE_N, 1 ))
6661 ty = (tx .- mean) .* rstd
6762 ty = ty .* tw .+ tb
6863 ct. store (Y; index= (j, bid_m), tile= ty)
69- j += Int32 (1 )
7064 end
7165
7266 return
@@ -136,23 +130,19 @@ function layer_norm_bwd_dx(DX::ct.TileArray{Float32, 2}, DY::ct.TileArray{Float3
136130 # First pass: compute c1 and c2 reduction terms
137131 c1 = zeros (Float32, (TILE_N, 1 ))
138132 c2 = zeros (Float32, (TILE_N, 1 ))
139- j = Int32 (1 )
140- while j <= num_tiles
133+ for j in Int32 (1 ): num_tiles
141134 _, xhat, wdy = bwd_helper (X, W, DY, bid_m, j, mean, rstd, TILE_N, N)
142135 c1 = c1 .+ (xhat .* wdy)
143136 c2 = c2 .+ wdy
144- j += Int32 (1 )
145137 end
146138 c1 = sum (c1; dims= 1 ) / N
147139 c2 = sum (c2; dims= 1 ) / N
148140
149141 # Second pass: compute dX
150- j = Int32 (1 )
151- while j <= num_tiles
142+ for j in Int32 (1 ): num_tiles
152143 _, xhat, wdy = bwd_helper (X, W, DY, bid_m, j, mean, rstd, TILE_N, N)
153144 tdx = (wdy .- (xhat .* c1 .+ c2)) .* rstd
154145 ct. store (DX; index= (j, bid_m), tile= tdx)
155- j += Int32 (1 )
156146 end
157147
158148 return
@@ -195,19 +185,16 @@ function layer_norm_bwd_dx_partial_dwdb(DX::ct.TileArray{Float32, 2}, DY::ct.Til
195185 # First pass: compute c1 and c2 reduction terms
196186 c1 = zeros (Float32, (TILE_N, 1 ))
197187 c2 = zeros (Float32, (TILE_N, 1 ))
198- j = Int32 (1 )
199- while j <= num_tiles
188+ for j in Int32 (1 ): num_tiles
200189 _, xhat, wdy = bwd_helper (X, W, DY, bid_m, j, mean, rstd, TILE_N, N)
201190 c1 = c1 .+ (xhat .* wdy)
202191 c2 = c2 .+ wdy
203- j += Int32 (1 )
204192 end
205193 c1 = sum (c1; dims= 1 ) / N
206194 c2 = sum (c2; dims= 1 ) / N
207195
208196 # Second pass: compute dX and partial dW/dB
209- j = Int32 (1 )
210- while j <= num_tiles
197+ for j in Int32 (1 ): num_tiles
211198 tdy, xhat, wdy = bwd_helper (X, W, DY, bid_m, j, mean, rstd, TILE_N, N)
212199 tdx = (wdy .- (xhat .* c1 .+ c2)) .* rstd
213200 ct. store (DX; index= (j, bid_m), tile= tdx)
@@ -230,8 +217,6 @@ function layer_norm_bwd_dx_partial_dwdb(DX::ct.TileArray{Float32, 2}, DY::ct.Til
230217 # Release spinlock
231218 ct. atomic_xchg (Locks, group_bid_m, 0 ;
232219 memory_order= ct. MemoryOrder. Release)
233-
234- j += Int32 (1 )
235220 end
236221
237222 return
@@ -258,11 +243,9 @@ function layer_norm_bwd_dwdb(DW::ct.TileArray{Float32, 2}, DB::ct.TileArray{Floa
258243
259244 dw = zeros (Float32, (TILE_N, TILE_M))
260245 db = zeros (Float32, (TILE_N, TILE_M))
261- i = Int32 (1 )
262- while i <= num_tiles
246+ for i in Int32 (1 ): num_tiles
263247 dw = dw .+ ct. load (DW; index= (bid_n, i), shape= (TILE_N, TILE_M), padding_mode= ct. PaddingMode. Zero)
264248 db = db .+ ct. load (DB; index= (bid_n, i), shape= (TILE_N, TILE_M), padding_mode= ct. PaddingMode. Zero)
265- i += Int32 (1 )
266249 end
267250 sum_dw = sum (dw; dims= 2 )
268251 sum_db = sum (db; dims= 2 )
0 commit comments