diff --git a/csrc/kerutils/include/kerutils/device/sm100/gemm.cuh b/csrc/kerutils/include/kerutils/device/sm100/gemm.cuh index 8af4edcf..542432e3 100644 --- a/csrc/kerutils/include/kerutils/device/sm100/gemm.cuh +++ b/csrc/kerutils/include/kerutils/device/sm100/gemm.cuh @@ -17,7 +17,7 @@ struct SM100_MMA_F16BF16_WS_TS_NOELECT { static_assert(M == 32 || M == 64 || M == 128, "SM100_MMA_F16BF16_WS_TS_NOELECT M-mode size should be 32, 64 or 128 for 1 CTA cluster MMA."); static_assert(N == 64 || N == 128 || N == 256, - "SM100_MMA_F16BF16_WS_TS_NOELECT N-mode size should be 32, 64 or 128"); + "SM100_MMA_F16BF16_WS_TS_NOELECT N-mode size should be 64, 128 or 256"); using DRegisters = void; using ARegisters = uint64_t[1]; @@ -114,7 +114,7 @@ struct SM100_MMA_F16BF16_WS_SS_NOELECT { static_assert(M == 32 || M == 64 || M == 128, "SM100_MMA_F16BF16_WS_SS_NOELECT M-mode size should be 32, 64 or 128 for 1 CTA cluster MMA."); static_assert(N == 64 || N == 128 || N == 256, - "SM100_MMA_F16BF16_WS_SS_NOELECT N-mode size should be 32, 64 or 128"); + "SM100_MMA_F16BF16_WS_SS_NOELECT N-mode size should be 64, 128 or 256"); using DRegisters = void; using ARegisters = uint64_t[1];