Skip to content

Install CUDAGuard before Arch() in api entry points (fixes #158)#176

Draft
yurekami wants to merge 1 commit into
deepseek-ai:mainfrom
yurekami:fix/issue-158-arch-before-guard
Draft

Install CUDAGuard before Arch() in api entry points (fixes #158)#176
yurekami wants to merge 1 commit into
deepseek-ai:mainfrom
yurekami:fix/issue-158-arch-before-guard

Conversation

@yurekami
Copy link
Copy Markdown

Summary

Fixes #158 (the surviving half). at::cuda::CUDAGuard was placed after Arch arch = Arch(); in the csrc/api/*.h entry points. Because Arch() reads at::cuda::getCurrentDeviceProperties(), it queries the current device rather than the tensor's device. In multi-GPU setups — and especially heterogeneous SM90 + SM100 boxes — this caused:

  • the dispatcher to pick the wrong SM-specific impl (e.g. SM90 impl for tensors on an SM100 device),
  • arch.num_sms → wrong num_sm_parts fed into the scheduler metadata,
  • and in general, hardware probing that disagreed with where the kernel actually runs.

The SM100 prefill fwd/bwd paths already had this ordering correct; the three API entry points (dense_decode, sparse_decode, sparse_prefill) did not.

Fix

Move the existing at::cuda::CUDAGuard{(char)q.get_device()} up to immediately before Arch arch = Arch(); in all three entry points, and add a TORCH_CHECK(q.is_cuda(), ...) sanity check. No behavior change on single-GPU workloads.

  • csrc/api/dense_decode.h
  • csrc/api/sparse_decode.h
  • csrc/api/sparse_fwd.h

Test plan

Added tests/test_multi_gpu_device_guard.py with two tests, both skipped when torch.cuda.device_count() < 2:

  • test_dense_decode_respects_input_device_when_current_device_differs — pin torch.cuda.set_device(0), place all tensors on cuda:1, assert output is on cuda:1 and bit-equal to a matched-device reference run.
  • test_dense_decode_current_device_unchanged_after_call — assert the guard restores the caller's current device on exit (no leak).

I don't have a 2×GPU Hopper/Blackwell box to run the suite locally — marking this PR as draft so maintainers (or I, once I get access) can validate on real hardware before merge. The source change itself is a surgical reordering of existing lines.

Out of scope

The reporter's original claim about a hardcoded hw_info.device_id = 0 in run_fmha_fwd is already fixed in HEAD (csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cuh:292-298 and the matching bwd path), so no changes there.

Arch() reads at::cuda::getCurrentDeviceProperties(), so it must run after
the CUDAGuard that binds the current CUDA device to q.device(). The three
api/*.h entry points had the ordering inverted: the guard was installed
mid-function, after Arch had already probed whatever device happened to
be current. On multi-GPU setups (especially heterogeneous SM90 + SM100
boxes), this made the dispatcher pick the wrong SM-specific impl and
compute num_sm_parts from the wrong SM count.

Move the guard to immediately before Arch construction in:
- csrc/api/dense_decode.h
- csrc/api/sparse_decode.h
- csrc/api/sparse_fwd.h

Add a tests/test_multi_gpu_device_guard.py regression pair that pins
torch.cuda.set_device(0) while placing tensors on cuda:1 and asserts
bit-equal output against a matched-device reference run. Tests skip
when fewer than two CUDA devices are visible.

Refs deepseek-ai#158
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug/Correctness] Hardcoded device_id=0 + missing CUDAGuard can break multi-GPU correctness (wrong hw_info / stream mismatch)

2 participants