Install CUDAGuard before Arch() in api entry points (fixes #158)#176
Draft
yurekami wants to merge 1 commit into
Draft
Install CUDAGuard before Arch() in api entry points (fixes #158)#176yurekami wants to merge 1 commit into
yurekami wants to merge 1 commit into
Conversation
Arch() reads at::cuda::getCurrentDeviceProperties(), so it must run after the CUDAGuard that binds the current CUDA device to q.device(). The three api/*.h entry points had the ordering inverted: the guard was installed mid-function, after Arch had already probed whatever device happened to be current. On multi-GPU setups (especially heterogeneous SM90 + SM100 boxes), this made the dispatcher pick the wrong SM-specific impl and compute num_sm_parts from the wrong SM count. Move the guard to immediately before Arch construction in: - csrc/api/dense_decode.h - csrc/api/sparse_decode.h - csrc/api/sparse_fwd.h Add a tests/test_multi_gpu_device_guard.py regression pair that pins torch.cuda.set_device(0) while placing tensors on cuda:1 and asserts bit-equal output against a matched-device reference run. Tests skip when fewer than two CUDA devices are visible. Refs deepseek-ai#158
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes #158 (the surviving half).
at::cuda::CUDAGuardwas placed afterArch arch = Arch();in thecsrc/api/*.hentry points. BecauseArch()readsat::cuda::getCurrentDeviceProperties(), it queries the current device rather than the tensor's device. In multi-GPU setups — and especially heterogeneous SM90 + SM100 boxes — this caused:arch.num_sms→ wrongnum_sm_partsfed into the scheduler metadata,The SM100 prefill fwd/bwd paths already had this ordering correct; the three API entry points (
dense_decode,sparse_decode,sparse_prefill) did not.Fix
Move the existing
at::cuda::CUDAGuard{(char)q.get_device()}up to immediately beforeArch arch = Arch();in all three entry points, and add aTORCH_CHECK(q.is_cuda(), ...)sanity check. No behavior change on single-GPU workloads.csrc/api/dense_decode.hcsrc/api/sparse_decode.hcsrc/api/sparse_fwd.hTest plan
Added
tests/test_multi_gpu_device_guard.pywith two tests, both skipped whentorch.cuda.device_count() < 2:test_dense_decode_respects_input_device_when_current_device_differs— pintorch.cuda.set_device(0), place all tensors oncuda:1, assert output is oncuda:1and bit-equal to a matched-device reference run.test_dense_decode_current_device_unchanged_after_call— assert the guard restores the caller's current device on exit (no leak).I don't have a 2×GPU Hopper/Blackwell box to run the suite locally — marking this PR as draft so maintainers (or I, once I get access) can validate on real hardware before merge. The source change itself is a surgical reordering of existing lines.
Out of scope
The reporter's original claim about a hardcoded
hw_info.device_id = 0inrun_fmha_fwdis already fixed in HEAD (csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cuh:292-298and the matching bwd path), so no changes there.