Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 60 additions & 22 deletions src/liger_kernel/ops/backends/_ascend/ops/attn_res.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.ops.utils import get_npu_core_count

# Pad the last dim to a multiple of _FEAT_ALIGN for aligned vector UB loads.
# kernels use logical D for math/masks and d_stride for storage row pitch.
_FEAT_ALIGN = 16


def _pad_features_aligned(V: torch.Tensor) -> torch.Tensor:
d = V.shape[-1]
pad = (-d) % _FEAT_ALIGN
return torch.nn.functional.pad(V, (0, pad)) if pad else V


@triton.jit
def _attn_res_fwd_kernel(
Expand All @@ -38,6 +48,7 @@ def _attn_res_fwd_kernel(
RSTD_ptr, # [B*T, N]
n_tokens,
D,
d_stride,
eps,
n_blocks: tl.constexpr,
BLOCK_D: tl.constexpr,
Expand All @@ -58,7 +69,7 @@ def _attn_res_fwd_kernel(
score_max = tl.full((), float("-inf"), dtype=tl.float32)

for i in tl.static_range(0, n_blocks):
v_off = i * n_tokens * D + row_idx * D
v_off = i * n_tokens * d_stride + row_idx * d_stride
v = tl.load(V_ptr + v_off + cols, mask=d_mask, other=0.0).to(tl.float32)

# RMSNorm
Expand All @@ -84,12 +95,12 @@ def _attn_res_fwd_kernel(
a_i = tl.sum(tl.where(tl.arange(0, n_blocks) == i, alpha, 0.0))
tl.store(Alpha_ptr + row_idx * n_blocks + i, a_i)

v_off = i * n_tokens * D + row_idx * D
v_off = i * n_tokens * d_stride + row_idx * d_stride
v = tl.load(V_ptr + v_off + cols, mask=d_mask, other=0.0).to(tl.float32)

h += a_i * v

tl.store(Out_ptr + row_idx * D + cols, h, mask=d_mask)
tl.store(Out_ptr + row_idx * d_stride + cols, h, mask=d_mask)


@triton.jit
Expand All @@ -102,6 +113,7 @@ def _attn_res_fwd_kernel_tiled(
RSTD_ptr, # [T, N]
n_tokens,
D,
d_stride,
eps,
n_blocks: tl.constexpr,
BLOCK_D: tl.constexpr,
Expand All @@ -121,7 +133,9 @@ def _attn_res_fwd_kernel_tiled(
cols = d + tl.arange(0, BLOCK_D)
mask = cols < D

v = tl.load(V_ptr + i * n_tokens * D + tok * D + cols, mask=mask, other=0.0).to(tl.float32)
v = tl.load(V_ptr + i * n_tokens * d_stride + tok * d_stride + cols, mask=mask, other=0.0).to(
tl.float32
)

sum_sq += tl.sum(v * v, axis=0)

Expand All @@ -134,7 +148,9 @@ def _attn_res_fwd_kernel_tiled(
cols = d + tl.arange(0, BLOCK_D)
mask = cols < D

v = tl.load(V_ptr + i * n_tokens * D + tok * D + cols, mask=mask, other=0.0).to(tl.float32)
v = tl.load(V_ptr + i * n_tokens * d_stride + tok * d_stride + cols, mask=mask, other=0.0).to(
tl.float32
)

wq = tl.load(W_query_ptr + cols, mask=mask, other=0.0).to(tl.float32)
wn = tl.load(W_norm_ptr + cols, mask=mask, other=0.0).to(tl.float32)
Expand Down Expand Up @@ -165,11 +181,13 @@ def _attn_res_fwd_kernel_tiled(
for i in tl.static_range(0, n_blocks):
a_i = tl.sum(tl.where(tl.arange(0, n_blocks) == i, alpha, 0.0))

v = tl.load(V_ptr + i * n_tokens * D + tok * D + cols, mask=mask, other=0.0).to(tl.float32)
v = tl.load(V_ptr + i * n_tokens * d_stride + tok * d_stride + cols, mask=mask, other=0.0).to(
tl.float32
)

h += a_i * v

tl.store(Out_ptr + tok * D + cols, h, mask=mask)
tl.store(Out_ptr + tok * d_stride + cols, h, mask=mask)


@triton.jit
Expand All @@ -185,6 +203,7 @@ def _attn_res_bwd_kernel(
dW_norm_ptr, # [D]
n_tokens,
D,
d_stride,
n_blocks: tl.constexpr,
BLOCK_D: tl.constexpr,
):
Expand All @@ -196,7 +215,7 @@ def _attn_res_bwd_kernel(
d_mask = cols < D

# Load shared vectors
dh = tl.load(dOut_ptr + row_idx * D + cols, mask=d_mask, other=0.0).to(tl.float32)
dh = tl.load(dOut_ptr + row_idx * d_stride + cols, mask=d_mask, other=0.0).to(tl.float32)
w_query = tl.load(W_query_ptr + cols, mask=d_mask, other=0.0).to(tl.float32)
w_norm = tl.load(W_norm_ptr + cols, mask=d_mask, other=0.0).to(tl.float32)

Expand All @@ -205,7 +224,7 @@ def _attn_res_bwd_kernel(
d_alpha = tl.zeros((n_blocks,), dtype=tl.float32)

for i in tl.static_range(0, n_blocks):
v_off = i * n_tokens * D + row_idx * D
v_off = i * n_tokens * d_stride + row_idx * d_stride
v = tl.load(V_ptr + v_off + cols, mask=d_mask, other=0.0).to(tl.float32)

a_i = tl.load(Alpha_ptr + row_idx * n_blocks + i)
Expand All @@ -224,7 +243,7 @@ def _attn_res_bwd_kernel(

# Main loop
for i in tl.static_range(0, n_blocks):
v_off = i * n_tokens * D + row_idx * D
v_off = i * n_tokens * d_stride + row_idx * d_stride
v = tl.load(V_ptr + v_off + cols, mask=d_mask, other=0.0).to(tl.float32)

a_i = tl.sum(tl.where(tl.arange(0, n_blocks) == i, alpha, 0.0))
Expand Down Expand Up @@ -266,6 +285,7 @@ def _attn_res_bwd_kernel_tiled(
dW_norm_ptr,
n_tokens,
D,
d_stride,
n_blocks: tl.constexpr,
BLOCK_D: tl.constexpr,
):
Expand All @@ -287,8 +307,10 @@ def _attn_res_bwd_kernel_tiled(
cols = d + tl.arange(0, BLOCK_D)
mask = cols < D

dh = tl.load(dOut_ptr + tok * D + cols, mask=mask, other=0.0).to(tl.float32)
v = tl.load(V_ptr + i * n_tokens * D + tok * D + cols, mask=mask, other=0.0).to(tl.float32)
dh = tl.load(dOut_ptr + tok * d_stride + cols, mask=mask, other=0.0).to(tl.float32)
v = tl.load(V_ptr + i * n_tokens * d_stride + tok * d_stride + cols, mask=mask, other=0.0).to(
tl.float32
)

da_i += tl.sum(dh * v, axis=0)

Expand All @@ -311,7 +333,9 @@ def _attn_res_bwd_kernel_tiled(
cols = d + tl.arange(0, BLOCK_D)
mask = cols < D

v = tl.load(V_ptr + i * n_tokens * D + tok * D + cols, mask=mask, other=0.0).to(tl.float32)
v = tl.load(V_ptr + i * n_tokens * d_stride + tok * d_stride + cols, mask=mask, other=0.0).to(
tl.float32
)

wq = tl.load(W_query_ptr + cols, mask=mask, other=0.0).to(tl.float32)
wn = tl.load(W_norm_ptr + cols, mask=mask, other=0.0).to(tl.float32)
Expand All @@ -324,8 +348,10 @@ def _attn_res_bwd_kernel_tiled(
cols = d + tl.arange(0, BLOCK_D)
mask = cols < D

dh = tl.load(dOut_ptr + tok * D + cols, mask=mask, other=0.0).to(tl.float32)
v = tl.load(V_ptr + i * n_tokens * D + tok * D + cols, mask=mask, other=0.0).to(tl.float32)
dh = tl.load(dOut_ptr + tok * d_stride + cols, mask=mask, other=0.0).to(tl.float32)
v = tl.load(V_ptr + i * n_tokens * d_stride + tok * d_stride + cols, mask=mask, other=0.0).to(
tl.float32
)

wq = tl.load(W_query_ptr + cols, mask=mask, other=0.0).to(tl.float32)
wn = tl.load(W_norm_ptr + cols, mask=mask, other=0.0).to(tl.float32)
Expand All @@ -338,7 +364,7 @@ def _attn_res_bwd_kernel_tiled(

dv_total = dv_from_sum + dv_from_score

tl.store(dV_ptr + i * n_tokens * D + tok * D + cols, dv_total, mask=mask)
tl.store(dV_ptr + i * n_tokens * d_stride + tok * d_stride + cols, dv_total, mask=mask)

# Accumulate dW
v_norm = v * rstd
Expand Down Expand Up @@ -398,16 +424,15 @@ def attn_res_forward(blocks, w_query, w_norm, eps=1e-6):
orig_shape = V.shape # [N, B, T, D] or [N, B*T, D]
N = V.shape[0]
D = V.shape[-1]
# import pdb; pdb.set_trace()

# Flatten to [N, B*T, D]
V_3d = V.reshape(N, -1, D).contiguous()
V = _pad_features_aligned(V)
V_3d = V.reshape(N, -1, V.shape[-1]).contiguous()
d_stride = V_3d.shape[-1]
n_tokens = V_3d.shape[1]

w_query = w_query.contiguous()
w_norm = w_norm.contiguous()

Out = torch.empty(n_tokens, D, device=V.device, dtype=V.dtype)
Out = torch.empty(n_tokens, V_3d.shape[-1], device=V.device, dtype=V.dtype)
# Layout [B*T, N] for coalesced access per token
Alpha = torch.empty(n_tokens, N, device=V.device, dtype=torch.float32)
RSTD = torch.empty(n_tokens, N, device=V.device, dtype=torch.float32)
Expand All @@ -426,6 +451,7 @@ def attn_res_forward(blocks, w_query, w_norm, eps=1e-6):
RSTD,
n_tokens,
D,
d_stride,
eps,
BLOCK_D=BLOCK_D,
n_blocks=N,
Expand All @@ -440,11 +466,15 @@ def attn_res_forward(blocks, w_query, w_norm, eps=1e-6):
RSTD,
n_tokens,
D,
d_stride,
eps,
BLOCK_D=BLOCK_D,
n_blocks=N,
)

if d_stride != D:
Out = Out[:, :D].contiguous()

# Reshape output to match input spatial dims
out_shape = list(orig_shape[1:]) # [B, T, D] or [B*T, D]
return Out.view(out_shape), V_3d, Alpha, RSTD
Expand All @@ -455,8 +485,11 @@ def attn_res_backward(dh, V_3d, w_query, w_norm, Alpha, RSTD, eps=1e-6):
Returns: dV [N, B*T, D], dW_query [D], dW_norm [D]
"""
dh = dh.contiguous()
N, n_tokens, D = V_3d.shape
N, n_tokens, d_stride = V_3d.shape
D = dh.shape[-1]
dh_2d = dh.reshape(n_tokens, D)
if d_stride != D:
dh_2d = torch.nn.functional.pad(dh_2d, (0, d_stride - D))

dV = torch.empty_like(V_3d)
dW_query = torch.zeros(D, dtype=torch.float32, device=dh.device)
Expand All @@ -479,6 +512,7 @@ def attn_res_backward(dh, V_3d, w_query, w_norm, Alpha, RSTD, eps=1e-6):
dW_norm,
n_tokens,
D,
d_stride,
BLOCK_D=BLOCK_D,
n_blocks=N,
)
Expand All @@ -495,10 +529,14 @@ def attn_res_backward(dh, V_3d, w_query, w_norm, Alpha, RSTD, eps=1e-6):
dW_norm,
n_tokens,
D,
d_stride,
BLOCK_D=BLOCK_D,
n_blocks=N,
)

if d_stride != D:
dV = dV[:, :, :D].contiguous()

return dV, dW_query.to(w_query.dtype), dW_norm.to(w_norm.dtype)


Expand Down
Loading