Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions docs/FAQ.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# FlashMLA FAQ

Frequently asked questions about FlashMLA, based on common GitHub issues and user inquiries.

## General Questions

### Does `flash_mla_with_kvcache` work only in paged mode?

No. It supports both paged and non-paged KV cache. Paged mode is recommended for long-context workloads and memory fragmentation control, but non-paged works for shorter, fixed-length workloads.

### Can MLA/MHA be used in the prefill stage?

Yes. Prefill supports both MLA and MHA:
- Use **MLA** for sparse/long contexts with grouped-query attention
- Use **MHA** when you need standard dense attention or compatibility with existing kernels

### What GPU architectures are supported?

| Architecture | Support |
|--------------|---------|
| SM90 (Hopper - H100/H800) | Full support |
| SM100 (Blackwell - B100/B200) | Full support |
| SM120 | Not supported |
| Older (SM80, SM70) | Not supported |

Use matching CUDA drivers (12.8+) and ensure your build targets these architectures.

## Sparse Attention

### How do I test sparse attention?

1. Enable MLA sparse mode in your config
2. Load a sparse pattern (block-sparse mask)
3. Run the sparse test suite: `python tests/test_flash_mla_sparse.py`
4. Compare outputs against dense runs to validate correctness
5. Check perf counters to confirm sparse paths are used

### What's the difference between dense and sparse MLA?

| Feature | Dense MLA | Sparse MLA |
|---------|-----------|------------|
| Computation | Full attention matrix | Pruned blocks |
| Performance (H800) | ~660 TFlops | ~410 TFlops |
| Use case | Short sequences, full accuracy | Long sequences, structured sparsity |
| SM90 support | Yes | No |
| SM100 support | Yes | Yes |

Dense MLA computes full attention (higher FLOPs, higher accuracy for dense tasks). Sparse MLA prunes blocks to reduce compute/memory, trading some fidelity for speed/throughput on long or structured-sparsity workloads.

## Integration

### How do I integrate FlashMLA with vLLM/SGLang/other frameworks?

1. Use the provided FlashMLA attention interface
2. Register it as the backend kernel for attention ops
3. Follow the framework's custom op/plugin hooks:
- **vLLM**: Custom attention registry
- **SGLang**: Extension points
4. Ensure paged KV cache shape/block alignment matches FlashMLA's layout (block size = 64)

Example for vLLM:
```python
from vllm.attention.backends import register_attention_backend
from flash_mla import FlashMLABackend

register_attention_backend("flash_mla", FlashMLABackend)
```

## Performance

### What performance should I expect on different GPUs?

| GPU | Dense MLA | Sparse MLA | Notes |
|-----|-----------|------------|-------|
| H800 | ~660 TFlops | ~410 TFlops | Reference config |
| H100 | ~600 TFlops | ~380 TFlops | Slightly lower than H800 |
| B200 | ~1460 TFlops (MHA prefill) | TBD | SM100 optimizations |

Actual numbers vary with sequence length, batch shape, paging configuration, and sparsity pattern. Expect higher throughput on SM100 vs SM90 when similarly configured.

## Memory & Batching

### How do I handle variable-length sequences in a batch?

1. Use paged KV cache with proper offsets/indirection per sequence
2. Provide per-sequence lengths via `seq_lens` tensor
3. Use masks to avoid reading/writing past valid tokens
4. Pad to block boundaries (64 tokens) only where required by the cache layout

Example:
```python
# Variable-length batch
seq_lens = torch.tensor([128, 256, 64, 512]) # Different lengths per batch
nblk_per_seq = (seq_lens + 63) // 64 # Blocks needed per sequence
max_nblk = nblk_per_seq.max()

# Allocate cache with max blocks
k_cache = torch.zeros(B, Hk, max_nblk, 64, D, dtype=torch.bfloat16)
```

### What's the block size for paged KV cache?

Block size is **64 tokens**. This is fixed and cannot be changed. Align all allocations and paging logic to 64-token blocks.

## Quantization

### How do I enable FP8 quantization for the KV cache?

1. Ensure GPU supports FP8 (SM90/SM100)
2. Build with FP8 KV cache enabled
3. Use the FP8-compatible KV layout
4. Apply calibration/scaling utilities provided by FlashMLA

```python
# FP8 KV cache setup
k_cache = torch.zeros(B, Hk, Nblk, 64, D, dtype=torch.float8_e4m3fn)
v_cache = torch.zeros(B, Hk, Nblk, 64, D, dtype=torch.float8_e4m3fn)

# Scaling factors (calibrated)
k_scale = torch.ones(B, Hk, 1, 1, 1)
v_scale = torch.ones(B, Hk, 1, 1, 1)
```

Fall back to FP16/BF16 if FP8 is unsupported on your GPU.

---

## Still have questions?

If your question isn't answered here:
1. Search [existing issues](https://github.com/deepseek-ai/FlashMLA/issues)
2. Open a new issue with your GPU model, CUDA version, and detailed question
155 changes: 155 additions & 0 deletions docs/TENSOR_SHAPES.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Tensor Shapes Reference - FlashMLA

This document provides a comprehensive reference for tensor shapes expected by FlashMLA functions.

## Shape Notation

| Symbol | Meaning |
|--------|---------|
| B | Batch size |
| S | Sequence length (tokens in request) |
| P | Prompt length (prefill tokens) |
| T | Decode step count (incremental tokens) |
| H | Attention heads |
| Hq | Query heads |
| Hk | Key/Value heads (Hk <= Hq for grouped-query attention) |
| D | Per-head hidden dimension |
| G | Groups for grouped-query attention (G = Hk; Hq = G * ratio) |
| Nblk | Number of KV blocks: `ceil((P + T) / 64)` |
| Blk | Block size for paged cache (always 64) |

**Layout**: Row-major contiguous unless noted; strides allowed if stated.

## Function Reference

### 1. `flash_mla_with_kvcache` - MLA Decoding with Paged KV Cache

```python
flash_mla_with_kvcache(
q: [B, Hq, 1, D], # Current token query
k_cache: [B, Hk, Nblk, 64, D], # Paged K blocks
v_cache: [B, Hk, Nblk, 64, D], # Paged V blocks
block_table: [B, Nblk], # Maps block index -> physical cache block
seq_lens: [B], # Prompt + decoded length per batch
mask: [B, 1, 1, P+T], # Optional; causal if omitted
metadata: get_mla_metadata(...),
) -> out: [B, Hq, 1, D]
```

**Layout Requirements**:
- `q`: contiguous
- `k_cache`/`v_cache`: contiguous within last two dims (Blk, D)
- `block_table`: must be contiguous

**Notes**:
- MLA uses grouped-query attention: typically Hq = multiple of Hk (e.g., 8:4)
- Decoding assumes the last cache block may be partially filled; unused slots ignored via `seq_lens`

### 2. `get_mla_metadata` - Metadata for Decoding

```python
get_mla_metadata(
q_shape: tuple[B, Hq, 1, D],
kv_shape: tuple[B, Hk, Nblk, 64, D],
block_table: [B, Nblk],
seq_lens: [B],
) -> metadata
```

**Notes**:
- Pure shape/stride validation; does not materialize tensors
- Requires consistent Hq/Hk grouping and block_table coverage up to max `seq_lens`

### 3. `flash_mla_prefill` - Prefill Attention with Sparse Patterns

```python
flash_mla_prefill(
q: [B, Hq, P, D],
k: [B, Hk, P, D],
v: [B, Hk, P, D],
sparse_layout: optional pattern (e.g., block-sparse mask),
attn_mask: [B, 1, P, P] or [B, Hq, P, P], # Causal or custom
) -> out: [B, Hq, P, D]
```

**Layout Requirements**:
- q/k/v: contiguous on last dim; leading dims can be strided but consistent
- Sparse pattern typically block-aligned to 64 for reuse with paged cache

### 4. `mha_fwd_kvcache` - MHA Prefill with KV Cache

```python
mha_fwd_kvcache(
q: [B, H, P, D],
k_cache: [B, H, Nblk, 64, D],
v_cache: [B, H, Nblk, 64, D],
block_table: [B, Nblk],
attn_mask: [B, 1, P, P] or [B, H, P, P],
) -> out: [B, H, P, D]
```

**Notes**:
- Standard MHA (no MLA head-grouping)
- Prefill writes into paged cache; P must align with block_table coverage (`ceil(P/64)` blocks)

## Memory Layout Requirements

| Requirement | Details |
|-------------|---------|
| Contiguous dims | Inner dims (D and Blk) must be contiguous for q/k/v and caches |
| Block size | Fixed at 64; `block_table` must enumerate all blocks in order |
| Mixed precision | FP8/BF16/FP16 supported; metadata and indices are integer/FP32 |
| Strides | Batch/head/sequence dims may be strided if monotonic; non-monotonic unsupported |

## Visual Layout Diagrams

### Paged KV Cache (per batch, per head)

```
k_cache[b, h] -> [ blk0 | blk1 | ... | blk(Nblk-1) ]
|
v
blkX: 64 rows x D cols (contiguous slab)
```

### Grouped-Query Attention (Hq > Hk)

```
Hq heads (e.g., 8)
├── group0 (q heads 0-1) -> maps to kv head 0
├── group1 (q heads 2-3) -> maps to kv head 1
├── group2 (q heads 4-5) -> maps to kv head 2
└── group3 (q heads 6-7) -> maps to kv head 3
```

### Prefill to Cache Mapping (P tokens)

```
tokens 0..63 -> block_table[b, 0]
tokens 64..127 -> block_table[b, 1]
tokens 128..191 -> block_table[b, 2]
...
```

## Common Shape Mismatch Errors & Fixes

| Error | Cause | Fix |
|-------|-------|-----|
| Hq vs Hk mismatch | `Hq` is not an integer multiple of `Hk` | Adjust q reshape or head-splitting for MLA |
| Wrong block count | `Nblk` != `ceil(seq_len/64)` | Update block_table or cache allocation |
| Non-contiguous inner dims | q/k_cache/v_cache not contiguous on D/Blk | Call `.contiguous()` or re-pack before kernel call |
| Mask length short | `attn_mask` last dim < `P+T` | Pad mask or recompute with correct seq length |
| Seq_lens vs cache | `seq_lens[b]` > `Nblk*64` | Grow cache or truncate sequence |
| Mixed precision | q/k/v and cache have different dtypes | Cast consistently to FP8/BF16/FP16 |

## Quick Checks Before Calling

```python
# Before flash_mla_with_kvcache
assert q.shape == (B, Hq, 1, D)
assert k_cache.shape == (B, Hk, Nblk, 64, D)
assert Hq % Hk == 0 # MLA grouping
assert block_table.shape == (B, Nblk)
assert all(seq_lens[b] <= Nblk * 64 for b in range(B))
assert q.is_contiguous()
```
Loading