@@ -19,20 +19,17 @@ namespace {
1919#define PER_ROW 2
2020#define PER_GROUP 3
2121
22- static bool cpublas_checked = false ;
22+ static std::once_flag cpublas_flag ;
2323static bool cpublas_can_pack = false ;
2424
25- bool cpublas_could_pack () {
26- // the could_pack check requires AMX support implicitly
27- if (cpublas_checked) {
28- return cpublas_can_pack;
29- }
25+ static inline bool cpublas_could_pack () {
26+ std::call_once (cpublas_flag, []() {
3027#ifdef CPUBLAS_BRGEMM_F8F8F32
31- cpublas_can_pack = at::native::cpublas::could_pack (at::kFloat8_e4m3fn );
28+ cpublas_can_pack = at::native::cpublas::could_pack (at::kFloat8_e4m3fn );
3229#else
33- cpublas_can_pack = at::native::cpublas::could_pack (at::kBFloat16 );
30+ cpublas_can_pack = at::native::cpublas::could_pack (at::kBFloat16 );
3431#endif
35- cpublas_checked = true ;
32+ }) ;
3633 return cpublas_can_pack;
3734}
3835
@@ -124,59 +121,34 @@ float8_linear_prepack_impl(
124121}
125122
126123#if defined(CPU_CAPABILITY_AVX512)
127- // this doesn't handle NaN.
128- inline __m512bh cvt_e4m3_bf16_intrinsic_no_nan (__m256i fp8_vec) {
129- const __m512i x = _mm512_cvtepu8_epi16 (fp8_vec);
130-
131- const __m512i mant = _mm512_slli_epi16 (_mm512_and_si512 (x, _mm512_set1_epi16 (0x07 )), 4 );
132- const __m512i raw_exp = _mm512_srli_epi16 (_mm512_and_si512 (x, _mm512_set1_epi16 (0x78 )), 3 );
133- const __m512i exp = _mm512_slli_epi16 (_mm512_add_epi16 (raw_exp, _mm512_set1_epi16 (120 )), 7 );
134- const __m512i nonsign = _mm512_or_si512 (exp, mant);
135-
136- const __m512i sign = _mm512_slli_epi16 (_mm512_and_si512 (x, _mm512_set1_epi16 (0x80 )), 8 );
137- const __m512i combined = _mm512_or_si512 (nonsign, sign);
138-
139- const __mmask32 is_nonzero = _mm512_cmpneq_epi16_mask (x, _mm512_setzero_si512 ());
140- return (__m512bh)_mm512_maskz_mov_epi16 (is_nonzero, combined);
141- }
142-
143124static void cvt_f8e4m3_to_bf16 (
144- const at::Float8_e4m3fn* __restrict__ in,
145- at::BFloat16* out,
146- int64_t rows,
147- int64_t cols,
148- int64_t stride) {
149- if (stride == cols) {
150- // A contiguous buffer
151- size_t len = rows * cols;
152- size_t i = 0 ;
153- for (; i < len; i += 32 ) {
154- __m256i fp8_vec = _mm256_loadu_si256 ((__m256i*)&in[i]);
155- __m512bh bf16_vec = cvt_e4m3_bf16_intrinsic_no_nan (fp8_vec);
156- _mm512_storeu_si512 ((__m512i*)(out + i), (__m512i)bf16_vec);
157- }
158- for (; i < len; ++i) {
159- out[i] = (at::BFloat16)in[i];
160- }
161- } else {
162- // Non-contiguous. Access each row with stride
163- TORCH_CHECK (stride > cols);
164- for (int r = 0 ; r < rows; ++r) {
165- size_t i = 0 ;
166- size_t vec_len = cols / 32 * 32 ;
167- for (; i < vec_len; i += 32 ) {
168- __m256i fp8_vec = _mm256_loadu_si256 ((__m256i*)&in[r * stride + i]);
169- __m512bh bf16_vec = cvt_e4m3_bf16_intrinsic_no_nan (fp8_vec);
170- _mm512_storeu_si512 ((__m512i*)(out + r * cols + i), (__m512i)bf16_vec);
171- }
172- for (; i < cols; ++i) {
173- out[r * cols + i] = (at::BFloat16)in[r * stride + i];
174- }
175- }
125+ const at::Float8_e4m3fn* __restrict__ in,
126+ at::BFloat16* out,
127+ int64_t rows,
128+ int64_t cols,
129+ int64_t stride) {
130+ constexpr int64_t vec_len = 32 ; // 256 bit = 32 fp8 values
131+ __m512 fp32_vec_0, fp32_vec_1;
132+ for (int r = 0 ; r < rows; ++r) {
133+ size_t i = 0 ;
134+ size_t vec_len_aligned = cols / vec_len * vec_len;
135+ for (; i < vec_len_aligned; i += vec_len) {
136+ __m256i fp8_vec = _mm256_loadu_si256 ((__m256i*)&in[r * stride + i]);
137+ // Convert fp8 to fp32
138+ at::vec::CPU_CAPABILITY::cvtfp8e4m3_fp32 (_mm256_castsi256_si128 (fp8_vec), fp32_vec_0);
139+ at::vec::CPU_CAPABILITY::cvtfp8e4m3_fp32 (_mm256_extracti128_si256 (fp8_vec, 1 ), fp32_vec_1);
140+ // Convert to bf16 and store
141+ __m256i bf16_vec_0 = at::vec::cvtfp32_bf16 (fp32_vec_0);
142+ __m256i bf16_vec_1 = at::vec::cvtfp32_bf16 (fp32_vec_1);
143+ __m512i bf16_vec = _mm512_inserti32x8 (_mm512_castsi256_si512 (bf16_vec_0), bf16_vec_1, 1 );
144+ _mm512_storeu_si512 ((__m512i*)(out + r * cols + i), bf16_vec);
145+ }
146+ for (; i < cols; ++i) {
147+ out[r * cols + i] = (at::BFloat16)in[r * stride + i];
148+ }
176149 }
177150}
178151
179-
180152// accumulate and store result to buffer
181153// if act/wei are per_group quantized, apply scales
182154template <bool accum, int64_t N, int act_quant_mode, int wei_quant_mode>
@@ -294,7 +266,8 @@ inline void store_out(
294266 if constexpr (wei_quant_mode == PER_ROW) {
295267 b_scale = scales_b[j];
296268 }
297- c_ptr[i * lda + j] = static_cast <out_dtype>(y_buf[i * N + j] * a_scale * b_scale);
269+ float bias_val = bias ? bias[j] : 0 .0f ;
270+ c_ptr[i * lda + j] = static_cast <out_dtype>(y_buf[i * N + j] * a_scale * b_scale + bias_val);
298271 }
299272 } // for M
300273}
@@ -341,7 +314,8 @@ inline void store_out(
341314 if constexpr (wei_quant_mode == PER_ROW) {
342315 b_scale = scales_b[j];
343316 }
344- c_ptr[i * lda + j] = static_cast <out_dtype>(y_buf[i * N + j] * a_scale * b_scale);
317+ float bias_val = bias ? bias[j] : 0 .0f ;
318+ c_ptr[i * lda + j] = static_cast <out_dtype>(y_buf[i * N + j] * a_scale * b_scale + bias_val);
345319 }
346320 } // for M
347321}
@@ -515,7 +489,7 @@ void _float8_linear_impl(
515489
516490 for (int mci = mc; mci < mc_end; ++mci) {
517491 int64_t m_size = mci * block_m + block_m > M ? M - mci * block_m : block_m;
518- zero_buffer (y_buf, m_size * block_n);
492+ memset (y_buf, 0 , sizeof ( float ) * m_size * block_n);
519493 for (int kci = 0 ; kci < Kc; ++kci) {
520494 auto scales_a = a_scales_ptr + mci * block_m * num_groups + kci / block_per_group;
521495 auto scales_b = b_scales_ptr + nc * block_n * num_groups + kci / block_per_group * block_n;
0 commit comments