Refactor use_triton_kernel to use nvfp4_quantize_kernel_choice#3911
Refactor use_triton_kernel to use nvfp4_quantize_kernel_choice#3911jerryzh168 wants to merge 16 commits intogh/jerryzh168/42/basefrom
use_triton_kernel to use nvfp4_quantize_kernel_choice#3911Conversation
Summary: This is to prefer the addition of flashinfer quantize kernel path in next PR Test Plan: python test/prototype/mx_formats/test_inference_workflow.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3911
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 9d61f60 with merge base 15df843 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…rence`" Summary: This is to prefer the addition of flashinfer quantize kernel path in next PR Test Plan: python test/prototype/mx_formats/test_inference_workflow.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
…rence`" Summary: This is to prefer the addition of flashinfer quantize kernel path in next PR Test Plan: python test/prototype/mx_formats/test_inference_workflow.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
…rence`" Summary: This is to prefer the addition of flashinfer quantize kernel path in next PR Test Plan: python test/prototype/mx_formats/test_inference_workflow.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
…rence`" Summary: This is to prefer the addition of flashinfer quantize kernel path in next PR Test Plan: python test/prototype/mx_formats/test_inference_workflow.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
…rence`" Summary: This is to prefer the addition of flashinfer quantize kernel path in next PR Test Plan: python test/prototype/mx_formats/test_inference_workflow.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
…rence`" Summary: This is to prefer the addition of flashinfer quantize kernel path in next PR Test Plan: python test/prototype/mx_formats/test_inference_workflow.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
use_triton_kernel to use quantize_kernel_preferenceuse_triton_kernel to use nvfp4_quantize_kernel_choice
…_choice`" Summary: This is to prefer the addition of flashinfer quantize kernel path in next PR Test Plan: python test/prototype/mx_formats/test_inference_workflow.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
| leading_dims, M, K = data_hp.shape[:-2], data_hp.shape[-2], data_hp.shape[-1] | ||
|
|
||
| if use_triton_kernel: | ||
| if nvfp4_quantize_kernel_choice == NVFP4QuantizeKernelChoice.TRITON: |
There was a problem hiding this comment.
this logic is choosing a single kernel, can we just have one if statement instead of adding an intermediate variable kernel_choice?
| if use_triton_kernel: | ||
| if nvfp4_quantize_kernel_choice == NVFP4QuantizeKernelChoice.TRITON: | ||
| kernel_choice = "triton" | ||
| elif nvfp4_quantize_kernel_choice == NVFP4QuantizeKernelChoice.FLASHINFER: |
There was a problem hiding this comment.
is this PR adding flashinfer or is that the next PR?
There was a problem hiding this comment.
ah sorry, this should be in the next PR
| # flashinfer uses global_sf = (F8E4M3_MAX * F4_E2M1_MAX) / amax | ||
| # which is 1 / per_tensor_scale | ||
| global_sf = 1.0 / per_tensor_scale | ||
| data_lp, blockwise_scales = flashinfer_nvfp4_quantize( |
There was a problem hiding this comment.
are data_lp and blockwise_scales bitwise equivalent to the torch and triton paths?
There was a problem hiding this comment.
these are not bitwise equivalent I think, tested in next PR
|
|
||
|
|
||
| class NVFP4QuantizeKernelChoice(str, Enum): | ||
| """Enum for specifying the kernel used for NVFP4 quantization.""" |
There was a problem hiding this comment.
nit: make this more specific to explain what exactly this kernel is doing, "nvfp4 quantization" is correct but ambiguous
There was a problem hiding this comment.
sg. btw, I saw block_size: Block size for quantization (must be 16), is this true? why do we make this an argument if it has to be fixed?
There was a problem hiding this comment.
It's just the specification: https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference
There was a problem hiding this comment.
yeah confirmed with Vasiliy that we can remove this arg in the future
…_choice`" Summary: This is to prefer the addition of flashinfer quantize kernel path in next PR Test Plan: python test/prototype/mx_formats/test_inference_workflow.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
| aten = torch.ops.aten | ||
|
|
||
|
|
||
| class NVFP4QuantizeKernelChoice(str, Enum): |
There was a problem hiding this comment.
put in torchao/prototype/mx_formats/constants.py (or a similar file if already exists)
| orig_dtype (torch.dtype): Original tensor dtype before quantization | ||
| is_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format | ||
| use_triton_kernel (bool): Whether to use triton kernels | ||
| nvfp4_quantize_kernel_choice (NVFP4QuantizeKernelChoice): Kernel preference for quantization |
There was a problem hiding this comment.
quantize_kernel_choice? can be consistent for similar functionality in other workflows
or, quantize_to_nvfp4_kernel_choice? clearer name
There was a problem hiding this comment.
OK will change to quantize_to_nvfp4_kernel_choice for now
can rename to quantize_kernel_choice later when there are similar cases I think
| use_triton_kernel: Optional[bool] = None | ||
|
|
||
| def __post_init__(self): | ||
| self.nvfp4_quantize_kernel_choice = _handle_use_triton_kernel( |
There was a problem hiding this comment.
I think this should throw an exception if the user specified use_triton_kernel=True and anything other than kernel preference triton, and do nothing else. Setting self.use_triton_kernel to None is confusing here, let's just enforce everything is consistent and keep it.
There was a problem hiding this comment.
this would mean bc-breaking for current callsites, I was planning to not break bc in this PR, and then refactor all OSS and internal callsites, and then break bc, does that sound OK?
There was a problem hiding this comment.
the BC should not break as long as use_triton_kernel and nvfp4_quantize_kernel_choice both default to using the triton kernel
There was a problem hiding this comment.
I think this should throw an exception if the user specified use_triton_kernel=True and anything other than kernel preference triton, and do nothing else.
should this ignore the case of use_triton_kernel=False and kernel_choice triton?
the BC should not break as long as use_triton_kernel and nvfp4_quantize_kernel_choice both default to using the triton kernel
what about user setting use_triton_kernel=False?
There was a problem hiding this comment.
you can throw an exception if the two vars do not match
…_choice`" Summary: This is to prefer the addition of flashinfer quantize kernel path in next PR Test Plan: python test/prototype/mx_formats/test_inference_workflow.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
…_choice`" Summary: This is to prefer the addition of flashinfer quantize kernel path in next PR Test Plan: python test/prototype/mx_formats/test_inference_workflow.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
…_choice`" Summary: This is to prefer the addition of flashinfer quantize kernel path in next PR Test Plan: python test/prototype/mx_formats/test_inference_workflow.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
…_choice`" Summary: This is to prefer the addition of flashinfer quantize kernel path in next PR Test Plan: python test/prototype/mx_formats/test_inference_workflow.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
…_choice`" Summary: This is to prefer the addition of flashinfer quantize kernel path in next PR Test Plan: python test/prototype/mx_formats/test_inference_workflow.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
…_choice`" Summary: This is to prefer the addition of flashinfer quantize kernel path in next PR Test Plan: python test/prototype/mx_formats/test_inference_workflow.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
…_choice`" Summary: This is to prefer the addition of flashinfer quantize kernel path in next PR Test Plan: python test/prototype/mx_formats/test_inference_workflow.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
|
seems mslk kernels can give us similar performance as flashinfer kernel, this is no longer needed |
Stack from ghstack (oldest at bottom):
use_triton_kernelto usenvfp4_quantize_kernel_choice#3911Summary:
This is to prefer the addition of flashinfer quantize kernel path in next PR
use_triton_kernel==True -->
QuantizeToNVFP4KernelChoice.MSLKuse_triton_kernel==False -->
QuantizeToNVFP4KernelChoice.TRITONNote: this breaks BC for the users of the prototype API
for configs whose default is
use_triton_kernel = True(e.g.NVFP4DynamicActivationNVFP4WeightConfig), an error will be thrown when the flag is set to False,for configs whose default is
use_triton_kernel = False(e.g.NVFP4FakeQuantizeConfig), an error will be thrown when the flag is set to Truewe'll make these changes internally later
Test Plan:
python test/prototype/mx_formats/test_inference_workflow.py
Reviewers:
Subscribers:
Tasks:
Tags: