Skip to content

Commit 92dcc96

Browse files
Xia-WeiwenCopilot
andauthored
[X86] Fix correctness bug of DA8W4 linear kernel on certain platforms (#4309)
* [X86] Fix correctness bug of DA8W4 linear kernel on certain platforms * Simplify code * Refine code --------- Co-authored-by: Copilot <copilot@github.com>
1 parent 990d155 commit 92dcc96

1 file changed

Lines changed: 36 additions & 23 deletions

File tree

torchao/csrc/cpu/aten_kernels/da8w4_linear.cpp

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,23 @@ struct ActDtype<false> {
133133
using type = uint8_t;
134134
};
135135

136+
template<int64_t N, int64_t ldb>
137+
inline void _dequant_weight_zp_only_fallback(
138+
const uint8_t* __restrict__ B,
139+
int8_t* dqB,
140+
const int8_t* __restrict__ qzeros,
141+
int64_t K) {
142+
// Unpack weight from uint8 (two int4) to int8, and subtract zero point
143+
// Weight is not packed as VNNI format, shape = [K, N / 2]
144+
for (int k = 0; k < K; ++k) {
145+
for (int n = 0; n < N / 2; ++n) {
146+
int32_t b = (int32_t)B[k * ldb + n];
147+
dqB[k * N + n * 2] = (b & 0xf) - qzeros[n * 2];
148+
dqB[k * N + n * 2 + 1] = ((b >> 4) & 0xf) - qzeros[n * 2 + 1];
149+
}
150+
}
151+
}
152+
136153
#if defined(CPU_CAPABILITY_AVX512)
137154
inline std::array<__m256i, 2> load_zps_4vnni(const int8_t* __restrict__ zps) {
138155
// broadcast 01234567 to
@@ -190,27 +207,31 @@ inline std::array<__m256i, 2> load_uint4_as_int8(const uint8_t* __restrict__ qB)
190207
return {low, high};
191208
}
192209

193-
template <int64_t N, int64_t ldb>
210+
template <int64_t N, int64_t ldb, bool cpublas_can_pack>
194211
void _dequant_weight_zp_only(
195212
const uint8_t* __restrict__ B,
196213
int8_t* dqB,
197214
const int8_t* __restrict__ qzeros,
198215
int64_t K) {
199216
// unpack weight int8 -> two int4
200217
// subtract zero point
201-
// B shape = [K, ldb] = [K, N / 2], actual shape = [K / 4, N / 2, 4]
202-
// dqB shape = [K, N], actual shape = [K / 4, N, 4]
218+
if constexpr (cpublas_can_pack) {
219+
// B shape = [K, ldb] = [K, N / 2], actual shape = [K / 4, N / 2, 4]
220+
// dqB shape = [K, N], actual shape = [K / 4, N, 4]
203221
#pragma GCC unroll 2
204-
for (int n = 0; n < N; n += 16) {
205-
auto [zps_low, zps_high] = load_zps_4vnni(&qzeros[n]);
206-
for (int k = 0; k < K; k += 4) {
207-
auto [vb_low, vb_high] = load_uint4_as_int8(B + ldb * k + n / 2 * 4);
208-
vb_high = _mm256_sub_epi8(vb_high, zps_high);
209-
vb_low = _mm256_sub_epi8(vb_low, zps_low);
210-
// store vb to B
211-
_mm256_storeu_si256(reinterpret_cast<__m256i_u*>(dqB + N * k + n * 4), vb_low);
212-
_mm256_storeu_si256(reinterpret_cast<__m256i_u*>(dqB + N * k + (n + 8) * 4), vb_high);
222+
for (int n = 0; n < N; n += 16) {
223+
auto [zps_low, zps_high] = load_zps_4vnni(&qzeros[n]);
224+
for (int k = 0; k < K; k += 4) {
225+
auto [vb_low, vb_high] = load_uint4_as_int8(B + ldb * k + n / 2 * 4);
226+
vb_high = _mm256_sub_epi8(vb_high, zps_high);
227+
vb_low = _mm256_sub_epi8(vb_low, zps_low);
228+
// store vb to B
229+
_mm256_storeu_si256(reinterpret_cast<__m256i_u*>(dqB + N * k + n * 4), vb_low);
230+
_mm256_storeu_si256(reinterpret_cast<__m256i_u*>(dqB + N * k + (n + 8) * 4), vb_high);
231+
}
213232
}
233+
} else { // cannot pack (no AMX support)
234+
_dequant_weight_zp_only_fallback<N, ldb>(B, dqB, qzeros, K);
214235
}
215236
}
216237

@@ -272,21 +293,13 @@ void _dequant_and_store(
272293
}
273294

274295
#else
275-
template<int64_t N, int64_t ldb>
296+
template<int64_t N, int64_t ldb, bool cpublas_can_pack>
276297
void _dequant_weight_zp_only(
277298
const uint8_t* B,
278299
int8_t* dqB,
279300
const int8_t* qzeros,
280301
int64_t K) {
281-
// B shape = [K, N / 2]
282-
// dqB shape = [K, N]
283-
for (int k = 0; k < K; ++k) {
284-
for (int n = 0; n < N / 2; ++n) {
285-
int32_t b = (int32_t)B[k * ldb + n];
286-
dqB[k * N + n * 2] = (b & 0xf) - qzeros[n * 2];
287-
dqB[k * N + n * 2 + 1] = ((b >> 4) & 0xf) - qzeros[n * 2 + 1];
288-
}
289-
}
302+
return _dequant_weight_zp_only_fallback<N, ldb>(B, dqB, qzeros, K);
290303
}
291304
#endif
292305

@@ -456,7 +469,7 @@ void _dequant_gemm_accum(
456469
#endif
457470

458471
int8_t dqB[K * N];
459-
_dequant_weight_zp_only<N, ldb>(B, dqB, qzeros_b, K);
472+
_dequant_weight_zp_only<N, ldb, cpublas_can_pack>(B, dqB, qzeros_b, K);
460473
using Tin = typename ActDtype<sym_quant_a>::type;
461474
Tin* A_ptr = (Tin*)A;
462475
#if defined(CPU_CAPABILITY_AVX512)

0 commit comments

Comments
 (0)