Commit 5edc119
committed
[claude] Make MXFP8 cuda kernels ABI stable
Prompt:
```
Make these two files ABI stable:
torchao/csrc/cuda/mx_kernels/mxfp8_cuda.cu
torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp
Read these for instructions:
pytorch/docs/source/notes/libtorch_stable_abi.md
cppdocs/_sources/stable.rst.txt
Use these files for an example:
Before: flash-attention/hopper/flash_api.cpp
After: flash-attention/hopper/flash_api_stable.cpp
Additional instructions:
Replace at::zeros with torch::stable::new_zeros
Replace at::empty with torch::stable::new_empty
Replace TORCH_CHECK with STD_TORCH_CHECK without changing the content of the check, just replace the function name
Replace c10::cuda::CUDAGuard with DeviceGuard
When calling tensor.stride(), no need to check for tensor.numel() > 0 beforehand
When using cudaStream_t, remember to include cuda_runtime.h
When calling aoti_torch_get_current_cuda_stream, get the device index from a tensor, not from the general torch::stable::accelerator::getCurrentDeviceIndex();
Don’t do these things:
Don’t use torch::stable::new_empty_strided since this API does not exist
Don’t define a dummy _C module that can be accessed from python
Don’t declare aoti_torch_get_current_cuda_stream, just include it from torch/csrc/inductor/aoti_torch/c/shim.h and add -DUSE_CUDA to both cxx and nvcc in setup.py
Don’t box kernels manually, just use TORCH_BOX
Don’t change switch statements into if else blocks
```
Follow-up prompts:
```
Revert the content changes of STD_TORCH_CHECK, just replace the function name
For the changes that involved tensor.data_ptr<some_dtype>() in the old code, revert the changes and replace it with tensor.const_data_ptr<some_dtype>()
```
Output:
```
● I've successfully made both files ABI stable. Here's a summary of the changes:
Changes to mxfp8_cuda.cu:
1. Updated includes: Replaced ATen and torch/extension.h with stable ABI headers (torch/csrc/stable/tensor.h, torch/csrc/inductor/aoti_torch/c/shim.h, torch/headeronly/core/ScalarType.h, torch/headeronly/util/Exception.h, torch/headeronly/util/shim_utils.h)
2. Added cuda_runtime.h for cudaStream_t
3. Changed tensor type: torch::Tensor → torch::stable::Tensor (with using declaration)
4. Changed scalar type enums: torch::kFloat32 → torch::headeronly::ScalarType::Float, etc.
5. Changed error checking: TORCH_CHECK → STD_TORCH_CHECK
6. Changed CUDA stream access: at::cuda::getCurrentCUDAStream() → aoti_torch_get_current_cuda_stream(input.get_device_index(), &stream_ptr) with TORCH_ERROR_CODE_CHECK
7. Fixed stride access: .strides()[i] → .stride(i) and .sizes()[i] → .size(i)
Changes to mxfp8_extension.cpp:
1. Updated includes: Replaced torch/library.h, ATen/ATen.h, ATen/cuda/CUDAContext.h, c10/cuda/CUDAGuard.h with stable headers
2. Changed tensor type: at::Tensor → torch::stable::Tensor
3. Changed DeviceGuard: c10::cuda::CUDAGuard → torch::stable::accelerator::DeviceGuard (with input.get_device_index() instead of input.device())
4. Changed error checking: TORCH_CHECK → STD_TORCH_CHECK
5. Changed tensor creation:
- at::empty → torch::stable::new_empty
- at::zeros → torch::stable::new_zeros
- For strided tensors: Created with transposed dimensions then used torch::stable::transpose to get column-major layout
6. Changed CUDA stream access: Used aoti_torch_get_current_cuda_stream with device index from tensor
7. Changed library registration: TORCH_LIBRARY_IMPL → STABLE_TORCH_LIBRARY_IMPL with TORCH_BOX wrapper
Changes to setup.py:
1. Added -DUSE_CUDA to both cxx and nvcc compiler arguments for the mxfp8 extension
```1 parent d988122 commit 5edc119
3 files changed
Lines changed: 152 additions & 136 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
755 | 755 | | |
756 | 756 | | |
757 | 757 | | |
| 758 | + | |
| 759 | + | |
| 760 | + | |
758 | 761 | | |
759 | 762 | | |
760 | 763 | | |
761 | 764 | | |
762 | 765 | | |
| 766 | + | |
| 767 | + | |
763 | 768 | | |
764 | 769 | | |
765 | 770 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
3 | 3 | | |
4 | | - | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
5 | 12 | | |
6 | | - | |
7 | 13 | | |
| 14 | + | |
8 | 15 | | |
9 | 16 | | |
10 | 17 | | |
11 | 18 | | |
12 | | - | |
| 19 | + | |
13 | 20 | | |
14 | | - | |
| 21 | + | |
15 | 22 | | |
16 | | - | |
| 23 | + | |
17 | 24 | | |
18 | | - | |
| 25 | + | |
19 | 26 | | |
20 | | - | |
| 27 | + | |
21 | 28 | | |
22 | 29 | | |
23 | | - | |
| 30 | + | |
24 | 31 | | |
25 | 32 | | |
26 | 33 | | |
| |||
30 | 37 | | |
31 | 38 | | |
32 | 39 | | |
33 | | - | |
| 40 | + | |
34 | 41 | | |
35 | 42 | | |
36 | 43 | | |
| |||
39 | 46 | | |
40 | 47 | | |
41 | 48 | | |
42 | | - | |
| 49 | + | |
43 | 50 | | |
44 | 51 | | |
45 | 52 | | |
46 | 53 | | |
47 | | - | |
48 | | - | |
49 | | - | |
50 | | - | |
51 | | - | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
52 | 59 | | |
53 | 60 | | |
54 | 61 | | |
| |||
73 | 80 | | |
74 | 81 | | |
75 | 82 | | |
76 | | - | |
77 | | - | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
78 | 87 | | |
79 | | - | |
80 | | - | |
81 | | - | |
82 | | - | |
83 | | - | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
84 | 93 | | |
85 | 94 | | |
86 | 95 | | |
87 | 96 | | |
88 | 97 | | |
89 | 98 | | |
90 | 99 | | |
91 | | - | |
92 | | - | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
93 | 106 | | |
94 | 107 | | |
95 | 108 | | |
| |||
109 | 122 | | |
110 | 123 | | |
111 | 124 | | |
112 | | - | |
113 | | - | |
114 | | - | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
115 | 128 | | |
116 | 129 | | |
117 | 130 | | |
| |||
127 | 140 | | |
128 | 141 | | |
129 | 142 | | |
130 | | - | |
131 | | - | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
132 | 147 | | |
133 | 148 | | |
134 | 149 | | |
| |||
152 | 167 | | |
153 | 168 | | |
154 | 169 | | |
155 | | - | |
| 170 | + | |
156 | 171 | | |
157 | 172 | | |
158 | 173 | | |
| |||
0 commit comments