Skip to content

[NPU, A3] Add NPU kernel support for A3 machines#1220

Draft
zheliuyu wants to merge 1 commit into
linkedin:mainfrom
pt-ecosystem:main
Draft

[NPU, A3] Add NPU kernel support for A3 machines#1220
zheliuyu wants to merge 1 commit into
linkedin:mainfrom
pt-ecosystem:main

Conversation

@zheliuyu
Copy link
Copy Markdown
Contributor

@zheliuyu zheliuyu commented May 11, 2026

Motivation

This work follows the roadmap in linkedin/Liger-Kernel#969. The goal is to exercise the NPU kernel on Atlas 800T A3 (64G) and report how the test suite behaves on that hardware.

Details

  • Fixed the Ascend implementation of attn_res so it passes on A3 alongside the rest of the suite.
  • After the fix, the full test run completes successfully. 🍾
image

Why attn_res failed on A3

The failure showed up as vector-core / ACL errors (e.g. 507035, device sync failing), not a normal atol/rtol mismatch.

  • Ascend attn_res uses wide masked loads along the feature dim (e.g. BLOCK_D = next_power_of_2(D)).
  • Tests include awkward sizes like D = 123 with float32, so row pitch is 123 × 4 = 492 bytes—not 32B/64B-friendly for many vectorized paths on this stack, which can trigger vector-core faults for that lowering.
  • Fix: pad the last dim to a multiple of 16, pass d_stride as the real memory pitch, keep D for math/masks, and slice/pad tensors so callers still see logical D.

Benchmark results for the 4 most frequently used kernels

cross_entropy_memory_full_token_length cross_entropy_speed_backward_token_length
cross_entropy_memory_full_token_length cross_entropy_speed_backward_token_length
cross_entropy_speed_forward_token_length cross_entropy_speed_full_token_length
cross_entropy_speed_forward_token_length cross_entropy_speed_full_token_length
cross_entropy_speed_no-grad-forward_token_length rms_norm_memory_full_token_length
cross_entropy_speed_no-grad-forward_token_length rms_norm_memory_full_token_length
rms_norm_speed_backward_token_length rms_norm_speed_forward_token_length
rms_norm_speed_backward_token_length rms_norm_speed_forward_token_length
rms_norm_speed_full_token_length rope_memory_full_token_length
rms_norm_speed_full_token_length rope_memory_full_token_length
rope_speed_backward_token_length rope_speed_forward_token_length
rope_speed_backward_token_length rope_speed_forward_token_length
rope_speed_full_token_length swiglu_memory_full_token_length
rope_speed_full_token_length swiglu_memory_full_token_length
swiglu_speed_backward_token_length swiglu_speed_forward_token_length
swiglu_speed_backward_token_length swiglu_speed_forward_token_length
swiglu_speed_full_token_length
swiglu_speed_full_token_length

Testing Done

  • Hardware Type: All NPUs.
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

Copy link
Copy Markdown
Collaborator

@Tcc0403 Tcc0403 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants