Conversation
|
/build |
Greptile SummaryThis PR enables 256-bit (32-byte) global load/store transactions for NVIDIA Blackwell (sm_100+) GPUs in MatX's vectorized kernel dispatch path. It does so across three files: a compile-time Key changes and observations:
Confidence Score: 3/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[Host: Operator capability query] --> B{GetComputeCapability\n>= 1000?}
B -- Yes --> C[max_vec_width_bytes = 32]
B -- No --> D[max_vec_width_bytes = 16\nMAX_VEC_WIDTH_BYTES]
C --> E[width = max_vec_width_bytes / sizeof T]
D --> E
E --> F{Lsize % width == 0\nAND ptr is aligned?}
F -- No --> G[width /= 2]
G --> F
F -- Yes --> H[Dispatch kernel with EPT = width]
H --> I{Target architecture?}
I -- sm_100+ device code --> J[MAX_VEC_WIDTH_BYTES = 32\nvec_width_elems = 32/sizeof T]
I -- sm_90 / host --> K[MAX_VEC_WIDTH_BYTES = 16\nvec_width_elems = 16/sizeof T]
J --> L{static_assert\nEPT >= vec_width_elems}
K --> M[Vector::load uses\nvec_load_t = Vector of T\nnum_iterations = EPT / vec_width_elems]
L -- PASS --> M
L -- FAIL --> N[Compile error:\nEPT too small for sm_100]
M --> O[Vectorized load loop:\nnum_iterations x vec_load_t loads]
Last reviewed commit: c9a991f |
| const int max_vec_width_bytes = (detail::GetComputeCapability() >= 1000) ? 32 : MAX_VEC_WIDTH_BYTES; | ||
| int width = in.jit ? 32 : max_vec_width_bytes / sizeof(T); |
There was a problem hiding this comment.
Runtime GetComputeCapability() may diverge from compile-time MAX_VEC_WIDTH_BYTES in device code
GetComputeCapability() is a host-only function that queries the active CUDA device at runtime. On the host, MAX_VEC_WIDTH_BYTES is always 16 (because __CUDA_ARCH__ is undefined on the host). So this runtime check is the correct mechanism to allow 32-byte host-side dispatch on Blackwell.
However, if the fat binary does not include an sm_100 image (e.g., only compiled with -gencode arch=compute_90,code=sm_90), running on a Blackwell GPU will:
- Make
GetComputeCapability()return ≥ 1000 here →max_vec_width_bytes = 32→ dispatchesEPT = 8forfloat. - But the JIT-compiled (from
sm_90PTX) device code hasMAX_VEC_WIDTH_BYTES = 16→vec_width_elems = 4→num_iterations = 8 / 4 = 2(two 128-bit loads).
This is still functionally correct (total 32 bytes loaded), but users relying on single 256-bit instruction throughput would not get the optimization without explicitly targeting sm_100. A comment here noting this compile-target dependency would help future maintainers.
|
/build |
|
/build |
No description provided.