Skip to content

[BUG] CuTe on sm_120: forward TMA epilogue mis-gating, backward non-sm90 config bug, and NVVM atomic API incompatibility #2386

@LoserCheems

Description

@LoserCheems

Summary

While bringing the CuTe path up on sm_120, I ran into 3 separate issues in the CuTe forward/backward stack.

These were all hit in a minimal smoke test that only does:

  1. import flash_attn.cute
  2. run a small causal forward
  3. run a simple backward from a scalar loss

After fixing these issues locally in a vendored copy, the same forward+backward smoke test passes on RTX 5090.

Environment

  • GPU: NVIDIA GeForce RTX 5090
  • Compute capability: sm_120
  • torch: 2.10.0a0+b4e4ee81d3.nv25.12
  • nvidia-cutlass-dsl: 4.4.2
  • quack-kernels: 0.3.5

Minimal repro

import torch
from flash_attn.cute.interface import flash_attn_func

device = "cuda"
dtype = torch.float16

q = torch.randn(1, 128, 2, 64, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(1, 128, 2, 64, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(1, 128, 2, 64, device=device, dtype=dtype, requires_grad=True)

out, lse = flash_attn_func(q, k, v, causal=True)
loss = out.float().square().mean() + 1e-3 * lse.float().square().mean()
loss.backward()

Issue 1: forward path enables TMA output-store epilogue even though tma_atom_O is None

Error

AttributeError: 'NoneType' object has no attribute '_trait'

Traceback location

The failure comes from the forward epilogue calling:

  • copy_utils.tma_get_copy_fn(...)
  • which eventually reaches cpasync.tma_partition(...)
  • with tma_atom_O=None

Root cause

In the shared forward base path, epilogue() is called with tma_atom_O=None, but self.use_tma_O can still be enabled by architecture gating.

So this path can enter the TMA output-store branch even though it has no valid TMA output atom.


Issue 2: backward path uses dQ_single_wg before assignment on non-sm90 paths

Error

UnboundLocalError: cannot access local variable 'dQ_single_wg' where it is not associated with a value

Root cause

In _flash_attn_bwd, dQ_single_wg is assigned in the arch // 10 == 9 path, but the compile-key construction for arch // 10 in [8, 9, 12] still references it.

So on sm_120, backward compile-key generation fails before kernel compilation.


Issue 3: direct nvvm.atomicrmw(...) call in backward utils is not compatible with newer NVVM bindings

Error

TypeError: atomicrmw() got an unexpected keyword argument 'res'

Root cause

The helper used by backward atomic accumulation directly calls:

nvvm.atomicrmw(res=..., op=..., ptr=..., a=...)

That is not compatible with the NVVM binding exposed in this environment (nvidia-cutlass-dsl 4.4.2 on this stack).


Result after local fixes

After applying the following local fixes:

  1. restrict / disable the invalid TMA output-store branch in the shared forward path
  2. initialize dQ_single_wg for non-sm90 backward paths
  3. replace the direct nvvm.atomicrmw(...) float atomic-add call with cute.arch.atomic_add(...)

the same forward + backward smoke test succeeds on sm_120.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions