55using CUDA
66import cuTile as ct
77
8- const ConstInt = ct. Constant{Int}
9-
108#= ============================================================================
119 LayerNorm Forward Kernel
1210
@@ -25,43 +23,43 @@ const ConstInt = ct.Constant{Int}
2523function layer_norm_fwd (X:: ct.TileArray{Float32, 2} , W:: ct.TileArray{Float32, 1} ,
2624 B:: ct.TileArray{Float32, 1} , Y:: ct.TileArray{Float32, 2} ,
2725 Mean:: ct.TileArray{Float32, 1} , Rstd:: ct.TileArray{Float32, 1} ,
28- eps:: ct.Constant{ Float32} , TILE_N:: ConstInt )
26+ eps:: Float32 , TILE_N:: Int )
2927 bid_m = ct. bid (1 )
30- num_tiles = ct. num_tiles (X, 2 , (1 , TILE_N[] ))
28+ num_tiles = ct. num_tiles (X, 2 , (1 , TILE_N))
3129 N = size (X, 2 )
3230
3331 # Compute mean
34- mean = ct. full ((1 , TILE_N[] ), 0.0f0 , Float32)
32+ mean = ct. full ((1 , TILE_N), 0.0f0 , Float32)
3533 j = Int32 (1 )
3634 while j <= num_tiles
37- tx = ct. load (X, (bid_m, j), (1 , TILE_N[] ); padding_mode= ct. PaddingMode. Zero)
35+ tx = ct. load (X, (bid_m, j), (1 , TILE_N); padding_mode= ct. PaddingMode. Zero)
3836 mean = mean .+ tx
3937 j += Int32 (1 )
4038 end
4139 mean = sum (mean; dims= 2 ) / N
4240 ct. store (Mean, bid_m, mean)
4341
4442 # Compute variance
45- var = ct. full ((1 , TILE_N[] ), 0.0f0 , Float32)
43+ var = ct. full ((1 , TILE_N), 0.0f0 , Float32)
4644 j = Int32 (1 )
4745 while j <= num_tiles
48- tx = ct. load (X, (bid_m, j), (1 , TILE_N[] ); padding_mode= ct. PaddingMode. Zero)
46+ tx = ct. load (X, (bid_m, j), (1 , TILE_N); padding_mode= ct. PaddingMode. Zero)
4947 # Mask for valid elements
50- mask = ct. broadcast_to (((j - Int32 (1 )) * Int32 (TILE_N[] ) .+ ct. arange ((TILE_N[] ,), Int32)) .<= N, (1 , TILE_N[] ))
48+ mask = ct. broadcast_to (((j - Int32 (1 )) * Int32 (TILE_N) .+ ct. arange ((TILE_N,), Int32)) .<= N, (1 , TILE_N))
5149 centered_tx = ifelse .(mask, tx .- mean, 0.0f0 )
5250 var = var .+ (centered_tx .^ 2.0f0 )
5351 j += Int32 (1 )
5452 end
5553 var = sum (var; dims= 2 ) / N
56- rstd = 1.0f0 ./ sqrt .(var .+ eps[] )
54+ rstd = 1.0f0 ./ sqrt .(var .+ eps)
5755 ct. store (Rstd, bid_m, rstd)
5856
5957 # Normalize and apply affine transformation
6058 j = Int32 (1 )
6159 while j <= num_tiles
62- tx = ct. load (X, (bid_m, j), (1 , TILE_N[] ); padding_mode= ct. PaddingMode. Zero)
63- tw = reshape (ct. load (W, j, (TILE_N[] ,); padding_mode= ct. PaddingMode. Zero), (1 , TILE_N[] ))
64- tb = reshape (ct. load (B, j, (TILE_N[] ,); padding_mode= ct. PaddingMode. Zero), (1 , TILE_N[] ))
60+ tx = ct. load (X, (bid_m, j), (1 , TILE_N); padding_mode= ct. PaddingMode. Zero)
61+ tw = reshape (ct. load (W, j, (TILE_N,); padding_mode= ct. PaddingMode. Zero), (1 , TILE_N))
62+ tb = reshape (ct. load (B, j, (TILE_N,); padding_mode= ct. PaddingMode. Zero), (1 , TILE_N))
6563 ty = (tx .- mean) .* rstd
6664 ty = ty .* tw .+ tb
6765 ct. store (Y, (bid_m, j), ty)
@@ -123,21 +121,21 @@ Args:
123121function layer_norm_bwd_dx (DX:: ct.TileArray{Float32, 2} , DY:: ct.TileArray{Float32, 2} ,
124122 X:: ct.TileArray{Float32, 2} , W:: ct.TileArray{Float32, 1} ,
125123 Mean:: ct.TileArray{Float32, 1} , Rstd:: ct.TileArray{Float32, 1} ,
126- TILE_N:: ConstInt )
124+ TILE_N:: Int )
127125 bid_m = ct. bid (1 )
128- num_tiles = ct. num_tiles (X, 2 , (1 , TILE_N[] ))
126+ num_tiles = ct. num_tiles (X, 2 , (1 , TILE_N))
129127 N = size (X, 2 )
130128
131129 # Load mean and rstd for this row
132130 mean = ct. load (Mean, bid_m, (1 ,); padding_mode= ct. PaddingMode. Zero)
133131 rstd = ct. load (Rstd, bid_m, (1 ,); padding_mode= ct. PaddingMode. Zero)
134132
135133 # First pass: compute c1 and c2 reduction terms
136- c1 = ct. full ((1 , TILE_N[] ), 0.0f0 , Float32)
137- c2 = ct. full ((1 , TILE_N[] ), 0.0f0 , Float32)
134+ c1 = ct. full ((1 , TILE_N), 0.0f0 , Float32)
135+ c2 = ct. full ((1 , TILE_N), 0.0f0 , Float32)
138136 j = Int32 (1 )
139137 while j <= num_tiles
140- _, xhat, wdy = bwd_helper (X, W, DY, bid_m, j, mean, rstd, TILE_N[] , N)
138+ _, xhat, wdy = bwd_helper (X, W, DY, bid_m, j, mean, rstd, TILE_N, N)
141139 c1 = c1 .+ (xhat .* wdy)
142140 c2 = c2 .+ wdy
143141 j += Int32 (1 )
@@ -148,7 +146,7 @@ function layer_norm_bwd_dx(DX::ct.TileArray{Float32, 2}, DY::ct.TileArray{Float3
148146 # Second pass: compute dX
149147 j = Int32 (1 )
150148 while j <= num_tiles
151- _, xhat, wdy = bwd_helper (X, W, DY, bid_m, j, mean, rstd, TILE_N[] , N)
149+ _, xhat, wdy = bwd_helper (X, W, DY, bid_m, j, mean, rstd, TILE_N, N)
152150 tdx = (wdy .- (xhat .* c1 .+ c2)) .* rstd
153151 ct. store (DX, (bid_m, j), tdx)
154152 j += Int32 (1 )
@@ -181,22 +179,22 @@ function layer_norm_bwd_dx_partial_dwdb(DX::ct.TileArray{Float32, 2}, DY::ct.Til
181179 X:: ct.TileArray{Float32, 2} , W:: ct.TileArray{Float32, 1} ,
182180 Mean:: ct.TileArray{Float32, 1} , Rstd:: ct.TileArray{Float32, 1} ,
183181 Locks:: ct.TileArray{Int, 1} ,
184- GROUP_SIZE_M:: ConstInt , TILE_N:: ConstInt )
182+ GROUP_SIZE_M:: Int , TILE_N:: Int )
185183 bid_m = ct. bid (1 )
186- num_tiles = ct. num_tiles (X, 2 , (1 , TILE_N[] ))
184+ num_tiles = ct. num_tiles (X, 2 , (1 , TILE_N))
187185 N = size (X, 2 )
188- group_bid_m = ((bid_m - Int32 (1 )) % Int32 (GROUP_SIZE_M[] )) + Int32 (1 )
186+ group_bid_m = ((bid_m - Int32 (1 )) % Int32 (GROUP_SIZE_M)) + Int32 (1 )
189187
190188 # Load mean and rstd for this row
191189 mean = ct. load (Mean, bid_m, (1 ,); padding_mode= ct. PaddingMode. Zero)
192190 rstd = ct. load (Rstd, bid_m, (1 ,); padding_mode= ct. PaddingMode. Zero)
193191
194192 # First pass: compute c1 and c2 reduction terms
195- c1 = ct. full ((1 , TILE_N[] ), 0.0f0 , Float32)
196- c2 = ct. full ((1 , TILE_N[] ), 0.0f0 , Float32)
193+ c1 = ct. full ((1 , TILE_N), 0.0f0 , Float32)
194+ c2 = ct. full ((1 , TILE_N), 0.0f0 , Float32)
197195 j = Int32 (1 )
198196 while j <= num_tiles
199- _, xhat, wdy = bwd_helper (X, W, DY, bid_m, j, mean, rstd, TILE_N[] , N)
197+ _, xhat, wdy = bwd_helper (X, W, DY, bid_m, j, mean, rstd, TILE_N, N)
200198 c1 = c1 .+ (xhat .* wdy)
201199 c2 = c2 .+ wdy
202200 j += Int32 (1 )
@@ -207,12 +205,12 @@ function layer_norm_bwd_dx_partial_dwdb(DX::ct.TileArray{Float32, 2}, DY::ct.Til
207205 # Second pass: compute dX and partial dW/dB
208206 j = Int32 (1 )
209207 while j <= num_tiles
210- tdy, xhat, wdy = bwd_helper (X, W, DY, bid_m, j, mean, rstd, TILE_N[] , N)
208+ tdy, xhat, wdy = bwd_helper (X, W, DY, bid_m, j, mean, rstd, TILE_N, N)
211209 tdx = (wdy .- (xhat .* c1 .+ c2)) .* rstd
212210 ct. store (DX, (bid_m, j), tdx)
213211
214- partial_dw = reshape (tdy .* xhat, (TILE_N[] , 1 ))
215- partial_db = reshape (tdy, (TILE_N[] , 1 ))
212+ partial_dw = reshape (tdy .* xhat, (TILE_N, 1 ))
213+ partial_db = reshape (tdy, (TILE_N, 1 ))
216214
217215 # Acquire spinlock
218216 while ct. atomic_cas (Locks, group_bid_m, 0 , 1 ;
@@ -221,8 +219,8 @@ function layer_norm_bwd_dx_partial_dwdb(DX::ct.TileArray{Float32, 2}, DY::ct.Til
221219 end
222220
223221 # Critical section: accumulate partial gradients
224- partial_dw = partial_dw .+ ct. load (DW, (j, group_bid_m), (TILE_N[] , 1 ); padding_mode= ct. PaddingMode. Zero)
225- partial_db = partial_db .+ ct. load (DB, (j, group_bid_m), (TILE_N[] , 1 ); padding_mode= ct. PaddingMode. Zero)
222+ partial_dw = partial_dw .+ ct. load (DW, (j, group_bid_m), (TILE_N, 1 ); padding_mode= ct. PaddingMode. Zero)
223+ partial_db = partial_db .+ ct. load (DB, (j, group_bid_m), (TILE_N, 1 ); padding_mode= ct. PaddingMode. Zero)
226224 ct. store (DW, (j, group_bid_m), partial_dw)
227225 ct. store (DB, (j, group_bid_m), partial_db)
228226
@@ -251,16 +249,16 @@ Args:
251249"""
252250function layer_norm_bwd_dwdb (DW:: ct.TileArray{Float32, 2} , DB:: ct.TileArray{Float32, 2} ,
253251 FINAL_DW:: ct.TileArray{Float32, 1} , FINAL_DB:: ct.TileArray{Float32, 1} ,
254- TILE_M:: ConstInt , TILE_N:: ConstInt )
252+ TILE_M:: Int , TILE_N:: Int )
255253 bid_n = ct. bid (1 )
256- num_tiles = ct. num_tiles (DW, 2 , (TILE_N[] , TILE_M[] ))
254+ num_tiles = ct. num_tiles (DW, 2 , (TILE_N, TILE_M))
257255
258- dw = ct. zeros ((TILE_N[] , TILE_M[] ), Float32)
259- db = ct. zeros ((TILE_N[] , TILE_M[] ), Float32)
256+ dw = ct. zeros ((TILE_N, TILE_M), Float32)
257+ db = ct. zeros ((TILE_N, TILE_M), Float32)
260258 i = Int32 (1 )
261259 while i <= num_tiles
262- dw = dw .+ ct. load (DW, (bid_n, i), (TILE_N[] , TILE_M[] ); padding_mode= ct. PaddingMode. Zero)
263- db = db .+ ct. load (DB, (bid_n, i), (TILE_N[] , TILE_M[] ); padding_mode= ct. PaddingMode. Zero)
260+ dw = dw .+ ct. load (DW, (bid_n, i), (TILE_N, TILE_M); padding_mode= ct. PaddingMode. Zero)
261+ db = db .+ ct. load (DB, (bid_n, i), (TILE_N, TILE_M); padding_mode= ct. PaddingMode. Zero)
264262 i += Int32 (1 )
265263 end
266264 sum_dw = sum (dw; dims= 2 )
0 commit comments