Skip to content

Commit 9ec0d6b

Browse files
author
Copilot
committed
Compile code snippet to detect supported ISA
1 parent ea7803b commit 9ec0d6b

2 files changed

Lines changed: 127 additions & 66 deletions

File tree

setup.py

Lines changed: 120 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -387,77 +387,136 @@ def bool_to_on_off(value):
387387
class X86KernelBuild:
388388
"""Class for all x86-kernel-specific build logic"""
389389

390-
# Preferred GCC major version required for full AVX10.2 support.
391-
# Minimum GCC major version required for building x86 kernels.
392-
_PREFERRED_GCC_MAJOR = 15
393-
_MINIMUM_GCC_MAJOR = 11
394-
_cxx = None
395-
_cxx_major = None
396-
_cxx_checked = False
390+
# ISA capability levels, in increasing order of capability.
391+
# None – The default level
392+
# "avx512" – AVX-512F/BW/VL/DQ/VNNI, etc.
393+
# "avx10_2" – avx512 + AVX10.2
394+
_ISA_LEVELS = [None, "avx512", "avx10_2"]
397395

398-
@staticmethod
399-
def _gcc_major(exe: str) -> "int | None":
400-
"""Return the GCC major version reported by *exe*, or None on failure."""
401-
try:
402-
out = subprocess.run(
403-
[exe, "--version"],
404-
capture_output=True,
405-
text=True,
406-
timeout=10,
407-
).stdout
408-
# Extract the first MAJOR.MINOR.PATCH version in the output.
409-
m = re.search(r"\b(\d+)\.\d+\.\d+", out)
410-
if m:
411-
return int(m.group(1))
412-
except Exception:
413-
pass
414-
return None
396+
_cxx = None # resolved compiler path
397+
_cxx_checked = False # True once find_cxx_compiler() has run
398+
_isa_level = None # set by _probe_isa(); one of _ISA_LEVELS
399+
_isa_probed = False # True once _probe_isa() has run
415400

416401
@staticmethod
417402
def find_cxx_compiler() -> "str | None":
418-
"""Find a C++ compiler
403+
"""Find a C++ compiler.
419404
420405
Checks $CXX first, then ``g++`` on $PATH.
421-
Returns the compiler path or None.
406+
Returns the resolved compiler path, or None if not found.
422407
"""
423408
if X86KernelBuild._cxx_checked:
424409
return X86KernelBuild._cxx
425410

426-
def _check(exe: str) -> "str | None":
411+
def _resolve(exe: str) -> "str | None":
427412
if not exe:
428413
return None
429-
if os.sep not in exe:
430-
exe = shutil.which(exe) or ""
431-
if not os.path.isfile(exe):
432-
return None
433-
major = X86KernelBuild._gcc_major(exe)
434-
return exe if major is not None else None
435-
436-
cxx = _check(os.environ.get("CXX", ""))
437-
if _check(cxx):
438-
X86KernelBuild._cxx = cxx
439-
X86KernelBuild._cxx_major = X86KernelBuild._gcc_major(cxx)
440-
elif _check("g++"):
441-
X86KernelBuild._cxx = "g++"
442-
X86KernelBuild._cxx_major = X86KernelBuild._gcc_major("g++")
414+
resolved = shutil.which(exe) if os.sep not in exe else exe
415+
return resolved if resolved and os.path.isfile(resolved) else None
416+
417+
X86KernelBuild._cxx = _resolve(os.environ.get("CXX", "")) or _resolve("g++")
443418
X86KernelBuild._cxx_checked = True
444419
return X86KernelBuild._cxx
445420

421+
@staticmethod
422+
def _try_compile(cxx: str, march: str, snippet: str) -> bool:
423+
"""Return True if *cxx* can compile *snippet* with the given *march* flag."""
424+
import tempfile
425+
426+
with tempfile.TemporaryDirectory() as tmpdir:
427+
src = os.path.join(tmpdir, "probe.cpp")
428+
obj = os.path.join(tmpdir, "probe.o")
429+
with open(src, "w") as f:
430+
f.write(snippet)
431+
try:
432+
subprocess.check_call(
433+
[cxx, "-std=c++20", march, "-c", src, "-o", obj],
434+
stdout=subprocess.DEVNULL,
435+
stderr=subprocess.DEVNULL,
436+
timeout=30,
437+
)
438+
return True
439+
except Exception:
440+
return False
441+
442+
@staticmethod
443+
def _probe_isa() -> None:
444+
"""Probe which ISA level the compiler supports by compiling test snippets."""
445+
if X86KernelBuild._isa_probed:
446+
return
447+
448+
# Snippet that exercises AVX512 + VNNI (the minimum for our kernels).
449+
# _mm512_dpbusd_epi32 requires avx512f + avx512vnni.
450+
_AVX512_SNIPPET = """\
451+
#include <immintrin.h>
452+
int main() {
453+
__m512i a = _mm512_setzero_epi32();
454+
// avx512-vnni
455+
__m512i c = _mm512_dpbusd_epi32(a, a, a);
456+
(void)c;
457+
return 0;
458+
}
459+
"""
460+
# Snippet that exercises AVX10.2 fp8 hardware conversions.
461+
# _mm256_cvthf8_ph requires -march=diamondrapids and GCC >= 15.
462+
_AVX10_2_SNIPPET = """\
463+
#include <immintrin.h>
464+
int main() {
465+
__m128i a = _mm_setzero_si128();
466+
// avx10.2 fp8 -> fp16 hardware convert
467+
__m256h b = _mm256_cvthf8_ph(a);
468+
(void)b;
469+
return 0;
470+
}
471+
"""
472+
cxx = X86KernelBuild._cxx
473+
if cxx is None:
474+
X86KernelBuild._isa_probed = True
475+
return
476+
477+
if X86KernelBuild._try_compile(cxx, "-march=diamondrapids", _AVX10_2_SNIPPET):
478+
X86KernelBuild._isa_level = "avx10_2"
479+
elif X86KernelBuild._try_compile(cxx, "-march=sapphirerapids", _AVX512_SNIPPET):
480+
X86KernelBuild._isa_level = "avx512"
481+
else:
482+
X86KernelBuild._isa_level = None
483+
484+
print("[X86 Build] compiler check")
485+
print("- Found compiler:", cxx)
486+
print(
487+
"- AVX512 support:",
488+
"Yes" if X86KernelBuild._isa_at_least("avx512") else "No",
489+
)
490+
print(
491+
"- AVX10.2 support:",
492+
"Yes" if X86KernelBuild._isa_at_least("avx10_2") else "No",
493+
)
494+
X86KernelBuild._isa_probed = True
495+
496+
@staticmethod
497+
def _isa_at_least(level: str) -> bool:
498+
"""Return True if the probed ISA level is >= *level*."""
499+
levels = X86KernelBuild._ISA_LEVELS
500+
return levels.index(X86KernelBuild._isa_level) >= levels.index(level)
501+
446502
@staticmethod
447503
def is_enabled() -> bool:
448504
"""Return True when CPU aten_kernels should be included in the build."""
449505
enabled = bool(use_cpu_kernels and is_linux and is_x86_64)
450-
if enabled and not X86KernelBuild._cxx_checked:
506+
if enabled and not X86KernelBuild._isa_probed:
451507
X86KernelBuild.find_cxx_compiler()
452-
if (
453-
not X86KernelBuild._cxx
454-
or X86KernelBuild._cxx_major < X86KernelBuild._MINIMUM_GCC_MAJOR
455-
):
508+
if not X86KernelBuild._cxx:
456509
raise RuntimeError(
457-
"You are building with `USE_CPU_KERNELS=1` but "
458-
"no suitable C++ compiler found for building X86 kernels. "
459-
"Please set the CXX environment variable to a GCC compiler "
460-
"(version 15 for full features, 11 or higher required)."
510+
"[X86 Build] You are building X86 kernels but no C++ compiler was found. "
511+
"Please set the CXX environment variable to point to g++."
512+
)
513+
X86KernelBuild._probe_isa()
514+
if not X86KernelBuild._isa_at_least("avx512"):
515+
raise RuntimeError(
516+
"[X86 Build] You are building X86 kernels but the compiler "
517+
f"({X86KernelBuild._cxx}) does not support the required ISA "
518+
"features (AVX512F + AVX512-VNNI). Please install a compatible "
519+
"compiler (GCC >= 11.2, GCC 15 for full features)."
461520
)
462521
return enabled
463522

@@ -489,9 +548,10 @@ def add_compile_flags(extra_compile_args: dict) -> None:
489548
"-fno-tree-vectorize",
490549
"-fopenmp",
491550
]
492-
if X86KernelBuild._cxx_major >= X86KernelBuild._MINIMUM_GCC_MAJOR:
551+
# Gate defines on probed ISA capability, not compiler version.
552+
if X86KernelBuild._isa_at_least("avx512"):
493553
flags.append("-DBUILD_AVX512")
494-
if X86KernelBuild._cxx_major >= X86KernelBuild._PREFERRED_GCC_MAJOR:
554+
if X86KernelBuild._isa_at_least("avx10_2"):
495555
flags.append("-DBUILD_AVX10_2")
496556
extra_compile_args["cxx"].extend(flags)
497557

@@ -532,13 +592,13 @@ def precompile_isa_objects(build_temp: str, extensions: list) -> None:
532592
build_configs = [
533593
{
534594
"isa": "AVX512",
535-
"gcc_min_ver": X86KernelBuild._MINIMUM_GCC_MAJOR,
595+
"isa_level": "avx512",
536596
"defines": avx512_defines + ["-DCPU_CAPABILITY=AVX512"],
537597
"flags": ["-march=sapphirerapids"],
538598
},
539599
{
540600
"isa": "AVX10_2",
541-
"gcc_min_ver": X86KernelBuild._PREFERRED_GCC_MAJOR,
601+
"isa_level": "avx10_2",
542602
"defines": avx10_2_defines + ["-DCPU_CAPABILITY=AVX10_2"],
543603
"flags": ["-march=diamondrapids"],
544604
},
@@ -551,21 +611,22 @@ def compile_kernel(src):
551611
shutil.copy2(src, temp_src)
552612
obj = os.path.join(build_dir, f"{stem}.{config['isa']}.o")
553613
cmd = [cxx] + cxx_flags + ["-c", temp_src, "-o", obj]
554-
print(f"[X86 {config['isa']}] Compiling {src} -> {os.path.basename(obj)}")
614+
print(
615+
f"[X86 Build] {config['isa']}: Compiling {src} -> {os.path.basename(obj)}"
616+
)
555617
try:
556618
subprocess.check_call(cmd)
557619
return obj
558620
except subprocess.CalledProcessError as e:
559621
print(
560-
f"[ERROR] Unable to compile {config['isa']} variant of {src}:\n{e}\n"
622+
f"[X86 Build] [ERROR] Unable to compile {config['isa']} variant of {src}:\n{e}\n"
561623
)
562624
raise e
563625

564626
for config in build_configs:
565-
if X86KernelBuild._cxx_major < config["gcc_min_ver"]:
627+
if not X86KernelBuild._isa_at_least(config["isa_level"]):
566628
print(
567-
f"[WARNING] GCC {config['gcc_min_ver']} or higher is required to build {config['isa']} kernels. "
568-
f"Current compiler: {cxx} (GCC {X86KernelBuild._cxx_major}). "
629+
f"[X86 Build] [WARNING] Compiler does not support {config['isa']} ISA. "
569630
f"{config['isa']} kernels will not be built."
570631
)
571632
continue

torchao/csrc/cpu/aten_kernels/dispatch.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,8 @@ declare_all_kernels(DEFAULT)
155155

156156
/********** DA8W4 Linear Kernel Dispatch **********/
157157
declare_da8w4_linear_prepack_impl {
158-
// BUILD_AVX10_2 should be only set when __GNUC__ >= 15. Here is just a double check. Same below.
159-
#if defined(BUILD_AVX10_2) && __GNUC__ >= 15
158+
// BUILD_AVX10_2 is only defined when the compiler passed the AVX10.2 ISA probe in setup.py.
159+
#if defined(BUILD_AVX10_2)
160160
if (kHasAVX10_2) {
161161
return AVX10_2::call_da8w4_linear_prepack_impl();
162162
}
@@ -170,7 +170,7 @@ declare_da8w4_linear_prepack_impl {
170170
}
171171

172172
declare_da8w4_linear_impl {
173-
#if defined(BUILD_AVX10_2) && __GNUC__ >= 15
173+
#if defined(BUILD_AVX10_2)
174174
if (kHasAVX10_2) {
175175
return AVX10_2::call_da8w4_linear_impl();
176176
}
@@ -185,7 +185,7 @@ declare_da8w4_linear_impl {
185185

186186
/********** FLOAT8 Linear Kernel Dispatch **********/
187187
declare_float8_linear_prepack_impl {
188-
#if defined(BUILD_AVX10_2) && __GNUC__ >= 15
188+
#if defined(BUILD_AVX10_2)
189189
if (kHasAVX10_2) {
190190
return AVX10_2::call_float8_linear_prepack_impl();
191191
}
@@ -199,7 +199,7 @@ declare_float8_linear_prepack_impl {
199199
}
200200

201201
declare_float8_linear_impl {
202-
#if defined(BUILD_AVX10_2) && __GNUC__ >= 15
202+
#if defined(BUILD_AVX10_2)
203203
if (kHasAVX10_2) {
204204
return AVX10_2::call_float8_linear_impl();
205205
}
@@ -214,7 +214,7 @@ declare_float8_linear_impl {
214214

215215
/********** Scaled Embedding Bag Kernel Dispatch **********/
216216
declare_scaled_embedding_bag_impl {
217-
#if defined(BUILD_AVX10_2) && __GNUC__ >= 15
217+
#if defined(BUILD_AVX10_2)
218218
if (kHasAVX10_2) {
219219
return AVX10_2::call_scaled_embedding_bag_impl();
220220
}
@@ -229,7 +229,7 @@ declare_scaled_embedding_bag_impl {
229229

230230
/********** Quantized SDPA Kernel **********/
231231
declare_qscaled_dot_product_impl {
232-
#if defined(BUILD_AVX10_2) && __GNUC__ >= 15
232+
#if defined(BUILD_AVX10_2)
233233
if (kHasAVX10_2) {
234234
return AVX10_2::call_qscaled_dot_product_impl();
235235
}

0 commit comments

Comments
 (0)