Skip to content

Commit 5ded17f

Browse files
committed
Add DEBUG_2CTA.md
1 parent 120b306 commit 5ded17f

File tree

1 file changed

+118
-0
lines changed

1 file changed

+118
-0
lines changed

AI/DEBUG_2CTA.md

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Debugging GPU Kernel Hangs (Deadlocks) in CUTLASS DSL / 2CTA Kernels
2+
3+
## General Approach to Debugging Kernel Hangs
4+
5+
### Step 1: Build a minimal repro
6+
7+
Strip the test case down to the smallest input that triggers the hang:
8+
- batch=1, nheads=1, smallest seqlen that hangs
9+
- Single config, no loops, no benchmarking
10+
- Add a timeout or run with `compute-sanitizer` so you can distinguish a hang from slow execution
11+
12+
### Step 2: Add printf to locate the hang
13+
14+
GPU `printf` (`cute.printf`) is the primary tool. The goal is binary search: narrow down which warp and which operation is blocked.
15+
16+
**Printf guards** — avoid print storms:
17+
```python
18+
# One thread per warp:
19+
if cute.arch.thread_idx()[0] % 32 == 0:
20+
cute.printf("...")
21+
22+
# One thread per CTA (elect_one is a context manager, not a bool):
23+
with cute.arch.elect_one():
24+
cute.printf("...")
25+
26+
# One specific thread:
27+
if tidx == 0:
28+
cute.printf("...")
29+
```
30+
31+
**Strategy — coarse to fine:**
32+
1. First, print at the entry/exit of each warp's main function (load, mma, softmax, correction). This tells you which warp is stuck.
33+
2. Then add prints before/after each pipeline wait (`consumer_wait`, `producer_acquire`). This tells you which barrier is stuck.
34+
3. Then print the barrier index, phase, and stage to understand the pipeline state.
35+
36+
**What to print:**
37+
- CTA index (`cute.arch.block_idx()[0]`) — critical for multi-CTA debugging
38+
- Pipeline stage index and phase
39+
- Loop iteration count
40+
- Whether a `try_wait` succeeds or fails (use `try_wait_token` parameter)
41+
42+
### Step 3: Identify the deadlock chain
43+
44+
A hang is always a cycle. Typical chain in a pipelined kernel:
45+
46+
```
47+
MMA waiting for K from load (pipeline_kv full barrier)
48+
-> Load finished but stuck in producer_tail (waiting for MMA to release empty barrier)
49+
-> MMA can't release because it's waiting for K
50+
```
51+
52+
Once you see which barrier is stuck, trace backwards: who is supposed to signal it, and why haven't they?
53+
54+
### Step 4: Vary the problem size systematically
55+
56+
Test with different sequence lengths / block counts to find the pattern:
57+
58+
| seqlen | n_blocks | Result |
59+
|--------|----------|--------|
60+
| 128 | 1 | ? |
61+
| 256 | 2 | ? |
62+
| 384 | 3 | ? |
63+
| 512 | 4 | ? |
64+
65+
If the hang correlates with the number of visits to a pipeline stage (e.g., works for n_blocks <= kv_stages but fails when stages wrap around), the problem is likely in barrier tx_count or phase tracking.
66+
67+
### Step 5: Check barrier byte counts (tx_count)
68+
69+
For TMA-based pipelines, `arrive_and_expect_tx` sets the expected transaction byte count on an mbarrier. If the expected count doesn't match the actual bytes arriving, the barrier either:
70+
- Fires too early (expected < actual) — causes data races
71+
- Never fires (expected > actual) — causes hangs
72+
73+
In **2CTA / cluster mode**, both CTAs' TMAs signal the **same** cluster-level mbarrier. If each CTA's TMA contributes N bytes, the barrier receives 2N bytes total. The tx_count must be `N * cta_group_size`, not just `N`.
74+
75+
**All TMA pipelines need doubling** — Q, K, and V. Even though each CTA loads a different M-tile for Q, both CTAs' TMA operations still signal the same cluster-level barrier, so the expected byte count must account for both.
76+
77+
### Step 6: Check phase / parity tracking
78+
79+
`mbarrier_try_wait_parity` uses a single parity bit (0 or 1). If your pipeline state tracks phase as a monotonically increasing counter (0, 1, 2, 3, ...), you need `phase % 2` before passing it to the barrier wait. Without this, phase=2 looks like phase=0 to the hardware, which can cause waits on already-completed barriers or misses on pending ones.
80+
81+
### Step 7: Beware compiler-as-bug-source
82+
83+
If the kernel works WITH printf but hangs WITHOUT it, the printf is acting as a **compiler barrier**. The MLIR/LLVM backend cannot optimize through an opaque function call like printf, which prevents harmful instruction reordering.
84+
85+
Signs this is happening:
86+
- A single `cute.printf("\n")` in the right function fixes the hang
87+
- PTX fences (`fence_view_async_shared`, `fence_acq_rel_cluster`, `sync_warp`, `fence_proxy`) do NOT fix it — these affect hardware memory ordering, not compiler scheduling
88+
- The fix is location-sensitive (printf in one function fixes it, in another doesn't)
89+
90+
Possible workarounds:
91+
- `@dsl_user_op` decorator on pipeline methods to make them opaque to the compiler
92+
- `asm volatile` barriers (if available in the DSL)
93+
- Compare generated PTX/SASS with and without printf to identify what the compiler is reordering
94+
- File a bug against the CUTLASS DSL / MLIR pipeline
95+
96+
---
97+
98+
## 2CTA-Specific Pitfalls
99+
100+
### tcgen05.commit with empty commit groups
101+
102+
`tcgen05.commit(mbar, mask, cta_group::2)` is supposed to signal an mbarrier after all pending MMA operations complete. But if there are **no pending operations** (empty commit group), the signal only reaches the local CTA's barrier, not the remote CTA's. Fix: use explicit `mbarrier_arrive(barrier, dst_cta_rank)` to both CTAs.
103+
104+
### producer_tail deadlock
105+
106+
The default `producer_tail` (inherited from sm90 pipelines) drains the pipeline by calling `producer_acquire` in a loop. In 2CTA mode this deadlocks because the consumer (MMA warp) may have already exited without releasing all stages. Fix: make `producer_tail` a no-op for 2CTA.
107+
108+
### Tile scheduler must account for cluster shape
109+
110+
Both CTAs in a cluster must get the **same** tile coordinate. Raw `blockIdx.x` assigns consecutive values to CTAs in the same cluster. Fix: divide `blockIdx.x` by `cluster_shape_m`.
111+
112+
### Cross-CTA vs per-CTA pipelines
113+
114+
Pipelines where CTA 1's threads remotely arrive on CTA 0's barriers need cluster-sized cooperative group counts. Pipelines that are purely local to each CTA keep per-CTA counts.
115+
116+
### Softmax masking offset
117+
118+
Causal mask row positions must account for the CTA's position within the cluster. Multiply `m_block` by `cta_group_size` when computing mask coordinates.

0 commit comments

Comments
 (0)