diff --git a/benchmark/bench_flash_mla.py b/benchmark/bench_flash_mla.py index 95c75f29..123da349 100644 --- a/benchmark/bench_flash_mla.py +++ b/benchmark/bench_flash_mla.py @@ -80,15 +80,14 @@ def flash_mla(): @torch.inference_mode() def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - + for i in range(b): blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan") assert d > dv, "mla with rope dim should be larger than no rope dim" q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() - - + kv_indptr = [0] kv_indices = [] for i in range(b): @@ -99,7 +98,7 @@ def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_ kv_indptr.append(kv_indptr[-1] + num_blocks) for seq_len in cache_seqlens[1:]: kv_indptr.append((seq_len + block_size - 1) // block_size + kv_indptr[-1]) - + q_indptr = torch.arange(0, b + 1).int() * s_q kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) kv_indices = torch.tensor(kv_indices, dtype=torch.int32) @@ -263,7 +262,7 @@ def _mla_attn( attn_logits.stride(1), attn_logits.stride(2), BLOCK_H=BLOCK_H, - BLOCK_N=BLOCK_N, + BLOCK_N=BLOCK_N, NUM_KV_SPLITS=num_kv_splits, PAGE_SIZE=page_size, HEAD_DIM_CKV=head_dim_ckv, @@ -313,7 +312,7 @@ def _mla_softmax_reducev_kernel( e_sum = e_sum * old_scale + exp_logic e_max = n_e_max - + tl.store( O + cur_batch * stride_o_b + cur_head * stride_o_h + offs_d_ckv, acc / e_sum, @@ -375,15 +374,15 @@ def mla_decode_triton( b_seq_len, num_kv_splits, ) - + @torch.inference_mode() def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): - + for i in range(b): blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan") blocked_v = blocked_k[..., :dv] - + assert d > dv, "mla with rope dim should be larger than no rope dim" q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() @@ -406,7 +405,7 @@ def flash_mla_triton(): "flash_infer": run_flash_infer, "flash_mla_triton": run_flash_mla_triton, } - + def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): print(f"comparing {baseline} vs {target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}") device = torch.device("cuda:0") @@ -419,7 +418,7 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal assert target in FUNC_TABLE baseline_func = FUNC_TABLE[baseline] target_func = FUNC_TABLE[target] - + total_seqlens = cache_seqlens.sum().item() mean_seqlens = cache_seqlens.float().mean().int().item() max_seqlen = cache_seqlens.max().item() @@ -430,10 +429,10 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal block_size = 64 block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) - + out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) - + torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" if target not in ["flash_infer", "flash_mla_triton"] and baseline not in ["flash_infer", "flash_mla_triton"]: # flash_infer has a different lse return value @@ -457,7 +456,7 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): random.seed(0) assert target in FUNC_TABLE target_func = FUNC_TABLE[target] - + total_seqlens = cache_seqlens.sum().item() mean_seqlens = cache_seqlens.float().mean().int().item() max_seqlen = cache_seqlens.max().item() @@ -468,7 +467,7 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): block_size = 64 block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) - + out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 @@ -500,7 +499,7 @@ def get_args(): args = parser.parse_args() return args - + if __name__ == "__main__": args = get_args() benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target diff --git a/tests/lib.py b/tests/lib.py index f8847212..154a471d 100644 --- a/tests/lib.py +++ b/tests/lib.py @@ -20,7 +20,7 @@ def get_cos_diff(x: torch.Tensor, y: torch.Tensor) -> float: sim = 2 * (x * y).sum().item() / denominator return 1 - sim assert ans.shape == ref.shape, f"`{name}` Shape mismatch: {ans.shape} vs {ref.shape}" - + ans = ans.clone().to(torch.float) ref = ref.clone().to(torch.float) @@ -34,7 +34,7 @@ def deal_with_anomalies(val: float): print(f"`{name}` Anomaly number `{val}` mismatch: {ans_mask.sum().item()} in ans but {ref_mask.sum().item()} in ref") return False return True - + anomalies_check_passed = True anomalies_check_passed &= deal_with_anomalies(float("inf")) anomalies_check_passed &= deal_with_anomalies(float("-inf"))