diff --git a/docs/FAQ.md b/docs/FAQ.md new file mode 100644 index 00000000..af52581f --- /dev/null +++ b/docs/FAQ.md @@ -0,0 +1,132 @@ +# FlashMLA FAQ + +Frequently asked questions about FlashMLA, based on common GitHub issues and user inquiries. + +## General Questions + +### Does `flash_mla_with_kvcache` work only in paged mode? + +No. It supports both paged and non-paged KV cache. Paged mode is recommended for long-context workloads and memory fragmentation control, but non-paged works for shorter, fixed-length workloads. + +### Can MLA/MHA be used in the prefill stage? + +Yes. Prefill supports both MLA and MHA: +- Use **MLA** for sparse/long contexts with grouped-query attention +- Use **MHA** when you need standard dense attention or compatibility with existing kernels + +### What GPU architectures are supported? + +| Architecture | Support | +|--------------|---------| +| SM90 (Hopper - H100/H800) | Full support | +| SM100 (Blackwell - B100/B200) | Full support | +| SM120 | Not supported | +| Older (SM80, SM70) | Not supported | + +Use matching CUDA drivers (12.8+) and ensure your build targets these architectures. + +## Sparse Attention + +### How do I test sparse attention? + +1. Enable MLA sparse mode in your config +2. Load a sparse pattern (block-sparse mask) +3. Run the sparse test suite: `python tests/test_flash_mla_sparse.py` +4. Compare outputs against dense runs to validate correctness +5. Check perf counters to confirm sparse paths are used + +### What's the difference between dense and sparse MLA? + +| Feature | Dense MLA | Sparse MLA | +|---------|-----------|------------| +| Computation | Full attention matrix | Pruned blocks | +| Performance (H800) | ~660 TFlops | ~410 TFlops | +| Use case | Short sequences, full accuracy | Long sequences, structured sparsity | +| SM90 support | Yes | No | +| SM100 support | Yes | Yes | + +Dense MLA computes full attention (higher FLOPs, higher accuracy for dense tasks). Sparse MLA prunes blocks to reduce compute/memory, trading some fidelity for speed/throughput on long or structured-sparsity workloads. + +## Integration + +### How do I integrate FlashMLA with vLLM/SGLang/other frameworks? + +1. Use the provided FlashMLA attention interface +2. Register it as the backend kernel for attention ops +3. Follow the framework's custom op/plugin hooks: + - **vLLM**: Custom attention registry + - **SGLang**: Extension points +4. Ensure paged KV cache shape/block alignment matches FlashMLA's layout (block size = 64) + +Example for vLLM: +```python +from vllm.attention.backends import register_attention_backend +from flash_mla import FlashMLABackend + +register_attention_backend("flash_mla", FlashMLABackend) +``` + +## Performance + +### What performance should I expect on different GPUs? + +| GPU | Dense MLA | Sparse MLA | Notes | +|-----|-----------|------------|-------| +| H800 | ~660 TFlops | ~410 TFlops | Reference config | +| H100 | ~600 TFlops | ~380 TFlops | Slightly lower than H800 | +| B200 | ~1460 TFlops (MHA prefill) | TBD | SM100 optimizations | + +Actual numbers vary with sequence length, batch shape, paging configuration, and sparsity pattern. Expect higher throughput on SM100 vs SM90 when similarly configured. + +## Memory & Batching + +### How do I handle variable-length sequences in a batch? + +1. Use paged KV cache with proper offsets/indirection per sequence +2. Provide per-sequence lengths via `seq_lens` tensor +3. Use masks to avoid reading/writing past valid tokens +4. Pad to block boundaries (64 tokens) only where required by the cache layout + +Example: +```python +# Variable-length batch +seq_lens = torch.tensor([128, 256, 64, 512]) # Different lengths per batch +nblk_per_seq = (seq_lens + 63) // 64 # Blocks needed per sequence +max_nblk = nblk_per_seq.max() + +# Allocate cache with max blocks +k_cache = torch.zeros(B, Hk, max_nblk, 64, D, dtype=torch.bfloat16) +``` + +### What's the block size for paged KV cache? + +Block size is **64 tokens**. This is fixed and cannot be changed. Align all allocations and paging logic to 64-token blocks. + +## Quantization + +### How do I enable FP8 quantization for the KV cache? + +1. Ensure GPU supports FP8 (SM90/SM100) +2. Build with FP8 KV cache enabled +3. Use the FP8-compatible KV layout +4. Apply calibration/scaling utilities provided by FlashMLA + +```python +# FP8 KV cache setup +k_cache = torch.zeros(B, Hk, Nblk, 64, D, dtype=torch.float8_e4m3fn) +v_cache = torch.zeros(B, Hk, Nblk, 64, D, dtype=torch.float8_e4m3fn) + +# Scaling factors (calibrated) +k_scale = torch.ones(B, Hk, 1, 1, 1) +v_scale = torch.ones(B, Hk, 1, 1, 1) +``` + +Fall back to FP16/BF16 if FP8 is unsupported on your GPU. + +--- + +## Still have questions? + +If your question isn't answered here: +1. Search [existing issues](https://github.com/deepseek-ai/FlashMLA/issues) +2. Open a new issue with your GPU model, CUDA version, and detailed question diff --git a/docs/TENSOR_SHAPES.md b/docs/TENSOR_SHAPES.md new file mode 100644 index 00000000..45eaaaac --- /dev/null +++ b/docs/TENSOR_SHAPES.md @@ -0,0 +1,155 @@ +# Tensor Shapes Reference - FlashMLA + +This document provides a comprehensive reference for tensor shapes expected by FlashMLA functions. + +## Shape Notation + +| Symbol | Meaning | +|--------|---------| +| B | Batch size | +| S | Sequence length (tokens in request) | +| P | Prompt length (prefill tokens) | +| T | Decode step count (incremental tokens) | +| H | Attention heads | +| Hq | Query heads | +| Hk | Key/Value heads (Hk <= Hq for grouped-query attention) | +| D | Per-head hidden dimension | +| G | Groups for grouped-query attention (G = Hk; Hq = G * ratio) | +| Nblk | Number of KV blocks: `ceil((P + T) / 64)` | +| Blk | Block size for paged cache (always 64) | + +**Layout**: Row-major contiguous unless noted; strides allowed if stated. + +## Function Reference + +### 1. `flash_mla_with_kvcache` - MLA Decoding with Paged KV Cache + +```python +flash_mla_with_kvcache( + q: [B, Hq, 1, D], # Current token query + k_cache: [B, Hk, Nblk, 64, D], # Paged K blocks + v_cache: [B, Hk, Nblk, 64, D], # Paged V blocks + block_table: [B, Nblk], # Maps block index -> physical cache block + seq_lens: [B], # Prompt + decoded length per batch + mask: [B, 1, 1, P+T], # Optional; causal if omitted + metadata: get_mla_metadata(...), +) -> out: [B, Hq, 1, D] +``` + +**Layout Requirements**: +- `q`: contiguous +- `k_cache`/`v_cache`: contiguous within last two dims (Blk, D) +- `block_table`: must be contiguous + +**Notes**: +- MLA uses grouped-query attention: typically Hq = multiple of Hk (e.g., 8:4) +- Decoding assumes the last cache block may be partially filled; unused slots ignored via `seq_lens` + +### 2. `get_mla_metadata` - Metadata for Decoding + +```python +get_mla_metadata( + q_shape: tuple[B, Hq, 1, D], + kv_shape: tuple[B, Hk, Nblk, 64, D], + block_table: [B, Nblk], + seq_lens: [B], +) -> metadata +``` + +**Notes**: +- Pure shape/stride validation; does not materialize tensors +- Requires consistent Hq/Hk grouping and block_table coverage up to max `seq_lens` + +### 3. `flash_mla_prefill` - Prefill Attention with Sparse Patterns + +```python +flash_mla_prefill( + q: [B, Hq, P, D], + k: [B, Hk, P, D], + v: [B, Hk, P, D], + sparse_layout: optional pattern (e.g., block-sparse mask), + attn_mask: [B, 1, P, P] or [B, Hq, P, P], # Causal or custom +) -> out: [B, Hq, P, D] +``` + +**Layout Requirements**: +- q/k/v: contiguous on last dim; leading dims can be strided but consistent +- Sparse pattern typically block-aligned to 64 for reuse with paged cache + +### 4. `mha_fwd_kvcache` - MHA Prefill with KV Cache + +```python +mha_fwd_kvcache( + q: [B, H, P, D], + k_cache: [B, H, Nblk, 64, D], + v_cache: [B, H, Nblk, 64, D], + block_table: [B, Nblk], + attn_mask: [B, 1, P, P] or [B, H, P, P], +) -> out: [B, H, P, D] +``` + +**Notes**: +- Standard MHA (no MLA head-grouping) +- Prefill writes into paged cache; P must align with block_table coverage (`ceil(P/64)` blocks) + +## Memory Layout Requirements + +| Requirement | Details | +|-------------|---------| +| Contiguous dims | Inner dims (D and Blk) must be contiguous for q/k/v and caches | +| Block size | Fixed at 64; `block_table` must enumerate all blocks in order | +| Mixed precision | FP8/BF16/FP16 supported; metadata and indices are integer/FP32 | +| Strides | Batch/head/sequence dims may be strided if monotonic; non-monotonic unsupported | + +## Visual Layout Diagrams + +### Paged KV Cache (per batch, per head) + +``` +k_cache[b, h] -> [ blk0 | blk1 | ... | blk(Nblk-1) ] + | + v + blkX: 64 rows x D cols (contiguous slab) +``` + +### Grouped-Query Attention (Hq > Hk) + +``` +Hq heads (e.g., 8) +├── group0 (q heads 0-1) -> maps to kv head 0 +├── group1 (q heads 2-3) -> maps to kv head 1 +├── group2 (q heads 4-5) -> maps to kv head 2 +└── group3 (q heads 6-7) -> maps to kv head 3 +``` + +### Prefill to Cache Mapping (P tokens) + +``` +tokens 0..63 -> block_table[b, 0] +tokens 64..127 -> block_table[b, 1] +tokens 128..191 -> block_table[b, 2] +... +``` + +## Common Shape Mismatch Errors & Fixes + +| Error | Cause | Fix | +|-------|-------|-----| +| Hq vs Hk mismatch | `Hq` is not an integer multiple of `Hk` | Adjust q reshape or head-splitting for MLA | +| Wrong block count | `Nblk` != `ceil(seq_len/64)` | Update block_table or cache allocation | +| Non-contiguous inner dims | q/k_cache/v_cache not contiguous on D/Blk | Call `.contiguous()` or re-pack before kernel call | +| Mask length short | `attn_mask` last dim < `P+T` | Pad mask or recompute with correct seq length | +| Seq_lens vs cache | `seq_lens[b]` > `Nblk*64` | Grow cache or truncate sequence | +| Mixed precision | q/k/v and cache have different dtypes | Cast consistently to FP8/BF16/FP16 | + +## Quick Checks Before Calling + +```python +# Before flash_mla_with_kvcache +assert q.shape == (B, Hq, 1, D) +assert k_cache.shape == (B, Hk, Nblk, 64, D) +assert Hq % Hk == 0 # MLA grouping +assert block_table.shape == (B, Nblk) +assert all(seq_lens[b] <= Nblk * 64 for b in range(B)) +assert q.is_contiguous() +``` diff --git a/docs/TROUBLESHOOTING.md b/docs/TROUBLESHOOTING.md new file mode 100644 index 00000000..cad8583d --- /dev/null +++ b/docs/TROUBLESHOOTING.md @@ -0,0 +1,111 @@ +# FlashMLA Troubleshooting Guide + +This guide summarizes frequent issues reported by users and provides quick fixes. If a problem persists, follow the linked GitHub issues for deeper context. + +## GPU Architecture & Compatibility + +- **Supported architectures**: SM90 (Hopper), SM100 (Blackwell). SM120 support is not available yet. +- **Sparse attention support**: + - SM100: Sparse BF16/FP16 available. + - SM90: Dense attention only; Sparse BF16 is **not supported** (see error below). + - SM120: Not supported. +- **Performance reference** (dense decoding, H800): Dense MLA ~660 TFlops, Sparse ~410 TFlops. + +## GPU Compatibility Matrix + +| GPU Arch | Dense MLA | Sparse MLA | Notes | +|----------|-----------|------------|-------| +| SM90 (Hopper) | BF16/FP16 | Not supported | Use dense kernels only | +| SM100 (Blackwell) | BF16/FP16 | BF16/FP16 (config-dependent) | Ensure kernels are built with sparse enabled | +| SM120 | Unsupported | Unsupported | Not planned yet | + +## Sparse vs Dense Attention Usage + +**Symptom**: Users enable sparse attention on SM90 and get lower throughput or runtime errors. + +**Fix**: On SM90, disable sparse attention (set `use_sparse=False` or equivalent flag). Use dense kernels only. On SM100, sparse attention requires correct build flags and runtime configuration. Ensure your model config selects sparse kernels only where supported. + +## RuntimeError: Sparse BF16 MLA is not supported on SM90 + +**Error**: +``` +RuntimeError: Sparse BF16 MLA is not supported on SM90 +``` + +**Cause**: Sparse kernels are not compiled or supported on SM90. + +**Fix**: +1. Switch to dense attention +2. Rebuild without sparse flags +3. Verify your model config does not request sparse kernels + +## Build Errors by CUDA Version + +**Requirements**: CUDA 12.8+, PyTorch 2.0+ + +### Common Errors + +| Error | Cause | Solution | +|-------|-------|----------| +| `nvcc fatal : Unsupported gpu architecture 'sm_120'` | Unsupported arch flag | Remove unsupported arch flags; limit to sm_90 or sm_100 | +| `undefined symbol: ... cudaMemcpyAsync` | CUDA runtime/toolkit mismatch | Align CUDA toolkit with driver/runtime version | +| `ptxas fatal : Value 'sm_90' is not defined` | Old CUDA toolkit | Ensure CUDA >= 12.8 | + +### Build Checklist + +1. Confirm `nvcc --version` reports 12.8 or newer +2. Clear `CMAKE_CUDA_ARCHITECTURES` or `TORCH_CUDA_ARCH_LIST` to only include `90` or `100` +3. Remove previous build artifacts (`build/`, `*.so`) before rebuilding +4. Ensure PyTorch is compiled for the same CUDA major/minor as your toolkit + +## KV Cache Paging Mode Questions + +**Symptom**: Unexpected memory usage or OOM when paging is enabled. + +**Fixes**: +1. Verify paging is supported on your GPU (SM90/SM100 only) +2. Tune page size and eviction threshold; start with defaults provided by FlashMLA +3. If instability persists, disable paging to confirm the root cause, then re-enable with more conservative thresholds + +## Windows / ARM64 Build Support + +**Current status**: Official builds target Linux x86_64. Windows and ARM64 are not officially supported. + +**Workarounds (community)**: +- **Windows**: Use WSL2 with CUDA 12.8+ and compatible drivers +- **ARM64** (e.g., Grace): Cross-compile is experimental; verify toolchain supports your GPU arch and CUDA 12.8+ + +## Metadata API Compatibility Issues + +**Symptom**: Metadata API shape/dtype mismatches across releases. + +**Fixes**: +1. Align FlashMLA version with the metadata API expectations in your host framework +2. Regenerate metadata after upgrading FlashMLA or PyTorch +3. Check for breaking changes called out in release notes; adjust field names or tensor layouts accordingly + +## Sparse Attention Configuration Gotchas + +- Ensure runtime flags match build capabilities: if built without sparse, disable sparse at runtime +- **Mixed precision**: Sparse BF16 on SM90 is unsupported; use dense BF16 or switch to FP16 where allowed +- **Fallbacks**: If sparse fails to dispatch, explicitly force dense kernels to avoid silent slow paths + +## Advanced Help (Relevant Issues) + +| Category | Related Issues | +|----------|----------------| +| GPU arch support | #101, #113, #124, #134 | +| Sparse vs dense confusion | #87, #116, #142 | +| CUDA version build failures | #115, #121, #160 | +| Sparse BF16 on SM90 runtime error | #93, #110, #178 | +| KV cache paging behavior | #121, #126 | +| Windows/ARM64 build requests | #109, #119, #151 | +| Metadata API compatibility | #108, #126, #173 | + +--- + +If your issue is not covered here, please open a new GitHub issue with: +- GPU model +- CUDA/PyTorch versions +- Build flags +- Full error log