|
4 | 4 | # This source code is licensed under the license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
| 7 | +import importlib |
7 | 8 | import logging |
8 | 9 | from typing import Optional, Tuple |
9 | 10 |
|
@@ -1387,3 +1388,69 @@ def mxfp8_quantize_cuda( |
1387 | 1388 | raise NotImplementedError( |
1388 | 1389 | "`mxfp8_quantize_cuda` needs (1) torch 2.8+ and (2) torchao built from source on a machine with CUDA capability 10.0+. Please see https://github.com/pytorch/ao/issues/2932 for more details." |
1389 | 1390 | ) |
| 1391 | + |
| 1392 | + |
| 1393 | +_mslk_available = importlib.util.find_spec("mslk") is not None |
| 1394 | + |
| 1395 | + |
| 1396 | +def mslk_quantize_nvfp4( |
| 1397 | + x: torch.Tensor, per_tensor_scale: torch.Tensor |
| 1398 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 1399 | + """Quantize a tensor to NVFP4 using the MSLK triton kernel. |
| 1400 | +
|
| 1401 | + Args: |
| 1402 | + x: Input tensor to quantize. |
| 1403 | + per_tensor_scale: Per-tensor scale (TorchAO convention: amax / (F8E4M3_MAX * F4_E2M1_MAX)). |
| 1404 | +
|
| 1405 | + Returns: |
| 1406 | + Tuple of (blockwise_scales, quantized_data_uint8) matching TorchAO's convention. |
| 1407 | + """ |
| 1408 | + mslk_global_scale = per_tensor_scale.reciprocal() |
| 1409 | + return _mslk_quantize_nvfp4_custom_op(x, mslk_global_scale) |
| 1410 | + |
| 1411 | + |
| 1412 | +@torch.library.custom_op("ao::mslk_quantize_nvfp4", mutates_args=()) |
| 1413 | +def _mslk_quantize_nvfp4_custom_op( |
| 1414 | + x: torch.Tensor, global_scale: torch.Tensor |
| 1415 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 1416 | + """Inner custom op for MSLK NVFP4 quantization. |
| 1417 | +
|
| 1418 | + Args: |
| 1419 | + x: Input tensor to quantize. |
| 1420 | + global_scale: Global scale in MSLK convention (1.0 / per_tensor_scale). |
| 1421 | +
|
| 1422 | + Returns: |
| 1423 | + Tuple of (blockwise_scales, quantized_data_uint8) matching TorchAO's convention. |
| 1424 | + """ |
| 1425 | + assert _mslk_available, ( |
| 1426 | + "mslk is required for NVFP4 triton quantization. " |
| 1427 | + "Install from https://github.com/pytorch/MSLK" |
| 1428 | + ) |
| 1429 | + from mslk.quantize.triton.fp4_quantize import ( |
| 1430 | + triton_quantize_nvfp4 as _mslk_triton_quantize_nvfp4, |
| 1431 | + ) |
| 1432 | + |
| 1433 | + data_lp, blockwise_scales = _mslk_triton_quantize_nvfp4(x, global_scale) |
| 1434 | + return blockwise_scales, data_lp.view(torch.uint8) |
| 1435 | + |
| 1436 | + |
| 1437 | +@_mslk_quantize_nvfp4_custom_op.register_fake |
| 1438 | +def _(x, global_scale): |
| 1439 | + # Mirror the reshape logic from the real MSLK kernel |
| 1440 | + orig_leading_dims, orig_N = x.shape[:-2], x.shape[-1] |
| 1441 | + x_2d = x.reshape(-1, orig_N) |
| 1442 | + M, N = x_2d.shape |
| 1443 | + |
| 1444 | + num_scales = N // 16 |
| 1445 | + n_row_blocks = triton.cdiv(M, 128) |
| 1446 | + n_col_blocks = triton.cdiv(num_scales, 4) |
| 1447 | + padded_rows = n_row_blocks * 128 |
| 1448 | + padded_cols = n_col_blocks * 4 |
| 1449 | + |
| 1450 | + scales = x.new_empty(padded_rows, padded_cols, dtype=torch.float8_e4m3fn) |
| 1451 | + xq = x.new_empty(M, N // 2, dtype=torch.uint8) |
| 1452 | + |
| 1453 | + # Reshape back to match original leading dims |
| 1454 | + scales = scales.view(*orig_leading_dims, -1, padded_cols) |
| 1455 | + xq = xq.view(*orig_leading_dims, -1, N // 2) |
| 1456 | + return scales, xq |
0 commit comments