Extract shared FP8 dequant primitives into sycl_tla_moe_dequant.hpp (PR-A1)#1813
Extract shared FP8 dequant primitives into sycl_tla_moe_dequant.hpp (PR-A1)#1813Copilot wants to merge 18 commits into
Conversation
Agent-Logs-Url: https://github.com/intel/auto-round/sessions/95841e6d-d5d1-4662-8db0-4dd69690bc28 Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com>
Agent-Logs-Url: https://github.com/intel/auto-round/sessions/95841e6d-d5d1-4662-8db0-4dd69690bc28 Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com>
for more information, see https://pre-commit.ci
Agent-Logs-Url: https://github.com/intel/auto-round/sessions/91221649-2c90-4404-ae86-3321b1581428 Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com>
Agent-Logs-Url: https://github.com/intel/auto-round/sessions/91221649-2c90-4404-ae86-3321b1581428 Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com>
|
@copilot resolve the merge conflicts in this pull request |
…ecode-implementation # Conflicts: # auto_round_extension/ark/auto_round_kernel/ark.cpp Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com>
Merged |
There was a problem hiding this comment.
Pull request overview
This PR adds an XPU-optimized MoE decode-phase GEMV kernel (small M per expert) with multiple weight formats, and wires it through the C++/PyTorch extension layer with corresponding unit tests.
Changes:
- Added a SYCL decode GEMV kernel supporting FP16/BF16, INT8/INT4/INT2 (sym/asym), and FP8 (E4M3/E5M2) weights.
- Exposed the kernel via pybind (
moe_gemm_decode) and added a Python wrapper with argument validation. - Added unit tests covering the new decode paths and key validation error cases.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| auto_round_extension/ark/test/test_moe.py | Adds decode-path unit tests plus packing/dequant reference helpers for INT2/4/8 and FP8. |
| auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp | Introduces the new SYCL MoE decode GEMV kernel implementations and dispatch. |
| auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_common.hpp | Declares the new moe_gemm_decode API (but docs currently lag implementation). |
| auto_round_extension/ark/auto_round_kernel/ark.cpp | Includes the new header and binds moe_gemm_decode via pybind. |
| auto_round_extension/ark/auto_round_kernel/init.py | Adds the ARK.moe_gemm_decode Python wrapper and validation logic. |
Comments suppressed due to low confidence (2)
auto_round_extension/ark/auto_round_kernel/init.py:871
- num_tokens_per_expert is converted to int32/contiguous but its device is not validated. If it’s a CPU tensor, the kernel will treat a host pointer as device memory. Please ensure num_tokens_per_expert is on XPU (and matches activations.device), or move it to XPU explicitly before calling into the extension.
weights = weights.contiguous()
if num_tokens_per_expert.dtype != torch.int32:
num_tokens_per_expert = num_tokens_per_expert.to(torch.int32)
if not num_tokens_per_expert.is_contiguous():
auto_round_extension/ark/auto_round_kernel/init.py:896
- group_size is used in modulo/division checks (e.g.,
K % group_size) without validating group_size > 0. Passing group_size=0 will raise a ZeroDivisionError rather than a clear ValueError. Please add an explicit check that group_size is a positive integer before any modulo/division operations.
if scales is None:
raise ValueError("scales is required for FP8 weights")
if scales.dtype != activations.dtype:
raise ValueError("scales dtype must match activations dtype")
if K % group_size != 0:
Agent-Logs-Url: https://github.com/intel/auto-round/sessions/132db2ab-85c0-45b6-81a7-b9baaa533e5e Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com>
|
Azure Pipelines could not run because the pipeline triggers exclude this branch/path. |
|
/azp run Unit-Test-CUDA-AutoRound |
|
Azure Pipelines could not run because the pipeline triggers exclude this branch/path. |
Signed-off-by: Dong, Bo1 <bo1.dong@intel.com>
|
/azp run Unit-Test-CUDA-AutoRound |
|
Azure Pipelines could not run because the pipeline triggers exclude this branch/path. |
|
/azp run Unit-Test-CUDA-AutoRound |
|
Azure Pipelines could not run because the pipeline triggers exclude this branch/path. |
|
/azp run Unit-Test-CUDA-AutoRound |
|
Azure Pipelines could not run because the pipeline triggers exclude this branch/path. |
First step of the path-A MoE prefill plan: lift the FP8 weight dequantization primitives and the
ARK_FP8_DECODE_USE_LUTenv-var reader out of the decode kernel into a standalone header so the upcoming mixed-input prefill mainloop can share one definition with the GEMV decode path. Pure refactor, zero behavior change.Changes
wrapper/include/sycl_tla_moe_dequant.hpp(ark::moe_dequantnamespace):decode_fp8_e4m3_lut/decode_fp8_e5m2_lut/*_bits/decode_fp8<IsE4M3, UseLut>+ host-side cachedfp8_decode_use_lut(). Identical bit-level behavior to the previous in-decode definitions.sycl_tla_moe_decode.hpp: replace the inline FP8 helper block (~89 lines) with#include "sycl_tla_moe_dequant.hpp"andusingre-exports insidemoe_decode_detail, so every existing call site (decode_fp8<...>(raw)in the kernel,fp8_decode_use_lut()in the dispatcher) resolves unchanged. Drops now-unused<cctype>,<cstdlib>,<limits>,<string>,bestla/sycl/fp8_lut.hincludes.sycl_tla_moe_decode.hppinwrapper/include/, picked up by the existing include path.Scope note
INT2/INT4/INT8 dequant are still inlined inside the decode kernel; they will be lifted into this header in PR-A2 when the mixed-input prefill mainloop starts consuming them, to avoid landing dead code now.
Type of Change
Refactor
Checklist Before Submitting
/azp run Unit-Test-CUDA-AutoRound.