Skip to content

Commit 3c2cb8c

Browse files
authored
[CPU] Add check of AVX512 at runtitme (#4039)
* [CPU] Add check of AVX512 at runtitme * Move flag to utils.h * Update comments
1 parent e654d74 commit 3c2cb8c

2 files changed

Lines changed: 8 additions & 3 deletions

File tree

torchao/csrc/cpu/aten_kernels/scaled_embedding_bag.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <c10/util/Float8_e4m3fn.h>
66
#include <c10/util/Unroll.h>
77
#include <torch/all.h>
8+
#include "utils.h"
89

910
#define QTYPE_DISPATCH(TYPE, ...) \
1011
[&]() { \
@@ -180,7 +181,7 @@ inline void _scaled_embedding_bag_krnl(
180181
const index_t *offsets, const data_t *weight, const double scale,
181182
output_t *result, const int64_t num_batch) {
182183
#if defined(CPU_CAPABILITY_AVX512)
183-
if (emb_dim % 128 == 0) {
184+
if (kHasAVX512 && emb_dim % 128 == 0) {
184185
constexpr int64_t block_dim = 128;
185186
const int64_t num_blocks = emb_dim / block_dim;
186187
__m512 scale_v = _mm512_set1_ps(scale);

torchao/csrc/cpu/aten_kernels/utils.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#include <tuple>
1111
#include <ATen/native/cpu/utils.h>
1212

13-
int64_t get_m_block(int64_t M) {
13+
inline int64_t get_m_block(int64_t M) {
1414
if (M <= 48) {
1515
return M;
1616
} else if (M < 64) {
@@ -22,7 +22,7 @@ int64_t get_m_block(int64_t M) {
2222
}
2323
}
2424

25-
std::tuple<bool, int64_t, int64_t, int64_t>
25+
inline std::tuple<bool, int64_t, int64_t, int64_t>
2626
get_m_blocking(int64_t M) {
2727
bool parallel_on_M = M > 128;
2828
int64_t block_m = get_m_block(M);
@@ -32,6 +32,10 @@ get_m_blocking(int64_t M) {
3232
}
3333

3434
#if defined(CPU_CAPABILITY_AVX512)
35+
// Cached check for AVX-512F support in this process, for use by CPU kernels
36+
// that include this header and are compiled with CPU_CAPABILITY_AVX512.
37+
inline const bool kHasAVX512 = __builtin_cpu_supports("avx512f");
38+
3539
template<typename T>
3640
void zero_buffer(T* data, int64_t size) {
3741
const int32_t vec_size = at::vec::Vectorized<T>::size();

0 commit comments

Comments
 (0)