diff --git a/csrc/sm90/prefill/sparse/phase1.cuh b/csrc/sm90/prefill/sparse/phase1.cuh index bf2fff84..0f12c60f 100644 --- a/csrc/sm90/prefill/sparse/phase1.cuh +++ b/csrc/sm90/prefill/sparse/phase1.cuh @@ -81,7 +81,7 @@ __device__ void KernelTemplate::devfunc(const SparseAttn const int num_topk_blocks = HAVE_TOPK_LENGTH ? ku::ceil_div(topk_length, (int)B_TOPK) : (int)((unsigned int)params.topk/(unsigned int)B_TOPK); if (warpgroup_idx == 0 || warpgroup_idx == 1) { - cutlass::arch::warpgroup_reg_alloc<216>(); + cutlass::arch::warpgroup_reg_alloc<200>(); if (warp_idx == 0 && elect_one_sync()) { // Load Q @@ -491,7 +491,7 @@ __device__ void KernelTemplate::devfunc(const SparseAttn } }; - int64_t cache_policy = createpolicy_evict_last(); + int64_t cache_policy = createpolicy_evict_first(); auto copy_tiles = [&](int block_idx, int buf_idx, int tile_start, int tile_end) { // Copy some K/V tiles from global memory to shared memory // A tile has a shape of 64 (B_TOPK) x 64 @@ -527,32 +527,34 @@ __device__ void KernelTemplate::devfunc(const SparseAttn copy_tiles(block_idx+0, 0, 0, 4); commit_to_mbar(plan.bar_k0_ready[0]); - // V1R - plan.bar_k1_free[1].wait(cur_bar_wait_phase); - copy_tiles(block_idx+1, 1, 4, D_K/64); - commit_to_mbar(plan.bar_k1_ready[1]); - + auto publish_valid_mask = [&]() { + if (idx_in_group == 0) { + CUTE_UNROLL + for (int buf_idx = 0; buf_idx < 2; ++buf_idx) + CUTE_UNROLL + for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) + plan.is_kv_valid[buf_idx][local_row*NUM_GROUPS+group_idx] = is_token_valid[buf_idx][local_row]; + plan.bar_is_kv_valid_ready.arrive(); + } + }; + // V0R plan.bar_k0_free[1].wait(cur_bar_wait_phase); copy_tiles(block_idx+0, 0, 4, D_K/64); commit_to_mbar(plan.bar_k0_ready[1]); + publish_valid_mask(); + + // V1R + plan.bar_k1_free[1].wait(cur_bar_wait_phase); + copy_tiles(block_idx+1, 1, 4, D_K/64); + commit_to_mbar(plan.bar_k1_ready[1]); + // V1L plan.bar_k1_free[0].wait(cur_bar_wait_phase); copy_tiles(block_idx+1, 1, 0, 4); commit_to_mbar(plan.bar_k1_ready[0]); - // Valid mask - // NOTE: V1R's finish implies maskings of the last round have finished - if (idx_in_group == 0) { - CUTE_UNROLL - for (int buf_idx = 0; buf_idx < 2; ++buf_idx) - CUTE_UNROLL - for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) - plan.is_kv_valid[buf_idx][local_row*NUM_GROUPS+group_idx] = is_token_valid[buf_idx][local_row]; - plan.bar_is_kv_valid_ready.arrive(); - } - cur_bar_wait_phase ^= 1; } }