Skip to content

Added 256b LD/ST for Blackwell+#1138

Merged
cliffburdick merged 2 commits intomainfrom
256b_ldst
Mar 12, 2026
Merged

Added 256b LD/ST for Blackwell+#1138
cliffburdick merged 2 commits intomainfrom
256b_ldst

Conversation

@cliffburdick
Copy link
Copy Markdown
Collaborator

No description provided.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Mar 11, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@cliffburdick
Copy link
Copy Markdown
Collaborator Author

/build

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 11, 2026

Greptile Summary

This PR enables 256-bit (32-byte) global load/store transactions for NVIDIA Blackwell (sm_100+) GPUs in MatX's vectorized kernel dispatch path. It does so across three files: a compile-time __CUDA_ARCH__ guard sets MAX_VEC_WIDTH_BYTES = 32 for sm_100+ device code, a runtime GetComputeCapability() check in host-side dispatch selects the 32-byte width on CC100+ hardware, and the Vector::load helper replaces the previous type-unsafe float4 reinterpret cast with a type-correct Vector<T, vec_width_elems> alias backed by new static_assert guards.

Key changes and observations:

  • Type-safety fix: The previous float4 cast in Vector::load was incorrect for any T with sizeof(T) != 4; the new vec_load_t = Vector<T, vec_width_elems> cast is type-safe for all element types.
  • static_assert tightening: The new asserts (EPT >= vec_width_elems and EPT % vec_width_elems == 0) correctly prevent the previous silent zero-iteration load bug — but because MAX_VEC_WIDTH_BYTES doubles on sm_100, valid pre-Blackwell EPT values (e.g., EPT = 4 for float) may trigger a compile-time failure when an sm_100 image is included in the fat binary.
  • Duplicated runtime CC query: The (detail::GetComputeCapability() >= 1000) ? 32 : MAX_VEC_WIDTH_BYTES expression is computed independently in both the ELEMENTS_PER_THREAD and MAX_EPT_VEC_LOAD capability branches, resulting in two cudaDeviceGetAttribute calls per dispatch query; hoisting it would avoid the redundancy.
  • Alignment correctness: The Vector<T, vec_width_elems> struct's alignas(sizeof(T) * vec_width_elems) attribute guarantees the 32-byte alignment required for 256-bit loads on sm_100; the dispatch alignment check in tensor_impl.h ensures the source pointer is equally aligned.

Confidence Score: 3/5

  • The PR is functionally correct for fully-aligned, properly-sized Blackwell tensors, but a latent compile-time breakage exists when an sm_100 fat-binary image is paired with EPT values smaller than MAX_VEC_WIDTH_BYTES/sizeof(T).
  • The float4-to-vec_load_t fix and the runtime CC dispatch are sound. However, the new static_assert in Vector::load combined with the halving dispatch logic in tensor_impl.h can produce a hard compile error for sm_100 targets when tensor dimensions or pointer alignment only support the narrower pre-Blackwell EPT (e.g., EPT = 4 for float). This is not a latent runtime issue — it will surface as a build failure for real-world users who compile for sm_100 with non-32-byte-aligned/sized tensors.
  • include/matx/core/vector.h (static_assert / dispatch EPT interaction) and include/matx/core/tensor_impl.h (halving loop and duplicated CC query)

Important Files Changed

Filename Overview
include/matx/core/defines.h Adds __CUDA_ARCH__ >= 1000 guard to set MAX_VEC_WIDTH_BYTES = 32 for Blackwell device code; falls through to 16 on the host and older architectures. Change is minimal and correct.
include/matx/core/vector.h Replaces the previously type-unsafe float4 reinterpret cast with a type-correct Vector<T, vec_width_elems> alias (vec_load_t), and adds static_assert guards enforcing EPT ≥ vec_width_elems. The assert converts the previous silent zero-iteration bug into a compile-time failure when EPT < MAX_VEC_WIDTH_BYTES/sizeof(T).
include/matx/core/tensor_impl.h Host-side dispatch paths for ELEMENTS_PER_THREAD and MAX_EPT_VEC_LOAD now query GetComputeCapability() at runtime to enable 32-byte vector dispatch on Blackwell. The identical max_vec_width_bytes computation is duplicated in both branches.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Host: Operator capability query] --> B{GetComputeCapability\n>= 1000?}
    B -- Yes --> C[max_vec_width_bytes = 32]
    B -- No --> D[max_vec_width_bytes = 16\nMAX_VEC_WIDTH_BYTES]
    C --> E[width = max_vec_width_bytes / sizeof T]
    D --> E
    E --> F{Lsize % width == 0\nAND ptr is aligned?}
    F -- No --> G[width /= 2]
    G --> F
    F -- Yes --> H[Dispatch kernel with EPT = width]
    H --> I{Target architecture?}
    I -- sm_100+ device code --> J[MAX_VEC_WIDTH_BYTES = 32\nvec_width_elems = 32/sizeof T]
    I -- sm_90 / host --> K[MAX_VEC_WIDTH_BYTES = 16\nvec_width_elems = 16/sizeof T]
    J --> L{static_assert\nEPT >= vec_width_elems}
    K --> M[Vector::load uses\nvec_load_t = Vector of T\nnum_iterations = EPT / vec_width_elems]
    L -- PASS --> M
    L -- FAIL --> N[Compile error:\nEPT too small for sm_100]
    M --> O[Vectorized load loop:\nnum_iterations x vec_load_t loads]
Loading

Last reviewed commit: c9a991f

Comment on lines +1463 to +1464
const int max_vec_width_bytes = (detail::GetComputeCapability() >= 1000) ? 32 : MAX_VEC_WIDTH_BYTES;
int width = in.jit ? 32 : max_vec_width_bytes / sizeof(T);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Runtime GetComputeCapability() may diverge from compile-time MAX_VEC_WIDTH_BYTES in device code

GetComputeCapability() is a host-only function that queries the active CUDA device at runtime. On the host, MAX_VEC_WIDTH_BYTES is always 16 (because __CUDA_ARCH__ is undefined on the host). So this runtime check is the correct mechanism to allow 32-byte host-side dispatch on Blackwell.

However, if the fat binary does not include an sm_100 image (e.g., only compiled with -gencode arch=compute_90,code=sm_90), running on a Blackwell GPU will:

  1. Make GetComputeCapability() return ≥ 1000 here → max_vec_width_bytes = 32 → dispatches EPT = 8 for float.
  2. But the JIT-compiled (from sm_90 PTX) device code has MAX_VEC_WIDTH_BYTES = 16vec_width_elems = 4num_iterations = 8 / 4 = 2 (two 128-bit loads).

This is still functionally correct (total 32 bytes loaded), but users relying on single 256-bit instruction throughput would not get the optimization without explicitly targeting sm_100. A comment here noting this compile-target dependency would help future maintainers.

@cliffburdick
Copy link
Copy Markdown
Collaborator Author

/build

@cliffburdick
Copy link
Copy Markdown
Collaborator Author

/build

@cliffburdick cliffburdick merged commit 1ae3b1a into main Mar 12, 2026
1 check passed
@cliffburdick cliffburdick deleted the 256b_ldst branch March 12, 2026 21:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant