Skip to content

Commit 91db770

Browse files
Xia-WeiwenCopilot
andauthored
[X86] Bug fixes and refinements in da8w4/float8 linear kernels (#4301)
* [X86] Refine da8w4/float8 linear kernels * Fix correctness bugs in fallback path * Refine code --------- Co-authored-by: Copilot <copilot@github.com>
1 parent b3e0db2 commit 91db770

3 files changed

Lines changed: 42 additions & 90 deletions

File tree

torchao/csrc/cpu/aten_kernels/da8w4_linear.cpp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,14 @@ namespace {
99

1010
#define BLOCK_N 32
1111

12-
static bool cpublas_checked = false;
12+
static std::once_flag cpublas_once;
1313
static bool cpublas_can_pack = false;
1414

15-
bool cpublas_could_pack() {
15+
static inline bool cpublas_could_pack() {
1616
// the could_pack check requires AMX support implicitly
17-
if (cpublas_checked) {
18-
return cpublas_can_pack;
19-
}
20-
cpublas_can_pack = at::native::cpublas::could_pack(at::kByte);
21-
cpublas_checked = true;
17+
std::call_once(cpublas_once, []() {
18+
cpublas_can_pack = at::native::cpublas::could_pack(at::kByte);
19+
});
2220
return cpublas_can_pack;
2321
}
2422

@@ -135,7 +133,6 @@ struct ActDtype<false> {
135133
using type = uint8_t;
136134
};
137135

138-
139136
#if defined(CPU_CAPABILITY_AVX512)
140137
inline std::array<__m256i, 2> load_zps_4vnni(const int8_t* __restrict__ zps) {
141138
// broadcast 01234567 to
@@ -286,8 +283,8 @@ void _dequant_weight_zp_only(
286283
for (int k = 0; k < K; ++k) {
287284
for (int n = 0; n < N / 2; ++n) {
288285
int32_t b = (int32_t)B[k * ldb + n];
289-
dqB[k * N + n * 2] = (b & 0xf) - qzeros[n];
290-
dqB[k * N + n * 2 + 1] = (b >> 4) - qzeros[n];
286+
dqB[k * N + n * 2] = (b & 0xf) - qzeros[n * 2];
287+
dqB[k * N + n * 2 + 1] = ((b >> 4) & 0xf) - qzeros[n * 2 + 1];
291288
}
292289
}
293290
}
@@ -407,7 +404,6 @@ void _dequant_gemm_accum_small_M(
407404
_mm512_storeu_ps(C + row * ldc + col * 16, vc_float);
408405
};
409406
c10::ForcedUnroll<M * COLS>{}(store);
410-
411407
}
412408

413409
#define call_dequant_gemm_accum_small_M(M) \

torchao/csrc/cpu/aten_kernels/float8_linear.cpp

Lines changed: 35 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,17 @@ namespace {
1919
#define PER_ROW 2
2020
#define PER_GROUP 3
2121

22-
static bool cpublas_checked = false;
22+
static std::once_flag cpublas_flag;
2323
static bool cpublas_can_pack = false;
2424

25-
bool cpublas_could_pack() {
26-
// the could_pack check requires AMX support implicitly
27-
if (cpublas_checked) {
28-
return cpublas_can_pack;
29-
}
25+
static inline bool cpublas_could_pack() {
26+
std::call_once(cpublas_flag, []() {
3027
#ifdef CPUBLAS_BRGEMM_F8F8F32
31-
cpublas_can_pack = at::native::cpublas::could_pack(at::kFloat8_e4m3fn);
28+
cpublas_can_pack = at::native::cpublas::could_pack(at::kFloat8_e4m3fn);
3229
#else
33-
cpublas_can_pack = at::native::cpublas::could_pack(at::kBFloat16);
30+
cpublas_can_pack = at::native::cpublas::could_pack(at::kBFloat16);
3431
#endif
35-
cpublas_checked = true;
32+
});
3633
return cpublas_can_pack;
3734
}
3835

@@ -124,59 +121,34 @@ float8_linear_prepack_impl(
124121
}
125122

126123
#if defined(CPU_CAPABILITY_AVX512)
127-
// this doesn't handle NaN.
128-
inline __m512bh cvt_e4m3_bf16_intrinsic_no_nan(__m256i fp8_vec) {
129-
const __m512i x = _mm512_cvtepu8_epi16(fp8_vec);
130-
131-
const __m512i mant = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x07)), 4);
132-
const __m512i raw_exp = _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x78)), 3);
133-
const __m512i exp = _mm512_slli_epi16(_mm512_add_epi16(raw_exp, _mm512_set1_epi16(120)), 7);
134-
const __m512i nonsign = _mm512_or_si512(exp, mant);
135-
136-
const __m512i sign = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x80)), 8);
137-
const __m512i combined = _mm512_or_si512(nonsign, sign);
138-
139-
const __mmask32 is_nonzero = _mm512_cmpneq_epi16_mask(x, _mm512_setzero_si512());
140-
return (__m512bh)_mm512_maskz_mov_epi16(is_nonzero, combined);
141-
}
142-
143124
static void cvt_f8e4m3_to_bf16(
144-
const at::Float8_e4m3fn* __restrict__ in,
145-
at::BFloat16* out,
146-
int64_t rows,
147-
int64_t cols,
148-
int64_t stride) {
149-
if (stride == cols) {
150-
// A contiguous buffer
151-
size_t len = rows * cols;
152-
size_t i = 0;
153-
for (; i < len; i += 32) {
154-
__m256i fp8_vec = _mm256_loadu_si256((__m256i*)&in[i]);
155-
__m512bh bf16_vec = cvt_e4m3_bf16_intrinsic_no_nan(fp8_vec);
156-
_mm512_storeu_si512((__m512i*)(out + i), (__m512i)bf16_vec);
157-
}
158-
for (; i < len; ++i) {
159-
out[i] = (at::BFloat16)in[i];
160-
}
161-
} else {
162-
// Non-contiguous. Access each row with stride
163-
TORCH_CHECK(stride > cols);
164-
for (int r = 0; r < rows; ++r) {
165-
size_t i = 0;
166-
size_t vec_len = cols / 32 * 32;
167-
for (; i < vec_len; i += 32) {
168-
__m256i fp8_vec = _mm256_loadu_si256((__m256i*)&in[r * stride + i]);
169-
__m512bh bf16_vec = cvt_e4m3_bf16_intrinsic_no_nan(fp8_vec);
170-
_mm512_storeu_si512((__m512i*)(out + r * cols + i), (__m512i)bf16_vec);
171-
}
172-
for (; i < cols; ++i) {
173-
out[r * cols + i] = (at::BFloat16)in[r * stride + i];
174-
}
175-
}
125+
const at::Float8_e4m3fn* __restrict__ in,
126+
at::BFloat16* out,
127+
int64_t rows,
128+
int64_t cols,
129+
int64_t stride) {
130+
constexpr int64_t vec_len = 32; // 256 bit = 32 fp8 values
131+
__m512 fp32_vec_0, fp32_vec_1;
132+
for (int r = 0; r < rows; ++r) {
133+
size_t i = 0;
134+
size_t vec_len_aligned = cols / vec_len * vec_len;
135+
for (; i < vec_len_aligned; i += vec_len) {
136+
__m256i fp8_vec = _mm256_loadu_si256((__m256i*)&in[r * stride + i]);
137+
// Convert fp8 to fp32
138+
at::vec::CPU_CAPABILITY::cvtfp8e4m3_fp32(_mm256_castsi256_si128(fp8_vec), fp32_vec_0);
139+
at::vec::CPU_CAPABILITY::cvtfp8e4m3_fp32(_mm256_extracti128_si256(fp8_vec, 1), fp32_vec_1);
140+
// Convert to bf16 and store
141+
__m256i bf16_vec_0 = at::vec::cvtfp32_bf16(fp32_vec_0);
142+
__m256i bf16_vec_1 = at::vec::cvtfp32_bf16(fp32_vec_1);
143+
__m512i bf16_vec = _mm512_inserti32x8(_mm512_castsi256_si512(bf16_vec_0), bf16_vec_1, 1);
144+
_mm512_storeu_si512((__m512i*)(out + r * cols + i), bf16_vec);
145+
}
146+
for (; i < cols; ++i) {
147+
out[r * cols + i] = (at::BFloat16)in[r * stride + i];
148+
}
176149
}
177150
}
178151

179-
180152
// accumulate and store result to buffer
181153
// if act/wei are per_group quantized, apply scales
182154
template <bool accum, int64_t N, int act_quant_mode, int wei_quant_mode>
@@ -294,7 +266,8 @@ inline void store_out(
294266
if constexpr (wei_quant_mode == PER_ROW) {
295267
b_scale = scales_b[j];
296268
}
297-
c_ptr[i * lda + j] = static_cast<out_dtype>(y_buf[i * N + j] * a_scale * b_scale);
269+
float bias_val = bias ? bias[j] : 0.0f;
270+
c_ptr[i * lda + j] = static_cast<out_dtype>(y_buf[i * N + j] * a_scale * b_scale + bias_val);
298271
}
299272
} // for M
300273
}
@@ -341,7 +314,8 @@ inline void store_out(
341314
if constexpr (wei_quant_mode == PER_ROW) {
342315
b_scale = scales_b[j];
343316
}
344-
c_ptr[i * lda + j] = static_cast<out_dtype>(y_buf[i * N + j] * a_scale * b_scale);
317+
float bias_val = bias ? bias[j] : 0.0f;
318+
c_ptr[i * lda + j] = static_cast<out_dtype>(y_buf[i * N + j] * a_scale * b_scale + bias_val);
345319
}
346320
} // for M
347321
}
@@ -515,7 +489,7 @@ void _float8_linear_impl(
515489

516490
for (int mci = mc; mci < mc_end; ++mci) {
517491
int64_t m_size = mci * block_m + block_m > M ? M - mci * block_m : block_m;
518-
zero_buffer(y_buf, m_size * block_n);
492+
memset(y_buf, 0, sizeof(float) * m_size * block_n);
519493
for (int kci = 0; kci < Kc; ++kci) {
520494
auto scales_a = a_scales_ptr + mci * block_m * num_groups + kci / block_per_group;
521495
auto scales_b = b_scales_ptr + nc * block_n * num_groups + kci / block_per_group * block_n;

torchao/csrc/cpu/aten_kernels/utils.h

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,24 +35,6 @@ get_m_blocking(int64_t M) {
3535
// Cached check for AVX-512F support in this process, for use by CPU kernels
3636
// that include this header and are compiled with CPU_CAPABILITY_AVX512.
3737
inline const bool kHasAVX512 = __builtin_cpu_supports("avx512f");
38-
39-
template<typename T>
40-
void zero_buffer(T* data, int64_t size) {
41-
const int32_t vec_size = at::vec::Vectorized<T>::size();
42-
auto zero_vec = at::vec::Vectorized<T>(0);
43-
int64_t d = 0;
44-
for (; d < size - (size % vec_size); d += vec_size) {
45-
zero_vec.store(data + d);
46-
}
47-
if (d < size) {
48-
zero_vec.store(data + d, size - d);
49-
}
50-
}
51-
#else
52-
template<typename T>
53-
void zero_buffer(T* data, int64_t size) {
54-
memset(data, 0, sizeof(T) * size);
55-
}
5638
#endif
5739

5840
template <typename T> struct vnni_traits;

0 commit comments

Comments
 (0)