Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 142 additions & 0 deletions benchmark/bench_sm80_decode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""Benchmark for the SM80 dense MLA decode kernel.

Compares against a PyTorch eager (BMM-based) reference. The eager path is
slow for long sequences -- iteration counts shrink accordingly. Reports
latency, KV bandwidth, and speedup across a config sweep."""

import argparse
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import torch
import flash_mla.cuda as cuda


def torch_eager_mla(q, kcache, block_table, seqlens_k, softmax_scale, head_size_v, is_causal=False):
"""PyTorch BMM-based MLA decode. Returns the same shape as the kernel output."""
b, sq, hq, dk = q.shape
_, pbs, hk, _ = kcache.shape
nq_per_hk = hq // hk
out = torch.zeros(b, sq, hq, head_size_v, dtype=q.dtype, device=q.device)
for bi in range(b):
sk = int(seqlens_k[bi].item())
bt = block_table[bi]
nb = (sk + pbs - 1) // pbs
ks = [kcache[bt[bl].item()] for bl in range(nb)]
kc = torch.cat(ks, dim=0)[:sk] # (sk, hk, dk)
k_full = kc.transpose(0, 1).contiguous() # (hk, sk, dk)
v_full = k_full[:, :, :head_size_v] # (hk, sk, dv)
q_b = q[bi] # (sq, hq, dk)
q_rs = (q_b.view(sq, hk, nq_per_hk, dk)
.permute(1, 0, 2, 3)
.reshape(hk, sq * nq_per_hk, dk))
scores = torch.bmm(q_rs.float(), k_full.float().transpose(1, 2)) * softmax_scale
if is_causal:
for sq_idx in range(sq):
rb = max(0, sk - (sq - sq_idx - 1))
if rb < sk:
for nq_idx in range(nq_per_hk):
scores[:, sq_idx * nq_per_hk + nq_idx, rb:] = float('-inf')
probs = torch.softmax(scores, dim=-1)
o = torch.bmm(probs, v_full.float()) # (hk, q_per_hk, dv)
o_rs = (o.view(hk, sq, nq_per_hk, head_size_v)
.permute(1, 2, 0, 3)
.reshape(sq, hq, head_size_v))
out[bi] = o_rs.to(q.dtype)
return out


def bench(fn, iters, warmup):
for _ in range(warmup):
fn()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
for _ in range(iters):
fn()
e.record()
torch.cuda.synchronize()
return s.elapsed_time(e) / iters


def make_inputs(batch, sq, hq, hk, sk, dtype, device):
head_size_k = 576
page_block_size = 64
q = torch.randn(batch, sq, hq, head_size_k, dtype=dtype, device=device) * 0.1
nb = (sk + page_block_size - 1) // page_block_size
kcache = torch.randn(nb * batch, page_block_size, hk, head_size_k, dtype=dtype, device=device) * 0.1
seqlens_k = torch.full((batch,), sk, dtype=torch.int32, device=device)
block_table = torch.arange(nb * batch, dtype=torch.int32, device=device).view(batch, nb)
return q, kcache, seqlens_k, block_table


def main():
parser = argparse.ArgumentParser()
parser.add_argument('--dtype', default='bfloat16', choices=['bfloat16', 'float16'])
parser.add_argument('--no-torch-baseline', action='store_true',
help='skip PyTorch eager reference (much faster sweep)')
parser.add_argument('--check', action='store_true',
help='also run a one-shot correctness check vs eager')
args = parser.parse_args()

torch.manual_seed(0)
device = 'cuda'
dtype = getattr(torch, args.dtype)
head_size_k = 576
head_size_v = 512
softmax_scale = 1.0 / (head_size_k ** 0.5)

configs = [
# (batch, sq, hq, hk, sk)
(1, 1, 16, 1, 256),
(1, 1, 16, 1, 1024),
(1, 1, 16, 1, 4096),
(1, 1, 16, 1, 16384),
(1, 1, 16, 1, 65536),
(4, 1, 16, 1, 1024),
(4, 1, 16, 1, 4096),
(16, 1, 16, 1, 1024),
(16, 1, 16, 1, 4096),
(64, 1, 16, 1, 1024),
(64, 1, 16, 1, 4096),
(1, 1, 64, 1, 4096),
]

print(f'{"config":<32} {"ours(ms)":>9} {"torch(ms)":>10} {"speedup":>8} {"ours BW(GB/s)":>14}')
print('-' * 75)
for batch, sq, hq, hk, sk in configs:
q, kcache, seqlens_k, block_table = make_inputs(batch, sq, hq, hk, sk, dtype, device)

ours_fn = lambda: cuda.dense_decode_fwd(
q, kcache, head_size_v, seqlens_k, block_table, softmax_scale, False, None, None
)
ours_ms = bench(ours_fn, iters=200, warmup=20)
kv_bytes = batch * hk * sk * head_size_k * 2
bw = kv_bytes / (ours_ms * 1e-3) / 1e9

if args.no_torch_baseline:
torch_str = 'skip'
speedup_str = '-'
else:
iters = 5 if sk * batch >= 8192 else (20 if sk * batch >= 1024 else 50)
warmup = 2
torch_fn = lambda: torch_eager_mla(q, kcache, block_table, seqlens_k, softmax_scale, head_size_v)
torch_ms = bench(torch_fn, iters=iters, warmup=warmup)
torch_str = f'{torch_ms:.3f}'
speedup_str = f'{torch_ms / ours_ms:.1f}x'

if args.check:
out, _, _, _ = ours_fn()
ref = torch_eager_mla(q, kcache, block_table, seqlens_k, softmax_scale, head_size_v)
diff = (out.float() - ref.float()).abs().max().item()
tag = 'OK' if diff < 0.02 else f'FAIL diff={diff:.4f}'
speedup_str = f'{speedup_str} ({tag})'

cfg = f'b={batch} sq={sq} hq={hq} hk={hk} sk={sk}'
print(f'{cfg:<32} {ours_ms:>9.3f} {torch_str:>10} {speedup_str:>8} {bw:>14.1f}')


if __name__ == '__main__':
main()
129 changes: 129 additions & 0 deletions benchmark/profile_decode_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""Profile dense_decode_fwd's share of an MLA decode step.

Uses a DeepSeek-V3-shaped 1-layer attention block:
x (b, H) -> Q proj -> q (b, hq, dk)
q + KV cache -> dense_decode_fwd -> o (b, hq, dv)
o -> O proj -> y (b, H)

This is an under-estimate of full decode step time (no FFN / MoE / layernorm),
which means dense_decode's measured share here is an UPPER bound on its share
of a full step. If decode is < 30% even here, BLOCK_M=8 redesign (which gives
+3-5pp on decode itself) won't move the full-step needle meaningfully."""

import argparse
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import torch
import flash_mla.cuda as cuda


def time_fn(fn, iters=100, warmup=20):
for _ in range(warmup):
fn()
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
for _ in range(iters):
fn()
e.record()
torch.cuda.synchronize()
return s.elapsed_time(e) / iters


def profile(batch, seqlen, hidden, num_q_heads, num_kv_heads, head_dim_k, head_dim_v, dtype):
device = 'cuda'
pbs = 64
softmax_scale = 1.0 / (head_dim_k ** 0.5)

# Linear projections (Q absorbed: hidden -> num_q_heads*head_dim_k).
x = torch.randn(batch, hidden, dtype=dtype, device=device) * 0.1
W_q = torch.randn(hidden, num_q_heads * head_dim_k, dtype=dtype, device=device) * 0.01
W_o = torch.randn(num_q_heads * head_dim_v, hidden, dtype=dtype, device=device) * 0.01

# KV cache (paged, num_kv_heads heads).
nb = (seqlen + pbs - 1) // pbs
total_blocks = nb * batch
kcache = torch.randn(total_blocks, pbs, num_kv_heads, head_dim_k, dtype=dtype, device=device) * 0.1
seqlens_k = torch.full((batch,), seqlen, dtype=torch.int32, device=device)
block_table = torch.arange(total_blocks, dtype=torch.int32, device=device).view(batch, nb)

def attn_block():
q = (x @ W_q).view(batch, 1, num_q_heads, head_dim_k)
out, _, _, _ = cuda.dense_decode_fwd(
q, kcache, head_dim_v, seqlens_k, block_table,
softmax_scale, False, None, None,
)
out_flat = out.contiguous().view(batch, num_q_heads * head_dim_v)
return out_flat @ W_o

def attn_no_oproj():
q = (x @ W_q).view(batch, 1, num_q_heads, head_dim_k)
out, _, _, _ = cuda.dense_decode_fwd(
q, kcache, head_dim_v, seqlens_k, block_table,
softmax_scale, False, None, None,
)
return out

def decode_only():
q = torch.randn(batch, 1, num_q_heads, head_dim_k, dtype=dtype, device=device) * 0.1
out, _, _, _ = cuda.dense_decode_fwd(
q, kcache, head_dim_v, seqlens_k, block_table,
softmax_scale, False, None, None,
)
return out

full = time_fn(attn_block)
no_op = time_fn(attn_no_oproj)
only = time_fn(decode_only)
qproj = no_op - only
oproj = full - no_op
share = only / full * 100.0
return full, qproj, only, oproj, share


def main():
ap = argparse.ArgumentParser()
ap.add_argument('--dtype', default='bfloat16', choices=['bfloat16', 'float16'])
args = ap.parse_args()
dtype = getattr(torch, args.dtype)

# DeepSeek-V3 architectural constants (MLA absorbed mode).
HIDDEN = 7168
NUM_Q_HEADS = 128
NUM_KV_HEADS = 1
HEAD_DIM_K = 576
HEAD_DIM_V = 512

print(f'DeepSeek-V3-shaped 1-layer attention block, dtype={args.dtype}')
print(f' hidden={HIDDEN} hq={NUM_Q_HEADS} hk={NUM_KV_HEADS} dk={HEAD_DIM_K} dv={HEAD_DIM_V}')
print()
print(f'{"config":<22} {"attn(ms)":>10} {"qproj":>8} {"decode":>8} {"oproj":>8} {"decode%":>9}')
print('-' * 70)

configs = [
# (batch, seqlen)
(1, 1024), (1, 4096), (1, 16384), (1, 65536),
(4, 1024), (4, 4096), (4, 16384),
(16, 1024), (16, 4096), (16, 16384),
(64, 1024), (64, 4096), (64, 16384),
(128, 4096),
]
for b, sk in configs:
try:
full, qproj, only, oproj, share = profile(b, sk, HIDDEN, NUM_Q_HEADS, NUM_KV_HEADS, HEAD_DIM_K, HEAD_DIM_V, dtype)
tag = f'b={b} sk={sk}'
print(f'{tag:<22} {full:>10.3f} {qproj:>8.3f} {only:>8.3f} {oproj:>8.3f} {share:>8.1f}%')
except torch.cuda.OutOfMemoryError:
print(f'b={b} sk={sk}: OOM (skipped)')

print()
print('Note: this 1-layer attention block excludes FFN/MoE/layernorm/residual,')
print('which together typically dominate full step time. The "decode%" above')
print('is therefore an UPPER bound on dense_decode share of a real decode step.')


if __name__ == '__main__':
main()
14 changes: 12 additions & 2 deletions csrc/api/api.cpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
#include <pybind11/pybind11.h>

#include "dense_decode.h"

#if !defined(FLASH_MLA_DISABLE_SM90)
#include "sparse_fwd.h"
#include "sparse_decode.h"
#include "dense_decode.h"
#endif

#if !defined(FLASH_MLA_DISABLE_SM100)
#include "dense_fwd.h"
#endif

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "FlashMLA";
m.def("sparse_decode_fwd", &sparse_attn_decode_interface);
m.def("dense_decode_fwd", &dense_attn_decode_interface);
#if !defined(FLASH_MLA_DISABLE_SM90)
m.def("sparse_decode_fwd", &sparse_attn_decode_interface);
m.def("sparse_prefill_fwd", &sparse_attn_prefill_interface);
#endif
#if !defined(FLASH_MLA_DISABLE_SM100)
m.def("dense_prefill_fwd", &FMHACutlassSM100FwdRun);
m.def("dense_prefill_bwd", &FMHACutlassSM100BwdRun);
#endif
}
4 changes: 4 additions & 0 deletions csrc/api/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ struct Arch {
num_sms = device_prop->multiProcessorCount;
}

bool is_sm80() const {
return major == 8 && minor == 0;
}

bool is_sm90a() const {
return major == 9 && minor == 0;
}
Expand Down
42 changes: 37 additions & 5 deletions csrc/api/dense_decode.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
#include "common.h"
#include "params.h"

#ifndef FLASH_MLA_DISABLE_SM90
#include "sm90/decode/dense/splitkv_mla.h"
#endif
#ifndef FLASH_MLA_DISABLE_SM80
#include "sm80/decode/dense/splitkv_mla.h"
#endif
#include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "smxx/decode/combine/combine.h"

Expand All @@ -24,8 +29,8 @@ dense_attn_decode_interface(
) {
// Check arch
Arch arch = Arch();
if (!arch.is_sm90a()) {
TORCH_CHECK(false, "Dense decode MLA is only supported on SM90a architecture");
if (!arch.is_sm90a() && !arch.is_sm80()) {
TORCH_CHECK(false, "Dense decode MLA is only supported on SM80 or SM90a architectures");
}

// Check data types
Expand Down Expand Up @@ -172,18 +177,45 @@ dense_attn_decode_interface(

params.stream = at::cuda::getCurrentCUDAStream().stream();

#define DISPATCH_DENSE_DECODE_KERNEL(SCALAR_T) \
do { \
if (arch.is_sm90a()) { \
CALL_SM90_DENSE_DECODE(SCALAR_T); \
} else if (arch.is_sm80()) { \
CALL_SM80_DENSE_DECODE(SCALAR_T); \
} else { \
TORCH_CHECK(false, "Unsupported arch for dense MLA decode"); \
} \
} while (0)

#ifndef FLASH_MLA_DISABLE_SM90
#define CALL_SM90_DENSE_DECODE(SCALAR_T) sm90::run_flash_splitkv_mla_kernel<SCALAR_T>(params)
#else
#define CALL_SM90_DENSE_DECODE(SCALAR_T) TORCH_CHECK(false, "FlashMLA was built with FLASH_MLA_DISABLE_SM90; cannot run on SM90 GPU")
#endif

#ifndef FLASH_MLA_DISABLE_SM80
#define CALL_SM80_DENSE_DECODE(SCALAR_T) sm80::run_flash_splitkv_mla_kernel<SCALAR_T>(params)
#else
#define CALL_SM80_DENSE_DECODE(SCALAR_T) TORCH_CHECK(false, "FlashMLA was built with FLASH_MLA_DISABLE_SM80; cannot run on SM80 GPU")
#endif

if (q_dtype == torch::kBFloat16) {
sm90::run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(params);
DISPATCH_DENSE_DECODE_KERNEL(cutlass::bfloat16_t);
} else if (q_dtype == torch::kHalf) {
#ifdef FLASH_MLA_DISABLE_FP16
TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP16. Please remove this flag from your environment and re-compile FlashMLA.");
#else
sm90::run_flash_splitkv_mla_kernel<cutlass::half_t>(params);
DISPATCH_DENSE_DECODE_KERNEL(cutlass::half_t);
#endif
} else {
TORCH_CHECK(false, "Unsupported dtype for dense MLA on SM90");
TORCH_CHECK(false, "Unsupported dtype for dense MLA decode");
}

#undef DISPATCH_DENSE_DECODE_KERNEL
#undef CALL_SM90_DENSE_DECODE
#undef CALL_SM80_DENSE_DECODE

CombineParams combine_params = {
batch_size, seqlen_q_ori,
num_heads_q, head_size_v,
Expand Down
13 changes: 13 additions & 0 deletions csrc/sm80/decode/dense/config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#pragma once

namespace sm80::Config {

static constexpr int BLOCK_SIZE_M = 64;
static constexpr int PAGE_BLOCK_SIZE = 64;

static constexpr int HEAD_DIM_K = 576;
static constexpr int HEAD_DIM_V = 512;

static constexpr int NUM_THREADS = 128;

}
Loading