Support 32x32 scaling for weights in MXFP8 weight quantization kernel#4254
Support 32x32 scaling for weights in MXFP8 weight quantization kernel#4254danielvegamyhre merged 1 commit intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4254
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 New FailureAs of commit b7809e3 with merge base 25ca6b8 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Fixes #4185. To test: To benchmark: Benchmarking results: To test with grouped MM, apply this patch, and run (tolerances updated as 32x32 gives "worse" quantization quality than 32x1): |
| self, | ||
| amax: cutlass.Float32, | ||
| ): | ||
| amax = cute.arch.fmax( |
There was a problem hiding this comment.
using warp reduction for this sense - one nit but i think this is better done via a for loop, see example from quack:
@cute.jit
def warp_reduce(val: cute.Numeric,
op: Callable,
width: cutlass.Constexpr = 32) -> cute.Numeric:
for i in range(int(math.log2(width))):
# cute.arch.shuffle_sync_bfly will read from another thread's registers
val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i))
return val| def mxfp8_quantize_cuda_3d( | ||
| x: torch.Tensor, | ||
| block_size: int = 32, | ||
| scale_block_k: int = 1, |
There was a problem hiding this comment.
i think this API is a bit confusing - the function is just called "mxfp8_quantize_cuda_3d" and combined with the "scale_block_k" param (without corresponding scale_block_n param) seems to communicate that it is a kernel for scaling across K.
I think we should update this to be:
def mxfp8_quantize_3d_cutedsl(
x: torch.Tensor,
scale_block_n: Literal[1, 32],
scale_block_k: Literal[1, 32],
) ...
# .. (validate only 32x1 and 32x32 are currently supported)
# .. (in the future would be nice to support 1x32 as well)There was a problem hiding this comment.
Changed. (Probably 1x32 will not be a straightforward extension like 32x32.)
There was a problem hiding this comment.
Yeah I agree, based on my experience with the kernels for 1x32 vs 32x1 scaling on 2D tensors.
This can just be the user-facing API, that dispatches to the appropriate kernel based on the scale_k/scale_n values
| ) | ||
| else: | ||
| padded_scale_cols = cute.round_up(n_blocks, 4) | ||
| k_block_tiles = cute.ceil_div(K, 128) |
There was a problem hiding this comment.
i just noticed, shouldn't we always be padding the scale tensor rows/cols such that is evenly divisible into 128x4 cols (or 4x128, depending on scaling dim)? for both 32x1 scaling and 32x32 scaling?
this seems to assume that in the case of 32x1 scaling, we assume K is divisible by 128, but do we actually validate that, or will it just silently fail/crash?
i would prefer we don't do if/else here, and just unconditionally use padded rows/cols.
cd4c46f to
3cc1b4a
Compare
| weight_e4m3, weight_scales_blocked = mxfp8_quantize_cuda_3d( | ||
| weight._data if hasattr(weight, "_data") else weight, | ||
| block_size, | ||
| scale_block_n=block_size, |
There was a problem hiding this comment.
can you make this more explicit by also specifying scale_block_k=1? so it is clear at a glance we are still using 32x1 scaling here, not 32x32 (yet).
3cc1b4a to
68aa274
Compare
| block_size: int = 32, | ||
| scale_block_k: int = 1, |
There was a problem hiding this comment.
can we make everything consistent with the user-facing API where scale_block_n and scale_block_k are both explicit, rather than N being implicit and K being explicit? i.e. in the kernel as well
There was a problem hiding this comment.
Sure, fixed - sorry about that. For the kernel, as discussed above, considerable changes would be probably needed for scale_block_n different than 32, but it still make sense to have it explicit.
bb90db2 to
fb9bd5f
Compare
| K: int, | ||
| scale_block_k: int, | ||
| ) -> Tuple[str, Tuple[int, int, int, int]]: | ||
| del scale_block_n |
There was a problem hiding this comment.
one last nit: in this particular helper function, i think it is confusing to pass K and scale_block_n args then immediately delete them via del K and del scale_block_n. let's just only pass the args we need here.
the other places you added explicit scale_block_n/scale_block_k look good (main kernel entry point, user facing api/wrapper, etc).
| # (compute_warps, tile_n, tile_k, k_tiles_per_cta) | ||
| _CUTEDSL_CONFIGS = { | ||
| "bf16_default": (6, 32, 128, 4), | ||
| "bf16_1x32": (4, 32, 128, 4), |
There was a problem hiding this comment.
should this one be 32x1 (not 1x32) since it is quantizing shape (E,N,K) with 32x1 scaling along N
There was a problem hiding this comment.
That was a typo, fixed.
| amax = self._warp_reduce_max(amax) | ||
| scale_biased, inv_scale = compute_scale_from_amax(amax, USE_RCEIL) | ||
| if cutlass.const_expr(SCALE_DIM_K_VALUE == 32): | ||
| if lane == cutlass.Int32(0): |
There was a problem hiding this comment.
based on your store_scale_32x32 implementation here:
if cutlass.const_expr(BLOCKED_SCALE_OUTPUT):
n_row = n_block * cutlass.Int64(32) + cutlass.Int64(lane)
scales_expert[n_row, k_block] = scale_u8don't we need all 32 lanes to store the warp reduced scale factor, not just thread 0, in order to replicate it to all 32 rows of the SF tile?
There was a problem hiding this comment.
Indeed that was a bug. I actually encountered it when I tried end-to-end test, and fixed it, but instead of adding the fix to commit, it stayed in the patch attached to my first commit. I've pushed this change now, and here is the proper patch for end-to-end testing: diff-end-to-end-test.txt. Shall we add this patch, to be an option for given grouped MM test?
| padded_scale_cols = cute.round_up(n_blocks, 4) | ||
| k_block_tiles = cute.ceil_div(K, 128) | ||
| n_block_tiles = padded_scale_cols // cutlass.Int64(4) | ||
| scale_rows = K if cutlass.const_expr(SCALE_DIM_K_VALUE == 1) else N |
There was a problem hiding this comment.
i'm a bit confused, why does the number of rows in the scale factor tensor depend on the granularity of our scale factor size (32x32 vs 32x1)? same question for scale_cols below. is this right? if so, could you please add a brief comment explaining
There was a problem hiding this comment.
I think the confusion comes from the fact that I tried to keep single scale factor for 32x32 block as an option for the quantization side of the things, to be ready in case we decide to try with grouped MM kernel that would handle this situation. This variant should be under blocked_scale_output = False condition. I've added some comments to clarify it, and I've also added an explicit test variant for this case. Is this explaining it, or you had something else on mind here?
There was a problem hiding this comment.
Would this complexity of swapping K/N for rows/cols be eliminated if we only supported blocked layout for 32x32?
There was a problem hiding this comment.
Sorry, I think I answered a different question above. You were asking about the blocked_scale_output = True case, while I replied about the compact blocked_scale_output = False 32x32 representation. In the blocked path this is intentional: 32x1 and 32x32 map to different logical scale layouts before blocking, so scale_rows / scale_cols do depend on the granularity there.
There was a problem hiding this comment.
i see, so to summarize my understanding:
- for 32x1 blocked layout the scale shape is (K, N//32) since the grouped gemm expects this shape
- for 32x32 blocked layout, we conceptually have one scale per (N//32, K//32) block, then replicate it by having every lane write it to its associated row, giving us (N, K//32).
is that correct? also, have you tested the 32x32 blocked outputs with the torch._scaled_grouped_mm mxfp8 grouped gemm?
There was a problem hiding this comment.
Yes, this is correct.
As for the testing: What is tested for 32x32 at the moment is the quantization itself in test_cuda_mx_3d_cutedsl_numerics. If you want, I can add a grouped MM test that consumes the blocked 32x32 scales.
There was a problem hiding this comment.
Yes please add a test that confirms A @ B with 1x32 for A and 32x32 for B works, that would be great
fb9bd5f to
5b89ea9
Compare
| USE_RCEIL, | ||
| BLOCKED_SCALE_OUTPUT_VALUE, | ||
| ) | ||
| self._quantize_store_tail( |
There was a problem hiding this comment.
Comparing the if cutlass.const_expr(IS_FULL_K_TILES) "true" path vs "false" path, the code looks identical except:
- in the 'false' path, it has an extra condition
k_in_bounds(e.g.if k_rel < TILE_K and k_in_bounds: ...) - at the end we call
self._quantize_store_fullvsself._quantize_store_tail
Could we simplify this by:
- always just doing the k_in_bounds check (unnecessary for full blocks but i doubt that 1 boolean check will hurt perf?)
- executing the shared code (which would now be one unified path)
- doing the if/else at the end for full vs tail store?
if i missed something, please let me know
There was a problem hiding this comment.
Benchmarks still look the same after the change?
There was a problem hiding this comment.
Yes, benchmarked with and without this change, no noticeable difference.
5b89ea9 to
e465aec
Compare
| scale_biased, inv_scale = compute_scale_from_amax(amax, USE_RCEIL) | ||
| if cutlass.const_expr(SCALE_DIM_K_VALUE == 32): | ||
| if cutlass.const_expr(BLOCKED_SCALE_OUTPUT): | ||
| self._store_scale_32x32( |
There was a problem hiding this comment.
Can you please add a comment to clarify the blocked path uses all 32 lanes to replicate the warp-reduced scale, versus the non-blocked (row major) path uses a single lane to write a single non replicated scale factor for the 32x32 block?
Just to make this calling logic more immediately understandable, so the reader doesn't have to dive deeper into the code to understand the higher level calling logic.
| BLOCKED_SCALE_OUTPUT, | ||
| ) | ||
| else: | ||
| self._store_scale( |
There was a problem hiding this comment.
nit but can we call this _store_scale_32x1 now to be explicit, since this kernel now supports different scale granularities. it would be preferable for both 32x32 and 32x1 be explicit, rather than one be implicit and one be explicit.
e465aec to
182a766
Compare
182a766 to
3db7b3d
Compare
| @pytest.mark.parametrize( | ||
| "scale_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL) | ||
| ) | ||
| def test_emulate_mxfp8_grouped_gemm_2d_3d( |
There was a problem hiding this comment.
@alexsamardzic it seems like this test was changed from testing the emulated path, to testing the non-emulated/real-sm100 path? instead we should keep the test for the emulated path, and make a new test for this real 1x32 @ 32x32 case (make sure to add skip if not on sm100)
There was a problem hiding this comment.
Done. We have now test_emulate_mxfp8_grouped_gemm_2d_3d and test_mxfp8_grouped_gemm_2d_3d, that both test both 32x1 and 32x32 variants that we have at the moment.
3db7b3d to
b7809e3
Compare
Fixes #4185