Summary
While bringing the CuTe path up on sm_120, I ran into 3 separate issues in the CuTe forward/backward stack.
These were all hit in a minimal smoke test that only does:
- import
flash_attn.cute
- run a small causal forward
- run a simple backward from a scalar loss
After fixing these issues locally in a vendored copy, the same forward+backward smoke test passes on RTX 5090.
Environment
- GPU: NVIDIA GeForce RTX 5090
- Compute capability:
sm_120
- torch:
2.10.0a0+b4e4ee81d3.nv25.12
- nvidia-cutlass-dsl:
4.4.2
- quack-kernels:
0.3.5
Minimal repro
import torch
from flash_attn.cute.interface import flash_attn_func
device = "cuda"
dtype = torch.float16
q = torch.randn(1, 128, 2, 64, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(1, 128, 2, 64, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(1, 128, 2, 64, device=device, dtype=dtype, requires_grad=True)
out, lse = flash_attn_func(q, k, v, causal=True)
loss = out.float().square().mean() + 1e-3 * lse.float().square().mean()
loss.backward()
Issue 1: forward path enables TMA output-store epilogue even though tma_atom_O is None
Error
AttributeError: 'NoneType' object has no attribute '_trait'
Traceback location
The failure comes from the forward epilogue calling:
copy_utils.tma_get_copy_fn(...)
- which eventually reaches
cpasync.tma_partition(...)
- with
tma_atom_O=None
Root cause
In the shared forward base path, epilogue() is called with tma_atom_O=None, but self.use_tma_O can still be enabled by architecture gating.
So this path can enter the TMA output-store branch even though it has no valid TMA output atom.
Issue 2: backward path uses dQ_single_wg before assignment on non-sm90 paths
Error
UnboundLocalError: cannot access local variable 'dQ_single_wg' where it is not associated with a value
Root cause
In _flash_attn_bwd, dQ_single_wg is assigned in the arch // 10 == 9 path, but the compile-key construction for arch // 10 in [8, 9, 12] still references it.
So on sm_120, backward compile-key generation fails before kernel compilation.
Issue 3: direct nvvm.atomicrmw(...) call in backward utils is not compatible with newer NVVM bindings
Error
TypeError: atomicrmw() got an unexpected keyword argument 'res'
Root cause
The helper used by backward atomic accumulation directly calls:
nvvm.atomicrmw(res=..., op=..., ptr=..., a=...)
That is not compatible with the NVVM binding exposed in this environment (nvidia-cutlass-dsl 4.4.2 on this stack).
Result after local fixes
After applying the following local fixes:
- restrict / disable the invalid TMA output-store branch in the shared forward path
- initialize
dQ_single_wg for non-sm90 backward paths
- replace the direct
nvvm.atomicrmw(...) float atomic-add call with cute.arch.atomic_add(...)
the same forward + backward smoke test succeeds on sm_120.
Summary
While bringing the CuTe path up on
sm_120, I ran into 3 separate issues in the CuTe forward/backward stack.These were all hit in a minimal smoke test that only does:
flash_attn.cuteAfter fixing these issues locally in a vendored copy, the same forward+backward smoke test passes on RTX 5090.
Environment
sm_1202.10.0a0+b4e4ee81d3.nv25.124.4.20.3.5Minimal repro
Issue 1: forward path enables TMA output-store epilogue even though
tma_atom_OisNoneError
Traceback location
The failure comes from the forward epilogue calling:
copy_utils.tma_get_copy_fn(...)cpasync.tma_partition(...)tma_atom_O=NoneRoot cause
In the shared forward base path,
epilogue()is called withtma_atom_O=None, butself.use_tma_Ocan still be enabled by architecture gating.So this path can enter the TMA output-store branch even though it has no valid TMA output atom.
Issue 2: backward path uses
dQ_single_wgbefore assignment on non-sm90 pathsError
Root cause
In
_flash_attn_bwd,dQ_single_wgis assigned in thearch // 10 == 9path, but the compile-key construction forarch // 10 in [8, 9, 12]still references it.So on
sm_120, backward compile-key generation fails before kernel compilation.Issue 3: direct
nvvm.atomicrmw(...)call in backward utils is not compatible with newer NVVM bindingsError
Root cause
The helper used by backward atomic accumulation directly calls:
That is not compatible with the NVVM binding exposed in this environment (
nvidia-cutlass-dsl 4.4.2on this stack).Result after local fixes
After applying the following local fixes:
dQ_single_wgfor non-sm90 backward pathsnvvm.atomicrmw(...)float atomic-add call withcute.arch.atomic_add(...)the same forward + backward smoke test succeeds on
sm_120.