@@ -419,44 +419,33 @@ def _gcc_major(exe: str) -> "int | None":
419419
420420 @staticmethod
421421 def find_preferred_cxx_compiler () -> "str | None" :
422- """Find a C++ compiler that meets the preferred GCC version requirement .
422+ """Find a C++ compiler at or above _PREFERRED_GCC_MAJOR .
423423
424- Search order:
425- 1. $CXX environment variable
426- 2. ``g++`` on $PATH (via ``which``)
427-
428- Returns the path of the first qualifying compiler, or None if no
429- compiler at or above _PREFERRED_GCC_MAJOR is found.
424+ Checks $CXX first, then ``g++`` on $PATH.
425+ Returns the compiler path or None.
430426 """
431427 min_major = X86KernelBuild ._PREFERRED_GCC_MAJOR
432428
433429 def _check (exe : str ) -> "str | None" :
434430 if not exe :
435431 return None
436432 if os .sep not in exe :
437- # Plain name — resolve via PATH.
438- resolved = shutil .which (exe )
439- if not resolved :
440- return None
441- exe = resolved
433+ exe = shutil .which (exe ) or ""
442434 if not os .path .isfile (exe ):
443435 return None
444436 major = X86KernelBuild ._gcc_major (exe )
445- if major is not None and major >= min_major :
446- return exe
447- return None
448-
449- # 1. Explicit $CXX
450- result = _check (os .environ .get ("CXX" , "" ))
451- if result :
452- return result
437+ return exe if major is not None and major >= min_major else None
453438
454- # 2. g++ on PATH
455- result = _check ("g++" )
456- if result :
457- return result
439+ return _check (os .environ .get ("CXX" , "" )) or _check ("g++" )
458440
459- return None
441+ @staticmethod
442+ def _resolve_cxx () -> str :
443+ """Return the effective C++ compiler path (from $CXX or preferred, falling back to g++)."""
444+ return (
445+ os .environ .get ("CXX" )
446+ or X86KernelBuild .find_preferred_cxx_compiler ()
447+ or "g++"
448+ )
460449
461450 @staticmethod
462451 def get_include_flags () -> list :
@@ -477,27 +466,12 @@ def get_include_flags() -> list:
477466 def add_compile_flags (extra_compile_args : dict ) -> None :
478467 """Extend *extra_compile_args* with CPU-kernel compile options.
479468
480- The main kernel files are compiled with AVX512 + AMX flags so that
481- PyTorch's vec512 headers (which use Sleef f16 intrinsics under
482- CPU_CAPABILITY_AVX512) compile correctly. -fno-tree-vectorize
483- prevents the compiler from emitting 512-bit packed instructions in
484- scalar fallback functions; AVX512 pragma regions re-enable
485- vectorization only where explicitly desired.
486-
487- Runtime dispatch via __builtin_cpu_supports() selects the right path:
488- - no AVX512: scalar fallback paths are used.
489- - AVX512 + AMX: AVX512/AMX optimised paths are selected.
490- - AVX10.2: AVX10.2 objects (compiled separately with
491- -march=diamondrapids) are also linked in.
469+ Enables AVX512 + AMX (for PyTorch vec512 headers) and -fno-tree-vectorize
470+ (to prevent scalar paths from emitting 512-bit instructions).
471+ Runtime dispatch via __builtin_cpu_supports() selects the right path.
492472 """
493473 if not X86KernelBuild .is_enabled ():
494474 return
495- # Build with full AVX512 + AMX support so PyTorch's vec512 headers
496- # (which use Sleef f16 intrinsics under CPU_CAPABILITY_AVX512) compile
497- # correctly. -fno-tree-vectorize prevents the compiler from emitting
498- # 512-bit packed instructions in *scalar* fallback functions; AVX512
499- # pragma regions add #pragma GCC optimize("O3,tree-vectorize") to
500- # re-enable vectorization only where explicitly desired.
501475 extra_compile_args ["cxx" ].extend (
502476 [
503477 "-DCPU_CAPABILITY_AVX512" ,
@@ -528,14 +502,10 @@ def add_compile_flags(extra_compile_args: dict) -> None:
528502
529503 @staticmethod
530504 def get_link_flags () -> list :
531- """Return extra link flags for the CPU kernel .so.
532-
533- Adds an RPATH entry for every PyTorch library directory so that
534- libc10.so / libtorch_cpu.so are found at runtime without needing
535- LD_LIBRARY_PATH. Also statically links libstdc++ to carry new
536- CXXABI symbols (e.g. __cxa_call_terminate from CXXABI_1.3.15)
537- that newer GCC versions generate but PyTorch's bundled libstdc++
538- may lack.
505+ """Return extra link flags: PyTorch lib RPATHs + -static-libstdc++.
506+
507+ Static libstdc++ carries new CXXABI symbols that GCC 15 generates but
508+ PyTorch's bundled libstdc++ may lack.
539509 """
540510 if not X86KernelBuild .is_enabled ():
541511 return []
@@ -554,11 +524,9 @@ def get_link_flags() -> list:
554524 def filter_sources (sources : list , extensions_dir : str ) -> list :
555525 """Remove CPU aten_kernels sources from *sources* when not building for CPU."""
556526 aten_kernels_dir = os .path .join (extensions_dir , "cpu" , "aten_kernels" )
557- cxx = os .environ .get (
558- "CXX" , X86KernelBuild .find_preferred_cxx_compiler () or "g++"
559- )
527+ cxx = X86KernelBuild ._resolve_cxx ()
560528 compiler_ok = (
561- cxx and X86KernelBuild ._gcc_major (cxx ) >= X86KernelBuild ._MINIMUM_GCC_MAJOR
529+ X86KernelBuild ._gcc_major (cxx ) >= X86KernelBuild ._MINIMUM_GCC_MAJOR
562530 )
563531 if not X86KernelBuild .is_enabled () or not compiler_ok :
564532 excluded = set (glob .glob (os .path .join (aten_kernels_dir , "*.cpp" )))
@@ -567,31 +535,19 @@ def filter_sources(sources: list, extensions_dir: str) -> list:
567535
568536 @staticmethod
569537 def precompile_isa_objects (build_temp : str , extensions : list ) -> None :
570- """Pre-compile ISA-specific CPU objects from kernel source files.
571-
572- Instead of maintaining separate *_avx10_2.cpp files, each kernel
573- source file contains ISA-specific code guarded by
574- CPU_CAPABILITY_AVX10_2. At build time we:
575- 1. Scan kernel .cpp files for the CPU_CAPABILITY_AVX10_2 marker.
576- 2. Copy each matching file to a temp path in the build dir.
577- 3. Compile that temp copy with -DCPU_CAPABILITY_AVX10_2
578- -march=diamondrapids.
579- 4. Attach the resulting .o as extra_objects on the main extension.
580-
581- The #if defined(CPU_CAPABILITY_AVX10_2) guard in each source file
582- ensures that only the AVX10.2 variant code is compiled in the temp
583- copy (the main build compiles the #else branch).
538+ """Compile AVX10.2 temp copies of kernel files that contain CPU_CAPABILITY_AVX10_2.
539+
540+ Each matching .cpp is copied to a temp dir and compiled with
541+ -DCPU_CAPABILITY_AVX10_2 -march=diamondrapids. The resulting .o is
542+ attached as an extra_object on the main torchao._C extension.
584543 """
585544 main_ext = next ((e for e in extensions if e .name == "torchao._C" ), None )
586545 if main_ext is None :
587546 return
588547
589- cxx = os .environ .get (
590- "CXX" , X86KernelBuild .find_preferred_cxx_compiler () or "g++"
591- )
548+ cxx = X86KernelBuild ._resolve_cxx ()
592549 compiler_ok = (
593- cxx
594- and X86KernelBuild ._gcc_major (cxx ) >= X86KernelBuild ._PREFERRED_GCC_MAJOR
550+ X86KernelBuild ._gcc_major (cxx ) >= X86KernelBuild ._PREFERRED_GCC_MAJOR
595551 )
596552 if not compiler_ok :
597553 print (
@@ -607,9 +563,8 @@ def precompile_isa_objects(build_temp: str, extensions: list) -> None:
607563
608564 aten_kernels_dir = os .path .join ("torchao" , "csrc" , "cpu" , "aten_kernels" )
609565
610- # --- AVX10.2 variant: copy kernel files and compile with DMR target ---
611- # Include the kernel source dir so that relative includes like
612- # utils.h still resolve when the file is compiled from a temp copy.
566+ # Copy each kernel that has AVX10.2 code to a temp dir and compile
567+ # with -march=diamondrapids so hardware fp8 instructions are available.
613568 avx10_2_flags = (
614569 ["-O3" , "-std=c++20" , "-fPIC" , "-fopenmp" ]
615570 + include_flags
0 commit comments