Skip to content

Commit 5edc119

Browse files
committed
[claude] Make MXFP8 cuda kernels ABI stable
Prompt: ``` Make these two files ABI stable: torchao/csrc/cuda/mx_kernels/mxfp8_cuda.cu torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp Read these for instructions: pytorch/docs/source/notes/libtorch_stable_abi.md cppdocs/_sources/stable.rst.txt Use these files for an example: Before: flash-attention/hopper/flash_api.cpp After: flash-attention/hopper/flash_api_stable.cpp Additional instructions: Replace at::zeros with torch::stable::new_zeros Replace at::empty with torch::stable::new_empty Replace TORCH_CHECK with STD_TORCH_CHECK without changing the content of the check, just replace the function name Replace c10::cuda::CUDAGuard with DeviceGuard When calling tensor.stride(), no need to check for tensor.numel() > 0 beforehand When using cudaStream_t, remember to include cuda_runtime.h When calling aoti_torch_get_current_cuda_stream, get the device index from a tensor, not from the general torch::stable::accelerator::getCurrentDeviceIndex(); Don’t do these things: Don’t use torch::stable::new_empty_strided since this API does not exist Don’t define a dummy _C module that can be accessed from python Don’t declare aoti_torch_get_current_cuda_stream, just include it from torch/csrc/inductor/aoti_torch/c/shim.h and add -DUSE_CUDA to both cxx and nvcc in setup.py Don’t box kernels manually, just use TORCH_BOX Don’t change switch statements into if else blocks ``` Follow-up prompts: ``` Revert the content changes of STD_TORCH_CHECK, just replace the function name For the changes that involved tensor.data_ptr<some_dtype>() in the old code, revert the changes and replace it with tensor.const_data_ptr<some_dtype>() ``` Output: ``` ● I've successfully made both files ABI stable. Here's a summary of the changes: Changes to mxfp8_cuda.cu: 1. Updated includes: Replaced ATen and torch/extension.h with stable ABI headers (torch/csrc/stable/tensor.h, torch/csrc/inductor/aoti_torch/c/shim.h, torch/headeronly/core/ScalarType.h, torch/headeronly/util/Exception.h, torch/headeronly/util/shim_utils.h) 2. Added cuda_runtime.h for cudaStream_t 3. Changed tensor type: torch::Tensor → torch::stable::Tensor (with using declaration) 4. Changed scalar type enums: torch::kFloat32 → torch::headeronly::ScalarType::Float, etc. 5. Changed error checking: TORCH_CHECK → STD_TORCH_CHECK 6. Changed CUDA stream access: at::cuda::getCurrentCUDAStream() → aoti_torch_get_current_cuda_stream(input.get_device_index(), &stream_ptr) with TORCH_ERROR_CODE_CHECK 7. Fixed stride access: .strides()[i] → .stride(i) and .sizes()[i] → .size(i) Changes to mxfp8_extension.cpp: 1. Updated includes: Replaced torch/library.h, ATen/ATen.h, ATen/cuda/CUDAContext.h, c10/cuda/CUDAGuard.h with stable headers 2. Changed tensor type: at::Tensor → torch::stable::Tensor 3. Changed DeviceGuard: c10::cuda::CUDAGuard → torch::stable::accelerator::DeviceGuard (with input.get_device_index() instead of input.device()) 4. Changed error checking: TORCH_CHECK → STD_TORCH_CHECK 5. Changed tensor creation: - at::empty → torch::stable::new_empty - at::zeros → torch::stable::new_zeros - For strided tensors: Created with transposed dimensions then used torch::stable::transpose to get column-major layout 6. Changed CUDA stream access: Used aoti_torch_get_current_cuda_stream with device index from tensor 7. Changed library registration: TORCH_LIBRARY_IMPL → STABLE_TORCH_LIBRARY_IMPL with TORCH_BOX wrapper Changes to setup.py: 1. Added -DUSE_CUDA to both cxx and nvcc compiler arguments for the mxfp8 extension ```
1 parent d988122 commit 5edc119

3 files changed

Lines changed: 152 additions & 136 deletions

File tree

setup.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,11 +755,16 @@ def get_extensions():
755755
f"-DPy_LIMITED_API={min_supported_cpython_hexcode}",
756756
"-std=c++17",
757757
"-O3",
758+
"-DUSE_CUDA",
759+
# define TORCH_TARGET_VERSION with min version 2.11 for ABI stable Float8_e8m0fnu
760+
"-DTORCH_TARGET_VERSION=0x020b000000000000",
758761
],
759762
"nvcc": nvcc_args
760763
+ [
761764
"-gencode=arch=compute_100,code=sm_100",
762765
"-gencode=arch=compute_120,code=compute_120",
766+
"-DUSE_CUDA",
767+
"-DTORCH_TARGET_VERSION=0x020b000000000000",
763768
],
764769
},
765770
),

torchao/csrc/cuda/mx_kernels/mxfp8_cuda.cu

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,33 @@
11
// CUDA bridge for MXFP8 quantization
22

33
#include "mxfp8_quantize.cuh"
4-
#include <ATen/cuda/CUDAContext.h>
4+
5+
#include <torch/csrc/stable/tensor.h>
6+
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
7+
#include <torch/headeronly/core/ScalarType.h>
8+
#include <torch/headeronly/util/Exception.h>
9+
#include <torch/headeronly/util/shim_utils.h>
10+
11+
#include <cuda_runtime.h>
512
#include <string>
6-
#include <torch/extension.h>
713

14+
using torch::stable::Tensor;
815

916
namespace mxfp8 {
1017

1118
// Convert PyTorch scalar type to our DType enum
12-
DType get_input_dtype(const torch::Tensor &t) {
19+
DType get_input_dtype(const Tensor &t) {
1320
switch (t.scalar_type()) {
14-
case torch::kFloat32:
21+
case torch::headeronly::ScalarType::Float:
1522
return DType::kFloat32;
16-
case torch::kFloat16:
23+
case torch::headeronly::ScalarType::Half:
1724
return DType::kFloat16;
18-
case torch::kBFloat16:
25+
case torch::headeronly::ScalarType::BFloat16:
1926
return DType::kBFloat16;
20-
case torch::kUInt8:
27+
case torch::headeronly::ScalarType::Byte:
2128
return DType::kByte;
2229
default:
23-
TORCH_CHECK(false, "Unsupported input tensor dtype: ", t.scalar_type());
30+
STD_TORCH_CHECK(false, "Unsupported input tensor dtype: ", t.scalar_type());
2431
}
2532
}
2633

@@ -30,7 +37,7 @@ ScaleCalculationMode get_scaling_mode(const std::string &scaling_mode) {
3037
} else if (scaling_mode.compare("rceil") == 0) {
3138
return ScaleCalculationMode::RCEIL;
3239
} else {
33-
TORCH_CHECK(false, "Unsupported scaling mode: ", scaling_mode, ". Only ['floor', 'rceil'] are supported.");
40+
STD_TORCH_CHECK(false, "Unsupported scaling mode: ", scaling_mode, ". Only ['floor', 'rceil'] are supported.");
3441
}
3542
}
3643

@@ -39,16 +46,16 @@ DType get_output_dtype(const std::string &fp8_format) {
3946
if (fp8_format.compare("e4m3") == 0) {
4047
return DType::kFloat8E4M3;
4148
} else {
42-
TORCH_CHECK(false, "Unsupported FP8 format: ", fp8_format,
49+
STD_TORCH_CHECK(false, "Unsupported FP8 format: ", fp8_format,
4350
". Only 'e4m3' is supported.");
4451
}
4552
}
4653

47-
void mxfp8_quantize_cuda(const torch::Tensor &input,
48-
torch::Tensor &output_rowwise,
49-
torch::Tensor &output_colwise,
50-
torch::Tensor &scales_rowwise,
51-
torch::Tensor &scales_colwise,
54+
void mxfp8_quantize_cuda(const Tensor &input,
55+
Tensor &output_rowwise,
56+
Tensor &output_colwise,
57+
Tensor &scales_rowwise,
58+
Tensor &scales_colwise,
5259
int64_t scale_dim_x,
5360
int64_t scale_dim_y,
5461
const std::string &fp8_format,
@@ -73,23 +80,29 @@ void mxfp8_quantize_cuda(const torch::Tensor &input,
7380
? reinterpret_cast<e8m0_t *>(scales_colwise.data_ptr())
7481
: nullptr;
7582

76-
// Get CUDA stream
77-
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
83+
// Get CUDA stream using stable ABI
84+
void* stream_ptr = nullptr;
85+
TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(input.get_device_index(), &stream_ptr));
86+
cudaStream_t stream = static_cast<cudaStream_t>(stream_ptr);
7887

79-
// Get strides of scale ptrs
80-
int64_t scale_rowwise_stride_dim0 = scales_rowwise.strides()[0];
81-
int64_t scale_rowwise_stride_dim1 = scales_rowwise.strides()[1];
82-
int64_t scale_colwise_stride_dim0 = scales_colwise.strides()[0];
83-
int64_t scale_colwise_stride_dim1 = scales_colwise.strides()[1];
88+
// Get strides of scale ptrs (guard against 1D empty tensors when rowwise/colwise is false)
89+
int64_t scale_rowwise_stride_dim0 = scales_rowwise.dim() >= 2 ? scales_rowwise.stride(0) : 0;
90+
int64_t scale_rowwise_stride_dim1 = scales_rowwise.dim() >= 2 ? scales_rowwise.stride(1) : 0;
91+
int64_t scale_colwise_stride_dim0 = scales_colwise.dim() >= 2 ? scales_colwise.stride(0) : 0;
92+
int64_t scale_colwise_stride_dim1 = scales_colwise.dim() >= 2 ? scales_colwise.stride(1) : 0;
8493

8594
#if defined(DEBUG)
8695
printf("mxfp8_quantize_cuda:\n");
8796
printf("Quantizing input tensor of size %ld x %ld\n", rows, cols);
8897
printf("scaling_mode: %s\n", scaling_mode.c_str());
8998
printf("Scale dim x: %ld\n", scale_dim_x);
9099
printf("Scale dim y: %ld\n", scale_dim_y);
91-
printf("Rowwise scale shape: %ld x %ld\n", scales_rowwise.sizes()[0], scales_rowwise.sizes()[1]);
92-
printf("Colwise scale shape: %ld x %ld\n", scales_colwise.sizes()[0], scales_colwise.sizes()[1]);
100+
printf("Rowwise scale shape: %ld x %ld\n",
101+
scales_rowwise.dim() >= 1 ? scales_rowwise.size(0) : 0,
102+
scales_rowwise.dim() >= 2 ? scales_rowwise.size(1) : 0);
103+
printf("Colwise scale shape: %ld x %ld\n",
104+
scales_colwise.dim() >= 1 ? scales_colwise.size(0) : 0,
105+
scales_colwise.dim() >= 2 ? scales_colwise.size(1) : 0);
93106
printf("scale_rowwise_stride_dim0 = %ld\n", scale_rowwise_stride_dim0);
94107
printf("scale_rowwise_stride_dim1 = %ld\n", scale_rowwise_stride_dim1);
95108
printf("scale_colwise_stride_dim0 = %ld\n", scale_colwise_stride_dim0);
@@ -109,9 +122,9 @@ void mxfp8_quantize_cuda(const torch::Tensor &input,
109122
stream);
110123
}
111124

112-
void mxfp8_quantize_3d_cuda(const torch::Tensor &input,
113-
torch::Tensor &output_colwise,
114-
torch::Tensor &scales_colwise,
125+
void mxfp8_quantize_3d_cuda(const Tensor &input,
126+
Tensor &output_colwise,
127+
Tensor &scales_colwise,
115128
int64_t scale_dim_n,
116129
const std::string &fp8_format,
117130
const std::string &scaling_mode) {
@@ -127,8 +140,10 @@ void mxfp8_quantize_3d_cuda(const torch::Tensor &input,
127140
e8m0_t *scales_colwise_ptr =
128141
reinterpret_cast<e8m0_t *>(scales_colwise.data_ptr());
129142

130-
// Get CUDA stream
131-
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
143+
// Get CUDA stream using stable ABI
144+
void* stream_ptr = nullptr;
145+
TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(input.get_device_index(), &stream_ptr));
146+
cudaStream_t stream = static_cast<cudaStream_t>(stream_ptr);
132147

133148
// Get strides of scales tensor
134149
int64_t scales_colwise_stride_dim0 = scales_colwise.stride(0);
@@ -152,7 +167,7 @@ void mxfp8_quantize_3d_cuda(const torch::Tensor &input,
152167
printf("scaling_mode: %s\n", scaling_mode.c_str());
153168
printf("Scale dim n: %ld\n", scale_dim_n);
154169
printf("Output scale shape: %ld x %ld x %ld\n",
155-
scales_colwise.sizes()[0], scales_colwise.sizes()[1], scales_colwise.sizes()[2]);
170+
scales_colwise.size(0), scales_colwise.size(1), scales_colwise.size(2));
156171
printf("scales_colwise_stride_dim0 = %ld\n", scales_colwise_stride_dim0);
157172
printf("scales_colwise_stride_dim1 = %ld\n", scales_colwise_stride_dim1);
158173
printf("input_stride_dim0 = %ld\n", input_stride_dim0);

0 commit comments

Comments
 (0)