@@ -133,6 +133,23 @@ struct ActDtype<false> {
133133 using type = uint8_t ;
134134};
135135
136+ template <int64_t N, int64_t ldb>
137+ inline void _dequant_weight_zp_only_fallback (
138+ const uint8_t * __restrict__ B,
139+ int8_t * dqB,
140+ const int8_t * __restrict__ qzeros,
141+ int64_t K) {
142+ // Unpack weight from uint8 (two int4) to int8, and subtract zero point
143+ // Weight is not packed as VNNI format, shape = [K, N / 2]
144+ for (int k = 0 ; k < K; ++k) {
145+ for (int n = 0 ; n < N / 2 ; ++n) {
146+ int32_t b = (int32_t )B[k * ldb + n];
147+ dqB[k * N + n * 2 ] = (b & 0xf ) - qzeros[n * 2 ];
148+ dqB[k * N + n * 2 + 1 ] = ((b >> 4 ) & 0xf ) - qzeros[n * 2 + 1 ];
149+ }
150+ }
151+ }
152+
136153#if defined(CPU_CAPABILITY_AVX512)
137154inline std::array<__m256i, 2 > load_zps_4vnni (const int8_t * __restrict__ zps) {
138155 // broadcast 01234567 to
@@ -190,27 +207,31 @@ inline std::array<__m256i, 2> load_uint4_as_int8(const uint8_t* __restrict__ qB)
190207 return {low, high};
191208}
192209
193- template <int64_t N, int64_t ldb>
210+ template <int64_t N, int64_t ldb, bool cpublas_can_pack >
194211void _dequant_weight_zp_only (
195212 const uint8_t * __restrict__ B,
196213 int8_t * dqB,
197214 const int8_t * __restrict__ qzeros,
198215 int64_t K) {
199216 // unpack weight int8 -> two int4
200217 // subtract zero point
201- // B shape = [K, ldb] = [K, N / 2], actual shape = [K / 4, N / 2, 4]
202- // dqB shape = [K, N], actual shape = [K / 4, N, 4]
218+ if constexpr (cpublas_can_pack) {
219+ // B shape = [K, ldb] = [K, N / 2], actual shape = [K / 4, N / 2, 4]
220+ // dqB shape = [K, N], actual shape = [K / 4, N, 4]
203221#pragma GCC unroll 2
204- for (int n = 0 ; n < N; n += 16 ) {
205- auto [zps_low, zps_high] = load_zps_4vnni (&qzeros[n]);
206- for (int k = 0 ; k < K; k += 4 ) {
207- auto [vb_low, vb_high] = load_uint4_as_int8 (B + ldb * k + n / 2 * 4 );
208- vb_high = _mm256_sub_epi8 (vb_high, zps_high);
209- vb_low = _mm256_sub_epi8 (vb_low, zps_low);
210- // store vb to B
211- _mm256_storeu_si256 (reinterpret_cast <__m256i_u*>(dqB + N * k + n * 4 ), vb_low);
212- _mm256_storeu_si256 (reinterpret_cast <__m256i_u*>(dqB + N * k + (n + 8 ) * 4 ), vb_high);
222+ for (int n = 0 ; n < N; n += 16 ) {
223+ auto [zps_low, zps_high] = load_zps_4vnni (&qzeros[n]);
224+ for (int k = 0 ; k < K; k += 4 ) {
225+ auto [vb_low, vb_high] = load_uint4_as_int8 (B + ldb * k + n / 2 * 4 );
226+ vb_high = _mm256_sub_epi8 (vb_high, zps_high);
227+ vb_low = _mm256_sub_epi8 (vb_low, zps_low);
228+ // store vb to B
229+ _mm256_storeu_si256 (reinterpret_cast <__m256i_u*>(dqB + N * k + n * 4 ), vb_low);
230+ _mm256_storeu_si256 (reinterpret_cast <__m256i_u*>(dqB + N * k + (n + 8 ) * 4 ), vb_high);
231+ }
213232 }
233+ } else { // cannot pack (no AMX support)
234+ _dequant_weight_zp_only_fallback<N, ldb>(B, dqB, qzeros, K);
214235 }
215236}
216237
@@ -272,21 +293,13 @@ void _dequant_and_store(
272293}
273294
274295#else
275- template <int64_t N, int64_t ldb>
296+ template <int64_t N, int64_t ldb, bool cpublas_can_pack >
276297void _dequant_weight_zp_only (
277298 const uint8_t * B,
278299 int8_t * dqB,
279300 const int8_t * qzeros,
280301 int64_t K) {
281- // B shape = [K, N / 2]
282- // dqB shape = [K, N]
283- for (int k = 0 ; k < K; ++k) {
284- for (int n = 0 ; n < N / 2 ; ++n) {
285- int32_t b = (int32_t )B[k * ldb + n];
286- dqB[k * N + n * 2 ] = (b & 0xf ) - qzeros[n * 2 ];
287- dqB[k * N + n * 2 + 1 ] = ((b >> 4 ) & 0xf ) - qzeros[n * 2 + 1 ];
288- }
289- }
302+ return _dequant_weight_zp_only_fallback<N, ldb>(B, dqB, qzeros, K);
290303}
291304#endif
292305
@@ -456,7 +469,7 @@ void _dequant_gemm_accum(
456469#endif
457470
458471 int8_t dqB[K * N];
459- _dequant_weight_zp_only<N, ldb>(B, dqB, qzeros_b, K);
472+ _dequant_weight_zp_only<N, ldb, cpublas_can_pack >(B, dqB, qzeros_b, K);
460473 using Tin = typename ActDtype<sym_quant_a>::type;
461474 Tin* A_ptr = (Tin*)A;
462475#if defined(CPU_CAPABILITY_AVX512)
0 commit comments