Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 20 additions & 18 deletions csrc/sm90/prefill/sparse/phase1.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
const int num_topk_blocks = HAVE_TOPK_LENGTH ? ku::ceil_div(topk_length, (int)B_TOPK) : (int)((unsigned int)params.topk/(unsigned int)B_TOPK);

if (warpgroup_idx == 0 || warpgroup_idx == 1) {
cutlass::arch::warpgroup_reg_alloc<216>();
cutlass::arch::warpgroup_reg_alloc<200>();

if (warp_idx == 0 && elect_one_sync()) {
// Load Q
Expand Down Expand Up @@ -491,7 +491,7 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
}
};

int64_t cache_policy = createpolicy_evict_last();
int64_t cache_policy = createpolicy_evict_first();
auto copy_tiles = [&](int block_idx, int buf_idx, int tile_start, int tile_end) {
// Copy some K/V tiles from global memory to shared memory
// A tile has a shape of 64 (B_TOPK) x 64
Expand Down Expand Up @@ -527,32 +527,34 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
copy_tiles(block_idx+0, 0, 0, 4);
commit_to_mbar(plan.bar_k0_ready[0]);

// V1R
plan.bar_k1_free[1].wait(cur_bar_wait_phase);
copy_tiles(block_idx+1, 1, 4, D_K/64);
commit_to_mbar(plan.bar_k1_ready[1]);

auto publish_valid_mask = [&]() {
if (idx_in_group == 0) {
CUTE_UNROLL
for (int buf_idx = 0; buf_idx < 2; ++buf_idx)
CUTE_UNROLL
for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row)
plan.is_kv_valid[buf_idx][local_row*NUM_GROUPS+group_idx] = is_token_valid[buf_idx][local_row];
plan.bar_is_kv_valid_ready.arrive();
}
};

// V0R
plan.bar_k0_free[1].wait(cur_bar_wait_phase);
copy_tiles(block_idx+0, 0, 4, D_K/64);
commit_to_mbar(plan.bar_k0_ready[1]);

publish_valid_mask();

// V1R
plan.bar_k1_free[1].wait(cur_bar_wait_phase);
copy_tiles(block_idx+1, 1, 4, D_K/64);
commit_to_mbar(plan.bar_k1_ready[1]);

// V1L
plan.bar_k1_free[0].wait(cur_bar_wait_phase);
copy_tiles(block_idx+1, 1, 0, 4);
commit_to_mbar(plan.bar_k1_ready[0]);

// Valid mask
// NOTE: V1R's finish implies maskings of the last round have finished
if (idx_in_group == 0) {
CUTE_UNROLL
for (int buf_idx = 0; buf_idx < 2; ++buf_idx)
CUTE_UNROLL
for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row)
plan.is_kv_valid[buf_idx][local_row*NUM_GROUPS+group_idx] = is_token_valid[buf_idx][local_row];
plan.bar_is_kv_valid_ready.arrive();
}

cur_bar_wait_phase ^= 1;
}
}
Expand Down