diff --git a/benchmark/bench_sm80_decode.py b/benchmark/bench_sm80_decode.py new file mode 100644 index 00000000..909ea40c --- /dev/null +++ b/benchmark/bench_sm80_decode.py @@ -0,0 +1,142 @@ +"""Benchmark for the SM80 dense MLA decode kernel. + +Compares against a PyTorch eager (BMM-based) reference. The eager path is +slow for long sequences -- iteration counts shrink accordingly. Reports +latency, KV bandwidth, and speedup across a config sweep.""" + +import argparse +import os +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import torch +import flash_mla.cuda as cuda + + +def torch_eager_mla(q, kcache, block_table, seqlens_k, softmax_scale, head_size_v, is_causal=False): + """PyTorch BMM-based MLA decode. Returns the same shape as the kernel output.""" + b, sq, hq, dk = q.shape + _, pbs, hk, _ = kcache.shape + nq_per_hk = hq // hk + out = torch.zeros(b, sq, hq, head_size_v, dtype=q.dtype, device=q.device) + for bi in range(b): + sk = int(seqlens_k[bi].item()) + bt = block_table[bi] + nb = (sk + pbs - 1) // pbs + ks = [kcache[bt[bl].item()] for bl in range(nb)] + kc = torch.cat(ks, dim=0)[:sk] # (sk, hk, dk) + k_full = kc.transpose(0, 1).contiguous() # (hk, sk, dk) + v_full = k_full[:, :, :head_size_v] # (hk, sk, dv) + q_b = q[bi] # (sq, hq, dk) + q_rs = (q_b.view(sq, hk, nq_per_hk, dk) + .permute(1, 0, 2, 3) + .reshape(hk, sq * nq_per_hk, dk)) + scores = torch.bmm(q_rs.float(), k_full.float().transpose(1, 2)) * softmax_scale + if is_causal: + for sq_idx in range(sq): + rb = max(0, sk - (sq - sq_idx - 1)) + if rb < sk: + for nq_idx in range(nq_per_hk): + scores[:, sq_idx * nq_per_hk + nq_idx, rb:] = float('-inf') + probs = torch.softmax(scores, dim=-1) + o = torch.bmm(probs, v_full.float()) # (hk, q_per_hk, dv) + o_rs = (o.view(hk, sq, nq_per_hk, head_size_v) + .permute(1, 2, 0, 3) + .reshape(sq, hq, head_size_v)) + out[bi] = o_rs.to(q.dtype) + return out + + +def bench(fn, iters, warmup): + for _ in range(warmup): + fn() + torch.cuda.synchronize() + s = torch.cuda.Event(enable_timing=True) + e = torch.cuda.Event(enable_timing=True) + s.record() + for _ in range(iters): + fn() + e.record() + torch.cuda.synchronize() + return s.elapsed_time(e) / iters + + +def make_inputs(batch, sq, hq, hk, sk, dtype, device): + head_size_k = 576 + page_block_size = 64 + q = torch.randn(batch, sq, hq, head_size_k, dtype=dtype, device=device) * 0.1 + nb = (sk + page_block_size - 1) // page_block_size + kcache = torch.randn(nb * batch, page_block_size, hk, head_size_k, dtype=dtype, device=device) * 0.1 + seqlens_k = torch.full((batch,), sk, dtype=torch.int32, device=device) + block_table = torch.arange(nb * batch, dtype=torch.int32, device=device).view(batch, nb) + return q, kcache, seqlens_k, block_table + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--dtype', default='bfloat16', choices=['bfloat16', 'float16']) + parser.add_argument('--no-torch-baseline', action='store_true', + help='skip PyTorch eager reference (much faster sweep)') + parser.add_argument('--check', action='store_true', + help='also run a one-shot correctness check vs eager') + args = parser.parse_args() + + torch.manual_seed(0) + device = 'cuda' + dtype = getattr(torch, args.dtype) + head_size_k = 576 + head_size_v = 512 + softmax_scale = 1.0 / (head_size_k ** 0.5) + + configs = [ + # (batch, sq, hq, hk, sk) + (1, 1, 16, 1, 256), + (1, 1, 16, 1, 1024), + (1, 1, 16, 1, 4096), + (1, 1, 16, 1, 16384), + (1, 1, 16, 1, 65536), + (4, 1, 16, 1, 1024), + (4, 1, 16, 1, 4096), + (16, 1, 16, 1, 1024), + (16, 1, 16, 1, 4096), + (64, 1, 16, 1, 1024), + (64, 1, 16, 1, 4096), + (1, 1, 64, 1, 4096), + ] + + print(f'{"config":<32} {"ours(ms)":>9} {"torch(ms)":>10} {"speedup":>8} {"ours BW(GB/s)":>14}') + print('-' * 75) + for batch, sq, hq, hk, sk in configs: + q, kcache, seqlens_k, block_table = make_inputs(batch, sq, hq, hk, sk, dtype, device) + + ours_fn = lambda: cuda.dense_decode_fwd( + q, kcache, head_size_v, seqlens_k, block_table, softmax_scale, False, None, None + ) + ours_ms = bench(ours_fn, iters=200, warmup=20) + kv_bytes = batch * hk * sk * head_size_k * 2 + bw = kv_bytes / (ours_ms * 1e-3) / 1e9 + + if args.no_torch_baseline: + torch_str = 'skip' + speedup_str = '-' + else: + iters = 5 if sk * batch >= 8192 else (20 if sk * batch >= 1024 else 50) + warmup = 2 + torch_fn = lambda: torch_eager_mla(q, kcache, block_table, seqlens_k, softmax_scale, head_size_v) + torch_ms = bench(torch_fn, iters=iters, warmup=warmup) + torch_str = f'{torch_ms:.3f}' + speedup_str = f'{torch_ms / ours_ms:.1f}x' + + if args.check: + out, _, _, _ = ours_fn() + ref = torch_eager_mla(q, kcache, block_table, seqlens_k, softmax_scale, head_size_v) + diff = (out.float() - ref.float()).abs().max().item() + tag = 'OK' if diff < 0.02 else f'FAIL diff={diff:.4f}' + speedup_str = f'{speedup_str} ({tag})' + + cfg = f'b={batch} sq={sq} hq={hq} hk={hk} sk={sk}' + print(f'{cfg:<32} {ours_ms:>9.3f} {torch_str:>10} {speedup_str:>8} {bw:>14.1f}') + + +if __name__ == '__main__': + main() diff --git a/benchmark/profile_decode_step.py b/benchmark/profile_decode_step.py new file mode 100644 index 00000000..0a7b8b54 --- /dev/null +++ b/benchmark/profile_decode_step.py @@ -0,0 +1,129 @@ +"""Profile dense_decode_fwd's share of an MLA decode step. + +Uses a DeepSeek-V3-shaped 1-layer attention block: + x (b, H) -> Q proj -> q (b, hq, dk) + q + KV cache -> dense_decode_fwd -> o (b, hq, dv) + o -> O proj -> y (b, H) + +This is an under-estimate of full decode step time (no FFN / MoE / layernorm), +which means dense_decode's measured share here is an UPPER bound on its share +of a full step. If decode is < 30% even here, BLOCK_M=8 redesign (which gives ++3-5pp on decode itself) won't move the full-step needle meaningfully.""" + +import argparse +import os +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import torch +import flash_mla.cuda as cuda + + +def time_fn(fn, iters=100, warmup=20): + for _ in range(warmup): + fn() + torch.cuda.synchronize() + s = torch.cuda.Event(enable_timing=True) + e = torch.cuda.Event(enable_timing=True) + s.record() + for _ in range(iters): + fn() + e.record() + torch.cuda.synchronize() + return s.elapsed_time(e) / iters + + +def profile(batch, seqlen, hidden, num_q_heads, num_kv_heads, head_dim_k, head_dim_v, dtype): + device = 'cuda' + pbs = 64 + softmax_scale = 1.0 / (head_dim_k ** 0.5) + + # Linear projections (Q absorbed: hidden -> num_q_heads*head_dim_k). + x = torch.randn(batch, hidden, dtype=dtype, device=device) * 0.1 + W_q = torch.randn(hidden, num_q_heads * head_dim_k, dtype=dtype, device=device) * 0.01 + W_o = torch.randn(num_q_heads * head_dim_v, hidden, dtype=dtype, device=device) * 0.01 + + # KV cache (paged, num_kv_heads heads). + nb = (seqlen + pbs - 1) // pbs + total_blocks = nb * batch + kcache = torch.randn(total_blocks, pbs, num_kv_heads, head_dim_k, dtype=dtype, device=device) * 0.1 + seqlens_k = torch.full((batch,), seqlen, dtype=torch.int32, device=device) + block_table = torch.arange(total_blocks, dtype=torch.int32, device=device).view(batch, nb) + + def attn_block(): + q = (x @ W_q).view(batch, 1, num_q_heads, head_dim_k) + out, _, _, _ = cuda.dense_decode_fwd( + q, kcache, head_dim_v, seqlens_k, block_table, + softmax_scale, False, None, None, + ) + out_flat = out.contiguous().view(batch, num_q_heads * head_dim_v) + return out_flat @ W_o + + def attn_no_oproj(): + q = (x @ W_q).view(batch, 1, num_q_heads, head_dim_k) + out, _, _, _ = cuda.dense_decode_fwd( + q, kcache, head_dim_v, seqlens_k, block_table, + softmax_scale, False, None, None, + ) + return out + + def decode_only(): + q = torch.randn(batch, 1, num_q_heads, head_dim_k, dtype=dtype, device=device) * 0.1 + out, _, _, _ = cuda.dense_decode_fwd( + q, kcache, head_dim_v, seqlens_k, block_table, + softmax_scale, False, None, None, + ) + return out + + full = time_fn(attn_block) + no_op = time_fn(attn_no_oproj) + only = time_fn(decode_only) + qproj = no_op - only + oproj = full - no_op + share = only / full * 100.0 + return full, qproj, only, oproj, share + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--dtype', default='bfloat16', choices=['bfloat16', 'float16']) + args = ap.parse_args() + dtype = getattr(torch, args.dtype) + + # DeepSeek-V3 architectural constants (MLA absorbed mode). + HIDDEN = 7168 + NUM_Q_HEADS = 128 + NUM_KV_HEADS = 1 + HEAD_DIM_K = 576 + HEAD_DIM_V = 512 + + print(f'DeepSeek-V3-shaped 1-layer attention block, dtype={args.dtype}') + print(f' hidden={HIDDEN} hq={NUM_Q_HEADS} hk={NUM_KV_HEADS} dk={HEAD_DIM_K} dv={HEAD_DIM_V}') + print() + print(f'{"config":<22} {"attn(ms)":>10} {"qproj":>8} {"decode":>8} {"oproj":>8} {"decode%":>9}') + print('-' * 70) + + configs = [ + # (batch, seqlen) + (1, 1024), (1, 4096), (1, 16384), (1, 65536), + (4, 1024), (4, 4096), (4, 16384), + (16, 1024), (16, 4096), (16, 16384), + (64, 1024), (64, 4096), (64, 16384), + (128, 4096), + ] + for b, sk in configs: + try: + full, qproj, only, oproj, share = profile(b, sk, HIDDEN, NUM_Q_HEADS, NUM_KV_HEADS, HEAD_DIM_K, HEAD_DIM_V, dtype) + tag = f'b={b} sk={sk}' + print(f'{tag:<22} {full:>10.3f} {qproj:>8.3f} {only:>8.3f} {oproj:>8.3f} {share:>8.1f}%') + except torch.cuda.OutOfMemoryError: + print(f'b={b} sk={sk}: OOM (skipped)') + + print() + print('Note: this 1-layer attention block excludes FFN/MoE/layernorm/residual,') + print('which together typically dominate full step time. The "decode%" above') + print('is therefore an UPPER bound on dense_decode share of a real decode step.') + + +if __name__ == '__main__': + main() diff --git a/csrc/api/api.cpp b/csrc/api/api.cpp index f43f2a09..029ab730 100644 --- a/csrc/api/api.cpp +++ b/csrc/api/api.cpp @@ -1,15 +1,25 @@ #include +#include "dense_decode.h" + +#if !defined(FLASH_MLA_DISABLE_SM90) #include "sparse_fwd.h" #include "sparse_decode.h" -#include "dense_decode.h" +#endif + +#if !defined(FLASH_MLA_DISABLE_SM100) #include "dense_fwd.h" +#endif PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "FlashMLA"; - m.def("sparse_decode_fwd", &sparse_attn_decode_interface); m.def("dense_decode_fwd", &dense_attn_decode_interface); +#if !defined(FLASH_MLA_DISABLE_SM90) + m.def("sparse_decode_fwd", &sparse_attn_decode_interface); m.def("sparse_prefill_fwd", &sparse_attn_prefill_interface); +#endif +#if !defined(FLASH_MLA_DISABLE_SM100) m.def("dense_prefill_fwd", &FMHACutlassSM100FwdRun); m.def("dense_prefill_bwd", &FMHACutlassSM100BwdRun); +#endif } diff --git a/csrc/api/common.h b/csrc/api/common.h index 2c930ed9..51e75028 100644 --- a/csrc/api/common.h +++ b/csrc/api/common.h @@ -31,6 +31,10 @@ struct Arch { num_sms = device_prop->multiProcessorCount; } + bool is_sm80() const { + return major == 8 && minor == 0; + } + bool is_sm90a() const { return major == 9 && minor == 0; } diff --git a/csrc/api/dense_decode.h b/csrc/api/dense_decode.h index 7df178a6..90cfe50b 100644 --- a/csrc/api/dense_decode.h +++ b/csrc/api/dense_decode.h @@ -6,7 +6,12 @@ #include "common.h" #include "params.h" +#ifndef FLASH_MLA_DISABLE_SM90 #include "sm90/decode/dense/splitkv_mla.h" +#endif +#ifndef FLASH_MLA_DISABLE_SM80 +#include "sm80/decode/dense/splitkv_mla.h" +#endif #include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h" #include "smxx/decode/combine/combine.h" @@ -24,8 +29,8 @@ dense_attn_decode_interface( ) { // Check arch Arch arch = Arch(); - if (!arch.is_sm90a()) { - TORCH_CHECK(false, "Dense decode MLA is only supported on SM90a architecture"); + if (!arch.is_sm90a() && !arch.is_sm80()) { + TORCH_CHECK(false, "Dense decode MLA is only supported on SM80 or SM90a architectures"); } // Check data types @@ -172,18 +177,45 @@ dense_attn_decode_interface( params.stream = at::cuda::getCurrentCUDAStream().stream(); +#define DISPATCH_DENSE_DECODE_KERNEL(SCALAR_T) \ + do { \ + if (arch.is_sm90a()) { \ + CALL_SM90_DENSE_DECODE(SCALAR_T); \ + } else if (arch.is_sm80()) { \ + CALL_SM80_DENSE_DECODE(SCALAR_T); \ + } else { \ + TORCH_CHECK(false, "Unsupported arch for dense MLA decode"); \ + } \ + } while (0) + +#ifndef FLASH_MLA_DISABLE_SM90 +#define CALL_SM90_DENSE_DECODE(SCALAR_T) sm90::run_flash_splitkv_mla_kernel(params) +#else +#define CALL_SM90_DENSE_DECODE(SCALAR_T) TORCH_CHECK(false, "FlashMLA was built with FLASH_MLA_DISABLE_SM90; cannot run on SM90 GPU") +#endif + +#ifndef FLASH_MLA_DISABLE_SM80 +#define CALL_SM80_DENSE_DECODE(SCALAR_T) sm80::run_flash_splitkv_mla_kernel(params) +#else +#define CALL_SM80_DENSE_DECODE(SCALAR_T) TORCH_CHECK(false, "FlashMLA was built with FLASH_MLA_DISABLE_SM80; cannot run on SM80 GPU") +#endif + if (q_dtype == torch::kBFloat16) { - sm90::run_flash_splitkv_mla_kernel(params); + DISPATCH_DENSE_DECODE_KERNEL(cutlass::bfloat16_t); } else if (q_dtype == torch::kHalf) { #ifdef FLASH_MLA_DISABLE_FP16 TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP16. Please remove this flag from your environment and re-compile FlashMLA."); #else - sm90::run_flash_splitkv_mla_kernel(params); + DISPATCH_DENSE_DECODE_KERNEL(cutlass::half_t); #endif } else { - TORCH_CHECK(false, "Unsupported dtype for dense MLA on SM90"); + TORCH_CHECK(false, "Unsupported dtype for dense MLA decode"); } +#undef DISPATCH_DENSE_DECODE_KERNEL +#undef CALL_SM90_DENSE_DECODE +#undef CALL_SM80_DENSE_DECODE + CombineParams combine_params = { batch_size, seqlen_q_ori, num_heads_q, head_size_v, diff --git a/csrc/sm80/decode/dense/config.h b/csrc/sm80/decode/dense/config.h new file mode 100644 index 00000000..f684fc0a --- /dev/null +++ b/csrc/sm80/decode/dense/config.h @@ -0,0 +1,13 @@ +#pragma once + +namespace sm80::Config { + +static constexpr int BLOCK_SIZE_M = 64; +static constexpr int PAGE_BLOCK_SIZE = 64; + +static constexpr int HEAD_DIM_K = 576; +static constexpr int HEAD_DIM_V = 512; + +static constexpr int NUM_THREADS = 128; + +} diff --git a/csrc/sm80/decode/dense/instantiations/bf16.cu b/csrc/sm80/decode/dense/instantiations/bf16.cu new file mode 100644 index 00000000..e76257b2 --- /dev/null +++ b/csrc/sm80/decode/dense/instantiations/bf16.cu @@ -0,0 +1,9 @@ +#include "../splitkv_mla.cuh" + +namespace sm80 { +template void run_flash_splitkv_mla_kernel(DenseAttnDecodeParams ¶ms); +} + +// Note: cute_traits.h sanity check moved to a separate TU +// (cute_traits_sanity.cu) so that issues with cute compose with +// splitkv_mla.cuh's includes don't block the production raw-PTX kernel. diff --git a/csrc/sm80/decode/dense/instantiations/fp16.cu b/csrc/sm80/decode/dense/instantiations/fp16.cu new file mode 100644 index 00000000..9f74c663 --- /dev/null +++ b/csrc/sm80/decode/dense/instantiations/fp16.cu @@ -0,0 +1,7 @@ +#include "../splitkv_mla.cuh" + +#ifndef FLASH_MLA_DISABLE_FP16 +namespace sm80 { +template void run_flash_splitkv_mla_kernel(DenseAttnDecodeParams ¶ms); +} +#endif diff --git a/csrc/sm80/decode/dense/splitkv_mla.cuh b/csrc/sm80/decode/dense/splitkv_mla.cuh new file mode 100644 index 00000000..fcb16979 --- /dev/null +++ b/csrc/sm80/decode/dense/splitkv_mla.cuh @@ -0,0 +1,471 @@ +#pragma once + +#include +#include +#include +#include + +#include "params.h" +#include "utils.h" +#include "config.h" +#include "traits.h" +#include "../../utils.cuh" + +namespace sm80 { + +static constexpr float LOG2_E = 1.44269504088896340736f; + +// ============================================================================= +// Phase 3 -- BLOCK_M=16 + 4-wg V-quarter + double sK buffer. +// +// Threading: 128 threads / CTA = 4 warpgroups x 1 warp x 32 lanes. +// Each wg = 1 warp covers the M dim (16 rows) entirely via mma.m16n8k16. +// Wgs split V columns into 4 quarters: wg w -> V[w*128 : (w+1)*128]. +// All wgs compute QK^T (4x duplicated, but per-warp QK^T is short). +// +// SMEM layout (162 KB / 164 KB cap): +// sQ : 16 x 576 BF16 = 18 KB +// sK[0] : 64 x 576 BF16 = 72 KB +// sK[1] : 64 x 576 BF16 = 72 KB +// +// Cross-block prefetch: while iter i computes on sK[stage], iter i+1's K is +// already cp.async-issued into sK[1-stage]. wait_group<1> at iter start blocks +// only on the current stage's load, leaving the next stage's load in flight. +// ============================================================================= + +template +__global__ void __launch_bounds__(cfg::NUM_THREADS, 1) +flash_fwd_splitkv_mla_kernel_sm80(__grid_constant__ const DenseAttnDecodeParams params) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + using sm80::cfg::BLOCK_SIZE_M; + using sm80::cfg::PAGE_BLOCK_SIZE; + using sm80::cfg::HEAD_DIM_K; + using sm80::cfg::HEAD_DIM_V; + using sm80::cfg::NUM_THREADS; + using sm80::cfg::NUM_WARPGROUPS; + using sm80::cfg::HEAD_DIM_V_PER_WG; + using sm80::cfg::SMEM_STRIDE_K; + using sm80::cfg::SK_STAGES; + + constexpr int N_TILES_PER_WG = HEAD_DIM_V_PER_WG / 8; // 16 PV N-tiles per wg + constexpr int QK_K_TILES = HEAD_DIM_K / 16; // 36 + constexpr int QK_N_TILES = PAGE_BLOCK_SIZE / 8; // 8 + constexpr int PV_K_TILES = PAGE_BLOCK_SIZE / 16; // 4 + + extern __shared__ char smem_buf[]; + T* sQ = reinterpret_cast(smem_buf); + T* sK[SK_STAGES]; + { + char* sK_base = smem_buf + cfg::smem_q_bytes(); + #pragma unroll + for (int s = 0; s < SK_STAGES; ++s) { + sK[s] = reinterpret_cast(sK_base + s * cfg::smem_k_bytes()); + } + } + + const int tid = threadIdx.x; + const int warp_idx = tid / 32; // 0..3 + const int lane_idx = tid % 32; + const int wg_idx = warp_idx; // 1 warp / wg, so wg_idx == warp_idx + + const int m_block_idx = blockIdx.x; + const int k_head_idx = blockIdx.y; + const int partition_idx = blockIdx.z; + + DecodingSchedMeta sched_meta = params.tile_scheduler_metadata_ptr[partition_idx]; + if (sched_meta.begin_req_idx >= params.b) return; + + // ---- cp.async tiling (16-byte chunks) ---- + constexpr int Q_BYTES_PER_ROW = HEAD_DIM_K * sizeof(T); + constexpr int CHUNK_BYTES = 16; + constexpr int CHUNKS_PER_ROW = Q_BYTES_PER_ROW / CHUNK_BYTES; // 72 + constexpr int Q_TOTAL_CHUNKS = BLOCK_SIZE_M * CHUNKS_PER_ROW; // 16 * 72 = 1152 + constexpr int Q_PER_TID_BASE = Q_TOTAL_CHUNKS / NUM_THREADS; // 4 for 256 thread + constexpr int Q_REMAINDER = Q_TOTAL_CHUNKS - Q_PER_TID_BASE * NUM_THREADS; // 128 + constexpr int K_TOTAL_CHUNKS = PAGE_BLOCK_SIZE * CHUNKS_PER_ROW; // 64 * 72 = 4608 + constexpr int K_CHUNKS_PER_TID = K_TOTAL_CHUNKS / NUM_THREADS; // 18 + constexpr int ELEMS_PER_CHUNK = CHUNK_BYTES / sizeof(T); // 8 + static_assert(Q_BYTES_PER_ROW % CHUNK_BYTES == 0, "Q row not 16B aligned"); + static_assert(K_TOTAL_CHUNKS % NUM_THREADS == 0, "K chunks not divisible"); + + // ---- per-block-iter K load helper ---- + auto issue_k_load = [&](int block_idx, int stage, const int* block_table_ptr) { + int kv_block_index = __ldg(block_table_ptr + block_idx); + const T* gK_block = (const T*)params.k_ptr + + (int64_t)kv_block_index * params.k_batch_stride + + k_head_idx * params.k_head_stride; + T* sK_dst = sK[stage]; + #pragma unroll + for (int i = 0; i < K_CHUNKS_PER_TID; ++i) { + int chunk_idx = tid * K_CHUNKS_PER_TID + i; + int row = chunk_idx / CHUNKS_PER_ROW; + int chunk_in_row = chunk_idx % CHUNKS_PER_ROW; + int elem_offset = chunk_in_row * ELEMS_PER_CHUNK; + int swiz_off = swizzle_col_bf16(row, elem_offset); + const T* g_src = gK_block + row * params.k_row_stride + elem_offset; + uint32_t s_dst = cvta_to_shared_u32(sK_dst + row * SMEM_STRIDE_K + swiz_off); + cp_async_16_cg(s_dst, g_src); + } + cp_async_commit_group(); + }; + + // ----------------------------------------------------------------- + // batch loop + // ----------------------------------------------------------------- + for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) { + const int seqlen_k = __ldg(params.seqlens_k_ptr + batch_idx); + const int start_block_idx= batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_block_idx : 0; + const int end_block_idx = batch_idx == sched_meta.end_req_idx + ? sched_meta.end_block_idx + : (seqlen_k + PAGE_BLOCK_SIZE - 1) / PAGE_BLOCK_SIZE; + + const T* gQ = (const T*)params.q_ptr + + batch_idx * params.q_batch_stride + + m_block_idx* BLOCK_SIZE_M * params.q_row_stride + + k_head_idx * params.q_head_stride; + T* gO = (T*)params.o_ptr + + batch_idx * params.o_batch_stride + + m_block_idx* BLOCK_SIZE_M * params.o_row_stride + + k_head_idx * params.o_head_stride; + float* gLse = params.softmax_lse_ptr + + (batch_idx * params.h_k + k_head_idx) * params.q_seq_per_hk + + m_block_idx * BLOCK_SIZE_M; + const int* block_table_ptr = params.block_table + batch_idx * params.block_table_batch_stride; + + const int num_valid_seq_q = min(params.q_seq_per_hk - m_block_idx * BLOCK_SIZE_M, BLOCK_SIZE_M); + + // ---- per-row right border (causal + OOB) ---- + int rRightBorder[2]; + { + int base_row = m_block_idx * BLOCK_SIZE_M; + int row_lo = base_row + (lane_idx / 4); + int row_hi = row_lo + 8; + auto rb = [&](int row) -> int { + if (params.is_causal && row < params.q_seq_per_hk) { + int s_q_idx = row / params.q_head_per_hk; + int mask_len = params.s_q - s_q_idx - 1; + return max(0, seqlen_k - mask_len); + } + return seqlen_k; + }; + rRightBorder[0] = rb(row_lo); + rRightBorder[1] = rb(row_hi); + } + + // ---- per-batch register state ---- + // rO[N_TILES_PER_WG][4]: 16 N-tiles x 4 fp32 = 64 fp32/thread (half of Phase 2's 128) + float rO[N_TILES_PER_WG][4]; + #pragma unroll + for (int n = 0; n < N_TILES_PER_WG; ++n) { + #pragma unroll + for (int j = 0; j < 4; ++j) rO[n][j] = 0.0f; + } + float rL[2] = {0.0f, 0.0f}; + float rM[2] = {-INFINITY, -INFINITY}; + const float scale_log2 = params.scale_softmax_log2; + + // ---- Load Q tile ---- + // Each thread loads Q_PER_TID_BASE chunks; the first Q_REMAINDER threads + // load one extra chunk to cover Q_TOTAL_CHUNKS that isn't a multiple of NUM_THREADS. + #pragma unroll + for (int i = 0; i < Q_PER_TID_BASE; ++i) { + int chunk_idx = tid * Q_PER_TID_BASE + i; + int row = chunk_idx / CHUNKS_PER_ROW; + int chunk_in_row = chunk_idx % CHUNKS_PER_ROW; + int elem_offset = chunk_in_row * ELEMS_PER_CHUNK; + int swiz_off = swizzle_col_bf16(row, elem_offset); + const T* g_src = gQ + row * params.q_row_stride + elem_offset; + uint32_t s_dst = cvta_to_shared_u32(sQ + row * SMEM_STRIDE_K + swiz_off); + cp_async_16(s_dst, g_src); + } + if (tid < Q_REMAINDER) { + int chunk_idx = NUM_THREADS * Q_PER_TID_BASE + tid; + int row = chunk_idx / CHUNKS_PER_ROW; + int chunk_in_row = chunk_idx % CHUNKS_PER_ROW; + int elem_offset = chunk_in_row * ELEMS_PER_CHUNK; + int swiz_off = swizzle_col_bf16(row, elem_offset); + const T* g_src = gQ + row * params.q_row_stride + elem_offset; + uint32_t s_dst = cvta_to_shared_u32(sQ + row * SMEM_STRIDE_K + swiz_off); + cp_async_16(s_dst, g_src); + } + cp_async_commit_group(); + + // ---- Prologue: issue first K block load (stage 0) ---- + if (start_block_idx < end_block_idx) { + issue_k_load(start_block_idx, 0, block_table_ptr); + } + // Wait for both Q and the first K to finish before starting compute. + cp_async_wait_all(); + __syncthreads(); + + // ---- K-block loop with double-buffered prefetch ---- + // sK[stage] holds K_i during iter i. We issue K_{i+1} into sK[1-stage] + // during the compute of iter i. wait_group<1> at iter start waits for + // K_i to be ready, leaving K_{i+1} (if any) in flight. + int stage = 0; + // Issue prefetch for block_idx + 1 if it exists, before entering the loop. + if (start_block_idx + 1 < end_block_idx) { + issue_k_load(start_block_idx + 1, 1, block_table_ptr); + } + + for (int block_idx = start_block_idx; block_idx < end_block_idx; ++block_idx) { + const int start_token = block_idx * PAGE_BLOCK_SIZE; + T* sK_cur = sK[stage]; + + // === QK^T + softmax + rPb pack (rP fp32 inner scope to free regs before PV) === + uint32_t rPb[PV_K_TILES][4]; + { + float rP[QK_N_TILES][4]; + #pragma unroll + for (int n = 0; n < QK_N_TILES; ++n) { + #pragma unroll + for (int j = 0; j < 4; ++j) rP[n][j] = 0.0f; + } + + #pragma unroll 1 + for (int k_tile = 0; k_tile < QK_K_TILES; ++k_tile) { + int k_offset = k_tile * 16; + + // Load A (Q tile, M=16, K=16). m_block has only one warp -> ROWS_PER_WARP=16. + // ldmatrix.x4 mat layout: mat 0..3 covering (M_lo/M_hi x K_lo/K_hi). + int mat = lane_idx / 8; + int m_lo_or_hi = mat & 1; + int k_half = mat >> 1; + int q_row = m_lo_or_hi * 8 + (lane_idx % 8); + int q_col = k_offset + k_half * 8; + int q_swiz = swizzle_col_bf16(q_row, q_col); + uint32_t rQ[4]; + uint32_t s_addr_q = cvta_to_shared_u32(sQ + q_row * SMEM_STRIDE_K + q_swiz); + ldmatrix_x4(rQ, s_addr_q); + + #pragma unroll + for (int n = 0; n < QK_N_TILES; ++n) { + int n_offset = n * 8; + // K^T B operand: K stored row-major -> ldmatrix without .trans gives col-major B. + int b_mat = (lane_idx / 8) & 1; + int row_b = lane_idx & 7; + int n_row = n_offset + row_b; + int k_col = k_offset + b_mat * 8; + int k_swiz= swizzle_col_bf16(n_row, k_col); + uint32_t s_addr_b = cvta_to_shared_u32(sK_cur + n_row * SMEM_STRIDE_K + k_swiz); + uint32_t rKT[2]; + ldmatrix_x2(rKT, s_addr_b); + mma_m16n8k16_acc(rP[n], rQ, rKT); + } + } + + // Mask + scale. + #pragma unroll + for (int n = 0; n < QK_N_TILES; ++n) { + int base_token = start_token + n * 8 + (lane_idx % 4) * 2; + #pragma unroll + for (int j = 0; j < 4; ++j) rP[n][j] *= scale_log2; + if (base_token + 0 >= rRightBorder[0]) rP[n][0] = -INFINITY; + if (base_token + 1 >= rRightBorder[0]) rP[n][1] = -INFINITY; + if (base_token + 0 >= rRightBorder[1]) rP[n][2] = -INFINITY; + if (base_token + 1 >= rRightBorder[1]) rP[n][3] = -INFINITY; + } + + // new rowmax. + float new_rM[2] = {rM[0], rM[1]}; + #pragma unroll + for (int n = 0; n < QK_N_TILES; ++n) { + new_rM[0] = max(new_rM[0], max(rP[n][0], rP[n][1])); + new_rM[1] = max(new_rM[1], max(rP[n][2], rP[n][3])); + } + new_rM[0] = max(new_rM[0], __shfl_xor_sync(0xffffffff, new_rM[0], 1)); + new_rM[0] = max(new_rM[0], __shfl_xor_sync(0xffffffff, new_rM[0], 2)); + new_rM[1] = max(new_rM[1], __shfl_xor_sync(0xffffffff, new_rM[1], 1)); + new_rM[1] = max(new_rM[1], __shfl_xor_sync(0xffffffff, new_rM[1], 2)); + + float scale_for_old[2]; + scale_for_old[0] = (rM[0] == -INFINITY) ? 0.0f : exp2f(rM[0] - new_rM[0]); + scale_for_old[1] = (rM[1] == -INFINITY) ? 0.0f : exp2f(rM[1] - new_rM[1]); + + #pragma unroll + for (int n = 0; n < N_TILES_PER_WG; ++n) { + rO[n][0] *= scale_for_old[0]; + rO[n][1] *= scale_for_old[0]; + rO[n][2] *= scale_for_old[1]; + rO[n][3] *= scale_for_old[1]; + } + rL[0] *= scale_for_old[0]; + rL[1] *= scale_for_old[1]; + rM[0] = new_rM[0]; + rM[1] = new_rM[1]; + + // exp + accumulate L. + #pragma unroll + for (int n = 0; n < QK_N_TILES; ++n) { + rP[n][0] = exp2f(rP[n][0] - new_rM[0]); + rP[n][1] = exp2f(rP[n][1] - new_rM[0]); + rP[n][2] = exp2f(rP[n][2] - new_rM[1]); + rP[n][3] = exp2f(rP[n][3] - new_rM[1]); + rL[0] += rP[n][0] + rP[n][1]; + rL[1] += rP[n][2] + rP[n][3]; + } + + // pack rP -> rPb. + #pragma unroll + for (int kt = 0; kt < PV_K_TILES; ++kt) { + int n0 = kt * 2; + int n1 = kt * 2 + 1; + rPb[kt][0] = pack_2xfp32_to_b32(rP[n0][0], rP[n0][1]); + rPb[kt][1] = pack_2xfp32_to_b32(rP[n0][2], rP[n0][3]); + rPb[kt][2] = pack_2xfp32_to_b32(rP[n1][0], rP[n1][1]); + rPb[kt][3] = pack_2xfp32_to_b32(rP[n1][2], rP[n1][3]); + } + } // rP fp32 destroyed + + // === PV: rO[16x128] += rPb[16x64] @ V[64x128] === + // V[k, n] = sK_cur[k, wg_idx*128 + n]. + #pragma unroll 1 + for (int kt = 0; kt < PV_K_TILES; ++kt) { + int k_off_pv = kt * 16; + uint32_t rA[4] = { rPb[kt][0], rPb[kt][1], rPb[kt][2], rPb[kt][3] }; + + #pragma unroll + for (int nt = 0; nt < N_TILES_PER_WG; ++nt) { + int v_global_col = wg_idx * HEAD_DIM_V_PER_WG + nt * 8; + int b_mat = (lane_idx / 8) & 1; + int row_b = lane_idx & 7; + int k_row = k_off_pv + b_mat * 8 + row_b; + int v_swiz = swizzle_col_bf16(k_row, v_global_col); + uint32_t s_addr_v = cvta_to_shared_u32(sK_cur + k_row * SMEM_STRIDE_K + v_swiz); + uint32_t rB[2]; + ldmatrix_x2_trans(rB, s_addr_v); + mma_m16n8k16_acc(rO[nt], rA, rB); + } + } + + // ---- Prefetch K_{i+2} into sK[stage] (the buffer we just finished reading) ---- + // PV is done with sK_cur, so the *current* stage is now safe to overwrite. + // We use it as the buffer for K_{i+2}, leaving sK[1-stage] (=K_{i+1}) intact. + __syncthreads(); + if (block_idx + 2 < end_block_idx) { + issue_k_load(block_idx + 2, stage, block_table_ptr); + } + // Swap stage: next iter will compute on the buffer we previously prefetched. + stage = 1 - stage; + + // Wait for the *new* current stage (K_{i+1}) to be ready before next iter's compute. + // Because we always have at most 2 commit_groups in flight (one per stage), we wait + // for the older one to complete using wait_group<1>. + if (block_idx + 1 < end_block_idx) { + cp_async_wait_group<1>(); + __syncthreads(); + } + } + + // ---- Reduce rL within warp ---- + rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 1); + rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 2); + rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 1); + rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 2); + #pragma unroll + for (int i = 0; i < 2; ++i) { + rL[i] = (rL[i] == 0.0f || rL[i] != rL[i]) ? 1.0f : rL[i]; + } + + // ---- Store rO ---- + const int n_split_idx_local = batch_idx == sched_meta.begin_req_idx + ? sched_meta.begin_split_idx : 0; + const bool is_no_split = batch_idx == sched_meta.begin_req_idx + ? !sched_meta.is_first_req_splitted + : (batch_idx == sched_meta.end_req_idx ? !sched_meta.is_last_req_splitted : true); + const float rL_inv[2] = { 1.0f / rL[0], 1.0f / rL[1] }; + const int row_lo_local = lane_idx / 4; + const int row_hi_local = row_lo_local + 8; + + if (is_no_split) { + T* gO_row_lo = gO + row_lo_local * params.o_row_stride; + T* gO_row_hi = gO + row_hi_local * params.o_row_stride; + #pragma unroll + for (int nt = 0; nt < N_TILES_PER_WG; ++nt) { + int v_col = wg_idx * HEAD_DIM_V_PER_WG + nt * 8 + (lane_idx % 4) * 2; + T v0 = static_cast(rO[nt][0] * rL_inv[0]); + T v1 = static_cast(rO[nt][1] * rL_inv[0]); + T v2 = static_cast(rO[nt][2] * rL_inv[1]); + T v3 = static_cast(rO[nt][3] * rL_inv[1]); + if (row_lo_local < num_valid_seq_q) { + gO_row_lo[v_col + 0] = v0; + gO_row_lo[v_col + 1] = v1; + } + if (row_hi_local < num_valid_seq_q) { + gO_row_hi[v_col + 0] = v2; + gO_row_hi[v_col + 1] = v3; + } + } + // LSE: only wg 0 writes (all wgs have the same rL/rM up to noise; pick one). + if (wg_idx == 0 && (lane_idx % 4) == 0) { + int row_lo = lane_idx / 4; + int row_hi = row_lo + 8; + float lse_lo = (rL[0] == 0.0f || rL[0] != rL[0]) + ? INFINITY : (logf(rL[0]) + rM[0] / LOG2_E); + float lse_hi = (rL[1] == 0.0f || rL[1] != rL[1]) + ? INFINITY : (logf(rL[1]) + rM[1] / LOG2_E); + if (row_lo < num_valid_seq_q) gLse[row_lo] = lse_lo; + if (row_hi < num_valid_seq_q) gLse[row_hi] = lse_hi; + } + } else { + const int split_idx = params.num_splits_ptr[batch_idx] + n_split_idx_local; + float* gOAccum = (float*)params.oaccum_ptr + + ((int64_t)(split_idx * params.h_k + k_head_idx) * params.q_seq_per_hk + + m_block_idx * BLOCK_SIZE_M) * HEAD_DIM_V; + float* gLseAccum = params.softmax_lseaccum_ptr + + (split_idx * params.h_k + k_head_idx) * params.q_seq_per_hk + + m_block_idx * BLOCK_SIZE_M; + float* gOA_row_lo = gOAccum + row_lo_local * HEAD_DIM_V; + float* gOA_row_hi = gOAccum + row_hi_local * HEAD_DIM_V; + #pragma unroll + for (int nt = 0; nt < N_TILES_PER_WG; ++nt) { + int v_col = wg_idx * HEAD_DIM_V_PER_WG + nt * 8 + (lane_idx % 4) * 2; + float v0 = rO[nt][0] * rL_inv[0]; + float v1 = rO[nt][1] * rL_inv[0]; + float v2 = rO[nt][2] * rL_inv[1]; + float v3 = rO[nt][3] * rL_inv[1]; + if (row_lo_local < num_valid_seq_q) { + gOA_row_lo[v_col + 0] = v0; + gOA_row_lo[v_col + 1] = v1; + } + if (row_hi_local < num_valid_seq_q) { + gOA_row_hi[v_col + 0] = v2; + gOA_row_hi[v_col + 1] = v3; + } + } + if (wg_idx == 0 && (lane_idx % 4) == 0) { + int row_lo = lane_idx / 4; + int row_hi = row_lo + 8; + float lse_lo = (rL[0] == 0.0f || rL[0] != rL[0]) + ? -INFINITY : (log2f(rL[0]) + rM[0]); + float lse_hi = (rL[1] == 0.0f || rL[1] != rL[1]) + ? -INFINITY : (log2f(rL[1]) + rM[1]); + if (row_lo < num_valid_seq_q) gLseAccum[row_lo] = lse_lo; + if (row_hi < num_valid_seq_q) gLseAccum[row_hi] = lse_hi; + } + } + + if (batch_idx != sched_meta.end_req_idx) __syncthreads(); + } +#endif +} + +template +void run_flash_splitkv_mla_kernel(DenseAttnDecodeParams ¶ms) { + using namespace sm80::cfg; + FLASH_ASSERT(params.d == HEAD_DIM_K); + FLASH_ASSERT(params.d_v == HEAD_DIM_V); + + constexpr size_t smem_size = smem_total_bytes(); + auto kernel = &flash_fwd_splitkv_mla_kernel_sm80; + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + const int num_m_block = (params.q_seq_per_hk + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M; + dim3 grid(num_m_block, params.h_k, params.num_sm_parts); + dim3 block(NUM_THREADS, 1, 1); + + kernel<<>>(params); + CHECK_CUDA_KERNEL_LAUNCH(); +} + +} diff --git a/csrc/sm80/decode/dense/splitkv_mla.h b/csrc/sm80/decode/dense/splitkv_mla.h new file mode 100644 index 00000000..36fa10b9 --- /dev/null +++ b/csrc/sm80/decode/dense/splitkv_mla.h @@ -0,0 +1,10 @@ +#pragma once + +#include "params.h" + +namespace sm80 { + +template +void run_flash_splitkv_mla_kernel(DenseAttnDecodeParams ¶ms); + +} diff --git a/csrc/sm80/decode/dense/traits.h b/csrc/sm80/decode/dense/traits.h new file mode 100644 index 00000000..325360b4 --- /dev/null +++ b/csrc/sm80/decode/dense/traits.h @@ -0,0 +1,54 @@ +#pragma once + +namespace sm80 { + +namespace cfg { + +// Q tile / KV page geometry. +// +// BLOCK_SIZE_M = 16: one warp covers the M dim exactly via mma.m16n8k16. +// This is the smallest tile that doesn't waste mma rows; it lets us fit a +// double sK buffer in SMEM (sQ 18 KB + 2 x sK 144 KB = 162 KB <= 164 KB cap), +// which is the key to overlapping K loading with PV compute. +constexpr int BLOCK_SIZE_M = 16; +constexpr int PAGE_BLOCK_SIZE = 64; +constexpr int HEAD_DIM_K = 576; +constexpr int HEAD_DIM_V = 512; + +// Threading. +// 4 warpgroups x 1 warp/wg x 32 threads = 128 threads / CTA. +// Each wg owns the full M dim (one warp suffices) and a quarter of V. +// All wgs compute QK^T independently (4x duplicated, ~1.4 KFLOPS/CTA); +// compute is not the bottleneck so this redundancy is acceptable. +// +// Tried 8-wg V-eighth (256 thread, spill 320B -> 80B, per-warp rO 32 fp32): +// measured 30-40% regression vs 4-wg, attributed to extra __syncthreads + +// duplicate QK^T cycles outweighing the spill saving (which was small to +// begin with: 320B/iter << 73KB/iter HBM K traffic). +constexpr int NUM_THREADS = 128; +constexpr int NUM_WARPS = NUM_THREADS / 32; // 4 +constexpr int NUM_WARPGROUPS = 4; +constexpr int WARPS_PER_WG = NUM_WARPS / NUM_WARPGROUPS;// 1 +constexpr int ROWS_PER_WARP = BLOCK_SIZE_M / WARPS_PER_WG;// 16 +constexpr int HEAD_DIM_V_PER_WG = HEAD_DIM_V / NUM_WARPGROUPS;// 128 + +// SMEM row stride. Swizzle alone (no padding) gives 0 bank conflict for the +// "all-lanes-same-column-different-row" pattern that dominates QK^T / PV. +// Combining padding with swizzle re-creates conflicts -- DON'T. +constexpr int SMEM_PAD_K = 0; +constexpr int SMEM_STRIDE_K = HEAD_DIM_K + SMEM_PAD_K; + +// Number of sK buffers in SMEM (double-buffered for cross-block prefetch). +constexpr int SK_STAGES = 2; + +// SMEM region byte sizes. +template +constexpr int smem_q_bytes() { return BLOCK_SIZE_M * SMEM_STRIDE_K * sizeof(T); } +template +constexpr int smem_k_bytes() { return PAGE_BLOCK_SIZE * SMEM_STRIDE_K * sizeof(T); } +template +constexpr int smem_total_bytes() { return smem_q_bytes() + SK_STAGES * smem_k_bytes(); } + +} + +} diff --git a/csrc/sm80/utils.cuh b/csrc/sm80/utils.cuh new file mode 100644 index 00000000..bc0ff974 --- /dev/null +++ b/csrc/sm80/utils.cuh @@ -0,0 +1,166 @@ +#pragma once + +#include +#include + +namespace sm80 { + +__device__ __forceinline__ uint32_t cvta_to_shared_u32(const void* ptr) { + return static_cast(__cvta_generic_to_shared(ptr)); +} + +__device__ __forceinline__ void cp_async_16(uint32_t smem_addr, const void* gmem_ptr) { + asm volatile( + "cp.async.ca.shared.global [%0], [%1], 16;\n" + :: "r"(smem_addr), "l"(gmem_ptr) + ); +} + +// L1-bypassing 16-byte cp.async (cache at L2 only). Useful for streaming +// reads that won't be reused at the L1 level, e.g. KV cache in attention. +__device__ __forceinline__ void cp_async_16_cg(uint32_t smem_addr, const void* gmem_ptr) { + asm volatile( + "cp.async.cg.shared.global [%0], [%1], 16;\n" + :: "r"(smem_addr), "l"(gmem_ptr) + ); +} + +__device__ __forceinline__ void cp_async_16_zfill_oob(uint32_t smem_addr, const void* gmem_ptr, bool in_bounds) { + int src_size = in_bounds ? 16 : 0; + asm volatile( + "cp.async.ca.shared.global [%0], [%1], 16, %2;\n" + :: "r"(smem_addr), "l"(gmem_ptr), "r"(src_size) + ); +} + +__device__ __forceinline__ void cp_async_commit_group() { + asm volatile("cp.async.commit_group;\n"); +} + +__device__ __forceinline__ void cp_async_wait_all() { + asm volatile("cp.async.wait_all;\n"); +} + +template +__device__ __forceinline__ void cp_async_wait_group() { + asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); +} + +// ldmatrix.x4 .m8n8 .shared .b16 +__device__ __forceinline__ void ldmatrix_x4(uint32_t (&out)[4], uint32_t smem_addr) { + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(out[0]), "=r"(out[1]), "=r"(out[2]), "=r"(out[3]) + : "r"(smem_addr) + ); +} + +__device__ __forceinline__ void ldmatrix_x4_trans(uint32_t (&out)[4], uint32_t smem_addr) { + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(out[0]), "=r"(out[1]), "=r"(out[2]), "=r"(out[3]) + : "r"(smem_addr) + ); +} + +__device__ __forceinline__ void ldmatrix_x2(uint32_t (&out)[2], uint32_t smem_addr) { + asm volatile( + "ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];\n" + : "=r"(out[0]), "=r"(out[1]) + : "r"(smem_addr) + ); +} + +__device__ __forceinline__ void ldmatrix_x2_trans(uint32_t (&out)[2], uint32_t smem_addr) { + asm volatile( + "ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0, %1}, [%2];\n" + : "=r"(out[0]), "=r"(out[1]) + : "r"(smem_addr) + ); +} + +// Pack two FP32 values into one BF16x2 (.b32) register via PTX cvt. +__device__ __forceinline__ uint32_t pack_bf16x2(float a, float b) { + uint32_t out; + asm("cvt.rn.bf16x2.f32 %0, %1, %2;\n" : "=r"(out) : "f"(b), "f"(a)); + return out; +} + +// Pack two FP32 values into one FP16x2 (.b32) register. +__device__ __forceinline__ uint32_t pack_fp16x2(float a, float b) { + uint32_t out; + asm("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(out) : "f"(b), "f"(a)); + return out; +} + +template +__device__ __forceinline__ uint32_t pack_2xfp32_to_b32(float a, float b); + +template<> +__device__ __forceinline__ uint32_t pack_2xfp32_to_b32(float a, float b) { + return pack_bf16x2(a, b); +} + +template<> +__device__ __forceinline__ uint32_t pack_2xfp32_to_b32(float a, float b) { + return pack_fp16x2(a, b); +} + +// mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 (accumulating into D) +template +__device__ __forceinline__ void mma_m16n8k16_acc( + float (&D)[4], const uint32_t (&A)[4], const uint32_t (&B)[2]); + +template<> +__device__ __forceinline__ void mma_m16n8k16_acc( + float (&D)[4], const uint32_t (&A)[4], const uint32_t (&B)[2] +) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}, " + "{%4, %5, %6, %7}, " + "{%8, %9}, " + "{%0, %1, %2, %3};\n" + : "+f"(D[0]), "+f"(D[1]), "+f"(D[2]), "+f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]) + ); +} + +template<> +__device__ __forceinline__ void mma_m16n8k16_acc( + float (&D)[4], const uint32_t (&A)[4], const uint32_t (&B)[2] +) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}, " + "{%4, %5, %6, %7}, " + "{%8, %9}, " + "{%0, %1, %2, %3};\n" + : "+f"(D[0]), "+f"(D[1]), "+f"(D[2]), "+f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]) + ); +} + +// XOR swizzle for SMEM K-major layout (BF16/FP16). Within each row, the +// element offset (in BF16 units) is XOR-permuted by ((row & 7) << 3) so that +// 8 consecutive rows accessing the same logical column land on 8 different +// 16-byte chunks. This eliminates the 32-way bank conflict that the +// non-swizzled K-major layout suffers from. +// +// Granularity: 16-byte chunks (= 8 BF16 elements). All ldmatrix and cp.async +// accesses in this kernel are 16-byte aligned, so the XOR is well-defined. +__device__ __forceinline__ int swizzle_col_bf16(int row, int col) { + return col ^ ((row & 7) << 3); +} + +// 16-byte SMEM store (st.shared.b128) +__device__ __forceinline__ void st_shared_b128(uint32_t smem_addr, const uint32_t (&data)[4]) { + asm volatile( + "st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n" + :: "r"(smem_addr), "r"(data[0]), "r"(data[1]), "r"(data[2]), "r"(data[3]) + ); +} + +} diff --git a/csrc/smxx/decode/combine/combine.cu b/csrc/smxx/decode/combine/combine.cu index 376dadd9..de08ee07 100644 --- a/csrc/smxx/decode/combine/combine.cu +++ b/csrc/smxx/decode/combine/combine.cu @@ -56,7 +56,11 @@ flash_fwd_mla_combine_kernel(__grid_constant__ const CombineParams params) { __shared__ float smem_buf[BLOCK_SIZE_M][MAX_SPLITS]; // Wait for the previous kernel (the MLA kernel) to finish + // PDL (Programmatic Dependent Launch) requires SM90+. On older arches the + // launch is not overlapped, so the dependency is already satisfied. +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); +#endif // Prefetch static_assert(HEAD_DIM_V % (32*4) == 0); diff --git a/csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu b/csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu index 083da60c..4c4f793a 100644 --- a/csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu +++ b/csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu @@ -8,7 +8,7 @@ namespace smxx::decode { -__global__ void __launch_bounds__(32, 1, 1) +__global__ void __launch_bounds__(32, 1) get_mla_metadata_kernel(__grid_constant__ const GetDecodeSchedMetaParams params) { int *seqlens_k_ptr = params.seqlens_k_ptr; DecodingSchedMeta *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr; diff --git a/setup.py b/setup.py index 513b4355..0459d1a9 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,7 @@ def get_arch_flags(): DISABLE_SM100 = is_flag_set("FLASH_MLA_DISABLE_SM100") DISABLE_SM90 = is_flag_set("FLASH_MLA_DISABLE_SM90") + DISABLE_SM80 = is_flag_set("FLASH_MLA_DISABLE_SM80") if major < 12 or (major == 12 and minor <= 8): assert DISABLE_SM100, "sm100 compilation for Flash MLA requires NVCC 12.9 or higher. Please set FLASH_MLA_DISABLE_SM100=1 to disable sm100 compilation, or update your environment." # TODO Implement this @@ -43,6 +44,8 @@ def get_arch_flags(): arch_flags.extend(["-gencode", "arch=compute_100f,code=sm_100f"]) if not DISABLE_SM90: arch_flags.extend(["-gencode", "arch=compute_90a,code=sm_90a"]) + if not DISABLE_SM80: + arch_flags.extend(["-gencode", "arch=compute_80,code=sm_80"]) return arch_flags def get_nvcc_thread_args(): @@ -53,59 +56,100 @@ def get_nvcc_thread_args(): this_dir = os.path.dirname(os.path.abspath(__file__)) + +def get_nvidia_wheel_includes(): + """Collect include paths from pip-installed nvidia-* wheels (e.g. + nvidia-cusparse-cu12). PyTorch headers transitively include cusparse.h, + which is not present in some system CUDA toolkits.""" + paths = [] + try: + import nvidia + for nvidia_root in nvidia.__path__: + root = Path(nvidia_root) + if not root.exists(): + continue + for sub in root.iterdir(): + inc = sub / "include" + if inc.is_dir(): + paths.append(inc) + except ImportError: + pass + return paths + if IS_WINDOWS: cxx_args = ["/O2", "/std:c++20", "/DNDEBUG", "/W0"] else: cxx_args = ["-O3", "-std=c++20", "-DNDEBUG", "-Wno-deprecated-declarations"] +DISABLE_SM100 = is_flag_set("FLASH_MLA_DISABLE_SM100") +DISABLE_SM90 = is_flag_set("FLASH_MLA_DISABLE_SM90") +DISABLE_SM80 = is_flag_set("FLASH_MLA_DISABLE_SM80") + +sources = ["csrc/api/api.cpp"] + +# Misc kernels for decoding (arch-agnostic, used by all dense decode paths) +sources += [ + "csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu", + "csrc/smxx/decode/combine/combine.cu", +] + +if not DISABLE_SM80: + sources += [ + "csrc/sm80/decode/dense/instantiations/fp16.cu", + "csrc/sm80/decode/dense/instantiations/bf16.cu", + ] + +if not DISABLE_SM90: + sources += [ + # sm90 dense decode + "csrc/sm90/decode/dense/instantiations/fp16.cu", + "csrc/sm90/decode/dense/instantiations/bf16.cu", + # sm90 sparse decode + "csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu", + "csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu", + "csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu", + "csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu", + # sm90 sparse prefill + "csrc/sm90/prefill/sparse/fwd.cu", + "csrc/sm90/prefill/sparse/instantiations/phase1_k512.cu", + "csrc/sm90/prefill/sparse/instantiations/phase1_k512_topklen.cu", + "csrc/sm90/prefill/sparse/instantiations/phase1_k576.cu", + "csrc/sm90/prefill/sparse/instantiations/phase1_k576_topklen.cu", + ] + +if not DISABLE_SM100: + sources += [ + # sm100 dense prefill & backward + "csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu", + "csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu", + # sm100 sparse prefill + "csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k512.cu", + "csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k576.cu", + "csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k512.cu", + "csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k576.cu", + "csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_prefill_k512.cu", + # sm100 sparse decode + "csrc/sm100/decode/head64/instantiations/v32.cu", + "csrc/sm100/decode/head64/instantiations/model1.cu", + "csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_decode_k512.cu", + ] + +cxx_features = [] +if DISABLE_SM100: + cxx_features.append("-DFLASH_MLA_DISABLE_SM100") +if DISABLE_SM90: + cxx_features.append("-DFLASH_MLA_DISABLE_SM90") +if DISABLE_SM80: + cxx_features.append("-DFLASH_MLA_DISABLE_SM80") + ext_modules = [] ext_modules.append( CUDAExtension( name="flash_mla.cuda", - sources=[ - # API - "csrc/api/api.cpp", - - # Misc kernels for decoding - "csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu", - "csrc/smxx/decode/combine/combine.cu", - - # sm90 dense decode - "csrc/sm90/decode/dense/instantiations/fp16.cu", - "csrc/sm90/decode/dense/instantiations/bf16.cu", - - # sm90 sparse decode - "csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu", - "csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu", - "csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu", - "csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu", - - # sm90 sparse prefill - "csrc/sm90/prefill/sparse/fwd.cu", - "csrc/sm90/prefill/sparse/instantiations/phase1_k512.cu", - "csrc/sm90/prefill/sparse/instantiations/phase1_k512_topklen.cu", - "csrc/sm90/prefill/sparse/instantiations/phase1_k576.cu", - "csrc/sm90/prefill/sparse/instantiations/phase1_k576_topklen.cu", - - # sm100 dense prefill & backward - "csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu", - "csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu", - - # sm100 sparse prefill - "csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k512.cu", - "csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k576.cu", - "csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k512.cu", - "csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k576.cu", - "csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_prefill_k512.cu", - - # sm100 sparse decode - "csrc/sm100/decode/head64/instantiations/v32.cu", - "csrc/sm100/decode/head64/instantiations/model1.cu", - "csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_decode_k512.cu", - ], + sources=sources, extra_compile_args={ - "cxx": cxx_args + get_features_args(), - "nvcc": [ + "cxx": cxx_args + cxx_features + get_features_args(), + "nvcc": cxx_features + [ "-O3", "-std=c++20", "-DNDEBUG", @@ -126,10 +170,11 @@ def get_nvcc_thread_args(): include_dirs=[ Path(this_dir) / "csrc", Path(this_dir) / "csrc" / "kerutils" / "include", # TODO Remove me + Path(this_dir) / "csrc" / "sm80", Path(this_dir) / "csrc" / "sm90", Path(this_dir) / "csrc" / "cutlass" / "include", Path(this_dir) / "csrc" / "cutlass" / "tools" / "util" / "include", - ], + ] + get_nvidia_wheel_includes(), ) )