diff --git a/csrc/api/sparse_decode.h b/csrc/api/sparse_decode.h index 6df5c841..4dd86164 100644 --- a/csrc/api/sparse_decode.h +++ b/csrc/api/sparse_decode.h @@ -1,5 +1,7 @@ #pragma once +#include + #include "common.h" #include "params.h" @@ -359,15 +361,15 @@ sparse_attn_decode_interface( features.push_back(DecodeFeatures::EXTRA_TOPK_LENGTH); } - DecodeImplBase* impl; + std::unique_ptr impl; if (arch.is_sm100f()) { if (h_q == 64) { - impl = new Decode_Sm100_Head64_Impl(); + impl = std::make_unique(); } else if (h_q == 128) { if (d_qk == 576) { - impl = new Decode_Sm100_Head64x2_Impl(); + impl = std::make_unique(); } else if (d_qk == 512) { - impl = new Decode_Sm100_Head128_Impl(); + impl = std::make_unique(); } else { TORCH_CHECK(false, "Unsupported d_qk: ", d_qk); } @@ -375,7 +377,7 @@ sparse_attn_decode_interface( TORCH_CHECK(false, "Unsupported h_q: ", h_q); } } else if (arch.is_sm90a()) { - impl = new Decode_Sm90_Impl(); + impl = std::make_unique(); } else { TORCH_CHECK(false, "Unsupported architecture for sparse decode fwd"); } @@ -489,7 +491,5 @@ sparse_attn_decode_interface( }; smxx::decode::run_flash_mla_combine_kernel(combine_params); - delete impl; - return {out, lse.transpose(1, 2), tile_scheduler_metadata, num_splits}; }